diff options
author | Chris <ccbrown112@gmail.com> | 2017-08-09 15:49:07 -0500 |
---|---|---|
committer | Christopher Speller <crspeller@gmail.com> | 2017-08-09 13:49:07 -0700 |
commit | ffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1 (patch) | |
tree | 5f29ad6b3ae1c26a05a827406e9fe8c0385d26d6 /utils | |
parent | 504582b824d07946c7fb43eb2a8f0aadb15a3677 (diff) | |
download | chat-ffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1.tar.gz chat-ffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1.tar.bz2 chat-ffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1.zip |
PLT-6358: Server HTTP client improvements (#6980)
* restrict untrusted, internal http connections by default
* command test fix
* more test fixes
* change setting from toggle to whitelist
* requested ui changes
* add isdefault diagnostic
* fix tests
Diffstat (limited to 'utils')
-rw-r--r-- | utils/httpclient.go | 156 | ||||
-rw-r--r-- | utils/httpclient_test.go | 86 |
2 files changed, 224 insertions, 18 deletions
diff --git a/utils/httpclient.go b/utils/httpclient.go index a2355d827..afa717637 100644 --- a/utils/httpclient.go +++ b/utils/httpclient.go @@ -4,9 +4,12 @@ package utils import ( + "context" "crypto/tls" + "errors" "net" "net/http" + "strings" "time" ) @@ -15,6 +18,11 @@ const ( requestTimeout = 30 * time.Second ) +var secureHttpClient *http.Client +var secureUntrustedHttpClient *http.Client +var insecureHttpClient *http.Client +var insecureUntrustedHttpClient *http.Client + // HttpClient returns a variation the default implementation of Client. // It uses a Transport with the same settings as the default Transport // but with the following modifications: @@ -24,27 +32,147 @@ const ( // "requestTimeout") // - skipping server certificate check if specified in "config.json" // via "ServiceSettings.EnableInsecureOutgoingConnections" -func HttpClient() *http.Client { - if Cfg.ServiceSettings.EnableInsecureOutgoingConnections != nil && *Cfg.ServiceSettings.EnableInsecureOutgoingConnections { +func HttpClient(trustURLs bool) *http.Client { + insecure := Cfg.ServiceSettings.EnableInsecureOutgoingConnections != nil && *Cfg.ServiceSettings.EnableInsecureOutgoingConnections + switch { + case insecure && trustURLs: return insecureHttpClient + case insecure: + return insecureUntrustedHttpClient + case trustURLs: + return secureHttpClient + default: + return secureUntrustedHttpClient } - return secureHttpClient } -var ( - secureHttpClient = createHttpClient(false) - insecureHttpClient = createHttpClient(true) -) +var reservedIPRanges []*net.IPNet + +func isReserved(ip net.IP) bool { + for _, ipRange := range reservedIPRanges { + if ipRange.Contains(ip) { + return true + } + } + return false +} + +func init() { + for _, cidr := range []string{ + // See https://tools.ietf.org/html/rfc6890 + "0.0.0.0/8", // This host on this network + "10.0.0.0/8", // Private-Use + "127.0.0.0/8", // Loopback + "169.254.0.0/16", // Link Local + "172.16.0.0/12", // Private-Use Networks + "192.168.0.0/16", // Private-Use Networks + "::/128", // Unspecified Address + "::1/128", // Loopback Address + "fc00::/7", // Unique-Local + "fe80::/10", // Linked-Scoped Unicast + } { + _, parsed, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + reservedIPRanges = append(reservedIPRanges, parsed) + } + + allowHost := func(host string) bool { + if Cfg.ServiceSettings.AllowedUntrustedInternalConnections == nil { + return false + } + for _, allowed := range strings.Fields(*Cfg.ServiceSettings.AllowedUntrustedInternalConnections) { + if host == allowed { + return true + } + } + return false + } + + allowIP := func(ip net.IP) bool { + if !isReserved(ip) { + return true + } + if Cfg.ServiceSettings.AllowedUntrustedInternalConnections == nil { + return false + } + for _, allowed := range strings.Fields(*Cfg.ServiceSettings.AllowedUntrustedInternalConnections) { + if _, ipRange, err := net.ParseCIDR(allowed); err == nil && ipRange.Contains(ip) { + return true + } + } + return false + } + + secureHttpClient = createHttpClient(false, nil, nil) + insecureHttpClient = createHttpClient(true, nil, nil) + + secureUntrustedHttpClient = createHttpClient(false, allowHost, allowIP) + insecureUntrustedHttpClient = createHttpClient(true, allowHost, allowIP) +} + +type DialContextFunction func(ctx context.Context, network, addr string) (net.Conn, error) + +var AddressForbidden error = errors.New("address forbidden") + +func dialContextFilter(dial DialContextFunction, allowHost func(host string) bool, allowIP func(ip net.IP) bool) DialContextFunction { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + if allowHost != nil && allowHost(host) { + return dial(ctx, network, addr) + } + + ips, err := net.LookupIP(host) + if err != nil { + return nil, err + } + + var firstErr error + for _, ip := range ips { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if allowIP == nil || !allowIP(ip) { + continue + } + + conn, err := dial(ctx, network, net.JoinHostPort(ip.String(), port)) + if err == nil { + return conn, nil + } + if firstErr == nil { + firstErr = err + } + } + if firstErr == nil { + return nil, AddressForbidden + } + return nil, firstErr + } +} + +func createHttpClient(enableInsecureConnections bool, allowHost func(host string) bool, allowIP func(ip net.IP) bool) *http.Client { + dialContext := (&net.Dialer{ + Timeout: connectTimeout, + KeepAlive: 30 * time.Second, + }).DialContext + + if allowHost != nil || allowIP != nil { + dialContext = dialContextFilter(dialContext, allowHost, allowIP) + } -func createHttpClient(enableInsecureConnections bool) *http.Client { client := &http.Client{ Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: connectTimeout, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, + Proxy: http.ProxyFromEnvironment, + DialContext: dialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: connectTimeout, diff --git a/utils/httpclient_test.go b/utils/httpclient_test.go index 17353a4e7..1878b58b4 100644 --- a/utils/httpclient_test.go +++ b/utils/httpclient_test.go @@ -4,21 +4,63 @@ package utils import ( + "context" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" - "os" + "net/url" "testing" ) +func TestHttpClient(t *testing.T) { + for _, allowInternal := range []bool{true, false} { + c := HttpClient(allowInternal) + for _, tc := range []struct { + URL string + IsInternal bool + }{ + { + URL: "https://google.com", + IsInternal: false, + }, + { + URL: "https://127.0.0.1", + IsInternal: true, + }, + } { + _, err := c.Get(tc.URL) + if !tc.IsInternal { + if err != nil { + t.Fatal("google is down?") + } + } else { + allowed := !tc.IsInternal || allowInternal + success := err == nil + switch e := err.(type) { + case *net.OpError: + success = e.Err != AddressForbidden + case *url.Error: + success = e.Err != AddressForbidden + } + if success != allowed { + t.Fatalf("failed for %v. allowed: %v, success %v", tc.URL, allowed, success) + } + } + } + } +} + func TestHttpClientWithProxy(t *testing.T) { proxy := createProxyServer() defer proxy.Close() - os.Setenv("HTTP_PROXY", proxy.URL) - client := HttpClient() - resp, err := client.Get("http://acme.com") + c := createHttpClient(true, nil, nil) + purl, _ := url.Parse(proxy.URL) + c.Transport.(*http.Transport).Proxy = http.ProxyURL(purl) + + resp, err := c.Get("http://acme.com") if err != nil { t.Fatal(err) } @@ -40,3 +82,39 @@ func createProxyServer() *httptest.Server { fmt.Fprint(w, "proxy") })) } + +func TestDialContextFilter(t *testing.T) { + for _, tc := range []struct { + Addr string + IsValid bool + }{ + { + Addr: "google.com:80", + IsValid: true, + }, + { + Addr: "8.8.8.8:53", + IsValid: true, + }, + { + Addr: "127.0.0.1:80", + }, + { + Addr: "10.0.0.1:80", + IsValid: true, + }, + } { + didDial := false + filter := dialContextFilter(func(ctx context.Context, network, addr string) (net.Conn, error) { + didDial = true + return nil, nil + }, func(host string) bool { return host == "10.0.0.1" }, func(ip net.IP) bool { return !isReserved(ip) }) + _, err := filter(context.Background(), "", tc.Addr) + switch { + case tc.IsValid == (err == AddressForbidden) || (err != nil && err != AddressForbidden): + t.Errorf("unexpected err for %v (%v)", tc.Addr, err) + case tc.IsValid != didDial: + t.Errorf("unexpected didDial for %v", tc.Addr) + } + } +} |