import copy import datetime import locale import os import pickle import sys from ._compat import PY2, string_types, pjoin, iteritems, to_bytes, exists from ._load import portalocker from .helpers.classes import SQLCustomType, DatabaseStoredFile class Migrator(object): def __init__(self, adapter): self.adapter = adapter @property def db(self): return self.adapter.db @property def dialect(self): return self.adapter.dialect @property def dbengine(self): return self.adapter.dbengine def create_table(self, table, migrate=True, fake_migrate=False, polymodel=None): db = table._db table._migrate = migrate fields = [] # PostGIS geo fields are added after the table has been created postcreation_fields = [] sql_fields = {} sql_fields_aux = {} TFK = {} tablename = table._tablename types = self.adapter.types for sortable, field in enumerate(table, start=1): if self.db._ignore_field_case: field_name = field.name.lower() field_rname = field._rname.lower() else: field_name = field.name field_rname = field._rname if self.dbengine == "oracle": # Oracle needs all field names quoted to ensure consistent case field_rname = self.dialect.quote(field_rname) field_type = field.type if isinstance(field_type, SQLCustomType): ftype = field_type.native or field_type.type elif field_type.startswith(("reference", "big-reference")): if field_type.startswith("reference"): referenced = field_type[10:].strip() type_name = "reference" else: referenced = field_type[14:].strip() type_name = "big-reference" if referenced == ".": referenced = tablename constraint_name = self.dialect.constraint_name( table._raw_rname, field._raw_rname ) # if not '.' in referenced \ # and referenced != tablename \ # and hasattr(table,'_primarykey'): # ftype = types['integer'] # else: try: rtable = db[referenced] rfield = rtable._id rfieldname = rfield.name rtablename = referenced except (KeyError, ValueError, AttributeError) as e: self.db.logger.debug("Error: %s" % e) try: rtablename, rfieldname = referenced.split(".") rtable = db[rtablename] rfield = rtable[rfieldname] except Exception as e: self.db.logger.debug("Error: %s" % e) raise KeyError( "Cannot resolve reference %s in %s definition" % (referenced, table._tablename) ) # must be PK reference or unique rfield_rname = rfield._rname if self.dbengine == "oracle": rfield_rname = self.dialect.quote(rfield_rname) if ( not rfield.type.startswith(("reference", "big-reference")) and getattr(rtable, "_primarykey", None) and rfieldname in rtable._primarykey or rfield.unique ): ftype = types[rfield.type[:9]] % dict(length=rfield.length) # multicolumn primary key reference? if not rfield.unique and len(rtable._primarykey) > 1: # then it has to be a table level FK if rtablename not in TFK: TFK[rtablename] = {} TFK[rtablename][rfieldname] = field_name else: fk = rtable._rname + " (" + rfield._rname + ")" if self.dbengine == "oracle": fk = ( self.dialect.quote(rtable._rname) + " (" + rfield_rname + ")" ) ftype = ftype + types["reference FK"] % dict( # should be quoted constraint_name=constraint_name, foreign_key=fk, table_name=table._rname, field_name=field._rname, on_delete_action=field.ondelete, ) else: # make a guess here for circular references if referenced in db: id_fieldname = db[referenced]._id._rname elif referenced == tablename: id_fieldname = table._id._rname else: # make a guess id_fieldname = self.dialect.quote("id") # gotcha: the referenced table must be defined before # the referencing one to be able to create the table # Also if it's not recommended, we can still support # references to tablenames without rname to make # migrations and model relationship work also if tables # are not defined in order if referenced == tablename: real_referenced = db[referenced]._rname else: real_referenced = ( referenced in db and db[referenced]._rname or referenced ) if self.dbengine == "oracle": real_referenced = self.dialect.quote(real_referenced) rfield = db[referenced]._id ftype_info = dict( index_name=self.dialect.quote(field._raw_rname + "__idx"), field_name=field_rname, constraint_name=self.dialect.quote(constraint_name), foreign_key="%s (%s)" % (real_referenced, rfield_rname), on_delete_action=field.ondelete, ) ftype_info["null"] = ( " NOT NULL" if field.notnull else self.dialect.allow_null ) ftype_info["unique"] = " UNIQUE" if field.unique else "" ftype = types[type_name] % ftype_info elif field_type.startswith("list:reference"): ftype = types[field_type[:14]] elif field_type.startswith("decimal"): precision, scale = map(int, field_type[8:-1].split(",")) ftype = types[field_type[:7]] % dict(precision=precision, scale=scale) elif field_type.startswith("geo"): if not hasattr(self.adapter, "srid"): raise RuntimeError("Adapter does not support geometry") srid = self.adapter.srid geotype, parms = field_type[:-1].split("(") if geotype not in types: raise SyntaxError( "Field: unknown field type: %s for %s" % (field_type, field_name) ) ftype = types[geotype] if self.dbengine == "postgres" and geotype == "geometry": if self.db._ignore_field_case is True: field_name = field_name.lower() # parameters: schema, srid, dimension dimension = 2 # GIS.dimension ??? parms = parms.split(",") if len(parms) == 3: schema, srid, dimension = parms elif len(parms) == 2: schema, srid = parms else: schema = parms[0] ftype = ( "SELECT AddGeometryColumn ('%%(schema)s', '%%(tablename)s', '%%(fieldname)s', %%(srid)s, '%s', %%(dimension)s);" % types[geotype] ) ftype = ftype % dict( schema=schema, tablename=table._raw_rname, fieldname=field._raw_rname, srid=srid, dimension=dimension, ) postcreation_fields.append(ftype) elif field_type not in types: raise SyntaxError( "Field: unknown field type: %s for %s" % (field_type, field_name) ) else: ftype = types[field_type] % {"length": field.length} if not field_type.startswith(("id", "reference", "big-reference")): if field.notnull: ftype += " NOT NULL" else: ftype += self.dialect.allow_null if field.unique: ftype += " UNIQUE" if field.custom_qualifier: ftype += " %s" % field.custom_qualifier # add to list of fields sql_fields[field_name] = dict( length=field.length, unique=field.unique, notnull=field.notnull, sortable=sortable, type=str(field_type), sql=ftype, rname=field_rname, raw_rname=field._raw_rname, ) if field.notnull and field.default is not None: # Caveat: sql_fields and sql_fields_aux # differ for default values. # sql_fields is used to trigger migrations and sql_fields_aux # is used for create tables. # The reason is that we do not want to trigger # a migration simply because a default value changes. not_null = self.dialect.not_null(field.default, field_type) ftype = ftype.replace("NOT NULL", not_null) sql_fields_aux[field_name] = dict(sql=ftype) # Postgres - PostGIS: # geometry fields are added after the table has been created, not now if not (self.dbengine == "postgres" and field_type.startswith("geom")): fields.append("%s %s" % (field_rname, ftype)) other = ";" # backend-specific extensions to fields if self.dbengine == "mysql": if not hasattr(table, "_primarykey"): fields.append("PRIMARY KEY (%s)" % (table._id._rname)) engine = self.adapter.adapter_args.get("engine", "InnoDB") other = " ENGINE=%s CHARACTER SET utf8;" % engine fields = ",\n ".join(fields) for rtablename in TFK: rtable = db[rtablename] rfields = TFK[rtablename] pkeys = [rtable[pk]._rname for pk in rtable._primarykey] fk_fields = [table[rfields[k]] for k in rtable._primarykey] fkeys = [f._rname for f in fk_fields] constraint_name = self.dialect.constraint_name( table._raw_rname, "_".join(f._raw_rname for f in fk_fields) ) on_delete = list(set(f.ondelete for f in fk_fields)) if len(on_delete) > 1: raise SyntaxError( "Table %s has incompatible ON DELETE actions in multi-field foreign key." % table._dalname ) tfk_field_name = ", ".join(fkeys) tfk_foreign_key = ", ".join(pkeys) tfk_foreign_table = rtable._rname if self.dbengine == "oracle": tfk_field_name = ", ".join([self.dialect.quote(fkey) for fkey in fkeys]) tfk_foreign_key = ", ".join( [self.dialect.quote(pkey) for pkey in pkeys] ) tfk_foreign_table = self.dialect.quote(rtable._rname) fields = ( fields + ",\n " + types["reference TFK"] % dict( constraint_name=constraint_name, table_name=table._rname, field_name=tfk_field_name, foreign_table=tfk_foreign_table, foreign_key=tfk_foreign_key, on_delete_action=on_delete[0], ) ) table_rname = table._rname if self.dbengine == "oracle": # must be explicitly quoted to preserve case table_rname = self.dialect.quote(table_rname) if getattr(table, "_primarykey", None): query = "CREATE TABLE %s(\n %s,\n %s) %s" % ( table_rname, fields, self.dialect.primary_key( ", ".join([table[pk]._rname for pk in table._primarykey]) ), other, ) else: query = "CREATE TABLE %s(\n %s\n)%s" % (table_rname, fields, other) uri = self.adapter.uri if uri.startswith("sqlite:///") or uri.startswith("spatialite:///"): if PY2: path_encoding = ( sys.getfilesystemencoding() or locale.getdefaultlocale()[1] or "utf8" ) dbpath = uri[9 : uri.rfind("/")].decode("utf8").encode(path_encoding) else: dbpath = uri[9 : uri.rfind("/")] else: dbpath = self.adapter.folder if not migrate: return query elif uri.startswith("sqlite:memory") or uri.startswith("spatialite:memory"): table._dbt = None elif isinstance(migrate, string_types): table._dbt = pjoin(dbpath, migrate) else: table._dbt = pjoin(dbpath, "%s_%s.table" % (db._uri_hash, tablename)) if not table._dbt or not self.file_exists(table._dbt): if table._dbt: self.log( "timestamp: %s\n%s\n" % (datetime.datetime.today().isoformat(), query), table, ) if not fake_migrate: self.adapter.create_sequence_and_triggers(query, table) db.commit() # Postgres geom fields are added now, # after the table has been created for query in postcreation_fields: self.adapter.execute(query) db.commit() if table._dbt: tfile = self.file_open(table._dbt, "wb") pickle.dump(sql_fields, tfile) self.file_close(tfile) if fake_migrate: self.log("faked!\n", table) else: self.log("success!\n", table) else: tfile = self.file_open(table._dbt, "rb") try: sql_fields_old = pickle.load(tfile) except EOFError: self.file_close(tfile) raise RuntimeError("File %s appears corrupted" % table._dbt) self.file_close(tfile) # add missing rnames for key, item in sql_fields_old.items(): tmp = sql_fields.get(key) if tmp: item.setdefault("rname", tmp["rname"]) item.setdefault("raw_rname", tmp["raw_rname"]) else: item.setdefault("rname", self.dialect.quote(key)) item.setdefault("raw_rname", key) if sql_fields != sql_fields_old: self.migrate_table( table, sql_fields, sql_fields_old, sql_fields_aux, None, fake_migrate=fake_migrate, ) return query def _fix(self, item): k, v = item if self.dbengine == "oracle" and "rname" in v: v["rname"] = self.dialect.quote(v["rname"]) if not isinstance(v, dict): v = dict(type="unknown", sql=v) if self.db._ignore_field_case is not True: return k, v return k.lower(), v def migrate_table( self, table, sql_fields, sql_fields_old, sql_fields_aux, logfile, fake_migrate=False, ): # logfile is deprecated (moved to adapter.log method) db = table._db db._migrated.append(table._tablename) tablename = table._tablename if self.dbengine in ("firebird",): drop_expr = "ALTER TABLE %s DROP %s;" else: drop_expr = "ALTER TABLE %s DROP COLUMN %s;" field_types = dict( (x.lower(), table[x].type) for x in sql_fields.keys() if x in table ) # make sure all field names are lower case to avoid # migrations because of case change sql_fields = dict(map(self._fix, iteritems(sql_fields))) sql_fields_old = dict(map(self._fix, iteritems(sql_fields_old))) sql_fields_aux = dict(map(self._fix, iteritems(sql_fields_aux))) table_rname = table._rname if self.dbengine == "oracle": table_rname = self.dialect.quote(table_rname) if db._debug: db.logger.debug("migrating %s to %s" % (sql_fields_old, sql_fields)) keys = list(sql_fields.keys()) for key in sql_fields_old: if key not in keys: keys.append(key) new_add = self.dialect.concat_add(table_rname) metadata_change = False sql_fields_current = copy.copy(sql_fields_old) for key in keys: query = None if key not in sql_fields_old: sql_fields_current[key] = sql_fields[key] if self.dbengine in ("postgres",) and sql_fields[key][ "type" ].startswith("geometry"): # 'sql' == ftype in sql query = [sql_fields[key]["sql"]] else: query = [ "ALTER TABLE %s ADD %s %s;" % ( table_rname, sql_fields[key]["rname"], sql_fields_aux[key]["sql"].replace(", ", new_add), ) ] metadata_change = True elif self.dbengine in ("sqlite", "spatialite"): if key in sql_fields: sql_fields_current[key] = sql_fields[key] # Field rname has changes, add new column if ( sql_fields[key]["raw_rname"].lower() != sql_fields_old[key]["raw_rname"].lower() ): tt = sql_fields_aux[key]["sql"].replace(", ", new_add) query = [ "ALTER TABLE %s ADD %s %s;" % (table_rname, sql_fields[key]["rname"], tt), "UPDATE %s SET %s=%s;" % ( table_rname, sql_fields[key]["rname"], sql_fields_old[key]["rname"], ), ] metadata_change = True elif key not in sql_fields: del sql_fields_current[key] ftype = sql_fields_old[key]["type"] if self.dbengine == "postgres" and ftype.startswith("geometry"): geotype, parms = ftype[:-1].split("(") schema = parms.split(",")[0] query = [ "SELECT DropGeometryColumn ('%(schema)s', \ '%(table)s', '%(field)s');" % dict( schema=schema, table=table._raw_rname, field=sql_fields_old[key]["raw_rname"], ) ] else: query = [drop_expr % (table._rname, sql_fields_old[key]["rname"])] metadata_change = True # The field has a new rname, temp field is not needed elif ( sql_fields[key]["raw_rname"].lower() != sql_fields_old[key]["raw_rname"].lower() ): sql_fields_current[key] = sql_fields[key] tt = sql_fields_aux[key]["sql"].replace(", ", new_add) query = [ "ALTER TABLE %s ADD %s %s;" % (table_rname, sql_fields[key]["rname"], tt), "UPDATE %s SET %s=%s;" % ( table_rname, sql_fields[key]["rname"], sql_fields_old[key]["rname"], ), drop_expr % (table_rname, sql_fields_old[key]["rname"]), ] metadata_change = True elif ( sql_fields[key]["sql"] != sql_fields_old[key]["sql"] and not isinstance(field_types.get(key), SQLCustomType) and not sql_fields[key]["type"].startswith("reference") and not sql_fields[key]["type"].startswith("double") and not sql_fields[key]["type"].startswith("id") ): sql_fields_current[key] = sql_fields[key] tt = sql_fields_aux[key]["sql"].replace(", ", new_add) key_tmp = self.dialect.quote(key + "__tmp") query = [ "ALTER TABLE %s ADD %s %s;" % (table_rname, key_tmp, tt), "UPDATE %s SET %s=%s;" % (table_rname, key_tmp, sql_fields_old[key]["rname"]), drop_expr % (table_rname, sql_fields_old[key]["rname"]), "ALTER TABLE %s ADD %s %s;" % (table_rname, sql_fields[key]["rname"], tt), "UPDATE %s SET %s=%s;" % (table_rname, sql_fields[key]["rname"], key_tmp), drop_expr % (table_rname, key_tmp), ] metadata_change = True elif sql_fields[key] != sql_fields_old[key]: sql_fields_current[key] = sql_fields[key] metadata_change = True if query: self.log( "timestamp: %s\n" % datetime.datetime.today().isoformat(), table ) for sub_query in query: self.log(sub_query + "\n", table) if fake_migrate: if db._adapter.commit_on_alter_table: self.save_dbt(table, sql_fields_current) self.log("faked!\n", table) else: self.adapter.execute(sub_query) # Caveat: mysql, oracle and firebird # do not allow multiple alter table # in one transaction so we must commit # partial transactions and # update table._dbt after alter table. if db._adapter.commit_on_alter_table: db.commit() self.save_dbt(table, sql_fields_current) self.log("success!\n", table) elif metadata_change: self.save_dbt(table, sql_fields_current) if metadata_change and not (query and db._adapter.commit_on_alter_table): db.commit() self.save_dbt(table, sql_fields_current) self.log("success!\n", table) def save_dbt(self, table, sql_fields_current): tfile = self.file_open(table._dbt, "wb") pickle.dump(sql_fields_current, tfile) self.file_close(tfile) def log(self, message, table=None): isabs = None logfilename = self.adapter.adapter_args.get("logfile", "sql.log") writelog = bool(logfilename) if writelog: isabs = os.path.isabs(logfilename) if table and table._dbt and writelog and self.adapter.folder: if isabs: table._loggername = logfilename else: table._loggername = pjoin(self.adapter.folder, logfilename) logfile = self.file_open(table._loggername, "ab") logfile.write(to_bytes(message)) self.file_close(logfile) @staticmethod def file_open(filename, mode="rb", lock=True): # to be used ONLY for files that on GAE may not be on filesystem if lock: fileobj = portalocker.LockedFile(filename, mode) else: fileobj = open(filename, mode) return fileobj @staticmethod def file_close(fileobj): # to be used ONLY for files that on GAE may not be on filesystem if fileobj: fileobj.close() @staticmethod def file_delete(filename): os.unlink(filename) @staticmethod def file_exists(filename): # to be used ONLY for files that on GAE may not be on filesystem return exists(filename) class InDBMigrator(Migrator): def file_exists(self, filename): return DatabaseStoredFile.exists(self.db, filename) def file_open(self, filename, mode="rb", lock=True): return DatabaseStoredFile(self.db, filename, mode) @staticmethod def file_close(fileobj): fileobj.close_connection() def file_delete(self, filename): query = "DELETE FROM web2py_filesystem WHERE path='%s'" % filename self.db.executesql(query) self.db.commit()