summaryrefslogtreecommitdiffstats
path: root/src/lib/Bcfg2/Server/SSLServer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/Bcfg2/Server/SSLServer.py')
-rw-r--r--src/lib/Bcfg2/Server/SSLServer.py453
1 files changed, 453 insertions, 0 deletions
diff --git a/src/lib/Bcfg2/Server/SSLServer.py b/src/lib/Bcfg2/Server/SSLServer.py
new file mode 100644
index 000000000..28450aa1a
--- /dev/null
+++ b/src/lib/Bcfg2/Server/SSLServer.py
@@ -0,0 +1,453 @@
+""" Bcfg2 SSL server used by the builtin server core
+(:mod:`Bcfg2.Server.BuiltinCore`). This needs to be documented
+better. """
+
+import os
+import sys
+import socket
+import select
+import signal
+import logging
+import ssl
+import threading
+import time
+from Bcfg2.Compat import xmlrpclib, SimpleXMLRPCServer, SocketServer, \
+ b64decode
+
+
+class XMLRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
+ """ An XML-RPC dispatcher. """
+
+ def __init__(self, allow_none, encoding):
+ try:
+ SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self,
+ allow_none,
+ encoding)
+ except:
+ # Python 2.4?
+ SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self)
+
+ self.logger = logging.getLogger(self.__class__.__name__)
+ self.allow_none = allow_none
+ self.encoding = encoding
+
+ def _marshaled_dispatch(self, address, data):
+ params, method = xmlrpclib.loads(data)
+ try:
+ if '.' not in method:
+ params = (address, ) + params
+ response = self.instance._dispatch(method, params, self.funcs)
+ # py3k compatibility
+ if type(response) not in [bool, str, list, dict]:
+ response = (response.decode('utf-8'), )
+ else:
+ response = (response, )
+ raw_response = xmlrpclib.dumps(response, methodresponse=1,
+ allow_none=self.allow_none,
+ encoding=self.encoding)
+ except xmlrpclib.Fault:
+ fault = sys.exc_info()[1]
+ raw_response = xmlrpclib.dumps(fault,
+ allow_none=self.allow_none,
+ encoding=self.encoding)
+ except:
+ self.logger.error("Unexpected handler error", exc_info=1)
+ # report exception back to server
+ raw_response = xmlrpclib.dumps(
+ xmlrpclib.Fault(1, "%s:%s" % (sys.exc_type, sys.exc_value)),
+ allow_none=self.allow_none, encoding=self.encoding)
+ return raw_response
+
+
+class SSLServer(SocketServer.TCPServer, object):
+ """ TCP server supporting SSL encryption. """
+ allow_reuse_address = True
+
+ def __init__(self, listen_all, server_address, RequestHandlerClass,
+ keyfile=None, certfile=None, reqCert=False, ca=None,
+ timeout=None, protocol='xmlrpc/ssl'):
+ """
+ :param listen_all: Listen on all interfaces
+ :type listen_all: bool
+ :param server_address: Address to bind to the server
+ :param RequestHandlerClass: Request handler used by TCP server
+ :param keyfile: Full path to SSL encryption key file
+ :type keyfile: string
+ :param certfile: Full path to SSL certificate file
+ :type certfile: string
+ :param reqCert: Require client to present certificate
+ :type reqCert: bool
+ :param ca: Full path to SSL CA that signed the key and cert
+ :type ca: string
+ :param timeout: Timeout for non-blocking request handling
+ :param protocol: The protocol to serve. Supported values are
+ ``xmlrpc/ssl`` and ``xmlrpc/tlsv1``.
+ :type protocol: string
+ """
+ # check whether or not we should listen on all interfaces
+ if listen_all:
+ listen_address = ('', server_address[1])
+ else:
+ listen_address = (server_address[0], server_address[1])
+
+ # check for IPv6 address
+ if ':' in server_address[0]:
+ self.address_family = socket.AF_INET6
+
+ self.logger = logging.getLogger(self.__class__.__name__)
+
+ try:
+ SocketServer.TCPServer.__init__(self, listen_address,
+ RequestHandlerClass)
+ except socket.gaierror:
+ e = sys.exc_info()[1]
+ self.logger.error("Failed to bind to socket: %s" % e)
+ raise
+ except socket.error:
+ self.logger.error("Failed to bind to socket")
+ raise
+
+ self.timeout = timeout
+ self.socket.settimeout(timeout)
+ self.keyfile = keyfile
+ if (keyfile is not None and
+ (keyfile == False or
+ not os.path.exists(keyfile) or
+ not os.access(keyfile, os.R_OK))):
+ msg = "Keyfile %s does not exist or is not readable" % keyfile
+ self.logger.error(msg)
+ raise Exception(msg)
+ self.certfile = certfile
+ if (certfile is not None and
+ (certfile == False or
+ not os.path.exists(certfile) or
+ not os.access(certfile, os.R_OK))):
+ msg = "Certfile %s does not exist or is not readable" % certfile
+ self.logger.error(msg)
+ raise Exception(msg)
+ self.ca = ca
+ if (ca is not None and
+ (ca == False or
+ not os.path.exists(ca) or
+ not os.access(ca, os.R_OK))):
+ msg = "CA %s does not exist or is not readable" % ca
+ self.logger.error(msg)
+ raise Exception(msg)
+ self.reqCert = reqCert
+ if ca and certfile:
+ self.mode = ssl.CERT_OPTIONAL
+ else:
+ self.mode = ssl.CERT_NONE
+ if protocol == 'xmlrpc/ssl':
+ self.ssl_protocol = ssl.PROTOCOL_SSLv23
+ elif protocol == 'xmlrpc/tlsv1':
+ self.ssl_protocol = ssl.PROTOCOL_TLSv1
+ else:
+ self.logger.error("Unknown protocol %s" % (protocol))
+ raise Exception("unknown protocol %s" % protocol)
+
+ def get_request(self):
+ (sock, sockinfo) = self.socket.accept()
+ sock.settimeout(self.timeout) # pylint: disable=E1101
+ sslsock = ssl.wrap_socket(sock,
+ server_side=True,
+ certfile=self.certfile,
+ keyfile=self.keyfile,
+ cert_reqs=self.mode,
+ ca_certs=self.ca,
+ ssl_version=self.ssl_protocol)
+ return sslsock, sockinfo
+
+ def close_request(self, request):
+ try:
+ request.unwrap()
+ except:
+ pass
+ try:
+ request.close()
+ except:
+ pass
+
+ def _get_url(self):
+ port = self.socket.getsockname()[1]
+ hostname = socket.gethostname()
+ protocol = "https"
+ return "%s://%s:%i" % (protocol, hostname, port)
+ url = property(_get_url)
+
+
+class XMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
+ """ XML-RPC request handler.
+
+ Adds support for HTTP authentication.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.logger = logging.getLogger(self.__class__.__name__)
+ SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.__init__(self, *args,
+ **kwargs)
+
+ def authenticate(self):
+ try:
+ header = self.headers['Authorization']
+ except KeyError:
+ self.logger.error("No authentication data presented")
+ return False
+ auth_content = b64decode(header.split()[1])
+ try:
+ # py3k compatibility
+ try:
+ username, password = auth_content.split(":")
+ except TypeError:
+ username, pw = auth_content.split(bytes(":", encoding='utf-8'))
+ password = pw.decode('utf-8')
+ except ValueError:
+ username = auth_content
+ password = ""
+ cert = self.request.getpeercert()
+ client_address = self.request.getpeername()
+ return (self.server.instance.authenticate(cert, username,
+ password, client_address) and
+ self.server.instance.check_acls(client_address[0]))
+
+ def parse_request(self):
+ """Extends parse_request.
+
+ Optionally check HTTP authentication when parsing.
+ """
+ if not SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.parse_request(self):
+ return False
+ try:
+ if not self.authenticate():
+ self.logger.error("Authentication Failure")
+ self.send_error(401, self.responses[401][0])
+ return False
+ except: # pylint: disable=W0702
+ self.logger.error("Unexpected Authentication Failure", exc_info=1)
+ self.send_error(401, self.responses[401][0])
+ return False
+ return True
+
+ ### need to override do_POST here
+ def do_POST(self):
+ try:
+ max_chunk_size = 10 * 1024 * 1024
+ size_remaining = int(self.headers["content-length"])
+ L = []
+ while size_remaining:
+ try:
+ select.select([self.rfile.fileno()], [], [], 3)
+ except select.error:
+ print("got select timeout")
+ raise
+ chunk_size = min(size_remaining, max_chunk_size)
+ L.append(self.rfile.read(chunk_size).decode('utf-8'))
+ size_remaining -= len(L[-1])
+ data = ''.join(L)
+ response = self.server._marshaled_dispatch(self.client_address,
+ data)
+ if sys.hexversion >= 0x03000000:
+ response = response.encode('utf-8')
+ except: # pylint: disable=W0702
+ try:
+ self.send_response(500)
+ self.end_headers()
+ except:
+ (etype, msg) = sys.exc_info()[:2]
+ self.logger.error("Error sending 500 response (%s): %s" %
+ (etype.__name__, msg))
+ raise
+ else:
+ # got a valid XML RPC response
+ # first, check ACLs
+ client_address = self.request.getpeername()
+ method = xmlrpclib.loads(data)[1]
+ if not self.server.instance.check_acls(client_address, method):
+ self.send_error(401, self.responses[401][0])
+ self.end_headers()
+ try:
+ self.send_response(200)
+ self.send_header("Content-type", "text/xml")
+ self.send_header("Content-length", str(len(response)))
+ self.end_headers()
+ failcount = 0
+ while True:
+ try:
+ # If we hit SSL3_WRITE_PENDING here try to resend.
+ self.wfile.write(response)
+ break
+ except ssl.SSLError:
+ e = sys.exc_info()[1]
+ if str(e).find("SSL3_WRITE_PENDING") < 0:
+ raise
+ self.logger.error("SSL3_WRITE_PENDING")
+ failcount += 1
+ if failcount < 5:
+ continue
+ raise
+ except socket.error:
+ err = sys.exc_info()[1]
+ if err[0] == 32:
+ self.logger.warning("Connection dropped from %s" %
+ self.client_address[0])
+ elif err[0] == 104:
+ self.logger.warning("Connection reset by peer: %s" %
+ self.client_address[0])
+ else:
+ self.logger.warning("Socket error sending response to %s: "
+ "%s" % (self.client_address[0], err))
+ except ssl.SSLError:
+ err = sys.exc_info()[1]
+ self.logger.warning("SSLError handling client %s: %s" %
+ (self.client_address[0], err))
+ except:
+ etype, err = sys.exc_info()[:2]
+ self.logger.error("Unknown error sending response to %s: "
+ "%s (%s)" %
+ (self.client_address[0], err,
+ etype.__name__))
+
+ def finish(self):
+ # shut down the connection
+ if not self.wfile.closed:
+ try:
+ self.wfile.flush()
+ self.wfile.close()
+ except socket.error:
+ err = sys.exc_info()[1]
+ self.logger.warning("Error closing connection: %s" % err)
+ self.rfile.close()
+
+
+class XMLRPCServer(SocketServer.ThreadingMixIn, SSLServer,
+ XMLRPCDispatcher, object):
+ """ Component XMLRPCServer. """
+
+ def __init__(self, listen_all, server_address, RequestHandlerClass=None,
+ keyfile=None, certfile=None, ca=None, protocol='xmlrpc/ssl',
+ timeout=10, logRequests=False,
+ register=True, allow_none=True, encoding=None):
+ """
+ :param listen_all: Listen on all interfaces
+ :type listen_all: bool
+ :param server_address: Address to bind to the server
+ :param RequestHandlerClass: request handler used by TCP server
+ :param keyfile: Full path to SSL encryption key file
+ :type keyfile: string
+ :param certfile: Full path to SSL certificate file
+ :type certfile: string
+ :param ca: Full path to SSL CA that signed the key and cert
+ :type ca: string
+ :param logRequests: Log all requests
+ :type logRequests: bool
+ :param register: Presence should be reported to service-location
+ :type register: bool
+ :param allow_none: Allow None values in XML-RPC
+ :type allow_non: bool
+ :param encoding: Encoding to use for XML-RPC
+ """
+
+ XMLRPCDispatcher.__init__(self, allow_none, encoding)
+
+ if not RequestHandlerClass:
+ # pylint: disable=E0102
+ class RequestHandlerClass(XMLRPCRequestHandler):
+ """A subclassed request handler to prevent
+ class-attribute conflicts."""
+ # pylint: enable=E0102
+
+ SSLServer.__init__(self,
+ listen_all,
+ server_address,
+ RequestHandlerClass,
+ ca=ca,
+ timeout=timeout,
+ keyfile=keyfile,
+ certfile=certfile,
+ protocol=protocol)
+ self.logRequests = logRequests
+ self.serve = False
+ self.register = register
+ self.register_introspection_functions()
+ self.register_function(self.ping)
+ self.logger.info("service available at %s" % self.url)
+ self.timeout = timeout
+
+ def _tasks_thread(self):
+ try:
+ while self.serve:
+ try:
+ if self.instance and hasattr(self.instance, 'do_tasks'):
+ self.instance.do_tasks()
+ except:
+ self.logger.error("Unexpected task failure", exc_info=1)
+ time.sleep(self.timeout)
+ except:
+ self.logger.error("tasks_thread failed", exc_info=1)
+
+ def server_close(self):
+ SSLServer.server_close(self)
+ self.logger.info("server_close()")
+
+ def _get_require_auth(self):
+ return getattr(self.RequestHandlerClass, "require_auth", False)
+
+ def _set_require_auth(self, value):
+ self.RequestHandlerClass.require_auth = value
+ require_auth = property(_get_require_auth, _set_require_auth)
+
+ def _get_credentials(self):
+ try:
+ return self.RequestHandlerClass.credentials
+ except AttributeError:
+ return dict()
+
+ def _set_credentials(self, value):
+ self.RequestHandlerClass.credentials = value
+ credentials = property(_get_credentials, _set_credentials)
+
+ def register_instance(self, instance, *args, **kwargs):
+ XMLRPCDispatcher.register_instance(self, instance, *args, **kwargs)
+ try:
+ name = instance.name
+ except AttributeError:
+ name = "unknown"
+ if hasattr(instance, '_get_rmi'):
+ for fname, func in instance._get_rmi().items():
+ self.register_function(func, name=fname)
+ self.logger.info("serving %s at %s" % (name, self.url))
+
+ def serve_forever(self):
+ """Serve single requests until (self.serve == False)."""
+ self.serve = True
+ self.task_thread = threading.Thread(target=self._tasks_thread)
+ self.task_thread.start()
+ self.logger.info("serve_forever() [start]")
+ signal.signal(signal.SIGINT, self._handle_shutdown_signal)
+ signal.signal(signal.SIGTERM, self._handle_shutdown_signal)
+
+ try:
+ while self.serve:
+ try:
+ self.handle_request()
+ except socket.timeout:
+ pass
+ except select.error:
+ pass
+ except:
+ self.logger.error("Got unexpected error in handle_request",
+ exc_info=1)
+ finally:
+ self.logger.info("serve_forever() [stop]")
+
+ def shutdown(self):
+ """Signal that automatic service should stop."""
+ self.serve = False
+
+ def _handle_shutdown_signal(self, *_):
+ self.shutdown()
+
+ def ping(self, *args):
+ """Echo response."""
+ self.logger.info("ping(%s)" % (", ".join([repr(arg) for arg in args])))
+ return args