# Copyright: 2005 Gentoo Foundation # Author(s): Brian Harring (ferringb@gentoo.org) # License: GPL2 import sys from portage.cache import template, cache_errors from portage.cache.template import reconstruct_eclasses class SQLDatabase(template.database): """template class for RDBM based caches This class is designed such that derivatives don't have to change much code, mostly constant strings. _BaseError must be an exception class that all Exceptions thrown from the derived RDBMS are derived from. SCHEMA_INSERT_CPV_INTO_PACKAGE should be modified dependant on the RDBMS, as should SCHEMA_PACKAGE_CREATE- basically you need to deal with creation of a unique pkgid. If the dbapi2 rdbms class has a method of recovering that id, then modify _insert_cpv to remove the extra select. Creation of a derived class involves supplying _initdb_con, and table_exists. Additionally, the default schemas may have to be modified. """ SCHEMA_PACKAGE_NAME = "package_cache" SCHEMA_PACKAGE_CREATE = "CREATE TABLE %s (\ pkgid INTEGER PRIMARY KEY, label VARCHAR(255), cpv VARCHAR(255), UNIQUE(label, cpv))" % SCHEMA_PACKAGE_NAME SCHEMA_PACKAGE_DROP = "DROP TABLE %s" % SCHEMA_PACKAGE_NAME SCHEMA_VALUES_NAME = "values_cache" SCHEMA_VALUES_CREATE = "CREATE TABLE %s ( pkgid integer references %s (pkgid) on delete cascade, \ key varchar(255), value text, UNIQUE(pkgid, key))" % (SCHEMA_VALUES_NAME, SCHEMA_PACKAGE_NAME) SCHEMA_VALUES_DROP = "DROP TABLE %s" % SCHEMA_VALUES_NAME SCHEMA_INSERT_CPV_INTO_PACKAGE = "INSERT INTO %s (label, cpv) VALUES(%%s, %%s)" % SCHEMA_PACKAGE_NAME _BaseError = () _dbClass = None autocommits = False # cleanse_keys = True # boolean indicating if the derived RDBMS class supports replace syntax _supports_replace = False def __init__(self, location, label, auxdbkeys, *args, **config): """initialize the instance. derived classes shouldn't need to override this""" super(SQLDatabase, self).__init__(location, label, auxdbkeys, *args, **config) config.setdefault("host","127.0.0.1") config.setdefault("autocommit", self.autocommits) self._initdb_con(config) self.label = self._sfilter(self.label) def _dbconnect(self, config): """should be overridden if the derived class needs special parameters for initializing the db connection, or cursor""" self.db = self._dbClass(**config) self.con = self.db.cursor() def _initdb_con(self,config): """ensure needed tables are in place. If the derived class needs a different set of table creation commands, overload the approriate SCHEMA_ attributes. If it needs additional execution beyond, override""" self._dbconnect(config) if not self._table_exists(self.SCHEMA_PACKAGE_NAME): if self.readonly: raise cache_errors.ReadOnlyRestriction("table %s doesn't exist" % \ self.SCHEMA_PACKAGE_NAME) try: self.con.execute(self.SCHEMA_PACKAGE_CREATE) except self._BaseError as e: raise cache_errors.InitializationError(self.__class__, e) if not self._table_exists(self.SCHEMA_VALUES_NAME): if self.readonly: raise cache_errors.ReadOnlyRestriction("table %s doesn't exist" % \ self.SCHEMA_VALUES_NAME) try: self.con.execute(self.SCHEMA_VALUES_CREATE) except self._BaseError as e: raise cache_errors.InitializationError(self.__class__, e) def _table_exists(self, tbl): """return true if a table exists derived classes must override this""" raise NotImplementedError def _sfilter(self, s): """meta escaping, returns quoted string for use in sql statements""" return "\"%s\"" % s.replace("\\","\\\\").replace("\"","\\\"") def _getitem(self, cpv): try: self.con.execute("SELECT key, value FROM %s NATURAL JOIN %s " "WHERE label=%s AND cpv=%s" % (self.SCHEMA_PACKAGE_NAME, self.SCHEMA_VALUES_NAME, self.label, self._sfilter(cpv))) except self._BaseError as e: raise cache_errors.CacheCorruption(self, cpv, e) rows = self.con.fetchall() if len(rows) == 0: raise KeyError(cpv) vals = dict([(k,"") for k in self._known_keys]) vals.update(dict(rows)) return vals def _delitem(self, cpv): """delete a cpv cache entry derived RDBM classes for this *must* either support cascaded deletes, or override this method""" try: try: self.con.execute("DELETE FROM %s WHERE label=%s AND cpv=%s" % \ (self.SCHEMA_PACKAGE_NAME, self.label, self._sfilter(cpv))) if self.autocommits: self.commit() except self._BaseError as e: raise cache_errors.CacheCorruption(self, cpv, e) if self.con.rowcount <= 0: raise KeyError(cpv) except SystemExit: raise except Exception: if not self.autocommits: self.db.rollback() # yes, this can roll back a lot more then just the delete. deal. raise def __del__(self): # just to be safe. if "db" in self.__dict__ and self.db != None: self.commit() self.db.close() def _setitem(self, cpv, values): try: # insert. try: pkgid = self._insert_cpv(cpv) except self._BaseError as e: raise cache_errors.CacheCorruption(cpv, e) # __getitem__ fills out missing values, # so we store only what's handed to us and is a known key db_values = [] for key in self._known_keys: if key in values and values[key]: db_values.append({"key":key, "value":values[key]}) if len(db_values) > 0: try: self.con.executemany("INSERT INTO %s (pkgid, key, value) VALUES(\"%s\", %%(key)s, %%(value)s)" % \ (self.SCHEMA_VALUES_NAME, str(pkgid)), db_values) except self._BaseError as e: raise cache_errors.CacheCorruption(cpv, e) if self.autocommits: self.commit() except SystemExit: raise except Exception: if not self.autocommits: try: self.db.rollback() except self._BaseError: pass raise def _insert_cpv(self, cpv): """uses SCHEMA_INSERT_CPV_INTO_PACKAGE, which must be overloaded if the table definition doesn't support auto-increment columns for pkgid. returns the cpvs new pkgid note this doesn't commit the transaction. The caller is expected to.""" cpv = self._sfilter(cpv) if self._supports_replace: query_str = self.SCHEMA_INSERT_CPV_INTO_PACKAGE.replace("INSERT","REPLACE",1) else: # just delete it. try: del self[cpv] except (cache_errors.CacheCorruption, KeyError): pass query_str = self.SCHEMA_INSERT_CPV_INTO_PACKAGE try: self.con.execute(query_str % (self.label, cpv)) except self._BaseError: self.db.rollback() raise self.con.execute("SELECT pkgid FROM %s WHERE label=%s AND cpv=%s" % \ (self.SCHEMA_PACKAGE_NAME, self.label, cpv)) if self.con.rowcount != 1: raise cache_error.CacheCorruption(cpv, "Tried to insert the cpv, but found " " %i matches upon the following select!" % len(rows)) return self.con.fetchone()[0] def __contains__(self, cpv): if not self.autocommits: try: self.commit() except self._BaseError as e: raise cache_errors.GeneralCacheCorruption(e) try: self.con.execute("SELECT cpv FROM %s WHERE label=%s AND cpv=%s" % \ (self.SCHEMA_PACKAGE_NAME, self.label, self._sfilter(cpv))) except self._BaseError as e: raise cache_errors.GeneralCacheCorruption(e) return self.con.rowcount > 0 def __iter__(self): if not self.autocommits: try: self.commit() except self._BaseError as e: raise cache_errors.GeneralCacheCorruption(e) try: self.con.execute("SELECT cpv FROM %s WHERE label=%s" % (self.SCHEMA_PACKAGE_NAME, self.label)) except self._BaseError as e: raise cache_errors.GeneralCacheCorruption(e) # return [ row[0] for row in self.con.fetchall() ] for x in self.con.fetchall(): yield x[0] def iteritems(self): try: self.con.execute("SELECT cpv, key, value FROM %s NATURAL JOIN %s " "WHERE label=%s" % (self.SCHEMA_PACKAGE_NAME, self.SCHEMA_VALUES_NAME, self.label)) except self._BaseError as e: raise cache_errors.CacheCorruption(self, cpv, e) oldcpv = None l = [] for x, y, v in self.con.fetchall(): if oldcpv != x: if oldcpv != None: d = dict(l) if "_eclasses_" in d: d["_eclasses_"] = reconstruct_eclasses(oldcpv, d["_eclasses_"]) else: d["_eclasses_"] = {} yield cpv, d l.clear() oldcpv = x l.append((y,v)) if oldcpv != None: d = dict(l) if "_eclasses_" in d: d["_eclasses_"] = reconstruct_eclasses(oldcpv, d["_eclasses_"]) else: d["_eclasses_"] = {} yield cpv, d def commit(self): self.db.commit() def get_matches(self,match_dict): query_list = [] for k,v in match_dict.items(): if k not in self._known_keys: raise cache_errors.InvalidRestriction(k, v, "key isn't known to this cache instance") v = v.replace("%","\\%") v = v.replace(".*","%") query_list.append("(key=%s AND value LIKE %s)" % (self._sfilter(k), self._sfilter(v))) if len(query_list): query = " AND "+" AND ".join(query_list) else: query = '' print("query = SELECT cpv from package_cache natural join values_cache WHERE label=%s %s" % (self.label, query)) try: self.con.execute("SELECT cpv from package_cache natural join values_cache WHERE label=%s %s" % \ (self.label, query)) except self._BaseError as e: raise cache_errors.GeneralCacheCorruption(e) return [ row[0] for row in self.con.fetchall() ] if sys.hexversion >= 0x3000000: items = iteritems keys = __iter__