import re
import sys
import types
from collections import defaultdict
from contextlib import contextmanager
from .._compat import (
    PY2,
    with_metaclass,
    iterkeys,
    iteritems,
    hashlib_md5,
    integer_types,
    basestring,
)
from .._globals import IDENTITY
from ..connection import ConnectionPool
from ..exceptions import NotOnNOSQLError
from ..helpers.classes import (
    Reference,
    ExecutionHandler,
    SQLCustomType,
    SQLALL,
    NullDriver,
)
from ..helpers.methods import use_common_filters, xorify, merge_tablemaps
from ..helpers.regex import REGEX_SELECT_AS_PARSER, REGEX_TABLE_DOT_FIELD
from ..migrator import Migrator
from ..objects import (
    Table,
    Field,
    Expression,
    Query,
    Rows,
    IterRows,
    LazySet,
    LazyReferenceGetter,
    VirtualCommand,
    Select,
)
from ..utils import deprecated
from . import AdapterMeta, with_connection, with_connection_or_raise


CALLABLETYPES = (
    types.LambdaType,
    types.FunctionType,
    types.BuiltinFunctionType,
    types.MethodType,
    types.BuiltinMethodType,
)


class BaseAdapter(with_metaclass(AdapterMeta, ConnectionPool)):
    dbengine = "None"
    drivers = ()
    uploads_in_blob = False
    support_distributed_transaction = False

    def __init__(
        self,
        db,
        uri,
        pool_size=0,
        folder=None,
        db_codec="UTF-8",
        credential_decoder=IDENTITY,
        driver_args={},
        adapter_args={},
        after_connection=None,
        entity_quoting=False,
    ):
        super(BaseAdapter, self).__init__()
        self._load_dependencies()
        self.db = db
        self.uri = uri
        self.pool_size = pool_size
        self.folder = folder
        self.db_codec = db_codec
        self.credential_decoder = credential_decoder
        self.driver_args = driver_args
        self.adapter_args = adapter_args
        self.expand = self._expand
        self._after_connection = after_connection
        self.set_connection(None)
        self.find_driver()
        self._initialize_()

    def _load_dependencies(self):
        from ..dialects import dialects
        from ..parsers import parsers
        from ..representers import representers

        self.dialect = dialects.get_for(self)
        self.parser = parsers.get_for(self)
        self.representer = representers.get_for(self)

    def _initialize_(self):
        self._find_work_folder()

    @property
    def types(self):
        return self.dialect.types

    @property
    def _available_drivers(self):
        return [
            driver
            for driver in self.drivers
            if driver in iterkeys(self.db._drivers_available)
        ]

    def _driver_from_uri(self):
        rv = None
        if self.uri:
            items = self.uri.split("://", 1)[0].split(":")
            rv = items[1] if len(items) > 1 else None
        return rv

    def find_driver(self):
        if getattr(self, "driver", None) is not None:
            return
        requested_driver = self._driver_from_uri() or self.adapter_args.get("driver")
        if requested_driver:
            if requested_driver in self._available_drivers:
                self.driver_name = requested_driver
                self.driver = self.db._drivers_available[requested_driver]
            else:
                raise RuntimeError("Driver %s is not available" % requested_driver)
        elif self._available_drivers:
            self.driver_name = self._available_drivers[0]
            self.driver = self.db._drivers_available[self.driver_name]
        else:
            raise RuntimeError(
                "No driver of supported ones %s is available" % str(self.drivers)
            )

    def connector(self):
        return self.driver.connect(self.driver_args)

    def test_connection(self):
        pass

    @with_connection
    def close_connection(self):
        rv = self.connection.close()
        self.set_connection(None)
        return rv

    def tables(self, *queries):
        tables = dict()
        for query in queries:
            if isinstance(query, Field):
                key = query.tablename
                if tables.get(key, query.table) is not query.table:
                    raise ValueError("Name conflict in table list: %s" % key)
                tables[key] = query.table
            elif isinstance(query, (Expression, Query)):
                tmp = [x for x in (query.first, query.second) if x is not None]
                tables = merge_tablemaps(tables, self.tables(*tmp))
        return tables

    def get_table(self, *queries):
        tablemap = self.tables(*queries)
        if len(tablemap) == 1:
            return tablemap.popitem()[1]
        elif len(tablemap) < 1:
            raise RuntimeError("No table selected")
        else:
            raise RuntimeError("Too many tables selected (%s)" % str(list(tablemap)))

    def common_filter(self, query, tablist):
        tenant_fieldname = self.db._request_tenant
        for table in tablist:
            if isinstance(table, basestring):
                table = self.db[table]
            # deal with user provided filters
            if table._common_filter is not None:
                query = query & table._common_filter(query)
            # deal with multi_tenant filters
            if tenant_fieldname in table:
                default = table[tenant_fieldname].default
                if default is not None:
                    newquery = table[tenant_fieldname] == default
                    if query is None:
                        query = newquery
                    else:
                        query = query & newquery
        return query

    def _expand(self, expression, field_type=None, colnames=False, query_env={}):
        return str(expression)

    def expand_all(self, fields, tabledict):
        new_fields = []
        append = new_fields.append
        for item in fields:
            if isinstance(item, SQLALL):
                new_fields += item._table
            elif isinstance(item, str):
                m = REGEX_TABLE_DOT_FIELD.match(item)
                if m:
                    tablename, fieldname = m.groups()
                    append(self.db[tablename][fieldname])
                else:
                    append(Expression(self.db, lambda item=item: item))
            else:
                append(item)
        # ## if no fields specified take them all from the requested tables
        if not new_fields:
            for table in tabledict.values():
                for field in table:
                    append(field)
        return new_fields

    def parse_value(self, value, field_itype, field_type, blob_decode=True):
        # [Note - gi0baro] I think next if block can be (should be?) avoided
        if field_type != "blob" and isinstance(value, str):
            try:
                value = value.decode(self.db._db_codec)
            except Exception:
                pass
        if PY2 and isinstance(value, unicode):
            value = value.encode("utf-8")
        if isinstance(field_type, SQLCustomType):
            value = field_type.decoder(value)
        if not isinstance(field_type, str) or value is None:
            return value
        elif field_type == "blob" and not blob_decode:
            return value
        else:
            return self.parser.parse(value, field_itype, field_type)

    def _add_operators_to_parsed_row(self, rid, table, row):
        for key, record_operator in iteritems(self.db.record_operators):
            setattr(row, key, record_operator(row, table, rid))
        if table._db._lazy_tables:
            row["__get_lazy_reference__"] = LazyReferenceGetter(table, rid)

    def _add_reference_sets_to_parsed_row(self, rid, table, tablename, row):
        for rfield in table._referenced_by:
            referee_link = self.db._referee_name and self.db._referee_name % dict(
                table=rfield.tablename, field=rfield.name
            )
            if referee_link and referee_link not in row and referee_link != tablename:
                row[referee_link] = LazySet(rfield, rid)

    def _regex_select_as_parser(self, colname):
        return re.search(REGEX_SELECT_AS_PARSER, colname)

    def _parse(
        self,
        row,
        tmps,
        fields,
        colnames,
        blob_decode,
        cacheable,
        fields_virtual,
        fields_lazy,
    ):
        new_row = defaultdict(self.db.Row)
        extras = self.db.Row()
        #: let's loop over columns
        for (j, colname) in enumerate(colnames):
            value = row[j]
            tmp = tmps[j]
            tablename = None
            #: do we have a real column?
            if tmp:
                (tablename, fieldname, table, field, ft, fit) = tmp
                colset = new_row[tablename]
                #: parse value
                value = self.parse_value(value, fit, ft, blob_decode)
                if field.filter_out:
                    value = field.filter_out(value)
                colset[fieldname] = value
                #! backward compatibility
                if ft == "id" and fieldname != "id" and "id" not in table.fields:
                    colset["id"] = value
                #: additional parsing for 'id' fields
                if ft == "id" and not cacheable:
                    self._add_operators_to_parsed_row(value, table, colset)
                    #: table may be 'nested_select' which doesn't have '_reference_by'
                    if hasattr(table, '_reference_by'):
                        self._add_reference_sets_to_parsed_row(
                            value, table, tablename, colset
                        )
            #: otherwise we set the value in extras
            else:
                #: fields[j] may be None if only 'colnames' was specified in db.executesql()
                f_itype, ftype = fields[j] and [fields[j]._itype, fields[j].type] or [None, None]
                value = self.parse_value(
                    value, f_itype, ftype, blob_decode
                )
                extras[colname] = value
                if not fields[j]:
                    new_row[colname] = value
                else:
                    new_column_match = self._regex_select_as_parser(colname)
                    if new_column_match is not None:
                        new_column_name = new_column_match.group(1)
                        new_row[new_column_name] = value
        #: add extras if needed (eg. operations results)
        if extras:
            new_row["_extra"] = extras
        #: add virtuals
        new_row = self.db.Row(**new_row)
        for tablename in fields_virtual.keys():
            for f, v in fields_virtual[tablename][1]:
                try:
                    new_row[tablename][f] = v.f(new_row)
                except (AttributeError, KeyError):
                    pass  # not enough fields to define virtual field
            for f, v in fields_lazy[tablename][1]:
                try:
                    new_row[tablename][f] = v.handler(v.f, new_row)
                except (AttributeError, KeyError):
                    pass  # not enough fields to define virtual field
        return new_row

    def _parse_expand_colnames(self, fieldlist):
        """
        - Expand a list of colnames into a list of
          (tablename, fieldname, table_obj, field_obj, field_type)
        - Create a list of table for virtual/lazy fields
        """
        fields_virtual = {}
        fields_lazy = {}
        tmps = []
        for field in fieldlist:
            if not isinstance(field, Field):
                tmps.append(None)
                continue
            table = field.table
            tablename, fieldname = table._tablename, field.name
            ft = field.type
            fit = field._itype
            tmps.append((tablename, fieldname, table, field, ft, fit))
            if tablename not in fields_virtual:
                fields_virtual[tablename] = (
                    table,
                    [(f.name, f) for f in table._virtual_fields],
                )
                fields_lazy[tablename] = (
                    table,
                    [(f.name, f) for f in table._virtual_methods],
                )
        return (fields_virtual, fields_lazy, tmps)

    def parse(self, rows, fields, colnames, blob_decode=True, cacheable=False):
        (fields_virtual, fields_lazy, tmps) = self._parse_expand_colnames(fields)
        new_rows = [
            self._parse(
                row,
                tmps,
                fields,
                colnames,
                blob_decode,
                cacheable,
                fields_virtual,
                fields_lazy,
            )
            for row in rows
        ]
        rowsobj = self.db.Rows(self.db, new_rows, colnames, rawrows=rows, fields=fields)
        # Old style virtual fields
        for tablename, tmp in fields_virtual.items():
            table = tmp[0]
            # ## old style virtual fields
            for item in table.virtualfields:
                try:
                    rowsobj = rowsobj.setvirtualfields(**{tablename: item})
                except (KeyError, AttributeError):
                    # to avoid breaking virtualfields when partial select
                    pass
        return rowsobj

    def iterparse(self, sql, fields, colnames, blob_decode=True, cacheable=False):
        """
        Iterator to parse one row at a time.
        It doesn't support the old style virtual fields
        """
        return IterRows(self.db, sql, fields, colnames, blob_decode, cacheable)

    def adapt(self, value):
        return value

    def represent(self, obj, field_type):
        if isinstance(obj, CALLABLETYPES):
            obj = obj()
        return self.representer.represent(obj, field_type)

    def _drop_table_cleanup(self, table):
        del self.db[table._tablename]
        del self.db.tables[self.db.tables.index(table._tablename)]
        self.db._remove_references_to(table)

    def drop_table(self, table, mode=""):
        self._drop_table_cleanup(table)

    def rowslice(self, rows, minimum=0, maximum=None):
        return rows

    def sqlsafe_table(self, tablename, original_tablename=None):
        return tablename

    def sqlsafe_field(self, fieldname):
        return fieldname


