From 402491b7e52c4d836c1274976cdb387852cfd17b Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 11 Sep 2017 10:02:02 -0500 Subject: PLT-7407: Back-end plugins (#7409) * tie back-end plugins together * fix comment typo * add tests and a bit of polish * tests and polish * add test, don't let backend executable paths escape the plugin directory --- api4/api.go | 2 +- api4/plugin_test.go | 2 +- app/app.go | 5 +- app/plugins.go | 219 +++++++++++++++++++++++++++-------- app/server.go | 34 +----- cmd/platform/server.go | 24 ++-- i18n/en.json | 12 +- plugin/hooks.go | 3 + plugin/pluginenv/environment.go | 59 ++++++++-- plugin/pluginenv/environment_test.go | 71 ++++++++++++ plugin/plugintest/hooks.go | 4 + plugin/rpcplugin/hooks.go | 29 ++++- plugin/rpcplugin/hooks_test.go | 44 +++++++ plugin/rpcplugin/io.go | 163 ++++++++++++++++++++++++-- plugin/rpcplugin/io_test.go | 73 ++++++++++++ plugin/rpcplugin/ipc.go | 2 +- plugin/rpcplugin/supervisor.go | 7 +- plugin/rpcplugin/supervisor_test.go | 13 +++ utils/extract.go | 33 +++--- 19 files changed, 651 insertions(+), 148 deletions(-) create mode 100644 plugin/rpcplugin/io_test.go diff --git a/api4/api.go b/api4/api.go index 5204c0596..50c56ca0b 100644 --- a/api4/api.go +++ b/api4/api.go @@ -150,7 +150,7 @@ func InitApi(full bool) { BaseRoutes.PublicFile = BaseRoutes.Root.PathPrefix("/files/{file_id:[A-Za-z0-9]+}/public").Subrouter() BaseRoutes.Plugins = BaseRoutes.ApiRoot.PathPrefix("/plugins").Subrouter() - BaseRoutes.Plugin = BaseRoutes.Plugins.PathPrefix("/{plugin_id:[A-Za-z0-9\\_\\-]+}").Subrouter() + BaseRoutes.Plugin = BaseRoutes.Plugins.PathPrefix("/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}").Subrouter() BaseRoutes.Commands = BaseRoutes.ApiRoot.PathPrefix("/commands").Subrouter() BaseRoutes.Command = BaseRoutes.Commands.PathPrefix("/{command_id:[A-Za-z0-9]+}").Subrouter() diff --git a/api4/plugin_test.go b/api4/plugin_test.go index 3d8b065b9..0e8c0638c 100644 --- a/api4/plugin_test.go +++ b/api4/plugin_test.go @@ -110,5 +110,5 @@ func TestPlugin(t *testing.T) { _, resp = th.SystemAdminClient.RemovePlugin("bad.id") CheckNotFoundStatus(t, resp) - th.App.Srv.PluginEnv = nil + th.App.PluginEnv = nil } diff --git a/app/app.go b/app/app.go index ce812ed16..953ca285a 100644 --- a/app/app.go +++ b/app/app.go @@ -6,10 +6,13 @@ package app import ( "io/ioutil" "net/http" + + "github.com/mattermost/mattermost-server/plugin/pluginenv" ) type App struct { - Srv *Server + Srv *Server + PluginEnv *pluginenv.Environment } var globalApp App diff --git a/app/plugins.go b/app/plugins.go index 50f810d76..f165f7b49 100644 --- a/app/plugins.go +++ b/app/plugins.go @@ -4,6 +4,7 @@ package app import ( + "context" "encoding/json" "io" "io/ioutil" @@ -19,17 +20,50 @@ import ( "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/utils" - "github.com/mattermost/mattermost-server/app/plugin" + builtinplugin "github.com/mattermost/mattermost-server/app/plugin" "github.com/mattermost/mattermost-server/app/plugin/jira" "github.com/mattermost/mattermost-server/app/plugin/ldapextras" + + "github.com/mattermost/mattermost-server/plugin" + "github.com/mattermost/mattermost-server/plugin/pluginenv" ) type PluginAPI struct { + id string + app *App +} + +func (api *PluginAPI) LoadPluginConfiguration(dest interface{}) error { + if b, err := json.Marshal(utils.Cfg.PluginSettings.Plugins[api.id]); err != nil { + return err + } else { + return json.Unmarshal(b, dest) + } +} + +func (api *PluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) { + return api.app.GetTeamByName(name) +} + +func (api *PluginAPI) GetUserByUsername(name string) (*model.User, *model.AppError) { + return api.app.GetUserByUsername(name) +} + +func (api *PluginAPI) GetChannelByName(name, teamId string) (*model.Channel, *model.AppError) { + return api.app.GetChannelByName(name, teamId) +} + +func (api *PluginAPI) CreatePost(post *model.Post) (*model.Post, *model.AppError) { + return api.app.CreatePostMissingChannel(post, true) +} + +type BuiltInPluginAPI struct { id string router *mux.Router + app *App } -func (api *PluginAPI) LoadPluginConfiguration(dest interface{}) error { +func (api *BuiltInPluginAPI) LoadPluginConfiguration(dest interface{}) error { if b, err := json.Marshal(utils.Cfg.PluginSettings.Plugins[api.id]); err != nil { return err } else { @@ -37,37 +71,37 @@ func (api *PluginAPI) LoadPluginConfiguration(dest interface{}) error { } } -func (api *PluginAPI) PluginRouter() *mux.Router { +func (api *BuiltInPluginAPI) PluginRouter() *mux.Router { return api.router } -func (api *PluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) { - return Global().GetTeamByName(name) +func (api *BuiltInPluginAPI) GetTeamByName(name string) (*model.Team, *model.AppError) { + return api.app.GetTeamByName(name) } -func (api *PluginAPI) GetUserByName(name string) (*model.User, *model.AppError) { - return Global().GetUserByUsername(name) +func (api *BuiltInPluginAPI) GetUserByName(name string) (*model.User, *model.AppError) { + return api.app.GetUserByUsername(name) } -func (api *PluginAPI) GetChannelByName(teamId, name string) (*model.Channel, *model.AppError) { - return Global().GetChannelByName(name, teamId) +func (api *BuiltInPluginAPI) GetChannelByName(teamId, name string) (*model.Channel, *model.AppError) { + return api.app.GetChannelByName(name, teamId) } -func (api *PluginAPI) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) { - return Global().GetDirectChannel(userId1, userId2) +func (api *BuiltInPluginAPI) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) { + return api.app.GetDirectChannel(userId1, userId2) } -func (api *PluginAPI) CreatePost(post *model.Post) (*model.Post, *model.AppError) { - return Global().CreatePostMissingChannel(post, true) +func (api *BuiltInPluginAPI) CreatePost(post *model.Post) (*model.Post, *model.AppError) { + return api.app.CreatePostMissingChannel(post, true) } -func (api *PluginAPI) GetLdapUserAttributes(userId string, attributes []string) (map[string]string, *model.AppError) { +func (api *BuiltInPluginAPI) GetLdapUserAttributes(userId string, attributes []string) (map[string]string, *model.AppError) { ldapInterface := einterfaces.GetLdapInterface() if ldapInterface == nil { return nil, model.NewAppError("GetLdapUserAttributes", "ent.ldap.disabled.app_error", nil, "", http.StatusNotImplemented) } - user, err := Global().GetUser(userId) + user, err := api.app.GetUser(userId) if err != nil { return nil, err } @@ -75,7 +109,7 @@ func (api *PluginAPI) GetLdapUserAttributes(userId string, attributes []string) return ldapInterface.GetUserAttributes(*user.AuthData, attributes) } -func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *model.AppError) { +func (api *BuiltInPluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *model.AppError) { token := "" isTokenFromQueryString := false @@ -111,7 +145,7 @@ func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *m return nil, model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized) } - session, err := Global().GetSession(token) + session, err := api.app.GetSession(token) if err != nil { return nil, model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token, http.StatusUnauthorized) @@ -122,7 +156,7 @@ func (api *PluginAPI) GetSessionFromRequest(r *http.Request) (*model.Session, *m return session, nil } -func (api *PluginAPI) I18n(id string, r *http.Request) string { +func (api *BuiltInPluginAPI) I18n(id string, r *http.Request) string { if r != nil { f, _ := utils.GetTranslationsAndLocale(nil, r) return f(id) @@ -131,16 +165,17 @@ func (api *PluginAPI) I18n(id string, r *http.Request) string { return f(id) } -func (a *App) InitPlugins() { - plugins := map[string]plugin.Plugin{ +func (a *App) InitBuiltInPlugins() { + plugins := map[string]builtinplugin.Plugin{ "jira": &jira.Plugin{}, "ldapextras": &ldapextras.Plugin{}, } for id, p := range plugins { l4g.Info("Initializing plugin: " + id) - api := &PluginAPI{ + api := &BuiltInPluginAPI{ id: id, router: a.Srv.Router.PathPrefix("/plugins/" + id).Subrouter(), + app: a, } p.Initialize(api) } @@ -155,19 +190,19 @@ func (a *App) InitPlugins() { } func (a *App) ActivatePlugins() { - if a.Srv.PluginEnv == nil { + if a.PluginEnv == nil { l4g.Error("plugin env not initialized") return } - plugins, err := a.Srv.PluginEnv.Plugins() + plugins, err := a.PluginEnv.Plugins() if err != nil { l4g.Error("failed to start up plugins: " + err.Error()) return } for _, plugin := range plugins { - err := a.Srv.PluginEnv.ActivatePlugin(plugin.Manifest.Id) + err := a.PluginEnv.ActivatePlugin(plugin.Manifest.Id) if err != nil { l4g.Error(err.Error()) } @@ -176,48 +211,43 @@ func (a *App) ActivatePlugins() { } func (a *App) UnpackAndActivatePlugin(pluginFile io.Reader) (*model.Manifest, *model.AppError) { - if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { + if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented) } tmpDir, err := ioutil.TempDir("", "plugintmp") if err != nil { - return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.temp_dir.app_error", nil, err.Error(), http.StatusInternalServerError) + return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.filesystem.app_error", nil, err.Error(), http.StatusInternalServerError) } - defer func() { - os.RemoveAll(tmpDir) - }() + defer os.RemoveAll(tmpDir) - filenames, err := utils.ExtractTarGz(pluginFile, tmpDir) - if err != nil { + if err := utils.ExtractTarGz(pluginFile, tmpDir); err != nil { return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.extract.app_error", nil, err.Error(), http.StatusBadRequest) } - if len(filenames) == 0 { - return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.no_files.app_error", nil, err.Error(), http.StatusBadRequest) + tmpPluginDir := tmpDir + dir, err := ioutil.ReadDir(tmpDir) + if err != nil { + return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.filesystem.app_error", nil, err.Error(), http.StatusInternalServerError) } - splitPath := strings.Split(filenames[0], string(os.PathSeparator)) - - if len(splitPath) == 0 { - return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.bad_path.app_error", nil, err.Error(), http.StatusBadRequest) + if len(dir) == 1 && dir[0].IsDir() { + tmpPluginDir = filepath.Join(tmpPluginDir, dir[0].Name()) } - manifestDir := filepath.Join(tmpDir, splitPath[0]) - - manifest, _, err := model.FindManifest(manifestDir) + manifest, _, err := model.FindManifest(tmpPluginDir) if err != nil { return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.manifest.app_error", nil, err.Error(), http.StatusBadRequest) } - os.Rename(manifestDir, filepath.Join(a.Srv.PluginEnv.SearchPath(), manifest.Id)) + os.Rename(tmpPluginDir, filepath.Join(a.PluginEnv.SearchPath(), manifest.Id)) if err != nil { return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.mvdir.app_error", nil, err.Error(), http.StatusInternalServerError) } // Should add manifest validation and error handling here - err = a.Srv.PluginEnv.ActivatePlugin(manifest.Id) + err = a.PluginEnv.ActivatePlugin(manifest.Id) if err != nil { return nil, model.NewAppError("UnpackAndActivatePlugin", "app.plugin.activate.app_error", nil, err.Error(), http.StatusBadRequest) } @@ -226,11 +256,11 @@ func (a *App) UnpackAndActivatePlugin(pluginFile io.Reader) (*model.Manifest, *m } func (a *App) GetActivePluginManifests() ([]*model.Manifest, *model.AppError) { - if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { + if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { return nil, model.NewAppError("GetActivePluginManifests", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented) } - plugins, err := a.Srv.PluginEnv.ActivePlugins() + plugins, err := a.PluginEnv.ActivePlugins() if err != nil { return nil, model.NewAppError("GetActivePluginManifests", "app.plugin.get_plugins.app_error", nil, err.Error(), http.StatusInternalServerError) } @@ -244,16 +274,16 @@ func (a *App) GetActivePluginManifests() ([]*model.Manifest, *model.AppError) { } func (a *App) RemovePlugin(id string) *model.AppError { - if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { + if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { return model.NewAppError("RemovePlugin", "app.plugin.disabled.app_error", nil, "", http.StatusNotImplemented) } - err := a.Srv.PluginEnv.DeactivatePlugin(id) + err := a.PluginEnv.DeactivatePlugin(id) if err != nil { return model.NewAppError("RemovePlugin", "app.plugin.deactivate.app_error", nil, err.Error(), http.StatusBadRequest) } - err = os.RemoveAll(filepath.Join(a.Srv.PluginEnv.SearchPath(), id)) + err = os.RemoveAll(filepath.Join(a.PluginEnv.SearchPath(), id)) if err != nil { return model.NewAppError("RemovePlugin", "app.plugin.remove.app_error", nil, err.Error(), http.StatusInternalServerError) } @@ -268,11 +298,11 @@ type ClientConfigPlugin struct { } func (a *App) GetPluginsForClientConfig() string { - if a.Srv.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { + if a.PluginEnv == nil || !*utils.Cfg.PluginSettings.Enable { return "" } - plugins, err := a.Srv.PluginEnv.ActivePlugins() + plugins, err := a.PluginEnv.ActivePlugins() if err != nil { return "" } @@ -292,3 +322,94 @@ func (a *App) GetPluginsForClientConfig() string { return string(b) } + +func (a *App) InitPlugins(pluginPath, webappPath string) { + a.InitBuiltInPlugins() + + if !utils.IsLicensed() || !*utils.License().Features.FutureFeatures || !*utils.Cfg.PluginSettings.Enable { + return + } + + l4g.Info("Starting up plugins") + + err := os.Mkdir(pluginPath, 0744) + if err != nil && !os.IsExist(err) { + l4g.Error("failed to start up plugins: " + err.Error()) + return + } + + a.PluginEnv, err = pluginenv.New( + pluginenv.SearchPath(pluginPath), + pluginenv.WebappPath(webappPath), + pluginenv.APIProvider(func(m *model.Manifest) (plugin.API, error) { + return &PluginAPI{ + id: m.Id, + app: a, + }, nil + }), + ) + + if err != nil { + l4g.Error("failed to start up plugins: " + err.Error()) + return + } + + utils.AddConfigListener(func(_, _ *model.Config) { + for _, err := range a.PluginEnv.Hooks().OnConfigurationChange() { + l4g.Error(err.Error()) + } + }) + + a.Srv.Router.HandleFunc("/plugins/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}", a.ServePluginRequest) + a.Srv.Router.HandleFunc("/plugins/{plugin_id:[A-Za-z0-9\\_\\-\\.]+}/{anything:.*}", a.ServePluginRequest) + + a.ActivatePlugins() +} + +func (a *App) ServePluginRequest(w http.ResponseWriter, r *http.Request) { + token := "" + + authHeader := r.Header.Get(model.HEADER_AUTH) + if strings.HasPrefix(strings.ToUpper(authHeader), model.HEADER_BEARER+":") { + token = authHeader[len(model.HEADER_BEARER)+1:] + } else if strings.HasPrefix(strings.ToLower(authHeader), model.HEADER_TOKEN+":") { + token = authHeader[len(model.HEADER_TOKEN)+1:] + } else if cookie, _ := r.Cookie(model.SESSION_COOKIE_TOKEN); cookie != nil && (r.Method == "GET" || r.Header.Get(model.HEADER_REQUESTED_WITH) == model.HEADER_REQUESTED_WITH_XML) { + token = cookie.Value + } else { + token = r.URL.Query().Get("access_token") + } + + r.Header.Del("Mattermost-User-Id") + if token != "" { + if session, err := a.GetSession(token); err != nil { + r.Header.Set("Mattermost-User-Id", session.UserId) + } + } + + cookies := r.Cookies() + r.Header.Del("Cookie") + for _, c := range cookies { + if c.Name != model.SESSION_COOKIE_TOKEN { + r.AddCookie(c) + } + } + r.Header.Del(model.HEADER_AUTH) + r.Header.Del("Referer") + + newQuery := r.URL.Query() + newQuery.Del("access_token") + r.URL.RawQuery = newQuery.Encode() + + params := mux.Vars(r) + a.PluginEnv.Hooks().ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "plugin_id", params["plugin_id"]))) +} + +func (a *App) ShutDownPlugins() { + if a.PluginEnv == nil { + return + } + for _, err := range a.PluginEnv.Shutdown() { + l4g.Error(err.Error()) + } +} diff --git a/app/server.go b/app/server.go index 21d727724..c44408d12 100644 --- a/app/server.go +++ b/app/server.go @@ -7,7 +7,6 @@ import ( "crypto/tls" "net" "net/http" - "os" "strings" "time" @@ -20,7 +19,6 @@ import ( "gopkg.in/throttled/throttled.v2/store/memstore" "github.com/mattermost/mattermost-server/model" - "github.com/mattermost/mattermost-server/plugin/pluginenv" "github.com/mattermost/mattermost-server/store" "github.com/mattermost/mattermost-server/utils" ) @@ -30,7 +28,6 @@ type Server struct { WebSocketRouter *WebSocketRouter Router *mux.Router GracefulServer *graceful.Server - PluginEnv *pluginenv.Environment } var allowedMethods []string = []string{ @@ -187,10 +184,6 @@ func (a *App) StartServer() { }() } - if utils.IsLicensed() && *utils.License().Features.FutureFeatures && *utils.Cfg.PluginSettings.Enable { - a.StartupPlugins("plugins", "webapp/dist") - } - go func() { var err error if *utils.Cfg.ServiceSettings.ConnectionSecurity == model.CONN_SECURITY_TLS { @@ -226,30 +219,7 @@ func (a *App) StopServer() { a.Srv.Store.Close() HubStop() - l4g.Info(utils.T("api.server.stop_server.stopped.info")) -} - -func (a *App) StartupPlugins(pluginPath, webappPath string) { - l4g.Info("Starting up plugins") - - err := os.Mkdir(pluginPath, 0744) - if err != nil { - if os.IsExist(err) { - err = nil - } else { - l4g.Error("failed to start up plugins: " + err.Error()) - return - } - } - - a.Srv.PluginEnv, err = pluginenv.New( - pluginenv.SearchPath(pluginPath), - pluginenv.WebappPath(webappPath), - ) + a.ShutDownPlugins() - if err != nil { - l4g.Error("failed to start up plugins: " + err.Error()) - } - - a.ActivatePlugins() + l4g.Info(utils.T("api.server.stop_server.stopped.info")) } diff --git a/cmd/platform/server.go b/cmd/platform/server.go index a11bc58b8..fe5f5272b 100644 --- a/cmd/platform/server.go +++ b/cmd/platform/server.go @@ -71,20 +71,22 @@ func runServer(configFileLocation string) { l4g.Error("Problem with file storage settings: " + err.Error()) } - app.Global().NewServer() - app.Global().InitStores() + a := app.Global() + a.NewServer() + a.InitStores() api.InitRouter() + + if model.BuildEnterpriseReady == "true" { + a.LoadLicense() + } + a.InitPlugins("plugins", "webapp/dist") + wsapi.InitRouter() api4.InitApi(false) api.InitApi() - app.Global().InitPlugins() wsapi.InitApi() web.InitWeb() - if model.BuildEnterpriseReady == "true" { - app.Global().LoadLicense() - } - if !utils.IsLicensed() && len(utils.Cfg.SqlSettings.DataSourceReplicas) > 1 { l4g.Warn(utils.T("store.sql.read_replicas_not_licensed.critical")) utils.Cfg.SqlSettings.DataSourceReplicas = utils.Cfg.SqlSettings.DataSourceReplicas[:1] @@ -98,7 +100,7 @@ func runServer(configFileLocation string) { resetStatuses() - app.Global().StartServer() + a.StartServer() // If we allow testing then listen for manual testing URL hits if utils.Cfg.ServiceSettings.EnableTesting { @@ -118,7 +120,7 @@ func runServer(configFileLocation string) { } if einterfaces.GetClusterInterface() != nil { - app.Global().RegisterAllClusterMessageHandlers() + a.RegisterAllClusterMessageHandlers() einterfaces.GetClusterInterface().StartInterNodeCommunication() } @@ -132,7 +134,7 @@ func runServer(configFileLocation string) { } } - jobs.Srv.Store = app.Global().Srv.Store + jobs.Srv.Store = a.Srv.Store if *utils.Cfg.JobSettings.RunJobs { jobs.Srv.StartWorkers() } @@ -157,7 +159,7 @@ func runServer(configFileLocation string) { jobs.Srv.StopSchedulers() jobs.Srv.StopWorkers() - app.Global().StopServer() + a.StopServer() } func runSecurityJob() { diff --git a/i18n/en.json b/i18n/en.json index 7503ca8e8..2c8bbd27b 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -3427,10 +3427,6 @@ "id": "app.plugin.activate.app_error", "translation": "Unable to activate extracted plugin. Plugin may already exist and be activated." }, - { - "id": "app.plugin.bad_path.app_error", - "translation": "Bad file path in extracted files" - }, { "id": "app.plugin.deactivate.app_error", "translation": "Unable to deactivate plugin" @@ -3443,6 +3439,10 @@ "id": "app.plugin.extract.app_error", "translation": "Encountered error extracting plugin" }, + { + "id": "app.plugin.filesystem.app_error", + "translation": "Encountered filesystem error" + }, { "id": "app.plugin.get_plugins.app_error", "translation": "Unable to get active plugins" @@ -3455,10 +3455,6 @@ "id": "app.plugin.mvdir.app_error", "translation": "Unable to move plugin from temporary directory to final destination" }, - { - "id": "app.plugin.no_files.app_error", - "translation": "No files found in the compressed folder" - }, { "id": "app.plugin.remove.app_error", "translation": "Unable to delete plugin" diff --git a/plugin/hooks.go b/plugin/hooks.go index 336e56ccb..7f0d8ae3c 100644 --- a/plugin/hooks.go +++ b/plugin/hooks.go @@ -12,6 +12,9 @@ type Hooks interface { // use the API, and the plugin will be terminated shortly after this invocation. OnDeactivate() error + // OnConfigurationChange is invoked when configuration changes may have been made. + OnConfigurationChange() error + // ServeHTTP allows the plugin to implement the http.Handler interface. Requests destined for // the /plugins/{id} path will be routed to the plugin. // diff --git a/plugin/pluginenv/environment.go b/plugin/pluginenv/environment.go index a943b24c6..e4a7f1b3b 100644 --- a/plugin/pluginenv/environment.go +++ b/plugin/pluginenv/environment.go @@ -4,6 +4,7 @@ package pluginenv import ( "fmt" "io/ioutil" + "net/http" "sync" "github.com/pkg/errors" @@ -27,7 +28,7 @@ type Environment struct { apiProvider APIProviderFunc supervisorProvider SupervisorProviderFunc activePlugins map[string]ActivePlugin - mutex sync.Mutex + mutex sync.RWMutex } type Option func(*Environment) @@ -61,15 +62,13 @@ func (env *Environment) SearchPath() string { // Returns a list of all plugins found within the environment. func (env *Environment) Plugins() ([]*model.BundleInfo, error) { - env.mutex.Lock() - defer env.mutex.Unlock() return ScanSearchPath(env.searchPath) } // Returns a list of all currently active plugins within the environment. func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) { - env.mutex.Lock() - defer env.mutex.Unlock() + env.mutex.RLock() + defer env.mutex.RUnlock() activePlugins := []*model.BundleInfo{} for _, p := range env.activePlugins { @@ -81,8 +80,8 @@ func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) { // Returns the ids of the currently active plugins. func (env *Environment) ActivePluginIds() (ids []string) { - env.mutex.Lock() - defer env.mutex.Unlock() + env.mutex.RLock() + defer env.mutex.RUnlock() for id := range env.activePlugins { ids = append(ids, id) @@ -200,13 +199,55 @@ func (env *Environment) Shutdown() (errs []error) { for _, activePlugin := range env.activePlugins { if activePlugin.Supervisor != nil { if err := activePlugin.Supervisor.Hooks().OnDeactivate(); err != nil { - errs = append(errs, err) + errs = append(errs, errors.Wrapf(err, "OnDeactivate() error for %v", activePlugin.BundleInfo.Manifest.Id)) } if err := activePlugin.Supervisor.Stop(); err != nil { - errs = append(errs, err) + errs = append(errs, errors.Wrapf(err, "error stopping supervisor for %v", activePlugin.BundleInfo.Manifest.Id)) } } } env.activePlugins = make(map[string]ActivePlugin) return } + +type EnvironmentHooks struct { + env *Environment +} + +func (env *Environment) Hooks() *EnvironmentHooks { + return &EnvironmentHooks{env} +} + +// OnConfigurationChange invokes the OnConfigurationChange hook for all plugins. Any errors +// encountered will be returned. +func (h *EnvironmentHooks) OnConfigurationChange() (errs []error) { + h.env.mutex.RLock() + defer h.env.mutex.RUnlock() + for _, activePlugin := range h.env.activePlugins { + if activePlugin.Supervisor == nil { + continue + } + if err := activePlugin.Supervisor.Hooks().OnConfigurationChange(); err != nil { + errs = append(errs, errors.Wrapf(err, "OnConfigurationChange error for %v", activePlugin.BundleInfo.Manifest.Id)) + } + } + return +} + +// ServeHTTP invokes the ServeHTTP hook for the plugin identified by the request or responds with a +// 404 not found. +// +// It expects the request's context to have a plugin_id set. +func (h *EnvironmentHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if id := r.Context().Value("plugin_id"); id != nil { + if idstr, ok := id.(string); ok { + h.env.mutex.RLock() + defer h.env.mutex.RUnlock() + if plugin, ok := h.env.activePlugins[idstr]; ok && plugin.Supervisor != nil { + plugin.Supervisor.Hooks().ServeHTTP(w, r) + return + } + } + } + http.NotFound(w, r) +} diff --git a/plugin/pluginenv/environment_test.go b/plugin/pluginenv/environment_test.go index e9d0820bb..f24ef8d3d 100644 --- a/plugin/pluginenv/environment_test.go +++ b/plugin/pluginenv/environment_test.go @@ -1,10 +1,14 @@ package pluginenv import ( + "context" "fmt" "io/ioutil" + "net/http" + "net/http/httptest" "os" "path/filepath" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -298,3 +302,70 @@ func TestEnvironment_ShutdownError(t *testing.T) { assert.Equal(t, env.ActivePluginIds(), []string{"foo"}) assert.Len(t, env.Shutdown(), 2) } + +func TestEnvironment_ConcurrentHookInvocations(t *testing.T) { + dir := initTmpDir(t, map[string]string{ + "foo/plugin.json": `{"id": "foo", "backend": {}}`, + }) + defer os.RemoveAll(dir) + + var provider MockProvider + defer provider.AssertExpectations(t) + + var api struct{ plugin.API } + var supervisor MockSupervisor + defer supervisor.AssertExpectations(t) + var hooks plugintest.Hooks + defer hooks.AssertExpectations(t) + + env, err := New( + SearchPath(dir), + APIProvider(provider.API), + SupervisorProvider(provider.Supervisor), + ) + require.NoError(t, err) + defer env.Shutdown() + + provider.On("API").Return(&api, nil) + provider.On("Supervisor").Return(&supervisor, nil) + + supervisor.On("Start").Return(nil) + supervisor.On("Stop").Return(nil) + supervisor.On("Hooks").Return(&hooks) + + ch := make(chan bool) + + hooks.On("OnActivate", &api).Return(nil) + hooks.On("OnDeactivate").Return(nil) + hooks.On("ServeHTTP", mock.AnythingOfType("*httptest.ResponseRecorder"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) { + r := args.Get(1).(*http.Request) + if r.URL.Path == "/1" { + <-ch + } else { + ch <- true + } + }) + + assert.NoError(t, env.ActivatePlugin("foo")) + + rec := httptest.NewRecorder() + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + req, err := http.NewRequest("GET", "/1", nil) + require.NoError(t, err) + env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo"))) + wg.Done() + }() + + go func() { + req, err := http.NewRequest("GET", "/2", nil) + require.NoError(t, err) + env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo"))) + wg.Done() + }() + + wg.Wait() +} diff --git a/plugin/plugintest/hooks.go b/plugin/plugintest/hooks.go index b0053a1ad..721a709ea 100644 --- a/plugin/plugintest/hooks.go +++ b/plugin/plugintest/hooks.go @@ -22,6 +22,10 @@ func (m *Hooks) OnDeactivate() error { return m.Called().Error(0) } +func (m *Hooks) OnConfigurationChange() error { + return m.Called().Error(0) +} + func (m *Hooks) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.Called(w, r) } diff --git a/plugin/rpcplugin/hooks.go b/plugin/rpcplugin/hooks.go index 68bce41eb..18e4a6672 100644 --- a/plugin/rpcplugin/hooks.go +++ b/plugin/rpcplugin/hooks.go @@ -86,6 +86,15 @@ func (h *LocalHooks) OnDeactivate(args, reply *struct{}) (err error) { return } +func (h *LocalHooks) OnConfigurationChange(args, reply *struct{}) error { + if hook, ok := h.hooks.(interface { + OnConfigurationChange() error + }); ok { + return hook.OnConfigurationChange() + } + return nil +} + type ServeHTTPArgs struct { ResponseWriterStream int64 Request *http.Request @@ -122,11 +131,14 @@ func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) { server.ServeConn(conn) } +// These assignments are part of the wire protocol. You can add more, but should not change existing +// assignments. const ( - remoteOnActivate = iota - remoteOnDeactivate - remoteServeHTTP - maxRemoteHookCount + remoteOnActivate = 0 + remoteOnDeactivate = 1 + remoteServeHTTP = 2 + remoteOnConfigurationChange = 3 + maxRemoteHookCount = iota ) type RemoteHooks struct { @@ -164,6 +176,13 @@ func (h *RemoteHooks) OnDeactivate() error { return h.client.Call("LocalHooks.OnDeactivate", struct{}{}, nil) } +func (h *RemoteHooks) OnConfigurationChange() error { + if !h.implemented[remoteOnConfigurationChange] { + return nil + } + return h.client.Call("LocalHooks.OnConfigurationChange", struct{}{}, nil) +} + func (h *RemoteHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !h.implemented[remoteServeHTTP] { http.NotFound(w, r) @@ -227,6 +246,8 @@ func ConnectHooks(conn io.ReadWriteCloser, muxer *Muxer) (*RemoteHooks, error) { remote.implemented[remoteOnActivate] = true case "OnDeactivate": remote.implemented[remoteOnDeactivate] = true + case "OnConfigurationChange": + remote.implemented[remoteOnConfigurationChange] = true case "ServeHTTP": remote.implemented[remoteServeHTTP] = true } diff --git a/plugin/rpcplugin/hooks_test.go b/plugin/rpcplugin/hooks_test.go index c3c6c8448..37c529510 100644 --- a/plugin/rpcplugin/hooks_test.go +++ b/plugin/rpcplugin/hooks_test.go @@ -6,10 +6,12 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/mattermost/mattermost-server/plugin" "github.com/mattermost/mattermost-server/plugin/plugintest" @@ -50,6 +52,9 @@ func TestHooks(t *testing.T) { hooks.On("OnDeactivate").Return(nil) assert.NoError(t, remote.OnDeactivate()) + hooks.On("OnConfigurationChange").Return(nil) + assert.NoError(t, remote.OnConfigurationChange()) + hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) { w := args.Get(0).(http.ResponseWriter) r := args.Get(1).(*http.Request) @@ -77,6 +82,45 @@ func TestHooks(t *testing.T) { })) } +func TestHooks_Concurrency(t *testing.T) { + var hooks plugintest.Hooks + defer hooks.AssertExpectations(t) + + assert.NoError(t, testHooksRPC(&hooks, func(remote *RemoteHooks) { + ch := make(chan bool) + + hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) { + r := args.Get(1).(*http.Request) + if r.URL.Path == "/1" { + <-ch + } else { + ch <- true + } + }) + + rec := httptest.NewRecorder() + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + req, err := http.NewRequest("GET", "/1", nil) + require.NoError(t, err) + remote.ServeHTTP(rec, req) + wg.Done() + }() + + go func() { + req, err := http.NewRequest("GET", "/2", nil) + require.NoError(t, err) + remote.ServeHTTP(rec, req) + wg.Done() + }() + + wg.Wait() + })) +} + type testHooks struct { mock.Mock } diff --git a/plugin/rpcplugin/io.go b/plugin/rpcplugin/io.go index 38229d868..21d79ab0b 100644 --- a/plugin/rpcplugin/io.go +++ b/plugin/rpcplugin/io.go @@ -2,26 +2,169 @@ package rpcplugin import ( "bufio" + "bytes" "encoding/binary" "io" - "os" + "sync" ) +type asyncRead struct { + b []byte + err error +} + +type asyncReadCloser struct { + io.ReadCloser + buffer bytes.Buffer + read chan struct{} + reads chan asyncRead + close chan struct{} + closeOnce sync.Once +} + +// NewAsyncReadCloser returns a ReadCloser that supports Close during Read. +func NewAsyncReadCloser(r io.ReadCloser) io.ReadCloser { + ret := &asyncReadCloser{ + ReadCloser: r, + read: make(chan struct{}), + reads: make(chan asyncRead), + close: make(chan struct{}), + } + go ret.loop() + return ret +} + +func (r *asyncReadCloser) loop() { + buf := make([]byte, 1024*8) + var n int + var err error + for { + select { + case <-r.read: + n = 0 + if err == nil { + n, err = r.ReadCloser.Read(buf) + } + select { + case r.reads <- asyncRead{buf[:n], err}: + case <-r.close: + } + case <-r.close: + r.ReadCloser.Close() + return + } + } +} + +func (r *asyncReadCloser) Read(b []byte) (int, error) { + if r.buffer.Len() > 0 { + return r.buffer.Read(b) + } + select { + case r.read <- struct{}{}: + case <-r.close: + } + select { + case read := <-r.reads: + if read.err != nil { + return 0, read.err + } + n := copy(b, read.b) + if n < len(read.b) { + r.buffer.Write(read.b[n:]) + } + return n, nil + case <-r.close: + return 0, io.EOF + } +} + +func (r *asyncReadCloser) Close() error { + r.closeOnce.Do(func() { + close(r.close) + }) + return nil +} + +type asyncWrite struct { + n int + err error +} + +type asyncWriteCloser struct { + io.WriteCloser + writeBuffer bytes.Buffer + write chan struct{} + writes chan asyncWrite + close chan struct{} + closeOnce sync.Once +} + +// NewAsyncWriteCloser returns a WriteCloser that supports Close during Write. +func NewAsyncWriteCloser(w io.WriteCloser) io.WriteCloser { + ret := &asyncWriteCloser{ + WriteCloser: w, + write: make(chan struct{}), + writes: make(chan asyncWrite), + close: make(chan struct{}), + } + go ret.loop() + return ret +} + +func (w *asyncWriteCloser) loop() { + var n int64 + var err error + for { + select { + case <-w.write: + n = 0 + if err == nil { + n, err = w.writeBuffer.WriteTo(w.WriteCloser) + } + select { + case w.writes <- asyncWrite{int(n), err}: + case <-w.close: + } + case <-w.close: + w.WriteCloser.Close() + return + } + } +} + +func (w *asyncWriteCloser) Write(b []byte) (int, error) { + if n, err := w.writeBuffer.Write(b); err != nil { + return n, err + } + select { + case w.write <- struct{}{}: + case <-w.close: + } + select { + case write := <-w.writes: + return write.n, write.err + case <-w.close: + return 0, io.EOF + } +} + +func (w *asyncWriteCloser) Close() error { + w.closeOnce.Do(func() { + close(w.close) + }) + return nil +} + type rwc struct { io.ReadCloser io.WriteCloser } func (rwc *rwc) Close() (err error) { - if f, ok := rwc.ReadCloser.(*os.File); ok { - // https://groups.google.com/d/topic/golang-nuts/i4w58KJ5-J8/discussion - err = os.NewFile(f.Fd(), "").Close() - } else { - err = rwc.ReadCloser.Close() - } - werr := rwc.WriteCloser.Close() - if err == nil { - err = werr + err = rwc.WriteCloser.Close() + if rerr := rwc.ReadCloser.Close(); err == nil { + err = rerr } return } diff --git a/plugin/rpcplugin/io_test.go b/plugin/rpcplugin/io_test.go new file mode 100644 index 000000000..cb31b23b3 --- /dev/null +++ b/plugin/rpcplugin/io_test.go @@ -0,0 +1,73 @@ +package rpcplugin + +import ( + "io/ioutil" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAsyncReadCloser(t *testing.T) { + rf, w, err := os.Pipe() + require.NoError(t, err) + r := NewAsyncReadCloser(rf) + defer r.Close() + + go func() { + w.Write([]byte("foo")) + w.Close() + }() + + foo, err := ioutil.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, "foo", string(foo)) +} + +func TestNewAsyncReadCloser_CloseDuringRead(t *testing.T) { + rf, w, err := os.Pipe() + require.NoError(t, err) + defer w.Close() + + r := NewAsyncReadCloser(rf) + + go func() { + time.Sleep(time.Millisecond * 200) + r.Close() + }() + r.Read(make([]byte, 10)) +} + +func TestNewAsyncWriteCloser(t *testing.T) { + r, wf, err := os.Pipe() + require.NoError(t, err) + w := NewAsyncWriteCloser(wf) + defer w.Close() + + go func() { + foo, err := ioutil.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, "foo", string(foo)) + r.Close() + }() + + n, err := w.Write([]byte("foo")) + require.NoError(t, err) + assert.Equal(t, 3, n) +} + +func TestNewAsyncWriteCloser_CloseDuringWrite(t *testing.T) { + r, wf, err := os.Pipe() + require.NoError(t, err) + defer r.Close() + + w := NewAsyncWriteCloser(wf) + + go func() { + time.Sleep(time.Millisecond * 200) + w.Close() + }() + w.Write(make([]byte, 10)) +} diff --git a/plugin/rpcplugin/ipc.go b/plugin/rpcplugin/ipc.go index 3e6c89c4f..bbb3db06e 100644 --- a/plugin/rpcplugin/ipc.go +++ b/plugin/rpcplugin/ipc.go @@ -19,7 +19,7 @@ func NewIPC() (io.ReadWriteCloser, []*os.File, error) { childWriter.Close() return nil, nil, err } - return NewReadWriteCloser(parentReader, parentWriter), []*os.File{childReader, childWriter}, nil + return NewReadWriteCloser(NewAsyncReadCloser(parentReader), NewAsyncWriteCloser(parentWriter)), []*os.File{childReader, childWriter}, nil } // Returns the IPC instance inherited by the process from its parent. diff --git a/plugin/rpcplugin/supervisor.go b/plugin/rpcplugin/supervisor.go index 6a00d0468..7e37e2851 100644 --- a/plugin/rpcplugin/supervisor.go +++ b/plugin/rpcplugin/supervisor.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "path/filepath" + "strings" "sync/atomic" "time" @@ -123,7 +124,11 @@ func SupervisorProvider(bundle *model.BundleInfo) (plugin.Supervisor, error) { } else if bundle.Manifest.Backend == nil || bundle.Manifest.Backend.Executable == "" { return nil, fmt.Errorf("no backend executable specified") } + executable := filepath.Clean(filepath.Join(".", bundle.Manifest.Backend.Executable)) + if strings.HasPrefix(executable, "..") { + return nil, fmt.Errorf("invalid backend executable") + } return &Supervisor{ - executable: filepath.Join(bundle.Path, bundle.Manifest.Backend.Executable), + executable: filepath.Join(bundle.Path, executable), }, nil } diff --git a/plugin/rpcplugin/supervisor_test.go b/plugin/rpcplugin/supervisor_test.go index 6940adcad..bad38b2d7 100644 --- a/plugin/rpcplugin/supervisor_test.go +++ b/plugin/rpcplugin/supervisor_test.go @@ -43,6 +43,19 @@ func TestSupervisor(t *testing.T) { require.NoError(t, supervisor.Stop()) } +func TestSupervisor_InvalidExecutablePath(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.NoError(t, err) + defer os.RemoveAll(dir) + + ioutil.WriteFile(filepath.Join(dir, "plugin.json"), []byte(`{"id": "foo", "backend": {"executable": "/foo/../../backend.exe"}}`), 0600) + + bundle := model.BundleInfoForPath(dir) + supervisor, err := SupervisorProvider(bundle) + assert.Nil(t, supervisor) + assert.Error(t, err) +} + // If plugin development goes really wrong, let's make sure plugin activation won't block forever. func TestSupervisor_StartTimeout(t *testing.T) { dir, err := ioutil.TempDir("", "") diff --git a/utils/extract.go b/utils/extract.go index 0559c6ce8..bc8e07f75 100644 --- a/utils/extract.go +++ b/utils/extract.go @@ -13,19 +13,16 @@ import ( ) // ExtractTarGz takes in an io.Reader containing the bytes for a .tar.gz file and -// a destination string to extract to. A list of the file and directory names that -// were extracted is returned. -func ExtractTarGz(gzipStream io.Reader, dst string) ([]string, error) { +// a destination string to extract to. +func ExtractTarGz(gzipStream io.Reader, dst string) error { uncompressedStream, err := gzip.NewReader(gzipStream) if err != nil { - return nil, fmt.Errorf("ExtractTarGz: NewReader failed: %s", err.Error()) + return fmt.Errorf("ExtractTarGz: NewReader failed: %s", err.Error()) } defer uncompressedStream.Close() tarReader := tar.NewReader(uncompressedStream) - filenames := []string{} - for true { header, err := tarReader.Next() @@ -34,50 +31,46 @@ func ExtractTarGz(gzipStream io.Reader, dst string) ([]string, error) { } if err != nil { - return nil, fmt.Errorf("ExtractTarGz: Next() failed: %s", err.Error()) + return fmt.Errorf("ExtractTarGz: Next() failed: %s", err.Error()) } switch header.Typeflag { case tar.TypeDir: if PathTraversesUpward(header.Name) { - return nil, fmt.Errorf("ExtractTarGz: path attempts to traverse upwards") + return fmt.Errorf("ExtractTarGz: path attempts to traverse upwards") } path := filepath.Join(dst, header.Name) if err := os.Mkdir(path, 0744); err != nil && !os.IsExist(err) { - return nil, fmt.Errorf("ExtractTarGz: Mkdir() failed: %s", err.Error()) + return fmt.Errorf("ExtractTarGz: Mkdir() failed: %s", err.Error()) } - - filenames = append(filenames, header.Name) case tar.TypeReg: if PathTraversesUpward(header.Name) { - return nil, fmt.Errorf("ExtractTarGz: path attempts to traverse upwards") + return fmt.Errorf("ExtractTarGz: path attempts to traverse upwards") } path := filepath.Join(dst, header.Name) dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0744); err != nil { - return nil, fmt.Errorf("ExtractTarGz: MkdirAll() failed: %s", err.Error()) + return fmt.Errorf("ExtractTarGz: MkdirAll() failed: %s", err.Error()) } - outFile, err := os.Create(path) + outFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) if err != nil { - return nil, fmt.Errorf("ExtractTarGz: Create() failed: %s", err.Error()) + return fmt.Errorf("ExtractTarGz: Create() failed: %s", err.Error()) } defer outFile.Close() if _, err := io.Copy(outFile, tarReader); err != nil { - return nil, fmt.Errorf("ExtractTarGz: Copy() failed: %s", err.Error()) + return fmt.Errorf("ExtractTarGz: Copy() failed: %s", err.Error()) } - - filenames = append(filenames, header.Name) default: - return nil, fmt.Errorf( + return fmt.Errorf( "ExtractTarGz: unknown type: %v in %v", header.Typeflag, header.Name) } } - return filenames, nil + return nil } -- cgit v1.2.3-1-g7c22