summaryrefslogtreecommitdiffstats
path: root/src/lib/Component.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/Component.py')
-rw-r--r--src/lib/Component.py49
1 files changed, 42 insertions, 7 deletions
diff --git a/src/lib/Component.py b/src/lib/Component.py
index 7ebdf8f86..5bf61452c 100644
--- a/src/lib/Component.py
+++ b/src/lib/Component.py
@@ -1,7 +1,7 @@
'''Cobalt component base classes'''
__revision__ = '$Revision$'
-import atexit, logging, select, signal, socket, sys, time, urlparse, xmlrpclib, cPickle, ConfigParser
+import atexit, logging, select, signal, socket, sys, time, urlparse, xmlrpclib, cPickle, ConfigParser, os
from base64 import decodestring
import BaseHTTPServer, SimpleXMLRPCServer
import Bcfg2.tlslite.errors
@@ -20,15 +20,17 @@ class ComponentKeyError(Exception):
'''raised in case of key parse fails'''
pass
+class ForkedChild(Exception):
+ '''raised after child has been forked'''
+ pass
+
class CobaltXMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
'''CobaltXMLRPCRequestHandler takes care of ssl xmlrpc requests'''
- def finish(self):
- '''Finish HTTPS connections properly'''
- self.request.close()
def do_POST(self):
'''Overload do_POST to pass through client address information'''
try:
+ self.cleanup = True
# get arguments
data = self.rfile.read(int(self.headers["content-length"]))
@@ -48,6 +50,9 @@ class CobaltXMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
authenticated = True
response = self.server._cobalt_marshalled_dispatch(data, self.client_address, authenticated)
+ except ForkedChild:
+ self.cleanup = False
+ return
except: # This should only happen if the module is buggy
# internal error, report as HTTP server error
log.error("Unexcepted handler failure in do_POST", exc_info=1)
@@ -81,6 +86,7 @@ class TLSServer(Bcfg2.tlslite.api.TLSSocketServerMixIn,
reqCert=False):
self.sc = Bcfg2.tlslite.api.SessionCache()
self.rc = reqCert
+ self.master = os.getpid()
x509 = Bcfg2.tlslite.api.X509()
s = open(keyfile).read()
x509.parse(s)
@@ -92,12 +98,15 @@ class TLSServer(Bcfg2.tlslite.api.TLSSocketServerMixIn,
self.chain = Bcfg2.tlslite.api.X509CertChain([x509])
BaseHTTPServer.HTTPServer.__init__(self, address, handler)
- def finish_request(self, sock, client_address):
+ def finish_request(self, sock, address):
sock.settimeout(90)
tlsConnection = TLSConnection(sock)
if self.handshake(tlsConnection) == True:
- self.RequestHandlerClass(tlsConnection, client_address, self)
- tlsConnection.close()
+ req = self.RequestHandlerClass(tlsConnection, address, self)
+ if req.cleanup:
+ tlsConnection.close()
+ if os.getpid() != self.master:
+ os._exit(0)
def handshake(self, tlsConnection):
try:
@@ -125,6 +134,8 @@ class Component(TLSServer,
__implementation__ = 'Generic'
__statefields__ = []
async_funcs = ['assert_location']
+ fork_funcs = []
+ child_limit = 32
def __init__(self, setup):
# need to get addr
@@ -134,6 +145,7 @@ class Component(TLSServer,
signal.signal(signal.SIGTERM, self.start_shutdown)
self.logger = logging.getLogger('Component')
self.cfile = ConfigParser.ConfigParser()
+ self.children = []
if setup['configfile']:
cfilename = setup['configfile']
else:
@@ -211,6 +223,13 @@ class Component(TLSServer,
params = rawparams[0:]
# generate response
try:
+ # need to add waitpid code here to enforce maxchild
+ if method in self.fork_funcs:
+ self.clean_up_children()
+ pid = os.fork()
+ if pid:
+ self.children.append(pid)
+ raise ForkedChild
# all handlers must take address as the first argument
response = self._dispatch(method, (address, ) + params)
# wrap response in a singleton tuple
@@ -222,6 +241,8 @@ class Component(TLSServer,
self.logger.error("Client %s called function %s with wrong argument count" %
(address[0], method), exc_info=1)
response = xmlrpclib.dumps(xmlrpclib.Fault(4, terror.args[0]))
+ except ForkedChild:
+ raise
except:
self.logger.error("Unexpected handler failure", exc_info=1)
# report exception back to server
@@ -229,6 +250,20 @@ class Component(TLSServer,
"%s:%s" % (sys.exc_type, sys.exc_value)))
return response
+ def clean_up_children(self):
+ while True:
+ try:
+ pid = os.waitpid(0, os.WNOHANG)[0]
+ self.children.remove(pid)
+ self.logger.debug("process %d exited" % pid)
+ except OSError:
+ break
+ if len(self.children) >= self.child_limit:
+ self.logger.info("Reached child_limit; waiting for child exit")
+ pid = os.waitpid(0, 0)[0]
+ self.children.remove(pid)
+ self.logger.debug("process %d exited" % pid)
+
def _authenticate_connection(self, method, user, password, address):
'''Authenticate new connection'''
(user, address, method)