diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | api4/user.go | 34 | ||||
-rw-r--r-- | api4/user_test.go | 87 | ||||
-rw-r--r-- | app/import.go | 10 | ||||
-rw-r--r-- | app/login.go | 25 | ||||
-rw-r--r-- | app/login_test.go | 37 | ||||
-rw-r--r-- | app/notification.go | 5 | ||||
-rw-r--r-- | app/slackimport.go | 20 | ||||
-rw-r--r-- | app/user.go | 11 | ||||
-rw-r--r-- | cmd/mattermost/commands/import.go | 11 | ||||
-rw-r--r-- | cmd/platform/main.go | 25 | ||||
-rw-r--r-- | config/default.json | 4 | ||||
-rw-r--r-- | i18n/en.json | 16 | ||||
-rw-r--r-- | model/client4.go | 36 | ||||
-rw-r--r-- | model/config.go | 20 | ||||
-rw-r--r-- | model/users_stats.go | 24 | ||||
-rw-r--r-- | store/sqlstore/upgrade.go | 10 | ||||
-rw-r--r-- | utils/config.go | 106 | ||||
-rw-r--r-- | utils/config_test.go | 284 |
19 files changed, 658 insertions, 109 deletions
@@ -127,7 +127,7 @@ start-docker: ## Starts the docker containers for local development. @if [ $(shell docker ps -a | grep -ci mattermost-minio) -eq 0 ]; then \ echo starting mattermost-minio; \ docker run --name mattermost-minio -p 9001:9000 -e "MINIO_ACCESS_KEY=minioaccesskey" \ - -e "MINIO_SECRET_KEY=miniosecretkey" -d minio/minio:latest server /data > /dev/null; \ + -e "MINIO_SECRET_KEY=miniosecretkey" -d minio/minio:RELEASE.2018-05-25T19-49-13Z server /data > /dev/null; \ docker exec -it mattermost-minio /bin/sh -c "mkdir -p /data/mattermost-test" > /dev/null; \ elif [ $(shell docker ps | grep -ci mattermost-minio) -eq 0 ]; then \ echo restarting mattermost-minio; \ diff --git a/api4/user.go b/api4/user.go index 2292544c4..2b79b19f1 100644 --- a/api4/user.go +++ b/api4/user.go @@ -22,6 +22,7 @@ func (api *API) InitUser() { api.BaseRoutes.Users.Handle("/usernames", api.ApiSessionRequired(getUsersByNames)).Methods("POST") api.BaseRoutes.Users.Handle("/search", api.ApiSessionRequired(searchUsers)).Methods("POST") api.BaseRoutes.Users.Handle("/autocomplete", api.ApiSessionRequired(autocompleteUsers)).Methods("GET") + api.BaseRoutes.Users.Handle("/stats", api.ApiSessionRequired(getTotalUsersStats)).Methods("GET") api.BaseRoutes.User.Handle("", api.ApiSessionRequired(getUser)).Methods("GET") api.BaseRoutes.User.Handle("/image", api.ApiSessionRequiredTrustRequester(getProfileImage)).Methods("GET") @@ -278,6 +279,20 @@ func setProfileImage(c *Context, w http.ResponseWriter, r *http.Request) { ReturnStatusOK(w) } +func getTotalUsersStats(c *Context, w http.ResponseWriter, r *http.Request) { + if c.Err != nil { + return + } + + if stats, err := c.App.GetTotalUsersStats(); err != nil { + c.Err = err + return + } else { + w.Write([]byte(stats.ToJson())) + return + } +} + func getUsers(c *Context, w http.ResponseWriter, r *http.Request) { inTeamId := r.URL.Query().Get("in_team") notInTeamId := r.URL.Query().Get("not_in_team") @@ -968,8 +983,27 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) { deviceId := props["device_id"] ldapOnly := props["ldap_only"] == "true" + if *c.App.Config().ExperimentalSettings.ClientSideCertEnable { + if license := c.App.License(); license == nil || !*license.Features.SAML { + c.Err = model.NewAppError("ClientSideCertNotAllowed", "Attempt to use the experimental feature ClientSideCertEnable without a valid enterprise license", nil, "", http.StatusBadRequest) + return + } else { + certPem, certSubject, certEmail := c.App.CheckForClienSideCert(r) + mlog.Debug("Client Cert", mlog.String("cert_subject", certSubject), mlog.String("cert_email", certEmail)) + + if len(certPem) == 0 || len(certEmail) == 0 { + c.Err = model.NewAppError("ClientSideCertMissing", "Attempted to sign in using the experimental feature ClientSideCert without providing a valid certificate", nil, "", http.StatusBadRequest) + return + } else if *c.App.Config().ExperimentalSettings.ClientSideCertCheck == model.CLIENT_SIDE_CERT_CHECK_PRIMARY_AUTH { + loginId = certEmail + password = "certificate" + } + } + } + c.LogAuditWithUserId(id, "attempt - login_id="+loginId) user, err := c.App.AuthenticateUserForLogin(id, loginId, password, mfaToken, ldapOnly) + if err != nil { c.LogAuditWithUserId(id, "failure - login_id="+loginId) c.Err = err diff --git a/api4/user_test.go b/api4/user_test.go index 1044e6162..96aa55d5f 100644 --- a/api4/user_test.go +++ b/api4/user_test.go @@ -909,6 +909,21 @@ func TestGetUsersByUsernames(t *testing.T) { CheckUnauthorizedStatus(t, resp) } +func TestGetTotalUsersStat(t *testing.T) { + th := Setup().InitBasic().InitSystemAdmin() + defer th.TearDown() + Client := th.Client + + total := <-th.App.Srv.Store.User().GetTotalUsersCount() + + rstats, resp := Client.GetTotalUsersStats("") + CheckNoError(t, resp) + + if rstats.TotalUsersCount != total.Data.(int64) { + t.Fatal("wrong count") + } +} + func TestUpdateUser(t *testing.T) { th := Setup().InitBasic().InitSystemAdmin() defer th.TearDown() @@ -1837,30 +1852,23 @@ func TestUpdateUserPassword(t *testing.T) { /*func TestResetPassword(t *testing.T) { th := Setup().InitBasic() Client := th.Client - Client.Logout() - user := th.BasicUser - // Delete all the messages before check the reset password utils.DeleteMailBox(user.Email) - success, resp := Client.SendPasswordResetEmail(user.Email) CheckNoError(t, resp) if !success { t.Fatal("should have succeeded") } - _, resp = Client.SendPasswordResetEmail("") CheckBadRequestStatus(t, resp) - // Should not leak whether the email is attached to an account or not success, resp = Client.SendPasswordResetEmail("notreal@example.com") CheckNoError(t, resp) if !success { t.Fatal("should have succeeded") } - // Check if the email was send to the right email address and the recovery key match var resultsMailbox utils.JSONMessageHeaderInbucket err := utils.RetryInbucket(5, func() error { @@ -1872,7 +1880,6 @@ func TestUpdateUserPassword(t *testing.T) { t.Log(err) t.Log("No email was received, maybe due load on the server. Disabling this verification") } - var recoveryTokenString string if err == nil && len(resultsMailbox) > 0 { if !strings.ContainsAny(resultsMailbox[0].To[0], user.Email) { @@ -1889,7 +1896,6 @@ func TestUpdateUserPassword(t *testing.T) { } } } - var recoveryToken *model.Token if result := <-th.App.Srv.Store.Token().GetByToken(recoveryTokenString); result.Err != nil { t.Log(recoveryTokenString) @@ -1897,44 +1903,33 @@ func TestUpdateUserPassword(t *testing.T) { } else { recoveryToken = result.Data.(*model.Token) } - _, resp = Client.ResetPassword(recoveryToken.Token, "") CheckBadRequestStatus(t, resp) - _, resp = Client.ResetPassword(recoveryToken.Token, "newp") CheckBadRequestStatus(t, resp) - _, resp = Client.ResetPassword("", "newpwd") CheckBadRequestStatus(t, resp) - _, resp = Client.ResetPassword("junk", "newpwd") CheckBadRequestStatus(t, resp) - code := "" for i := 0; i < model.TOKEN_SIZE; i++ { code += "a" } - _, resp = Client.ResetPassword(code, "newpwd") CheckBadRequestStatus(t, resp) - success, resp = Client.ResetPassword(recoveryToken.Token, "newpwd") CheckNoError(t, resp) if !success { t.Fatal("should have succeeded") } - Client.Login(user.Email, "newpwd") Client.Logout() - _, resp = Client.ResetPassword(recoveryToken.Token, "newpwd") CheckBadRequestStatus(t, resp) - authData := model.NewId() if result := <-app.Srv.Store.User().UpdateAuthData(user.Id, "random", &authData, "", true); result.Err != nil { t.Fatal(result.Err) } - _, resp = Client.SendPasswordResetEmail(user.Email) CheckBadRequestStatus(t, resp) }*/ @@ -2240,6 +2235,58 @@ func TestSetProfileImage(t *testing.T) { t.Fatal(err) } } +func TestCBALogin(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + Client := th.Client + Client.Logout() + + th.App.SetLicense(model.NewTestLicense("saml")) + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ExperimentalSettings.ClientSideCertEnable = true + *cfg.ExperimentalSettings.ClientSideCertCheck = model.CLIENT_SIDE_CERT_CHECK_PRIMARY_AUTH + }) + + user, resp := Client.Login(th.BasicUser.Email, th.BasicUser.Password) + if resp.Error.StatusCode != 400 && user == nil { + t.Fatal("Should have failed because it's missing the cert header") + } + + Client.HttpHeader["X-SSL-Client-Cert"] = "valid_cert_fake" + user, resp = Client.Login(th.BasicUser.Email, th.BasicUser.Password) + if resp.Error.StatusCode != 400 && user == nil { + t.Fatal("Should have failed because it's missing the cert subject") + } + + Client.HttpHeader["X-SSL-Client-Cert-Subject-DN"] = "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, CN=www.freesoft.org/emailAddress=mis_match" + th.BasicUser.Email + user, resp = Client.Login(th.BasicUser.Email, "") + if resp.Error.StatusCode != 400 && user == nil { + t.Fatal("Should have failed because the emails mismatch") + } + + Client.HttpHeader["X-SSL-Client-Cert-Subject-DN"] = "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, CN=www.freesoft.org/emailAddress=" + th.BasicUser.Email + user, resp = Client.Login(th.BasicUser.Email, "") + if !(user != nil && user.Email == th.BasicUser.Email) { + t.Fatal("Should have been able to login") + } + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ExperimentalSettings.ClientSideCertEnable = true + *cfg.ExperimentalSettings.ClientSideCertCheck = model.CLIENT_SIDE_CERT_CHECK_SECONDARY_AUTH + }) + + Client.HttpHeader["X-SSL-Client-Cert-Subject-DN"] = "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, CN=www.freesoft.org/emailAddress=" + th.BasicUser.Email + user, resp = Client.Login(th.BasicUser.Email, "") + if resp.Error.StatusCode != 400 && user == nil { + t.Fatal("Should have failed because password is required") + } + + Client.HttpHeader["X-SSL-Client-Cert-Subject-DN"] = "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, CN=www.freesoft.org/emailAddress=" + th.BasicUser.Email + user, resp = Client.Login(th.BasicUser.Email, th.BasicUser.Password) + if !(user != nil && user.Email == th.BasicUser.Email) { + t.Fatal("Should have been able to login") + } +} func TestSwitchAccount(t *testing.T) { th := Setup().InitBasic().InitSystemAdmin() diff --git a/app/import.go b/app/import.go index 8075497a0..5364b1026 100644 --- a/app/import.go +++ b/app/import.go @@ -1699,10 +1699,12 @@ func (a *App) OldImportFile(timestamp time.Time, file io.Reader, teamId string, return nil, err } - img, width, height := prepareImage(data) - if img != nil { - a.generateThumbnailImage(*img, fileInfo.ThumbnailPath, width, height) - a.generatePreviewImage(*img, fileInfo.PreviewPath, width) + if fileInfo.IsImage() && fileInfo.MimeType != "image/svg+xml" { + img, width, height := prepareImage(data) + if img != nil { + a.generateThumbnailImage(*img, fileInfo.ThumbnailPath, width, height) + a.generatePreviewImage(*img, fileInfo.PreviewPath, width) + } } return fileInfo, nil diff --git a/app/login.go b/app/login.go index 3001e1f4d..d3d2a423e 100644 --- a/app/login.go +++ b/app/login.go @@ -6,6 +6,7 @@ package app import ( "fmt" "net/http" + "strings" "time" "github.com/avct/uasurfer" @@ -13,6 +14,23 @@ import ( "github.com/mattermost/mattermost-server/store" ) +func (a *App) CheckForClienSideCert(r *http.Request) (string, string, string) { + pem := r.Header.Get("X-SSL-Client-Cert") // mapped to $ssl_client_cert from nginx + subject := r.Header.Get("X-SSL-Client-Cert-Subject-DN") // mapped to $ssl_client_s_dn from nginx + email := "" + + if len(subject) > 0 { + for _, v := range strings.Split(subject, "/") { + kv := strings.Split(v, "=") + if len(kv) == 2 && kv[0] == "emailAddress" { + email = kv[1] + } + } + } + + return pem, subject, email +} + func (a *App) AuthenticateUserForLogin(id, loginId, password, mfaToken string, ldapOnly bool) (user *model.User, err *model.AppError) { // Do statistics defer func() { @@ -35,6 +53,13 @@ func (a *App) AuthenticateUserForLogin(id, loginId, password, mfaToken string, l return nil, err } + // If client side cert is enable and it's checking as a primary source + // then trust the proxy and cert that the correct user is supplied and allow + // them access + if *a.Config().ExperimentalSettings.ClientSideCertEnable && *a.Config().ExperimentalSettings.ClientSideCertCheck == model.CLIENT_SIDE_CERT_CHECK_PRIMARY_AUTH { + return user, nil + } + // and then authenticate them if user, err = a.authenticateUser(user, password, mfaToken); err != nil { return nil, err diff --git a/app/login_test.go b/app/login_test.go new file mode 100644 index 000000000..db92f1d7d --- /dev/null +++ b/app/login_test.go @@ -0,0 +1,37 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "net/http" + "testing" +) + +func TestCheckForClienSideCert(t *testing.T) { + th := Setup() + defer th.TearDown() + + var tests = []struct { + pem string + subject string + expectedEmail string + }{ + {"blah", "blah", ""}, + {"blah", "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, CN=www.freesoft.org/emailAddress=test@test.com", "test@test.com"}, + {"blah", "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, CN=www.freesoft.org/EmailAddress=test@test.com", ""}, + {"blah", "CN=www.freesoft.org/EmailAddress=test@test.com, C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft", ""}, + } + + for _, tt := range tests { + r := &http.Request{Header: http.Header{}} + r.Header.Add("X-SSL-Client-Cert", tt.pem) + r.Header.Add("X-SSL-Client-Cert-Subject-DN", tt.subject) + + _, _, actualEmail := th.App.CheckForClienSideCert(r) + + if actualEmail != tt.expectedEmail { + t.Fatalf("CheckForClienSideCert(%v): expected %v, actual %v", tt.subject, tt.expectedEmail, actualEmail) + } + } +} diff --git a/app/notification.go b/app/notification.go index a3c1857d5..dbd37c7f2 100644 --- a/app/notification.go +++ b/app/notification.go @@ -752,6 +752,11 @@ func (a *App) sendPushNotification(post *model.Post, user *model.User, channel * msg.Message = a.getPushNotificationMessage(post.Message, explicitMention, channelWideMention, hasFiles, senderName, channelName, channel.Type, replyToThreadType, userLocale) for _, session := range sessions { + + if session.IsExpired() { + continue + } + tmpMessage := *model.PushNotificationFromJson(strings.NewReader(msg.ToJson())) tmpMessage.SetDeviceIdAndPlatform(session.DeviceId) diff --git a/app/slackimport.go b/app/slackimport.go index 3333af604..f9e2ac4ab 100644 --- a/app/slackimport.go +++ b/app/slackimport.go @@ -157,7 +157,7 @@ func (a *App) SlackAddUsers(teamId string, slackusers []SlackUser, importerLog * if email == "" { email = sUser.Username + "@example.com" importerLog.WriteString(utils.T("api.slackimport.slack_add_users.missing_email_address", map[string]interface{}{"Email": email, "Username": sUser.Username})) - mlog.Warn("Slack Import: User {{.Username}} does not have an email address in the Slack export. Used {{.Email}} as a placeholder. The user should update their email address once logged in to the system.") + mlog.Warn(fmt.Sprintf("Slack Import: User %v does not have an email address in the Slack export. Used %v as a placeholder. The user should update their email address once logged in to the system.", email, sUser.Username)) } password := model.NewId() @@ -396,7 +396,7 @@ func (a *App) SlackUploadFile(sPost SlackPost, uploads map[string]*zip.File, tea if file, ok := uploads[sPost.File.Id]; ok { openFile, err := file.Open() if err != nil { - mlog.Warn("Slack Import: Unable to open the file {{.FileId}} from the Slack export: {{.Error}}.") + mlog.Warn(fmt.Sprintf("Slack Import: Unable to open the file %v from the Slack export: %v.", sPost.File.Id, err.Error())) return nil, false } defer openFile.Close() @@ -404,13 +404,13 @@ func (a *App) SlackUploadFile(sPost SlackPost, uploads map[string]*zip.File, tea timestamp := utils.TimeFromMillis(SlackConvertTimeStamp(sPost.TimeStamp)) uploadedFile, err := a.OldImportFile(timestamp, openFile, teamId, channelId, userId, filepath.Base(file.Name)) if err != nil { - mlog.Warn("Slack Import: An error occurred when uploading file {{.FileId}}: {{.Error}}.") + mlog.Warn(fmt.Sprintf("Slack Import: An error occurred when uploading file %v: %v.", sPost.File.Id, err.Error())) return nil, false } return uploadedFile, true } else { - mlog.Warn("Slack Import: Unable to import file {{.FileId}} as the file is missing from the Slack export zip file.") + mlog.Warn(fmt.Sprintf("Slack Import: Unable to import file %v as the file is missing from the Slack export zip file.", sPost.File.Id)) return nil, false } } else { @@ -440,22 +440,22 @@ func (a *App) addSlackUsersToChannel(members []string, users map[string]*model.U func SlackSanitiseChannelProperties(channel model.Channel) model.Channel { if utf8.RuneCountInString(channel.DisplayName) > model.CHANNEL_DISPLAY_NAME_MAX_RUNES { - mlog.Warn(fmt.Sprint("api.slackimport.slack_sanitise_channel_properties.display_name_too_long.warn", map[string]interface{}{"ChannelName": channel.DisplayName})) + mlog.Warn(fmt.Sprintf("Slack Import: Channel %v display name exceeds the maximum length. It will be truncated when imported.", channel.DisplayName)) channel.DisplayName = truncateRunes(channel.DisplayName, model.CHANNEL_DISPLAY_NAME_MAX_RUNES) } if len(channel.Name) > model.CHANNEL_NAME_MAX_LENGTH { - mlog.Warn(fmt.Sprint("api.slackimport.slack_sanitise_channel_properties.name_too_long.warn", map[string]interface{}{"ChannelName": channel.DisplayName})) + mlog.Warn(fmt.Sprintf("Slack Import: Channel %v handle exceeds the maximum length. It will be truncated when imported.", channel.DisplayName)) channel.Name = channel.Name[0:model.CHANNEL_NAME_MAX_LENGTH] } if utf8.RuneCountInString(channel.Purpose) > model.CHANNEL_PURPOSE_MAX_RUNES { - mlog.Warn(fmt.Sprint("api.slackimport.slack_sanitise_channel_properties.purpose_too_long.warn", map[string]interface{}{"ChannelName": channel.DisplayName})) + mlog.Warn(fmt.Sprintf("Slack Import: Channel %v purpose exceeds the maximum length. It will be truncated when imported.", channel.DisplayName)) channel.Purpose = truncateRunes(channel.Purpose, model.CHANNEL_PURPOSE_MAX_RUNES) } if utf8.RuneCountInString(channel.Header) > model.CHANNEL_HEADER_MAX_RUNES { - mlog.Warn(fmt.Sprint("api.slackimport.slack_sanitise_channel_properties.header_too_long.warn", map[string]interface{}{"ChannelName": channel.DisplayName})) + mlog.Warn(fmt.Sprintf("Slack Import: Channel %v header exceeds the maximum length. It will be truncated when imported.", channel.DisplayName)) channel.Header = truncateRunes(channel.Header, model.CHANNEL_HEADER_MAX_RUNES) } @@ -514,7 +514,7 @@ func SlackConvertUserMentions(users []SlackUser, posts map[string][]SlackPost) m for _, user := range users { r, err := regexp.Compile("<@" + user.Id + `(\|` + user.Username + ")?>") if err != nil { - mlog.Warn(fmt.Sprint("Slack Import: Unable to compile the @mention, matching regular expression for the Slack user {{.Username}} (id={{.UserID}}).", user.Id, user.Username), mlog.String("user_id", user.Id)) + mlog.Warn(fmt.Sprintf("Slack Import: Unable to compile the @mention, matching regular expression for the Slack user %v (id=%v).", user.Id, user.Username), mlog.String("user_id", user.Id)) continue } regexes["@"+user.Username] = r @@ -542,7 +542,7 @@ func SlackConvertChannelMentions(channels []SlackChannel, posts map[string][]Sla for _, channel := range channels { r, err := regexp.Compile("<#" + channel.Id + `(\|` + channel.Name + ")?>") if err != nil { - mlog.Warn(fmt.Sprint("Slack Import: Unable to compile the !channel, matching regular expression for the Slack channel {{.ChannelName}} (id={{.ChannelID}}).", channel.Id, channel.Name)) + mlog.Warn(fmt.Sprintf("Slack Import: Unable to compile the !channel, matching regular expression for the Slack channel %v (id=%v).", channel.Id, channel.Name)) continue } regexes["~"+channel.Name] = r diff --git a/app/user.go b/app/user.go index ccf8dd40e..27e6f347d 100644 --- a/app/user.go +++ b/app/user.go @@ -1373,6 +1373,17 @@ func (a *App) GetVerifyEmailToken(token string) (*model.Token, *model.AppError) } } +func (a *App) GetTotalUsersStats() (*model.UsersStats, *model.AppError) { + stats := &model.UsersStats{} + + if result := <-a.Srv.Store.User().GetTotalUsersCount(); result.Err != nil { + return nil, result.Err + } else { + stats.TotalUsersCount = result.Data.(int64) + } + return stats, nil +} + func (a *App) VerifyUserEmail(userId string) *model.AppError { return (<-a.Srv.Store.User().VerifyEmail(userId)).Err } diff --git a/cmd/mattermost/commands/import.go b/cmd/mattermost/commands/import.go index 91cfaf997..8526ba6f8 100644 --- a/cmd/mattermost/commands/import.go +++ b/cmd/mattermost/commands/import.go @@ -74,9 +74,18 @@ func slackImportCmdF(command *cobra.Command, args []string) error { CommandPrettyPrintln("Running Slack Import. This may take a long time for large teams or teams with many messages.") - a.SlackImport(fileReader, fileInfo.Size(), team.Id) + importErr, log := a.SlackImport(fileReader, fileInfo.Size(), team.Id) + + if importErr != nil { + return err + } + + CommandPrettyPrintln("") + CommandPrintln(log.String()) + CommandPrettyPrintln("") CommandPrettyPrintln("Finished Slack Import.") + CommandPrettyPrintln("") return nil } diff --git a/cmd/platform/main.go b/cmd/platform/main.go index b5ea51920..25e091a84 100644 --- a/cmd/platform/main.go +++ b/cmd/platform/main.go @@ -6,19 +6,10 @@ package main import ( "fmt" "os" - "path/filepath" "syscall" -) -func findMattermostBinary() string { - for _, file := range []string{"./mattermost", "../mattermost", "./bin/mattermost"} { - path, _ := filepath.Abs(file) - if stat, err := os.Stat(path); err == nil && !stat.IsDir() { - return path - } - } - return "./mattermost" -} + "github.com/mattermost/mattermost-server/utils" +) func main() { // Print angry message to use mattermost command directly @@ -33,7 +24,15 @@ The platform binary will be removed in a future version. args := os.Args args[0] = "mattermost" args = append(args, "--platform") - if err := syscall.Exec(findMattermostBinary(), args, nil); err != nil { - fmt.Println("Could not start Mattermost, use the mattermost command directly.") + + realMattermost := utils.FindFile("mattermost") + if realMattermost == "" { + realMattermost = utils.FindFile("bin/mattermost") + } + + if realMattermost == "" { + fmt.Println("Could not start Mattermost, use the mattermost command directly: failed to find mattermost") + } else if err := syscall.Exec(realMattermost, args, nil); err != nil { + fmt.Printf("Could not start Mattermost, use the mattermost command directly: %s\n", err.Error()) } } diff --git a/config/default.json b/config/default.json index 67c1220bb..30c8f282f 100644 --- a/config/default.json +++ b/config/default.json @@ -337,6 +337,10 @@ "BlockProfileRate": 0, "ListenAddress": ":8067" }, + "ExperimentalSettings": { + "ClientSideCertEnable": false, + "ClientSideCertCheck": "secondary" + }, "AnalyticsSettings": { "MaxUsersForStatistics": 2500 }, diff --git a/i18n/en.json b/i18n/en.json index 4df698294..2f7aa47fc 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -2275,22 +2275,6 @@ "translation": "Slack Import: Error occurred when parsing some Slack posts. Import may work anyway." }, { - "id": "api.slackimport.slack_sanitise_channel_properties.display_name_too_long.warn", - "translation": "Slack Import: Channel {{.ChannelName}} display name exceeds the maximum length. It will be truncated when imported." - }, - { - "id": "api.slackimport.slack_sanitise_channel_properties.header_too_long.warn", - "translation": "Slack Import: Channel {{.ChannelName}} header exceeds the maximum length. It will be truncated when imported." - }, - { - "id": "api.slackimport.slack_sanitise_channel_properties.name_too_long.warn", - "translation": "Slack Import: Channel {{.ChannelName}} handle exceeds the maximum length. It will be truncated when imported." - }, - { - "id": "api.slackimport.slack_sanitise_channel_properties.purpose_too_long.warn", - "translation": "Slack Import: Channel {{.ChannelName}} purpose exceeds the maximum length. It will be truncated when imported." - }, - { "id": "api.status.init.debug", "translation": "Initializing status API routes" }, diff --git a/model/client4.go b/model/client4.go index fb4d1375c..f5a856835 100644 --- a/model/client4.go +++ b/model/client4.go @@ -57,6 +57,7 @@ type Client4 struct { HttpClient *http.Client // The http client AuthToken string AuthType string + HttpHeader map[string]string // Headers to be copied over for each request } func closeBody(r *http.Response) { @@ -78,7 +79,7 @@ func (c *Client4) Must(result interface{}, resp *Response) interface{} { } func NewAPIv4Client(url string) *Client4 { - return &Client4{url, url + API_URL_SUFFIX, &http.Client{}, "", ""} + return &Client4{url, url + API_URL_SUFFIX, &http.Client{}, "", "", map[string]string{}} } func BuildErrorResponse(r *http.Response, err *AppError) *Response { @@ -392,6 +393,10 @@ func (c *Client4) GetTeamSchemeRoute(teamId string) string { return fmt.Sprintf(c.GetTeamsRoute()+"/%v/scheme", teamId) } +func (c *Client4) GetTotalUsersStatsRoute() string { + return fmt.Sprintf(c.GetUsersRoute() + "/stats") +} + func (c *Client4) DoApiGet(url string, etag string) (*http.Response, *AppError) { return c.DoApiRequest(http.MethodGet, c.ApiUrl+url, "", etag) } @@ -410,7 +415,6 @@ func (c *Client4) DoApiDelete(url string) (*http.Response, *AppError) { func (c *Client4) DoApiRequest(method, url, data, etag string) (*http.Response, *AppError) { rq, _ := http.NewRequest(method, url, strings.NewReader(data)) - rq.Close = true if len(etag) > 0 { rq.Header.Set(HEADER_ETAG_CLIENT, etag) @@ -420,6 +424,13 @@ func (c *Client4) DoApiRequest(method, url, data, etag string) (*http.Response, rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) } + if c.HttpHeader != nil && len(c.HttpHeader) > 0 { + + for k, v := range c.HttpHeader { + rq.Header.Set(k, v) + } + } + if rp, err := c.HttpClient.Do(rq); err != nil || rp == nil { return nil, NewAppError(url, "model.client.connecting.app_error", nil, err.Error(), 0) } else if rp.StatusCode == 304 { @@ -435,7 +446,6 @@ func (c *Client4) DoApiRequest(method, url, data, etag string) (*http.Response, func (c *Client4) DoUploadFile(url string, data []byte, contentType string) (*FileUploadResponse, *Response) { rq, _ := http.NewRequest("POST", c.ApiUrl+url, bytes.NewReader(data)) rq.Header.Set("Content-Type", contentType) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -457,7 +467,6 @@ func (c *Client4) DoUploadFile(url string, data []byte, contentType string) (*Fi func (c *Client4) DoEmojiUploadFile(url string, data []byte, contentType string) (*Emoji, *Response) { rq, _ := http.NewRequest("POST", c.ApiUrl+url, bytes.NewReader(data)) rq.Header.Set("Content-Type", contentType) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -479,7 +488,6 @@ func (c *Client4) DoEmojiUploadFile(url string, data []byte, contentType string) func (c *Client4) DoUploadImportTeam(url string, data []byte, contentType string) (map[string]string, *Response) { rq, _ := http.NewRequest("POST", c.ApiUrl+url, bytes.NewReader(data)) rq.Header.Set("Content-Type", contentType) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -1102,7 +1110,6 @@ func (c *Client4) SetProfileImage(userId string, data []byte) (bool, *Response) rq, _ := http.NewRequest("POST", c.ApiUrl+c.GetUserRoute(userId)+"/image", bytes.NewReader(body.Bytes())) rq.Header.Set("Content-Type", writer.FormDataContentType()) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -1473,6 +1480,17 @@ func (c *Client4) GetTeamStats(teamId, etag string) (*TeamStats, *Response) { } } +// GetTotalUsersStats returns a total system user stats. +// Must be authenticated. +func (c *Client4) GetTotalUsersStats(etag string) (*UsersStats, *Response) { + if r, err := c.DoApiGet(c.GetTotalUsersStatsRoute(), etag); err != nil { + return nil, BuildErrorResponse(r, err) + } else { + defer closeBody(r) + return UsersStatsFromJson(r.Body), BuildResponse(r) + } +} + // GetTeamUnread will return a TeamUnread object that contains the amount of // unread messages and mentions the user has for the specified team. // Must be authenticated. @@ -1553,7 +1571,6 @@ func (c *Client4) SetTeamIcon(teamId string, data []byte) (bool, *Response) { rq, _ := http.NewRequest("POST", c.ApiUrl+c.GetTeamRoute(teamId)+"/image", bytes.NewReader(body.Bytes())) rq.Header.Set("Content-Type", writer.FormDataContentType()) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -2410,7 +2427,6 @@ func (c *Client4) UploadLicenseFile(data []byte) (bool, *Response) { rq, _ := http.NewRequest("POST", c.ApiUrl+c.GetLicenseRoute(), bytes.NewReader(body.Bytes())) rq.Header.Set("Content-Type", writer.FormDataContentType()) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -2798,7 +2814,6 @@ func (c *Client4) GetComplianceReport(reportId string) (*Compliance, *Response) func (c *Client4) DownloadComplianceReport(reportId string) ([]byte, *Response) { var rq *http.Request rq, _ = http.NewRequest("GET", c.ApiUrl+c.GetComplianceReportRoute(reportId), nil) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, "BEARER "+c.AuthToken) @@ -2903,7 +2918,6 @@ func (c *Client4) UploadBrandImage(data []byte) (bool, *Response) { rq, _ := http.NewRequest("POST", c.ApiUrl+c.GetBrandRoute()+"/image", bytes.NewReader(body.Bytes())) rq.Header.Set("Content-Type", writer.FormDataContentType()) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -3056,7 +3070,6 @@ func (c *Client4) DeauthorizeOAuthApp(appId string) (bool, *Response) { func (c *Client4) GetOAuthAccessToken(data url.Values) (*AccessResponse, *Response) { rq, _ := http.NewRequest(http.MethodPost, c.Url+"/oauth/access_token", strings.NewReader(data.Encode())) rq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) @@ -3612,7 +3625,6 @@ func (c *Client4) UploadPlugin(file io.Reader) (*Manifest, *Response) { rq, _ := http.NewRequest("POST", c.ApiUrl+c.GetPluginsRoute(), body) rq.Header.Set("Content-Type", writer.FormDataContentType()) - rq.Close = true if len(c.AuthToken) > 0 { rq.Header.Set(HEADER_AUTH, c.AuthType+" "+c.AuthToken) diff --git a/model/config.go b/model/config.go index f2bebf03b..47e2f68a4 100644 --- a/model/config.go +++ b/model/config.go @@ -160,6 +160,9 @@ const ( COMPLIANCE_EXPORT_TYPE_GLOBALRELAY = "globalrelay" GLOBALRELAY_CUSTOMER_TYPE_A9 = "A9" GLOBALRELAY_CUSTOMER_TYPE_A10 = "A10" + + CLIENT_SIDE_CERT_CHECK_PRIMARY_AUTH = "primary" + CLIENT_SIDE_CERT_CHECK_SECONDARY_AUTH = "secondary" ) type ServiceSettings struct { @@ -545,6 +548,21 @@ func (s *MetricsSettings) SetDefaults() { } } +type ExperimentalSettings struct { + ClientSideCertEnable *bool + ClientSideCertCheck *string +} + +func (s *ExperimentalSettings) SetDefaults() { + if s.ClientSideCertEnable == nil { + s.ClientSideCertEnable = NewBool(false) + } + + if s.ClientSideCertCheck == nil { + s.ClientSideCertCheck = NewString(CLIENT_SIDE_CERT_CHECK_SECONDARY_AUTH) + } +} + type AnalyticsSettings struct { MaxUsersForStatistics *int } @@ -1829,6 +1847,7 @@ type Config struct { NativeAppSettings NativeAppSettings ClusterSettings ClusterSettings MetricsSettings MetricsSettings + ExperimentalSettings ExperimentalSettings AnalyticsSettings AnalyticsSettings WebrtcSettings WebrtcSettings ElasticsearchSettings ElasticsearchSettings @@ -1891,6 +1910,7 @@ func (o *Config) SetDefaults() { o.PasswordSettings.SetDefaults() o.TeamSettings.SetDefaults() o.MetricsSettings.SetDefaults() + o.ExperimentalSettings.SetDefaults() o.SupportSettings.SetDefaults() o.AnnouncementSettings.SetDefaults() o.ThemeSettings.SetDefaults() diff --git a/model/users_stats.go b/model/users_stats.go new file mode 100644 index 000000000..49c882e34 --- /dev/null +++ b/model/users_stats.go @@ -0,0 +1,24 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "encoding/json" + "io" +) + +type UsersStats struct { + TotalUsersCount int64 `json:"total_users_count"` +} + +func (o *UsersStats) ToJson() string { + b, _ := json.Marshal(o) + return string(b) +} + +func UsersStatsFromJson(data io.Reader) *UsersStats { + var o *UsersStats + json.NewDecoder(data).Decode(&o) + return o +} diff --git a/store/sqlstore/upgrade.go b/store/sqlstore/upgrade.go index 8356ef17f..65c2c11e2 100644 --- a/store/sqlstore/upgrade.go +++ b/store/sqlstore/upgrade.go @@ -15,6 +15,7 @@ import ( ) const ( + VERSION_5_1_0 = "5.1.0" VERSION_5_0_0 = "5.0.0" VERSION_4_10_0 = "4.10.0" VERSION_4_9_0 = "4.9.0" @@ -78,6 +79,7 @@ func UpgradeDatabase(sqlStore SqlStore) { UpgradeDatabaseToVersion49(sqlStore) UpgradeDatabaseToVersion410(sqlStore) UpgradeDatabaseToVersion50(sqlStore) + UpgradeDatabaseToVersion51(sqlStore) // If the SchemaVersion is empty this this is the first time it has ran // so lets set it to the current version. @@ -447,3 +449,11 @@ func UpgradeDatabaseToVersion50(sqlStore SqlStore) { saveSchemaVersion(sqlStore, VERSION_5_0_0) } } + +func UpgradeDatabaseToVersion51(sqlStore SqlStore) { + // TODO: Uncomment following condition when version 5.1.0 is released + // if shouldPerformUpgrade(sqlStore, VERSION_5_0_0, VERSION_5_1_0) { + + // saveSchemaVersion(sqlStore, VERSION_5_1_0) + // } +} diff --git a/utils/config.go b/utils/config.go index 64085fcff..80673cba6 100644 --- a/utils/config.go +++ b/utils/config.go @@ -32,36 +32,96 @@ const ( LOG_FILENAME = "mattermost.log" ) -// FindConfigFile attempts to find an existing configuration file. fileName can be an absolute or -// relative path or name such as "/opt/mattermost/config.json" or simply "config.json". An empty -// string is returned if no configuration is found. -func FindConfigFile(fileName string) (path string) { - if filepath.IsAbs(fileName) { - if _, err := os.Stat(fileName); err == nil { - return fileName +var ( + commonBaseSearchPaths = []string{ + ".", + "..", + "../..", + "../../..", + } +) + +func FindPath(path string, baseSearchPaths []string, filter func(os.FileInfo) bool) string { + if filepath.IsAbs(path) { + if _, err := os.Stat(path); err == nil { + return path } - } else { - for _, dir := range []string{"./config", "../config", "../../config", "../../../config", "."} { - path, _ := filepath.Abs(filepath.Join(dir, fileName)) - if _, err := os.Stat(path); err == nil { - return path + + return "" + } + + searchPaths := []string{} + for _, baseSearchPath := range baseSearchPaths { + searchPaths = append(searchPaths, baseSearchPath) + } + + // Additionally attempt to search relative to the location of the running binary. + var binaryDir string + if exe, err := os.Executable(); err == nil { + if exe, err = filepath.EvalSymlinks(exe); err == nil { + if exe, err = filepath.Abs(exe); err == nil { + binaryDir = filepath.Dir(exe) } } } - return "" -} + if binaryDir != "" { + for _, baseSearchPath := range baseSearchPaths { + searchPaths = append( + searchPaths, + filepath.Join(binaryDir, baseSearchPath), + ) + } + } -// FindDir looks for the given directory in nearby ancestors, falling back to `./` if not found. -func FindDir(dir string) (string, bool) { - for _, parent := range []string{".", "..", "../..", "../../.."} { - foundDir, err := filepath.Abs(filepath.Join(parent, dir)) + for _, parent := range searchPaths { + found, err := filepath.Abs(filepath.Join(parent, path)) if err != nil { continue - } else if _, err := os.Stat(foundDir); err == nil { - return foundDir, true + } else if fileInfo, err := os.Stat(found); err == nil { + if filter != nil { + if filter(fileInfo) { + return found + } + } else { + return found + } } } - return "./", false + + return "" +} + +// FindConfigFile attempts to find an existing configuration file. fileName can be an absolute or +// relative path or name such as "/opt/mattermost/config.json" or simply "config.json". An empty +// string is returned if no configuration is found. +func FindConfigFile(fileName string) (path string) { + found := FindFile(filepath.Join("config", fileName)) + if found == "" { + found = FindPath(fileName, []string{"."}, nil) + } + + return found +} + +// FindFile looks for the given file in nearby ancestors relative to the current working +// directory as well as the directory of the executable. +func FindFile(path string) string { + return FindPath(path, commonBaseSearchPaths, func(fileInfo os.FileInfo) bool { + return !fileInfo.IsDir() + }) +} + +// FindDir looks for the given directory in nearby ancestors relative to the current working +// directory as well as the directory of the executable, falling back to `./` if not found. +func FindDir(dir string) (string, bool) { + found := FindPath(dir, commonBaseSearchPaths, func(fileInfo os.FileInfo) bool { + return fileInfo.IsDir() + }) + if found == "" { + return "./", false + } + + return found, true } func MloggerConfigFromLoggerConfig(s *model.LogSettings) *mlog.LoggerConfiguration { @@ -652,6 +712,10 @@ func GenerateClientConfig(c *model.Config, diagnosticId string, license *model.L props["SamlLoginButtonColor"] = *c.SamlSettings.LoginButtonColor props["SamlLoginButtonBorderColor"] = *c.SamlSettings.LoginButtonBorderColor props["SamlLoginButtonTextColor"] = *c.SamlSettings.LoginButtonTextColor + + // do this under the correct licensed feature + props["ExperimentalClientSideCertEnable"] = strconv.FormatBool(*c.ExperimentalSettings.ClientSideCertEnable) + props["ExperimentalClientSideCertCheck"] = *c.ExperimentalSettings.ClientSideCertCheck } if *license.Features.Cluster { diff --git a/utils/config_test.go b/utils/config_test.go index 75bbc420f..63b283584 100644 --- a/utils/config_test.go +++ b/utils/config_test.go @@ -5,6 +5,7 @@ package utils import ( "bytes" + "fmt" "io/ioutil" "os" "path/filepath" @@ -46,20 +47,281 @@ func TestTimezoneConfig(t *testing.T) { } func TestFindConfigFile(t *testing.T) { - dir, err := ioutil.TempDir("", "") - require.NoError(t, err) - defer os.RemoveAll(dir) + t.Run("config.json in current working directory, not inside config/", func(t *testing.T) { + // Force a unique working directory + cwd, err := ioutil.TempDir("", "") + require.NoError(t, err) + defer os.RemoveAll(cwd) + + prevDir, err := os.Getwd() + require.NoError(t, err) + defer os.Chdir(prevDir) + os.Chdir(cwd) + + configJson, err := filepath.Abs("config.json") + require.NoError(t, err) + require.NoError(t, ioutil.WriteFile(configJson, []byte("{}"), 0600)) + + // Relative paths end up getting symlinks fully resolved. + configJsonResolved, err := filepath.EvalSymlinks(configJson) + require.NoError(t, err) + + assert.Equal(t, configJsonResolved, FindConfigFile("config.json")) + }) - path := filepath.Join(dir, "config.json") - require.NoError(t, ioutil.WriteFile(path, []byte("{}"), 0600)) + t.Run("config/config.json from various paths", func(t *testing.T) { + // Create the following directory structure: + // tmpDir1/ + // config/ + // config.json + // tmpDir2/ + // tmpDir3/ + // tmpDir4/ + // tmpDir5/ + tmpDir1, err := ioutil.TempDir("", "") + require.NoError(t, err) + defer os.RemoveAll(tmpDir1) + + err = os.Mkdir(filepath.Join(tmpDir1, "config"), 0700) + require.NoError(t, err) + + tmpDir2, err := ioutil.TempDir(tmpDir1, "") + require.NoError(t, err) + + tmpDir3, err := ioutil.TempDir(tmpDir2, "") + require.NoError(t, err) + + tmpDir4, err := ioutil.TempDir(tmpDir3, "") + require.NoError(t, err) + + tmpDir5, err := ioutil.TempDir(tmpDir4, "") + require.NoError(t, err) + + configJson := filepath.Join(tmpDir1, "config", "config.json") + require.NoError(t, ioutil.WriteFile(configJson, []byte("{}"), 0600)) + + // Relative paths end up getting symlinks fully resolved, so use this below as necessary. + configJsonResolved, err := filepath.EvalSymlinks(configJson) + require.NoError(t, err) + + testCases := []struct { + Description string + Cwd *string + FileName string + Expected string + }{ + { + "absolute path to config.json", + nil, + configJson, + configJson, + }, + { + "absolute path to config.json from directory containing config.json", + &tmpDir1, + configJson, + configJson, + }, + { + "relative path to config.json from directory containing config.json", + &tmpDir1, + "config.json", + configJsonResolved, + }, + { + "subdirectory of directory containing config.json", + &tmpDir2, + "config.json", + configJsonResolved, + }, + { + "twice-nested subdirectory of directory containing config.json", + &tmpDir3, + "config.json", + configJsonResolved, + }, + { + "thrice-nested subdirectory of directory containing config.json", + &tmpDir4, + "config.json", + configJsonResolved, + }, + { + "can't find from four nesting levels deep", + &tmpDir5, + "config.json", + "", + }, + } - assert.Equal(t, path, FindConfigFile(path)) + for _, testCase := range testCases { + t.Run(testCase.Description, func(t *testing.T) { + if testCase.Cwd != nil { + prevDir, err := os.Getwd() + require.NoError(t, err) + defer os.Chdir(prevDir) + os.Chdir(*testCase.Cwd) + } + + assert.Equal(t, testCase.Expected, FindConfigFile(testCase.FileName)) + }) + } + }) + + t.Run("config/config.json relative to executable", func(t *testing.T) { + osExecutable, err := os.Executable() + require.NoError(t, err) + osExecutableDir := filepath.Dir(osExecutable) + + // Force a working directory different than the executable. + cwd, err := ioutil.TempDir("", "") + require.NoError(t, err) + defer os.RemoveAll(cwd) + + prevDir, err := os.Getwd() + require.NoError(t, err) + defer os.Chdir(prevDir) + os.Chdir(cwd) + + testCases := []struct { + Description string + RelativePath string + }{ + { + "config/config.json", + ".", + }, + { + "../config/config.json", + "../", + }, + } - prevDir, err := os.Getwd() - require.NoError(t, err) - defer os.Chdir(prevDir) - os.Chdir(dir) - assert.Equal(t, path, FindConfigFile(path)) + for _, testCase := range testCases { + t.Run(testCase.Description, func(t *testing.T) { + // Install the config in config/config.json relative to the executable + configJson := filepath.Join(osExecutableDir, testCase.RelativePath, "config", "config.json") + require.NoError(t, os.Mkdir(filepath.Dir(configJson), 0700)) + require.NoError(t, ioutil.WriteFile(configJson, []byte("{}"), 0600)) + defer os.RemoveAll(filepath.Dir(configJson)) + + // Relative paths end up getting symlinks fully resolved. + configJsonResolved, err := filepath.EvalSymlinks(configJson) + require.NoError(t, err) + + assert.Equal(t, configJsonResolved, FindConfigFile("config.json")) + }) + } + }) +} + +func TestFindFile(t *testing.T) { + t.Run("files from various paths", func(t *testing.T) { + // Create the following directory structure: + // tmpDir1/ + // file1.json + // file2.xml + // other.txt + // tmpDir2/ + // other.txt/ [directory] + // tmpDir3/ + // tmpDir4/ + // tmpDir5/ + tmpDir1, err := ioutil.TempDir("", "") + require.NoError(t, err) + defer os.RemoveAll(tmpDir1) + + tmpDir2, err := ioutil.TempDir(tmpDir1, "") + require.NoError(t, err) + + err = os.Mkdir(filepath.Join(tmpDir2, "other.txt"), 0700) + require.NoError(t, err) + + tmpDir3, err := ioutil.TempDir(tmpDir2, "") + require.NoError(t, err) + + tmpDir4, err := ioutil.TempDir(tmpDir3, "") + require.NoError(t, err) + + tmpDir5, err := ioutil.TempDir(tmpDir4, "") + require.NoError(t, err) + + type testCase struct { + Description string + Cwd *string + FileName string + Expected string + } + + testCases := []testCase{} + + for _, fileName := range []string{"file1.json", "file2.xml", "other.txt"} { + filePath := filepath.Join(tmpDir1, fileName) + require.NoError(t, ioutil.WriteFile(filePath, []byte("{}"), 0600)) + + // Relative paths end up getting symlinks fully resolved, so use this below as necessary. + filePathResolved, err := filepath.EvalSymlinks(filePath) + require.NoError(t, err) + + testCases = append(testCases, []testCase{ + { + fmt.Sprintf("absolute path to %s", fileName), + nil, + filePath, + filePath, + }, + { + fmt.Sprintf("absolute path to %s from containing directory", fileName), + &tmpDir1, + filePath, + filePath, + }, + { + fmt.Sprintf("relative path to %s from containing directory", fileName), + &tmpDir1, + fileName, + filePathResolved, + }, + { + fmt.Sprintf("%s: subdirectory of containing directory", fileName), + &tmpDir2, + fileName, + filePathResolved, + }, + { + fmt.Sprintf("%s: twice-nested subdirectory of containing directory", fileName), + &tmpDir3, + fileName, + filePathResolved, + }, + { + fmt.Sprintf("%s: thrice-nested subdirectory of containing directory", fileName), + &tmpDir4, + fileName, + filePathResolved, + }, + { + fmt.Sprintf("%s: can't find from four nesting levels deep", fileName), + &tmpDir5, + fileName, + "", + }, + }...) + } + + for _, testCase := range testCases { + t.Run(testCase.Description, func(t *testing.T) { + if testCase.Cwd != nil { + prevDir, err := os.Getwd() + require.NoError(t, err) + defer os.Chdir(prevDir) + os.Chdir(*testCase.Cwd) + } + + assert.Equal(t, testCase.Expected, FindFile(testCase.FileName)) + }) + } + }) } func TestConfigFromEnviroVars(t *testing.T) { |