class DebugHandler(ExecutionHandler):
    def before_execute(self, command):
        self.adapter.db.logger.debug("SQL: %s" % command)


class SQLAdapter(BaseAdapter):
    commit_on_alter_table = False
    # [Note - gi0baro] can_select_for_update should be deprecated and removed
    can_select_for_update = True
    execution_handlers = []
    migrator_cls = Migrator

    def __init__(self, *args, **kwargs):
        super(SQLAdapter, self).__init__(*args, **kwargs)
        migrator_cls = self.adapter_args.get("migrator", self.migrator_cls)
        self.migrator = migrator_cls(self)
        self.execution_handlers = list(self.db.execution_handlers)
        if self.db._debug:
            self.execution_handlers.insert(0, DebugHandler)

    def test_connection(self):
        self.execute("SELECT 1;")

    def represent(self, obj, field_type):
        if isinstance(obj, (Expression, Field)):
            return str(obj)
        return super(SQLAdapter, self).represent(obj, field_type)

    def adapt(self, obj):
        return "'%s'" % obj.replace("'", "''")

    def smart_adapt(self, obj):
        if isinstance(obj, (int, float)):
            return str(obj)
        return self.adapt(str(obj))

    def fetchall(self):
        return self.cursor.fetchall()

    def fetchone(self):
        return self.cursor.fetchone()

    def _build_handlers_for_execution(self):
        rv = []
        for handler_class in self.execution_handlers:
            rv.append(handler_class(self))
        return rv

    def filter_sql_command(self, command):
        return command

    @with_connection_or_raise
    def execute(self, *args, **kwargs):
        command = self.filter_sql_command(args[0])
        handlers = self._build_handlers_for_execution()
        for handler in handlers:
            handler.before_execute(command)
        rv = self.cursor.execute(command, *args[1:], **kwargs)
        for handler in handlers:
            handler.after_execute(command)
        return rv

    def _expand(self, expression, field_type=None, colnames=False, query_env={}):
        if isinstance(expression, Field):
            if not colnames:
                rv = expression.sqlsafe
            else:
                rv = expression.longname
            if field_type == "string" and expression.type not in (
                "string",
                "text",
                "json",
                "jsonb",
                "password",
            ):
                rv = self.dialect.cast(rv, self.types["text"], query_env)
        elif isinstance(expression, (Expression, Query)):
            first = expression.first
            second = expression.second
            op = expression.op
            optional_args = expression.optional_args or {}
            optional_args["query_env"] = query_env
            if second is not None:
                rv = op(first, second, **optional_args)
            elif first is not None:
                rv = op(first, **optional_args)
            elif isinstance(op, str):
                if op.endswith(";"):
                    op = op[:-1]
                rv = "(%s)" % op
            else:
                rv = op()
        elif field_type:
            rv = self.represent(expression, field_type)
        elif isinstance(expression, (list, tuple)):
            rv = ",".join(self.represent(item, field_type) for item in expression)
        elif isinstance(expression, bool):
            rv = self.dialect.true_exp if expression else self.dialect.false_exp
        else:
            rv = expression
        return str(rv)

    def _expand_for_index(
        self, expression, field_type=None, colnames=False, query_env={}
    ):
        if isinstance(expression, Field):
            return expression._rname
        return self._expand(expression, field_type, colnames, query_env)

    @contextmanager
    def index_expander(self):
        self.expand = self._expand_for_index
        yield
        self.expand = self._expand

    def lastrowid(self, table):
        return self.cursor.lastrowid

    def _insert(self, table, fields):
        if fields:
            return self.dialect.insert(
                table._rname,
                ",".join(el[0]._rname for el in fields),
                ",".join(self.expand(v, f.type) for f, v in fields),
            )
        return self.dialect.insert_empty(table._rname)

    def insert(self, table, fields):
        query = self._insert(table, fields)
        try:
            self.execute(query)
        except:
            e = sys.exc_info()[1]
            if hasattr(table, "_on_insert_error"):
                return table._on_insert_error(table, fields, e)
            raise e
        if hasattr(table, "_primarykey"):
            pkdict = dict(
                [(k[0].name, k[1]) for k in fields if k[0].name in table._primarykey]
            )
            if pkdict:
                return pkdict
        id = self.lastrowid(table)
        if hasattr(table, "_primarykey") and len(table._primarykey) == 1:
            id = {table._primarykey[0]: id}
        if not isinstance(id, integer_types):
            return id
        rid = Reference(id)
        (rid._table, rid._record) = (table, None)
        return rid

    def _update(self, table, query, fields):
        sql_q = ""
        query_env = dict(current_scope=[table._tablename])
        if query:
            if use_common_filters(query):
                query = self.common_filter(query, [table])
            sql_q = self.expand(query, query_env=query_env)
        sql_v = ",".join(
            [
                "%s=%s"
                % (field._rname, self.expand(value, field.type, query_env=query_env))
                for (field, value) in fields
            ]
        )
        return self.dialect.update(table, sql_v, sql_q)

    def update(self, table, query, fields):
        sql = self._update(table, query, fields)
        try:
            self.execute(sql)
        except:
            e = sys.exc_info()[1]
            if hasattr(table, "_on_update_error"):
                return table._on_update_error(table, query, fields, e)
            raise e
        try:
            return self.cursor.rowcount
        except:
            return None

    def _delete(self, table, query):
        sql_q = ""
        query_env = dict(current_scope=[table._tablename])
        if query:
            if use_common_filters(query):
                query = self.common_filter(query, [table])
            sql_q = self.expand(query, query_env=query_env)
        return self.dialect.delete(table, sql_q)

    def delete(self, table, query):
        sql = self._delete(table, query)
        self.execute(sql)
        try:
            return self.cursor.rowcount
        except:
            return None

    def _colexpand(self, field, query_env):
        return self.expand(field, colnames=True, query_env=query_env)

    def _geoexpand(self, field, query_env):
        if (
            isinstance(field.type, str)
            and field.type.startswith("geo")
            and isinstance(field, Field)
        ):
            field = field.st_astext()
        return self.expand(field, query_env=query_env)

    def _build_joins_for_select(self, tablenames, param):
        if not isinstance(param, (tuple, list)):
            param = [param]
        tablemap = {}
        for item in param:
            if isinstance(item, Expression):
                item = item.first
            key = item._tablename
            if tablemap.get(key, item) is not item:
                raise ValueError("Name conflict in table list: %s" % key)
            tablemap[key] = item
        join_tables = [t._tablename for t in param if not isinstance(t, Expression)]
        join_on = [t for t in param if isinstance(t, Expression)]
        tables_to_merge = {}
        for t in join_on:
            tables_to_merge = merge_tablemaps(tables_to_merge, self.tables(t))
        join_on_tables = [t.first._tablename for t in join_on]
        for t in join_on_tables:
            if t in tables_to_merge:
                tables_to_merge.pop(t)
        important_tablenames = join_tables + join_on_tables + list(tables_to_merge)
        excluded = [t for t in tablenames if t not in important_tablenames]
        return (
            join_tables,
            join_on,
            tables_to_merge,
            join_on_tables,
            important_tablenames,
            excluded,
            tablemap,
        )

    def _select_wcols(
        self,
        query,
        fields,
        left=False,
        join=False,
        distinct=False,
        orderby=False,
        groupby=False,
        having=False,
        limitby=False,
        orderby_on_limitby=True,
        for_update=False,
        outer_scoped=[],
        required=None,
        cache=None,
        cacheable=None,
        processor=None,
    ):
        #: parse tablemap
        tablemap = self.tables(query)
        #: apply common filters if needed
        if use_common_filters(query):
            query = self.common_filter(query, list(tablemap.values()))
        #: auto-adjust tables
        tablemap = merge_tablemaps(tablemap, self.tables(*fields))
        #: remove outer scoped tables if needed
        for item in outer_scoped:
            # FIXME: check for name conflicts
            tablemap.pop(item, None)
        if len(tablemap) < 1:
            raise SyntaxError("Set: no tables selected")
        query_tables = list(tablemap)
        #: check for_update argument
        # [Note - gi0baro] I think this should be removed since useless?
        #                  should affect only NoSQL?
        if self.can_select_for_update is False and for_update is True:
            raise SyntaxError("invalid select attribute: for_update")
        #: build joins (inner, left outer) and table names
        if join:
            (
                # FIXME? ijoin_tables is never used
                ijoin_tables,
                ijoin_on,
                itables_to_merge,
                ijoin_on_tables,
                iimportant_tablenames,
                iexcluded,
                itablemap,
            ) = self._build_joins_for_select(tablemap, join)
            tablemap = merge_tablemaps(tablemap, itables_to_merge)
            tablemap = merge_tablemaps(tablemap, itablemap)
        if left:
            (
                join_tables,
                join_on,
                tables_to_merge,
                join_on_tables,
                important_tablenames,
                excluded,
                jtablemap,
            ) = self._build_joins_for_select(tablemap, left)
            tablemap = merge_tablemaps(tablemap, tables_to_merge)
            tablemap = merge_tablemaps(tablemap, jtablemap)
        current_scope = outer_scoped + list(tablemap)
        query_env = dict(current_scope=current_scope, parent_scope=outer_scoped)
        #: prepare columns and expand fields
        colnames = [self._colexpand(x, query_env) for x in fields]
        sql_fields = ", ".join(self._geoexpand(x, query_env) for x in fields)
        table_alias = lambda name: tablemap[name].query_name(outer_scoped)[0]
        if join and not left:
            cross_joins = iexcluded + list(itables_to_merge)
            tokens = [table_alias(cross_joins[0])]
            tokens += [
                self.dialect.cross_join(table_alias(t), query_env)
                for t in cross_joins[1:]
            ]
            tokens += [self.dialect.join(t, query_env) for t in ijoin_on]
            sql_t = " ".join(tokens)
        elif not join and left:
            cross_joins = excluded + list(tables_to_merge)
            tokens = [table_alias(cross_joins[0])]
            tokens += [
                self.dialect.cross_join(table_alias(t), query_env)
                for t in cross_joins[1:]
            ]
            # FIXME: WTF? This is not correct syntax at least on PostgreSQL
            if join_tables:
                tokens.append(
                    self.dialect.left_join(
                        ",".join([table_alias(t) for t in join_tables]), query_env
                    )
                )
            tokens += [self.dialect.left_join(t, query_env) for t in join_on]
            sql_t = " ".join(tokens)
        elif join and left:
            all_tables_in_query = set(
                important_tablenames + iimportant_tablenames + query_tables
            )
            tables_in_joinon = set(join_on_tables + ijoin_on_tables)
            tables_not_in_joinon = list(
                all_tables_in_query.difference(tables_in_joinon)
            )
            tokens = [table_alias(tables_not_in_joinon[0])]
            tokens += [
                self.dialect.cross_join(table_alias(t), query_env)
                for t in tables_not_in_joinon[1:]
            ]
            tokens += [self.dialect.join(t, query_env) for t in ijoin_on]
            # FIXME: WTF? This is not correct syntax at least on PostgreSQL
            if join_tables:
                tokens.append(
                    self.dialect.left_join(
                        ",".join([table_alias(t) for t in join_tables]), query_env
                    )
                )
            tokens += [self.dialect.left_join(t, query_env) for t in join_on]
            sql_t = " ".join(tokens)
        else:
            sql_t = ", ".join(table_alias(t) for t in query_tables)
        #: expand query if needed
        if query:
            query = self.expand(query, query_env=query_env)
        if having:
            having = self.expand(having, query_env=query_env)
        #: groupby
        sql_grp = groupby
        if groupby:
            if isinstance(groupby, (list, tuple)):
                groupby = xorify(groupby)
            sql_grp = self.expand(groupby, query_env=query_env)
        #: orderby
        sql_ord = False
        if orderby:
            if isinstance(orderby, (list, tuple)):
                orderby = xorify(orderby)
            if str(orderby) == "<random>":
                sql_ord = self.dialect.random
            else:
                sql_ord = self.expand(orderby, query_env=query_env)
        #: set default orderby if missing
        if (
            limitby
            and not groupby
            and query_tables
            and orderby_on_limitby
            and not orderby
        ):
            sql_ord = ", ".join(
                [
                    tablemap[t][x].sqlsafe
                    for t in query_tables
                    if not isinstance(tablemap[t], Select)
                    for x in (
                        hasattr(tablemap[t], "_primarykey")
                        and tablemap[t]._primarykey
                        or ["_id"]
                    )
                ]
            )
        #: build sql using dialect
        return (
            colnames,
            self.dialect.select(
                sql_fields,
                sql_t,
                query,
                sql_grp,
                having,
                sql_ord,
                limitby,
                distinct,
                for_update and self.can_select_for_update,
            ),
        )

    def _select(self, query, fields, attributes):
        return self._select_wcols(query, fields, **attributes)[1]

    def nested_select(self, query, fields, attributes):
        return Select(self.db, query, fields, attributes)

    def _select_aux_execute(self, sql):
        self.execute(sql)
        return self.cursor.fetchall()

    def _select_aux(self, sql, fields, attributes, colnames):
        cache = attributes.get("cache", None)
        if not cache:
            rows = self._select_aux_execute(sql)
        else:
            if isinstance(cache, dict):
                cache_model = cache["model"]
                time_expire = cache["expiration"]
                key = cache.get("key")
                if not key:
                    key = self.uri + "/" + sql + "/rows"
                    key = hashlib_md5(key).hexdigest()
            else:
                (cache_model, time_expire) = cache
                key = self.uri + "/" + sql + "/rows"
                key = hashlib_md5(key).hexdigest()
            rows = cache_model(
                key,
                lambda self=self, sql=sql: self._select_aux_execute(sql),
                time_expire,
            )
        if isinstance(rows, tuple):
            rows = list(rows)
        limitby = attributes.get("limitby", None) or (0,)
        rows = self.rowslice(rows, limitby[0], None)
        processor = attributes.get("processor", self.parse)
        cacheable = attributes.get("cacheable", False)
        return processor(rows, fields, colnames, cacheable=cacheable)

    def _cached_select(self, cache, sql, fields, attributes, colnames):
        del attributes["cache"]
        (cache_model, time_expire) = cache
        key = self.uri + "/" + sql
        key = hashlib_md5(key).hexdigest()
        args = (sql, fields, attributes, colnames)
        ret = cache_model(
            key, lambda self=self, args=args: self._select_aux(*args), time_expire
        )
        ret._restore_fields(fields)
        return ret

    def select(self, query, fields, attributes):
        colnames, sql = self._select_wcols(query, fields, **attributes)
        cache = attributes.get("cache", None)
        if cache and attributes.get("cacheable", False):
            return self._cached_select(cache, sql, fields, attributes, colnames)
        return self._select_aux(sql, fields, attributes, colnames)

    def iterselect(self, query, fields, attributes):
        colnames, sql = self._select_wcols(query, fields, **attributes)
        cacheable = attributes.get("cacheable", False)
        return self.iterparse(sql, fields, colnames, cacheable=cacheable)

    def _count(self, query, distinct=None):
        tablemap = self.tables(query)
        tablenames = list(tablemap)
        tables = list(tablemap.values())
        query_env = dict(current_scope=tablenames)
        sql_q = ""
        if query:
            if use_common_filters(query):
                query = self.common_filter(query, tables)
            sql_q = self.expand(query, query_env=query_env)
        sql_t = ",".join(self.table_alias(t, []) for t in tables)
        sql_fields = "*"
        if distinct:
            if isinstance(distinct, (list, tuple)):
                distinct = xorify(distinct)
            sql_fields = self.expand(distinct, query_env=query_env)
        return self.dialect.select(
            self.dialect.count(sql_fields, distinct), sql_t, sql_q
        )

    def count(self, query, distinct=None):
        self.execute(self._count(query, distinct))
        return self.cursor.fetchone()[0]

    def bulk_insert(self, table, items):
        return [self.insert(table, item) for item in items]

    def create_table(self, *args, **kwargs):
        return self.migrator.create_table(*args, **kwargs)

    def _drop_table_cleanup(self, table):
        super(SQLAdapter, self)._drop_table_cleanup(table)
        if table._dbt:
            self.migrator.file_delete(table._dbt)
            self.migrator.log("success!\n", table)

    def drop_table(self, table, mode=""):
        queries = self.dialect.drop_table(table, mode)
        for query in queries:
            if table._dbt:
                self.migrator.log(query + "\n", table)
            self.execute(query)
        self.commit()
        self._drop_table_cleanup(table)

    @deprecated("drop", "drop_table", "SQLAdapter")
    def drop(self, table, mode=""):
        return self.drop_table(table, mode="")

    def truncate(self, table, mode=""):
        # Prepare functions "write_to_logfile" and "close_logfile"
        try:
            queries = self.dialect.truncate(table, mode)
            for query in queries:
                self.migrator.log(query + "\n", table)
                self.execute(query)
            self.migrator.log("success!\n", table)
        finally:
            pass

    def create_index(self, table, index_name, *fields, **kwargs):
        expressions = [
            field._rname if isinstance(field, Field) else field for field in fields
        ]
        sql = self.dialect.create_index(index_name, table, expressions, **kwargs)
        try:
            self.execute(sql)
            self.commit()
        except Exception as e:
            self.rollback()
            err = (
                "Error creating index %s\n  Driver error: %s\n"
                + "  SQL instruction: %s"
            )
            raise RuntimeError(err % (index_name, str(e), sql))
        return True

    def drop_index(self, table, index_name):
        sql = self.dialect.drop_index(index_name, table)
        try:
            self.execute(sql)
            self.commit()
        except Exception as e:
            self.rollback()
            err = "Error dropping index %s\n  Driver error: %s"
            raise RuntimeError(err % (index_name, str(e)))
        return True

    def distributed_transaction_begin(self, key):
        pass

    @with_connection
    def commit(self):
        return self.connection.commit()

    @with_connection
    def rollback(self):
        return self.connection.rollback()

    @with_connection
    def prepare(self, key):
        self.connection.prepare()

    @with_connection
    def commit_prepared(self, key):
        self.connection.commit()

    @with_connection
    def rollback_prepared(self, key):
        self.connection.rollback()

    def create_sequence_and_triggers(self, query, table, **args):
        self.execute(query)

    def sqlsafe_table(self, tablename, original_tablename=None):
        if original_tablename is not None:
            return self.dialect.alias(original_tablename, tablename)
        return self.dialect.quote(tablename)

    def sqlsafe_field(self, fieldname):
        return self.dialect.quote(fieldname)

    def table_alias(self, tbl, current_scope=[]):
        if isinstance(tbl, basestring):
            tbl = self.db[tbl]
        return tbl.query_name(current_scope)[0]

    def id_query(self, table):
        pkeys = getattr(table, "_primarykey", None)
        if pkeys:
            return table[pkeys[0]] != None
        return table._id != None


