summaryrefslogtreecommitdiffstats
path: root/app
diff options
context:
space:
mode:
Diffstat (limited to 'app')
-rw-r--r--app/admin.go3
-rw-r--r--app/app.go58
-rw-r--r--app/app_test.go3
-rw-r--r--app/apptestlib.go11
-rw-r--r--app/channel.go30
-rw-r--r--app/channel_test.go4
-rw-r--r--app/config.go109
-rw-r--r--app/config_test.go9
-rw-r--r--app/diagnostics.go3
-rw-r--r--app/email.go5
-rw-r--r--app/email_batching.go18
-rw-r--r--app/file.go42
-rw-r--r--app/import.go6
-rw-r--r--app/ldap.go4
-rw-r--r--app/license.go109
-rw-r--r--app/license_test.go75
-rw-r--r--app/login.go3
-rw-r--r--app/notification.go12
-rw-r--r--app/notification_test.go6
-rw-r--r--app/oauth.go13
-rw-r--r--app/plugin.go1
-rw-r--r--app/post_test.go2
-rw-r--r--app/role.go6
-rw-r--r--app/server.go12
-rw-r--r--app/server_test.go50
-rw-r--r--app/session_test.go22
-rw-r--r--app/team.go2
-rw-r--r--app/team_test.go1
-rw-r--r--app/user.go18
-rw-r--r--app/web_hub.go133
30 files changed, 614 insertions, 156 deletions
diff --git a/app/admin.go b/app/admin.go
index b838ed3bd..154fa8899 100644
--- a/app/admin.go
+++ b/app/admin.go
@@ -237,7 +237,8 @@ func (a *App) TestEmail(userId string, cfg *model.Config) *model.AppError {
return err
} else {
T := utils.GetUserTranslations(user.Locale)
- if err := utils.SendMailUsingConfig(user.Email, T("api.admin.test_email.subject"), T("api.admin.test_email.body"), cfg); err != nil {
+ license := a.License()
+ if err := utils.SendMailUsingConfig(user.Email, T("api.admin.test_email.subject"), T("api.admin.test_email.body"), cfg, license != nil && *license.Features.Compliance); err != nil {
return err
}
}
diff --git a/app/app.go b/app/app.go
index 1e46d29d0..26aed4c73 100644
--- a/app/app.go
+++ b/app/app.go
@@ -4,6 +4,7 @@
package app
import (
+ "crypto/ecdsa"
"html/template"
"net"
"net/http"
@@ -58,15 +59,22 @@ type App struct {
configFile string
configListeners map[string]func(*model.Config, *model.Config)
+ licenseValue atomic.Value
+ clientLicenseValue atomic.Value
+ licenseListeners map[string]func()
+
+ siteURL string
+
newStore func() store.Store
- htmlTemplateWatcher *utils.HTMLTemplateWatcher
- sessionCache *utils.Cache
- roles map[string]*model.Role
- configListenerId string
- licenseListenerId string
- disableConfigWatch bool
- configWatcher *utils.ConfigWatcher
+ htmlTemplateWatcher *utils.HTMLTemplateWatcher
+ sessionCache *utils.Cache
+ roles map[string]*model.Role
+ configListenerId string
+ licenseListenerId string
+ disableConfigWatch bool
+ configWatcher *utils.ConfigWatcher
+ asymmetricSigningKey *ecdsa.PrivateKey
pluginCommands []*PluginCommand
pluginCommandsLock sync.RWMutex
@@ -80,7 +88,7 @@ var appCount = 0
// New creates a new App. You must call Shutdown when you're done with it.
// XXX: For now, only one at a time is allowed as some resources are still shared.
-func New(options ...Option) (*App, error) {
+func New(options ...Option) (outApp *App, outErr error) {
appCount++
if appCount > 1 {
panic("Only one App should exist at a time. Did you forget to call Shutdown()?")
@@ -91,11 +99,17 @@ func New(options ...Option) (*App, error) {
Srv: &Server{
Router: mux.NewRouter(),
},
- sessionCache: utils.NewLru(model.SESSION_CACHE_SIZE),
- configFile: "config.json",
- configListeners: make(map[string]func(*model.Config, *model.Config)),
- clientConfig: make(map[string]string),
- }
+ sessionCache: utils.NewLru(model.SESSION_CACHE_SIZE),
+ configFile: "config.json",
+ configListeners: make(map[string]func(*model.Config, *model.Config)),
+ clientConfig: make(map[string]string),
+ licenseListeners: map[string]func(){},
+ }
+ defer func() {
+ if outErr != nil {
+ app.Shutdown()
+ }
+ }()
for _, option := range options {
option(app)
@@ -118,9 +132,9 @@ func New(options ...Option) (*App, error) {
app.configListenerId = app.AddConfigListener(func(_, _ *model.Config) {
app.configOrLicenseListener()
})
- app.licenseListenerId = utils.AddLicenseListener(app.configOrLicenseListener)
+ app.licenseListenerId = app.AddLicenseListener(app.configOrLicenseListener)
app.regenerateClientConfig()
- app.SetDefaultRolesBasedOnConfig()
+ app.setDefaultRolesBasedOnConfig()
l4g.Info(utils.T("api.server.new_server.init.info"))
@@ -139,6 +153,10 @@ func New(options ...Option) (*App, error) {
}
app.Srv.Store = app.newStore()
+ if err := app.ensureAsymmetricSigningKey(); err != nil {
+ return nil, errors.Wrapf(err, "unable to ensure asymmetric signing key")
+ }
+
app.initJobs()
app.initBuiltInPlugins()
@@ -157,7 +175,7 @@ func New(options ...Option) (*App, error) {
func (a *App) configOrLicenseListener() {
a.regenerateClientConfig()
- a.SetDefaultRolesBasedOnConfig()
+ a.setDefaultRolesBasedOnConfig()
}
func (a *App) Shutdown() {
@@ -171,7 +189,9 @@ func (a *App) Shutdown() {
a.ShutDownPlugins()
a.WaitForGoroutines()
- a.Srv.Store.Close()
+ if a.Srv.Store != nil {
+ a.Srv.Store.Close()
+ }
a.Srv = nil
if a.htmlTemplateWatcher != nil {
@@ -179,7 +199,7 @@ func (a *App) Shutdown() {
}
a.RemoveConfigListener(a.configListenerId)
- utils.RemoveLicenseListener(a.licenseListenerId)
+ a.RemoveLicenseListener(a.licenseListenerId)
l4g.Info(utils.T("api.server.stop_server.stopped.info"))
a.DisableConfigWatch()
@@ -448,5 +468,5 @@ func (a *App) Handle404(w http.ResponseWriter, r *http.Request) {
l4g.Debug("%v: code=404 ip=%v", r.URL.Path, utils.GetIpAddress(r))
- utils.RenderWebError(err, w, r)
+ utils.RenderWebAppError(w, r, err, a.AsymmetricSigningKey())
}
diff --git a/app/app_test.go b/app/app_test.go
index 25b19ead8..09f8725d7 100644
--- a/app/app_test.go
+++ b/app/app_test.go
@@ -51,7 +51,8 @@ func TestAppRace(t *testing.T) {
a, err := New()
require.NoError(t, err)
a.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" })
- a.StartServer()
+ serverErr := a.StartServer()
+ require.NoError(t, serverErr)
a.Shutdown()
}
}
diff --git a/app/apptestlib.go b/app/apptestlib.go
index 09afc8f76..c7846c9b5 100644
--- a/app/apptestlib.go
+++ b/app/apptestlib.go
@@ -96,15 +96,20 @@ func setupTestHelper(enterprise bool) *TestHelper {
if testStore != nil {
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" })
}
- th.App.StartServer()
+ serverErr := th.App.StartServer()
+ if serverErr != nil {
+ panic(serverErr)
+ }
+
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = prevListenAddress })
th.App.Srv.Store.MarkSystemRanUnitTests()
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableOpenServer = true })
- utils.SetIsLicensed(enterprise)
if enterprise {
- utils.License().Features.SetDefaults()
+ th.App.SetLicense(model.NewTestLicense())
+ } else {
+ th.App.SetLicense(nil)
}
return th
diff --git a/app/channel.go b/app/channel.go
index e4bf48654..8ac1f421c 100644
--- a/app/channel.go
+++ b/app/channel.go
@@ -1359,7 +1359,7 @@ func (a *App) PermanentDeleteChannel(channel *model.Channel) *model.AppError {
// This function is intended for use from the CLI. It is not robust against people joining the channel while the move
// is in progress, and therefore should not be used from the API without first fixing this potential race condition.
-func (a *App) MoveChannel(team *model.Team, channel *model.Channel) *model.AppError {
+func (a *App) MoveChannel(team *model.Team, channel *model.Channel, user *model.User) *model.AppError {
// Check that all channel members are in the destination team.
if channelMembers, err := a.GetChannelMembersPage(channel.Id, 0, 10000000); err != nil {
return err
@@ -1378,11 +1378,37 @@ func (a *App) MoveChannel(team *model.Team, channel *model.Channel) *model.AppEr
}
}
- // Change the Team ID of the channel.
+ // keep instance of the previous team
+ var previousTeam *model.Team
+ if result := <-a.Srv.Store.Team().Get(channel.TeamId); result.Err != nil {
+ return result.Err
+ } else {
+ previousTeam = result.Data.(*model.Team)
+ }
channel.TeamId = team.Id
if result := <-a.Srv.Store.Channel().Update(channel); result.Err != nil {
return result.Err
}
+ a.postChannelMoveMessage(user, channel, previousTeam)
+
+ return nil
+}
+
+func (a *App) postChannelMoveMessage(user *model.User, channel *model.Channel, previousTeam *model.Team) *model.AppError {
+
+ post := &model.Post{
+ ChannelId: channel.Id,
+ Message: fmt.Sprintf(utils.T("api.team.move_channel.success"), previousTeam.Name),
+ Type: model.POST_MOVE_CHANNEL,
+ UserId: user.Id,
+ Props: model.StringInterface{
+ "username": user.Username,
+ },
+ }
+
+ if _, err := a.CreatePost(post, channel, false); err != nil {
+ return model.NewAppError("postChannelMoveMessage", "api.team.move_channel.post.error", nil, err.Error(), http.StatusInternalServerError)
+ }
return nil
}
diff --git a/app/channel_test.go b/app/channel_test.go
index d83590a27..e4a0e4320 100644
--- a/app/channel_test.go
+++ b/app/channel_test.go
@@ -97,7 +97,7 @@ func TestMoveChannel(t *testing.T) {
t.Fatal(err)
}
- if err := th.App.MoveChannel(targetTeam, channel1); err == nil {
+ if err := th.App.MoveChannel(targetTeam, channel1, th.BasicUser); err == nil {
t.Fatal("Should have failed due to mismatched members.")
}
@@ -105,7 +105,7 @@ func TestMoveChannel(t *testing.T) {
t.Fatal(err)
}
- if err := th.App.MoveChannel(targetTeam, channel1); err != nil {
+ if err := th.App.MoveChannel(targetTeam, channel1, th.BasicUser); err != nil {
t.Fatal(err)
}
}
diff --git a/app/config.go b/app/config.go
index a2398f9e9..35a0c9a3f 100644
--- a/app/config.go
+++ b/app/config.go
@@ -4,10 +4,17 @@
package app
import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
"crypto/md5"
+ "crypto/rand"
+ "crypto/x509"
+ "encoding/base64"
"encoding/json"
"fmt"
+ "net/url"
"runtime/debug"
+ "strings"
l4g "github.com/alecthomas/log4go"
@@ -48,7 +55,7 @@ func (a *App) LoadConfig(configFile string) *model.AppError {
a.config.Store(cfg)
- utils.SetSiteURL(*cfg.ServiceSettings.SiteURL)
+ a.siteURL = strings.TrimRight(*cfg.ServiceSettings.SiteURL, "/")
a.InvokeConfigListeners(old, cfg)
return nil
@@ -116,8 +123,91 @@ func (a *App) InvokeConfigListeners(old, current *model.Config) {
}
}
+// EnsureAsymmetricSigningKey ensures that an asymmetric signing key exists and future calls to
+// AsymmetricSigningKey will always return a valid signing key.
+func (a *App) ensureAsymmetricSigningKey() error {
+ if a.asymmetricSigningKey != nil {
+ return nil
+ }
+
+ var key *model.SystemAsymmetricSigningKey
+
+ result := <-a.Srv.Store.System().GetByName(model.SYSTEM_ASYMMETRIC_SIGNING_KEY)
+ if result.Err == nil {
+ if err := json.Unmarshal([]byte(result.Data.(*model.System).Value), &key); err != nil {
+ return err
+ }
+ }
+
+ // If we don't already have a key, try to generate one.
+ if key == nil {
+ newECDSAKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return err
+ }
+ newKey := &model.SystemAsymmetricSigningKey{
+ ECDSAKey: &model.SystemECDSAKey{
+ Curve: "P-256",
+ X: newECDSAKey.X,
+ Y: newECDSAKey.Y,
+ D: newECDSAKey.D,
+ },
+ }
+ system := &model.System{
+ Name: model.SYSTEM_ASYMMETRIC_SIGNING_KEY,
+ }
+ v, err := json.Marshal(newKey)
+ if err != nil {
+ return err
+ }
+ system.Value = string(v)
+ if result = <-a.Srv.Store.System().Save(system); result.Err == nil {
+ // If we were able to save the key, use it, otherwise ignore the error.
+ key = newKey
+ }
+ }
+
+ // If we weren't able to save a new key above, another server must have beat us to it. Get the
+ // key from the database, and if that fails, error out.
+ if key == nil {
+ result := <-a.Srv.Store.System().GetByName(model.SYSTEM_ASYMMETRIC_SIGNING_KEY)
+ if result.Err != nil {
+ return result.Err
+ } else if err := json.Unmarshal([]byte(result.Data.(*model.System).Value), &key); err != nil {
+ return err
+ }
+ }
+
+ var curve elliptic.Curve
+ switch key.ECDSAKey.Curve {
+ case "P-256":
+ curve = elliptic.P256()
+ default:
+ return fmt.Errorf("unknown curve: " + key.ECDSAKey.Curve)
+ }
+ a.asymmetricSigningKey = &ecdsa.PrivateKey{
+ PublicKey: ecdsa.PublicKey{
+ Curve: curve,
+ X: key.ECDSAKey.X,
+ Y: key.ECDSAKey.Y,
+ },
+ D: key.ECDSAKey.D,
+ }
+ a.regenerateClientConfig()
+ return nil
+}
+
+// AsymmetricSigningKey will return a private key that can be used for asymmetric signing.
+func (a *App) AsymmetricSigningKey() *ecdsa.PrivateKey {
+ return a.asymmetricSigningKey
+}
+
func (a *App) regenerateClientConfig() {
- a.clientConfig = utils.GenerateClientConfig(a.Config(), a.DiagnosticId())
+ a.clientConfig = utils.GenerateClientConfig(a.Config(), a.DiagnosticId(), a.License())
+ if key := a.AsymmetricSigningKey(); key != nil {
+ der, _ := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ a.clientConfig["AsymmetricSigningPublicKey"] = base64.StdEncoding.EncodeToString(der)
+ }
clientConfigJSON, _ := json.Marshal(a.clientConfig)
a.clientConfigHash = fmt.Sprintf("%x", md5.Sum(clientConfigJSON))
}
@@ -167,10 +257,15 @@ func (a *App) Desanitize(cfg *model.Config) {
}
}
-// License returns the currently active license or nil if the application is unlicensed.
-func (a *App) License() *model.License {
- if utils.IsLicensed() {
- return utils.License()
+func (a *App) GetCookieDomain() string {
+ if *a.Config().ServiceSettings.AllowCookiesForSubdomains {
+ if siteURL, err := url.Parse(*a.Config().ServiceSettings.SiteURL); err == nil {
+ return siteURL.Hostname()
+ }
}
- return nil
+ return ""
+}
+
+func (a *App) GetSiteURL() string {
+ return a.siteURL
}
diff --git a/app/config_test.go b/app/config_test.go
index e3d50b958..5ee999f0f 100644
--- a/app/config_test.go
+++ b/app/config_test.go
@@ -6,6 +6,8 @@ package app
import (
"testing"
+ "github.com/stretchr/testify/assert"
+
"github.com/mattermost/mattermost-server/model"
)
@@ -54,3 +56,10 @@ func TestConfigListener(t *testing.T) {
t.Fatal("listener 2 should've been called")
}
}
+
+func TestAsymmetricSigningKey(t *testing.T) {
+ th := Setup().InitBasic()
+ defer th.TearDown()
+ assert.NotNil(t, th.App.AsymmetricSigningKey())
+ assert.NotEmpty(t, th.App.ClientConfig()["AsymmetricSigningPublicKey"])
+}
diff --git a/app/diagnostics.go b/app/diagnostics.go
index 809d9ff1e..12553afc8 100644
--- a/app/diagnostics.go
+++ b/app/diagnostics.go
@@ -243,6 +243,8 @@ func (a *App) trackConfig() {
"isdefault_image_proxy_type": isDefault(*cfg.ServiceSettings.ImageProxyType, ""),
"isdefault_image_proxy_url": isDefault(*cfg.ServiceSettings.ImageProxyURL, ""),
"isdefault_image_proxy_options": isDefault(*cfg.ServiceSettings.ImageProxyOptions, ""),
+ "websocket_url": isDefault(*cfg.ServiceSettings.WebsocketURL, ""),
+ "allow_cookies_for_subdomains": *cfg.ServiceSettings.AllowCookiesForSubdomains,
})
a.SendDiagnostic(TRACK_CONFIG_TEAM, map[string]interface{}{
@@ -501,6 +503,7 @@ func (a *App) trackConfig() {
a.SendDiagnostic(TRACK_CONFIG_MESSAGE_EXPORT, map[string]interface{}{
"enable_message_export": *cfg.MessageExportSettings.EnableExport,
+ "export_format": *cfg.MessageExportSettings.ExportFormat,
"daily_run_time": *cfg.MessageExportSettings.DailyRunTime,
"default_export_from_timestamp": *cfg.MessageExportSettings.ExportFromTimestamp,
"batch_size": *cfg.MessageExportSettings.BatchSize,
diff --git a/app/email.go b/app/email.go
index 764dc017a..8ee3e79e2 100644
--- a/app/email.go
+++ b/app/email.go
@@ -191,7 +191,7 @@ func (a *App) SendUserAccessTokenAddedEmail(email, locale string) *model.AppErro
bodyPage := a.NewEmailTemplate("password_change_body", locale)
bodyPage.Props["Title"] = T("api.templates.user_access_token_body.title")
bodyPage.Html["Info"] = utils.TranslateAsHtml(T, "api.templates.user_access_token_body.info",
- map[string]interface{}{"SiteName": a.ClientConfig()["SiteName"], "SiteURL": utils.GetSiteURL()})
+ map[string]interface{}{"SiteName": a.ClientConfig()["SiteName"], "SiteURL": a.GetSiteURL()})
if err := a.SendMail(email, subject, bodyPage.Render()); err != nil {
return model.NewAppError("SendUserAccessTokenAddedEmail", "api.user.send_user_access_token.error", nil, err.Error(), http.StatusInternalServerError)
@@ -317,5 +317,6 @@ func (a *App) NewEmailTemplate(name, locale string) *utils.HTMLTemplate {
}
func (a *App) SendMail(to, subject, htmlBody string) *model.AppError {
- return utils.SendMailUsingConfig(to, subject, htmlBody, a.Config())
+ license := a.License()
+ return utils.SendMailUsingConfig(to, subject, htmlBody, a.Config(), license != nil && *license.Features.Compliance)
}
diff --git a/app/email_batching.go b/app/email_batching.go
index 2a33d7d3e..07adda674 100644
--- a/app/email_batching.go
+++ b/app/email_batching.go
@@ -7,6 +7,7 @@ import (
"fmt"
"html/template"
"strconv"
+ "sync"
"time"
"github.com/mattermost/mattermost-server/model"
@@ -57,6 +58,8 @@ type EmailBatchingJob struct {
app *App
newNotifications chan *batchedNotification
pendingNotifications map[string][]*batchedNotification
+ task *model.ScheduledTask
+ taskMutex sync.Mutex
}
func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob {
@@ -68,12 +71,17 @@ func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob {
}
func (job *EmailBatchingJob) Start() {
- if task := model.GetTaskByName(EMAIL_BATCHING_TASK_NAME); task != nil {
- task.Cancel()
- }
-
l4g.Debug(utils.T("api.email_batching.start.starting"), *job.app.Config().EmailSettings.EmailBatchingInterval)
- model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second)
+ newTask := model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second)
+
+ job.taskMutex.Lock()
+ oldTask := job.task
+ job.task = newTask
+ job.taskMutex.Unlock()
+
+ if oldTask != nil {
+ oldTask.Cancel()
+ }
}
func (job *EmailBatchingJob) Add(user *model.User, post *model.Post, team *model.Team) bool {
diff --git a/app/file.go b/app/file.go
index d66c64adb..06ee61c92 100644
--- a/app/file.go
+++ b/app/file.go
@@ -58,7 +58,8 @@ const (
)
func (a *App) FileBackend() (utils.FileBackend, *model.AppError) {
- return utils.NewFileBackend(&a.Config().FileSettings)
+ license := a.License()
+ return utils.NewFileBackend(&a.Config().FileSettings, license != nil && *license.Features.Compliance)
}
func (a *App) ReadFile(path string) ([]byte, *model.AppError) {
@@ -279,11 +280,38 @@ func GeneratePublicLinkHash(fileId, salt string) string {
return base64.RawURLEncoding.EncodeToString(hash.Sum(nil))
}
-func (a *App) UploadFiles(teamId string, channelId string, userId string, fileHeaders []*multipart.FileHeader, clientIds []string) (*model.FileUploadResponse, *model.AppError) {
+func (a *App) UploadMultipartFiles(teamId string, channelId string, userId string, fileHeaders []*multipart.FileHeader, clientIds []string) (*model.FileUploadResponse, *model.AppError) {
+ files := make([]io.ReadCloser, len(fileHeaders))
+ filenames := make([]string, len(fileHeaders))
+
+ for i, fileHeader := range fileHeaders {
+ file, fileErr := fileHeader.Open()
+ if fileErr != nil {
+ return nil, model.NewAppError("UploadFiles", "api.file.upload_file.bad_parse.app_error", nil, fileErr.Error(), http.StatusBadRequest)
+ }
+
+ // Will be closed after UploadFiles returns
+ defer file.Close()
+
+ files[i] = file
+ filenames[i] = fileHeader.Filename
+ }
+
+ return a.UploadFiles(teamId, channelId, userId, files, filenames, clientIds)
+}
+
+// Uploads some files to the given team and channel as the given user. files and filenames should have
+// the same length. clientIds should either not be provided or have the same length as files and filenames.
+// The provided files should be closed by the caller so that they are not leaked.
+func (a *App) UploadFiles(teamId string, channelId string, userId string, files []io.ReadCloser, filenames []string, clientIds []string) (*model.FileUploadResponse, *model.AppError) {
if len(*a.Config().FileSettings.DriverName) == 0 {
return nil, model.NewAppError("uploadFile", "api.file.upload_file.storage.app_error", nil, "", http.StatusNotImplemented)
}
+ if len(filenames) != len(files) || (len(clientIds) > 0 && len(clientIds) != len(files)) {
+ return nil, model.NewAppError("UploadFiles", "api.file.upload_file.incorrect_number_of_files.app_error", nil, "", http.StatusBadRequest)
+ }
+
resStruct := &model.FileUploadResponse{
FileInfos: []*model.FileInfo{},
ClientIds: []string{},
@@ -293,18 +321,12 @@ func (a *App) UploadFiles(teamId string, channelId string, userId string, fileHe
thumbnailPathList := []string{}
imageDataList := [][]byte{}
- for i, fileHeader := range fileHeaders {
- file, fileErr := fileHeader.Open()
- if fileErr != nil {
- return nil, model.NewAppError("UploadFiles", "api.file.upload_file.bad_parse.app_error", nil, fileErr.Error(), http.StatusBadRequest)
- }
- defer file.Close()
-
+ for i, file := range files {
buf := bytes.NewBuffer(nil)
io.Copy(buf, file)
data := buf.Bytes()
- info, err := a.DoUploadFile(time.Now(), teamId, channelId, userId, fileHeader.Filename, data)
+ info, err := a.DoUploadFile(time.Now(), teamId, channelId, userId, filenames[i], data)
if err != nil {
return nil, err
}
diff --git a/app/import.go b/app/import.go
index 6291794b0..5a3158fab 100644
--- a/app/import.go
+++ b/app/import.go
@@ -817,6 +817,12 @@ func (a *App) ImportUserTeams(user *model.User, data *[]UserTeamImportData) *mod
}
}
+ if defaultChannel, err := a.GetChannelByName(model.DEFAULT_CHANNEL, team.Id); err != nil {
+ return err
+ } else if _, err = a.addUserToChannel(user, defaultChannel, member); err != nil {
+ return err
+ }
+
if err := a.ImportUserChannels(user, team, member, tdata.Channels); err != nil {
return err
}
diff --git a/app/ldap.go b/app/ldap.go
index 179529c52..ff7a5ed21 100644
--- a/app/ldap.go
+++ b/app/ldap.go
@@ -67,7 +67,7 @@ func (a *App) SwitchEmailToLdap(email, password, code, ldapId, ldapPassword stri
}
a.Go(func() {
- if err := a.SendSignInChangeEmail(user.Email, "AD/LDAP", user.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendSignInChangeEmail(user.Email, "AD/LDAP", user.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
@@ -113,7 +113,7 @@ func (a *App) SwitchLdapToEmail(ldapPassword, code, email, newPassword string) (
T := utils.GetUserTranslations(user.Locale)
a.Go(func() {
- if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
diff --git a/app/license.go b/app/license.go
index c7fd07197..c12f23d1d 100644
--- a/app/license.go
+++ b/app/license.go
@@ -4,16 +4,19 @@
package app
import (
+ "crypto/md5"
+ "fmt"
"net/http"
"strings"
l4g "github.com/alecthomas/log4go"
+
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/utils"
)
func (a *App) LoadLicense() {
- utils.RemoveLicense()
+ a.SetLicense(nil)
licenseId := ""
if result := <-a.Srv.Store.System().Get(); result.Err == nil {
@@ -36,7 +39,7 @@ func (a *App) LoadLicense() {
if result := <-a.Srv.Store.License().Get(licenseId); result.Err == nil {
record := result.Data.(*model.LicenseRecord)
- utils.LoadLicense([]byte(record.Bytes))
+ a.ValidateAndSetLicenseBytes([]byte(record.Bytes))
l4g.Info("License key valid unlocking enterprise features.")
} else {
l4g.Info(utils.T("mattermost.load_license.find.warn"))
@@ -59,7 +62,7 @@ func (a *App) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError)
}
}
- if ok := utils.SetLicense(license); !ok {
+ if ok := a.SetLicense(license); !ok {
return nil, model.NewAppError("addLicense", model.EXPIRED_LICENSE_ERROR, nil, "", http.StatusBadRequest)
}
@@ -102,21 +105,117 @@ func (a *App) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError)
return license, nil
}
+// License returns the currently active license or nil if the application is unlicensed.
+func (a *App) License() *model.License {
+ license, _ := a.licenseValue.Load().(*model.License)
+ return license
+}
+
+func (a *App) SetLicense(license *model.License) bool {
+ defer func() {
+ a.setDefaultRolesBasedOnConfig()
+ for _, listener := range a.licenseListeners {
+ listener()
+ }
+ }()
+
+ if license != nil {
+ license.Features.SetDefaults()
+
+ if !license.IsExpired() {
+ a.licenseValue.Store(license)
+ a.clientLicenseValue.Store(utils.GetClientLicense(license))
+ return true
+ }
+ }
+
+ a.licenseValue.Store((*model.License)(nil))
+ a.clientLicenseValue.Store(map[string]string(nil))
+ return false
+}
+
+func (a *App) ValidateAndSetLicenseBytes(b []byte) {
+ if success, licenseStr := utils.ValidateLicense(b); success {
+ license := model.LicenseFromJson(strings.NewReader(licenseStr))
+ a.SetLicense(license)
+ return
+ }
+
+ l4g.Warn(utils.T("utils.license.load_license.invalid.warn"))
+}
+
+func (a *App) SetClientLicense(m map[string]string) {
+ a.clientLicenseValue.Store(m)
+}
+
+func (a *App) ClientLicense() map[string]string {
+ if clientLicense, _ := a.clientLicenseValue.Load().(map[string]string); clientLicense != nil {
+ return clientLicense
+ }
+ return map[string]string{"IsLicensed": "false"}
+}
+
func (a *App) RemoveLicense() *model.AppError {
- utils.RemoveLicense()
+ if license, _ := a.licenseValue.Load().(*model.License); license == nil {
+ return nil
+ }
sysVar := &model.System{}
sysVar.Name = model.SYSTEM_ACTIVE_LICENSE_ID
sysVar.Value = ""
if result := <-a.Srv.Store.System().SaveOrUpdate(sysVar); result.Err != nil {
- utils.RemoveLicense()
return result.Err
}
+ a.SetLicense(nil)
a.ReloadConfig()
a.InvalidateAllCaches()
return nil
}
+
+func (a *App) AddLicenseListener(listener func()) string {
+ id := model.NewId()
+ a.licenseListeners[id] = listener
+ return id
+}
+
+func (a *App) RemoveLicenseListener(id string) {
+ delete(a.licenseListeners, id)
+}
+
+func (a *App) GetClientLicenseEtag(useSanitized bool) string {
+ value := ""
+
+ lic := a.ClientLicense()
+
+ if useSanitized {
+ lic = a.GetSanitizedClientLicense()
+ }
+
+ for k, v := range lic {
+ value += fmt.Sprintf("%s:%s;", k, v)
+ }
+
+ return model.Etag(fmt.Sprintf("%x", md5.Sum([]byte(value))))
+}
+
+func (a *App) GetSanitizedClientLicense() map[string]string {
+ sanitizedLicense := make(map[string]string)
+
+ for k, v := range a.ClientLicense() {
+ sanitizedLicense[k] = v
+ }
+
+ delete(sanitizedLicense, "Id")
+ delete(sanitizedLicense, "Name")
+ delete(sanitizedLicense, "Email")
+ delete(sanitizedLicense, "PhoneNumber")
+ delete(sanitizedLicense, "IssuedAt")
+ delete(sanitizedLicense, "StartsAt")
+ delete(sanitizedLicense, "ExpiresAt")
+
+ return sanitizedLicense
+}
diff --git a/app/license_test.go b/app/license_test.go
index 5b73d9d18..f86d604d1 100644
--- a/app/license_test.go
+++ b/app/license_test.go
@@ -4,8 +4,9 @@
package app
import (
- //"github.com/mattermost/mattermost-server/model"
"testing"
+
+ "github.com/mattermost/mattermost-server/model"
)
func TestLoadLicense(t *testing.T) {
@@ -37,3 +38,75 @@ func TestRemoveLicense(t *testing.T) {
t.Fatal("should have removed license")
}
}
+
+func TestSetLicense(t *testing.T) {
+ th := Setup()
+ defer th.TearDown()
+
+ l1 := &model.License{}
+ l1.Features = &model.Features{}
+ l1.Customer = &model.Customer{}
+ l1.StartsAt = model.GetMillis() - 1000
+ l1.ExpiresAt = model.GetMillis() + 100000
+ if ok := th.App.SetLicense(l1); !ok {
+ t.Fatal("license should have worked")
+ }
+
+ l2 := &model.License{}
+ l2.Features = &model.Features{}
+ l2.Customer = &model.Customer{}
+ l2.StartsAt = model.GetMillis() - 1000
+ l2.ExpiresAt = model.GetMillis() - 100
+ if ok := th.App.SetLicense(l2); ok {
+ t.Fatal("license should have failed")
+ }
+
+ l3 := &model.License{}
+ l3.Features = &model.Features{}
+ l3.Customer = &model.Customer{}
+ l3.StartsAt = model.GetMillis() + 10000
+ l3.ExpiresAt = model.GetMillis() + 100000
+ if ok := th.App.SetLicense(l3); !ok {
+ t.Fatal("license should have passed")
+ }
+}
+
+func TestClientLicenseEtag(t *testing.T) {
+ th := Setup()
+ defer th.TearDown()
+
+ etag1 := th.App.GetClientLicenseEtag(false)
+
+ th.App.SetClientLicense(map[string]string{"SomeFeature": "true", "IsLicensed": "true"})
+
+ etag2 := th.App.GetClientLicenseEtag(false)
+ if etag1 == etag2 {
+ t.Fatal("etags should not match")
+ }
+
+ th.App.SetClientLicense(map[string]string{"SomeFeature": "true", "IsLicensed": "false"})
+
+ etag3 := th.App.GetClientLicenseEtag(false)
+ if etag2 == etag3 {
+ t.Fatal("etags should not match")
+ }
+}
+
+func TestGetSanitizedClientLicense(t *testing.T) {
+ th := Setup()
+ defer th.TearDown()
+
+ l1 := &model.License{}
+ l1.Features = &model.Features{}
+ l1.Customer = &model.Customer{}
+ l1.Customer.Name = "TestName"
+ l1.StartsAt = model.GetMillis() - 1000
+ l1.ExpiresAt = model.GetMillis() + 100000
+ th.App.SetLicense(l1)
+
+ m := th.App.GetSanitizedClientLicense()
+
+ if _, ok := m["Name"]; ok {
+ t.Fatal("should have been sanatized")
+ }
+}
diff --git a/app/login.go b/app/login.go
index ecc0f0163..e01566bcd 100644
--- a/app/login.go
+++ b/app/login.go
@@ -113,6 +113,7 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User,
secure = true
}
+ domain := a.GetCookieDomain()
expiresAt := time.Unix(model.GetMillis()/1000+int64(maxAge), 0)
sessionCookie := &http.Cookie{
Name: model.SESSION_COOKIE_TOKEN,
@@ -121,6 +122,7 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User,
MaxAge: maxAge,
Expires: expiresAt,
HttpOnly: true,
+ Domain: domain,
Secure: secure,
}
@@ -130,6 +132,7 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User,
Path: "/",
MaxAge: maxAge,
Expires: expiresAt,
+ Domain: domain,
Secure: secure,
}
diff --git a/app/notification.go b/app/notification.go
index 24e84500b..8cb63fbaf 100644
--- a/app/notification.go
+++ b/app/notification.go
@@ -362,7 +362,7 @@ func (a *App) sendNotificationEmail(post *model.Post, user *model.User, channel
emailNotificationContentsType = *a.Config().EmailSettings.EmailNotificationContentsType
}
- teamURL := utils.GetSiteURL() + "/" + team.Name
+ teamURL := a.GetSiteURL() + "/" + team.Name
var bodyText = a.getNotificationEmailBody(user, post, channel, senderName, team.Name, teamURL, emailNotificationContentsType, translateFunc)
a.Go(func() {
@@ -421,7 +421,7 @@ func (a *App) getNotificationEmailBody(recipient *model.User, post *model.Post,
bodyPage = a.NewEmailTemplate("post_body_generic", recipient.Locale)
}
- bodyPage.Props["SiteURL"] = utils.GetSiteURL()
+ bodyPage.Props["SiteURL"] = a.GetSiteURL()
if teamName != "select_team" {
bodyPage.Props["TeamLink"] = teamURL + "/pl/" + post.Id
} else {
@@ -623,7 +623,9 @@ func (a *App) getPushNotificationMessage(postMessage string, wasMentioned bool,
message := ""
category := ""
- if *a.Config().EmailSettings.PushNotificationContents == model.FULL_NOTIFICATION {
+ contentsConfig := *a.Config().EmailSettings.PushNotificationContents
+
+ if contentsConfig == model.FULL_NOTIFICATION {
category = model.CATEGORY_CAN_REPLY
if channelType == model.CHANNEL_DIRECT {
@@ -631,7 +633,7 @@ func (a *App) getPushNotificationMessage(postMessage string, wasMentioned bool,
} else {
message = senderName + userLocale("api.post.send_notifications_and_forget.push_in") + channelName + ": " + model.ClearMentionTags(postMessage)
}
- } else if *a.Config().EmailSettings.PushNotificationContents == model.GENERIC_NO_CHANNEL_NOTIFICATION {
+ } else if contentsConfig == model.GENERIC_NO_CHANNEL_NOTIFICATION {
if channelType == model.CHANNEL_DIRECT {
category = model.CATEGORY_CAN_REPLY
@@ -659,6 +661,8 @@ func (a *App) getPushNotificationMessage(postMessage string, wasMentioned bool,
if len(postMessage) == 0 && hasFiles {
if channelType == model.CHANNEL_DIRECT {
message = senderName + userLocale("api.post.send_notifications_and_forget.push_image_only_dm")
+ } else if contentsConfig == model.GENERIC_NO_CHANNEL_NOTIFICATION {
+ message = senderName + userLocale("api.post.send_notifications_and_forget.push_image_only_no_channel")
} else {
message = senderName + userLocale("api.post.send_notifications_and_forget.push_image_only") + channelName
}
diff --git a/app/notification_test.go b/app/notification_test.go
index 43703c019..5fc1d152c 100644
--- a/app/notification_test.go
+++ b/app/notification_test.go
@@ -1414,6 +1414,12 @@ func TestGetPushNotificationMessage(t *testing.T) {
ExpectedMessage: "user uploaded one or more files in a direct message",
ExpectedCategory: model.CATEGORY_CAN_REPLY,
},
+ "only files without channel, public channel": {
+ HasFiles: true,
+ PushNotificationContents: model.GENERIC_NO_CHANNEL_NOTIFICATION,
+ ChannelType: model.CHANNEL_OPEN,
+ ExpectedMessage: "user uploaded one or more files",
+ },
} {
t.Run(name, func(t *testing.T) {
locale := tc.Locale
diff --git a/app/oauth.go b/app/oauth.go
index 5a66f542e..630fd3e2d 100644
--- a/app/oauth.go
+++ b/app/oauth.go
@@ -527,7 +527,7 @@ func (a *App) CompleteSwitchWithOAuth(service string, userData io.ReadCloser, em
}
a.Go(func() {
- if err := a.SendSignInChangeEmail(user.Email, strings.Title(service)+" SSO", user.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendSignInChangeEmail(user.Email, strings.Title(service)+" SSO", user.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
@@ -600,7 +600,12 @@ func (a *App) GetAuthorizationCode(w http.ResponseWriter, r *http.Request, servi
props["token"] = stateToken.Token
state := b64.StdEncoding.EncodeToString([]byte(model.MapToJson(props)))
- redirectUri := utils.GetSiteURL() + "/signup/" + service + "/complete"
+ siteUrl := a.GetSiteURL()
+ if strings.TrimSpace(siteUrl) == "" {
+ siteUrl = GetProtocol(r) + "://" + r.Host
+ }
+
+ redirectUri := siteUrl + "/signup/" + service + "/complete"
authUrl := endpoint + "?response_type=code&client_id=" + clientId + "&redirect_uri=" + url.QueryEscape(redirectUri) + "&state=" + url.QueryEscape(state)
@@ -736,7 +741,7 @@ func (a *App) SwitchEmailToOAuth(w http.ResponseWriter, r *http.Request, email,
stateProps["email"] = email
if service == model.USER_AUTH_SERVICE_SAML {
- return utils.GetSiteURL() + "/login/sso/saml?action=" + model.OAUTH_ACTION_EMAIL_TO_SSO + "&email=" + utils.UrlEncode(email), nil
+ return a.GetSiteURL() + "/login/sso/saml?action=" + model.OAUTH_ACTION_EMAIL_TO_SSO + "&email=" + utils.UrlEncode(email), nil
} else {
if authUrl, err := a.GetAuthorizationCode(w, r, service, stateProps, ""); err != nil {
return "", err
@@ -768,7 +773,7 @@ func (a *App) SwitchOAuthToEmail(email, password, requesterId string) (string, *
T := utils.GetUserTranslations(user.Locale)
a.Go(func() {
- if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendSignInChangeEmail(user.Email, T("api.templates.signin_change_email.body.method_email"), user.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
diff --git a/app/plugin.go b/app/plugin.go
index 3f06a000f..fe671d26a 100644
--- a/app/plugin.go
+++ b/app/plugin.go
@@ -565,6 +565,7 @@ func (a *App) RegisterPluginCommand(pluginId string, command *model.Command) err
TeamId: command.TeamId,
AutoComplete: command.AutoComplete,
AutoCompleteDesc: command.AutoCompleteDesc,
+ AutoCompleteHint: command.AutoCompleteHint,
DisplayName: command.DisplayName,
}
diff --git a/app/post_test.go b/app/post_test.go
index ebe973270..2472e40c6 100644
--- a/app/post_test.go
+++ b/app/post_test.go
@@ -345,4 +345,4 @@ func TestMakeOpenGraphURLsAbsolute(t *testing.T) {
}
})
}
-} \ No newline at end of file
+}
diff --git a/app/role.go b/app/role.go
index 5f39dd623..9f271ea7a 100644
--- a/app/role.go
+++ b/app/role.go
@@ -12,8 +12,6 @@ func (a *App) Role(id string) *model.Role {
return a.roles[id]
}
-// Updates the roles based on the app config and the global license check. You may need to invoke
-// this when license changes are made.
-func (a *App) SetDefaultRolesBasedOnConfig() {
- a.roles = utils.DefaultRolesBasedOnConfig(a.Config())
+func (a *App) setDefaultRolesBasedOnConfig() {
+ a.roles = utils.DefaultRolesBasedOnConfig(a.Config(), a.License() != nil)
}
diff --git a/app/server.go b/app/server.go
index 1659908b6..afa282ad6 100644
--- a/app/server.go
+++ b/app/server.go
@@ -17,6 +17,7 @@ import (
l4g "github.com/alecthomas/log4go"
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
+ "github.com/pkg/errors"
"golang.org/x/crypto/acme/autocert"
"github.com/mattermost/mattermost-server/model"
@@ -116,7 +117,7 @@ func redirectHTTPToHTTPS(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, url.String(), http.StatusFound)
}
-func (a *App) StartServer() {
+func (a *App) StartServer() error {
l4g.Info(utils.T("api.server.start_server.starting.info"))
var handler http.Handler = &CorsWrapper{a.Config, a.Srv.Router}
@@ -126,8 +127,7 @@ func (a *App) StartServer() {
rateLimiter, err := NewRateLimiter(&a.Config().RateLimitSettings)
if err != nil {
- l4g.Critical(err.Error())
- return
+ return err
}
a.Srv.RateLimiter = rateLimiter
@@ -151,8 +151,8 @@ func (a *App) StartServer() {
listener, err := net.Listen("tcp", addr)
if err != nil {
- l4g.Critical(utils.T("api.server.start_server.starting.critical"), err)
- return
+ errors.Wrapf(err, utils.T("api.server.start_server.starting.critical"), err)
+ return err
}
a.Srv.ListenAddr = listener.Addr().(*net.TCPAddr)
@@ -214,6 +214,8 @@ func (a *App) StartServer() {
}
close(a.Srv.didFinishListen)
}()
+
+ return nil
}
type tcpKeepAliveListener struct {
diff --git a/app/server_test.go b/app/server_test.go
new file mode 100644
index 000000000..de358b976
--- /dev/null
+++ b/app/server_test.go
@@ -0,0 +1,50 @@
+// Copyright (c) 2017-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package app
+
+import (
+ "testing"
+
+ "github.com/mattermost/mattermost-server/model"
+ "github.com/stretchr/testify/require"
+)
+
+func TestStartServerSuccess(t *testing.T) {
+ a, err := New()
+ require.NoError(t, err)
+
+ a.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" })
+ serverErr := a.StartServer()
+ a.Shutdown()
+ require.NoError(t, serverErr)
+}
+
+func TestStartServerRateLimiterCriticalError(t *testing.T) {
+ a, err := New()
+ require.NoError(t, err)
+
+ // Attempt to use Rate Limiter with an invalid config
+ a.UpdateConfig(func(cfg *model.Config) {
+ *cfg.RateLimitSettings.Enable = true
+ *cfg.RateLimitSettings.MaxBurst = -100
+ })
+
+ serverErr := a.StartServer()
+ a.Shutdown()
+ require.Error(t, serverErr)
+}
+
+func TestStartServerPortUnavailable(t *testing.T) {
+ a, err := New()
+ require.NoError(t, err)
+
+ // Attempt to listen on a system-reserved port
+ a.UpdateConfig(func(cfg *model.Config) {
+ *cfg.ServiceSettings.ListenAddress = ":21"
+ })
+
+ serverErr := a.StartServer()
+ a.Shutdown()
+ require.Error(t, serverErr)
+}
diff --git a/app/session_test.go b/app/session_test.go
index bca3b59b7..bf8198a4e 100644
--- a/app/session_test.go
+++ b/app/session_test.go
@@ -6,11 +6,10 @@ package app
import (
"testing"
- "github.com/mattermost/mattermost-server/model"
- "github.com/mattermost/mattermost-server/utils"
-
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+
+ "github.com/mattermost/mattermost-server/model"
)
func TestCache(t *testing.T) {
@@ -48,18 +47,7 @@ func TestGetSessionIdleTimeoutInMinutes(t *testing.T) {
session, _ = th.App.CreateSession(session)
- isLicensed := utils.IsLicensed()
- license := utils.License()
- timeout := *th.App.Config().ServiceSettings.SessionIdleTimeoutInMinutes
- defer func() {
- utils.SetIsLicensed(isLicensed)
- utils.SetLicense(license)
- th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.SessionIdleTimeoutInMinutes = timeout })
- }()
- utils.SetIsLicensed(true)
- utils.SetLicense(&model.License{Features: &model.Features{}})
- utils.License().Features.SetDefaults()
- *utils.License().Features.Compliance = true
+ th.App.SetLicense(model.NewTestLicense("compliance"))
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.SessionIdleTimeoutInMinutes = 5 })
rsession, err := th.App.GetSession(session.Token)
@@ -122,7 +110,7 @@ func TestGetSessionIdleTimeoutInMinutes(t *testing.T) {
assert.Nil(t, err)
// Test regular session with license off, should not timeout
- *utils.License().Features.Compliance = false
+ th.App.SetLicense(nil)
session = &model.Session{
UserId: model.NewId(),
@@ -136,7 +124,7 @@ func TestGetSessionIdleTimeoutInMinutes(t *testing.T) {
_, err = th.App.GetSession(session.Token)
assert.Nil(t, err)
- *utils.License().Features.Compliance = true
+ th.App.SetLicense(model.NewTestLicense("compliance"))
// Test regular session with timeout set to 0, should not timeout
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.SessionIdleTimeoutInMinutes = 0 })
diff --git a/app/team.go b/app/team.go
index a15c64c3f..d8750bfbb 100644
--- a/app/team.go
+++ b/app/team.go
@@ -741,7 +741,7 @@ func (a *App) InviteNewUsersToTeam(emailList []string, teamId, senderId string)
}
nameFormat := *a.Config().TeamSettings.TeammateNameDisplay
- a.SendInviteEmails(team, user.GetDisplayName(nameFormat), emailList, utils.GetSiteURL())
+ a.SendInviteEmails(team, user.GetDisplayName(nameFormat), emailList, a.GetSiteURL())
return nil
}
diff --git a/app/team_test.go b/app/team_test.go
index a2bf44a57..cdfec12da 100644
--- a/app/team_test.go
+++ b/app/team_test.go
@@ -481,7 +481,6 @@ func TestJoinUserToTeam(t *testing.T) {
maxUsersPerTeam := th.App.Config().TeamSettings.MaxUsersPerTeam
defer func() {
th.App.UpdateConfig(func(cfg *model.Config) { cfg.TeamSettings.MaxUsersPerTeam = maxUsersPerTeam })
- th.App.SetDefaultRolesBasedOnConfig()
th.App.PermanentDeleteTeam(team)
}()
one := 1
diff --git a/app/user.go b/app/user.go
index 69c6d072b..f915f35cb 100644
--- a/app/user.go
+++ b/app/user.go
@@ -106,7 +106,7 @@ func (a *App) CreateUserWithInviteId(user *model.User, inviteId string) (*model.
a.AddDirectChannels(team.Id, ruser)
- if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
@@ -119,7 +119,7 @@ func (a *App) CreateUserAsAdmin(user *model.User) (*model.User, *model.AppError)
return nil, err
}
- if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
@@ -143,7 +143,7 @@ func (a *App) CreateUserFromSignup(user *model.User) (*model.User, *model.AppErr
return nil, err
}
- if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendWelcomeEmail(ruser.Id, ruser.Email, ruser.EmailVerified, ruser.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
@@ -1027,7 +1027,7 @@ func (a *App) UpdateUser(user *model.User, sendNotifications bool) (*model.User,
if sendNotifications {
if rusers[0].Email != rusers[1].Email {
a.Go(func() {
- if err := a.SendEmailChangeEmail(rusers[1].Email, rusers[0].Email, rusers[0].Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendEmailChangeEmail(rusers[1].Email, rusers[0].Email, rusers[0].Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
@@ -1041,7 +1041,7 @@ func (a *App) UpdateUser(user *model.User, sendNotifications bool) (*model.User,
if rusers[0].Username != rusers[1].Username {
a.Go(func() {
- if err := a.SendChangeUsernameEmail(rusers[1].Username, rusers[0].Username, rusers[0].Email, rusers[0].Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendChangeUsernameEmail(rusers[1].Username, rusers[0].Username, rusers[0].Email, rusers[0].Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
@@ -1091,7 +1091,7 @@ func (a *App) UpdateMfa(activate bool, userId, token string) *model.AppError {
return
}
- if err := a.SendMfaChangeEmail(user.Email, activate, user.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendMfaChangeEmail(user.Email, activate, user.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
@@ -1129,7 +1129,7 @@ func (a *App) UpdatePasswordSendEmail(user *model.User, newPassword, method stri
}
a.Go(func() {
- if err := a.SendPasswordChangeEmail(user.Email, method, user.Locale, utils.GetSiteURL()); err != nil {
+ if err := a.SendPasswordChangeEmail(user.Email, method, user.Locale, a.GetSiteURL()); err != nil {
l4g.Error(err.Error())
}
})
@@ -1342,9 +1342,9 @@ func (a *App) SendEmailVerification(user *model.User) *model.AppError {
}
if _, err := a.GetStatus(user.Id); err != nil {
- return a.SendVerifyEmail(user.Email, user.Locale, utils.GetSiteURL(), token.Token)
+ return a.SendVerifyEmail(user.Email, user.Locale, a.GetSiteURL(), token.Token)
} else {
- return a.SendEmailChangeVerifyEmail(user.Email, user.Locale, utils.GetSiteURL(), token.Token)
+ return a.SendEmailChangeVerifyEmail(user.Email, user.Locale, a.GetSiteURL(), token.Token)
}
}
diff --git a/app/web_hub.go b/app/web_hub.go
index eeae13e09..c1c8cb7bb 100644
--- a/app/web_hub.go
+++ b/app/web_hub.go
@@ -30,7 +30,6 @@ type Hub struct {
// See https://github.com/mattermost/mattermost-server/pull/7281
connectionCount int64
app *App
- connections []*WebConn
connectionIndex int
register chan *WebConn
unregister chan *WebConn
@@ -47,7 +46,6 @@ func (a *App) NewWebHub() *Hub {
app: a,
register: make(chan *WebConn, 1),
unregister: make(chan *WebConn, 1),
- connections: make([]*WebConn, 0, model.SESSION_CACHE_SIZE),
broadcast: make(chan *model.WebSocketEvent, BROADCAST_QUEUE_SIZE),
stop: make(chan struct{}),
didStop: make(chan struct{}),
@@ -170,8 +168,14 @@ func (a *App) Publish(message *model.WebSocketEvent) {
}
func (a *App) PublishSkipClusterSend(message *model.WebSocketEvent) {
- for _, hub := range a.Hubs {
- hub.Broadcast(message)
+ if message.Broadcast.UserId != "" {
+ if len(a.Hubs) != 0 {
+ a.GetHubForUserId(message.Broadcast.UserId).Broadcast(message)
+ }
+ } else {
+ for _, hub := range a.Hubs {
+ hub.Broadcast(message)
+ }
}
}
@@ -362,80 +366,53 @@ func (h *Hub) Start() {
var doRecover func()
doStart = func() {
-
h.goroutineId = getGoroutineId()
l4g.Debug("Hub for index %v is starting with goroutine %v", h.connectionIndex, h.goroutineId)
+ connections := newHubConnectionIndex()
+
for {
select {
case webCon := <-h.register:
- h.connections = append(h.connections, webCon)
- atomic.StoreInt64(&h.connectionCount, int64(len(h.connections)))
-
+ connections.Add(webCon)
+ atomic.StoreInt64(&h.connectionCount, int64(len(connections.All())))
case webCon := <-h.unregister:
- userId := webCon.UserId
-
- found := false
- indexToDel := -1
- for i, webConCandidate := range h.connections {
- if webConCandidate == webCon {
- indexToDel = i
- continue
- }
- if userId == webConCandidate.UserId {
- found = true
- if indexToDel != -1 {
- break
- }
- }
- }
+ connections.Remove(webCon)
- if indexToDel != -1 {
- // Delete the webcon we are unregistering
- h.connections[indexToDel] = h.connections[len(h.connections)-1]
- h.connections = h.connections[:len(h.connections)-1]
- }
-
- if len(userId) == 0 {
+ if len(webCon.UserId) == 0 {
continue
}
- if !found {
+ if len(connections.ForUser(webCon.UserId)) == 0 {
h.app.Go(func() {
- h.app.SetStatusOffline(userId, false)
+ h.app.SetStatusOffline(webCon.UserId, false)
})
}
-
case userId := <-h.invalidateUser:
- for _, webCon := range h.connections {
- if webCon.UserId == userId {
- webCon.InvalidateCache()
- }
+ for _, webCon := range connections.ForUser(userId) {
+ webCon.InvalidateCache()
}
-
case msg := <-h.broadcast:
- for _, webCon := range h.connections {
+ candidates := connections.All()
+ if msg.Broadcast.UserId != "" {
+ candidates = connections.ForUser(msg.Broadcast.UserId)
+ }
+ msg.PrecomputeJSON()
+ for _, webCon := range candidates {
if webCon.ShouldSendEvent(msg) {
select {
case webCon.Send <- msg:
default:
l4g.Error(fmt.Sprintf("webhub.broadcast: cannot send, closing websocket for userId=%v", webCon.UserId))
close(webCon.Send)
- for i, webConCandidate := range h.connections {
- if webConCandidate == webCon {
- h.connections[i] = h.connections[len(h.connections)-1]
- h.connections = h.connections[:len(h.connections)-1]
- break
- }
- }
+ connections.Remove(webCon)
}
}
}
-
case <-h.stop:
userIds := make(map[string]bool)
- for _, webCon := range h.connections {
+ for _, webCon := range connections.All() {
userIds[webCon.UserId] = true
webCon.Close()
}
@@ -444,7 +421,6 @@ func (h *Hub) Start() {
h.app.SetStatusOffline(userId, false)
}
- h.connections = make([]*WebConn, 0, model.SESSION_CACHE_SIZE)
h.ExplicitStop = true
close(h.didStop)
@@ -474,3 +450,60 @@ func (h *Hub) Start() {
go doRecoverableStart()
}
+
+type hubConnectionIndexIndexes struct {
+ connections int
+ connectionsByUserId int
+}
+
+// hubConnectionIndex provides fast addition, removal, and iteration of web connections.
+type hubConnectionIndex struct {
+ connections []*WebConn
+ connectionsByUserId map[string][]*WebConn
+ connectionIndexes map[*WebConn]*hubConnectionIndexIndexes
+}
+
+func newHubConnectionIndex() *hubConnectionIndex {
+ return &hubConnectionIndex{
+ connections: make([]*WebConn, 0, model.SESSION_CACHE_SIZE),
+ connectionsByUserId: make(map[string][]*WebConn),
+ connectionIndexes: make(map[*WebConn]*hubConnectionIndexIndexes),
+ }
+}
+
+func (i *hubConnectionIndex) Add(wc *WebConn) {
+ i.connections = append(i.connections, wc)
+ i.connectionsByUserId[wc.UserId] = append(i.connectionsByUserId[wc.UserId], wc)
+ i.connectionIndexes[wc] = &hubConnectionIndexIndexes{
+ connections: len(i.connections) - 1,
+ connectionsByUserId: len(i.connectionsByUserId[wc.UserId]) - 1,
+ }
+}
+
+func (i *hubConnectionIndex) Remove(wc *WebConn) {
+ indexes, ok := i.connectionIndexes[wc]
+ if !ok {
+ return
+ }
+
+ last := i.connections[len(i.connections)-1]
+ i.connections[indexes.connections] = last
+ i.connections = i.connections[:len(i.connections)-1]
+ i.connectionIndexes[last].connections = indexes.connections
+
+ userConnections := i.connectionsByUserId[wc.UserId]
+ last = userConnections[len(userConnections)-1]
+ userConnections[indexes.connectionsByUserId] = last
+ i.connectionsByUserId[wc.UserId] = userConnections[:len(userConnections)-1]
+ i.connectionIndexes[last].connectionsByUserId = indexes.connectionsByUserId
+
+ delete(i.connectionIndexes, wc)
+}
+
+func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
+ return i.connectionsByUserId[id]
+}
+
+func (i *hubConnectionIndex) All() []*WebConn {
+ return i.connections
+}