from sqlalchemy import event, Column, String from sqlalchemy.orm.session import Session class APIMixin(object): api_id = Column(String(255), nullable=False) def get_api_id(self): if self.api_id is None: self.create_api_object() return self.api_id def before_created(self): if self.api_id is None: self.create_api_object() def after_deleted(self): if self.api_id is not None: self.remove_api_object() self.api_id = None def create_api_object(self): raise NotImplementedError def remove_api_object(self): raise NotImplementedError def after_commit(self): pass class SessionHelper(object): def __init__(self): self.new = set() self.dirty = set() self.deleted = set() event.listen(Session, 'before_flush', self.before_flush) event.listen(Session, 'after_flush', self.after_flush) event.listen(Session, 'after_commit', self.after_commit) event.listen(Session, 'after_rollback', self.after_rollback) def before_flush(self, session, flush_context, instances): for obj in session.new: if isinstance(obj, APIMixin): obj.before_created() def after_flush(self, session, flush_context): self.new.update( obj for obj in session.new if isinstance(obj, APIMixin)) self.dirty.update( obj for obj in session.dirty if isinstance(obj, APIMixin)) self.deleted.update( obj for obj in session.deleted if isinstance(obj, APIMixin)) def after_commit(self, session): for obj in self.new: obj.after_commit() self.new.clear() for obj in self.dirty: obj.after_commit() self.dirty.clear() for obj in self.deleted: obj.after_deleted() self.deleted.clear() def after_rollback(self, session): self.dirty.clear() self.deleted.clear() for obj in self.new: obj.after_deleted() self.new.clear() helper = SessionHelper()