diff options
Diffstat (limited to 'app')
-rw-r--r-- | app/admin.go | 3 | ||||
-rw-r--r-- | app/app.go | 58 | ||||
-rw-r--r-- | app/app_test.go | 3 | ||||
-rw-r--r-- | app/apptestlib.go | 11 | ||||
-rw-r--r-- | app/channel.go | 30 | ||||
-rw-r--r-- | app/channel_test.go | 4 | ||||
-rw-r--r-- | app/config.go | 109 | ||||
-rw-r--r-- | app/config_test.go | 9 | ||||
-rw-r--r-- | app/diagnostics.go | 3 | ||||
-rw-r--r-- | app/email.go | 5 | ||||
-rw-r--r-- | app/email_batching.go | 18 | ||||
-rw-r--r-- | app/file.go | 42 | ||||
-rw-r--r-- | app/import.go | 6 | ||||
-rw-r--r-- | app/ldap.go | 4 | ||||
-rw-r--r-- | app/license.go | 109 | ||||
-rw-r--r-- | app/license_test.go | 75 | ||||
-rw-r--r-- | app/login.go | 3 | ||||
-rw-r--r-- | app/notification.go | 12 | ||||
-rw-r--r-- | app/notification_test.go | 6 | ||||
-rw-r--r-- | app/oauth.go | 13 | ||||
-rw-r--r-- | app/plugin.go | 1 | ||||
-rw-r--r-- | app/post_test.go | 2 | ||||
-rw-r--r-- | app/role.go | 6 | ||||
-rw-r--r-- | app/server.go | 12 | ||||
-rw-r--r-- | app/server_test.go | 50 | ||||
-rw-r--r-- | app/session_test.go | 22 | ||||
-rw-r--r-- | app/team.go | 2 | ||||
-rw-r--r-- | app/team_test.go | 1 | ||||
-rw-r--r-- | app/user.go | 18 | ||||
-rw-r--r-- | app/web_hub.go | 133 |
30 files changed, 614 insertions, 156 deletions
diff --git a/app/admin.go b/app/admin.go index b838ed3bd..154fa8899 100644 --- a/app/admin.go +++ b/app/admin.go @@ -237,7 +237,8 @@ func (a *App) TestEmail(userId string, cfg *model.Config) *model.AppError { return err } else { T := utils.GetUserTranslations(user.Locale) - if err := utils.SendMailUsingConfig(user.Email, T("api.admin.test_email.subject"), T("api.admin.test_email.body"), cfg); err != nil { + license := a.License() + if err := utils.SendMailUsingConfig(user.Email, T("api.admin.test_email.subject"), T("api.admin.test_email.body"), cfg, license != nil && *license.Features.Compliance); err != nil { return err } } diff --git a/app/app.go b/app/app.go index 1e46d29d0..26aed4c73 100644 --- a/app/app.go +++ b/app/app.go @@ -4,6 +4,7 @@ package app import ( + "crypto/ecdsa" "html/template" "net" "net/http" @@ -58,15 +59,22 @@ type App struct { configFile string configListeners map[string]func(*model.Config, *model.Config) + licenseValue atomic.Value + clientLicenseValue atomic.Value + licenseListeners map[string]func() + + siteURL string + newStore func() store.Store - htmlTemplateWatcher *utils.HTMLTemplateWatcher - sessionCache *utils.Cache - roles map[string]*model.Role - configListenerId string - licenseListenerId string - disableConfigWatch bool - configWatcher *utils.ConfigWatcher + htmlTemplateWatcher *utils.HTMLTemplateWatcher + sessionCache *utils.Cache + roles map[string]*model.Role + configListenerId string + licenseListenerId string + disableConfigWatch bool + configWatcher *utils.ConfigWatcher + asymmetricSigningKey *ecdsa.PrivateKey pluginCommands []*PluginCommand pluginCommandsLock sync.RWMutex @@ -80,7 +88,7 @@ var appCount = 0 // New creates a new App. You must call Shutdown when you're done with it. // XXX: For now, only one at a time is allowed as some resources are still shared. -func New(options ...Option) (*App, error) { +func New(options ...Option) (outApp *App, outErr error) { appCount++ if appCount > 1 { panic("Only one App should exist at a time. Did you forget to call Shutdown()?") @@ -91,11 +99,17 @@ func New(options ...Option) (*App, error) { Srv: &Server{ Router: mux.NewRouter(), }, - sessionCache: utils.NewLru(model.SESSION_CACHE_SIZE), - configFile: "config.json", - configListeners: make(map[string]func(*model.Config, *model.Config)), - clientConfig: make(map[string]string), - } + sessionCache: utils.NewLru(model.SESSION_CACHE_SIZE), + configFile: "config.json", + configListeners: make(map[string]func(*model.Config, *model.Config)), + clientConfig: make(map[string]string), + licenseListeners: map[string]func(){}, + } + defer func() { + if outErr != nil { + app.Shutdown() + } + }() for _, option := range options { option(app) @@ -118,9 +132,9 @@ func New(options ...Option) (*App, error) { app.configListenerId = app.AddConfigListener(func(_, _ *model.Config) { app.configOrLicenseListener() }) - app.licenseListenerId = utils.AddLicenseListener(app.configOrLicenseListener) + app.licenseListenerId = app.AddLicenseListener(app.configOrLicenseListener) app.regenerateClientConfig() - app.SetDefaultRolesBasedOnConfig() + app.setDefaultRolesBasedOnConfig() l4g.Info(utils.T("api.server.new_server.init.info")) @@ -139,6 +153,10 @@ func New(options ...Option) (*App, error) { } app.Srv.Store = app.newStore() + if err := app.ensureAsymmetricSigningKey(); err != nil { + return nil, errors.Wrapf(err, "unable to ensure asymmetric signing key") + } + app.initJobs() app.initBuiltInPlugins() @@ -157,7 +175,7 @@ func New(options ...Option) (*App, error) { func (a *App) configOrLicenseListener() { a.regenerateClientConfig() - a.SetDefaultRolesBasedOnConfig() + a.setDefaultRolesBasedOnConfig() } func (a *App) Shutdown() { @@ -171,7 +189,9 @@ func (a *App) Shutdown() { a.ShutDownPlugins() a.WaitForGoroutines() - a.Srv.Store.Close() + if a.Srv.Store != nil { + a.Srv.Store.Close() + } a.Srv = nil if a.htmlTemplateWatcher != nil { @@ -179,7 +199,7 @@ func (a *App) Shutdown() { } a.RemoveConfigListener(a.configListenerId) - utils.RemoveLicenseListener(a.licenseListenerId) + a.RemoveLicenseListener(a.licenseListenerId) l4g.Info(utils.T("api.server.stop_server.stopped.info")) a.DisableConfigWatch() @@ -448,5 +468,5 @@ func (a *App) Handle404(w http.ResponseWriter, r *http.Request) { l4g.Debug("%v: code=404 ip=%v", r.URL.Path, utils.GetIpAddress(r)) - utils.RenderWebError(err, w, r) + utils.RenderWebAppError(w, r, err, a.AsymmetricSigningKey()) } diff --git a/app/app_test.go b/app/app_test.go index 25b19ead8..09f8725d7 100644 --- a/app/app_test.go +++ b/app/app_test.go @@ -51,7 +51,8 @@ func TestAppRace(t *testing.T) { a, err := New() require.NoError(t, err) a.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" }) - a.StartServer() + serverErr := a.StartServer() + require.NoError(t, serverErr) a.Shutdown() } } diff --git a/app/apptestlib.go b/app/apptestlib.go index 09afc8f76..c7846c9b5 100644 --- a/app/apptestlib.go +++ b/app/apptestlib.go @@ -96,15 +96,20 @@ func setupTestHelper(enterprise bool) *TestHelper { if testStore != nil { th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" }) } - th.App.StartServer() + serverErr := th.App.StartServer() + if serverErr != nil { + panic(serverErr) + } + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = prevListenAddress }) th.App.Srv.Store.MarkSystemRanUnitTests() th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableOpenServer = true }) - utils.SetIsLicensed(enterprise) if enterprise { - utils.License().Features.SetDefaults() + th.App.SetLicense(model.NewTestLicense()) + } else { + th.App.SetLicense(nil) } return th diff --git a/app/channel.go b/app/channel.go index e4bf48654..8ac1f421c 100644 --- a/app/channel.go +++ b/app/channel.go @@ -1359,7 +1359,7 @@ func (a *App) PermanentDeleteChannel(channel *model.Channel) *model.AppError { // This function is intended for use from the CLI. It is not robust against people joining the channel while the move // is in progress, and therefore should not be used from the API without first fixing this potential race condition. -func (a *App) MoveChannel(team *model.Team, channel *model.Channel) *model.AppError { +func (a *App) MoveChannel(team *model.Team, channel *model.Channel, user *model.User) *model.AppError { // Check that all channel members are in the destination team. if channelMembers, err := a.GetChannelMembersPage(channel.Id, 0, 10000000); err != nil { return err @@ -1378,11 +1378,37 @@ func (a *App) MoveChannel(team *model.Team, channel *model.Channel) *model.AppEr } } - // Change the Team ID of the channel. + // keep instance of the previous team + var previousTeam *model.Team + if result := <-a.Srv.Store.Team().Get(channel.TeamId); result.Err != nil { + return result.Err + } else { + previousTeam = result.Data.(*model.Team) + } channel.TeamId = team.Id if result := <-a.Srv.Store.Channel().Update(channel); result.Err != nil { return result.Err } + a.postChannelMoveMessage(user, channel, previousTeam) + + return nil +} + +func (a *App) postChannelMoveMessage(user *model.User, channel *model.Channel, previousTeam *model.Team) *model.AppError { + + post := &model.Post{ + ChannelId: channel.Id, + Message: fmt.Sprintf(utils.T("api.team.move_channel.success"), previousTeam.Name), + Type: model.POST_MOVE_CHANNEL, + UserId: user.Id, + Props: model.StringInterface{ + "username": user.Username, + }, + } + + if _, err := a.CreatePost(post, channel, false); err != nil { + return model.NewAppError("postChannelMoveMessage", "api.team.move_channel.post.error", nil, err.Error(), http.StatusInternalServerError) + } return nil } diff --git a/app/channel_test.go b/app/channel_test.go index d83590a27..e4a0e4320 100644 --- a/app/channel_test.go +++ b/app/channel_test.go @@ -97,7 +97,7 @@ func TestMoveChannel(t *testing.T) { t.Fatal(err) } - if err := th.App.MoveChannel(targetTeam, channel1); err == nil { + if err := th.App.MoveChannel(targetTeam, channel1, th.BasicUser); err == nil { t.Fatal("Should have failed due to mismatched members.") } @@ -105,7 +105,7 @@ func TestMoveChannel(t *testing.T) { t.Fatal(err) } - if err := th.App.MoveChannel(targetTeam, channel1); err != nil { + if err := th.App.MoveChannel(targetTeam, channel1, th.BasicUser); err != nil { t.Fatal(err) } } diff --git a/app/config.go b/app/config.go index a2398f9e9..35a0c9a3f 100644 --- a/app/config.go +++ b/app/config.go @@ -4,10 +4,17 @@ package app import ( + "crypto/ecdsa" + "crypto/elliptic" "crypto/md5" + "crypto/rand" + "crypto/x509" + "encoding/base64" "encoding/json" "fmt" + "net/url" "runtime/debug" + "strings" l4g "github.com/alecthomas/log4go" @@ -48,7 +55,7 @@ func (a *App) LoadConfig(configFile string) *model.AppError { a.config.Store(cfg) - utils.SetSiteURL(*cfg.ServiceSettings.SiteURL) + a.siteURL = strings.TrimRight(*cfg.ServiceSettings.SiteURL, "/") a.InvokeConfigListeners(old, cfg) return nil @@ -116,8 +123,91 @@ func (a *App) InvokeConfigListeners(old, current *model.Config) { } } +// EnsureAsymmetricSigningKey ensures that an asymmetric signing key exists and future calls to +// AsymmetricSigningKey will always return a valid signing key. +func (a *App) ensureAsymmetricSigningKey() error { + if a.asymmetricSigningKey != nil { + return nil + } + + var key *model.SystemAsymmetricSigningKey + + result := <-a.Srv.Store.System().GetByName(model.SYSTEM_ASYMMETRIC_SIGNING_KEY) + if result.Err == nil { + if err := json.Unmarshal([]byte(result.Data.(*model.System).Value), &key); err != nil { + return err + } + } + + // If we don't already have a key, try to generate one. + if key == nil { + newECDSAKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + newKey := &model.SystemAsymmetricSigningKey{ + ECDSAKey: &model.SystemECDSAKey{ + Curve: "P-256", + X: newECDSAKey.X, + Y: newECDSAKey.Y, + D: newECDSAKey.D, + }, + } + system := &model.System{ + Name: model.SYSTEM_ASYMMETRIC_SIGNING_KEY, + } + v, err := json.Marshal(newKey) + if err != nil { + return err + } + system.Value = string(v) + if result = <-a.Srv.Store.System().Save(system); result.Err == nil { + // If we were able to save the key, use it, otherwise ignore the error. + key = newKey + } + } + + // If we weren't able to save a new key above, another server must have beat us to it. Get the + // key from the database, and if that fails, error out. + if key == nil { + result := <-a.Srv.Store.System().GetByName(model.SYSTEM_ASYMMETRIC_SIGNING_KEY) + if result.Err != nil { + return result.Err + } else if err := json.Unmarshal([]byte(result.Data.(*model.System).Value), &key); err != nil { + return err + } + } + + var curve elliptic.Curve + switch key.ECDSAKey.Curve { + case "P-256": + curve = elliptic.P256() + default: + return fmt.Errorf("unknown curve: " + key.ECDSAKey.Curve) + } + a.asymmetricSigningKey = &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: curve, + X: key.ECDSAKey.X, + Y: key.ECDSAKey.Y, + }, + D: key.ECDSAKey.D, + } + a.regenerateClientConfig() + return nil +} + +// AsymmetricSigningKey will return a private key that can be used for asymmetric signing. +func (a *App) AsymmetricSigningKey() *ecdsa.PrivateKey { + return a.asymmetricSigningKey +} + func (a *App) regenerateClientConfig() { - a.clientConfig = utils.GenerateClientConfig(a.Config(), a.DiagnosticId()) + a.clientConfig = utils.GenerateClientConfig(a.Config(), a.DiagnosticId(), a.License()) + if key := a.AsymmetricSigningKey(); key != nil { + der, _ := x509.MarshalPKIXPublicKey(&key.PublicKey) + a.clientConfig["AsymmetricSigningPublicKey"] = base64.StdEncoding.EncodeToString(der) + } clientConfigJSON, _ := json.Marshal(a.clientConfig) a.clientConfigHash = fmt.Sprintf("%x", md5.Sum(clientConfigJSON)) } @@ -167,10 +257,15 @@ func (a *App) Desanitize(cfg *model.Config) { } } -// License returns the currently active license or nil if the application is unlicensed. -func (a *App) License() *model.License { - if utils.IsLicensed() { - return utils.License() +func (a *App) GetCookieDomain() string { + if *a.Config().ServiceSettings.AllowCookiesForSubdomains { + if siteURL, err := url.Parse(*a.Config().ServiceSettings.SiteURL); err == nil { + return siteURL.Hostname() + } } - return nil + return "" +} + +func (a *App) GetSiteURL() string { + return a.siteURL } diff --git a/app/config_test.go b/app/config_test.go index e3d50b958..5ee999f0f 100644 --- a/app/config_test.go +++ b/app/config_test.go @@ -6,6 +6,8 @@ package app import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/mattermost/mattermost-server/model" ) @@ -54,3 +56,10 @@ func TestConfigListener(t *testing.T) { t.Fatal("listener 2 should've been called") } } + +func TestAsymmetricSigningKey(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + assert.NotNil(t, th.App.AsymmetricSigningKey()) + assert.NotEmpty(t, th.App.ClientConfig()["AsymmetricSigningPublicKey"]) +} diff --git a/app/diagnostics.go b/app/diagnostics.go index 809d9ff1e..12553afc8 100644 --- a/app/diagnostics.go +++ b/app/diagnostics.go @@ -243,6 +243,8 @@ func (a *App) trackConfig() { "isdefault_image_proxy_type": isDefault(*cfg.ServiceSettings.ImageProxyType, ""), "isdefault_image_proxy_url": isDefault(*cfg.ServiceSettings.ImageProxyURL, ""), "isdefault_image_proxy_options": isDefault(*cfg.ServiceSettings.ImageProxyOptions, ""), + "websocket_url": isDefault(*cfg.ServiceSettings.WebsocketURL, ""), + "allow_cookies_for_subdomains": *cfg.ServiceSettings.AllowCookiesForSubdomains, }) a.SendDiagnostic(TRACK_CONFIG_TEAM, map[string]interface{}{ @@ -501,6 +503,7 @@ func (a *App) trackConfig() { a.SendDiagnostic(TRACK_CONFIG_MESSAGE_EXPORT, map[string]interface{}{ "enable_message_export": *cfg.MessageExportSettings.EnableExport, + "export_format": *cfg.MessageExportSettings.ExportFormat, "daily_run_time": *cfg.MessageExportSettings.DailyRunTime, "default_export_from_timestamp": *cfg.MessageExportSettings.ExportFromTimestamp, "batch_size": *cfg.MessageExportSettings.BatchSize, diff --git a/app/email.go b/app/email.go index 764dc017a..8ee3e79e2 100644 --- a/app/email.go +++ b/app/email.go @@ -191,7 +191,7 @@ func (a *App) SendUserAccessTokenAddedEmail(email, locale string) *model.AppErro bodyPage := a.NewEmailTemplate("password_change_body", locale) bodyPage.Props["Title"] = T("api.templates.user_access_token_body.title") bodyPage.Html["Info"] = utils.TranslateAsHtml(T, "api.templates.user_access_token_body.info", - map[string]interface{}{"SiteName": a.ClientConfig()["SiteName"], "SiteURL": utils.GetSiteURL()}) + map[string]interface{}{"SiteName": a.ClientConfig()["SiteName"], "SiteURL": a.GetSiteURL()}) if err := a.SendMail(email, subject, bodyPage.Render()); err != nil { return model.NewAppError("SendUserAccessTokenAddedEmail", "api.user.send_user_access_token.error", nil, err.Error(), http.StatusInternalServerError) @@ -317,5 +317,6 @@ func (a *App) NewEmailTemplate(name, locale string) *utils.HTMLTemplate { } func (a *App) SendMail(to, subject, htmlBody string) *model.AppError { - return utils.SendMailUsingConfig(to, subject, htmlBody, a.Config()) + license := a.License() + return utils.SendMailUsingConfig(to, subject, htmlBody, a.Config(), license != nil && *license.Features.Compliance) } diff --git a/app/email_batching.go b/app/email_batching.go index 2a33d7d3e..07adda674 100644 --- a/app/email_batching.go +++ b/app/email_batching.go @@ -7,6 +7,7 @@ import ( "fmt" "html/template" "strconv" + "sync" "time" "github.com/mattermost/mattermost-server/model" @@ -57,6 +58,8 @@ type EmailBatchingJob struct { app *App newNotifications chan *batchedNotification pendingNotifications map[string][]*batchedNotification + task *model.ScheduledTask + taskMutex sync.Mutex } func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob { @@ -68,12 +71,17 @@ func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob { } func (job *EmailBatchingJob) Start() { - if task := model.GetTaskByName(EMAIL_BATCHING_TASK_NAME); task != nil { - task.Cancel() - } - l4g.Debug(utils.T("api.email_batching.start.starting"), *job.app.Config().EmailSettings.EmailBatchingInterval) - model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second) + newTask := model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second) + + job.taskMutex.Lock() + oldTask := job.task + job.task = newTask + job.taskMutex.Unlock() + + if oldTask != nil { + oldTask.Cancel() + } } func (job *EmailBatchingJob) Add(user *model.User, post *model.Post, team *model.Team) bool { diff --git a/app/file.go b/app/file.go index d66c64adb..06ee61c92 100644 --- a/app/file.go +++ b/app/file.go @@ -58,7 +58,8 @@ const ( ) func (a *App) FileBackend() (utils.FileBackend, *model.AppError) { - return utils.NewFileBackend(&a.Config().FileSettings) + license := a.License() + return utils.NewFileBackend(&a.Config().FileSettings, license != nil && *license.Features.Compliance) } func (a *App) ReadFile(path string) ([]byte, *model.AppError) { @@ -279,11 +280,38 @@ func GeneratePublicLinkHash(fileId, salt string) string { return base64.RawURLEncoding.EncodeToString(hash.Sum(nil)) } -func (a *App) UploadFiles(teamId string, channelId string, userId string, fileHeaders []*multipart.FileHeader, clientIds []string) (*model.FileUploadResponse, *model.AppError) { +func (a *App) UploadMultipartFiles(teamId string, channelId string, userId string, fileHeaders []*multipart.FileHeader, clientIds []string) (*model.FileUploadResponse, *model.AppError) { + files := make([]io.ReadCloser, len(fileHeaders)) + filenames := make([]string, len(fileHeaders)) + + for i, fileHeader := range fileHeaders { + file, fileErr := fileHeader.Open() + if fileErr != nil { + return nil, model.NewAppError("UploadFiles", "api.file.upload_file.bad_parse.app_error", nil, fileErr.Error(), http.StatusBadRequest) + } + + // Will be closed after UploadFiles returns + defer file.Close() + + files[i] = file + filenames[i] = fileHeader.Filename + } + + return a.UploadFiles(teamId, channelId, userId, files, filenames, clientIds) +} + +// Uploads some files to the given team and channel as the given user. files and filenames should have +// the same length. clientIds should either not be provided or have the same length as files and filenames. +// The provided files should be closed by the caller so that they are not leaked. +func (a *App) UploadFiles(teamId string, channelId string, userId string, files []io.ReadCloser, filenames []string, clientIds []string) (*model.FileUploadResponse, *model.AppError) { if len(*a.Config().FileSettings.DriverName) == 0 { return nil, model.NewAppError("uploadFile", "api.file.upload_file.storage.app_error", nil, "", http.StatusNotImplemented) } + if len(filenames) != len(files) || (len(clientIds) > 0 && len(clientIds) != len(files)) { + return nil, model.NewAppError("UploadFiles", "api.file.upload_file.incorrect_number_of_files.app_error", nil, "", http.StatusBadRequest) + } + resStruct := &model.FileUploadResponse{ FileInfos: []*model.FileInfo{}, ClientIds: []string{}, @@ -293,18 +321,12 @@ func (a *App) UploadFiles(teamId string, channelId string, userId string, fileHe thumbnailPathList := []string{} imageDataList := [][]byte{} - for i, fileHeader := range fileHeaders { - file, fileErr := fileHeader.Open() - if fileErr != nil { - return nil, model.NewAppError("UploadFiles", "api.file.upload_file.bad_parse.app_error", nil, fileErr.Error(), http.StatusBadRequest) - } - defer file.Close() - + for i, file := range files { buf := bytes.NewBuffer(nil) io.Copy(buf, file) data := buf.Bytes() - info, err := a.DoUploadFile(time.Now(), teamId, channelId, userId, fileHeader.Filename, data) + info, err := a.DoUploadFile(time.Now(), teamId, channelId, userId, filenames[i], data) if err != nil { return nil, err } diff --git a/app/import.go b/app/import.go index 6291794b0..5a3158fab 100644 --- a/app/import.go +++ b/app/import.go @@ -817,6 +817,12 @@ func (a *App) ImportUserTeams(user *model.User, data *[]UserTeamImportData) *mod } } + if defaultChannel, err := a.GetChannelByName(model.DEFAULT_CHANNEL, team.Id); err != nil { + return err + } else if _, err = a.addUserToChannel(user, defaultChannel, member); err != nil { + return err + } + if err := a.ImportUserChannels(user, team, member, tdata.Channels); err != nil { return err } diff --git a/app/ldap.go b/app/ldap.go index 179529c52..ff7a5ed21 100644 --- a/app/ldap.go +++ b/app/ldap.go @@ -67,7 +67,7 @@ func (a *App) SwitchEmailToLdap(email, password, code, ldapId, ldapPassword stri } a.Go(func() { - if err := a.SendSignInChangeEmail(user.Email, "AD/LDAP", user.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendSignInChangeEmail(user.Email, "AD/LDAP", user.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) @@ -113,7 +113,7 @@ func (a *App) SwitchLdapToEmail(ldapPassword, code, email, newPassword string) ( T := utils.GetUserTranslations(user.Locale) a.Go(func() { - if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) diff --git a/app/license.go b/app/license.go index c7fd07197..c12f23d1d 100644 --- a/app/license.go +++ b/app/license.go @@ -4,16 +4,19 @@ package app import ( + "crypto/md5" + "fmt" "net/http" "strings" l4g "github.com/alecthomas/log4go" + "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/utils" ) func (a *App) LoadLicense() { - utils.RemoveLicense() + a.SetLicense(nil) licenseId := "" if result := <-a.Srv.Store.System().Get(); result.Err == nil { @@ -36,7 +39,7 @@ func (a *App) LoadLicense() { if result := <-a.Srv.Store.License().Get(licenseId); result.Err == nil { record := result.Data.(*model.LicenseRecord) - utils.LoadLicense([]byte(record.Bytes)) + a.ValidateAndSetLicenseBytes([]byte(record.Bytes)) l4g.Info("License key valid unlocking enterprise features.") } else { l4g.Info(utils.T("mattermost.load_license.find.warn")) @@ -59,7 +62,7 @@ func (a *App) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) } } - if ok := utils.SetLicense(license); !ok { + if ok := a.SetLicense(license); !ok { return nil, model.NewAppError("addLicense", model.EXPIRED_LICENSE_ERROR, nil, "", http.StatusBadRequest) } @@ -102,21 +105,117 @@ func (a *App) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) return license, nil } +// License returns the currently active license or nil if the application is unlicensed. +func (a *App) License() *model.License { + license, _ := a.licenseValue.Load().(*model.License) + return license +} + +func (a *App) SetLicense(license *model.License) bool { + defer func() { + a.setDefaultRolesBasedOnConfig() + for _, listener := range a.licenseListeners { + listener() + } + }() + + if license != nil { + license.Features.SetDefaults() + + if !license.IsExpired() { + a.licenseValue.Store(license) + a.clientLicenseValue.Store(utils.GetClientLicense(license)) + return true + } + } + + a.licenseValue.Store((*model.License)(nil)) + a.clientLicenseValue.Store(map[string]string(nil)) + return false +} + +func (a *App) ValidateAndSetLicenseBytes(b []byte) { + if success, licenseStr := utils.ValidateLicense(b); success { + license := model.LicenseFromJson(strings.NewReader(licenseStr)) + a.SetLicense(license) + return + } + + l4g.Warn(utils.T("utils.license.load_license.invalid.warn")) +} + +func (a *App) SetClientLicense(m map[string]string) { + a.clientLicenseValue.Store(m) +} + +func (a *App) ClientLicense() map[string]string { + if clientLicense, _ := a.clientLicenseValue.Load().(map[string]string); clientLicense != nil { + return clientLicense + } + return map[string]string{"IsLicensed": "false"} +} + func (a *App) RemoveLicense() *model.AppError { - utils.RemoveLicense() + if license, _ := a.licenseValue.Load().(*model.License); license == nil { + return nil + } sysVar := &model.System{} sysVar.Name = model.SYSTEM_ACTIVE_LICENSE_ID sysVar.Value = "" if result := <-a.Srv.Store.System().SaveOrUpdate(sysVar); result.Err != nil { - utils.RemoveLicense() return result.Err } + a.SetLicense(nil) a.ReloadConfig() a.InvalidateAllCaches() return nil } + +func (a *App) AddLicenseListener(listener func()) string { + id := model.NewId() + a.licenseListeners[id] = listener + return id +} + +func (a *App) RemoveLicenseListener(id string) { + delete(a.licenseListeners, id) +} + +func (a *App) GetClientLicenseEtag(useSanitized bool) string { + value := "" + + lic := a.ClientLicense() + + if useSanitized { + lic = a.GetSanitizedClientLicense() + } + + for k, v := range lic { + value += fmt.Sprintf("%s:%s;", k, v) + } + + return model.Etag(fmt.Sprintf("%x", md5.Sum([]byte(value)))) +} + +func (a *App) GetSanitizedClientLicense() map[string]string { + sanitizedLicense := make(map[string]string) + + for k, v := range a.ClientLicense() { + sanitizedLicense[k] = v + } + + delete(sanitizedLicense, "Id") + delete(sanitizedLicense, "Name") + delete(sanitizedLicense, "Email") + delete(sanitizedLicense, "PhoneNumber") + delete(sanitizedLicense, "IssuedAt") + delete(sanitizedLicense, "StartsAt") + delete(sanitizedLicense, "ExpiresAt") + + return sanitizedLicense +} diff --git a/app/license_test.go b/app/license_test.go index 5b73d9d18..f86d604d1 100644 --- a/app/license_test.go +++ b/app/license_test.go @@ -4,8 +4,9 @@ package app import ( - //"github.com/mattermost/mattermost-server/model" "testing" + + "github.com/mattermost/mattermost-server/model" ) func TestLoadLicense(t *testing.T) { @@ -37,3 +38,75 @@ func TestRemoveLicense(t *testing.T) { t.Fatal("should have removed license") } } + +func TestSetLicense(t *testing.T) { + th := Setup() + defer th.TearDown() + + l1 := &model.License{} + l1.Features = &model.Features{} + l1.Customer = &model.Customer{} + l1.StartsAt = model.GetMillis() - 1000 + l1.ExpiresAt = model.GetMillis() + 100000 + if ok := th.App.SetLicense(l1); !ok { + t.Fatal("license should have worked") + } + + l2 := &model.License{} + l2.Features = &model.Features{} + l2.Customer = &model.Customer{} + l2.StartsAt = model.GetMillis() - 1000 + l2.ExpiresAt = model.GetMillis() - 100 + if ok := th.App.SetLicense(l2); ok { + t.Fatal("license should have failed") + } + + l3 := &model.License{} + l3.Features = &model.Features{} + l3.Customer = &model.Customer{} + l3.StartsAt = model.GetMillis() + 10000 + l3.ExpiresAt = model.GetMillis() + 100000 + if ok := th.App.SetLicense(l3); !ok { + t.Fatal("license should have passed") + } +} + +func TestClientLicenseEtag(t *testing.T) { + th := Setup() + defer th.TearDown() + + etag1 := th.App.GetClientLicenseEtag(false) + + th.App.SetClientLicense(map[string]string{"SomeFeature": "true", "IsLicensed": "true"}) + + etag2 := th.App.GetClientLicenseEtag(false) + if etag1 == etag2 { + t.Fatal("etags should not match") + } + + th.App.SetClientLicense(map[string]string{"SomeFeature": "true", "IsLicensed": "false"}) + + etag3 := th.App.GetClientLicenseEtag(false) + if etag2 == etag3 { + t.Fatal("etags should not match") + } +} + +func TestGetSanitizedClientLicense(t *testing.T) { + th := Setup() + defer th.TearDown() + + l1 := &model.License{} + l1.Features = &model.Features{} + l1.Customer = &model.Customer{} + l1.Customer.Name = "TestName" + l1.StartsAt = model.GetMillis() - 1000 + l1.ExpiresAt = model.GetMillis() + 100000 + th.App.SetLicense(l1) + + m := th.App.GetSanitizedClientLicense() + + if _, ok := m["Name"]; ok { + t.Fatal("should have been sanatized") + } +} diff --git a/app/login.go b/app/login.go index ecc0f0163..e01566bcd 100644 --- a/app/login.go +++ b/app/login.go @@ -113,6 +113,7 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User, secure = true } + domain := a.GetCookieDomain() expiresAt := time.Unix(model.GetMillis()/1000+int64(maxAge), 0) sessionCookie := &http.Cookie{ Name: model.SESSION_COOKIE_TOKEN, @@ -121,6 +122,7 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User, MaxAge: maxAge, Expires: expiresAt, HttpOnly: true, + Domain: domain, Secure: secure, } @@ -130,6 +132,7 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User, Path: "/", MaxAge: maxAge, Expires: expiresAt, + Domain: domain, Secure: secure, } diff --git a/app/notification.go b/app/notification.go index 24e84500b..8cb63fbaf 100644 --- a/app/notification.go +++ b/app/notification.go @@ -362,7 +362,7 @@ func (a *App) sendNotificationEmail(post *model.Post, user *model.User, channel emailNotificationContentsType = *a.Config().EmailSettings.EmailNotificationContentsType } - teamURL := utils.GetSiteURL() + "/" + team.Name + teamURL := a.GetSiteURL() + "/" + team.Name var bodyText = a.getNotificationEmailBody(user, post, channel, senderName, team.Name, teamURL, emailNotificationContentsType, translateFunc) a.Go(func() { @@ -421,7 +421,7 @@ func (a *App) getNotificationEmailBody(recipient *model.User, post *model.Post, bodyPage = a.NewEmailTemplate("post_body_generic", recipient.Locale) } - bodyPage.Props["SiteURL"] = utils.GetSiteURL() + bodyPage.Props["SiteURL"] = a.GetSiteURL() if teamName != "select_team" { bodyPage.Props["TeamLink"] = teamURL + "/pl/" + post.Id } else { @@ -623,7 +623,9 @@ func (a *App) getPushNotificationMessage(postMessage string, wasMentioned bool, message := "" category := "" - if *a.Config().EmailSettings.PushNotificationContents == model.FULL_NOTIFICATION { + contentsConfig := *a.Config().EmailSettings.PushNotificationContents + + if contentsConfig == model.FULL_NOTIFICATION { category = model.CATEGORY_CAN_REPLY if channelType == model.CHANNEL_DIRECT { @@ -631,7 +633,7 @@ func (a *App) getPushNotificationMessage(postMessage string, wasMentioned bool, } else { message = senderName + userLocale("api.post.send_notifications_and_forget.push_in") + channelName + ": " + model.ClearMentionTags(postMessage) } - } else if *a.Config().EmailSettings.PushNotificationContents == model.GENERIC_NO_CHANNEL_NOTIFICATION { + } else if contentsConfig == model.GENERIC_NO_CHANNEL_NOTIFICATION { if channelType == model.CHANNEL_DIRECT { category = model.CATEGORY_CAN_REPLY @@ -659,6 +661,8 @@ func (a *App) getPushNotificationMessage(postMessage string, wasMentioned bool, if len(postMessage) == 0 && hasFiles { if channelType == model.CHANNEL_DIRECT { message = senderName + userLocale("api.post.send_notifications_and_forget.push_image_only_dm") + } else if contentsConfig == model.GENERIC_NO_CHANNEL_NOTIFICATION { + message = senderName + userLocale("api.post.send_notifications_and_forget.push_image_only_no_channel") } else { message = senderName + userLocale("api.post.send_notifications_and_forget.push_image_only") + channelName } diff --git a/app/notification_test.go b/app/notification_test.go index 43703c019..5fc1d152c 100644 --- a/app/notification_test.go +++ b/app/notification_test.go @@ -1414,6 +1414,12 @@ func TestGetPushNotificationMessage(t *testing.T) { ExpectedMessage: "user uploaded one or more files in a direct message", ExpectedCategory: model.CATEGORY_CAN_REPLY, }, + "only files without channel, public channel": { + HasFiles: true, + PushNotificationContents: model.GENERIC_NO_CHANNEL_NOTIFICATION, + ChannelType: model.CHANNEL_OPEN, + ExpectedMessage: "user uploaded one or more files", + }, } { t.Run(name, func(t *testing.T) { locale := tc.Locale diff --git a/app/oauth.go b/app/oauth.go index 5a66f542e..630fd3e2d 100644 --- a/app/oauth.go +++ b/app/oauth.go @@ -527,7 +527,7 @@ func (a *App) CompleteSwitchWithOAuth(service string, userData io.ReadCloser, em } a.Go(func() { - if err := a.SendSignInChangeEmail(user.Email, strings.Title(service)+" SSO", user.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendSignInChangeEmail(user.Email, strings.Title(service)+" SSO", user.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) @@ -600,7 +600,12 @@ func (a *App) GetAuthorizationCode(w http.ResponseWriter, r *http.Request, servi props["token"] = stateToken.Token state := b64.StdEncoding.EncodeToString([]byte(model.MapToJson(props))) - redirectUri := utils.GetSiteURL() + "/signup/" + service + "/complete" + siteUrl := a.GetSiteURL() + if strings.TrimSpace(siteUrl) == "" { + siteUrl = GetProtocol(r) + "://" + r.Host + } + + redirectUri := siteUrl + "/signup/" + service + "/complete" authUrl := endpoint + "?response_type=code&client_id=" + clientId + "&redirect_uri=" + url.QueryEscape(redirectUri) + "&state=" + url.QueryEscape(state) @@ -736,7 +741,7 @@ func (a *App) SwitchEmailToOAuth(w http.ResponseWriter, r *http.Request, email, stateProps["email"] = email if service == model.USER_AUTH_SERVICE_SAML { - return utils.GetSiteURL() + "/login/sso/saml?action=" + model.OAUTH_ACTION_EMAIL_TO_SSO + "&email=" + utils.UrlEncode(email), nil + return a.GetSiteURL() + "/login/sso/saml?action=" + model.OAUTH_ACTION_EMAIL_TO_SSO + "&email=" + utils.UrlEncode(email), nil } else { if authUrl, err := a.GetAuthorizationCode(w, r, service, stateProps, ""); err != nil { return "", err @@ -768,7 +773,7 @@ func (a *App) SwitchOAuthToEmail(email, password, requesterId string) (string, * T := utils.GetUserTranslations(user.Locale) a.Go(func() { - if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) diff --git a/app/plugin.go b/app/plugin.go index 3f06a000f..fe671d26a 100644 --- a/app/plugin.go +++ b/app/plugin.go @@ -565,6 +565,7 @@ func (a *App) RegisterPluginCommand(pluginId string, command *model.Command) err TeamId: command.TeamId, AutoComplete: command.AutoComplete, AutoCompleteDesc: command.AutoCompleteDesc, + AutoCompleteHint: command.AutoCompleteHint, DisplayName: command.DisplayName, } diff --git a/app/post_test.go b/app/post_test.go index ebe973270..2472e40c6 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -345,4 +345,4 @@ func TestMakeOpenGraphURLsAbsolute(t *testing.T) { } }) } -}
\ No newline at end of file +} diff --git a/app/role.go b/app/role.go index 5f39dd623..9f271ea7a 100644 --- a/app/role.go +++ b/app/role.go @@ -12,8 +12,6 @@ func (a *App) Role(id string) *model.Role { return a.roles[id] } -// Updates the roles based on the app config and the global license check. You may need to invoke -// this when license changes are made. -func (a *App) SetDefaultRolesBasedOnConfig() { - a.roles = utils.DefaultRolesBasedOnConfig(a.Config()) +func (a *App) setDefaultRolesBasedOnConfig() { + a.roles = utils.DefaultRolesBasedOnConfig(a.Config(), a.License() != nil) } diff --git a/app/server.go b/app/server.go index 1659908b6..afa282ad6 100644 --- a/app/server.go +++ b/app/server.go @@ -17,6 +17,7 @@ import ( l4g "github.com/alecthomas/log4go" "github.com/gorilla/handlers" "github.com/gorilla/mux" + "github.com/pkg/errors" "golang.org/x/crypto/acme/autocert" "github.com/mattermost/mattermost-server/model" @@ -116,7 +117,7 @@ func redirectHTTPToHTTPS(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, url.String(), http.StatusFound) } -func (a *App) StartServer() { +func (a *App) StartServer() error { l4g.Info(utils.T("api.server.start_server.starting.info")) var handler http.Handler = &CorsWrapper{a.Config, a.Srv.Router} @@ -126,8 +127,7 @@ func (a *App) StartServer() { rateLimiter, err := NewRateLimiter(&a.Config().RateLimitSettings) if err != nil { - l4g.Critical(err.Error()) - return + return err } a.Srv.RateLimiter = rateLimiter @@ -151,8 +151,8 @@ func (a *App) StartServer() { listener, err := net.Listen("tcp", addr) if err != nil { - l4g.Critical(utils.T("api.server.start_server.starting.critical"), err) - return + errors.Wrapf(err, utils.T("api.server.start_server.starting.critical"), err) + return err } a.Srv.ListenAddr = listener.Addr().(*net.TCPAddr) @@ -214,6 +214,8 @@ func (a *App) StartServer() { } close(a.Srv.didFinishListen) }() + + return nil } type tcpKeepAliveListener struct { diff --git a/app/server_test.go b/app/server_test.go new file mode 100644 index 000000000..de358b976 --- /dev/null +++ b/app/server_test.go @@ -0,0 +1,50 @@ +// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "testing" + + "github.com/mattermost/mattermost-server/model" + "github.com/stretchr/testify/require" +) + +func TestStartServerSuccess(t *testing.T) { + a, err := New() + require.NoError(t, err) + + a.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" }) + serverErr := a.StartServer() + a.Shutdown() + require.NoError(t, serverErr) +} + +func TestStartServerRateLimiterCriticalError(t *testing.T) { + a, err := New() + require.NoError(t, err) + + // Attempt to use Rate Limiter with an invalid config + a.UpdateConfig(func(cfg *model.Config) { + *cfg.RateLimitSettings.Enable = true + *cfg.RateLimitSettings.MaxBurst = -100 + }) + + serverErr := a.StartServer() + a.Shutdown() + require.Error(t, serverErr) +} + +func TestStartServerPortUnavailable(t *testing.T) { + a, err := New() + require.NoError(t, err) + + // Attempt to listen on a system-reserved port + a.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.ListenAddress = ":21" + }) + + serverErr := a.StartServer() + a.Shutdown() + require.Error(t, serverErr) +} diff --git a/app/session_test.go b/app/session_test.go index bca3b59b7..bf8198a4e 100644 --- a/app/session_test.go +++ b/app/session_test.go @@ -6,11 +6,10 @@ package app import ( "testing" - "github.com/mattermost/mattermost-server/model" - "github.com/mattermost/mattermost-server/utils" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost-server/model" ) func TestCache(t *testing.T) { @@ -48,18 +47,7 @@ func TestGetSessionIdleTimeoutInMinutes(t *testing.T) { session, _ = th.App.CreateSession(session) - isLicensed := utils.IsLicensed() - license := utils.License() - timeout := *th.App.Config().ServiceSettings.SessionIdleTimeoutInMinutes - defer func() { - utils.SetIsLicensed(isLicensed) - utils.SetLicense(license) - th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.SessionIdleTimeoutInMinutes = timeout }) - }() - utils.SetIsLicensed(true) - utils.SetLicense(&model.License{Features: &model.Features{}}) - utils.License().Features.SetDefaults() - *utils.License().Features.Compliance = true + th.App.SetLicense(model.NewTestLicense("compliance")) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.SessionIdleTimeoutInMinutes = 5 }) rsession, err := th.App.GetSession(session.Token) @@ -122,7 +110,7 @@ func TestGetSessionIdleTimeoutInMinutes(t *testing.T) { assert.Nil(t, err) // Test regular session with license off, should not timeout - *utils.License().Features.Compliance = false + th.App.SetLicense(nil) session = &model.Session{ UserId: model.NewId(), @@ -136,7 +124,7 @@ func TestGetSessionIdleTimeoutInMinutes(t *testing.T) { _, err = th.App.GetSession(session.Token) assert.Nil(t, err) - *utils.License().Features.Compliance = true + th.App.SetLicense(model.NewTestLicense("compliance")) // Test regular session with timeout set to 0, should not timeout th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.SessionIdleTimeoutInMinutes = 0 }) diff --git a/app/team.go b/app/team.go index a15c64c3f..d8750bfbb 100644 --- a/app/team.go +++ b/app/team.go @@ -741,7 +741,7 @@ func (a *App) InviteNewUsersToTeam(emailList []string, teamId, senderId string) } nameFormat := *a.Config().TeamSettings.TeammateNameDisplay - a.SendInviteEmails(team, user.GetDisplayName(nameFormat), emailList, utils.GetSiteURL()) + a.SendInviteEmails(team, user.GetDisplayName(nameFormat), emailList, a.GetSiteURL()) return nil } diff --git a/app/team_test.go b/app/team_test.go index a2bf44a57..cdfec12da 100644 --- a/app/team_test.go +++ b/app/team_test.go @@ -481,7 +481,6 @@ func TestJoinUserToTeam(t *testing.T) { maxUsersPerTeam := th.App.Config().TeamSettings.MaxUsersPerTeam defer func() { th.App.UpdateConfig(func(cfg *model.Config) { cfg.TeamSettings.MaxUsersPerTeam = maxUsersPerTeam }) - th.App.SetDefaultRolesBasedOnConfig() th.App.PermanentDeleteTeam(team) }() one := 1 diff --git a/app/user.go b/app/user.go index 69c6d072b..f915f35cb 100644 --- a/app/user.go +++ b/app/user.go @@ -106,7 +106,7 @@ func (a *App) CreateUserWithInviteId(user *model.User, inviteId string) (*model. a.AddDirectChannels(team.Id, ruser) - if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } @@ -119,7 +119,7 @@ func (a *App) CreateUserAsAdmin(user *model.User) (*model.User, *model.AppError) return nil, err } - if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } @@ -143,7 +143,7 @@ func (a *App) CreateUserFromSignup(user *model.User) (*model.User, *model.AppErr return nil, err } - if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } @@ -1027,7 +1027,7 @@ func (a *App) UpdateUser(user *model.User, sendNotifications bool) (*model.User, if sendNotifications { if rusers[0].Email != rusers[1].Email { a.Go(func() { - if err := a.SendEmailChangeEmail(rusers[1].Email, rusers[0].Email, rusers[0].Locale, utils.GetSiteURL()); err != nil { + if err := a.SendEmailChangeEmail(rusers[1].Email, rusers[0].Email, rusers[0].Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) @@ -1041,7 +1041,7 @@ func (a *App) UpdateUser(user *model.User, sendNotifications bool) (*model.User, if rusers[0].Username != rusers[1].Username { a.Go(func() { - if err := a.SendChangeUsernameEmail(rusers[1].Username, rusers[0].Username, rusers[0].Email, rusers[0].Locale, utils.GetSiteURL()); err != nil { + if err := a.SendChangeUsernameEmail(rusers[1].Username, rusers[0].Username, rusers[0].Email, rusers[0].Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) @@ -1091,7 +1091,7 @@ func (a *App) UpdateMfa(activate bool, userId, token string) *model.AppError { return } - if err := a.SendMfaChangeEmail(user.Email, activate, user.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendMfaChangeEmail(user.Email, activate, user.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) @@ -1129,7 +1129,7 @@ func (a *App) UpdatePasswordSendEmail(user *model.User, newPassword, method stri } a.Go(func() { - if err := a.SendPasswordChangeEmail(user.Email, method, user.Locale, utils.GetSiteURL()); err != nil { + if err := a.SendPasswordChangeEmail(user.Email, method, user.Locale, a.GetSiteURL()); err != nil { l4g.Error(err.Error()) } }) @@ -1342,9 +1342,9 @@ func (a *App) SendEmailVerification(user *model.User) *model.AppError { } if _, err := a.GetStatus(user.Id); err != nil { - return a.SendVerifyEmail(user.Email, user.Locale, utils.GetSiteURL(), token.Token) + return a.SendVerifyEmail(user.Email, user.Locale, a.GetSiteURL(), token.Token) } else { - return a.SendEmailChangeVerifyEmail(user.Email, user.Locale, utils.GetSiteURL(), token.Token) + return a.SendEmailChangeVerifyEmail(user.Email, user.Locale, a.GetSiteURL(), token.Token) } } diff --git a/app/web_hub.go b/app/web_hub.go index eeae13e09..c1c8cb7bb 100644 --- a/app/web_hub.go +++ b/app/web_hub.go @@ -30,7 +30,6 @@ type Hub struct { // See https://github.com/mattermost/mattermost-server/pull/7281 connectionCount int64 app *App - connections []*WebConn connectionIndex int register chan *WebConn unregister chan *WebConn @@ -47,7 +46,6 @@ func (a *App) NewWebHub() *Hub { app: a, register: make(chan *WebConn, 1), unregister: make(chan *WebConn, 1), - connections: make([]*WebConn, 0, model.SESSION_CACHE_SIZE), broadcast: make(chan *model.WebSocketEvent, BROADCAST_QUEUE_SIZE), stop: make(chan struct{}), didStop: make(chan struct{}), @@ -170,8 +168,14 @@ func (a *App) Publish(message *model.WebSocketEvent) { } func (a *App) PublishSkipClusterSend(message *model.WebSocketEvent) { - for _, hub := range a.Hubs { - hub.Broadcast(message) + if message.Broadcast.UserId != "" { + if len(a.Hubs) != 0 { + a.GetHubForUserId(message.Broadcast.UserId).Broadcast(message) + } + } else { + for _, hub := range a.Hubs { + hub.Broadcast(message) + } } } @@ -362,80 +366,53 @@ func (h *Hub) Start() { var doRecover func() doStart = func() { - h.goroutineId = getGoroutineId() l4g.Debug("Hub for index %v is starting with goroutine %v", h.connectionIndex, h.goroutineId) + connections := newHubConnectionIndex() + for { select { case webCon := <-h.register: - h.connections = append(h.connections, webCon) - atomic.StoreInt64(&h.connectionCount, int64(len(h.connections))) - + connections.Add(webCon) + atomic.StoreInt64(&h.connectionCount, int64(len(connections.All()))) case webCon := <-h.unregister: - userId := webCon.UserId - - found := false - indexToDel := -1 - for i, webConCandidate := range h.connections { - if webConCandidate == webCon { - indexToDel = i - continue - } - if userId == webConCandidate.UserId { - found = true - if indexToDel != -1 { - break - } - } - } + connections.Remove(webCon) - if indexToDel != -1 { - // Delete the webcon we are unregistering - h.connections[indexToDel] = h.connections[len(h.connections)-1] - h.connections = h.connections[:len(h.connections)-1] - } - - if len(userId) == 0 { + if len(webCon.UserId) == 0 { continue } - if !found { + if len(connections.ForUser(webCon.UserId)) == 0 { h.app.Go(func() { - h.app.SetStatusOffline(userId, false) + h.app.SetStatusOffline(webCon.UserId, false) }) } - case userId := <-h.invalidateUser: - for _, webCon := range h.connections { - if webCon.UserId == userId { - webCon.InvalidateCache() - } + for _, webCon := range connections.ForUser(userId) { + webCon.InvalidateCache() } - case msg := <-h.broadcast: - for _, webCon := range h.connections { + candidates := connections.All() + if msg.Broadcast.UserId != "" { + candidates = connections.ForUser(msg.Broadcast.UserId) + } + msg.PrecomputeJSON() + for _, webCon := range candidates { if webCon.ShouldSendEvent(msg) { select { case webCon.Send <- msg: default: l4g.Error(fmt.Sprintf("webhub.broadcast: cannot send, closing websocket for userId=%v", webCon.UserId)) close(webCon.Send) - for i, webConCandidate := range h.connections { - if webConCandidate == webCon { - h.connections[i] = h.connections[len(h.connections)-1] - h.connections = h.connections[:len(h.connections)-1] - break - } - } + connections.Remove(webCon) } } } - case <-h.stop: userIds := make(map[string]bool) - for _, webCon := range h.connections { + for _, webCon := range connections.All() { userIds[webCon.UserId] = true webCon.Close() } @@ -444,7 +421,6 @@ func (h *Hub) Start() { h.app.SetStatusOffline(userId, false) } - h.connections = make([]*WebConn, 0, model.SESSION_CACHE_SIZE) h.ExplicitStop = true close(h.didStop) @@ -474,3 +450,60 @@ func (h *Hub) Start() { go doRecoverableStart() } + +type hubConnectionIndexIndexes struct { + connections int + connectionsByUserId int +} + +// hubConnectionIndex provides fast addition, removal, and iteration of web connections. +type hubConnectionIndex struct { + connections []*WebConn + connectionsByUserId map[string][]*WebConn + connectionIndexes map[*WebConn]*hubConnectionIndexIndexes +} + +func newHubConnectionIndex() *hubConnectionIndex { + return &hubConnectionIndex{ + connections: make([]*WebConn, 0, model.SESSION_CACHE_SIZE), + connectionsByUserId: make(map[string][]*WebConn), + connectionIndexes: make(map[*WebConn]*hubConnectionIndexIndexes), + } +} + +func (i *hubConnectionIndex) Add(wc *WebConn) { + i.connections = append(i.connections, wc) + i.connectionsByUserId[wc.UserId] = append(i.connectionsByUserId[wc.UserId], wc) + i.connectionIndexes[wc] = &hubConnectionIndexIndexes{ + connections: len(i.connections) - 1, + connectionsByUserId: len(i.connectionsByUserId[wc.UserId]) - 1, + } +} + +func (i *hubConnectionIndex) Remove(wc *WebConn) { + indexes, ok := i.connectionIndexes[wc] + if !ok { + return + } + + last := i.connections[len(i.connections)-1] + i.connections[indexes.connections] = last + i.connections = i.connections[:len(i.connections)-1] + i.connectionIndexes[last].connections = indexes.connections + + userConnections := i.connectionsByUserId[wc.UserId] + last = userConnections[len(userConnections)-1] + userConnections[indexes.connectionsByUserId] = last + i.connectionsByUserId[wc.UserId] = userConnections[:len(userConnections)-1] + i.connectionIndexes[last].connectionsByUserId = indexes.connectionsByUserId + + delete(i.connectionIndexes, wc) +} + +func (i *hubConnectionIndex) ForUser(id string) []*WebConn { + return i.connectionsByUserId[id] +} + +func (i *hubConnectionIndex) All() []*WebConn { + return i.connections +} |