class NoSQLAdapter(BaseAdapter):
    can_select_for_update = False

    def commit(self):
        pass

    def rollback(self):
        pass

    def prepare(self):
        pass

    def commit_prepared(self, key):
        pass

    def rollback_prepared(self, key):
        pass

    def id_query(self, table):
        return table._id > 0

    def create_table(self, table, migrate=True, fake_migrate=False, polymodel=None):
        table._dbt = None
        table._notnulls = []
        for field_name in table.fields:
            if table[field_name].notnull:
                table._notnulls.append(field_name)
        table._uniques = []
        for field_name in table.fields:
            if table[field_name].unique:
                # this is unnecessary if the fields are indexed and unique
                table._uniques.append(field_name)

    def drop_table(self, table, mode=""):
        ctable = self.connection[table._tablename]
        ctable.drop()
        self._drop_table_cleanup(table)

    @deprecated("drop", "drop_table", "SQLAdapter")
    def drop(self, table, mode=""):
        return self.drop_table(table, mode="")

    def _select(self, *args, **kwargs):
        raise NotOnNOSQLError("Nested queries are not supported on NoSQL databases")

    def nested_select(self, *args, **kwargs):
        raise NotOnNOSQLError("Nested queries are not supported on NoSQL databases")


class NullAdapter(BaseAdapter):
    def _load_dependencies(self):
        from ..dialects.base import CommonDialect

        self.dialect = CommonDialect(self)

    def find_driver(self):
        pass

    def connector(self):
        return NullDriver()
