From 640d3018c9a75e7c85da55c3483396e31a6de994 Mon Sep 17 00:00:00 2001 From: =Corey Hulen Date: Wed, 20 Jan 2016 08:56:09 -0600 Subject: PLT-7 adding loc db calls for oauth table --- api/oauth.go | 24 ++++++++++++----------- api/team.go | 2 +- api/user.go | 6 +++--- store/sql_oauth_store.go | 25 ++++++++++++------------ store/sql_oauth_store_test.go | 45 ++++++++++++++++++++++--------------------- store/store.go | 24 +++++++++++------------ web/web.go | 12 ++++++------ 7 files changed, 71 insertions(+), 67 deletions(-) diff --git a/api/oauth.go b/api/oauth.go index eb5e0e496..64848d0ce 100644 --- a/api/oauth.go +++ b/api/oauth.go @@ -5,12 +5,14 @@ package api import ( "fmt" + "net/http" + "net/url" + l4g "github.com/alecthomas/log4go" "github.com/gorilla/mux" "github.com/mattermost/platform/model" "github.com/mattermost/platform/utils" - "net/http" - "net/url" + goi18n "github.com/nicksnyder/go-i18n/i18n" ) func InitOAuth(r *mux.Router) { @@ -41,7 +43,7 @@ func registerOAuthApp(c *Context, w http.ResponseWriter, r *http.Request) { app.ClientSecret = secret app.CreatorId = c.Session.UserId - if result := <-Srv.Store.OAuth().SaveApp(app); result.Err != nil { + if result := <-Srv.Store.OAuth().SaveApp(c.T, app); result.Err != nil { c.Err = result.Err return } else { @@ -90,7 +92,7 @@ func allowOAuth(c *Context, w http.ResponseWriter, r *http.Request) { state := r.URL.Query().Get("state") var app *model.OAuthApp - if result := <-Srv.Store.OAuth().GetApp(clientId); result.Err != nil { + if result := <-Srv.Store.OAuth().GetApp(c.T, clientId); result.Err != nil { c.Err = model.NewAppError("allowOAuth", "server_error: Error accessing the database", "") return } else { @@ -112,7 +114,7 @@ func allowOAuth(c *Context, w http.ResponseWriter, r *http.Request) { authData := &model.AuthData{UserId: c.Session.UserId, ClientId: clientId, CreateAt: model.GetMillis(), RedirectUri: redirectUri, State: state, Scope: scope} authData.Code = model.HashPassword(fmt.Sprintf("%v:%v:%v:%v", clientId, redirectUri, authData.CreateAt, c.Session.UserId)) - if result := <-Srv.Store.OAuth().SaveAuthData(authData); result.Err != nil { + if result := <-Srv.Store.OAuth().SaveAuthData(c.T, authData); result.Err != nil { responseData["redirect"] = redirectUri + "?error=server_error&state=" + state w.Write([]byte(model.MapToJson(responseData))) return @@ -125,20 +127,20 @@ func allowOAuth(c *Context, w http.ResponseWriter, r *http.Request) { w.Write([]byte(model.MapToJson(responseData))) } -func RevokeAccessToken(token string) *model.AppError { +func RevokeAccessToken(T goi18n.TranslateFunc, token string) *model.AppError { schan := Srv.Store.Session().Remove(token) sessionCache.Remove(token) var accessData *model.AccessData - if result := <-Srv.Store.OAuth().GetAccessData(token); result.Err != nil { + if result := <-Srv.Store.OAuth().GetAccessData(T, token); result.Err != nil { return model.NewAppError("RevokeAccessToken", "Error getting access token from DB before deletion", "") } else { accessData = result.Data.(*model.AccessData) } - tchan := Srv.Store.OAuth().RemoveAccessData(token) - cchan := Srv.Store.OAuth().RemoveAuthData(accessData.AuthCode) + tchan := Srv.Store.OAuth().RemoveAccessData(T, token) + cchan := Srv.Store.OAuth().RemoveAuthData(T, accessData.AuthCode) if result := <-tchan; result.Err != nil { return model.NewAppError("RevokeAccessToken", "Error deleting access token from DB", "") @@ -155,8 +157,8 @@ func RevokeAccessToken(token string) *model.AppError { return nil } -func GetAuthData(code string) *model.AuthData { - if result := <-Srv.Store.OAuth().GetAuthData(code); result.Err != nil { +func GetAuthData(T goi18n.TranslateFunc, code string) *model.AuthData { + if result := <-Srv.Store.OAuth().GetAuthData(T, code); result.Err != nil { l4g.Error("Couldn't find auth code for code=%s", code) return nil } else { diff --git a/api/team.go b/api/team.go index 7ee7b41c9..e55d454e0 100644 --- a/api/team.go +++ b/api/team.go @@ -340,7 +340,7 @@ func revokeAllSessions(c *Context, w http.ResponseWriter, r *http.Request) { c.LogAudit("revoked_all=" + id) if session.IsOAuth { - RevokeAccessToken(session.Token) + RevokeAccessToken(c.T, session.Token) } else { sessionCache.Remove(session.Token) diff --git a/api/user.go b/api/user.go index 66125d242..71426acaa 100644 --- a/api/user.go +++ b/api/user.go @@ -713,7 +713,7 @@ func RevokeSessionById(c *Context, sessionId string) { c.LogAudit("session_id=" + session.Id) if session.IsOAuth { - RevokeAccessToken(session.Token) + RevokeAccessToken(c.T, session.Token) } else { sessionCache.Remove(session.Token) @@ -734,7 +734,7 @@ func RevokeAllSession(c *Context, userId string) { for _, session := range sessions { c.LogAuditWithUserId(userId, "session_id="+session.Id) if session.IsOAuth { - RevokeAccessToken(session.Token) + RevokeAccessToken(c.T, session.Token) } else { sessionCache.Remove(session.Token) if result := <-Srv.Store.Session().Remove(session.Id); result.Err != nil { @@ -1440,7 +1440,7 @@ func PermanentDeleteUser(c *Context, user *model.User) *model.AppError { return result.Err } - if result := <-Srv.Store.OAuth().PermanentDeleteAuthDataByUser(user.Id); result.Err != nil { + if result := <-Srv.Store.OAuth().PermanentDeleteAuthDataByUser(c.T, user.Id); result.Err != nil { return result.Err } diff --git a/store/sql_oauth_store.go b/store/sql_oauth_store.go index 43a5bee31..20184e6a4 100644 --- a/store/sql_oauth_store.go +++ b/store/sql_oauth_store.go @@ -5,6 +5,7 @@ package store import ( "github.com/mattermost/platform/model" + goi18n "github.com/nicksnyder/go-i18n/i18n" "strings" ) @@ -52,7 +53,7 @@ func (as SqlOAuthStore) CreateIndexesIfNotExists() { as.CreateIndexIfNotExists("idx_oauthauthdata_client_id", "OAuthAuthData", "Code") } -func (as SqlOAuthStore) SaveApp(app *model.OAuthApp) StoreChannel { +func (as SqlOAuthStore) SaveApp(T goi18n.TranslateFunc, app *model.OAuthApp) StoreChannel { storeChannel := make(StoreChannel) @@ -86,7 +87,7 @@ func (as SqlOAuthStore) SaveApp(app *model.OAuthApp) StoreChannel { return storeChannel } -func (as SqlOAuthStore) UpdateApp(app *model.OAuthApp) StoreChannel { +func (as SqlOAuthStore) UpdateApp(T goi18n.TranslateFunc, app *model.OAuthApp) StoreChannel { storeChannel := make(StoreChannel) @@ -127,7 +128,7 @@ func (as SqlOAuthStore) UpdateApp(app *model.OAuthApp) StoreChannel { return storeChannel } -func (as SqlOAuthStore) GetApp(id string) StoreChannel { +func (as SqlOAuthStore) GetApp(T goi18n.TranslateFunc, id string) StoreChannel { storeChannel := make(StoreChannel) @@ -150,7 +151,7 @@ func (as SqlOAuthStore) GetApp(id string) StoreChannel { return storeChannel } -func (as SqlOAuthStore) GetAppByUser(userId string) StoreChannel { +func (as SqlOAuthStore) GetAppByUser(T goi18n.TranslateFunc, userId string) StoreChannel { storeChannel := make(StoreChannel) @@ -172,7 +173,7 @@ func (as SqlOAuthStore) GetAppByUser(userId string) StoreChannel { return storeChannel } -func (as SqlOAuthStore) SaveAccessData(accessData *model.AccessData) StoreChannel { +func (as SqlOAuthStore) SaveAccessData(T goi18n.TranslateFunc, accessData *model.AccessData) StoreChannel { storeChannel := make(StoreChannel) @@ -198,7 +199,7 @@ func (as SqlOAuthStore) SaveAccessData(accessData *model.AccessData) StoreChanne return storeChannel } -func (as SqlOAuthStore) GetAccessData(token string) StoreChannel { +func (as SqlOAuthStore) GetAccessData(T goi18n.TranslateFunc, token string) StoreChannel { storeChannel := make(StoreChannel) @@ -221,7 +222,7 @@ func (as SqlOAuthStore) GetAccessData(token string) StoreChannel { return storeChannel } -func (as SqlOAuthStore) GetAccessDataByAuthCode(authCode string) StoreChannel { +func (as SqlOAuthStore) GetAccessDataByAuthCode(T goi18n.TranslateFunc, authCode string) StoreChannel { storeChannel := make(StoreChannel) @@ -248,7 +249,7 @@ func (as SqlOAuthStore) GetAccessDataByAuthCode(authCode string) StoreChannel { return storeChannel } -func (as SqlOAuthStore) RemoveAccessData(token string) StoreChannel { +func (as SqlOAuthStore) RemoveAccessData(T goi18n.TranslateFunc, token string) StoreChannel { storeChannel := make(StoreChannel) go func() { @@ -265,7 +266,7 @@ func (as SqlOAuthStore) RemoveAccessData(token string) StoreChannel { return storeChannel } -func (as SqlOAuthStore) SaveAuthData(authData *model.AuthData) StoreChannel { +func (as SqlOAuthStore) SaveAuthData(T goi18n.TranslateFunc, authData *model.AuthData) StoreChannel { storeChannel := make(StoreChannel) @@ -292,7 +293,7 @@ func (as SqlOAuthStore) SaveAuthData(authData *model.AuthData) StoreChannel { return storeChannel } -func (as SqlOAuthStore) GetAuthData(code string) StoreChannel { +func (as SqlOAuthStore) GetAuthData(T goi18n.TranslateFunc, code string) StoreChannel { storeChannel := make(StoreChannel) @@ -315,7 +316,7 @@ func (as SqlOAuthStore) GetAuthData(code string) StoreChannel { return storeChannel } -func (as SqlOAuthStore) RemoveAuthData(code string) StoreChannel { +func (as SqlOAuthStore) RemoveAuthData(T goi18n.TranslateFunc, code string) StoreChannel { storeChannel := make(StoreChannel) go func() { @@ -333,7 +334,7 @@ func (as SqlOAuthStore) RemoveAuthData(code string) StoreChannel { return storeChannel } -func (as SqlOAuthStore) PermanentDeleteAuthDataByUser(userId string) StoreChannel { +func (as SqlOAuthStore) PermanentDeleteAuthDataByUser(T goi18n.TranslateFunc, userId string) StoreChannel { storeChannel := make(StoreChannel) go func() { diff --git a/store/sql_oauth_store_test.go b/store/sql_oauth_store_test.go index c3f6ea7ac..f8d035a0c 100644 --- a/store/sql_oauth_store_test.go +++ b/store/sql_oauth_store_test.go @@ -5,6 +5,7 @@ package store import ( "github.com/mattermost/platform/model" + "github.com/mattermost/platform/utils" "testing" ) @@ -17,7 +18,7 @@ func TestOAuthStoreSaveApp(t *testing.T) { a1.CallbackUrls = []string{"https://nowhere.com"} a1.Homepage = "https://nowhere.com" - if err := (<-store.OAuth().SaveApp(&a1)).Err; err != nil { + if err := (<-store.OAuth().SaveApp(utils.T, &a1)).Err; err != nil { t.Fatal(err) } } @@ -30,13 +31,13 @@ func TestOAuthStoreGetApp(t *testing.T) { a1.Name = "TestApp" + model.NewId() a1.CallbackUrls = []string{"https://nowhere.com"} a1.Homepage = "https://nowhere.com" - Must(store.OAuth().SaveApp(&a1)) + Must(store.OAuth().SaveApp(utils.T, &a1)) - if err := (<-store.OAuth().GetApp(a1.Id)).Err; err != nil { + if err := (<-store.OAuth().GetApp(utils.T, a1.Id)).Err; err != nil { t.Fatal(err) } - if err := (<-store.OAuth().GetAppByUser(a1.CreatorId)).Err; err != nil { + if err := (<-store.OAuth().GetAppByUser(utils.T, a1.CreatorId)).Err; err != nil { t.Fatal(err) } } @@ -49,13 +50,13 @@ func TestOAuthStoreUpdateApp(t *testing.T) { a1.Name = "TestApp" + model.NewId() a1.CallbackUrls = []string{"https://nowhere.com"} a1.Homepage = "https://nowhere.com" - Must(store.OAuth().SaveApp(&a1)) + Must(store.OAuth().SaveApp(utils.T, &a1)) a1.CreateAt = 1 a1.ClientSecret = "pwd" a1.CreatorId = "12345678901234567890123456" a1.Name = "NewName" - if result := <-store.OAuth().UpdateApp(&a1); result.Err != nil { + if result := <-store.OAuth().UpdateApp(utils.T, &a1); result.Err != nil { t.Fatal(result.Err) } else { ua1 := (result.Data.([2]*model.OAuthApp)[0]) @@ -82,7 +83,7 @@ func TestOAuthStoreSaveAccessData(t *testing.T) { a1.Token = model.NewId() a1.RefreshToken = model.NewId() - if err := (<-store.OAuth().SaveAccessData(&a1)).Err; err != nil { + if err := (<-store.OAuth().SaveAccessData(utils.T, &a1)).Err; err != nil { t.Fatal(err) } } @@ -94,9 +95,9 @@ func TestOAuthStoreGetAccessData(t *testing.T) { a1.AuthCode = model.NewId() a1.Token = model.NewId() a1.RefreshToken = model.NewId() - Must(store.OAuth().SaveAccessData(&a1)) + Must(store.OAuth().SaveAccessData(utils.T, &a1)) - if result := <-store.OAuth().GetAccessData(a1.Token); result.Err != nil { + if result := <-store.OAuth().GetAccessData(utils.T, a1.Token); result.Err != nil { t.Fatal(result.Err) } else { ra1 := result.Data.(*model.AccessData) @@ -105,11 +106,11 @@ func TestOAuthStoreGetAccessData(t *testing.T) { } } - if err := (<-store.OAuth().GetAccessDataByAuthCode(a1.AuthCode)).Err; err != nil { + if err := (<-store.OAuth().GetAccessDataByAuthCode(utils.T, a1.AuthCode)).Err; err != nil { t.Fatal(err) } - if err := (<-store.OAuth().GetAccessDataByAuthCode("junk")).Err; err != nil { + if err := (<-store.OAuth().GetAccessDataByAuthCode(utils.T, "junk")).Err; err != nil { t.Fatal(err) } } @@ -121,13 +122,13 @@ func TestOAuthStoreRemoveAccessData(t *testing.T) { a1.AuthCode = model.NewId() a1.Token = model.NewId() a1.RefreshToken = model.NewId() - Must(store.OAuth().SaveAccessData(&a1)) + Must(store.OAuth().SaveAccessData(utils.T, &a1)) - if err := (<-store.OAuth().RemoveAccessData(a1.Token)).Err; err != nil { + if err := (<-store.OAuth().RemoveAccessData(utils.T, a1.Token)).Err; err != nil { t.Fatal(err) } - if result := <-store.OAuth().GetAccessDataByAuthCode(a1.AuthCode); result.Err != nil { + if result := <-store.OAuth().GetAccessDataByAuthCode(utils.T, a1.AuthCode); result.Err != nil { t.Fatal(result.Err) } else { if result.Data != nil { @@ -144,7 +145,7 @@ func TestOAuthStoreSaveAuthData(t *testing.T) { a1.UserId = model.NewId() a1.Code = model.NewId() - if err := (<-store.OAuth().SaveAuthData(&a1)).Err; err != nil { + if err := (<-store.OAuth().SaveAuthData(utils.T, &a1)).Err; err != nil { t.Fatal(err) } } @@ -156,9 +157,9 @@ func TestOAuthStoreGetAuthData(t *testing.T) { a1.ClientId = model.NewId() a1.UserId = model.NewId() a1.Code = model.NewId() - Must(store.OAuth().SaveAuthData(&a1)) + Must(store.OAuth().SaveAuthData(utils.T, &a1)) - if err := (<-store.OAuth().GetAuthData(a1.Code)).Err; err != nil { + if err := (<-store.OAuth().GetAuthData(utils.T, a1.Code)).Err; err != nil { t.Fatal(err) } } @@ -170,13 +171,13 @@ func TestOAuthStoreRemoveAuthData(t *testing.T) { a1.ClientId = model.NewId() a1.UserId = model.NewId() a1.Code = model.NewId() - Must(store.OAuth().SaveAuthData(&a1)) + Must(store.OAuth().SaveAuthData(utils.T, &a1)) - if err := (<-store.OAuth().RemoveAuthData(a1.Code)).Err; err != nil { + if err := (<-store.OAuth().RemoveAuthData(utils.T, a1.Code)).Err; err != nil { t.Fatal(err) } - if err := (<-store.OAuth().GetAuthData(a1.Code)).Err; err == nil { + if err := (<-store.OAuth().GetAuthData(utils.T, a1.Code)).Err; err == nil { t.Fatal("should have errored - auth code removed") } } @@ -188,9 +189,9 @@ func TestOAuthStoreRemoveAuthDataByUser(t *testing.T) { a1.ClientId = model.NewId() a1.UserId = model.NewId() a1.Code = model.NewId() - Must(store.OAuth().SaveAuthData(&a1)) + Must(store.OAuth().SaveAuthData(utils.T, &a1)) - if err := (<-store.OAuth().PermanentDeleteAuthDataByUser(a1.UserId)).Err; err != nil { + if err := (<-store.OAuth().PermanentDeleteAuthDataByUser(utils.T, a1.UserId)).Err; err != nil { t.Fatal(err) } } diff --git a/store/store.go b/store/store.go index 5b711fdc7..fe103032e 100644 --- a/store/store.go +++ b/store/store.go @@ -147,18 +147,18 @@ type AuditStore interface { } type OAuthStore interface { - SaveApp(app *model.OAuthApp) StoreChannel - UpdateApp(app *model.OAuthApp) StoreChannel - GetApp(id string) StoreChannel - GetAppByUser(userId string) StoreChannel - SaveAuthData(authData *model.AuthData) StoreChannel - GetAuthData(code string) StoreChannel - RemoveAuthData(code string) StoreChannel - PermanentDeleteAuthDataByUser(userId string) StoreChannel - SaveAccessData(accessData *model.AccessData) StoreChannel - GetAccessData(token string) StoreChannel - GetAccessDataByAuthCode(authCode string) StoreChannel - RemoveAccessData(token string) StoreChannel + SaveApp(T goi18n.TranslateFunc, app *model.OAuthApp) StoreChannel + UpdateApp(T goi18n.TranslateFunc, app *model.OAuthApp) StoreChannel + GetApp(T goi18n.TranslateFunc, id string) StoreChannel + GetAppByUser(T goi18n.TranslateFunc, userId string) StoreChannel + SaveAuthData(T goi18n.TranslateFunc, authData *model.AuthData) StoreChannel + GetAuthData(T goi18n.TranslateFunc, code string) StoreChannel + RemoveAuthData(T goi18n.TranslateFunc, code string) StoreChannel + PermanentDeleteAuthDataByUser(T goi18n.TranslateFunc, userId string) StoreChannel + SaveAccessData(T goi18n.TranslateFunc, accessData *model.AccessData) StoreChannel + GetAccessData(T goi18n.TranslateFunc, token string) StoreChannel + GetAccessDataByAuthCode(T goi18n.TranslateFunc, authCode string) StoreChannel + RemoveAccessData(T goi18n.TranslateFunc, token string) StoreChannel } type SystemStore interface { diff --git a/web/web.go b/web/web.go index c7a7b8666..a98b213be 100644 --- a/web/web.go +++ b/web/web.go @@ -846,7 +846,7 @@ func authorizeOAuth(c *api.Context, w http.ResponseWriter, r *http.Request) { } var app *model.OAuthApp - if result := <-api.Srv.Store.OAuth().GetApp(clientId); result.Err != nil { + if result := <-api.Srv.Store.OAuth().GetApp(c.T, clientId); result.Err != nil { c.Err = result.Err return } else { @@ -909,10 +909,10 @@ func getAccessToken(c *api.Context, w http.ResponseWriter, r *http.Request) { redirectUri := r.FormValue("redirect_uri") - achan := api.Srv.Store.OAuth().GetApp(clientId) - tchan := api.Srv.Store.OAuth().GetAccessDataByAuthCode(code) + achan := api.Srv.Store.OAuth().GetApp(c.T, clientId) + tchan := api.Srv.Store.OAuth().GetAccessDataByAuthCode(c.T, code) - authData := api.GetAuthData(code) + authData := api.GetAuthData(c.T, code) if authData == nil { c.LogAudit("fail - invalid auth code") @@ -967,7 +967,7 @@ func getAccessToken(c *api.Context, w http.ResponseWriter, r *http.Request) { accessData := result.Data.(*model.AccessData) // Revoke access token, related auth code, and session from DB as well as from cache - if err := api.RevokeAccessToken(accessData.Token); err != nil { + if err := api.RevokeAccessToken(c.T, accessData.Token); err != nil { l4g.Error("Encountered an error revoking an access token, err=" + err.Message) } @@ -995,7 +995,7 @@ func getAccessToken(c *api.Context, w http.ResponseWriter, r *http.Request) { accessData := &model.AccessData{AuthCode: authData.Code, Token: session.Token, RedirectUri: callback} - if result := <-api.Srv.Store.OAuth().SaveAccessData(accessData); result.Err != nil { + if result := <-api.Srv.Store.OAuth().SaveAccessData(c.T, accessData); result.Err != nil { l4g.Error(result.Err) c.Err = model.NewAppError("getAccessToken", "server_error: Encountered internal server error while saving access token to database", "") return -- cgit v1.2.3-1-g7c22