From 47e6a33a4505e13ba4edf37ff1f8fbdadb279ee3 Mon Sep 17 00:00:00 2001 From: JoramWilander Date: Wed, 16 Sep 2015 15:49:12 -0400 Subject: Implement OAuth2 service provider functionality. --- api/api.go | 1 + api/api_test.go | 2 +- api/channel_test.go | 4 +- api/command.go | 8 +- api/context.go | 77 ++++-- api/oauth.go | 165 +++++++++++ api/oauth_test.go | 157 +++++++++++ api/post_test.go | 4 +- api/team_test.go | 2 +- api/user.go | 42 +-- api/user_test.go | 17 +- config/config.json | 3 +- docker/dev/config_docker.json | 3 +- docker/local/config_docker.json | 3 +- manualtesting/manual_testing.go | 4 +- model/access.go | 56 +++- model/access_test.go | 41 +++ model/authorize.go | 103 +++++++ model/authorize_test.go | 66 +++++ model/client.go | 192 ++++++++----- model/oauth.go | 151 ++++++++++ model/oauth_test.go | 95 +++++++ model/session.go | 9 +- model/utils.go | 2 + store/sql_oauth_store.go | 334 +++++++++++++++++++++++ store/sql_oauth_store_test.go | 182 ++++++++++++ store/sql_session_store.go | 25 +- store/sql_session_store_test.go | 4 +- store/sql_store.go | 27 +- store/store.go | 19 +- utils/config.go | 30 +- web/react/components/authorize.jsx | 72 +++++ web/react/components/popover_list_members.jsx | 2 +- web/react/components/register_app_modal.jsx | 249 +++++++++++++++++ web/react/components/user_settings.jsx | 10 + web/react/components/user_settings_developer.jsx | 93 +++++++ web/react/components/user_settings_modal.jsx | 11 +- web/react/pages/authorize.jsx | 21 ++ web/react/pages/channel.jsx | 6 + web/react/utils/client.jsx | 33 +++ web/sass-files/sass/partials/_signup.scss | 15 + web/templates/authorize.html | 26 ++ web/templates/channel.html | 1 + web/web.go | 204 +++++++++++++- web/web_test.go | 134 ++++++++- 45 files changed, 2506 insertions(+), 199 deletions(-) create mode 100644 api/oauth.go create mode 100644 api/oauth_test.go create mode 100644 model/access_test.go create mode 100644 model/authorize.go create mode 100644 model/authorize_test.go create mode 100644 model/oauth.go create mode 100644 model/oauth_test.go create mode 100644 store/sql_oauth_store.go create mode 100644 store/sql_oauth_store_test.go create mode 100644 web/react/components/authorize.jsx create mode 100644 web/react/components/register_app_modal.jsx create mode 100644 web/react/components/user_settings_developer.jsx create mode 100644 web/react/pages/authorize.jsx create mode 100644 web/templates/authorize.html diff --git a/api/api.go b/api/api.go index 8203b07a6..c8f97c5af 100644 --- a/api/api.go +++ b/api/api.go @@ -43,6 +43,7 @@ func InitApi() { InitFile(r) InitCommand(r) InitAdmin(r) + InitOAuth(r) templatesDir := utils.FindDir("api/templates") l4g.Debug("Parsing server templates at %v", templatesDir) diff --git a/api/api_test.go b/api/api_test.go index 0c2e57891..642db581e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -17,7 +17,7 @@ func Setup() { NewServer() StartServer() InitApi() - Client = model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port + "/api/v1") + Client = model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port) } } diff --git a/api/channel_test.go b/api/channel_test.go index d65aff66c..7e9267192 100644 --- a/api/channel_test.go +++ b/api/channel_test.go @@ -62,7 +62,7 @@ func TestCreateChannel(t *testing.T) { } } - if _, err := Client.DoPost("/channels/create", "garbage"); err == nil { + if _, err := Client.DoApiPost("/channels/create", "garbage"); err == nil { t.Fatal("should have been an error") } @@ -627,7 +627,7 @@ func TestGetChannelExtraInfo(t *testing.T) { currentEtag = cache_result.Etag } - Client2 := model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port + "/api/v1") + Client2 := model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port) user2 := &model.User{TeamId: team.Id, Email: model.NewId() + "tester2@test.com", Nickname: "Tester 2", Password: "pwd"} user2 = Client2.Must(Client2.CreateUser(user2, "")).Data.(*model.User) diff --git a/api/command.go b/api/command.go index 2919e93a0..be1d3229b 100644 --- a/api/command.go +++ b/api/command.go @@ -315,7 +315,7 @@ func loadTestSetupCommand(c *Context, command *model.Command) bool { numPosts, _ = strconv.Atoi(tokens[numArgs+2]) } } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) if doTeams { if err := CreateBasicUser(client); err != nil { @@ -375,7 +375,7 @@ func loadTestUsersCommand(c *Context, command *model.Command) bool { if err == false { usersr = utils.Range{10, 15} } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) userCreator := NewAutoUserCreator(client, c.Session.TeamId) userCreator.Fuzzy = doFuzz userCreator.CreateTestUsers(usersr) @@ -405,7 +405,7 @@ func loadTestChannelsCommand(c *Context, command *model.Command) bool { if err == false { channelsr = utils.Range{20, 30} } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) client.MockSession(c.Session.Id) channelCreator := NewAutoChannelCreator(client, c.Session.TeamId) channelCreator.Fuzzy = doFuzz @@ -457,7 +457,7 @@ func loadTestPostsCommand(c *Context, command *model.Command) bool { } } - client := model.NewClient(c.GetSiteURL() + "/api/v1") + client := model.NewClient(c.GetSiteURL()) client.MockSession(c.Session.Id) testPoster := NewAutoPostCreator(client, command.ChannelId) testPoster.Fuzzy = doFuzz diff --git a/api/context.go b/api/context.go index 5dcdfaf96..b1b4d2d10 100644 --- a/api/context.go +++ b/api/context.go @@ -80,9 +80,36 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.RequestId = model.NewId() c.IpAddress = GetIpAddress(r) + token := "" + isTokenFromQueryString := false + + // Attempt to parse token out of the header + authHeader := r.Header.Get(model.HEADER_AUTH) + if len(authHeader) > 6 && strings.ToUpper(authHeader[0:6]) == model.HEADER_BEARER { + // Default session token + token = authHeader[7:] + + } else if len(authHeader) > 5 && strings.ToLower(authHeader[0:5]) == model.HEADER_TOKEN { + // OAuth token + token = authHeader[6:] + } + + // Attempt to parse the token from the cookie + if len(token) == 0 { + if cookie, err := r.Cookie(model.SESSION_TOKEN); err == nil { + token = cookie.Value + } + } + + // Attempt to parse token out of the query string + if len(token) == 0 { + token = r.URL.Query().Get("access_token") + isTokenFromQueryString = true + } + protocol := "http" - // if the request came from the ELB then assume this is produciton + // If the request came from the ELB then assume this is produciton // and redirect all http requests to https if utils.Cfg.ServiceSettings.UseSSL { forwardProto := r.Header.Get(model.HEADER_FORWARDED_PROTO) @@ -105,36 +132,19 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("Content-Security-Policy", "frame-ancestors none") } else { - // All api response bodies will be JSON formatted + // All api response bodies will be JSON formatted by default w.Header().Set("Content-Type", "application/json") } - sessionId := "" - - // attempt to parse the session token from the header - if ah := r.Header.Get(model.HEADER_AUTH); ah != "" { - if len(ah) > 6 && strings.ToUpper(ah[0:6]) == "BEARER" { - sessionId = ah[7:] - } - } - - // attempt to parse the session token from the cookie - if sessionId == "" { - if cookie, err := r.Cookie(model.SESSION_TOKEN); err == nil { - sessionId = cookie.Value - } - } - - if sessionId != "" { - + if len(token) != 0 { var session *model.Session - if ts, ok := sessionCache.Get(sessionId); ok { + if ts, ok := sessionCache.Get(token); ok { session = ts.(*model.Session) } if session == nil { - if sessionResult := <-Srv.Store.Session().Get(sessionId); sessionResult.Err != nil { - c.LogError(model.NewAppError("ServeHTTP", "Invalid session", "id="+sessionId+", err="+sessionResult.Err.DetailedError)) + if sessionResult := <-Srv.Store.Session().Get(token); sessionResult.Err != nil { + c.LogError(model.NewAppError("ServeHTTP", "Invalid session", "token="+token+", err="+sessionResult.Err.DetailedError)) } else { session = sessionResult.Data.(*model.Session) } @@ -142,7 +152,10 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if session == nil || session.IsExpired() { c.RemoveSessionCookie(w) - c.Err = model.NewAppError("ServeHTTP", "Invalid or expired session, please login again.", "id="+sessionId) + c.Err = model.NewAppError("ServeHTTP", "Invalid or expired session, please login again.", "token="+token) + c.Err.StatusCode = http.StatusUnauthorized + } else if !session.IsOAuth && isTokenFromQueryString { + c.Err = model.NewAppError("ServeHTTP", "Session is not OAuth but token was provided in the query string", "token="+token) c.Err.StatusCode = http.StatusUnauthorized } else { c.Session = *session @@ -166,10 +179,10 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.SystemAdminRequired() } - if c.Err == nil && h.isUserActivity && sessionId != "" && len(c.Session.UserId) > 0 { + if c.Err == nil && h.isUserActivity && token != "" && len(c.Session.UserId) > 0 { go func() { - if err := (<-Srv.Store.User().UpdateUserAndSessionActivity(c.Session.UserId, sessionId, model.GetMillis())).Err; err != nil { - l4g.Error("Failed to update LastActivityAt for user_id=%v and session_id=%v, err=%v", c.Session.UserId, sessionId, err) + if err := (<-Srv.Store.User().UpdateUserAndSessionActivity(c.Session.UserId, c.Session.Id, model.GetMillis())).Err; err != nil { + l4g.Error("Failed to update LastActivityAt for user_id=%v and session_id=%v, err=%v", c.Session.UserId, c.Session.Id, err) } }() } @@ -197,7 +210,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (c *Context) LogAudit(extraInfo string) { - audit := &model.Audit{UserId: c.Session.UserId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.AltId} + audit := &model.Audit{UserId: c.Session.UserId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.Id} if r := <-Srv.Store.Audit().Save(audit); r.Err != nil { c.LogError(r.Err) } @@ -209,7 +222,7 @@ func (c *Context) LogAuditWithUserId(userId, extraInfo string) { extraInfo = strings.TrimSpace(extraInfo + " session_user=" + c.Session.UserId) } - audit := &model.Audit{UserId: userId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.AltId} + audit := &model.Audit{UserId: userId, IpAddress: c.IpAddress, Action: c.Path, ExtraInfo: extraInfo, SessionId: c.Session.Id} if r := <-Srv.Store.Audit().Save(audit); r.Err != nil { c.LogError(r.Err) } @@ -315,7 +328,7 @@ func (c *Context) IsTeamAdmin(userId string) bool { func (c *Context) RemoveSessionCookie(w http.ResponseWriter) { - sessionCache.Remove(c.Session.Id) + sessionCache.Remove(c.Session.Token) cookie := &http.Cookie{ Name: model.SESSION_TOKEN, @@ -471,3 +484,7 @@ func Handle404(w http.ResponseWriter, r *http.Request) { l4g.Error("%v: code=404 ip=%v", r.URL.Path, GetIpAddress(r)) RenderWebError(err, w, r) } + +func AddSessionToCache(session *model.Session) { + sessionCache.Add(session.Token, session) +} diff --git a/api/oauth.go b/api/oauth.go new file mode 100644 index 000000000..26c3c5da8 --- /dev/null +++ b/api/oauth.go @@ -0,0 +1,165 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package api + +import ( + l4g "code.google.com/p/log4go" + "fmt" + "github.com/gorilla/mux" + "github.com/mattermost/platform/model" + "github.com/mattermost/platform/utils" + "net/http" + "net/url" +) + +func InitOAuth(r *mux.Router) { + l4g.Debug("Initializing oauth api routes") + + sr := r.PathPrefix("/oauth").Subrouter() + + sr.Handle("/register", ApiUserRequired(registerOAuthApp)).Methods("POST") + sr.Handle("/allow", ApiUserRequired(allowOAuth)).Methods("GET") +} + +func registerOAuthApp(c *Context, w http.ResponseWriter, r *http.Request) { + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + c.Err = model.NewAppError("registerOAuthApp", "The system admin has turned off OAuth service providing.", "") + c.Err.StatusCode = http.StatusNotImplemented + return + } + + app := model.OAuthAppFromJson(r.Body) + + if app == nil { + c.SetInvalidParam("registerOAuthApp", "app") + return + } + + secret := model.NewId() + + app.ClientSecret = secret + app.CreatorId = c.Session.UserId + + if result := <-Srv.Store.OAuth().SaveApp(app); result.Err != nil { + c.Err = result.Err + return + } else { + app = result.Data.(*model.OAuthApp) + app.ClientSecret = secret + + c.LogAudit("client_id=" + app.Id) + + w.Write([]byte(app.ToJson())) + return + } + +} + +func allowOAuth(c *Context, w http.ResponseWriter, r *http.Request) { + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + c.Err = model.NewAppError("allowOAuth", "The system admin has turned off OAuth service providing.", "") + c.Err.StatusCode = http.StatusNotImplemented + return + } + + c.LogAudit("attempt") + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + responseData := map[string]string{} + + responseType := r.URL.Query().Get("response_type") + if len(responseType) == 0 { + c.Err = model.NewAppError("allowOAuth", "invalid_request: Bad response_type", "") + return + } + + clientId := r.URL.Query().Get("client_id") + if len(clientId) != 26 { + c.Err = model.NewAppError("allowOAuth", "invalid_request: Bad client_id", "") + return + } + + redirectUri := r.URL.Query().Get("redirect_uri") + if len(redirectUri) == 0 { + c.Err = model.NewAppError("allowOAuth", "invalid_request: Missing or bad redirect_uri", "") + return + } + + scope := r.URL.Query().Get("scope") + state := r.URL.Query().Get("state") + + var app *model.OAuthApp + if result := <-Srv.Store.OAuth().GetApp(clientId); result.Err != nil { + c.Err = model.NewAppError("allowOAuth", "server_error: Error accessing the database", "") + return + } else { + app = result.Data.(*model.OAuthApp) + } + + if !app.IsValidRedirectURL(redirectUri) { + c.LogAudit("fail - redirect_uri did not match registered callback") + c.Err = model.NewAppError("allowOAuth", "invalid_request: Supplied redirect_uri did not match registered callback_url", "") + return + } + + if responseType != model.AUTHCODE_RESPONSE_TYPE { + responseData["redirect"] = redirectUri + "?error=unsupported_response_type&state=" + state + w.Write([]byte(model.MapToJson(responseData))) + return + } + + authData := &model.AuthData{UserId: c.Session.UserId, ClientId: clientId, CreateAt: model.GetMillis(), RedirectUri: redirectUri, State: state, Scope: scope} + authData.Code = model.HashPassword(fmt.Sprintf("%v:%v:%v:%v", clientId, redirectUri, authData.CreateAt, c.Session.UserId)) + + if result := <-Srv.Store.OAuth().SaveAuthData(authData); result.Err != nil { + responseData["redirect"] = redirectUri + "?error=server_error&state=" + state + w.Write([]byte(model.MapToJson(responseData))) + return + } + + c.LogAudit("success") + + responseData["redirect"] = redirectUri + "?code=" + url.QueryEscape(authData.Code) + "&state=" + url.QueryEscape(authData.State) + + w.Write([]byte(model.MapToJson(responseData))) +} + +func RevokeAccessToken(token string) *model.AppError { + + schan := Srv.Store.Session().Remove(token) + sessionCache.Remove(token) + + var accessData *model.AccessData + if result := <-Srv.Store.OAuth().GetAccessData(token); result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error getting access token from DB before deletion", "") + } else { + accessData = result.Data.(*model.AccessData) + } + + tchan := Srv.Store.OAuth().RemoveAccessData(token) + cchan := Srv.Store.OAuth().RemoveAuthData(accessData.AuthCode) + + if result := <-tchan; result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error deleting access token from DB", "") + } + + if result := <-cchan; result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error deleting authorization code from DB", "") + } + + if result := <-schan; result.Err != nil { + return model.NewAppError("RevokeAccessToken", "Error deleting session from DB", "") + } + + return nil +} + +func GetAuthData(code string) *model.AuthData { + if result := <-Srv.Store.OAuth().GetAuthData(code); result.Err != nil { + l4g.Error("Couldn't find auth code for code=%s", code) + return nil + } else { + return result.Data.(*model.AuthData) + } +} diff --git a/api/oauth_test.go b/api/oauth_test.go new file mode 100644 index 000000000..18db49bc5 --- /dev/null +++ b/api/oauth_test.go @@ -0,0 +1,157 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package api + +import ( + "github.com/mattermost/platform/model" + "github.com/mattermost/platform/store" + "github.com/mattermost/platform/utils" + "net/url" + "strings" + "testing" +) + +func TestRegisterApp(t *testing.T) { + Setup() + + team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewId() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN} + rteam, _ := Client.CreateTeam(&team) + + user := model.User{TeamId: rteam.Data.(*model.Team).Id, Email: strings.ToLower(model.NewId()) + "corey@test.com", Password: "pwd"} + ruser := Client.Must(Client.CreateUser(&user, "")).Data.(*model.User) + store.Must(Srv.Store.User().VerifyEmail(ruser.Id)) + + app := &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("should have failed - oauth providing turned off") + } + + } else { + + Client.Logout() + + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("not logged in - should have failed") + } + + Client.Must(Client.LoginById(ruser.Id, "pwd")) + + if result, err := Client.RegisterApp(app); err != nil { + t.Fatal(err) + } else { + rapp := result.Data.(*model.OAuthApp) + if len(rapp.Id) != 26 { + t.Fatal("clientid didn't return properly") + } + if len(rapp.ClientSecret) != 26 { + t.Fatal("client secret didn't return properly") + } + } + + app = &model.OAuthApp{Name: "", Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("missing name - should have failed") + } + + app = &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("missing homepage - should have failed") + } + + app = &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{}} + if _, err := Client.RegisterApp(app); err == nil { + t.Fatal("missing callback url - should have failed") + } + } +} + +func TestAllowOAuth(t *testing.T) { + Setup() + + team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewId() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN} + rteam, _ := Client.CreateTeam(&team) + + user := model.User{TeamId: rteam.Data.(*model.Team).Id, Email: strings.ToLower(model.NewId()) + "corey@test.com", Password: "pwd"} + ruser := Client.Must(Client.CreateUser(&user, "")).Data.(*model.User) + store.Must(Srv.Store.User().VerifyEmail(ruser.Id)) + + app := &model.OAuthApp{Name: "TestApp" + model.NewId(), Homepage: "https://nowhere.com", Description: "test", CallbackUrls: []string{"https://nowhere.com"}} + + Client.Must(Client.LoginById(ruser.Id, "pwd")) + + state := "123" + + if !utils.Cfg.ServiceSettings.EnableOAuthServiceProvider { + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, "12345678901234567890123456", app.CallbackUrls[0], "all", state); err == nil { + t.Fatal("should have failed - oauth service providing turned off") + } + } else { + app = Client.Must(Client.RegisterApp(app)).Data.(*model.OAuthApp) + + if result, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, app.CallbackUrls[0], "all", state); err != nil { + t.Fatal(err) + } else { + redirect := result.Data.(map[string]string)["redirect"] + if len(redirect) == 0 { + t.Fatal("redirect url should be set") + } + + ru, _ := url.Parse(redirect) + if ru == nil { + t.Fatal("redirect url unparseable") + } else { + if len(ru.Query().Get("code")) == 0 { + t.Fatal("authorization code not returned") + } + if ru.Query().Get("state") != state { + t.Fatal("returned state doesn't match") + } + } + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, "", "all", state); err == nil { + t.Fatal("should have failed - no redirect_url given") + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, "", "", state); err == nil { + t.Fatal("should have failed - no redirect_url given") + } + + if result, err := Client.AllowOAuth("junk", app.Id, app.CallbackUrls[0], "all", state); err != nil { + t.Fatal(err) + } else { + redirect := result.Data.(map[string]string)["redirect"] + if len(redirect) == 0 { + t.Fatal("redirect url should be set") + } + + ru, _ := url.Parse(redirect) + if ru == nil { + t.Fatal("redirect url unparseable") + } else { + if ru.Query().Get("error") != "unsupported_response_type" { + t.Fatal("wrong error returned") + } + if ru.Query().Get("state") != state { + t.Fatal("returned state doesn't match") + } + } + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, "", app.CallbackUrls[0], "all", state); err == nil { + t.Fatal("should have failed - empty client id") + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, "junk", app.CallbackUrls[0], "all", state); err == nil { + t.Fatal("should have failed - bad client id") + } + + if _, err := Client.AllowOAuth(model.AUTHCODE_RESPONSE_TYPE, app.Id, "https://somewhereelse.com", "all", state); err == nil { + t.Fatal("should have failed - redirect uri host does not match app host") + } + } +} diff --git a/api/post_test.go b/api/post_test.go index 85d92de3a..4cccfd62a 100644 --- a/api/post_test.go +++ b/api/post_test.go @@ -118,7 +118,7 @@ func TestCreatePost(t *testing.T) { t.Fatal("Should have been forbidden") } - if _, err = Client.DoPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { + if _, err = Client.DoApiPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { t.Fatal("should have been an error") } } @@ -203,7 +203,7 @@ func TestCreateValetPost(t *testing.T) { t.Fatal("Should have been forbidden") } - if _, err = Client.DoPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { + if _, err = Client.DoApiPost("/channels/"+channel3.Id+"/create", "garbage"); err == nil { t.Fatal("should have been an error") } } else { diff --git a/api/team_test.go b/api/team_test.go index 2723eff57..4f1b9e5f0 100644 --- a/api/team_test.go +++ b/api/team_test.go @@ -103,7 +103,7 @@ func TestCreateTeam(t *testing.T) { } } - if _, err := Client.DoPost("/teams/create", "garbage"); err == nil { + if _, err := Client.DoApiPost("/teams/create", "garbage"); err == nil { t.Fatal("should have been an error") } } diff --git a/api/user.go b/api/user.go index cdd9a68be..b42d156ae 100644 --- a/api/user.go +++ b/api/user.go @@ -336,7 +336,7 @@ func Login(c *Context, w http.ResponseWriter, r *http.Request, user *model.User, return } - session := &model.Session{UserId: user.Id, TeamId: user.TeamId, Roles: user.Roles, DeviceId: deviceId} + session := &model.Session{UserId: user.Id, TeamId: user.TeamId, Roles: user.Roles, DeviceId: deviceId, IsOAuth: false} maxAge := model.SESSION_TIME_WEB_IN_SECS @@ -378,13 +378,13 @@ func Login(c *Context, w http.ResponseWriter, r *http.Request, user *model.User, return } else { session = result.Data.(*model.Session) - sessionCache.Add(session.Id, session) + AddSessionToCache(session) } - w.Header().Set(model.HEADER_TOKEN, session.Id) + w.Header().Set(model.HEADER_TOKEN, session.Token) sessionCookie := &http.Cookie{ Name: model.SESSION_TOKEN, - Value: session.Id, + Value: session.Token, Path: "/", MaxAge: maxAge, HttpOnly: true, @@ -430,25 +430,27 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) { func revokeSession(c *Context, w http.ResponseWriter, r *http.Request) { props := model.MapFromJson(r.Body) - altId := props["id"] + id := props["id"] - if result := <-Srv.Store.Session().GetSessions(c.Session.UserId); result.Err != nil { + if result := <-Srv.Store.Session().Get(id); result.Err != nil { c.Err = result.Err return } else { - sessions := result.Data.([]*model.Session) + session := result.Data.(*model.Session) - for _, session := range sessions { - if session.AltId == altId { - c.LogAudit("session_id=" + session.AltId) - sessionCache.Remove(session.Id) - if result := <-Srv.Store.Session().Remove(session.Id); result.Err != nil { - c.Err = result.Err - return - } else { - w.Write([]byte(model.MapToJson(props))) - return - } + c.LogAudit("session_id=" + session.Id) + + if session.IsOAuth { + RevokeAccessToken(session.Token) + } else { + sessionCache.Remove(session.Token) + + if result := <-Srv.Store.Session().Remove(session.Id); result.Err != nil { + c.Err = result.Err + return + } else { + w.Write([]byte(model.MapToJson(props))) + return } } } @@ -462,8 +464,8 @@ func RevokeAllSession(c *Context, userId string) { sessions := result.Data.([]*model.Session) for _, session := range sessions { - c.LogAuditWithUserId(userId, "session_id="+session.AltId) - sessionCache.Remove(session.Id) + c.LogAuditWithUserId(userId, "session_id="+session.Id) + sessionCache.Remove(session.Token) if result := <-Srv.Store.Session().Remove(session.Id); result.Err != nil { c.Err = result.Err return diff --git a/api/user_test.go b/api/user_test.go index fe5a4a27f..986365bd0 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -68,7 +68,7 @@ func TestCreateUser(t *testing.T) { } } - if _, err := Client.DoPost("/users/create", "garbage"); err == nil { + if _, err := Client.DoApiPost("/users/create", "garbage"); err == nil { t.Fatal("should have been an error") } } @@ -190,11 +190,11 @@ func TestSessions(t *testing.T) { for _, session := range sessions { if session.DeviceId == deviceId { - otherSession = session.AltId + otherSession = session.Id } - if len(session.Id) != 0 { - t.Fatal("shouldn't return sessions") + if len(session.Token) != 0 { + t.Fatal("shouldn't return session tokens") } } @@ -212,11 +212,6 @@ func TestSessions(t *testing.T) { if len(sessions2) != 1 { t.Fatal("invalid number of sessions") } - - if _, err := Client.RevokeSession(otherSession); err != nil { - t.Fatal(err) - } - } func TestGetUser(t *testing.T) { @@ -355,7 +350,7 @@ func TestUserCreateImage(t *testing.T) { Client.LoginByEmail(team.Name, user.Email, "pwd") - Client.DoGet("/users/"+user.Id+"/image", "", "") + Client.DoApiGet("/users/"+user.Id+"/image", "", "") if utils.IsS3Configured() && !utils.Cfg.ServiceSettings.UseLocalStorage { var auth aws.Auth @@ -453,7 +448,7 @@ func TestUserUploadProfileImage(t *testing.T) { t.Fatal(upErr) } - Client.DoGet("/users/"+user.Id+"/image", "", "") + Client.DoApiGet("/users/"+user.Id+"/image", "", "") if utils.IsS3Configured() && !utils.Cfg.ServiceSettings.UseLocalStorage { var auth aws.Auth diff --git a/config/config.json b/config/config.json index b97f3f310..4c4fbb255 100644 --- a/config/config.json +++ b/config/config.json @@ -23,7 +23,8 @@ "UseLocalStorage": true, "StorageDirectory": "./data/", "AllowedLoginAttempts": 10, - "DisableEmailSignUp": false + "DisableEmailSignUp": false, + "EnableOAuthServiceProvider": false }, "SSOSettings": { "gitlab": { diff --git a/docker/dev/config_docker.json b/docker/dev/config_docker.json index 794ac95ae..bc42951b8 100644 --- a/docker/dev/config_docker.json +++ b/docker/dev/config_docker.json @@ -23,7 +23,8 @@ "UseLocalStorage": true, "StorageDirectory": "/mattermost/data/", "AllowedLoginAttempts": 10, - "DisableEmailSignUp": false + "DisableEmailSignUp": false, + "EnableOAuthServiceProvider": false }, "SSOSettings": { "gitlab": { diff --git a/docker/local/config_docker.json b/docker/local/config_docker.json index 794ac95ae..bc42951b8 100644 --- a/docker/local/config_docker.json +++ b/docker/local/config_docker.json @@ -23,7 +23,8 @@ "UseLocalStorage": true, "StorageDirectory": "/mattermost/data/", "AllowedLoginAttempts": 10, - "DisableEmailSignUp": false + "DisableEmailSignUp": false, + "EnableOAuthServiceProvider": false }, "SSOSettings": { "gitlab": { diff --git a/manualtesting/manual_testing.go b/manualtesting/manual_testing.go index f7408b814..86b173c6a 100644 --- a/manualtesting/manual_testing.go +++ b/manualtesting/manual_testing.go @@ -53,7 +53,7 @@ func manualTest(c *api.Context, w http.ResponseWriter, r *http.Request) { } // Create a client for tests to use - client := model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port + "/api/v1") + client := model.NewClient("http://localhost:" + utils.Cfg.ServiceSettings.Port) // Check for username parameter and create a user if present username, ok1 := params["username"] @@ -65,7 +65,7 @@ func manualTest(c *api.Context, w http.ResponseWriter, r *http.Request) { // Create team for testing team := &model.Team{ DisplayName: teamDisplayName[0], - Name: utils.RandomName(utils.Range{20, 20}, utils.LOWERCASE), + Name: utils.RandomName(utils.Range{20, 20}, utils.LOWERCASE), Email: utils.RandomEmail(utils.Range{20, 20}, utils.LOWERCASE), Type: model.TEAM_OPEN, } diff --git a/model/access.go b/model/access.go index f9e36ce07..44a0463ac 100644 --- a/model/access.go +++ b/model/access.go @@ -9,17 +9,69 @@ import ( ) const ( - ACCESS_TOKEN_GRANT_TYPE = "authorization_code" - ACCESS_TOKEN_TYPE = "bearer" + ACCESS_TOKEN_GRANT_TYPE = "authorization_code" + ACCESS_TOKEN_TYPE = "bearer" + REFRESH_TOKEN_GRANT_TYPE = "refresh_token" ) +type AccessData struct { + AuthCode string `json:"auth_code"` + Token string `json"token"` + RefreshToken string `json:"refresh_token"` + RedirectUri string `json:"redirect_uri"` +} + type AccessResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int32 `json:"expires_in"` + Scope string `json:"scope"` RefreshToken string `json:"refresh_token"` } +// IsValid validates the AccessData and returns an error if it isn't configured +// correctly. +func (ad *AccessData) IsValid() *AppError { + + if len(ad.AuthCode) == 0 || len(ad.AuthCode) > 128 { + return NewAppError("AccessData.IsValid", "Invalid auth code", "") + } + + if len(ad.Token) != 26 { + return NewAppError("AccessData.IsValid", "Invalid access token", "") + } + + if len(ad.RefreshToken) > 26 { + return NewAppError("AccessData.IsValid", "Invalid refresh token", "") + } + + if len(ad.RedirectUri) > 256 { + return NewAppError("AccessData.IsValid", "Invalid redirect uri", "") + } + + return nil +} + +func (ad *AccessData) ToJson() string { + b, err := json.Marshal(ad) + if err != nil { + return "" + } else { + return string(b) + } +} + +func AccessDataFromJson(data io.Reader) *AccessData { + decoder := json.NewDecoder(data) + var ad AccessData + err := decoder.Decode(&ad) + if err == nil { + return &ad + } else { + return nil + } +} + func (ar *AccessResponse) ToJson() string { b, err := json.Marshal(ar) if err != nil { diff --git a/model/access_test.go b/model/access_test.go new file mode 100644 index 000000000..e385c0586 --- /dev/null +++ b/model/access_test.go @@ -0,0 +1,41 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "strings" + "testing" +) + +func TestAccessJson(t *testing.T) { + a1 := AccessData{} + a1.AuthCode = NewId() + a1.Token = NewId() + a1.RefreshToken = NewId() + + json := a1.ToJson() + ra1 := AccessDataFromJson(strings.NewReader(json)) + + if a1.Token != ra1.Token { + t.Fatal("tokens didn't match") + } +} + +func TestAccessIsValid(t *testing.T) { + ad := AccessData{} + + if err := ad.IsValid(); err == nil { + t.Fatal("should have failed") + } + + ad.AuthCode = NewId() + if err := ad.IsValid(); err == nil { + t.Fatal("should have failed") + } + + ad.Token = NewId() + if err := ad.IsValid(); err != nil { + t.Fatal(err) + } +} diff --git a/model/authorize.go b/model/authorize.go new file mode 100644 index 000000000..6eaac97f1 --- /dev/null +++ b/model/authorize.go @@ -0,0 +1,103 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "encoding/json" + "io" +) + +const ( + AUTHCODE_EXPIRE_TIME = 60 * 10 // 10 minutes + AUTHCODE_RESPONSE_TYPE = "code" +) + +type AuthData struct { + ClientId string `json:"client_id"` + UserId string `json:"user_id"` + Code string `json:"code"` + ExpiresIn int32 `json:"expires_in"` + CreateAt int64 `json:"create_at"` + RedirectUri string `json:"redirect_uri"` + State string `json:"state"` + Scope string `json:"scope"` +} + +// IsValid validates the AuthData and returns an error if it isn't configured +// correctly. +func (ad *AuthData) IsValid() *AppError { + + if len(ad.ClientId) != 26 { + return NewAppError("AuthData.IsValid", "Invalid client id", "") + } + + if len(ad.UserId) != 26 { + return NewAppError("AuthData.IsValid", "Invalid user id", "") + } + + if len(ad.Code) == 0 || len(ad.Code) > 128 { + return NewAppError("AuthData.IsValid", "Invalid authorization code", "client_id="+ad.ClientId) + } + + if ad.ExpiresIn == 0 { + return NewAppError("AuthData.IsValid", "Expires in must be set", "") + } + + if ad.CreateAt <= 0 { + return NewAppError("AuthData.IsValid", "Create at must be a valid time", "client_id="+ad.ClientId) + } + + if len(ad.RedirectUri) > 256 { + return NewAppError("AuthData.IsValid", "Invalid redirect uri", "client_id="+ad.ClientId) + } + + if len(ad.State) > 128 { + return NewAppError("AuthData.IsValid", "Invalid state", "client_id="+ad.ClientId) + } + + if len(ad.Scope) > 128 { + return NewAppError("AuthData.IsValid", "Invalid scope", "client_id="+ad.ClientId) + } + + return nil +} + +func (ad *AuthData) PreSave() { + if ad.ExpiresIn == 0 { + ad.ExpiresIn = AUTHCODE_EXPIRE_TIME + } + + if ad.CreateAt == 0 { + ad.CreateAt = GetMillis() + } +} + +func (ad *AuthData) ToJson() string { + b, err := json.Marshal(ad) + if err != nil { + return "" + } else { + return string(b) + } +} + +func AuthDataFromJson(data io.Reader) *AuthData { + decoder := json.NewDecoder(data) + var ad AuthData + err := decoder.Decode(&ad) + if err == nil { + return &ad + } else { + return nil + } +} + +func (ad *AuthData) IsExpired() bool { + + if GetMillis() > ad.CreateAt+int64(ad.ExpiresIn*1000) { + return true + } + + return false +} diff --git a/model/authorize_test.go b/model/authorize_test.go new file mode 100644 index 000000000..14524ad84 --- /dev/null +++ b/model/authorize_test.go @@ -0,0 +1,66 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "strings" + "testing" +) + +func TestAuthJson(t *testing.T) { + a1 := AuthData{} + a1.ClientId = NewId() + a1.UserId = NewId() + a1.Code = NewId() + + json := a1.ToJson() + ra1 := AuthDataFromJson(strings.NewReader(json)) + + if a1.Code != ra1.Code { + t.Fatal("codes didn't match") + } +} + +func TestAuthPreSave(t *testing.T) { + a1 := AuthData{} + a1.ClientId = NewId() + a1.UserId = NewId() + a1.Code = NewId() + a1.PreSave() + a1.IsExpired() +} + +func TestAuthIsValid(t *testing.T) { + + ad := AuthData{} + + if err := ad.IsValid(); err == nil { + t.Fatal() + } + + ad.ClientId = NewId() + if err := ad.IsValid(); err == nil { + t.Fatal() + } + + ad.UserId = NewId() + if err := ad.IsValid(); err == nil { + t.Fatal() + } + + ad.Code = NewId() + if err := ad.IsValid(); err == nil { + t.Fatal() + } + + ad.ExpiresIn = 1 + if err := ad.IsValid(); err == nil { + t.Fatal() + } + + ad.CreateAt = 1 + if err := ad.IsValid(); err != nil { + t.Fatal() + } +} diff --git a/model/client.go b/model/client.go index 204d08e69..9a89e8208 100644 --- a/model/client.go +++ b/model/client.go @@ -23,7 +23,9 @@ const ( HEADER_FORWARDED = "X-Forwarded-For" HEADER_FORWARDED_PROTO = "X-Forwarded-Proto" HEADER_TOKEN = "token" + HEADER_BEARER = "BEARER" HEADER_AUTH = "Authorization" + API_URL_SUFFIX = "/api/v1" ) type Result struct { @@ -33,22 +35,37 @@ type Result struct { } type Client struct { - Url string // The location of the server like "http://localhost/api/v1" + Url string // The location of the server like "http://localhost:8065" + ApiUrl string // The api location of the server like "http://localhost:8065/api/v1" HttpClient *http.Client // The http client AuthToken string + AuthType string } // NewClient constructs a new client with convienence methods for talking to // the server. func NewClient(url string) *Client { - return &Client{url, &http.Client{}, ""} + return &Client{url, url + API_URL_SUFFIX, &http.Client{}, "", ""} } -func (c *Client) DoPost(url string, data string) (*http.Response, *AppError) { +func (c *Client) DoPost(url string, data, contentType string) (*http.Response, *AppError) { rq, _ := http.NewRequest("POST", c.Url+url, strings.NewReader(data)) + rq.Header.Set("Content-Type", contentType) + + if rp, err := c.HttpClient.Do(rq); err != nil { + return nil, NewAppError(url, "We encountered an error while connecting to the server", err.Error()) + } else if rp.StatusCode >= 300 { + return nil, AppErrorFromJson(rp.Body) + } else { + return rp, nil + } +} + +func (c *Client) DoApiPost(url string, data string) (*http.Response, *AppError) { + rq, _ := http.NewRequest("POST", c.ApiUrl+url, strings.NewReader(data)) if len(c.AuthToken) > 0 { - rq.Header.Set(HEADER_AUTH, "BEARER "+c.AuthToken) + rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) } if rp, err := c.HttpClient.Do(rq); err != nil { @@ -60,15 +77,15 @@ func (c *Client) DoPost(url string, data string) (*http.Response, *AppError) { } } -func (c *Client) DoGet(url string, data string, etag string) (*http.Response, *AppError) { - rq, _ := http.NewRequest("GET", c.Url+url, strings.NewReader(data)) +func (c *Client) DoApiGet(url string, data string, etag string) (*http.Response, *AppError) { + rq, _ := http.NewRequest("GET", c.ApiUrl+url, strings.NewReader(data)) if len(etag) > 0 { rq.Header.Set(HEADER_ETAG_CLIENT, etag) } if len(c.AuthToken) > 0 { - rq.Header.Set(HEADER_AUTH, "BEARER "+c.AuthToken) + rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) } if rp, err := c.HttpClient.Do(rq); err != nil { @@ -106,7 +123,7 @@ func (c *Client) SignupTeam(email string, displayName string) (*Result, *AppErro m := make(map[string]string) m["email"] = email m["display_name"] = displayName - if r, err := c.DoPost("/teams/signup", MapToJson(m)); err != nil { + if r, err := c.DoApiPost("/teams/signup", MapToJson(m)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -115,7 +132,7 @@ func (c *Client) SignupTeam(email string, displayName string) (*Result, *AppErro } func (c *Client) CreateTeamFromSignup(teamSignup *TeamSignup) (*Result, *AppError) { - if r, err := c.DoPost("/teams/create_from_signup", teamSignup.ToJson()); err != nil { + if r, err := c.DoApiPost("/teams/create_from_signup", teamSignup.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -124,7 +141,7 @@ func (c *Client) CreateTeamFromSignup(teamSignup *TeamSignup) (*Result, *AppErro } func (c *Client) CreateTeam(team *Team) (*Result, *AppError) { - if r, err := c.DoPost("/teams/create", team.ToJson()); err != nil { + if r, err := c.DoApiPost("/teams/create", team.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -136,7 +153,7 @@ func (c *Client) FindTeamByName(name string, allServers bool) (*Result, *AppErro m := make(map[string]string) m["name"] = name m["all"] = fmt.Sprintf("%v", allServers) - if r, err := c.DoPost("/teams/find_team_by_name", MapToJson(m)); err != nil { + if r, err := c.DoApiPost("/teams/find_team_by_name", MapToJson(m)); err != nil { return nil, err } else { val := false @@ -152,7 +169,7 @@ func (c *Client) FindTeamByName(name string, allServers bool) (*Result, *AppErro func (c *Client) FindTeams(email string) (*Result, *AppError) { m := make(map[string]string) m["email"] = email - if r, err := c.DoPost("/teams/find_teams", MapToJson(m)); err != nil { + if r, err := c.DoApiPost("/teams/find_teams", MapToJson(m)); err != nil { return nil, err } else { @@ -164,7 +181,7 @@ func (c *Client) FindTeams(email string) (*Result, *AppError) { func (c *Client) FindTeamsSendEmail(email string) (*Result, *AppError) { m := make(map[string]string) m["email"] = email - if r, err := c.DoPost("/teams/email_teams", MapToJson(m)); err != nil { + if r, err := c.DoApiPost("/teams/email_teams", MapToJson(m)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -173,7 +190,7 @@ func (c *Client) FindTeamsSendEmail(email string) (*Result, *AppError) { } func (c *Client) InviteMembers(invites *Invites) (*Result, *AppError) { - if r, err := c.DoPost("/teams/invite_members", invites.ToJson()); err != nil { + if r, err := c.DoApiPost("/teams/invite_members", invites.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -182,7 +199,7 @@ func (c *Client) InviteMembers(invites *Invites) (*Result, *AppError) { } func (c *Client) UpdateTeamDisplayName(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/teams/update_name", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/teams/update_name", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -191,7 +208,7 @@ func (c *Client) UpdateTeamDisplayName(data map[string]string) (*Result, *AppErr } func (c *Client) UpdateValetFeature(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/teams/update_valet_feature", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/teams/update_valet_feature", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -200,7 +217,7 @@ func (c *Client) UpdateValetFeature(data map[string]string) (*Result, *AppError) } func (c *Client) CreateUser(user *User, hash string) (*Result, *AppError) { - if r, err := c.DoPost("/users/create", user.ToJson()); err != nil { + if r, err := c.DoApiPost("/users/create", user.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -209,7 +226,7 @@ func (c *Client) CreateUser(user *User, hash string) (*Result, *AppError) { } func (c *Client) CreateUserFromSignup(user *User, data string, hash string) (*Result, *AppError) { - if r, err := c.DoPost("/users/create?d="+data+"&h="+hash, user.ToJson()); err != nil { + if r, err := c.DoApiPost("/users/create?d="+data+"&h="+hash, user.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -218,7 +235,7 @@ func (c *Client) CreateUserFromSignup(user *User, data string, hash string) (*Re } func (c *Client) GetUser(id string, etag string) (*Result, *AppError) { - if r, err := c.DoGet("/users/"+id, "", etag); err != nil { + if r, err := c.DoApiGet("/users/"+id, "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -227,7 +244,7 @@ func (c *Client) GetUser(id string, etag string) (*Result, *AppError) { } func (c *Client) GetMe(etag string) (*Result, *AppError) { - if r, err := c.DoGet("/users/me", "", etag); err != nil { + if r, err := c.DoApiGet("/users/me", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -236,7 +253,7 @@ func (c *Client) GetMe(etag string) (*Result, *AppError) { } func (c *Client) GetProfiles(teamId string, etag string) (*Result, *AppError) { - if r, err := c.DoGet("/users/profiles", "", etag); err != nil { + if r, err := c.DoApiGet("/users/profiles", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -269,13 +286,14 @@ func (c *Client) LoginByEmailWithDevice(name string, email string, password stri } func (c *Client) login(m map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/users/login", MapToJson(m)); err != nil { + if r, err := c.DoApiPost("/users/login", MapToJson(m)); err != nil { return nil, err } else { c.AuthToken = r.Header.Get(HEADER_TOKEN) - sessionId := getCookie(SESSION_TOKEN, r) + c.AuthType = HEADER_BEARER + sessionToken := getCookie(SESSION_TOKEN, r) - if c.AuthToken != sessionId.Value { + if c.AuthToken != sessionToken.Value { NewAppError("/users/login", "Authentication tokens didn't match", "") } @@ -285,21 +303,32 @@ func (c *Client) login(m map[string]string) (*Result, *AppError) { } func (c *Client) Logout() (*Result, *AppError) { - if r, err := c.DoPost("/users/logout", ""); err != nil { + if r, err := c.DoApiPost("/users/logout", ""); err != nil { return nil, err } else { c.AuthToken = "" + c.AuthType = HEADER_BEARER return &Result{r.Header.Get(HEADER_REQUEST_ID), r.Header.Get(HEADER_ETAG_SERVER), MapFromJson(r.Body)}, nil } } +func (c *Client) SetOAuthToken(token string) { + c.AuthToken = token + c.AuthType = HEADER_TOKEN +} + +func (c *Client) ClearOAuthToken() { + c.AuthToken = "" + c.AuthType = HEADER_BEARER +} + func (c *Client) RevokeSession(sessionAltId string) (*Result, *AppError) { m := make(map[string]string) m["id"] = sessionAltId - if r, err := c.DoPost("/users/revoke_session", MapToJson(m)); err != nil { + if r, err := c.DoApiPost("/users/revoke_session", MapToJson(m)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -308,7 +337,7 @@ func (c *Client) RevokeSession(sessionAltId string) (*Result, *AppError) { } func (c *Client) GetSessions(id string) (*Result, *AppError) { - if r, err := c.DoGet("/users/"+id+"/sessions", "", ""); err != nil { + if r, err := c.DoApiGet("/users/"+id+"/sessions", "", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -321,7 +350,7 @@ func (c *Client) Command(channelId string, command string, suggest bool) (*Resul m["command"] = command m["channelId"] = channelId m["suggest"] = strconv.FormatBool(suggest) - if r, err := c.DoPost("/command", MapToJson(m)); err != nil { + if r, err := c.DoApiPost("/command", MapToJson(m)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -330,7 +359,7 @@ func (c *Client) Command(channelId string, command string, suggest bool) (*Resul } func (c *Client) GetAudits(id string, etag string) (*Result, *AppError) { - if r, err := c.DoGet("/users/"+id+"/audits", "", etag); err != nil { + if r, err := c.DoApiGet("/users/"+id+"/audits", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -339,7 +368,7 @@ func (c *Client) GetAudits(id string, etag string) (*Result, *AppError) { } func (c *Client) GetLogs() (*Result, *AppError) { - if r, err := c.DoGet("/admin/logs", "", ""); err != nil { + if r, err := c.DoApiGet("/admin/logs", "", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -348,7 +377,7 @@ func (c *Client) GetLogs() (*Result, *AppError) { } func (c *Client) GetClientProperties() (*Result, *AppError) { - if r, err := c.DoGet("/admin/client_props", "", ""); err != nil { + if r, err := c.DoApiGet("/admin/client_props", "", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -357,7 +386,7 @@ func (c *Client) GetClientProperties() (*Result, *AppError) { } func (c *Client) CreateChannel(channel *Channel) (*Result, *AppError) { - if r, err := c.DoPost("/channels/create", channel.ToJson()); err != nil { + if r, err := c.DoApiPost("/channels/create", channel.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -366,7 +395,7 @@ func (c *Client) CreateChannel(channel *Channel) (*Result, *AppError) { } func (c *Client) CreateDirectChannel(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/channels/create_direct", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/channels/create_direct", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -375,7 +404,7 @@ func (c *Client) CreateDirectChannel(data map[string]string) (*Result, *AppError } func (c *Client) UpdateChannel(channel *Channel) (*Result, *AppError) { - if r, err := c.DoPost("/channels/update", channel.ToJson()); err != nil { + if r, err := c.DoApiPost("/channels/update", channel.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -384,7 +413,7 @@ func (c *Client) UpdateChannel(channel *Channel) (*Result, *AppError) { } func (c *Client) UpdateChannelDesc(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/channels/update_desc", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/channels/update_desc", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -393,7 +422,7 @@ func (c *Client) UpdateChannelDesc(data map[string]string) (*Result, *AppError) } func (c *Client) UpdateNotifyLevel(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/channels/update_notify_level", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/channels/update_notify_level", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -402,7 +431,7 @@ func (c *Client) UpdateNotifyLevel(data map[string]string) (*Result, *AppError) } func (c *Client) GetChannels(etag string) (*Result, *AppError) { - if r, err := c.DoGet("/channels/", "", etag); err != nil { + if r, err := c.DoApiGet("/channels/", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -411,7 +440,7 @@ func (c *Client) GetChannels(etag string) (*Result, *AppError) { } func (c *Client) GetChannel(id, etag string) (*Result, *AppError) { - if r, err := c.DoGet("/channels/"+id+"/", "", etag); err != nil { + if r, err := c.DoApiGet("/channels/"+id+"/", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -420,7 +449,7 @@ func (c *Client) GetChannel(id, etag string) (*Result, *AppError) { } func (c *Client) GetMoreChannels(etag string) (*Result, *AppError) { - if r, err := c.DoGet("/channels/more", "", etag); err != nil { + if r, err := c.DoApiGet("/channels/more", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -429,7 +458,7 @@ func (c *Client) GetMoreChannels(etag string) (*Result, *AppError) { } func (c *Client) GetChannelCounts(etag string) (*Result, *AppError) { - if r, err := c.DoGet("/channels/counts", "", etag); err != nil { + if r, err := c.DoApiGet("/channels/counts", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -438,7 +467,7 @@ func (c *Client) GetChannelCounts(etag string) (*Result, *AppError) { } func (c *Client) JoinChannel(id string) (*Result, *AppError) { - if r, err := c.DoPost("/channels/"+id+"/join", ""); err != nil { + if r, err := c.DoApiPost("/channels/"+id+"/join", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -447,7 +476,7 @@ func (c *Client) JoinChannel(id string) (*Result, *AppError) { } func (c *Client) LeaveChannel(id string) (*Result, *AppError) { - if r, err := c.DoPost("/channels/"+id+"/leave", ""); err != nil { + if r, err := c.DoApiPost("/channels/"+id+"/leave", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -456,7 +485,7 @@ func (c *Client) LeaveChannel(id string) (*Result, *AppError) { } func (c *Client) DeleteChannel(id string) (*Result, *AppError) { - if r, err := c.DoPost("/channels/"+id+"/delete", ""); err != nil { + if r, err := c.DoApiPost("/channels/"+id+"/delete", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -467,7 +496,7 @@ func (c *Client) DeleteChannel(id string) (*Result, *AppError) { func (c *Client) AddChannelMember(id, user_id string) (*Result, *AppError) { data := make(map[string]string) data["user_id"] = user_id - if r, err := c.DoPost("/channels/"+id+"/add", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/channels/"+id+"/add", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -478,7 +507,7 @@ func (c *Client) AddChannelMember(id, user_id string) (*Result, *AppError) { func (c *Client) RemoveChannelMember(id, user_id string) (*Result, *AppError) { data := make(map[string]string) data["user_id"] = user_id - if r, err := c.DoPost("/channels/"+id+"/remove", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/channels/"+id+"/remove", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -487,7 +516,7 @@ func (c *Client) RemoveChannelMember(id, user_id string) (*Result, *AppError) { } func (c *Client) UpdateLastViewedAt(channelId string) (*Result, *AppError) { - if r, err := c.DoPost("/channels/"+channelId+"/update_last_viewed_at", ""); err != nil { + if r, err := c.DoApiPost("/channels/"+channelId+"/update_last_viewed_at", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -496,7 +525,7 @@ func (c *Client) UpdateLastViewedAt(channelId string) (*Result, *AppError) { } func (c *Client) GetChannelExtraInfo(id string, etag string) (*Result, *AppError) { - if r, err := c.DoGet("/channels/"+id+"/extra_info", "", etag); err != nil { + if r, err := c.DoApiGet("/channels/"+id+"/extra_info", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -505,7 +534,7 @@ func (c *Client) GetChannelExtraInfo(id string, etag string) (*Result, *AppError } func (c *Client) CreatePost(post *Post) (*Result, *AppError) { - if r, err := c.DoPost("/channels/"+post.ChannelId+"/create", post.ToJson()); err != nil { + if r, err := c.DoApiPost("/channels/"+post.ChannelId+"/create", post.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -514,7 +543,7 @@ func (c *Client) CreatePost(post *Post) (*Result, *AppError) { } func (c *Client) CreateValetPost(post *Post) (*Result, *AppError) { - if r, err := c.DoPost("/channels/"+post.ChannelId+"/valet_create", post.ToJson()); err != nil { + if r, err := c.DoApiPost("/channels/"+post.ChannelId+"/valet_create", post.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -523,7 +552,7 @@ func (c *Client) CreateValetPost(post *Post) (*Result, *AppError) { } func (c *Client) UpdatePost(post *Post) (*Result, *AppError) { - if r, err := c.DoPost("/channels/"+post.ChannelId+"/update", post.ToJson()); err != nil { + if r, err := c.DoApiPost("/channels/"+post.ChannelId+"/update", post.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -532,7 +561,7 @@ func (c *Client) UpdatePost(post *Post) (*Result, *AppError) { } func (c *Client) GetPosts(channelId string, offset int, limit int, etag string) (*Result, *AppError) { - if r, err := c.DoGet(fmt.Sprintf("/channels/%v/posts/%v/%v", channelId, offset, limit), "", etag); err != nil { + if r, err := c.DoApiGet(fmt.Sprintf("/channels/%v/posts/%v/%v", channelId, offset, limit), "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -541,7 +570,7 @@ func (c *Client) GetPosts(channelId string, offset int, limit int, etag string) } func (c *Client) GetPostsSince(channelId string, time int64) (*Result, *AppError) { - if r, err := c.DoGet(fmt.Sprintf("/channels/%v/posts/%v", channelId, time), "", ""); err != nil { + if r, err := c.DoApiGet(fmt.Sprintf("/channels/%v/posts/%v", channelId, time), "", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -550,7 +579,7 @@ func (c *Client) GetPostsSince(channelId string, time int64) (*Result, *AppError } func (c *Client) GetPost(channelId string, postId string, etag string) (*Result, *AppError) { - if r, err := c.DoGet(fmt.Sprintf("/channels/%v/post/%v", channelId, postId), "", etag); err != nil { + if r, err := c.DoApiGet(fmt.Sprintf("/channels/%v/post/%v", channelId, postId), "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -559,7 +588,7 @@ func (c *Client) GetPost(channelId string, postId string, etag string) (*Result, } func (c *Client) DeletePost(channelId string, postId string) (*Result, *AppError) { - if r, err := c.DoPost(fmt.Sprintf("/channels/%v/post/%v/delete", channelId, postId), ""); err != nil { + if r, err := c.DoApiPost(fmt.Sprintf("/channels/%v/post/%v/delete", channelId, postId), ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -568,7 +597,7 @@ func (c *Client) DeletePost(channelId string, postId string) (*Result, *AppError } func (c *Client) SearchPosts(terms string) (*Result, *AppError) { - if r, err := c.DoGet("/posts/search?terms="+url.QueryEscape(terms), "", ""); err != nil { + if r, err := c.DoApiGet("/posts/search?terms="+url.QueryEscape(terms), "", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -577,7 +606,7 @@ func (c *Client) SearchPosts(terms string) (*Result, *AppError) { } func (c *Client) UploadFile(url string, data []byte, contentType string) (*Result, *AppError) { - rq, _ := http.NewRequest("POST", c.Url+url, bytes.NewReader(data)) + rq, _ := http.NewRequest("POST", c.ApiUrl+url, bytes.NewReader(data)) rq.Header.Set("Content-Type", contentType) if len(c.AuthToken) > 0 { @@ -599,7 +628,7 @@ func (c *Client) GetFile(url string, isFullUrl bool) (*Result, *AppError) { if isFullUrl { rq, _ = http.NewRequest("GET", url, nil) } else { - rq, _ = http.NewRequest("GET", c.Url+"/files/get"+url, nil) + rq, _ = http.NewRequest("GET", c.ApiUrl+"/files/get"+url, nil) } if len(c.AuthToken) > 0 { @@ -618,7 +647,7 @@ func (c *Client) GetFile(url string, isFullUrl bool) (*Result, *AppError) { func (c *Client) GetFileInfo(url string) (*Result, *AppError) { var rq *http.Request - rq, _ = http.NewRequest("GET", c.Url+"/files/get_info"+url, nil) + rq, _ = http.NewRequest("GET", c.ApiUrl+"/files/get_info"+url, nil) if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, "BEARER "+c.AuthToken) @@ -635,7 +664,7 @@ func (c *Client) GetFileInfo(url string) (*Result, *AppError) { } func (c *Client) GetPublicLink(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/files/get_public_link", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/files/get_public_link", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -644,7 +673,7 @@ func (c *Client) GetPublicLink(data map[string]string) (*Result, *AppError) { } func (c *Client) UpdateUser(user *User) (*Result, *AppError) { - if r, err := c.DoPost("/users/update", user.ToJson()); err != nil { + if r, err := c.DoApiPost("/users/update", user.ToJson()); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -653,7 +682,7 @@ func (c *Client) UpdateUser(user *User) (*Result, *AppError) { } func (c *Client) UpdateUserRoles(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/users/update_roles", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/users/update_roles", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -665,7 +694,7 @@ func (c *Client) UpdateActive(userId string, active bool) (*Result, *AppError) { data := make(map[string]string) data["user_id"] = userId data["active"] = strconv.FormatBool(active) - if r, err := c.DoPost("/users/update_active", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/users/update_active", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -674,7 +703,7 @@ func (c *Client) UpdateActive(userId string, active bool) (*Result, *AppError) { } func (c *Client) UpdateUserNotify(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/users/update_notify", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/users/update_notify", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -688,7 +717,7 @@ func (c *Client) UpdateUserPassword(userId, currentPassword, newPassword string) data["new_password"] = newPassword data["user_id"] = userId - if r, err := c.DoPost("/users/newpassword", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/users/newpassword", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -697,7 +726,7 @@ func (c *Client) UpdateUserPassword(userId, currentPassword, newPassword string) } func (c *Client) SendPasswordReset(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/users/send_password_reset", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/users/send_password_reset", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -706,7 +735,7 @@ func (c *Client) SendPasswordReset(data map[string]string) (*Result, *AppError) } func (c *Client) ResetPassword(data map[string]string) (*Result, *AppError) { - if r, err := c.DoPost("/users/reset_password", MapToJson(data)); err != nil { + if r, err := c.DoApiPost("/users/reset_password", MapToJson(data)); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -715,7 +744,7 @@ func (c *Client) ResetPassword(data map[string]string) (*Result, *AppError) { } func (c *Client) GetStatuses() (*Result, *AppError) { - if r, err := c.DoGet("/users/status", "", ""); err != nil { + if r, err := c.DoApiGet("/users/status", "", ""); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -724,7 +753,7 @@ func (c *Client) GetStatuses() (*Result, *AppError) { } func (c *Client) GetMyTeam(etag string) (*Result, *AppError) { - if r, err := c.DoGet("/teams/me", "", etag); err != nil { + if r, err := c.DoApiGet("/teams/me", "", etag); err != nil { return nil, err } else { return &Result{r.Header.Get(HEADER_REQUEST_ID), @@ -732,6 +761,33 @@ func (c *Client) GetMyTeam(etag string) (*Result, *AppError) { } } +func (c *Client) RegisterApp(app *OAuthApp) (*Result, *AppError) { + if r, err := c.DoApiPost("/oauth/register", app.ToJson()); err != nil { + return nil, err + } else { + return &Result{r.Header.Get(HEADER_REQUEST_ID), + r.Header.Get(HEADER_ETAG_SERVER), OAuthAppFromJson(r.Body)}, nil + } +} + +func (c *Client) AllowOAuth(rspType, clientId, redirect, scope, state string) (*Result, *AppError) { + if r, err := c.DoApiGet("/oauth/allow?response_type="+rspType+"&client_id="+clientId+"&redirect_uri="+url.QueryEscape(redirect)+"&scope="+scope+"&state="+url.QueryEscape(state), "", ""); err != nil { + return nil, err + } else { + return &Result{r.Header.Get(HEADER_REQUEST_ID), + r.Header.Get(HEADER_ETAG_SERVER), MapFromJson(r.Body)}, nil + } +} + +func (c *Client) GetAccessToken(data url.Values) (*Result, *AppError) { + if r, err := c.DoPost("/oauth/access_token", data.Encode(), "application/x-www-form-urlencoded"); err != nil { + return nil, err + } else { + return &Result{r.Header.Get(HEADER_REQUEST_ID), + r.Header.Get(HEADER_ETAG_SERVER), AccessResponseFromJson(r.Body)}, nil + } +} + func (c *Client) MockSession(sessionToken string) { c.AuthToken = sessionToken } diff --git a/model/oauth.go b/model/oauth.go new file mode 100644 index 000000000..3b31e677d --- /dev/null +++ b/model/oauth.go @@ -0,0 +1,151 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "encoding/json" + "fmt" + "io" +) + +type OAuthApp struct { + Id string `json:"id"` + CreatorId string `json:"creator_id"` + CreateAt int64 `json:"update_at"` + UpdateAt int64 `json:"update_at"` + ClientSecret string `json:"client_secret"` + Name string `json:"name"` + Description string `json:"description"` + CallbackUrls StringArray `json:"callback_urls"` + Homepage string `json:"homepage"` +} + +// IsValid validates the app and returns an error if it isn't configured +// correctly. +func (a *OAuthApp) IsValid() *AppError { + + if len(a.Id) != 26 { + return NewAppError("OAuthApp.IsValid", "Invalid app id", "") + } + + if a.CreateAt == 0 { + return NewAppError("OAuthApp.IsValid", "Create at must be a valid time", "app_id="+a.Id) + } + + if a.UpdateAt == 0 { + return NewAppError("OAuthApp.IsValid", "Update at must be a valid time", "app_id="+a.Id) + } + + if len(a.CreatorId) != 26 { + return NewAppError("OAuthApp.IsValid", "Invalid creator id", "app_id="+a.Id) + } + + if len(a.ClientSecret) == 0 || len(a.ClientSecret) > 128 { + return NewAppError("OAuthApp.IsValid", "Invalid client secret", "app_id="+a.Id) + } + + if len(a.Name) == 0 || len(a.Name) > 64 { + return NewAppError("OAuthApp.IsValid", "Invalid name", "app_id="+a.Id) + } + + if len(a.CallbackUrls) == 0 || len(fmt.Sprintf("%s", a.CallbackUrls)) > 1024 { + return NewAppError("OAuthApp.IsValid", "Invalid callback urls", "app_id="+a.Id) + } + + if len(a.Homepage) == 0 || len(a.Homepage) > 256 { + return NewAppError("OAuthApp.IsValid", "Invalid homepage", "app_id="+a.Id) + } + + if len(a.Description) > 512 { + return NewAppError("OAuthApp.IsValid", "Invalid description", "app_id="+a.Id) + } + + return nil +} + +// PreSave will set the Id and ClientSecret if missing. It will also fill +// in the CreateAt, UpdateAt times. It should be run before saving the app to the db. +func (a *OAuthApp) PreSave() { + if a.Id == "" { + a.Id = NewId() + } + + if a.ClientSecret == "" { + a.ClientSecret = NewId() + } + + a.CreateAt = GetMillis() + a.UpdateAt = a.CreateAt + + if len(a.ClientSecret) > 0 { + a.ClientSecret = HashPassword(a.ClientSecret) + } +} + +// PreUpdate should be run before updating the app in the db. +func (a *OAuthApp) PreUpdate() { + a.UpdateAt = GetMillis() +} + +// ToJson convert a User to a json string +func (a *OAuthApp) ToJson() string { + b, err := json.Marshal(a) + if err != nil { + return "" + } else { + return string(b) + } +} + +// Generate a valid strong etag so the browser can cache the results +func (a *OAuthApp) Etag() string { + return Etag(a.Id, a.UpdateAt) +} + +// Remove any private data from the app object +func (a *OAuthApp) Sanitize() { + a.ClientSecret = "" +} + +func (a *OAuthApp) IsValidRedirectURL(url string) bool { + for _, u := range a.CallbackUrls { + if u == url { + return true + } + } + + return false +} + +// OAuthAppFromJson will decode the input and return a User +func OAuthAppFromJson(data io.Reader) *OAuthApp { + decoder := json.NewDecoder(data) + var app OAuthApp + err := decoder.Decode(&app) + if err == nil { + return &app + } else { + return nil + } +} + +func OAuthAppMapToJson(a map[string]*OAuthApp) string { + b, err := json.Marshal(a) + if err != nil { + return "" + } else { + return string(b) + } +} + +func OAuthAppMapFromJson(data io.Reader) map[string]*OAuthApp { + decoder := json.NewDecoder(data) + var apps map[string]*OAuthApp + err := decoder.Decode(&apps) + if err == nil { + return apps + } else { + return nil + } +} diff --git a/model/oauth_test.go b/model/oauth_test.go new file mode 100644 index 000000000..2530ead98 --- /dev/null +++ b/model/oauth_test.go @@ -0,0 +1,95 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "strings" + "testing" +) + +func TestOAuthAppJson(t *testing.T) { + a1 := OAuthApp{} + a1.Id = NewId() + a1.Name = "TestOAuthApp" + NewId() + a1.CallbackUrls = []string{"https://nowhere.com"} + a1.Homepage = "https://nowhere.com" + a1.ClientSecret = NewId() + + json := a1.ToJson() + ra1 := OAuthAppFromJson(strings.NewReader(json)) + + if a1.Id != ra1.Id { + t.Fatal("ids did not match") + } +} + +func TestOAuthAppPreSave(t *testing.T) { + a1 := OAuthApp{} + a1.Id = NewId() + a1.Name = "TestOAuthApp" + NewId() + a1.CallbackUrls = []string{"https://nowhere.com"} + a1.Homepage = "https://nowhere.com" + a1.ClientSecret = NewId() + a1.PreSave() + a1.Etag() + a1.Sanitize() +} + +func TestOAuthAppPreUpdate(t *testing.T) { + a1 := OAuthApp{} + a1.Id = NewId() + a1.Name = "TestOAuthApp" + NewId() + a1.CallbackUrls = []string{"https://nowhere.com"} + a1.Homepage = "https://nowhere.com" + a1.ClientSecret = NewId() + a1.PreUpdate() +} + +func TestOAuthAppIsValid(t *testing.T) { + app := OAuthApp{} + + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.Id = NewId() + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.CreateAt = 1 + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.UpdateAt = 1 + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.CreatorId = NewId() + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.ClientSecret = NewId() + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.Name = "TestOAuthApp" + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.CallbackUrls = []string{"https://nowhere.com"} + if err := app.IsValid(); err == nil { + t.Fatal() + } + + app.Homepage = "https://nowhere.com" + if err := app.IsValid(); err != nil { + t.Fatal() + } +} diff --git a/model/session.go b/model/session.go index c812f83e2..3c7c75eb4 100644 --- a/model/session.go +++ b/model/session.go @@ -14,6 +14,8 @@ const ( SESSION_TIME_WEB_IN_SECS = 60 * 60 * 24 * SESSION_TIME_WEB_IN_DAYS SESSION_TIME_MOBILE_IN_DAYS = 30 SESSION_TIME_MOBILE_IN_SECS = 60 * 60 * 24 * SESSION_TIME_MOBILE_IN_DAYS + SESSION_TIME_OAUTH_IN_DAYS = 365 + SESSION_TIME_OAUTH_IN_SECS = 60 * 60 * 24 * SESSION_TIME_OAUTH_IN_DAYS SESSION_CACHE_IN_SECS = 60 * 10 SESSION_CACHE_SIZE = 10000 SESSION_PROP_PLATFORM = "platform" @@ -23,7 +25,7 @@ const ( type Session struct { Id string `json:"id"` - AltId string `json:"alt_id"` + Token string `json:"token"` CreateAt int64 `json:"create_at"` ExpiresAt int64 `json:"expires_at"` LastActivityAt int64 `json:"last_activity_at"` @@ -31,6 +33,7 @@ type Session struct { TeamId string `json:"team_id"` DeviceId string `json:"device_id"` Roles string `json:"roles"` + IsOAuth bool `json:"is_oauth"` Props StringMap `json:"props"` } @@ -59,7 +62,7 @@ func (me *Session) PreSave() { me.Id = NewId() } - me.AltId = NewId() + me.Token = NewId() me.CreateAt = GetMillis() me.LastActivityAt = me.CreateAt @@ -70,7 +73,7 @@ func (me *Session) PreSave() { } func (me *Session) Sanitize() { - me.Id = "" + me.Token = "" } func (me *Session) IsExpired() bool { diff --git a/model/utils.go b/model/utils.go index d5122e805..04b92947b 100644 --- a/model/utils.go +++ b/model/utils.go @@ -32,6 +32,7 @@ type AppError struct { RequestId string `json:"request_id"` // The RequestId that's also set in the header StatusCode int `json:"status_code"` // The http status code Where string `json:"-"` // The function where it happened in the form of Struct.Func + IsOAuth bool `json:"is_oauth"` // Whether the error is OAuth specific } func (er *AppError) Error() string { @@ -65,6 +66,7 @@ func NewAppError(where string, message string, details string) *AppError { ap.Where = where ap.DetailedError = details ap.StatusCode = 500 + ap.IsOAuth = false return ap } diff --git a/store/sql_oauth_store.go b/store/sql_oauth_store.go new file mode 100644 index 000000000..2a6fa3118 --- /dev/null +++ b/store/sql_oauth_store.go @@ -0,0 +1,334 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package store + +import ( + "github.com/mattermost/platform/model" + "strings" +) + +type SqlOAuthStore struct { + *SqlStore +} + +func NewSqlOAuthStore(sqlStore *SqlStore) OAuthStore { + as := &SqlOAuthStore{sqlStore} + + for _, db := range sqlStore.GetAllConns() { + table := db.AddTableWithName(model.OAuthApp{}, "OAuthApps").SetKeys(false, "Id") + table.ColMap("Id").SetMaxSize(26) + table.ColMap("CreatorId").SetMaxSize(26) + table.ColMap("ClientSecret").SetMaxSize(128) + table.ColMap("Name").SetMaxSize(64) + table.ColMap("Description").SetMaxSize(512) + table.ColMap("CallbackUrls").SetMaxSize(1024) + table.ColMap("Homepage").SetMaxSize(256) + + tableAuth := db.AddTableWithName(model.AuthData{}, "OAuthAuthData").SetKeys(false, "Code") + tableAuth.ColMap("UserId").SetMaxSize(26) + tableAuth.ColMap("ClientId").SetMaxSize(26) + tableAuth.ColMap("Code").SetMaxSize(128) + tableAuth.ColMap("RedirectUri").SetMaxSize(256) + tableAuth.ColMap("State").SetMaxSize(128) + tableAuth.ColMap("Scope").SetMaxSize(128) + + tableAccess := db.AddTableWithName(model.AccessData{}, "OAuthAccessData").SetKeys(false, "Token") + tableAccess.ColMap("AuthCode").SetMaxSize(128) + tableAccess.ColMap("Token").SetMaxSize(26) + tableAccess.ColMap("RefreshToken").SetMaxSize(26) + tableAccess.ColMap("RedirectUri").SetMaxSize(256) + } + + return as +} + +func (as SqlOAuthStore) UpgradeSchemaIfNeeded() { +} + +func (as SqlOAuthStore) CreateIndexesIfNotExists() { + as.CreateIndexIfNotExists("idx_oauthapps_creator_id", "OAuthApps", "CreatorId") + as.CreateIndexIfNotExists("idx_oauthaccessdata_auth_code", "OAuthAccessData", "AuthCode") + as.CreateIndexIfNotExists("idx_oauthauthdata_client_id", "OAuthAuthData", "Code") +} + +func (as SqlOAuthStore) SaveApp(app *model.OAuthApp) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + if len(app.Id) > 0 { + result.Err = model.NewAppError("SqlOAuthStore.SaveApp", "Must call update for exisiting app", "app_id="+app.Id) + storeChannel <- result + close(storeChannel) + return + } + + app.PreSave() + if result.Err = app.IsValid(); result.Err != nil { + storeChannel <- result + close(storeChannel) + return + } + + if err := as.GetMaster().Insert(app); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.SaveApp", "We couldn't save the app.", "app_id="+app.Id+", "+err.Error()) + } else { + result.Data = app + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (as SqlOAuthStore) UpdateApp(app *model.OAuthApp) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + app.PreUpdate() + + if result.Err = app.IsValid(); result.Err != nil { + storeChannel <- result + close(storeChannel) + return + } + + if oldAppResult, err := as.GetMaster().Get(model.OAuthApp{}, app.Id); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.UpdateApp", "We encounted an error finding the app", "app_id="+app.Id+", "+err.Error()) + } else if oldAppResult == nil { + result.Err = model.NewAppError("SqlOAuthStore.UpdateApp", "We couldn't find the existing app to update", "app_id="+app.Id) + } else { + oldApp := oldAppResult.(*model.OAuthApp) + app.CreateAt = oldApp.CreateAt + app.ClientSecret = oldApp.ClientSecret + app.CreatorId = oldApp.CreatorId + + if count, err := as.GetMaster().Update(app); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.UpdateApp", "We encounted an error updating the app", "app_id="+app.Id+", "+err.Error()) + } else if count != 1 { + result.Err = model.NewAppError("SqlOAuthStore.UpdateApp", "We couldn't update the app", "app_id="+app.Id) + } else { + result.Data = [2]*model.OAuthApp{app, oldApp} + } + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (as SqlOAuthStore) GetApp(id string) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + if obj, err := as.GetReplica().Get(model.OAuthApp{}, id); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.GetApp", "We encounted an error finding the app", "app_id="+id+", "+err.Error()) + } else if obj == nil { + result.Err = model.NewAppError("SqlOAuthStore.GetApp", "We couldn't find the existing app", "app_id="+id) + } else { + result.Data = obj.(*model.OAuthApp) + } + + storeChannel <- result + close(storeChannel) + + }() + + return storeChannel +} + +func (as SqlOAuthStore) GetAppByUser(userId string) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + var apps []*model.OAuthApp + + if _, err := as.GetReplica().Select(&apps, "SELECT * FROM OAuthApps WHERE CreatorId = :UserId", map[string]interface{}{"UserId": userId}); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.GetAppByUser", "We couldn't find any existing apps", "user_id="+userId+", "+err.Error()) + } + + result.Data = apps + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (as SqlOAuthStore) SaveAccessData(accessData *model.AccessData) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + if result.Err = accessData.IsValid(); result.Err != nil { + storeChannel <- result + close(storeChannel) + return + } + + if err := as.GetMaster().Insert(accessData); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.SaveAccessData", "We couldn't save the access token.", err.Error()) + } else { + result.Data = accessData + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (as SqlOAuthStore) GetAccessData(token string) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + accessData := model.AccessData{} + + if err := as.GetReplica().SelectOne(&accessData, "SELECT * FROM OAuthAccessData WHERE Token = :Token", map[string]interface{}{"Token": token}); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.GetAccessData", "We encounted an error finding the access token", err.Error()) + } else { + result.Data = &accessData + } + + storeChannel <- result + close(storeChannel) + + }() + + return storeChannel +} + +func (as SqlOAuthStore) GetAccessDataByAuthCode(authCode string) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + accessData := model.AccessData{} + + if err := as.GetReplica().SelectOne(&accessData, "SELECT * FROM OAuthAccessData WHERE AuthCode = :AuthCode", map[string]interface{}{"AuthCode": authCode}); err != nil { + if strings.Contains(err.Error(), "no rows") { + result.Data = nil + } else { + result.Err = model.NewAppError("SqlOAuthStore.GetAccessDataByAuthCode", "We encountered an error finding the access token", err.Error()) + } + } else { + result.Data = &accessData + } + + storeChannel <- result + close(storeChannel) + + }() + + return storeChannel +} + +func (as SqlOAuthStore) RemoveAccessData(token string) StoreChannel { + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + if _, err := as.GetMaster().Exec("DELETE FROM OAuthAccessData WHERE Token = :Token", map[string]interface{}{"Token": token}); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.RemoveAccessData", "We couldn't remove the access token", "err="+err.Error()) + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (as SqlOAuthStore) SaveAuthData(authData *model.AuthData) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + authData.PreSave() + if result.Err = authData.IsValid(); result.Err != nil { + storeChannel <- result + close(storeChannel) + return + } + + if err := as.GetMaster().Insert(authData); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.SaveAuthData", "We couldn't save the authorization code.", err.Error()) + } else { + result.Data = authData + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} + +func (as SqlOAuthStore) GetAuthData(code string) StoreChannel { + + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + if obj, err := as.GetReplica().Get(model.AuthData{}, code); err != nil { + result.Err = model.NewAppError("SqlOAuthStore.GetAuthData", "We encounted an error finding the authorization code", err.Error()) + } else if obj == nil { + result.Err = model.NewAppError("SqlOAuthStore.GetAuthData", "We couldn't find the existing authorization code", "") + } else { + result.Data = obj.(*model.AuthData) + } + + storeChannel <- result + close(storeChannel) + + }() + + return storeChannel +} + +func (as SqlOAuthStore) RemoveAuthData(code string) StoreChannel { + storeChannel := make(StoreChannel) + + go func() { + result := StoreResult{} + + _, err := as.GetMaster().Exec("DELETE FROM OAuthAuthData WHERE Code = :Code", map[string]interface{}{"Code": code}) + if err != nil { + result.Err = model.NewAppError("SqlOAuthStore.RemoveAuthData", "We couldn't remove the authorization code", "err="+err.Error()) + } + + storeChannel <- result + close(storeChannel) + }() + + return storeChannel +} diff --git a/store/sql_oauth_store_test.go b/store/sql_oauth_store_test.go new file mode 100644 index 000000000..08e1388e0 --- /dev/null +++ b/store/sql_oauth_store_test.go @@ -0,0 +1,182 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +package store + +import ( + "github.com/mattermost/platform/model" + "testing" +) + +func TestOAuthStoreSaveApp(t *testing.T) { + Setup() + + a1 := model.OAuthApp{} + a1.CreatorId = model.NewId() + a1.Name = "TestApp" + model.NewId() + a1.CallbackUrls = []string{"https://nowhere.com"} + a1.Homepage = "https://nowhere.com" + + if err := (<-store.OAuth().SaveApp(&a1)).Err; err != nil { + t.Fatal(err) + } +} + +func TestOAuthStoreGetApp(t *testing.T) { + Setup() + + a1 := model.OAuthApp{} + a1.CreatorId = model.NewId() + a1.Name = "TestApp" + model.NewId() + a1.CallbackUrls = []string{"https://nowhere.com"} + a1.Homepage = "https://nowhere.com" + Must(store.OAuth().SaveApp(&a1)) + + if err := (<-store.OAuth().GetApp(a1.Id)).Err; err != nil { + t.Fatal(err) + } + + if err := (<-store.OAuth().GetAppByUser(a1.CreatorId)).Err; err != nil { + t.Fatal(err) + } +} + +func TestOAuthStoreUpdateApp(t *testing.T) { + Setup() + + a1 := model.OAuthApp{} + a1.CreatorId = model.NewId() + a1.Name = "TestApp" + model.NewId() + a1.CallbackUrls = []string{"https://nowhere.com"} + a1.Homepage = "https://nowhere.com" + Must(store.OAuth().SaveApp(&a1)) + + a1.CreateAt = 1 + a1.ClientSecret = "pwd" + a1.CreatorId = "12345678901234567890123456" + a1.Name = "NewName" + if result := <-store.OAuth().UpdateApp(&a1); result.Err != nil { + t.Fatal(result.Err) + } else { + ua1 := (result.Data.([2]*model.OAuthApp)[0]) + if ua1.Name != "NewName" { + t.Fatal("name did not update") + } + if ua1.CreateAt == 1 { + t.Fatal("create at should not have updated") + } + if ua1.ClientSecret == "pwd" { + t.Fatal("client secret should not have updated") + } + if ua1.CreatorId == "12345678901234567890123456" { + t.Fatal("creator id should not have updated") + } + } +} + +func TestOAuthStoreSaveAccessData(t *testing.T) { + Setup() + + a1 := model.AccessData{} + a1.AuthCode = model.NewId() + a1.Token = model.NewId() + a1.RefreshToken = model.NewId() + + if err := (<-store.OAuth().SaveAccessData(&a1)).Err; err != nil { + t.Fatal(err) + } +} + +func TestOAuthStoreGetAccessData(t *testing.T) { + Setup() + + a1 := model.AccessData{} + a1.AuthCode = model.NewId() + a1.Token = model.NewId() + a1.RefreshToken = model.NewId() + Must(store.OAuth().SaveAccessData(&a1)) + + if result := <-store.OAuth().GetAccessData(a1.Token); result.Err != nil { + t.Fatal(result.Err) + } else { + ra1 := result.Data.(*model.AccessData) + if a1.Token != ra1.Token { + t.Fatal("tokens didn't match") + } + } + + if err := (<-store.OAuth().GetAccessDataByAuthCode(a1.AuthCode)).Err; err != nil { + t.Fatal(err) + } + + if err := (<-store.OAuth().GetAccessDataByAuthCode("junk")).Err; err != nil { + t.Fatal(err) + } +} + +func TestOAuthStoreRemoveAccessData(t *testing.T) { + Setup() + + a1 := model.AccessData{} + a1.AuthCode = model.NewId() + a1.Token = model.NewId() + a1.RefreshToken = model.NewId() + Must(store.OAuth().SaveAccessData(&a1)) + + if err := (<-store.OAuth().RemoveAccessData(a1.Token)).Err; err != nil { + t.Fatal(err) + } + + if result := <-store.OAuth().GetAccessDataByAuthCode(a1.AuthCode); result.Err != nil { + t.Fatal(result.Err) + } else { + if result.Data != nil { + t.Fatal("did not delete access token") + } + } +} + +func TestOAuthStoreSaveAuthData(t *testing.T) { + Setup() + + a1 := model.AuthData{} + a1.ClientId = model.NewId() + a1.UserId = model.NewId() + a1.Code = model.NewId() + + if err := (<-store.OAuth().SaveAuthData(&a1)).Err; err != nil { + t.Fatal(err) + } +} + +func TestOAuthStoreGetAuthData(t *testing.T) { + Setup() + + a1 := model.AuthData{} + a1.ClientId = model.NewId() + a1.UserId = model.NewId() + a1.Code = model.NewId() + Must(store.OAuth().SaveAuthData(&a1)) + + if err := (<-store.OAuth().GetAuthData(a1.Code)).Err; err != nil { + t.Fatal(err) + } +} + +func TestOAuthStoreRemoveAuthData(t *testing.T) { + Setup() + + a1 := model.AuthData{} + a1.ClientId = model.NewId() + a1.UserId = model.NewId() + a1.Code = model.NewId() + Must(store.OAuth().SaveAuthData(&a1)) + + if err := (<-store.OAuth().RemoveAuthData(a1.Code)).Err; err != nil { + t.Fatal(err) + } + + if err := (<-store.OAuth().GetAuthData(a1.Code)).Err; err == nil { + t.Fatal("should have errored - auth code removed") + } +} diff --git a/store/sql_session_store.go b/store/sql_session_store.go index 12004ab78..c1d2c852b 100644 --- a/store/sql_session_store.go +++ b/store/sql_session_store.go @@ -18,7 +18,7 @@ func NewSqlSessionStore(sqlStore *SqlStore) SessionStore { for _, db := range sqlStore.GetAllConns() { table := db.AddTableWithName(model.Session{}, "Sessions").SetKeys(false, "Id") table.ColMap("Id").SetMaxSize(26) - table.ColMap("AltId").SetMaxSize(26) + table.ColMap("Token").SetMaxSize(26) table.ColMap("UserId").SetMaxSize(26) table.ColMap("TeamId").SetMaxSize(26) table.ColMap("DeviceId").SetMaxSize(128) @@ -34,7 +34,7 @@ func (me SqlSessionStore) UpgradeSchemaIfNeeded() { func (me SqlSessionStore) CreateIndexesIfNotExists() { me.CreateIndexIfNotExists("idx_sessions_user_id", "Sessions", "UserId") - me.CreateIndexIfNotExists("idx_sessions_alt_id", "Sessions", "AltId") + me.CreateIndexIfNotExists("idx_sessions_token", "Sessions", "Token") } func (me SqlSessionStore) Save(session *model.Session) StoreChannel { @@ -70,19 +70,21 @@ func (me SqlSessionStore) Save(session *model.Session) StoreChannel { return storeChannel } -func (me SqlSessionStore) Get(id string) StoreChannel { +func (me SqlSessionStore) Get(sessionIdOrToken string) StoreChannel { storeChannel := make(StoreChannel) go func() { result := StoreResult{} - if obj, err := me.GetReplica().Get(model.Session{}, id); err != nil { - result.Err = model.NewAppError("SqlSessionStore.Get", "We encounted an error finding the session", "id="+id+", "+err.Error()) - } else if obj == nil { - result.Err = model.NewAppError("SqlSessionStore.Get", "We couldn't find the existing session", "id="+id) + var sessions []*model.Session + + if _, err := me.GetReplica().Select(&sessions, "SELECT * FROM Sessions WHERE Token = :Token OR Id = :Id LIMIT 1", map[string]interface{}{"Token": sessionIdOrToken, "Id": sessionIdOrToken}); err != nil { + result.Err = model.NewAppError("SqlSessionStore.Get", "We encounted an error finding the session", "sessionIdOrToken="+sessionIdOrToken+", "+err.Error()) + } else if sessions == nil || len(sessions) == 0 { + result.Err = model.NewAppError("SqlSessionStore.Get", "We encounted an error finding the session", "sessionIdOrToken="+sessionIdOrToken) } else { - result.Data = obj.(*model.Session) + result.Data = sessions[0] } storeChannel <- result @@ -120,15 +122,15 @@ func (me SqlSessionStore) GetSessions(userId string) StoreChannel { return storeChannel } -func (me SqlSessionStore) Remove(sessionIdOrAlt string) StoreChannel { +func (me SqlSessionStore) Remove(sessionIdOrToken string) StoreChannel { storeChannel := make(StoreChannel) go func() { result := StoreResult{} - _, err := me.GetMaster().Exec("DELETE FROM Sessions WHERE Id = :Id Or AltId = :AltId", map[string]interface{}{"Id": sessionIdOrAlt, "AltId": sessionIdOrAlt}) + _, err := me.GetMaster().Exec("DELETE FROM Sessions WHERE Id = :Id Or Token = :Token", map[string]interface{}{"Id": sessionIdOrToken, "Token": sessionIdOrToken}) if err != nil { - result.Err = model.NewAppError("SqlSessionStore.RemoveSession", "We couldn't remove the session", "id="+sessionIdOrAlt+", err="+err.Error()) + result.Err = model.NewAppError("SqlSessionStore.RemoveSession", "We couldn't remove the session", "id="+sessionIdOrToken+", err="+err.Error()) } storeChannel <- result @@ -181,7 +183,6 @@ func (me SqlSessionStore) UpdateRoles(userId, roles string) StoreChannel { go func() { result := StoreResult{} - if _, err := me.GetMaster().Exec("UPDATE Sessions SET Roles = :Roles WHERE UserId = :UserId", map[string]interface{}{"Roles": roles, "UserId": userId}); err != nil { result.Err = model.NewAppError("SqlSessionStore.UpdateRoles", "We couldn't update the roles", "userId="+userId) } else { diff --git a/store/sql_session_store_test.go b/store/sql_session_store_test.go index 581aff971..4ae680556 100644 --- a/store/sql_session_store_test.go +++ b/store/sql_session_store_test.go @@ -80,7 +80,7 @@ func TestSessionRemove(t *testing.T) { } } -func TestSessionRemoveAlt(t *testing.T) { +func TestSessionRemoveToken(t *testing.T) { Setup() s1 := model.Session{} @@ -96,7 +96,7 @@ func TestSessionRemoveAlt(t *testing.T) { } } - Must(store.Session().Remove(s1.AltId)) + Must(store.Session().Remove(s1.Token)) if rs2 := (<-store.Session().Get(s1.Id)); rs2.Err == nil { t.Fatal("should have been removed") diff --git a/store/sql_store.go b/store/sql_store.go index 98c67d668..c0b3c2021 100644 --- a/store/sql_store.go +++ b/store/sql_store.go @@ -38,6 +38,7 @@ type SqlStore struct { user UserStore audit AuditStore session SessionStore + oauth OAuthStore } func NewSqlStore() Store { @@ -55,28 +56,36 @@ func NewSqlStore() Store { utils.Cfg.SqlSettings.Trace) } + // Temporary upgrade code, remove after 0.8.0 release + if sqlStore.DoesColumnExist("Sessions", "AltId") { + sqlStore.GetMaster().Exec("DROP TABLE IF EXISTS Sessions") + } + sqlStore.team = NewSqlTeamStore(sqlStore) sqlStore.channel = NewSqlChannelStore(sqlStore) sqlStore.post = NewSqlPostStore(sqlStore) sqlStore.user = NewSqlUserStore(sqlStore) sqlStore.audit = NewSqlAuditStore(sqlStore) sqlStore.session = NewSqlSessionStore(sqlStore) + sqlStore.oauth = NewSqlOAuthStore(sqlStore) sqlStore.master.CreateTablesIfNotExists() - sqlStore.team.(*SqlTeamStore).CreateIndexesIfNotExists() - sqlStore.channel.(*SqlChannelStore).CreateIndexesIfNotExists() - sqlStore.post.(*SqlPostStore).CreateIndexesIfNotExists() - sqlStore.user.(*SqlUserStore).CreateIndexesIfNotExists() - sqlStore.audit.(*SqlAuditStore).CreateIndexesIfNotExists() - sqlStore.session.(*SqlSessionStore).CreateIndexesIfNotExists() - sqlStore.team.(*SqlTeamStore).UpgradeSchemaIfNeeded() sqlStore.channel.(*SqlChannelStore).UpgradeSchemaIfNeeded() sqlStore.post.(*SqlPostStore).UpgradeSchemaIfNeeded() sqlStore.user.(*SqlUserStore).UpgradeSchemaIfNeeded() sqlStore.audit.(*SqlAuditStore).UpgradeSchemaIfNeeded() sqlStore.session.(*SqlSessionStore).UpgradeSchemaIfNeeded() + sqlStore.oauth.(*SqlOAuthStore).UpgradeSchemaIfNeeded() + + sqlStore.team.(*SqlTeamStore).CreateIndexesIfNotExists() + sqlStore.channel.(*SqlChannelStore).CreateIndexesIfNotExists() + sqlStore.post.(*SqlPostStore).CreateIndexesIfNotExists() + sqlStore.user.(*SqlUserStore).CreateIndexesIfNotExists() + sqlStore.audit.(*SqlAuditStore).CreateIndexesIfNotExists() + sqlStore.session.(*SqlSessionStore).CreateIndexesIfNotExists() + sqlStore.oauth.(*SqlOAuthStore).CreateIndexesIfNotExists() return sqlStore } @@ -363,6 +372,10 @@ func (ss SqlStore) Audit() AuditStore { return ss.audit } +func (ss SqlStore) OAuth() OAuthStore { + return ss.oauth +} + type mattermConverter struct{} func (me mattermConverter) ToDb(val interface{}) (interface{}, error) { diff --git a/store/store.go b/store/store.go index 959e93fa4..0218bc757 100644 --- a/store/store.go +++ b/store/store.go @@ -34,6 +34,7 @@ type Store interface { User() UserStore Audit() AuditStore Session() SessionStore + OAuth() OAuthStore Close() } @@ -104,9 +105,9 @@ type UserStore interface { type SessionStore interface { Save(session *model.Session) StoreChannel - Get(id string) StoreChannel + Get(sessionIdOrToken string) StoreChannel GetSessions(userId string) StoreChannel - Remove(sessionIdOrAlt string) StoreChannel + Remove(sessionIdOrToken string) StoreChannel UpdateLastActivityAt(sessionId string, time int64) StoreChannel UpdateRoles(userId string, roles string) StoreChannel } @@ -115,3 +116,17 @@ type AuditStore interface { Save(audit *model.Audit) StoreChannel Get(user_id string, limit int) StoreChannel } + +type OAuthStore interface { + SaveApp(app *model.OAuthApp) StoreChannel + UpdateApp(app *model.OAuthApp) StoreChannel + GetApp(id string) StoreChannel + GetAppByUser(userId string) StoreChannel + SaveAuthData(authData *model.AuthData) StoreChannel + GetAuthData(code string) StoreChannel + RemoveAuthData(code string) StoreChannel + SaveAccessData(accessData *model.AccessData) StoreChannel + GetAccessData(token string) StoreChannel + GetAccessDataByAuthCode(authCode string) StoreChannel + RemoveAccessData(token string) StoreChannel +} diff --git a/utils/config.go b/utils/config.go index 9f66e9f37..212a1a559 100644 --- a/utils/config.go +++ b/utils/config.go @@ -21,20 +21,21 @@ const ( ) type ServiceSettings struct { - SiteName string - Mode string - AllowTesting bool - UseSSL bool - Port string - Version string - InviteSalt string - PublicLinkSalt string - ResetSalt string - AnalyticsUrl string - UseLocalStorage bool - StorageDirectory string - AllowedLoginAttempts int - DisableEmailSignUp bool + SiteName string + Mode string + AllowTesting bool + UseSSL bool + Port string + Version string + InviteSalt string + PublicLinkSalt string + ResetSalt string + AnalyticsUrl string + UseLocalStorage bool + StorageDirectory string + AllowedLoginAttempts int + DisableEmailSignUp bool + EnableOAuthServiceProvider bool } type SSOSetting struct { @@ -286,6 +287,7 @@ func getClientProperties(c *Config) map[string]string { props["ProfileHeight"] = fmt.Sprintf("%v", c.ImageSettings.ProfileHeight) props["ProfileWidth"] = fmt.Sprintf("%v", c.ImageSettings.ProfileWidth) props["ProfileWidth"] = fmt.Sprintf("%v", c.ImageSettings.ProfileWidth) + props["EnableOAuthServiceProvider"] = strconv.FormatBool(c.ServiceSettings.EnableOAuthServiceProvider) return props } diff --git a/web/react/components/authorize.jsx b/web/react/components/authorize.jsx new file mode 100644 index 000000000..dd4479ad4 --- /dev/null +++ b/web/react/components/authorize.jsx @@ -0,0 +1,72 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +var Client = require('../utils/client.jsx'); + +export default class Authorize extends React.Component { + constructor(props) { + super(props); + + this.handleAllow = this.handleAllow.bind(this); + this.handleDeny = this.handleDeny.bind(this); + + this.state = {}; + } + handleAllow() { + const responseType = this.props.responseType; + const clientId = this.props.clientId; + const redirectUri = this.props.redirectUri; + const state = this.props.state; + const scope = this.props.scope; + + Client.allowOAuth2(responseType, clientId, redirectUri, state, scope, + (data) => { + if (data.redirect) { + window.location.replace(data.redirect); + } + }, + () => {} + ); + } + handleDeny() { + window.location.replace(this.props.redirectUri + '?error=access_denied'); + } + render() { + return ( +
+
+

{'An application would like to connect to your '}{this.props.teamName}{' account'}

+ +
+
+ +
+ + +
+
+ ); + } +} + +Authorize.propTypes = { + appName: React.PropTypes.string, + teamName: React.PropTypes.string, + responseType: React.PropTypes.string, + clientId: React.PropTypes.string, + redirectUri: React.PropTypes.string, + state: React.PropTypes.string, + scope: React.PropTypes.string +}; diff --git a/web/react/components/popover_list_members.jsx b/web/react/components/popover_list_members.jsx index fb9522afb..ec873dd00 100644 --- a/web/react/components/popover_list_members.jsx +++ b/web/react/components/popover_list_members.jsx @@ -25,7 +25,7 @@ export default class PopoverListMembers extends React.Component { $('#member_popover').popover({placement: 'bottom', trigger: 'click', html: true}); $('body').on('click', function onClick(e) { - if ($(e.target.parentNode.parentNode)[0] !== $('#member_popover')[0] && $(e.target).parents('.popover.in').length === 0) { + if (e.target.parentNode && $(e.target.parentNode.parentNode)[0] !== $('#member_popover')[0] && $(e.target).parents('.popover.in').length === 0) { $('#member_popover').popover('hide'); } }); diff --git a/web/react/components/register_app_modal.jsx b/web/react/components/register_app_modal.jsx new file mode 100644 index 000000000..3dd5c094e --- /dev/null +++ b/web/react/components/register_app_modal.jsx @@ -0,0 +1,249 @@ +// Copyright (c) 2015 Spinpunch, Inc. All Rights Reserved. +// See License.txt for license information. + +var Client = require('../utils/client.jsx'); + +export default class RegisterAppModal extends React.Component { + constructor() { + super(); + + this.register = this.register.bind(this); + this.onHide = this.onHide.bind(this); + this.save = this.save.bind(this); + + this.state = {clientId: '', clientSecret: '', saved: false}; + } + componentDidMount() { + $(React.findDOMNode(this)).on('hide.bs.modal', this.onHide); + } + register() { + var state = this.state; + state.serverError = null; + + var app = {}; + + var name = this.refs.name.getDOMNode().value; + if (!name || name.length === 0) { + state.nameError = 'Application name must be filled in.'; + this.setState(state); + return; + } + state.nameError = null; + app.name = name; + + var homepage = this.refs.homepage.getDOMNode().value; + if (!homepage || homepage.length === 0) { + state.homepageError = 'Homepage must be filled in.'; + this.setState(state); + return; + } + state.homepageError = null; + app.homepage = homepage; + + var desc = this.refs.desc.getDOMNode().value; + app.description = desc; + + var rawCallbacks = this.refs.callback.getDOMNode().value.trim(); + if (!rawCallbacks || rawCallbacks.length === 0) { + state.callbackError = 'At least one callback URL must be filled in.'; + this.setState(state); + return; + } + state.callbackError = null; + app.callback_urls = rawCallbacks.split('\n'); + + Client.registerOAuthApp(app, + (data) => { + state.clientId = data.id; + state.clientSecret = data.client_secret; + this.setState(state); + }, + (err) => { + state.serverError = err.message; + this.setState(state); + } + ); + } + onHide(e) { + if (!this.state.saved && this.state.clientId !== '') { + e.preventDefault(); + return; + } + + this.setState({clientId: '', clientSecret: '', saved: false}); + } + save() { + this.setState({saved: this.refs.save.getDOMNode().checked}); + } + render() { + var nameError; + if (this.state.nameError) { + nameError =
; + } + var homepageError; + if (this.state.homepageError) { + homepageError =
; + } + var callbackError; + if (this.state.callbackError) { + callbackError =
; + } + var serverError; + if (this.state.serverError) { + serverError =
; + } + + var body = ''; + if (this.state.clientId === '') { + body = ( +
+

{'Register a New Application'}

+
+ +
+ + {nameError} +
+
+
+ +
+ + {homepageError} +
+
+
+ +
+ +
+
+
+ +
+