summaryrefslogtreecommitdiffstats
path: root/utils
diff options
context:
space:
mode:
authorChris <ccbrown112@gmail.com>2017-08-09 15:49:07 -0500
committerChristopher Speller <crspeller@gmail.com>2017-08-09 13:49:07 -0700
commitffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1 (patch)
tree5f29ad6b3ae1c26a05a827406e9fe8c0385d26d6 /utils
parent504582b824d07946c7fb43eb2a8f0aadb15a3677 (diff)
downloadchat-ffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1.tar.gz
chat-ffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1.tar.bz2
chat-ffbf8e51fe0b80b39fa76535f96c9179b2fcc0a1.zip
PLT-6358: Server HTTP client improvements (#6980)
* restrict untrusted, internal http connections by default * command test fix * more test fixes * change setting from toggle to whitelist * requested ui changes * add isdefault diagnostic * fix tests
Diffstat (limited to 'utils')
-rw-r--r--utils/httpclient.go156
-rw-r--r--utils/httpclient_test.go86
2 files changed, 224 insertions, 18 deletions
diff --git a/utils/httpclient.go b/utils/httpclient.go
index a2355d827..afa717637 100644
--- a/utils/httpclient.go
+++ b/utils/httpclient.go
@@ -4,9 +4,12 @@
package utils
import (
+ "context"
"crypto/tls"
+ "errors"
"net"
"net/http"
+ "strings"
"time"
)
@@ -15,6 +18,11 @@ const (
requestTimeout = 30 * time.Second
)
+var secureHttpClient *http.Client
+var secureUntrustedHttpClient *http.Client
+var insecureHttpClient *http.Client
+var insecureUntrustedHttpClient *http.Client
+
// HttpClient returns a variation the default implementation of Client.
// It uses a Transport with the same settings as the default Transport
// but with the following modifications:
@@ -24,27 +32,147 @@ const (
// "requestTimeout")
// - skipping server certificate check if specified in "config.json"
// via "ServiceSettings.EnableInsecureOutgoingConnections"
-func HttpClient() *http.Client {
- if Cfg.ServiceSettings.EnableInsecureOutgoingConnections != nil && *Cfg.ServiceSettings.EnableInsecureOutgoingConnections {
+func HttpClient(trustURLs bool) *http.Client {
+ insecure := Cfg.ServiceSettings.EnableInsecureOutgoingConnections != nil && *Cfg.ServiceSettings.EnableInsecureOutgoingConnections
+ switch {
+ case insecure && trustURLs:
return insecureHttpClient
+ case insecure:
+ return insecureUntrustedHttpClient
+ case trustURLs:
+ return secureHttpClient
+ default:
+ return secureUntrustedHttpClient
}
- return secureHttpClient
}
-var (
- secureHttpClient = createHttpClient(false)
- insecureHttpClient = createHttpClient(true)
-)
+var reservedIPRanges []*net.IPNet
+
+func isReserved(ip net.IP) bool {
+ for _, ipRange := range reservedIPRanges {
+ if ipRange.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+func init() {
+ for _, cidr := range []string{
+ // See https://tools.ietf.org/html/rfc6890
+ "0.0.0.0/8", // This host on this network
+ "10.0.0.0/8", // Private-Use
+ "127.0.0.0/8", // Loopback
+ "169.254.0.0/16", // Link Local
+ "172.16.0.0/12", // Private-Use Networks
+ "192.168.0.0/16", // Private-Use Networks
+ "::/128", // Unspecified Address
+ "::1/128", // Loopback Address
+ "fc00::/7", // Unique-Local
+ "fe80::/10", // Linked-Scoped Unicast
+ } {
+ _, parsed, err := net.ParseCIDR(cidr)
+ if err != nil {
+ panic(err)
+ }
+ reservedIPRanges = append(reservedIPRanges, parsed)
+ }
+
+ allowHost := func(host string) bool {
+ if Cfg.ServiceSettings.AllowedUntrustedInternalConnections == nil {
+ return false
+ }
+ for _, allowed := range strings.Fields(*Cfg.ServiceSettings.AllowedUntrustedInternalConnections) {
+ if host == allowed {
+ return true
+ }
+ }
+ return false
+ }
+
+ allowIP := func(ip net.IP) bool {
+ if !isReserved(ip) {
+ return true
+ }
+ if Cfg.ServiceSettings.AllowedUntrustedInternalConnections == nil {
+ return false
+ }
+ for _, allowed := range strings.Fields(*Cfg.ServiceSettings.AllowedUntrustedInternalConnections) {
+ if _, ipRange, err := net.ParseCIDR(allowed); err == nil && ipRange.Contains(ip) {
+ return true
+ }
+ }
+ return false
+ }
+
+ secureHttpClient = createHttpClient(false, nil, nil)
+ insecureHttpClient = createHttpClient(true, nil, nil)
+
+ secureUntrustedHttpClient = createHttpClient(false, allowHost, allowIP)
+ insecureUntrustedHttpClient = createHttpClient(true, allowHost, allowIP)
+}
+
+type DialContextFunction func(ctx context.Context, network, addr string) (net.Conn, error)
+
+var AddressForbidden error = errors.New("address forbidden")
+
+func dialContextFilter(dial DialContextFunction, allowHost func(host string) bool, allowIP func(ip net.IP) bool) DialContextFunction {
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+
+ if allowHost != nil && allowHost(host) {
+ return dial(ctx, network, addr)
+ }
+
+ ips, err := net.LookupIP(host)
+ if err != nil {
+ return nil, err
+ }
+
+ var firstErr error
+ for _, ip := range ips {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ if allowIP == nil || !allowIP(ip) {
+ continue
+ }
+
+ conn, err := dial(ctx, network, net.JoinHostPort(ip.String(), port))
+ if err == nil {
+ return conn, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ if firstErr == nil {
+ return nil, AddressForbidden
+ }
+ return nil, firstErr
+ }
+}
+
+func createHttpClient(enableInsecureConnections bool, allowHost func(host string) bool, allowIP func(ip net.IP) bool) *http.Client {
+ dialContext := (&net.Dialer{
+ Timeout: connectTimeout,
+ KeepAlive: 30 * time.Second,
+ }).DialContext
+
+ if allowHost != nil || allowIP != nil {
+ dialContext = dialContextFilter(dialContext, allowHost, allowIP)
+ }
-func createHttpClient(enableInsecureConnections bool) *http.Client {
client := &http.Client{
Transport: &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- DialContext: (&net.Dialer{
- Timeout: connectTimeout,
- KeepAlive: 30 * time.Second,
- DualStack: true,
- }).DialContext,
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: dialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: connectTimeout,
diff --git a/utils/httpclient_test.go b/utils/httpclient_test.go
index 17353a4e7..1878b58b4 100644
--- a/utils/httpclient_test.go
+++ b/utils/httpclient_test.go
@@ -4,21 +4,63 @@
package utils
import (
+ "context"
"fmt"
"io/ioutil"
+ "net"
"net/http"
"net/http/httptest"
- "os"
+ "net/url"
"testing"
)
+func TestHttpClient(t *testing.T) {
+ for _, allowInternal := range []bool{true, false} {
+ c := HttpClient(allowInternal)
+ for _, tc := range []struct {
+ URL string
+ IsInternal bool
+ }{
+ {
+ URL: "https://google.com",
+ IsInternal: false,
+ },
+ {
+ URL: "https://127.0.0.1",
+ IsInternal: true,
+ },
+ } {
+ _, err := c.Get(tc.URL)
+ if !tc.IsInternal {
+ if err != nil {
+ t.Fatal("google is down?")
+ }
+ } else {
+ allowed := !tc.IsInternal || allowInternal
+ success := err == nil
+ switch e := err.(type) {
+ case *net.OpError:
+ success = e.Err != AddressForbidden
+ case *url.Error:
+ success = e.Err != AddressForbidden
+ }
+ if success != allowed {
+ t.Fatalf("failed for %v. allowed: %v, success %v", tc.URL, allowed, success)
+ }
+ }
+ }
+ }
+}
+
func TestHttpClientWithProxy(t *testing.T) {
proxy := createProxyServer()
defer proxy.Close()
- os.Setenv("HTTP_PROXY", proxy.URL)
- client := HttpClient()
- resp, err := client.Get("http://acme.com")
+ c := createHttpClient(true, nil, nil)
+ purl, _ := url.Parse(proxy.URL)
+ c.Transport.(*http.Transport).Proxy = http.ProxyURL(purl)
+
+ resp, err := c.Get("http://acme.com")
if err != nil {
t.Fatal(err)
}
@@ -40,3 +82,39 @@ func createProxyServer() *httptest.Server {
fmt.Fprint(w, "proxy")
}))
}
+
+func TestDialContextFilter(t *testing.T) {
+ for _, tc := range []struct {
+ Addr string
+ IsValid bool
+ }{
+ {
+ Addr: "google.com:80",
+ IsValid: true,
+ },
+ {
+ Addr: "8.8.8.8:53",
+ IsValid: true,
+ },
+ {
+ Addr: "127.0.0.1:80",
+ },
+ {
+ Addr: "10.0.0.1:80",
+ IsValid: true,
+ },
+ } {
+ didDial := false
+ filter := dialContextFilter(func(ctx context.Context, network, addr string) (net.Conn, error) {
+ didDial = true
+ return nil, nil
+ }, func(host string) bool { return host == "10.0.0.1" }, func(ip net.IP) bool { return !isReserved(ip) })
+ _, err := filter(context.Background(), "", tc.Addr)
+ switch {
+ case tc.IsValid == (err == AddressForbidden) || (err != nil && err != AddressForbidden):
+ t.Errorf("unexpected err for %v (%v)", tc.Addr, err)
+ case tc.IsValid != didDial:
+ t.Errorf("unexpected didDial for %v", tc.Addr)
+ }
+ }
+}