summaryrefslogtreecommitdiffstats
path: root/plugin
diff options
context:
space:
mode:
authorChris <ccbrown112@gmail.com>2017-09-11 10:02:02 -0500
committerGitHub <noreply@github.com>2017-09-11 10:02:02 -0500
commit402491b7e52c4d836c1274976cdb387852cfd17b (patch)
treee8adcbdf0af5370f8af11e3fc1021a328c971a5d /plugin
parenta69bed712d53e9a7984915fffffc8a2fd1647a7a (diff)
downloadchat-402491b7e52c4d836c1274976cdb387852cfd17b.tar.gz
chat-402491b7e52c4d836c1274976cdb387852cfd17b.tar.bz2
chat-402491b7e52c4d836c1274976cdb387852cfd17b.zip
PLT-7407: Back-end plugins (#7409)
* tie back-end plugins together * fix comment typo * add tests and a bit of polish * tests and polish * add test, don't let backend executable paths escape the plugin directory
Diffstat (limited to 'plugin')
-rw-r--r--plugin/hooks.go3
-rw-r--r--plugin/pluginenv/environment.go59
-rw-r--r--plugin/pluginenv/environment_test.go71
-rw-r--r--plugin/plugintest/hooks.go4
-rw-r--r--plugin/rpcplugin/hooks.go29
-rw-r--r--plugin/rpcplugin/hooks_test.go44
-rw-r--r--plugin/rpcplugin/io.go163
-rw-r--r--plugin/rpcplugin/io_test.go73
-rw-r--r--plugin/rpcplugin/ipc.go2
-rw-r--r--plugin/rpcplugin/supervisor.go7
-rw-r--r--plugin/rpcplugin/supervisor_test.go13
11 files changed, 443 insertions, 25 deletions
diff --git a/plugin/hooks.go b/plugin/hooks.go
index 336e56ccb..7f0d8ae3c 100644
--- a/plugin/hooks.go
+++ b/plugin/hooks.go
@@ -12,6 +12,9 @@ type Hooks interface {
// use the API, and the plugin will be terminated shortly after this invocation.
OnDeactivate() error
+ // OnConfigurationChange is invoked when configuration changes may have been made.
+ OnConfigurationChange() error
+
// ServeHTTP allows the plugin to implement the http.Handler interface. Requests destined for
// the /plugins/{id} path will be routed to the plugin.
//
diff --git a/plugin/pluginenv/environment.go b/plugin/pluginenv/environment.go
index a943b24c6..e4a7f1b3b 100644
--- a/plugin/pluginenv/environment.go
+++ b/plugin/pluginenv/environment.go
@@ -4,6 +4,7 @@ package pluginenv
import (
"fmt"
"io/ioutil"
+ "net/http"
"sync"
"github.com/pkg/errors"
@@ -27,7 +28,7 @@ type Environment struct {
apiProvider APIProviderFunc
supervisorProvider SupervisorProviderFunc
activePlugins map[string]ActivePlugin
- mutex sync.Mutex
+ mutex sync.RWMutex
}
type Option func(*Environment)
@@ -61,15 +62,13 @@ func (env *Environment) SearchPath() string {
// Returns a list of all plugins found within the environment.
func (env *Environment) Plugins() ([]*model.BundleInfo, error) {
- env.mutex.Lock()
- defer env.mutex.Unlock()
return ScanSearchPath(env.searchPath)
}
// Returns a list of all currently active plugins within the environment.
func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) {
- env.mutex.Lock()
- defer env.mutex.Unlock()
+ env.mutex.RLock()
+ defer env.mutex.RUnlock()
activePlugins := []*model.BundleInfo{}
for _, p := range env.activePlugins {
@@ -81,8 +80,8 @@ func (env *Environment) ActivePlugins() ([]*model.BundleInfo, error) {
// Returns the ids of the currently active plugins.
func (env *Environment) ActivePluginIds() (ids []string) {
- env.mutex.Lock()
- defer env.mutex.Unlock()
+ env.mutex.RLock()
+ defer env.mutex.RUnlock()
for id := range env.activePlugins {
ids = append(ids, id)
@@ -200,13 +199,55 @@ func (env *Environment) Shutdown() (errs []error) {
for _, activePlugin := range env.activePlugins {
if activePlugin.Supervisor != nil {
if err := activePlugin.Supervisor.Hooks().OnDeactivate(); err != nil {
- errs = append(errs, err)
+ errs = append(errs, errors.Wrapf(err, "OnDeactivate() error for %v", activePlugin.BundleInfo.Manifest.Id))
}
if err := activePlugin.Supervisor.Stop(); err != nil {
- errs = append(errs, err)
+ errs = append(errs, errors.Wrapf(err, "error stopping supervisor for %v", activePlugin.BundleInfo.Manifest.Id))
}
}
}
env.activePlugins = make(map[string]ActivePlugin)
return
}
+
+type EnvironmentHooks struct {
+ env *Environment
+}
+
+func (env *Environment) Hooks() *EnvironmentHooks {
+ return &EnvironmentHooks{env}
+}
+
+// OnConfigurationChange invokes the OnConfigurationChange hook for all plugins. Any errors
+// encountered will be returned.
+func (h *EnvironmentHooks) OnConfigurationChange() (errs []error) {
+ h.env.mutex.RLock()
+ defer h.env.mutex.RUnlock()
+ for _, activePlugin := range h.env.activePlugins {
+ if activePlugin.Supervisor == nil {
+ continue
+ }
+ if err := activePlugin.Supervisor.Hooks().OnConfigurationChange(); err != nil {
+ errs = append(errs, errors.Wrapf(err, "OnConfigurationChange error for %v", activePlugin.BundleInfo.Manifest.Id))
+ }
+ }
+ return
+}
+
+// ServeHTTP invokes the ServeHTTP hook for the plugin identified by the request or responds with a
+// 404 not found.
+//
+// It expects the request's context to have a plugin_id set.
+func (h *EnvironmentHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if id := r.Context().Value("plugin_id"); id != nil {
+ if idstr, ok := id.(string); ok {
+ h.env.mutex.RLock()
+ defer h.env.mutex.RUnlock()
+ if plugin, ok := h.env.activePlugins[idstr]; ok && plugin.Supervisor != nil {
+ plugin.Supervisor.Hooks().ServeHTTP(w, r)
+ return
+ }
+ }
+ }
+ http.NotFound(w, r)
+}
diff --git a/plugin/pluginenv/environment_test.go b/plugin/pluginenv/environment_test.go
index e9d0820bb..f24ef8d3d 100644
--- a/plugin/pluginenv/environment_test.go
+++ b/plugin/pluginenv/environment_test.go
@@ -1,10 +1,14 @@
package pluginenv
import (
+ "context"
"fmt"
"io/ioutil"
+ "net/http"
+ "net/http/httptest"
"os"
"path/filepath"
+ "sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -298,3 +302,70 @@ func TestEnvironment_ShutdownError(t *testing.T) {
assert.Equal(t, env.ActivePluginIds(), []string{"foo"})
assert.Len(t, env.Shutdown(), 2)
}
+
+func TestEnvironment_ConcurrentHookInvocations(t *testing.T) {
+ dir := initTmpDir(t, map[string]string{
+ "foo/plugin.json": `{"id": "foo", "backend": {}}`,
+ })
+ defer os.RemoveAll(dir)
+
+ var provider MockProvider
+ defer provider.AssertExpectations(t)
+
+ var api struct{ plugin.API }
+ var supervisor MockSupervisor
+ defer supervisor.AssertExpectations(t)
+ var hooks plugintest.Hooks
+ defer hooks.AssertExpectations(t)
+
+ env, err := New(
+ SearchPath(dir),
+ APIProvider(provider.API),
+ SupervisorProvider(provider.Supervisor),
+ )
+ require.NoError(t, err)
+ defer env.Shutdown()
+
+ provider.On("API").Return(&api, nil)
+ provider.On("Supervisor").Return(&supervisor, nil)
+
+ supervisor.On("Start").Return(nil)
+ supervisor.On("Stop").Return(nil)
+ supervisor.On("Hooks").Return(&hooks)
+
+ ch := make(chan bool)
+
+ hooks.On("OnActivate", &api).Return(nil)
+ hooks.On("OnDeactivate").Return(nil)
+ hooks.On("ServeHTTP", mock.AnythingOfType("*httptest.ResponseRecorder"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
+ r := args.Get(1).(*http.Request)
+ if r.URL.Path == "/1" {
+ <-ch
+ } else {
+ ch <- true
+ }
+ })
+
+ assert.NoError(t, env.ActivatePlugin("foo"))
+
+ rec := httptest.NewRecorder()
+
+ wg := sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ req, err := http.NewRequest("GET", "/1", nil)
+ require.NoError(t, err)
+ env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo")))
+ wg.Done()
+ }()
+
+ go func() {
+ req, err := http.NewRequest("GET", "/2", nil)
+ require.NoError(t, err)
+ env.Hooks().ServeHTTP(rec, req.WithContext(context.WithValue(context.Background(), "plugin_id", "foo")))
+ wg.Done()
+ }()
+
+ wg.Wait()
+}
diff --git a/plugin/plugintest/hooks.go b/plugin/plugintest/hooks.go
index b0053a1ad..721a709ea 100644
--- a/plugin/plugintest/hooks.go
+++ b/plugin/plugintest/hooks.go
@@ -22,6 +22,10 @@ func (m *Hooks) OnDeactivate() error {
return m.Called().Error(0)
}
+func (m *Hooks) OnConfigurationChange() error {
+ return m.Called().Error(0)
+}
+
func (m *Hooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.Called(w, r)
}
diff --git a/plugin/rpcplugin/hooks.go b/plugin/rpcplugin/hooks.go
index 68bce41eb..18e4a6672 100644
--- a/plugin/rpcplugin/hooks.go
+++ b/plugin/rpcplugin/hooks.go
@@ -86,6 +86,15 @@ func (h *LocalHooks) OnDeactivate(args, reply *struct{}) (err error) {
return
}
+func (h *LocalHooks) OnConfigurationChange(args, reply *struct{}) error {
+ if hook, ok := h.hooks.(interface {
+ OnConfigurationChange() error
+ }); ok {
+ return hook.OnConfigurationChange()
+ }
+ return nil
+}
+
type ServeHTTPArgs struct {
ResponseWriterStream int64
Request *http.Request
@@ -122,11 +131,14 @@ func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) {
server.ServeConn(conn)
}
+// These assignments are part of the wire protocol. You can add more, but should not change existing
+// assignments.
const (
- remoteOnActivate = iota
- remoteOnDeactivate
- remoteServeHTTP
- maxRemoteHookCount
+ remoteOnActivate = 0
+ remoteOnDeactivate = 1
+ remoteServeHTTP = 2
+ remoteOnConfigurationChange = 3
+ maxRemoteHookCount = iota
)
type RemoteHooks struct {
@@ -164,6 +176,13 @@ func (h *RemoteHooks) OnDeactivate() error {
return h.client.Call("LocalHooks.OnDeactivate", struct{}{}, nil)
}
+func (h *RemoteHooks) OnConfigurationChange() error {
+ if !h.implemented[remoteOnConfigurationChange] {
+ return nil
+ }
+ return h.client.Call("LocalHooks.OnConfigurationChange", struct{}{}, nil)
+}
+
func (h *RemoteHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !h.implemented[remoteServeHTTP] {
http.NotFound(w, r)
@@ -227,6 +246,8 @@ func ConnectHooks(conn io.ReadWriteCloser, muxer *Muxer) (*RemoteHooks, error) {
remote.implemented[remoteOnActivate] = true
case "OnDeactivate":
remote.implemented[remoteOnDeactivate] = true
+ case "OnConfigurationChange":
+ remote.implemented[remoteOnConfigurationChange] = true
case "ServeHTTP":
remote.implemented[remoteServeHTTP] = true
}
diff --git a/plugin/rpcplugin/hooks_test.go b/plugin/rpcplugin/hooks_test.go
index c3c6c8448..37c529510 100644
--- a/plugin/rpcplugin/hooks_test.go
+++ b/plugin/rpcplugin/hooks_test.go
@@ -6,10 +6,12 @@ import (
"net/http"
"net/http/httptest"
"strings"
+ "sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
"github.com/mattermost/mattermost-server/plugin"
"github.com/mattermost/mattermost-server/plugin/plugintest"
@@ -50,6 +52,9 @@ func TestHooks(t *testing.T) {
hooks.On("OnDeactivate").Return(nil)
assert.NoError(t, remote.OnDeactivate())
+ hooks.On("OnConfigurationChange").Return(nil)
+ assert.NoError(t, remote.OnConfigurationChange())
+
hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
w := args.Get(0).(http.ResponseWriter)
r := args.Get(1).(*http.Request)
@@ -77,6 +82,45 @@ func TestHooks(t *testing.T) {
}))
}
+func TestHooks_Concurrency(t *testing.T) {
+ var hooks plugintest.Hooks
+ defer hooks.AssertExpectations(t)
+
+ assert.NoError(t, testHooksRPC(&hooks, func(remote *RemoteHooks) {
+ ch := make(chan bool)
+
+ hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
+ r := args.Get(1).(*http.Request)
+ if r.URL.Path == "/1" {
+ <-ch
+ } else {
+ ch <- true
+ }
+ })
+
+ rec := httptest.NewRecorder()
+
+ wg := sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ req, err := http.NewRequest("GET", "/1", nil)
+ require.NoError(t, err)
+ remote.ServeHTTP(rec, req)
+ wg.Done()
+ }()
+
+ go func() {
+ req, err := http.NewRequest("GET", "/2", nil)
+ require.NoError(t, err)
+ remote.ServeHTTP(rec, req)
+ wg.Done()
+ }()
+
+ wg.Wait()
+ }))
+}
+
type testHooks struct {
mock.Mock
}
diff --git a/plugin/rpcplugin/io.go b/plugin/rpcplugin/io.go
index 38229d868..21d79ab0b 100644
--- a/plugin/rpcplugin/io.go
+++ b/plugin/rpcplugin/io.go
@@ -2,26 +2,169 @@ package rpcplugin
import (
"bufio"
+ "bytes"
"encoding/binary"
"io"
- "os"
+ "sync"
)
+type asyncRead struct {
+ b []byte
+ err error
+}
+
+type asyncReadCloser struct {
+ io.ReadCloser
+ buffer bytes.Buffer
+ read chan struct{}
+ reads chan asyncRead
+ close chan struct{}
+ closeOnce sync.Once
+}
+
+// NewAsyncReadCloser returns a ReadCloser that supports Close during Read.
+func NewAsyncReadCloser(r io.ReadCloser) io.ReadCloser {
+ ret := &asyncReadCloser{
+ ReadCloser: r,
+ read: make(chan struct{}),
+ reads: make(chan asyncRead),
+ close: make(chan struct{}),
+ }
+ go ret.loop()
+ return ret
+}
+
+func (r *asyncReadCloser) loop() {
+ buf := make([]byte, 1024*8)
+ var n int
+ var err error
+ for {
+ select {
+ case <-r.read:
+ n = 0
+ if err == nil {
+ n, err = r.ReadCloser.Read(buf)
+ }
+ select {
+ case r.reads <- asyncRead{buf[:n], err}:
+ case <-r.close:
+ }
+ case <-r.close:
+ r.ReadCloser.Close()
+ return
+ }
+ }
+}
+
+func (r *asyncReadCloser) Read(b []byte) (int, error) {
+ if r.buffer.Len() > 0 {
+ return r.buffer.Read(b)
+ }
+ select {
+ case r.read <- struct{}{}:
+ case <-r.close:
+ }
+ select {
+ case read := <-r.reads:
+ if read.err != nil {
+ return 0, read.err
+ }
+ n := copy(b, read.b)
+ if n < len(read.b) {
+ r.buffer.Write(read.b[n:])
+ }
+ return n, nil
+ case <-r.close:
+ return 0, io.EOF
+ }
+}
+
+func (r *asyncReadCloser) Close() error {
+ r.closeOnce.Do(func() {
+ close(r.close)
+ })
+ return nil
+}
+
+type asyncWrite struct {
+ n int
+ err error
+}
+
+type asyncWriteCloser struct {
+ io.WriteCloser
+ writeBuffer bytes.Buffer
+ write chan struct{}
+ writes chan asyncWrite
+ close chan struct{}
+ closeOnce sync.Once
+}
+
+// NewAsyncWriteCloser returns a WriteCloser that supports Close during Write.
+func NewAsyncWriteCloser(w io.WriteCloser) io.WriteCloser {
+ ret := &asyncWriteCloser{
+ WriteCloser: w,
+ write: make(chan struct{}),
+ writes: make(chan asyncWrite),
+ close: make(chan struct{}),
+ }
+ go ret.loop()
+ return ret
+}
+
+func (w *asyncWriteCloser) loop() {
+ var n int64
+ var err error
+ for {
+ select {
+ case <-w.write:
+ n = 0
+ if err == nil {
+ n, err = w.writeBuffer.WriteTo(w.WriteCloser)
+ }
+ select {
+ case w.writes <- asyncWrite{int(n), err}:
+ case <-w.close:
+ }
+ case <-w.close:
+ w.WriteCloser.Close()
+ return
+ }
+ }
+}
+
+func (w *asyncWriteCloser) Write(b []byte) (int, error) {
+ if n, err := w.writeBuffer.Write(b); err != nil {
+ return n, err
+ }
+ select {
+ case w.write <- struct{}{}:
+ case <-w.close:
+ }
+ select {
+ case write := <-w.writes:
+ return write.n, write.err
+ case <-w.close:
+ return 0, io.EOF
+ }
+}
+
+func (w *asyncWriteCloser) Close() error {
+ w.closeOnce.Do(func() {
+ close(w.close)
+ })
+ return nil
+}
+
type rwc struct {
io.ReadCloser
io.WriteCloser
}
func (rwc *rwc) Close() (err error) {
- if f, ok := rwc.ReadCloser.(*os.File); ok {
- // https://groups.google.com/d/topic/golang-nuts/i4w58KJ5-J8/discussion
- err = os.NewFile(f.Fd(), "").Close()
- } else {
- err = rwc.ReadCloser.Close()
- }
- werr := rwc.WriteCloser.Close()
- if err == nil {
- err = werr
+ err = rwc.WriteCloser.Close()
+ if rerr := rwc.ReadCloser.Close(); err == nil {
+ err = rerr
}
return
}
diff --git a/plugin/rpcplugin/io_test.go b/plugin/rpcplugin/io_test.go
new file mode 100644
index 000000000..cb31b23b3
--- /dev/null
+++ b/plugin/rpcplugin/io_test.go
@@ -0,0 +1,73 @@
+package rpcplugin
+
+import (
+ "io/ioutil"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAsyncReadCloser(t *testing.T) {
+ rf, w, err := os.Pipe()
+ require.NoError(t, err)
+ r := NewAsyncReadCloser(rf)
+ defer r.Close()
+
+ go func() {
+ w.Write([]byte("foo"))
+ w.Close()
+ }()
+
+ foo, err := ioutil.ReadAll(r)
+ require.NoError(t, err)
+ assert.Equal(t, "foo", string(foo))
+}
+
+func TestNewAsyncReadCloser_CloseDuringRead(t *testing.T) {
+ rf, w, err := os.Pipe()
+ require.NoError(t, err)
+ defer w.Close()
+
+ r := NewAsyncReadCloser(rf)
+
+ go func() {
+ time.Sleep(time.Millisecond * 200)
+ r.Close()
+ }()
+ r.Read(make([]byte, 10))
+}
+
+func TestNewAsyncWriteCloser(t *testing.T) {
+ r, wf, err := os.Pipe()
+ require.NoError(t, err)
+ w := NewAsyncWriteCloser(wf)
+ defer w.Close()
+
+ go func() {
+ foo, err := ioutil.ReadAll(r)
+ require.NoError(t, err)
+ assert.Equal(t, "foo", string(foo))
+ r.Close()
+ }()
+
+ n, err := w.Write([]byte("foo"))
+ require.NoError(t, err)
+ assert.Equal(t, 3, n)
+}
+
+func TestNewAsyncWriteCloser_CloseDuringWrite(t *testing.T) {
+ r, wf, err := os.Pipe()
+ require.NoError(t, err)
+ defer r.Close()
+
+ w := NewAsyncWriteCloser(wf)
+
+ go func() {
+ time.Sleep(time.Millisecond * 200)
+ w.Close()
+ }()
+ w.Write(make([]byte, 10))
+}
diff --git a/plugin/rpcplugin/ipc.go b/plugin/rpcplugin/ipc.go
index 3e6c89c4f..bbb3db06e 100644
--- a/plugin/rpcplugin/ipc.go
+++ b/plugin/rpcplugin/ipc.go
@@ -19,7 +19,7 @@ func NewIPC() (io.ReadWriteCloser, []*os.File, error) {
childWriter.Close()
return nil, nil, err
}
- return NewReadWriteCloser(parentReader, parentWriter), []*os.File{childReader, childWriter}, nil
+ return NewReadWriteCloser(NewAsyncReadCloser(parentReader), NewAsyncWriteCloser(parentWriter)), []*os.File{childReader, childWriter}, nil
}
// Returns the IPC instance inherited by the process from its parent.
diff --git a/plugin/rpcplugin/supervisor.go b/plugin/rpcplugin/supervisor.go
index 6a00d0468..7e37e2851 100644
--- a/plugin/rpcplugin/supervisor.go
+++ b/plugin/rpcplugin/supervisor.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"path/filepath"
+ "strings"
"sync/atomic"
"time"
@@ -123,7 +124,11 @@ func SupervisorProvider(bundle *model.BundleInfo) (plugin.Supervisor, error) {
} else if bundle.Manifest.Backend == nil || bundle.Manifest.Backend.Executable == "" {
return nil, fmt.Errorf("no backend executable specified")
}
+ executable := filepath.Clean(filepath.Join(".", bundle.Manifest.Backend.Executable))
+ if strings.HasPrefix(executable, "..") {
+ return nil, fmt.Errorf("invalid backend executable")
+ }
return &Supervisor{
- executable: filepath.Join(bundle.Path, bundle.Manifest.Backend.Executable),
+ executable: filepath.Join(bundle.Path, executable),
}, nil
}
diff --git a/plugin/rpcplugin/supervisor_test.go b/plugin/rpcplugin/supervisor_test.go
index 6940adcad..bad38b2d7 100644
--- a/plugin/rpcplugin/supervisor_test.go
+++ b/plugin/rpcplugin/supervisor_test.go
@@ -43,6 +43,19 @@ func TestSupervisor(t *testing.T) {
require.NoError(t, supervisor.Stop())
}
+func TestSupervisor_InvalidExecutablePath(t *testing.T) {
+ dir, err := ioutil.TempDir("", "")
+ require.NoError(t, err)
+ defer os.RemoveAll(dir)
+
+ ioutil.WriteFile(filepath.Join(dir, "plugin.json"), []byte(`{"id": "foo", "backend": {"executable": "/foo/../../backend.exe"}}`), 0600)
+
+ bundle := model.BundleInfoForPath(dir)
+ supervisor, err := SupervisorProvider(bundle)
+ assert.Nil(t, supervisor)
+ assert.Error(t, err)
+}
+
// If plugin development goes really wrong, let's make sure plugin activation won't block forever.
func TestSupervisor_StartTimeout(t *testing.T) {
dir, err := ioutil.TempDir("", "")