From dab1d03d81c538966d03fb9318a4588a9e803b44 Mon Sep 17 00:00:00 2001 From: Sol Jerome Date: Sat, 24 Mar 2012 11:20:07 -0500 Subject: Allow to run directly from a git checkout (#1037) Signed-off-by: Sol Jerome --- src/lib/Bcfg2/SSLServer.py | 461 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 461 insertions(+) create mode 100644 src/lib/Bcfg2/SSLServer.py (limited to 'src/lib/Bcfg2/SSLServer.py') diff --git a/src/lib/Bcfg2/SSLServer.py b/src/lib/Bcfg2/SSLServer.py new file mode 100644 index 000000000..418e259cc --- /dev/null +++ b/src/lib/Bcfg2/SSLServer.py @@ -0,0 +1,461 @@ +"""Bcfg2 SSL server.""" + +__all__ = [ + "SSLServer", "XMLRPCRequestHandler", "XMLRPCServer", +] + +import os +import sys +import socket +import base64 +import select +import signal +import logging +import ssl +import threading +import time +import types +# Compatibility imports +from Bcfg2.Bcfg2Py3k import xmlrpclib, SimpleXMLRPCServer, SocketServer + + +class ForkedChild(Exception): + pass + + +class XMLRPCDispatcher (SimpleXMLRPCServer.SimpleXMLRPCDispatcher): + logger = logging.getLogger("Cobalt.Server.XMLRPCDispatcher") + + def __init__(self, allow_none, encoding): + try: + SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, + allow_none, + encoding) + except: + # Python 2.4? + SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self) + + self.allow_none = allow_none + self.encoding = encoding + + def _marshaled_dispatch(self, address, data): + method_func = None + params, method = xmlrpclib.loads(data) + try: + if '.' not in method: + params = (address, ) + params + response = self.instance._dispatch(method, params, self.funcs) + # py3k compatibility + if type(response) not in [bool, str, list, dict, types.NoneType]: + response = (response.decode('utf-8'), ) + else: + response = (response, ) + raw_response = xmlrpclib.dumps(response, methodresponse=1, + allow_none=self.allow_none, + encoding=self.encoding) + except xmlrpclib.Fault: + fault = sys.exc_info()[1] + raw_response = xmlrpclib.dumps(fault, + allow_none=self.allow_none, + encoding=self.encoding) + except: + self.logger.error("Unexpected handler error", exc_info=1) + # report exception back to server + raw_response = xmlrpclib.dumps( + xmlrpclib.Fault(1, "%s:%s" % (sys.exc_type, sys.exc_value)), + allow_none=self.allow_none, encoding=self.encoding) + return raw_response + + +class SSLServer (SocketServer.TCPServer, object): + + """TCP server supporting SSL encryption. + + Methods: + handshake -- perform a SSL/TLS handshake + + Properties: + url -- A url pointing to this server. + + """ + + allow_reuse_address = True + logger = logging.getLogger("Cobalt.Server.TCPServer") + + def __init__(self, listen_all, server_address, RequestHandlerClass, + keyfile=None, certfile=None, reqCert=False, ca=None, + timeout=None, protocol='xmlrpc/ssl'): + + """Initialize the SSL-TCP server. + + Arguments: + server_address -- address to bind to the server + RequestHandlerClass -- class to handle requests + + Keyword arguments: + keyfile -- private encryption key filename (enables ssl encryption) + certfile -- certificate file (enables ssl encryption) + reqCert -- client must present certificate + timeout -- timeout for non-blocking request handling + + """ + + if listen_all: + listen_address = ('', server_address[1]) + else: + listen_address = (server_address[0], server_address[1]) + try: + SocketServer.TCPServer.__init__(self, listen_address, + RequestHandlerClass) + except socket.error: + self.logger.error("Failed to bind to socket") + raise + + self.timeout = timeout + self.socket.settimeout(timeout) + self.keyfile = keyfile + if keyfile != None: + if keyfile == False or not os.path.exists(keyfile): + self.logger.error("Keyfile %s does not exist" % keyfile) + raise Exception("keyfile doesn't exist") + self.certfile = certfile + if certfile != None: + if certfile == False or not os.path.exists(certfile): + self.logger.error("Certfile %s does not exist" % certfile) + raise Exception("certfile doesn't exist") + self.ca = ca + if ca != None: + if ca == False or not os.path.exists(ca): + self.logger.error("CA %s does not exist" % ca) + raise Exception("ca doesn't exist") + self.reqCert = reqCert + if ca and certfile: + self.mode = ssl.CERT_OPTIONAL + else: + self.mode = ssl.CERT_NONE + if protocol == 'xmlrpc/ssl': + self.ssl_protocol = ssl.PROTOCOL_SSLv23 + elif protocol == 'xmlrpc/tlsv1': + self.ssl_protocol = ssl.PROTOCOL_TLSv1 + else: + self.logger.error("Unknown protocol %s" % (protocol)) + raise Exception("unknown protocol %s" % protocol) + + def get_request(self): + (sock, sockinfo) = self.socket.accept() + sock.settimeout(self.timeout) + sslsock = ssl.wrap_socket(sock, + server_side=True, + certfile=self.certfile, + keyfile=self.keyfile, + cert_reqs=self.mode, + ca_certs=self.ca, + ssl_version=self.ssl_protocol) + return sslsock, sockinfo + + def close_request(self, request): + try: + request.unwrap() + except: + pass + try: + request.close() + except: + pass + + def _get_url(self): + port = self.socket.getsockname()[1] + hostname = socket.gethostname() + protocol = "https" + return "%s://%s:%i" % (protocol, hostname, port) + url = property(_get_url) + + +class XMLRPCRequestHandler (SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): + + """Component XML-RPC request handler. + + Adds support for HTTP authentication. + + Exceptions: + CouldNotAuthenticate -- client did not present acceptable authentication information + + Methods: + authenticate -- prompt a check of a client's provided username and password + handle_one_request -- handle a single rpc (optionally authenticating) + + """ + logger = logging.getLogger("Cobalt.Server.XMLRPCRequestHandler") + + def authenticate(self): + try: + header = self.headers['Authorization'] + except KeyError: + self.logger.error("No authentication data presented") + return False + auth_type, auth_content = header.split() + try: + # py3k compatibility + auth_content = base64.standard_b64decode(auth_content) + except TypeError: + auth_content = base64.standard_b64decode(bytes(auth_content.encode('ascii'))) + try: + # py3k compatibility + try: + username, password = auth_content.split(":") + except TypeError: + username, pw = auth_content.split(bytes(":", encoding='utf-8')) + password = pw.decode('utf-8') + except ValueError: + username = auth_content + password = "" + cert = self.request.getpeercert() + client_address = self.request.getpeername() + return self.server.instance.authenticate(cert, username, + password, client_address) + + def parse_request(self): + """Extends parse_request. + + Optionally check HTTP authentication when parsing. + + """ + if not SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.parse_request(self): + return False + try: + if not self.authenticate(): + self.logger.error("Authentication Failure") + self.send_error(401, self.responses[401][0]) + return False + except: + self.logger.error("Unexpected Authentication Failure", exc_info=1) + self.send_error(401, self.responses[401][0]) + return False + return True + + ### need to override do_POST here + def do_POST(self): + try: + max_chunk_size = 10 * 1024 * 1024 + size_remaining = int(self.headers["content-length"]) + L = [] + while size_remaining: + try: + select.select([self.rfile.fileno()], [], [], 3) + except select.error: + print("got select timeout") + raise + chunk_size = min(size_remaining, max_chunk_size) + L.append(self.rfile.read(chunk_size).decode('utf-8')) + size_remaining -= len(L[-1]) + data = ''.join(L) + response = self.server._marshaled_dispatch(self.client_address, + data) + if sys.hexversion >= 0x03000000: + response = response.encode('utf-8') + except: + try: + self.send_response(500) + self.end_headers() + except: + (type, msg) = sys.exc_info()[:2] + self.logger.error("Error sending 500 response (%s): %s" % \ + (type, msg)) + raise + else: + # got a valid XML RPC response + try: + self.send_response(200) + self.send_header("Content-type", "text/xml") + self.send_header("Content-length", str(len(response))) + self.end_headers() + failcount = 0 + while True: + try: + # If we hit SSL3_WRITE_PENDING here try to resend. + self.wfile.write(response) + break + except ssl.SSLError: + e = sys.exc_info()[1] + if str(e).find("SSL3_WRITE_PENDING") < 0: + raise + self.logger.error("SSL3_WRITE_PENDING") + failcount += 1 + if failcount < 5: + continue + raise + except: + (type, msg) = sys.exc_info()[:2] + if str(type) == 'socket.error' and msg[0] == 32: + self.logger.warning("Connection dropped from %s" % + self.client_address[0]) + elif str(type) == 'socket.error' and msg[0] == 104: + self.logger.warning("Connection reset by peer: %s" % + self.client_address[0]) + elif str(type) == 'ssl.SSLError': + self.logger.warning("SSLError handling client %s: %s" % + (self.client_address[0], msg)) + else: + self.logger.error("Error sending response (%s): %s" % + (type, msg)) + + def finish(self): + # shut down the connection + if not self.wfile.closed: + try: + self.wfile.flush() + self.wfile.close() + except socket.error: + err = sys.exc_info()[1] + self.logger.warning("Error closing connection: %s" % err) + self.rfile.close() + + +class XMLRPCServer (SocketServer.ThreadingMixIn, SSLServer, + XMLRPCDispatcher, object): + + """Component XMLRPCServer. + + Methods: + serve_daemon -- serve_forever in a daemonized process + serve_forever -- handle_one_request until not self.serve + shutdown -- stop serve_forever (by setting self.serve = False) + ping -- return all arguments received + + RPC methods: + ping + + (additional system.* methods are inherited from base dispatcher) + + Properties: + require_auth -- the request handler is requiring authorization + credentials -- valid credentials being used for authentication + + """ + + def __init__(self, listen_all, server_address, RequestHandlerClass=None, + keyfile=None, certfile=None, ca=None, protocol='xmlrpc/ssl', + timeout=10, + logRequests=False, + register=True, allow_none=True, encoding=None): + + """Initialize the XML-RPC server. + + Arguments: + server_address -- address to bind to the server + RequestHandlerClass -- request handler used by TCP server (optional) + + Keyword arguments: + keyfile -- private encryption key filename + certfile -- certificate file + logRequests -- log all requests (default False) + register -- presence should be reported to service-location (default True) + allow_none -- allow None values in xml-rpc + encoding -- encoding to use for xml-rpc (default UTF-8) + + """ + + XMLRPCDispatcher.__init__(self, allow_none, encoding) + + if not RequestHandlerClass: + class RequestHandlerClass (XMLRPCRequestHandler): + """A subclassed request handler to prevent class-attribute conflicts.""" + + SSLServer.__init__(self, + listen_all, + server_address, + RequestHandlerClass, + ca=ca, + timeout=timeout, + keyfile=keyfile, + certfile=certfile, + protocol=protocol) + self.logRequests = logRequests + self.serve = False + self.register = register + self.register_introspection_functions() + self.register_function(self.ping) + self.logger.info("service available at %s" % self.url) + self.timeout = timeout + + def _tasks_thread(self): + try: + while self.serve: + try: + if self.instance and hasattr(self.instance, 'do_tasks'): + self.instance.do_tasks() + except: + self.logger.error("Unexpected task failure", exc_info=1) + time.sleep(self.timeout) + except: + self.logger.error("tasks_thread failed", exc_info=1) + + def server_close(self): + SSLServer.server_close(self) + self.logger.info("server_close()") + + def _get_require_auth(self): + return getattr(self.RequestHandlerClass, "require_auth", False) + + def _set_require_auth(self, value): + self.RequestHandlerClass.require_auth = value + require_auth = property(_get_require_auth, _set_require_auth) + + def _get_credentials(self): + try: + return self.RequestHandlerClass.credentials + except AttributeError: + return dict() + + def _set_credentials(self, value): + self.RequestHandlerClass.credentials = value + credentials = property(_get_credentials, _set_credentials) + + def register_instance(self, instance, *args, **kwargs): + XMLRPCDispatcher.register_instance(self, instance, *args, **kwargs) + try: + name = instance.name + except AttributeError: + name = "unknown" + if hasattr(instance, 'plugins'): + for pname, pinst in list(instance.plugins.items()): + for mname in pinst.__rmi__: + xmname = "%s.%s" % (pname, mname) + fn = getattr(pinst, mname) + self.register_function(fn, name=xmname) + self.logger.info("serving %s at %s" % (name, self.url)) + + def serve_forever(self): + """Serve single requests until (self.serve == False).""" + self.serve = True + self.task_thread = threading.Thread(target=self._tasks_thread) + self.task_thread.start() + self.logger.info("serve_forever() [start]") + signal.signal(signal.SIGINT, self._handle_shutdown_signal) + signal.signal(signal.SIGTERM, self._handle_shutdown_signal) + + try: + while self.serve: + try: + self.handle_request() + except socket.timeout: + pass + except select.error: + pass + except: + self.logger.error("Got unexpected error in handle_request", + exc_info=1) + finally: + self.logger.info("serve_forever() [stop]") + + def shutdown(self): + """Signal that automatic service should stop.""" + self.serve = False + + def _handle_shutdown_signal(self, *_): + self.shutdown() + + def ping(self, *args): + """Echo response.""" + self.logger.info("ping(%s)" % (", ".join([repr(arg) for arg in args]))) + return args -- cgit v1.2.3-1-g7c22