// Copyright 2013 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package handlers import ( "bytes" "net" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" ) const ( ok = "ok\n" notAllowed = "Method not allowed\n" ) var okHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte(ok)) }) func newRequest(method, url string) *http.Request { req, err := http.NewRequest(method, url, nil) if err != nil { panic(err) } return req } func TestMethodHandler(t *testing.T) { tests := []struct { req *http.Request handler http.Handler code int allow string // Contents of the Allow header body string }{ // No handlers {newRequest("GET", "/foo"), MethodHandler{}, http.StatusMethodNotAllowed, "", notAllowed}, {newRequest("OPTIONS", "/foo"), MethodHandler{}, http.StatusOK, "", ""}, // A single handler {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler}, http.StatusOK, "", ok}, {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler}, http.StatusMethodNotAllowed, "GET", notAllowed}, // Multiple handlers {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok}, {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok}, {newRequest("DELETE", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusMethodNotAllowed, "GET, POST", notAllowed}, {newRequest("OPTIONS", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "GET, POST", ""}, // Override OPTIONS {newRequest("OPTIONS", "/foo"), MethodHandler{"OPTIONS": okHandler}, http.StatusOK, "", ok}, } for i, test := range tests { rec := httptest.NewRecorder() test.handler.ServeHTTP(rec, test.req) if rec.Code != test.code { t.Fatalf("%d: wrong code, got %d want %d", i, rec.Code, test.code) } if allow := rec.HeaderMap.Get("Allow"); allow != test.allow { t.Fatalf("%d: wrong Allow, got %s want %s", i, allow, test.allow) } if body := rec.Body.String(); body != test.body { t.Fatalf("%d: wrong body, got %q want %q", i, body, test.body) } } } func TestMakeLogger(t *testing.T) { rec := httptest.NewRecorder() logger := makeLogger(rec) // initial status if logger.Status() != http.StatusOK { t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusOK) } // WriteHeader logger.WriteHeader(http.StatusInternalServerError) if logger.Status() != http.StatusInternalServerError { t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusInternalServerError) } // Write logger.Write([]byte(ok)) if logger.Size() != len(ok) { t.Fatalf("wrong size, got %d want %d", logger.Size(), len(ok)) } // Header logger.Header().Set("key", "value") if val := logger.Header().Get("key"); val != "value" { t.Fatalf("wrong header, got %s want %s", val, "value") } } func TestWriteLog(t *testing.T) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { panic(err) } ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) // A typical request with an OK response req := newRequest("GET", "http://example.com") req.RemoteAddr = "192.168.100.5" buf := new(bytes.Buffer) writeLog(buf, req, *req.URL, ts, http.StatusOK, 100) log := buf.String() expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } // CONNECT request over http/2.0 req = &http.Request{ Method: "CONNECT", Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, URL: &url.URL{Host: "www.example.com:443"}, Host: "www.example.com:443", RemoteAddr: "192.168.100.5", } buf = new(bytes.Buffer) writeLog(buf, req, *req.URL, ts, http.StatusOK, 100) log = buf.String() expected = "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"CONNECT www.example.com:443 HTTP/2.0\" 200 100\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } // Request with an unauthorized user req = newRequest("GET", "http://example.com") req.RemoteAddr = "192.168.100.5" req.URL.User = url.User("kamil") buf.Reset() writeLog(buf, req, *req.URL, ts, http.StatusUnauthorized, 500) log = buf.String() expected = "192.168.100.5 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 401 500\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } // Request with url encoded parameters req = newRequest("GET", "http://example.com/test?abc=hello%20world&a=b%3F") req.RemoteAddr = "192.168.100.5" buf.Reset() writeLog(buf, req, *req.URL, ts, http.StatusOK, 100) log = buf.String() expected = "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET /test?abc=hello%20world&a=b%3F HTTP/1.1\" 200 100\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } } func TestWriteCombinedLog(t *testing.T) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { panic(err) } ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) // A typical request with an OK response req := newRequest("GET", "http://example.com") req.RemoteAddr = "192.168.100.5" req.Header.Set("Referer", "http://example.com") req.Header.Set( "User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.33 "+ "(KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33", ) buf := new(bytes.Buffer) writeCombinedLog(buf, req, *req.URL, ts, http.StatusOK, 100) log := buf.String() expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } // CONNECT request over http/2.0 req1 := &http.Request{ Method: "CONNECT", Host: "www.example.com:443", Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, RemoteAddr: "192.168.100.5", Header: http.Header{}, URL: &url.URL{Host: "www.example.com:443"}, } req1.Header.Set("Referer", "http://example.com") req1.Header.Set( "User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.33 "+ "(KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33", ) buf = new(bytes.Buffer) writeCombinedLog(buf, req1, *req1.URL, ts, http.StatusOK, 100) log = buf.String() expected = "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"CONNECT www.example.com:443 HTTP/2.0\" 200 100 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } // Request with an unauthorized user req.URL.User = url.User("kamil") buf.Reset() writeCombinedLog(buf, req, *req.URL, ts, http.StatusUnauthorized, 500) log = buf.String() expected = "192.168.100.5 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 401 500 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } // Test with remote ipv6 address req.RemoteAddr = "::1" buf.Reset() writeCombinedLog(buf, req, *req.URL, ts, http.StatusOK, 100) log = buf.String() expected = "::1 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } // Test remote ipv6 addr, with port req.RemoteAddr = net.JoinHostPort("::1", "65000") buf.Reset() writeCombinedLog(buf, req, *req.URL, ts, http.StatusOK, 100) log = buf.String() expected = "::1 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } } func TestLogPathRewrites(t *testing.T) { var buf bytes.Buffer handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL.Path = "/" // simulate http.StripPrefix and friends w.WriteHeader(200) }) logger := LoggingHandler(&buf, handler) logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) if !strings.Contains(buf.String(), "GET /subdir/asdf HTTP") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "GET /subdir/asdf HTTP") } } func BenchmarkWriteLog(b *testing.B) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { b.Fatalf(err.Error()) } ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) req := newRequest("GET", "http://example.com") req.RemoteAddr = "192.168.100.5" b.ResetTimer() buf := &bytes.Buffer{} for i := 0; i < b.N; i++ { buf.Reset() writeLog(buf, req, *req.URL, ts, http.StatusUnauthorized, 500) } } func TestContentTypeHandler(t *testing.T) { tests := []struct { Method string AllowContentTypes []string ContentType string Code int }{ {"POST", []string{"application/json"}, "application/json", http.StatusOK}, {"POST", []string{"application/json", "application/xml"}, "application/json", http.StatusOK}, {"POST", []string{"application/json"}, "application/json; charset=utf-8", http.StatusOK}, {"POST", []string{"application/json"}, "application/json+xxx", http.StatusUnsupportedMediaType}, {"POST", []string{"application/json"}, "text/plain", http.StatusUnsupportedMediaType}, {"GET", []string{"application/json"}, "", http.StatusOK}, {"GET", []string{}, "", http.StatusOK}, } for _, test := range tests { r, err := http.NewRequest(test.Method, "/", nil) if err != nil { t.Error(err) continue } h := ContentTypeHandler(okHandler, test.AllowContentTypes...) r.Header.Set("Content-Type", test.ContentType) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != test.Code { t.Errorf("expected %d, got %d", test.Code, w.Code) } } } func TestHTTPMethodOverride(t *testing.T) { var tests = []struct { Method string OverrideMethod string ExpectedMethod string }{ {"POST", "PUT", "PUT"}, {"POST", "PATCH", "PATCH"}, {"POST", "DELETE", "DELETE"}, {"PUT", "DELETE", "PUT"}, {"GET", "GET", "GET"}, {"HEAD", "HEAD", "HEAD"}, {"GET", "PUT", "GET"}, {"HEAD", "DELETE", "HEAD"}, } for _, test := range tests { h := HTTPMethodOverrideHandler(okHandler) reqs := make([]*http.Request, 0, 2) rHeader, err := http.NewRequest(test.Method, "/", nil) if err != nil { t.Error(err) } rHeader.Header.Set(HTTPMethodOverrideHeader, test.OverrideMethod) reqs = append(reqs, rHeader) f := url.Values{HTTPMethodOverrideFormKey: []string{test.OverrideMethod}} rForm, err := http.NewRequest(test.Method, "/", strings.NewReader(f.Encode())) if err != nil { t.Error(err) } rForm.Header.Set("Content-Type", "application/x-www-form-urlencoded") reqs = append(reqs, rForm) for _, r := range reqs { w := httptest.NewRecorder() h.ServeHTTP(w, r) if r.Method != test.ExpectedMethod { t.Errorf("Expected %s, got %s", test.ExpectedMethod, r.Method) } } } }