diff options
Diffstat (limited to 'src/lib/Component.py')
-rw-r--r-- | src/lib/Component.py | 49 |
1 files changed, 42 insertions, 7 deletions
diff --git a/src/lib/Component.py b/src/lib/Component.py index 7ebdf8f86..5bf61452c 100644 --- a/src/lib/Component.py +++ b/src/lib/Component.py @@ -1,7 +1,7 @@ '''Cobalt component base classes''' __revision__ = '$Revision$' -import atexit, logging, select, signal, socket, sys, time, urlparse, xmlrpclib, cPickle, ConfigParser +import atexit, logging, select, signal, socket, sys, time, urlparse, xmlrpclib, cPickle, ConfigParser, os from base64 import decodestring import BaseHTTPServer, SimpleXMLRPCServer import Bcfg2.tlslite.errors @@ -20,15 +20,17 @@ class ComponentKeyError(Exception): '''raised in case of key parse fails''' pass +class ForkedChild(Exception): + '''raised after child has been forked''' + pass + class CobaltXMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): '''CobaltXMLRPCRequestHandler takes care of ssl xmlrpc requests''' - def finish(self): - '''Finish HTTPS connections properly''' - self.request.close() def do_POST(self): '''Overload do_POST to pass through client address information''' try: + self.cleanup = True # get arguments data = self.rfile.read(int(self.headers["content-length"])) @@ -48,6 +50,9 @@ class CobaltXMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): authenticated = True response = self.server._cobalt_marshalled_dispatch(data, self.client_address, authenticated) + except ForkedChild: + self.cleanup = False + return except: # This should only happen if the module is buggy # internal error, report as HTTP server error log.error("Unexcepted handler failure in do_POST", exc_info=1) @@ -81,6 +86,7 @@ class TLSServer(Bcfg2.tlslite.api.TLSSocketServerMixIn, reqCert=False): self.sc = Bcfg2.tlslite.api.SessionCache() self.rc = reqCert + self.master = os.getpid() x509 = Bcfg2.tlslite.api.X509() s = open(keyfile).read() x509.parse(s) @@ -92,12 +98,15 @@ class TLSServer(Bcfg2.tlslite.api.TLSSocketServerMixIn, self.chain = Bcfg2.tlslite.api.X509CertChain([x509]) BaseHTTPServer.HTTPServer.__init__(self, address, handler) - def finish_request(self, sock, client_address): + def finish_request(self, sock, address): sock.settimeout(90) tlsConnection = TLSConnection(sock) if self.handshake(tlsConnection) == True: - self.RequestHandlerClass(tlsConnection, client_address, self) - tlsConnection.close() + req = self.RequestHandlerClass(tlsConnection, address, self) + if req.cleanup: + tlsConnection.close() + if os.getpid() != self.master: + os._exit(0) def handshake(self, tlsConnection): try: @@ -125,6 +134,8 @@ class Component(TLSServer, __implementation__ = 'Generic' __statefields__ = [] async_funcs = ['assert_location'] + fork_funcs = [] + child_limit = 32 def __init__(self, setup): # need to get addr @@ -134,6 +145,7 @@ class Component(TLSServer, signal.signal(signal.SIGTERM, self.start_shutdown) self.logger = logging.getLogger('Component') self.cfile = ConfigParser.ConfigParser() + self.children = [] if setup['configfile']: cfilename = setup['configfile'] else: @@ -211,6 +223,13 @@ class Component(TLSServer, params = rawparams[0:] # generate response try: + # need to add waitpid code here to enforce maxchild + if method in self.fork_funcs: + self.clean_up_children() + pid = os.fork() + if pid: + self.children.append(pid) + raise ForkedChild # all handlers must take address as the first argument response = self._dispatch(method, (address, ) + params) # wrap response in a singleton tuple @@ -222,6 +241,8 @@ class Component(TLSServer, self.logger.error("Client %s called function %s with wrong argument count" % (address[0], method), exc_info=1) response = xmlrpclib.dumps(xmlrpclib.Fault(4, terror.args[0])) + except ForkedChild: + raise except: self.logger.error("Unexpected handler failure", exc_info=1) # report exception back to server @@ -229,6 +250,20 @@ class Component(TLSServer, "%s:%s" % (sys.exc_type, sys.exc_value))) return response + def clean_up_children(self): + while True: + try: + pid = os.waitpid(0, os.WNOHANG)[0] + self.children.remove(pid) + self.logger.debug("process %d exited" % pid) + except OSError: + break + if len(self.children) >= self.child_limit: + self.logger.info("Reached child_limit; waiting for child exit") + pid = os.waitpid(0, 0)[0] + self.children.remove(pid) + self.logger.debug("process %d exited" % pid) + def _authenticate_connection(self, method, user, password, address): '''Authenticate new connection''' (user, address, method) |