From 6215c9159acb85033616d2937edf3d87ef7ca79b Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 28 Aug 2017 11:27:18 -0500 Subject: add plugin http handler (#7289) --- plugin/hooks.go | 11 ++++++ plugin/plugintest/hooks.go | 6 +++ plugin/rpcplugin/api.go | 10 ++--- plugin/rpcplugin/hooks.go | 72 ++++++++++++++++++++++++++++++++++ plugin/rpcplugin/hooks_test.go | 62 +++++++++++++++++++++++++++-- plugin/rpcplugin/http.go | 88 ++++++++++++++++++++++++++++++++++++++++++ plugin/rpcplugin/http_test.go | 61 +++++++++++++++++++++++++++++ plugin/rpcplugin/io.go | 54 +++++++++++++++++++++++--- plugin/rpcplugin/muxer.go | 41 ++++++++++++-------- plugin/rpcplugin/muxer_test.go | 28 ++++++++++++++ 10 files changed, 403 insertions(+), 30 deletions(-) create mode 100644 plugin/rpcplugin/http.go create mode 100644 plugin/rpcplugin/http_test.go (limited to 'plugin') diff --git a/plugin/hooks.go b/plugin/hooks.go index 28a762a1a..336e56ccb 100644 --- a/plugin/hooks.go +++ b/plugin/hooks.go @@ -1,5 +1,9 @@ package plugin +import ( + "net/http" +) + type Hooks interface { // OnActivate is invoked when the plugin is activated. OnActivate(API) error @@ -7,4 +11,11 @@ type Hooks interface { // OnDeactivate is invoked when the plugin is deactivated. This is the plugin's last chance to // use the API, and the plugin will be terminated shortly after this invocation. OnDeactivate() error + + // ServeHTTP allows the plugin to implement the http.Handler interface. Requests destined for + // the /plugins/{id} path will be routed to the plugin. + // + // The Mattermost-User-Id header will be present if (and only if) the request is by an + // authenticated user. + ServeHTTP(http.ResponseWriter, *http.Request) } diff --git a/plugin/plugintest/hooks.go b/plugin/plugintest/hooks.go index 057c705c9..4cac515b4 100644 --- a/plugin/plugintest/hooks.go +++ b/plugin/plugintest/hooks.go @@ -1,6 +1,8 @@ package plugintest import ( + "net/http" + "github.com/stretchr/testify/mock" "github.com/mattermost/platform/plugin" @@ -19,3 +21,7 @@ func (m *Hooks) OnActivate(api plugin.API) error { func (m *Hooks) OnDeactivate() error { return m.Called().Error(0) } + +func (m *Hooks) ServeHTTP(w http.ResponseWriter, r *http.Request) { + m.Called(w, r) +} diff --git a/plugin/rpcplugin/api.go b/plugin/rpcplugin/api.go index a807d0837..84eb3baae 100644 --- a/plugin/rpcplugin/api.go +++ b/plugin/rpcplugin/api.go @@ -26,11 +26,6 @@ func (h *LocalAPI) LoadPluginConfiguration(args struct{}, reply *[]byte) error { return nil } -type RemoteAPI struct { - client *rpc.Client - muxer *Muxer -} - func ServeAPI(api plugin.API, conn io.ReadWriteCloser, muxer *Muxer) { server := rpc.NewServer() server.Register(&LocalAPI{ @@ -40,6 +35,11 @@ func ServeAPI(api plugin.API, conn io.ReadWriteCloser, muxer *Muxer) { server.ServeConn(conn) } +type RemoteAPI struct { + client *rpc.Client + muxer *Muxer +} + var _ plugin.API = (*RemoteAPI)(nil) func (h *RemoteAPI) LoadPluginConfiguration(dest interface{}) error { diff --git a/plugin/rpcplugin/hooks.go b/plugin/rpcplugin/hooks.go index 5b97742aa..995f4ae1a 100644 --- a/plugin/rpcplugin/hooks.go +++ b/plugin/rpcplugin/hooks.go @@ -1,7 +1,10 @@ package rpcplugin import ( + "bytes" "io" + "io/ioutil" + "net/http" "net/rpc" "reflect" @@ -83,6 +86,33 @@ func (h *LocalHooks) OnDeactivate(args, reply *struct{}) (err error) { return } +type ServeHTTPArgs struct { + ResponseWriterStream int64 + Request *http.Request + RequestBodyStream int64 +} + +func (h *LocalHooks) ServeHTTP(args ServeHTTPArgs, reply *struct{}) error { + w := ConnectHTTPResponseWriter(h.muxer.Connect(args.ResponseWriterStream)) + defer w.Close() + + r := args.Request + if args.RequestBodyStream != 0 { + r.Body = ConnectIOReader(h.muxer.Connect(args.RequestBodyStream)) + } else { + r.Body = ioutil.NopCloser(&bytes.Buffer{}) + } + defer r.Body.Close() + + if hook, ok := h.hooks.(http.Handler); ok { + hook.ServeHTTP(w, r) + } else { + http.NotFound(w, r) + } + + return nil +} + func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) { server := rpc.NewServer() server.Register(&LocalHooks{ @@ -95,6 +125,7 @@ func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) { const ( remoteOnActivate = iota remoteOnDeactivate + remoteServeHTTP maxRemoteHookCount ) @@ -133,6 +164,45 @@ func (h *RemoteHooks) OnDeactivate() error { return h.client.Call("LocalHooks.OnDeactivate", struct{}{}, nil) } +func (h *RemoteHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !h.implemented[remoteServeHTTP] { + http.NotFound(w, r) + return + } + + responseWriterStream, stream := h.muxer.Serve() + defer stream.Close() + go ServeHTTPResponseWriter(w, stream) + + requestBodyStream := int64(0) + if r.Body != nil { + rid, rstream := h.muxer.Serve() + defer rstream.Close() + go ServeIOReader(r.Body, rstream) + requestBodyStream = rid + } + + forwardedRequest := &http.Request{ + Method: r.Method, + URL: r.URL, + Proto: r.Proto, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + Header: r.Header, + Host: r.Host, + RemoteAddr: r.RemoteAddr, + RequestURI: r.RequestURI, + } + + if err := h.client.Call("LocalHooks.ServeHTTP", ServeHTTPArgs{ + ResponseWriterStream: responseWriterStream, + Request: forwardedRequest, + RequestBodyStream: requestBodyStream, + }, nil); err != nil { + http.Error(w, "500 internal server error", http.StatusInternalServerError) + } +} + func (h *RemoteHooks) Close() error { if h.apiCloser != nil { h.apiCloser.Close() @@ -157,6 +227,8 @@ func ConnectHooks(conn io.ReadWriteCloser, muxer *Muxer) (*RemoteHooks, error) { remote.implemented[remoteOnActivate] = true case "OnDeactivate": remote.implemented[remoteOnDeactivate] = true + case "ServeHTTP": + remote.implemented[remoteServeHTTP] = true } } return remote, nil diff --git a/plugin/rpcplugin/hooks_test.go b/plugin/rpcplugin/hooks_test.go index 6cd7ff547..eb684956b 100644 --- a/plugin/rpcplugin/hooks_test.go +++ b/plugin/rpcplugin/hooks_test.go @@ -2,6 +2,10 @@ package rpcplugin import ( "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -45,6 +49,31 @@ func TestHooks(t *testing.T) { hooks.On("OnDeactivate").Return(nil) assert.NoError(t, remote.OnDeactivate()) + + hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) { + w := args.Get(0).(http.ResponseWriter) + r := args.Get(1).(*http.Request) + assert.Equal(t, "/foo", r.URL.Path) + assert.Equal(t, "POST", r.Method) + body, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, "asdf", string(body)) + assert.Equal(t, "header", r.Header.Get("Test-Header")) + w.Write([]byte("bar")) + }) + + w := httptest.NewRecorder() + r, err := http.NewRequest("POST", "/foo", strings.NewReader("asdf")) + r.Header.Set("Test-Header", "header") + assert.NoError(t, err) + remote.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "bar", string(body)) })) } @@ -73,9 +102,18 @@ func TestHooks_PartiallyImplemented(t *testing.T) { })) } -func BenchmarkOnDeactivate(b *testing.B) { - var hooks plugintest.Hooks - hooks.On("OnDeactivate").Return(nil) +type benchmarkHooks struct{} + +func (*benchmarkHooks) OnDeactivate() error { return nil } + +func (*benchmarkHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ioutil.ReadAll(r.Body) + w.Header().Set("Foo-Header", "foo") + http.Error(w, "foo", http.StatusBadRequest) +} + +func BenchmarkHooks_OnDeactivate(b *testing.B) { + var hooks benchmarkHooks if err := testHooksRPC(&hooks, func(remote *RemoteHooks) { b.ResetTimer() @@ -88,7 +126,23 @@ func BenchmarkOnDeactivate(b *testing.B) { } } -func BenchmarkOnDeactivate_Unimplemented(b *testing.B) { +func BenchmarkHooks_ServeHTTP(b *testing.B) { + var hooks benchmarkHooks + + if err := testHooksRPC(&hooks, func(remote *RemoteHooks) { + b.ResetTimer() + for n := 0; n < b.N; n++ { + w := httptest.NewRecorder() + r, _ := http.NewRequest("POST", "/foo", strings.NewReader("12345678901234567890")) + remote.ServeHTTP(w, r) + } + b.StopTimer() + }); err != nil { + b.Fatal(err.Error()) + } +} + +func BenchmarkHooks_Unimplemented(b *testing.B) { var hooks testHooks if err := testHooksRPC(&hooks, func(remote *RemoteHooks) { diff --git a/plugin/rpcplugin/http.go b/plugin/rpcplugin/http.go new file mode 100644 index 000000000..cfcb7419d --- /dev/null +++ b/plugin/rpcplugin/http.go @@ -0,0 +1,88 @@ +package rpcplugin + +import ( + "io" + "net/http" + "net/rpc" +) + +type LocalHTTPResponseWriter struct { + w http.ResponseWriter +} + +func (w *LocalHTTPResponseWriter) Header(args struct{}, reply *http.Header) error { + *reply = w.w.Header() + return nil +} + +func (w *LocalHTTPResponseWriter) Write(args []byte, reply *struct{}) error { + _, err := w.w.Write(args) + return err +} + +func (w *LocalHTTPResponseWriter) WriteHeader(args int, reply *struct{}) error { + w.w.WriteHeader(args) + return nil +} + +func (w *LocalHTTPResponseWriter) SyncHeader(args http.Header, reply *struct{}) error { + dest := w.w.Header() + for k := range dest { + if _, ok := args[k]; !ok { + delete(dest, k) + } + } + for k, v := range args { + dest[k] = v + } + return nil +} + +func ServeHTTPResponseWriter(w http.ResponseWriter, conn io.ReadWriteCloser) { + server := rpc.NewServer() + server.Register(&LocalHTTPResponseWriter{ + w: w, + }) + server.ServeConn(conn) +} + +type RemoteHTTPResponseWriter struct { + client *rpc.Client + header http.Header +} + +var _ http.ResponseWriter = (*RemoteHTTPResponseWriter)(nil) + +func (w *RemoteHTTPResponseWriter) Header() http.Header { + if w.header == nil { + w.client.Call("LocalHTTPResponseWriter.Header", struct{}{}, &w.header) + } + return w.header +} + +func (w *RemoteHTTPResponseWriter) Write(b []byte) (int, error) { + if err := w.client.Call("LocalHTTPResponseWriter.SyncHeader", w.header, nil); err != nil { + return 0, err + } + if err := w.client.Call("LocalHTTPResponseWriter.Write", b, nil); err != nil { + return 0, err + } + return len(b), nil +} + +func (w *RemoteHTTPResponseWriter) WriteHeader(statusCode int) { + if err := w.client.Call("LocalHTTPResponseWriter.SyncHeader", w.header, nil); err != nil { + return + } + w.client.Call("LocalHTTPResponseWriter.WriteHeader", statusCode, nil) +} + +func (h *RemoteHTTPResponseWriter) Close() error { + return h.client.Close() +} + +func ConnectHTTPResponseWriter(conn io.ReadWriteCloser) *RemoteHTTPResponseWriter { + return &RemoteHTTPResponseWriter{ + client: rpc.NewClient(conn), + } +} diff --git a/plugin/rpcplugin/http_test.go b/plugin/rpcplugin/http_test.go new file mode 100644 index 000000000..afaaf7756 --- /dev/null +++ b/plugin/rpcplugin/http_test.go @@ -0,0 +1,61 @@ +package rpcplugin + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func testHTTPResponseWriterRPC(w http.ResponseWriter, f func(w http.ResponseWriter)) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + c1 := NewMuxer(NewReadWriteCloser(r1, w2), false) + defer c1.Close() + + c2 := NewMuxer(NewReadWriteCloser(r2, w1), true) + defer c2.Close() + + id, server := c1.Serve() + go ServeHTTPResponseWriter(w, server) + + remote := ConnectHTTPResponseWriter(c2.Connect(id)) + defer remote.Close() + + f(remote) +} + +func TestHTTP(t *testing.T) { + w := httptest.NewRecorder() + + testHTTPResponseWriterRPC(w, func(w http.ResponseWriter) { + headers := w.Header() + headers.Set("Test-Header-A", "a") + headers.Set("Test-Header-B", "b") + w.Header().Set("Test-Header-C", "c") + w.WriteHeader(http.StatusPaymentRequired) + n, err := w.Write([]byte("this is ")) + assert.Equal(t, 8, n) + assert.NoError(t, err) + n, err = w.Write([]byte("a test")) + assert.Equal(t, 6, n) + assert.NoError(t, err) + }) + + r := w.Result() + defer r.Body.Close() + + assert.Equal(t, http.StatusPaymentRequired, r.StatusCode) + + body, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.EqualValues(t, "this is a test", body) + + assert.Equal(t, "a", r.Header.Get("Test-Header-A")) + assert.Equal(t, "b", r.Header.Get("Test-Header-B")) + assert.Equal(t, "c", r.Header.Get("Test-Header-C")) +} diff --git a/plugin/rpcplugin/io.go b/plugin/rpcplugin/io.go index f1b2f3c35..38229d868 100644 --- a/plugin/rpcplugin/io.go +++ b/plugin/rpcplugin/io.go @@ -1,7 +1,10 @@ package rpcplugin import ( + "bufio" + "encoding/binary" "io" + "os" ) type rwc struct { @@ -9,15 +12,56 @@ type rwc struct { io.WriteCloser } -func (rwc *rwc) Close() error { - rerr := rwc.ReadCloser.Close() +func (rwc *rwc) Close() (err error) { + if f, ok := rwc.ReadCloser.(*os.File); ok { + // https://groups.google.com/d/topic/golang-nuts/i4w58KJ5-J8/discussion + err = os.NewFile(f.Fd(), "").Close() + } else { + err = rwc.ReadCloser.Close() + } werr := rwc.WriteCloser.Close() - if rerr != nil { - return rerr + if err == nil { + err = werr } - return werr + return } func NewReadWriteCloser(r io.ReadCloser, w io.WriteCloser) io.ReadWriteCloser { return &rwc{r, w} } + +type RemoteIOReader struct { + conn io.ReadWriteCloser +} + +func (r *RemoteIOReader) Read(b []byte) (int, error) { + var buf [10]byte + n := binary.PutVarint(buf[:], int64(len(b))) + if _, err := r.conn.Write(buf[:n]); err != nil { + return 0, err + } + return r.conn.Read(b) +} + +func (r *RemoteIOReader) Close() error { + return r.conn.Close() +} + +func ConnectIOReader(conn io.ReadWriteCloser) io.ReadCloser { + return &RemoteIOReader{conn} +} + +func ServeIOReader(r io.Reader, conn io.ReadWriteCloser) { + cr := bufio.NewReader(conn) + defer conn.Close() + buf := make([]byte, 32*1024) + for { + n, err := binary.ReadVarint(cr) + if err != nil { + break + } + if written, err := io.CopyBuffer(conn, io.LimitReader(r, n), buf); err != nil || written < n { + break + } + } +} diff --git a/plugin/rpcplugin/muxer.go b/plugin/rpcplugin/muxer.go index a2bfbf8b6..393a122c4 100644 --- a/plugin/rpcplugin/muxer.go +++ b/plugin/rpcplugin/muxer.go @@ -114,20 +114,22 @@ func (m *Muxer) write(p []byte, sid int64) (int, error) { if m.IsClosed() { return 0, fmt.Errorf("muxer closed") } - buf := make([]byte, 10) - n := binary.PutVarint(buf, sid) + var buf [10]byte + n := binary.PutVarint(buf[:], sid) if _, err := m.conn.Write(buf[:n]); err != nil { m.shutdown(err) return 0, err } - n = binary.PutVarint(buf, int64(len(p))) + n = binary.PutVarint(buf[:], int64(len(p))) if _, err := m.conn.Write(buf[:n]); err != nil { m.shutdown(err) return 0, err } - if _, err := m.conn.Write(p); err != nil { - m.shutdown(err) - return 0, err + if len(p) > 0 { + if _, err := m.conn.Write(p); err != nil { + m.shutdown(err) + return 0, err + } } return len(p), nil } @@ -180,7 +182,11 @@ func (m *Muxer) loop() error { } continue } - _, err = io.CopyN(&stream.readBuf, reader, len) + if len == 0 { + stream.remoteClosed = true + } else { + _, err = io.CopyN(&stream.readBuf, reader, len) + } stream.mutex.Unlock() if err != nil { return err @@ -207,13 +213,14 @@ func (m *Muxer) shutdown(err error) { } type muxerStream struct { - id int64 - muxer *Muxer - readBuf bytes.Buffer - mutex *sync.Mutex - readWake *sync.Cond - isClosed bool - closeErr error + id int64 + muxer *Muxer + readBuf bytes.Buffer + mutex *sync.Mutex + readWake *sync.Cond + isClosed bool + remoteClosed bool + closeErr error } func (s *muxerStream) Read(p []byte) (int, error) { @@ -225,8 +232,9 @@ func (s *muxerStream) Read(p []byte) (int, error) { } else if s.isClosed { return 0, io.EOF } else if s.readBuf.Len() > 0 { - n, err := s.readBuf.Read(p) - return n, err + return s.readBuf.Read(p) + } else if s.remoteClosed { + return 0, io.EOF } s.readWake.Wait() } @@ -245,6 +253,7 @@ func (s *muxerStream) Close() error { s.mutex.Lock() defer s.mutex.Unlock() if !s.isClosed { + s.muxer.write(nil, s.id) s.isClosed = true s.muxer.rm(s.id) } diff --git a/plugin/rpcplugin/muxer_test.go b/plugin/rpcplugin/muxer_test.go index 7bb63d4f8..795a4fb1d 100644 --- a/plugin/rpcplugin/muxer_test.go +++ b/plugin/rpcplugin/muxer_test.go @@ -129,6 +129,34 @@ func TestMuxer_StreamCloseDuringRead(t *testing.T) { assert.Equal(t, io.EOF, err) } +// Closing a stream during a read should unblock and return io.EOF since this is the way for the +// remote to gracefully close a connection. +func TestMuxer_RemoteStreamCloseDuringRead(t *testing.T) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + alice := NewMuxer(NewReadWriteCloser(r1, w2), false) + defer func() { assert.NoError(t, alice.Close()) }() + + bob := NewMuxer(NewReadWriteCloser(r2, w1), true) + defer func() { assert.NoError(t, bob.Close()) }() + + id, as := alice.Serve() + bs := bob.Connect(id) + + go func() { + as.Write([]byte("foo")) + as.Close() + }() + buf := make([]byte, 20) + n, err := bs.Read(buf) + assert.Equal(t, 3, n) + assert.Equal(t, "foo", string(buf[:n])) + n, err = bs.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) +} + // Closing a muxer during a write should unblock, but return an error. func TestMuxer_CloseDuringWrite(t *testing.T) { r1, w1 := io.Pipe() -- cgit v1.2.3-1-g7c22