summaryrefslogtreecommitdiffstats
path: root/plugin
diff options
context:
space:
mode:
authorChris <ccbrown112@gmail.com>2017-08-28 11:27:18 -0500
committerGitHub <noreply@github.com>2017-08-28 11:27:18 -0500
commit6215c9159acb85033616d2937edf3d87ef7ca79b (patch)
tree54e53392804bda3a30aa9615fef0cd01ca189152 /plugin
parent510b1a18f5282981a70503c0cde474e121c9e651 (diff)
downloadchat-6215c9159acb85033616d2937edf3d87ef7ca79b.tar.gz
chat-6215c9159acb85033616d2937edf3d87ef7ca79b.tar.bz2
chat-6215c9159acb85033616d2937edf3d87ef7ca79b.zip
add plugin http handler (#7289)
Diffstat (limited to 'plugin')
-rw-r--r--plugin/hooks.go11
-rw-r--r--plugin/plugintest/hooks.go6
-rw-r--r--plugin/rpcplugin/api.go10
-rw-r--r--plugin/rpcplugin/hooks.go72
-rw-r--r--plugin/rpcplugin/hooks_test.go62
-rw-r--r--plugin/rpcplugin/http.go88
-rw-r--r--plugin/rpcplugin/http_test.go61
-rw-r--r--plugin/rpcplugin/io.go54
-rw-r--r--plugin/rpcplugin/muxer.go41
-rw-r--r--plugin/rpcplugin/muxer_test.go28
10 files changed, 403 insertions, 30 deletions
diff --git a/plugin/hooks.go b/plugin/hooks.go
index 28a762a1a..336e56ccb 100644
--- a/plugin/hooks.go
+++ b/plugin/hooks.go
@@ -1,5 +1,9 @@
package plugin
+import (
+ "net/http"
+)
+
type Hooks interface {
// OnActivate is invoked when the plugin is activated.
OnActivate(API) error
@@ -7,4 +11,11 @@ type Hooks interface {
// OnDeactivate is invoked when the plugin is deactivated. This is the plugin's last chance to
// use the API, and the plugin will be terminated shortly after this invocation.
OnDeactivate() error
+
+ // ServeHTTP allows the plugin to implement the http.Handler interface. Requests destined for
+ // the /plugins/{id} path will be routed to the plugin.
+ //
+ // The Mattermost-User-Id header will be present if (and only if) the request is by an
+ // authenticated user.
+ ServeHTTP(http.ResponseWriter, *http.Request)
}
diff --git a/plugin/plugintest/hooks.go b/plugin/plugintest/hooks.go
index 057c705c9..4cac515b4 100644
--- a/plugin/plugintest/hooks.go
+++ b/plugin/plugintest/hooks.go
@@ -1,6 +1,8 @@
package plugintest
import (
+ "net/http"
+
"github.com/stretchr/testify/mock"
"github.com/mattermost/platform/plugin"
@@ -19,3 +21,7 @@ func (m *Hooks) OnActivate(api plugin.API) error {
func (m *Hooks) OnDeactivate() error {
return m.Called().Error(0)
}
+
+func (m *Hooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ m.Called(w, r)
+}
diff --git a/plugin/rpcplugin/api.go b/plugin/rpcplugin/api.go
index a807d0837..84eb3baae 100644
--- a/plugin/rpcplugin/api.go
+++ b/plugin/rpcplugin/api.go
@@ -26,11 +26,6 @@ func (h *LocalAPI) LoadPluginConfiguration(args struct{}, reply *[]byte) error {
return nil
}
-type RemoteAPI struct {
- client *rpc.Client
- muxer *Muxer
-}
-
func ServeAPI(api plugin.API, conn io.ReadWriteCloser, muxer *Muxer) {
server := rpc.NewServer()
server.Register(&LocalAPI{
@@ -40,6 +35,11 @@ func ServeAPI(api plugin.API, conn io.ReadWriteCloser, muxer *Muxer) {
server.ServeConn(conn)
}
+type RemoteAPI struct {
+ client *rpc.Client
+ muxer *Muxer
+}
+
var _ plugin.API = (*RemoteAPI)(nil)
func (h *RemoteAPI) LoadPluginConfiguration(dest interface{}) error {
diff --git a/plugin/rpcplugin/hooks.go b/plugin/rpcplugin/hooks.go
index 5b97742aa..995f4ae1a 100644
--- a/plugin/rpcplugin/hooks.go
+++ b/plugin/rpcplugin/hooks.go
@@ -1,7 +1,10 @@
package rpcplugin
import (
+ "bytes"
"io"
+ "io/ioutil"
+ "net/http"
"net/rpc"
"reflect"
@@ -83,6 +86,33 @@ func (h *LocalHooks) OnDeactivate(args, reply *struct{}) (err error) {
return
}
+type ServeHTTPArgs struct {
+ ResponseWriterStream int64
+ Request *http.Request
+ RequestBodyStream int64
+}
+
+func (h *LocalHooks) ServeHTTP(args ServeHTTPArgs, reply *struct{}) error {
+ w := ConnectHTTPResponseWriter(h.muxer.Connect(args.ResponseWriterStream))
+ defer w.Close()
+
+ r := args.Request
+ if args.RequestBodyStream != 0 {
+ r.Body = ConnectIOReader(h.muxer.Connect(args.RequestBodyStream))
+ } else {
+ r.Body = ioutil.NopCloser(&bytes.Buffer{})
+ }
+ defer r.Body.Close()
+
+ if hook, ok := h.hooks.(http.Handler); ok {
+ hook.ServeHTTP(w, r)
+ } else {
+ http.NotFound(w, r)
+ }
+
+ return nil
+}
+
func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) {
server := rpc.NewServer()
server.Register(&LocalHooks{
@@ -95,6 +125,7 @@ func ServeHooks(hooks interface{}, conn io.ReadWriteCloser, muxer *Muxer) {
const (
remoteOnActivate = iota
remoteOnDeactivate
+ remoteServeHTTP
maxRemoteHookCount
)
@@ -133,6 +164,45 @@ func (h *RemoteHooks) OnDeactivate() error {
return h.client.Call("LocalHooks.OnDeactivate", struct{}{}, nil)
}
+func (h *RemoteHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if !h.implemented[remoteServeHTTP] {
+ http.NotFound(w, r)
+ return
+ }
+
+ responseWriterStream, stream := h.muxer.Serve()
+ defer stream.Close()
+ go ServeHTTPResponseWriter(w, stream)
+
+ requestBodyStream := int64(0)
+ if r.Body != nil {
+ rid, rstream := h.muxer.Serve()
+ defer rstream.Close()
+ go ServeIOReader(r.Body, rstream)
+ requestBodyStream = rid
+ }
+
+ forwardedRequest := &http.Request{
+ Method: r.Method,
+ URL: r.URL,
+ Proto: r.Proto,
+ ProtoMajor: r.ProtoMajor,
+ ProtoMinor: r.ProtoMinor,
+ Header: r.Header,
+ Host: r.Host,
+ RemoteAddr: r.RemoteAddr,
+ RequestURI: r.RequestURI,
+ }
+
+ if err := h.client.Call("LocalHooks.ServeHTTP", ServeHTTPArgs{
+ ResponseWriterStream: responseWriterStream,
+ Request: forwardedRequest,
+ RequestBodyStream: requestBodyStream,
+ }, nil); err != nil {
+ http.Error(w, "500 internal server error", http.StatusInternalServerError)
+ }
+}
+
func (h *RemoteHooks) Close() error {
if h.apiCloser != nil {
h.apiCloser.Close()
@@ -157,6 +227,8 @@ func ConnectHooks(conn io.ReadWriteCloser, muxer *Muxer) (*RemoteHooks, error) {
remote.implemented[remoteOnActivate] = true
case "OnDeactivate":
remote.implemented[remoteOnDeactivate] = true
+ case "ServeHTTP":
+ remote.implemented[remoteServeHTTP] = true
}
}
return remote, nil
diff --git a/plugin/rpcplugin/hooks_test.go b/plugin/rpcplugin/hooks_test.go
index 6cd7ff547..eb684956b 100644
--- a/plugin/rpcplugin/hooks_test.go
+++ b/plugin/rpcplugin/hooks_test.go
@@ -2,6 +2,10 @@ package rpcplugin
import (
"io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -45,6 +49,31 @@ func TestHooks(t *testing.T) {
hooks.On("OnDeactivate").Return(nil)
assert.NoError(t, remote.OnDeactivate())
+
+ hooks.On("ServeHTTP", mock.AnythingOfType("*rpcplugin.RemoteHTTPResponseWriter"), mock.AnythingOfType("*http.Request")).Run(func(args mock.Arguments) {
+ w := args.Get(0).(http.ResponseWriter)
+ r := args.Get(1).(*http.Request)
+ assert.Equal(t, "/foo", r.URL.Path)
+ assert.Equal(t, "POST", r.Method)
+ body, err := ioutil.ReadAll(r.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, "asdf", string(body))
+ assert.Equal(t, "header", r.Header.Get("Test-Header"))
+ w.Write([]byte("bar"))
+ })
+
+ w := httptest.NewRecorder()
+ r, err := http.NewRequest("POST", "/foo", strings.NewReader("asdf"))
+ r.Header.Set("Test-Header", "header")
+ assert.NoError(t, err)
+ remote.ServeHTTP(w, r)
+
+ resp := w.Result()
+ defer resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ body, err := ioutil.ReadAll(resp.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, "bar", string(body))
}))
}
@@ -73,9 +102,18 @@ func TestHooks_PartiallyImplemented(t *testing.T) {
}))
}
-func BenchmarkOnDeactivate(b *testing.B) {
- var hooks plugintest.Hooks
- hooks.On("OnDeactivate").Return(nil)
+type benchmarkHooks struct{}
+
+func (*benchmarkHooks) OnDeactivate() error { return nil }
+
+func (*benchmarkHooks) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ ioutil.ReadAll(r.Body)
+ w.Header().Set("Foo-Header", "foo")
+ http.Error(w, "foo", http.StatusBadRequest)
+}
+
+func BenchmarkHooks_OnDeactivate(b *testing.B) {
+ var hooks benchmarkHooks
if err := testHooksRPC(&hooks, func(remote *RemoteHooks) {
b.ResetTimer()
@@ -88,7 +126,23 @@ func BenchmarkOnDeactivate(b *testing.B) {
}
}
-func BenchmarkOnDeactivate_Unimplemented(b *testing.B) {
+func BenchmarkHooks_ServeHTTP(b *testing.B) {
+ var hooks benchmarkHooks
+
+ if err := testHooksRPC(&hooks, func(remote *RemoteHooks) {
+ b.ResetTimer()
+ for n := 0; n < b.N; n++ {
+ w := httptest.NewRecorder()
+ r, _ := http.NewRequest("POST", "/foo", strings.NewReader("12345678901234567890"))
+ remote.ServeHTTP(w, r)
+ }
+ b.StopTimer()
+ }); err != nil {
+ b.Fatal(err.Error())
+ }
+}
+
+func BenchmarkHooks_Unimplemented(b *testing.B) {
var hooks testHooks
if err := testHooksRPC(&hooks, func(remote *RemoteHooks) {
diff --git a/plugin/rpcplugin/http.go b/plugin/rpcplugin/http.go
new file mode 100644
index 000000000..cfcb7419d
--- /dev/null
+++ b/plugin/rpcplugin/http.go
@@ -0,0 +1,88 @@
+package rpcplugin
+
+import (
+ "io"
+ "net/http"
+ "net/rpc"
+)
+
+type LocalHTTPResponseWriter struct {
+ w http.ResponseWriter
+}
+
+func (w *LocalHTTPResponseWriter) Header(args struct{}, reply *http.Header) error {
+ *reply = w.w.Header()
+ return nil
+}
+
+func (w *LocalHTTPResponseWriter) Write(args []byte, reply *struct{}) error {
+ _, err := w.w.Write(args)
+ return err
+}
+
+func (w *LocalHTTPResponseWriter) WriteHeader(args int, reply *struct{}) error {
+ w.w.WriteHeader(args)
+ return nil
+}
+
+func (w *LocalHTTPResponseWriter) SyncHeader(args http.Header, reply *struct{}) error {
+ dest := w.w.Header()
+ for k := range dest {
+ if _, ok := args[k]; !ok {
+ delete(dest, k)
+ }
+ }
+ for k, v := range args {
+ dest[k] = v
+ }
+ return nil
+}
+
+func ServeHTTPResponseWriter(w http.ResponseWriter, conn io.ReadWriteCloser) {
+ server := rpc.NewServer()
+ server.Register(&LocalHTTPResponseWriter{
+ w: w,
+ })
+ server.ServeConn(conn)
+}
+
+type RemoteHTTPResponseWriter struct {
+ client *rpc.Client
+ header http.Header
+}
+
+var _ http.ResponseWriter = (*RemoteHTTPResponseWriter)(nil)
+
+func (w *RemoteHTTPResponseWriter) Header() http.Header {
+ if w.header == nil {
+ w.client.Call("LocalHTTPResponseWriter.Header", struct{}{}, &w.header)
+ }
+ return w.header
+}
+
+func (w *RemoteHTTPResponseWriter) Write(b []byte) (int, error) {
+ if err := w.client.Call("LocalHTTPResponseWriter.SyncHeader", w.header, nil); err != nil {
+ return 0, err
+ }
+ if err := w.client.Call("LocalHTTPResponseWriter.Write", b, nil); err != nil {
+ return 0, err
+ }
+ return len(b), nil
+}
+
+func (w *RemoteHTTPResponseWriter) WriteHeader(statusCode int) {
+ if err := w.client.Call("LocalHTTPResponseWriter.SyncHeader", w.header, nil); err != nil {
+ return
+ }
+ w.client.Call("LocalHTTPResponseWriter.WriteHeader", statusCode, nil)
+}
+
+func (h *RemoteHTTPResponseWriter) Close() error {
+ return h.client.Close()
+}
+
+func ConnectHTTPResponseWriter(conn io.ReadWriteCloser) *RemoteHTTPResponseWriter {
+ return &RemoteHTTPResponseWriter{
+ client: rpc.NewClient(conn),
+ }
+}
diff --git a/plugin/rpcplugin/http_test.go b/plugin/rpcplugin/http_test.go
new file mode 100644
index 000000000..afaaf7756
--- /dev/null
+++ b/plugin/rpcplugin/http_test.go
@@ -0,0 +1,61 @@
+package rpcplugin
+
+import (
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func testHTTPResponseWriterRPC(w http.ResponseWriter, f func(w http.ResponseWriter)) {
+ r1, w1 := io.Pipe()
+ r2, w2 := io.Pipe()
+
+ c1 := NewMuxer(NewReadWriteCloser(r1, w2), false)
+ defer c1.Close()
+
+ c2 := NewMuxer(NewReadWriteCloser(r2, w1), true)
+ defer c2.Close()
+
+ id, server := c1.Serve()
+ go ServeHTTPResponseWriter(w, server)
+
+ remote := ConnectHTTPResponseWriter(c2.Connect(id))
+ defer remote.Close()
+
+ f(remote)
+}
+
+func TestHTTP(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ testHTTPResponseWriterRPC(w, func(w http.ResponseWriter) {
+ headers := w.Header()
+ headers.Set("Test-Header-A", "a")
+ headers.Set("Test-Header-B", "b")
+ w.Header().Set("Test-Header-C", "c")
+ w.WriteHeader(http.StatusPaymentRequired)
+ n, err := w.Write([]byte("this is "))
+ assert.Equal(t, 8, n)
+ assert.NoError(t, err)
+ n, err = w.Write([]byte("a test"))
+ assert.Equal(t, 6, n)
+ assert.NoError(t, err)
+ })
+
+ r := w.Result()
+ defer r.Body.Close()
+
+ assert.Equal(t, http.StatusPaymentRequired, r.StatusCode)
+
+ body, err := ioutil.ReadAll(r.Body)
+ assert.NoError(t, err)
+ assert.EqualValues(t, "this is a test", body)
+
+ assert.Equal(t, "a", r.Header.Get("Test-Header-A"))
+ assert.Equal(t, "b", r.Header.Get("Test-Header-B"))
+ assert.Equal(t, "c", r.Header.Get("Test-Header-C"))
+}
diff --git a/plugin/rpcplugin/io.go b/plugin/rpcplugin/io.go
index f1b2f3c35..38229d868 100644
--- a/plugin/rpcplugin/io.go
+++ b/plugin/rpcplugin/io.go
@@ -1,7 +1,10 @@
package rpcplugin
import (
+ "bufio"
+ "encoding/binary"
"io"
+ "os"
)
type rwc struct {
@@ -9,15 +12,56 @@ type rwc struct {
io.WriteCloser
}
-func (rwc *rwc) Close() error {
- rerr := rwc.ReadCloser.Close()
+func (rwc *rwc) Close() (err error) {
+ if f, ok := rwc.ReadCloser.(*os.File); ok {
+ // https://groups.google.com/d/topic/golang-nuts/i4w58KJ5-J8/discussion
+ err = os.NewFile(f.Fd(), "").Close()
+ } else {
+ err = rwc.ReadCloser.Close()
+ }
werr := rwc.WriteCloser.Close()
- if rerr != nil {
- return rerr
+ if err == nil {
+ err = werr
}
- return werr
+ return
}
func NewReadWriteCloser(r io.ReadCloser, w io.WriteCloser) io.ReadWriteCloser {
return &rwc{r, w}
}
+
+type RemoteIOReader struct {
+ conn io.ReadWriteCloser
+}
+
+func (r *RemoteIOReader) Read(b []byte) (int, error) {
+ var buf [10]byte
+ n := binary.PutVarint(buf[:], int64(len(b)))
+ if _, err := r.conn.Write(buf[:n]); err != nil {
+ return 0, err
+ }
+ return r.conn.Read(b)
+}
+
+func (r *RemoteIOReader) Close() error {
+ return r.conn.Close()
+}
+
+func ConnectIOReader(conn io.ReadWriteCloser) io.ReadCloser {
+ return &RemoteIOReader{conn}
+}
+
+func ServeIOReader(r io.Reader, conn io.ReadWriteCloser) {
+ cr := bufio.NewReader(conn)
+ defer conn.Close()
+ buf := make([]byte, 32*1024)
+ for {
+ n, err := binary.ReadVarint(cr)
+ if err != nil {
+ break
+ }
+ if written, err := io.CopyBuffer(conn, io.LimitReader(r, n), buf); err != nil || written < n {
+ break
+ }
+ }
+}
diff --git a/plugin/rpcplugin/muxer.go b/plugin/rpcplugin/muxer.go
index a2bfbf8b6..393a122c4 100644
--- a/plugin/rpcplugin/muxer.go
+++ b/plugin/rpcplugin/muxer.go
@@ -114,20 +114,22 @@ func (m *Muxer) write(p []byte, sid int64) (int, error) {
if m.IsClosed() {
return 0, fmt.Errorf("muxer closed")
}
- buf := make([]byte, 10)
- n := binary.PutVarint(buf, sid)
+ var buf [10]byte
+ n := binary.PutVarint(buf[:], sid)
if _, err := m.conn.Write(buf[:n]); err != nil {
m.shutdown(err)
return 0, err
}
- n = binary.PutVarint(buf, int64(len(p)))
+ n = binary.PutVarint(buf[:], int64(len(p)))
if _, err := m.conn.Write(buf[:n]); err != nil {
m.shutdown(err)
return 0, err
}
- if _, err := m.conn.Write(p); err != nil {
- m.shutdown(err)
- return 0, err
+ if len(p) > 0 {
+ if _, err := m.conn.Write(p); err != nil {
+ m.shutdown(err)
+ return 0, err
+ }
}
return len(p), nil
}
@@ -180,7 +182,11 @@ func (m *Muxer) loop() error {
}
continue
}
- _, err = io.CopyN(&stream.readBuf, reader, len)
+ if len == 0 {
+ stream.remoteClosed = true
+ } else {
+ _, err = io.CopyN(&stream.readBuf, reader, len)
+ }
stream.mutex.Unlock()
if err != nil {
return err
@@ -207,13 +213,14 @@ func (m *Muxer) shutdown(err error) {
}
type muxerStream struct {
- id int64
- muxer *Muxer
- readBuf bytes.Buffer
- mutex *sync.Mutex
- readWake *sync.Cond
- isClosed bool
- closeErr error
+ id int64
+ muxer *Muxer
+ readBuf bytes.Buffer
+ mutex *sync.Mutex
+ readWake *sync.Cond
+ isClosed bool
+ remoteClosed bool
+ closeErr error
}
func (s *muxerStream) Read(p []byte) (int, error) {
@@ -225,8 +232,9 @@ func (s *muxerStream) Read(p []byte) (int, error) {
} else if s.isClosed {
return 0, io.EOF
} else if s.readBuf.Len() > 0 {
- n, err := s.readBuf.Read(p)
- return n, err
+ return s.readBuf.Read(p)
+ } else if s.remoteClosed {
+ return 0, io.EOF
}
s.readWake.Wait()
}
@@ -245,6 +253,7 @@ func (s *muxerStream) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.isClosed {
+ s.muxer.write(nil, s.id)
s.isClosed = true
s.muxer.rm(s.id)
}
diff --git a/plugin/rpcplugin/muxer_test.go b/plugin/rpcplugin/muxer_test.go
index 7bb63d4f8..795a4fb1d 100644
--- a/plugin/rpcplugin/muxer_test.go
+++ b/plugin/rpcplugin/muxer_test.go
@@ -129,6 +129,34 @@ func TestMuxer_StreamCloseDuringRead(t *testing.T) {
assert.Equal(t, io.EOF, err)
}
+// Closing a stream during a read should unblock and return io.EOF since this is the way for the
+// remote to gracefully close a connection.
+func TestMuxer_RemoteStreamCloseDuringRead(t *testing.T) {
+ r1, w1 := io.Pipe()
+ r2, w2 := io.Pipe()
+
+ alice := NewMuxer(NewReadWriteCloser(r1, w2), false)
+ defer func() { assert.NoError(t, alice.Close()) }()
+
+ bob := NewMuxer(NewReadWriteCloser(r2, w1), true)
+ defer func() { assert.NoError(t, bob.Close()) }()
+
+ id, as := alice.Serve()
+ bs := bob.Connect(id)
+
+ go func() {
+ as.Write([]byte("foo"))
+ as.Close()
+ }()
+ buf := make([]byte, 20)
+ n, err := bs.Read(buf)
+ assert.Equal(t, 3, n)
+ assert.Equal(t, "foo", string(buf[:n]))
+ n, err = bs.Read(buf)
+ assert.Equal(t, 0, n)
+ assert.Equal(t, io.EOF, err)
+}
+
// Closing a muxer during a write should unblock, but return an error.
func TestMuxer_CloseDuringWrite(t *testing.T) {
r1, w1 := io.Pipe()