summaryrefslogtreecommitdiffstats
path: root/utils
diff options
context:
space:
mode:
Diffstat (limited to 'utils')
-rw-r--r--utils/config.go3
-rw-r--r--utils/file_backend_s3.go18
-rw-r--r--utils/file_backend_s3_test.go32
-rw-r--r--utils/i18n.go5
-rw-r--r--utils/log.go33
-rw-r--r--utils/logger/log4go_json_writer.go30
-rw-r--r--utils/logger/logger.go222
-rw-r--r--utils/lru.go113
-rw-r--r--utils/lru_test.go33
-rw-r--r--utils/mail.go64
-rw-r--r--utils/mail_test.go64
11 files changed, 211 insertions, 406 deletions
diff --git a/utils/config.go b/utils/config.go
index c4d3d0d96..93a870743 100644
--- a/utils/config.go
+++ b/utils/config.go
@@ -26,9 +26,6 @@ import (
)
const (
- MODE_DEV = "dev"
- MODE_BETA = "beta"
- MODE_PROD = "prod"
LOG_ROTATE_SIZE = 10000
LOG_FILENAME = "mattermost.log"
)
diff --git a/utils/file_backend_s3.go b/utils/file_backend_s3.go
index 8e72272a1..75282897f 100644
--- a/utils/file_backend_s3.go
+++ b/utils/file_backend_s3.go
@@ -37,7 +37,10 @@ type S3FileBackend struct {
// disables automatic region lookup.
func (b *S3FileBackend) s3New() (*s3.Client, error) {
var creds *credentials.Credentials
- if b.signV2 {
+
+ if b.accessKey == "" && b.secretKey == "" {
+ creds = credentials.NewIAM("")
+ } else if b.signV2 {
creds = credentials.NewStatic(b.accessKey, b.secretKey, "", credentials.SignatureV2)
} else {
creds = credentials.NewStatic(b.accessKey, b.secretKey, "", credentials.SignatureV4)
@@ -244,3 +247,16 @@ func s3CopyMetadata(encrypt bool) map[string]string {
metaData["x-amz-server-side-encryption"] = "AES256"
return metaData
}
+
+func CheckMandatoryS3Fields(settings *model.FileSettings) *model.AppError {
+ if len(settings.AmazonS3Bucket) == 0 {
+ return model.NewAppError("S3File", "api.admin.test_s3.missing_s3_bucket", nil, "", http.StatusBadRequest)
+ }
+
+ // if S3 endpoint is not set call the set defaults to set that
+ if len(settings.AmazonS3Endpoint) == 0 {
+ settings.SetDefaults()
+ }
+
+ return nil
+}
diff --git a/utils/file_backend_s3_test.go b/utils/file_backend_s3_test.go
new file mode 100644
index 000000000..a8834f226
--- /dev/null
+++ b/utils/file_backend_s3_test.go
@@ -0,0 +1,32 @@
+// Copyright (c) 2018-present Mattermost, Inc. All Rights Reserved.
+// See License.txt for license information.
+
+package utils
+
+import (
+ "testing"
+
+ "github.com/mattermost/mattermost-server/model"
+)
+
+func TestCheckMandatoryS3Fields(t *testing.T) {
+ cfg := model.FileSettings{}
+
+ err := CheckMandatoryS3Fields(&cfg)
+ if err == nil || err.Message != "api.admin.test_s3.missing_s3_bucket" {
+ t.Fatal("should've failed with missing s3 bucket")
+ }
+
+ cfg.AmazonS3Bucket = "test-mm"
+ err = CheckMandatoryS3Fields(&cfg)
+ if err != nil {
+ t.Fatal("should've not failed")
+ }
+
+ cfg.AmazonS3Endpoint = ""
+ err = CheckMandatoryS3Fields(&cfg)
+ if err != nil || cfg.AmazonS3Endpoint != "s3.amazonaws.com" {
+ t.Fatal("should've not failed because it should set the endpoint to the default")
+ }
+
+}
diff --git a/utils/i18n.go b/utils/i18n.go
index 71e1aaee1..8ed82d19f 100644
--- a/utils/i18n.go
+++ b/utils/i18n.go
@@ -91,11 +91,6 @@ func GetUserTranslations(locale string) i18n.TranslateFunc {
return translations
}
-func SetTranslations(locale string) i18n.TranslateFunc {
- translations := TfuncWithFallback(locale)
- return translations
-}
-
func GetTranslationsAndLocale(w http.ResponseWriter, r *http.Request) (i18n.TranslateFunc, string) {
// This is for checking against locales like pt_BR or zn_CN
headerLocaleFull := strings.Split(r.Header.Get("Accept-Language"), ",")[0]
diff --git a/utils/log.go b/utils/log.go
deleted file mode 100644
index c1f579e9d..000000000
--- a/utils/log.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
-// See License.txt for license information.
-
-package utils
-
-import (
- "bytes"
- "io"
- "io/ioutil"
-
- l4g "github.com/alecthomas/log4go"
-)
-
-// InfoReader logs the content of the io.Reader and returns a new io.Reader
-// with the same content as the received io.Reader.
-// If you pass reader by reference, it won't be re-created unless the loglevel
-// includes Debug.
-// If an error is returned, the reader is consumed an cannot be read again.
-func InfoReader(reader io.Reader, message string) (io.Reader, error) {
- var err error
- l4g.Info(func() string {
- content, err := ioutil.ReadAll(reader)
- if err != nil {
- return ""
- }
-
- reader = bytes.NewReader(content)
-
- return message + string(content)
- })
-
- return reader, err
-}
diff --git a/utils/logger/log4go_json_writer.go b/utils/logger/log4go_json_writer.go
deleted file mode 100644
index ede541b2b..000000000
--- a/utils/logger/log4go_json_writer.go
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
-// See License.txt for license information.
-
-// glue functions that allow logger.go to leverage log4Go to write JSON-formatted log records to a file
-
-package logger
-
-import (
- l4g "github.com/alecthomas/log4go"
- "github.com/mattermost/mattermost-server/utils"
-)
-
-// newJSONLogWriter is a utility method for creating a FileLogWriter set up to
-// output JSON record log messages instead of line-based ones.
-func newJSONLogWriter(fname string, rotate bool) *l4g.FileLogWriter {
- return l4g.NewFileLogWriter(fname, rotate).SetFormat(
- `{"level": "%L",
- "timestamp": "%D %T",
- "source": "%S",
- "message": %M
- }`).SetRotateLines(utils.LOG_ROTATE_SIZE)
-}
-
-// NewJSONFileLogger - Create a new logger with a "file" filter configured to send JSON-formatted log messages at
-// or above lvl to a file with the specified filename.
-func NewJSONFileLogger(lvl l4g.Level, filename string) l4g.Logger {
- return l4g.Logger{
- "file": &l4g.Filter{Level: lvl, LogWriter: newJSONLogWriter(filename, false)},
- }
-}
diff --git a/utils/logger/logger.go b/utils/logger/logger.go
deleted file mode 100644
index 558f3fe47..000000000
--- a/utils/logger/logger.go
+++ /dev/null
@@ -1,222 +0,0 @@
-// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
-// See License.txt for license information.
-
-// this is a new logger interface for mattermost
-
-package logger
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "path/filepath"
- "runtime"
-
- l4g "github.com/alecthomas/log4go"
-
- "strings"
-
- "github.com/mattermost/mattermost-server/model"
- "github.com/mattermost/mattermost-server/utils"
- "github.com/pkg/errors"
-)
-
-// this pattern allows us to "mock" the underlying l4g code when unit testing
-var logger l4g.Logger
-var debugLog = l4g.Debug
-var infoLog = l4g.Info
-var errorLog = l4g.Error
-
-// assumes that ../config.go::configureLog has already been called, and has in turn called l4g.close() to clean up
-// any old filters that we might have previously created
-func initL4g(logSettings model.LogSettings) {
- // TODO: add support for newConfig.LogSettings.EnableConsole. Right now, ../config.go sets it up in its configureLog
- // method. If we also set it up here, messages will be written to the console twice. Eventually, when all instances
- // of l4g have been replaced by this logger, we can move that code to here
- if logSettings.EnableFile {
- level := l4g.DEBUG
- if logSettings.FileLevel == "INFO" {
- level = l4g.INFO
- } else if logSettings.FileLevel == "WARN" {
- level = l4g.WARNING
- } else if logSettings.FileLevel == "ERROR" {
- level = l4g.ERROR
- }
-
- // create a logger that writes JSON objects to a file, and override our log methods to use it
- if logger != nil {
- logger.Close()
- }
- logger = NewJSONFileLogger(level, utils.GetLogFileLocation(logSettings.FileLocation)+".jsonl")
- debugLog = logger.Debug
- infoLog = logger.Info
- errorLog = logger.Error
- }
-}
-
-// contextKey lets us add contextual information to log messages
-type contextKey string
-
-func (c contextKey) String() string {
- return string(c)
-}
-
-const contextKeyUserID contextKey = contextKey("user_id")
-const contextKeyRequestID contextKey = contextKey("request_id")
-
-// any contextKeys added to this array will be serialized in every log message
-var contextKeys = [2]contextKey{contextKeyUserID, contextKeyRequestID}
-
-// WithUserId adds a user id to the specified context. If the returned Context is subsequently passed to a logging
-// method, the user id will automatically be included in the logged message
-func WithUserId(ctx context.Context, userID string) context.Context {
- return context.WithValue(ctx, contextKeyUserID, userID)
-}
-
-// WithRequestId adds a request id to the specified context. If the returned Context is subsequently passed to a logging
-// method, the request id will automatically be included in the logged message
-func WithRequestId(ctx context.Context, requestID string) context.Context {
- return context.WithValue(ctx, contextKeyRequestID, requestID)
-}
-
-// extracts known contextKey values from the specified Context and assembles them into the returned map
-func serializeContext(ctx context.Context) map[string]string {
- serialized := make(map[string]string)
- for _, key := range contextKeys {
- value, ok := ctx.Value(key).(string)
- if ok {
- serialized[string(key)] = value
- }
- }
- return serialized
-}
-
-// Returns the path to the next file up the callstack that has a different name than this file
-// in other words, finds the path to the file that is doing the logging.
-// Removes machine-specific prefix, so returned path starts with /mattermost-server.
-// Looks a maximum of 10 frames up the call stack to find a file that has a different name than this one.
-func getCallerFilename() (string, error) {
- _, currentFilename, _, ok := runtime.Caller(0)
- if !ok {
- return "", errors.New("Failed to traverse stack frame")
- }
-
- platformDirectory := currentFilename
- for filepath.Base(platformDirectory) != "platform" {
- platformDirectory = filepath.Dir(platformDirectory)
- if platformDirectory == "." || platformDirectory == string(filepath.Separator) {
- break
- }
- }
-
- for i := 1; i < 10; i++ {
- _, parentFilename, _, ok := runtime.Caller(i)
- if !ok {
- return "", errors.New("Failed to traverse stack frame")
- } else if parentFilename != currentFilename && strings.Contains(parentFilename, platformDirectory) {
- // trim parentFilename such that we return the path to parentFilename, relative to platformDirectory
- return parentFilename[strings.LastIndex(parentFilename, platformDirectory)+len(platformDirectory)+1:], nil
- }
- }
- return "", errors.New("Failed to traverse stack frame")
-}
-
-// creates a JSON representation of a log message
-func serializeLogMessage(ctx context.Context, message string) string {
- callerFilename, err := getCallerFilename()
- if err != nil {
- callerFilename = "Unknown"
- }
-
- bytes, err := json.Marshal(&struct {
- Context map[string]string `json:"context"`
- File string `json:"file"`
- Message string `json:"message"`
- }{
- serializeContext(ctx),
- callerFilename,
- message,
- })
- if err != nil {
- errorLog("Failed to serialize log message %v", message)
- }
- return string(bytes)
-}
-
-func formatMessage(args ...interface{}) string {
- msg, ok := args[0].(string)
- if !ok {
- panic("Second argument is not of type string")
- }
- if len(args) > 1 {
- variables := args[1:]
- msg = fmt.Sprintf(msg, variables...)
- }
- return msg
-}
-
-// Debugc logs a debugLog level message, including context information that is stored in the first parameter.
-// If two parameters are supplied, the second must be a message string, and will be logged directly.
-// If more than two parameters are supplied, the second parameter must be a format string, and the remaining parameters
-// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function.
-func Debugc(ctx context.Context, args ...interface{}) {
- debugLog(func() string {
- msg := formatMessage(args...)
- return serializeLogMessage(ctx, msg)
- })
-}
-
-// Debugf logs a debugLog level message.
-// If one parameter is supplied, it must be a message string, and will be logged directly.
-// If two or more parameters are specified, the first parameter must be a format string, and the remaining parameters
-// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function.
-func Debugf(args ...interface{}) {
- debugLog(func() string {
- msg := formatMessage(args...)
- return serializeLogMessage(context.Background(), msg)
- })
-}
-
-// Infoc logs an infoLog level message, including context information that is stored in the first parameter.
-// If two parameters are supplied, the second must be a message string, and will be logged directly.
-// If more than two parameters are supplied, the second parameter must be a format string, and the remaining parameters
-// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function.
-func Infoc(ctx context.Context, args ...interface{}) {
- infoLog(func() string {
- msg := formatMessage(args...)
- return serializeLogMessage(ctx, msg)
- })
-}
-
-// Infof logs an infoLog level message.
-// If one parameter is supplied, it must be a message string, and will be logged directly.
-// If two or more parameters are specified, the first parameter must be a format string, and the remaining parameters
-// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function.
-func Infof(args ...interface{}) {
- infoLog(func() string {
- msg := formatMessage(args...)
- return serializeLogMessage(context.Background(), msg)
- })
-}
-
-// Errorc logs an error level message, including context information that is stored in the first parameter.
-// If two parameters are supplied, the second must be a message string, and will be logged directly.
-// If more than two parameters are supplied, the second parameter must be a format string, and the remaining parameters
-// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function.
-func Errorc(ctx context.Context, args ...interface{}) {
- errorLog(func() string {
- msg := formatMessage(args...)
- return serializeLogMessage(ctx, msg)
- })
-}
-
-// Errorf logs an error level message.
-// If one parameter is supplied, it must be a message string, and will be logged directly.
-// If two or more parameters are specified, the first parameter must be a format string, and the remaining parameters
-// must be the variables to substitute into the format string, following the convention of the fmt.Sprintf(...) function.
-func Errorf(args ...interface{}) {
- errorLog(func() string {
- msg := formatMessage(args...)
- return serializeLogMessage(context.Background(), msg)
- })
-}
diff --git a/utils/lru.go b/utils/lru.go
index 576331563..8e896a6dc 100644
--- a/utils/lru.go
+++ b/utils/lru.go
@@ -9,15 +9,14 @@ package utils
import (
"container/list"
- "errors"
"sync"
"time"
)
// Caching Interface
type ObjectCache interface {
- AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) bool
- AddWithDefaultExpires(key, value interface{}) bool
+ AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64)
+ AddWithDefaultExpires(key, value interface{})
Purge()
Get(key interface{}) (value interface{}, ok bool)
Remove(key interface{})
@@ -32,10 +31,11 @@ type Cache struct {
evictList *list.List
items map[interface{}]*list.Element
lock sync.RWMutex
- onEvicted func(key interface{}, value interface{})
name string
defaultExpiry int64
invalidateClusterEvent string
+ currentGeneration int64
+ len int
}
// entry is used to hold a value in the evictList
@@ -43,25 +43,16 @@ type entry struct {
key interface{}
value interface{}
expireAtSecs int64
+ generation int64
}
// New creates an LRU of the given size
func NewLru(size int) *Cache {
- cache, _ := NewLruWithEvict(size, nil)
- return cache
-}
-
-func NewLruWithEvict(size int, onEvicted func(key interface{}, value interface{})) (*Cache, error) {
- if size <= 0 {
- return nil, errors.New(T("utils.iru.with_evict"))
- }
- c := &Cache{
+ return &Cache{
size: size,
evictList: list.New(),
items: make(map[interface{}]*list.Element, size),
- onEvicted: onEvicted,
}
- return c, nil
}
func NewLruWithParams(size int, name string, defaultExpiry int64, invalidateClusterEvent string) *Cache {
@@ -77,26 +68,19 @@ func (c *Cache) Purge() {
c.lock.Lock()
defer c.lock.Unlock()
- if c.onEvicted != nil {
- for k, v := range c.items {
- c.onEvicted(k, v.Value)
- }
- }
-
- c.evictList = list.New()
- c.items = make(map[interface{}]*list.Element, c.size)
+ c.len = 0
+ c.currentGeneration++
}
-func (c *Cache) Add(key, value interface{}) bool {
- return c.AddWithExpiresInSecs(key, value, 0)
+func (c *Cache) Add(key, value interface{}) {
+ c.AddWithExpiresInSecs(key, value, 0)
}
-func (c *Cache) AddWithDefaultExpires(key, value interface{}) bool {
- return c.AddWithExpiresInSecs(key, value, c.defaultExpiry)
+func (c *Cache) AddWithDefaultExpires(key, value interface{}) {
+ c.AddWithExpiresInSecs(key, value, c.defaultExpiry)
}
-// Add adds a value to the cache. Returns true if an eviction occurred.
-func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) bool {
+func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64) {
c.lock.Lock()
defer c.lock.Unlock()
@@ -107,45 +91,46 @@ func (c *Cache) AddWithExpiresInSecs(key, value interface{}, expireAtSecs int64)
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
- ent.Value.(*entry).value = value
- ent.Value.(*entry).expireAtSecs = expireAtSecs
- return false
+ e := ent.Value.(*entry)
+ e.value = value
+ e.expireAtSecs = expireAtSecs
+ if e.generation != c.currentGeneration {
+ e.generation = c.currentGeneration
+ c.len++
+ }
+ return
}
// Add new item
- ent := &entry{key, value, expireAtSecs}
+ ent := &entry{key, value, expireAtSecs, c.currentGeneration}
entry := c.evictList.PushFront(ent)
c.items[key] = entry
+ c.len++
- evict := c.evictList.Len() > c.size
- // Verify size not exceeded
- if evict {
- c.removeOldest()
+ if c.evictList.Len() > c.size {
+ c.removeElement(c.evictList.Back())
}
- return evict
}
-// Get looks up a key's value from the cache.
func (c *Cache) Get(key interface{}) (value interface{}, ok bool) {
c.lock.Lock()
defer c.lock.Unlock()
if ent, ok := c.items[key]; ok {
+ e := ent.Value.(*entry)
- if ent.Value.(*entry).expireAtSecs > 0 {
- if (time.Now().UnixNano() / int64(time.Second)) > ent.Value.(*entry).expireAtSecs {
- c.removeElement(ent)
- return nil, false
- }
+ if e.generation != c.currentGeneration || (e.expireAtSecs > 0 && (time.Now().UnixNano()/int64(time.Second)) > e.expireAtSecs) {
+ c.removeElement(ent)
+ return nil, false
}
c.evictList.MoveToFront(ent)
return ent.Value.(*entry).value, true
}
- return
+
+ return nil, false
}
-// Remove removes the provided key from the cache.
func (c *Cache) Remove(key interface{}) {
c.lock.Lock()
defer c.lock.Unlock()
@@ -155,25 +140,19 @@ func (c *Cache) Remove(key interface{}) {
}
}
-// RemoveOldest removes the oldest item from the cache.
-func (c *Cache) RemoveOldest() {
- c.lock.Lock()
- defer c.lock.Unlock()
- c.removeOldest()
-}
-
// Keys returns a slice of the keys in the cache, from oldest to newest.
func (c *Cache) Keys() []interface{} {
c.lock.RLock()
defer c.lock.RUnlock()
- keys := make([]interface{}, len(c.items))
- ent := c.evictList.Back()
+ keys := make([]interface{}, c.len)
i := 0
- for ent != nil {
- keys[i] = ent.Value.(*entry).key
- ent = ent.Prev()
- i++
+ for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() {
+ e := ent.Value.(*entry)
+ if e.generation == c.currentGeneration {
+ keys[i] = e.key
+ i++
+ }
}
return keys
@@ -183,7 +162,7 @@ func (c *Cache) Keys() []interface{} {
func (c *Cache) Len() int {
c.lock.RLock()
defer c.lock.RUnlock()
- return c.evictList.Len()
+ return c.len
}
func (c *Cache) Name() string {
@@ -194,20 +173,12 @@ func (c *Cache) GetInvalidateClusterEvent() string {
return c.invalidateClusterEvent
}
-// removeOldest removes the oldest item from the cache.
-func (c *Cache) removeOldest() {
- ent := c.evictList.Back()
- if ent != nil {
- c.removeElement(ent)
- }
-}
-
// removeElement is used to remove a given list element from the cache
func (c *Cache) removeElement(e *list.Element) {
c.evictList.Remove(e)
kv := e.Value.(*entry)
- delete(c.items, kv.key)
- if c.onEvicted != nil {
- c.onEvicted(kv.key, kv.value)
+ if kv.generation == c.currentGeneration {
+ c.len--
}
+ delete(c.items, kv.key)
}
diff --git a/utils/lru_test.go b/utils/lru_test.go
index 987163cd3..4312515b9 100644
--- a/utils/lru_test.go
+++ b/utils/lru_test.go
@@ -11,14 +11,7 @@ import "testing"
import "time"
func TestLRU(t *testing.T) {
- evictCounter := 0
- onEvicted := func(k interface{}, v interface{}) {
- evictCounter += 1
- }
- l, err := NewLruWithEvict(128, onEvicted)
- if err != nil {
- t.Fatalf("err: %v", err)
- }
+ l := NewLru(128)
for i := 0; i < 256; i++ {
l.Add(i, i)
@@ -27,10 +20,6 @@ func TestLRU(t *testing.T) {
t.Fatalf("bad len: %v", l.Len())
}
- if evictCounter != 128 {
- t.Fatalf("bad evict count: %v", evictCounter)
- }
-
for i, k := range l.Keys() {
if v, ok := l.Get(k); !ok || v != k || v != i+128 {
t.Fatalf("bad key: %v", k)
@@ -73,26 +62,6 @@ func TestLRU(t *testing.T) {
}
}
-// test that Add return true/false if an eviction occurred
-func TestLRUAdd(t *testing.T) {
- evictCounter := 0
- onEvicted := func(k interface{}, v interface{}) {
- evictCounter += 1
- }
-
- l, err := NewLruWithEvict(1, onEvicted)
- if err != nil {
- t.Fatalf("err: %v", err)
- }
-
- if l.Add(1, 1) || evictCounter != 0 {
- t.Errorf("should not have an eviction")
- }
- if !l.Add(2, 2) || evictCounter != 1 {
- t.Errorf("should have an eviction")
- }
-}
-
func TestLRUExpire(t *testing.T) {
l := NewLru(128)
diff --git a/utils/mail.go b/utils/mail.go
index 9023f7090..3b9f4bd9d 100644
--- a/utils/mail.go
+++ b/utils/mail.go
@@ -5,6 +5,8 @@ package utils
import (
"crypto/tls"
+ "errors"
+ "io"
"mime"
"net"
"net/mail"
@@ -15,8 +17,6 @@ import (
"net/http"
- "io"
-
l4g "github.com/alecthomas/log4go"
"github.com/mattermost/html2text"
"github.com/mattermost/mattermost-server/model"
@@ -26,6 +26,56 @@ func encodeRFC2047Word(s string) string {
return mime.BEncoding.Encode("utf-8", s)
}
+type authChooser struct {
+ smtp.Auth
+ Config *model.Config
+}
+
+func (a *authChooser) Start(server *smtp.ServerInfo) (string, []byte, error) {
+ a.Auth = LoginAuth(a.Config.EmailSettings.SMTPUsername, a.Config.EmailSettings.SMTPPassword, a.Config.EmailSettings.SMTPServer+":"+a.Config.EmailSettings.SMTPPort)
+ for _, method := range server.Auth {
+ if method == "PLAIN" {
+ a.Auth = smtp.PlainAuth("", a.Config.EmailSettings.SMTPUsername, a.Config.EmailSettings.SMTPPassword, a.Config.EmailSettings.SMTPServer+":"+a.Config.EmailSettings.SMTPPort)
+ break
+ }
+ }
+ return a.Auth.Start(server)
+}
+
+type loginAuth struct {
+ username, password, host string
+}
+
+func LoginAuth(username, password, host string) smtp.Auth {
+ return &loginAuth{username, password, host}
+}
+
+func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
+ if !server.TLS {
+ return "", nil, errors.New("unencrypted connection")
+ }
+
+ if server.Name != a.host {
+ return "", nil, errors.New("wrong host name")
+ }
+
+ return "LOGIN", []byte{}, nil
+}
+
+func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
+ if more {
+ switch string(fromServer) {
+ case "Username:":
+ return []byte(a.username), nil
+ case "Password:":
+ return []byte(a.password), nil
+ default:
+ return nil, errors.New("Unkown fromServer")
+ }
+ }
+ return nil, nil
+}
+
func connectToSMTPServer(config *model.Config) (net.Conn, *model.AppError) {
var conn net.Conn
var err error
@@ -75,9 +125,7 @@ func newSMTPClient(conn net.Conn, config *model.Config) (*smtp.Client, *model.Ap
}
if *config.EmailSettings.EnableSMTPAuth {
- auth := smtp.PlainAuth("", config.EmailSettings.SMTPUsername, config.EmailSettings.SMTPPassword, config.EmailSettings.SMTPServer+":"+config.EmailSettings.SMTPPort)
-
- if err = c.Auth(auth); err != nil {
+ if err = c.Auth(&authChooser{Config: config}); err != nil {
return nil, model.NewAppError("SendMail", "utils.mail.new_client.auth.app_error", nil, err.Error(), http.StatusInternalServerError)
}
}
@@ -138,10 +186,8 @@ func sendMail(mimeTo, smtpTo string, from mail.Address, subject, htmlBody string
"Auto-Submitted": {"auto-generated"},
"Precedence": {"bulk"},
}
- if mimeHeaders != nil {
- for k, v := range mimeHeaders {
- headers[k] = []string{encodeRFC2047Word(v)}
- }
+ for k, v := range mimeHeaders {
+ headers[k] = []string{encodeRFC2047Word(v)}
}
m := gomail.NewMessage(gomail.SetCharset("UTF-8"))
diff --git a/utils/mail_test.go b/utils/mail_test.go
index 068c90c60..31a4f8996 100644
--- a/utils/mail_test.go
+++ b/utils/mail_test.go
@@ -7,6 +7,9 @@ import (
"strings"
"testing"
+ "net/smtp"
+
+ "github.com/mattermost/mattermost-server/model"
"github.com/stretchr/testify/require"
)
@@ -169,3 +172,64 @@ func TestSendMailUsingConfig(t *testing.T) {
}
}
}*/
+
+func TestAuthMethods(t *testing.T) {
+ config := model.Config{
+ EmailSettings: model.EmailSettings{
+ EnableSMTPAuth: model.NewBool(false),
+ SMTPUsername: "test",
+ SMTPPassword: "fakepass",
+ SMTPServer: "fakeserver",
+ SMTPPort: "25",
+ },
+ }
+
+ auth := &authChooser{Config: &config}
+ tests := []struct {
+ desc string
+ server *smtp.ServerInfo
+ err string
+ }{
+ {
+ desc: "auth PLAIN success",
+ server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"PLAIN"}, TLS: true},
+ },
+ {
+ desc: "auth PLAIN unencrypted connection fail",
+ server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"PLAIN"}, TLS: false},
+ err: "unencrypted connection",
+ },
+ {
+ desc: "auth PLAIN wrong host name",
+ server: &smtp.ServerInfo{Name: "wrongServer:999", Auth: []string{"PLAIN"}, TLS: true},
+ err: "wrong host name",
+ },
+ {
+ desc: "auth LOGIN success",
+ server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"LOGIN"}, TLS: true},
+ },
+ {
+ desc: "auth LOGIN unencrypted connection fail",
+ server: &smtp.ServerInfo{Name: "wrongServer:999", Auth: []string{"LOGIN"}, TLS: true},
+ err: "wrong host name",
+ },
+ {
+ desc: "auth LOGIN wrong host name",
+ server: &smtp.ServerInfo{Name: "fakeserver:25", Auth: []string{"LOGIN"}, TLS: false},
+ err: "unencrypted connection",
+ },
+ }
+
+ for i, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ _, _, err := auth.Start(test.server)
+ got := ""
+ if err != nil {
+ got = err.Error()
+ }
+ if got != test.err {
+ t.Errorf("%d. got error = %q; want %q", i, got, test.err)
+ }
+ })
+ }
+}