summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--api/user.go7
-rw-r--r--store/sql_user_store.go21
-rw-r--r--store/sql_user_store_test.go2
-rw-r--r--store/store.go2
4 files changed, 26 insertions, 6 deletions
diff --git a/api/user.go b/api/user.go
index 91c8c022a..8b2df7143 100644
--- a/api/user.go
+++ b/api/user.go
@@ -2087,13 +2087,16 @@ func switchToSSO(c *Context, w http.ResponseWriter, r *http.Request) {
func CompleteSwitchWithOAuth(c *Context, w http.ResponseWriter, r *http.Request, service string, userData io.ReadCloser, team *model.Team, email string) {
authData := ""
+ ssoEmail := ""
provider := einterfaces.GetOauthProvider(service)
if provider == nil {
c.Err = model.NewLocAppError("CompleteClaimWithOAuth", "api.user.complete_switch_with_oauth.unavailable.app_error",
map[string]interface{}{"Service": service}, "")
return
} else {
- authData = provider.GetAuthDataFromJson(userData)
+ ssoUser := provider.GetUserFromJson(userData)
+ authData = ssoUser.AuthData
+ ssoEmail = ssoUser.Email
}
if len(authData) == 0 {
@@ -2120,7 +2123,7 @@ func CompleteSwitchWithOAuth(c *Context, w http.ResponseWriter, r *http.Request,
return
}
- if result := <-Srv.Store.User().UpdateAuthData(user.Id, service, authData); result.Err != nil {
+ if result := <-Srv.Store.User().UpdateAuthData(user.Id, service, authData, ssoEmail); result.Err != nil {
c.Err = result.Err
return
}
diff --git a/store/sql_user_store.go b/store/sql_user_store.go
index 0b6970c96..b1544289d 100644
--- a/store/sql_user_store.go
+++ b/store/sql_user_store.go
@@ -305,7 +305,7 @@ func (us SqlUserStore) UpdateFailedPasswordAttempts(userId string, attempts int)
return storeChannel
}
-func (us SqlUserStore) UpdateAuthData(userId, service, authData string) StoreChannel {
+func (us SqlUserStore) UpdateAuthData(userId, service, authData, email string) StoreChannel {
storeChannel := make(StoreChannel)
@@ -314,7 +314,24 @@ func (us SqlUserStore) UpdateAuthData(userId, service, authData string) StoreCha
updateAt := model.GetMillis()
- if _, err := us.GetMaster().Exec("UPDATE Users SET Password = '', LastPasswordUpdate = :LastPasswordUpdate, UpdateAt = :UpdateAt, FailedAttempts = 0, AuthService = :AuthService, AuthData = :AuthData WHERE Id = :UserId", map[string]interface{}{"LastPasswordUpdate": updateAt, "UpdateAt": updateAt, "UserId": userId, "AuthService": service, "AuthData": authData}); err != nil {
+ query := `
+ UPDATE
+ Users
+ SET
+ Password = '',
+ LastPasswordUpdate = :LastPasswordUpdate,
+ UpdateAt = :UpdateAt,
+ FailedAttempts = 0,
+ AuthService = :AuthService,
+ AuthData = :AuthData`
+
+ if len(email) != 0 {
+ query += ", Email = :Email"
+ }
+
+ query += " WHERE Id = :UserId"
+
+ if _, err := us.GetMaster().Exec(query, map[string]interface{}{"LastPasswordUpdate": updateAt, "UpdateAt": updateAt, "UserId": userId, "AuthService": service, "AuthData": authData, "Email": email}); err != nil {
result.Err = model.NewLocAppError("SqlUserStore.UpdateAuthData", "store.sql_user.update_auth_data.app_error", nil, "id="+userId+", "+err.Error())
} else {
result.Data = userId
diff --git a/store/sql_user_store_test.go b/store/sql_user_store_test.go
index d1ee5e647..2350bad30 100644
--- a/store/sql_user_store_test.go
+++ b/store/sql_user_store_test.go
@@ -402,7 +402,7 @@ func TestUserStoreUpdateAuthData(t *testing.T) {
service := "someservice"
authData := "1"
- if err := (<-store.User().UpdateAuthData(u1.Id, service, authData)).Err; err != nil {
+ if err := (<-store.User().UpdateAuthData(u1.Id, service, authData, "")).Err; err != nil {
t.Fatal(err)
}
diff --git a/store/store.go b/store/store.go
index cfc679706..2aa627734 100644
--- a/store/store.go
+++ b/store/store.go
@@ -111,7 +111,7 @@ type UserStore interface {
UpdateLastActivityAt(userId string, time int64) StoreChannel
UpdateUserAndSessionActivity(userId string, sessionId string, time int64) StoreChannel
UpdatePassword(userId, newPassword string) StoreChannel
- UpdateAuthData(userId, service, authData string) StoreChannel
+ UpdateAuthData(userId, service, authData, email string) StoreChannel
Get(id string) StoreChannel
GetProfiles(teamId string) StoreChannel
GetByEmail(teamId string, email string) StoreChannel