source: OpenRLabs-Git/deploy/rlabs-docker/web2py-rlabs/gluon/packages/dal/pydal/adapters/base.py

main
Last change on this file was 42bd667, checked in by David Fuertes <dfuertes@…>, 4 years ago

Historial Limpio

  • Property mode set to 100755
File size: 37.4 KB
Line 
1import re
2import sys
3import types
4from collections import defaultdict
5from contextlib import contextmanager
6from .._compat import (
7    PY2,
8    with_metaclass,
9    iterkeys,
10    iteritems,
11    hashlib_md5,
12    integer_types,
13    basestring,
14)
15from .._globals import IDENTITY
16from ..connection import ConnectionPool
17from ..exceptions import NotOnNOSQLError
18from ..helpers.classes import (
19    Reference,
20    ExecutionHandler,
21    SQLCustomType,
22    SQLALL,
23    NullDriver,
24)
25from ..helpers.methods import use_common_filters, xorify, merge_tablemaps
26from ..helpers.regex import REGEX_SELECT_AS_PARSER, REGEX_TABLE_DOT_FIELD
27from ..migrator import Migrator
28from ..objects import (
29    Table,
30    Field,
31    Expression,
32    Query,
33    Rows,
34    IterRows,
35    LazySet,
36    LazyReferenceGetter,
37    VirtualCommand,
38    Select,
39)
40from ..utils import deprecated
41from . import AdapterMeta, with_connection, with_connection_or_raise
42
43
44CALLABLETYPES = (
45    types.LambdaType,
46    types.FunctionType,
47    types.BuiltinFunctionType,
48    types.MethodType,
49    types.BuiltinMethodType,
50)
51
52
53class BaseAdapter(with_metaclass(AdapterMeta, ConnectionPool)):
54    dbengine = "None"
55    drivers = ()
56    uploads_in_blob = False
57    support_distributed_transaction = False
58
59    def __init__(
60        self,
61        db,
62        uri,
63        pool_size=0,
64        folder=None,
65        db_codec="UTF-8",
66        credential_decoder=IDENTITY,
67        driver_args={},
68        adapter_args={},
69        after_connection=None,
70        entity_quoting=False,
71    ):
72        super(BaseAdapter, self).__init__()
73        self._load_dependencies()
74        self.db = db
75        self.uri = uri
76        self.pool_size = pool_size
77        self.folder = folder
78        self.db_codec = db_codec
79        self.credential_decoder = credential_decoder
80        self.driver_args = driver_args
81        self.adapter_args = adapter_args
82        self.expand = self._expand
83        self._after_connection = after_connection
84        self.set_connection(None)
85        self.find_driver()
86        self._initialize_()
87
88    def _load_dependencies(self):
89        from ..dialects import dialects
90        from ..parsers import parsers
91        from ..representers import representers
92
93        self.dialect = dialects.get_for(self)
94        self.parser = parsers.get_for(self)
95        self.representer = representers.get_for(self)
96
97    def _initialize_(self):
98        self._find_work_folder()
99
100    @property
101    def types(self):
102        return self.dialect.types
103
104    @property
105    def _available_drivers(self):
106        return [
107            driver
108            for driver in self.drivers
109            if driver in iterkeys(self.db._drivers_available)
110        ]
111
112    def _driver_from_uri(self):
113        rv = None
114        if self.uri:
115            items = self.uri.split("://", 1)[0].split(":")
116            rv = items[1] if len(items) > 1 else None
117        return rv
118
119    def find_driver(self):
120        if getattr(self, "driver", None) is not None:
121            return
122        requested_driver = self._driver_from_uri() or self.adapter_args.get("driver")
123        if requested_driver:
124            if requested_driver in self._available_drivers:
125                self.driver_name = requested_driver
126                self.driver = self.db._drivers_available[requested_driver]
127            else:
128                raise RuntimeError("Driver %s is not available" % requested_driver)
129        elif self._available_drivers:
130            self.driver_name = self._available_drivers[0]
131            self.driver = self.db._drivers_available[self.driver_name]
132        else:
133            raise RuntimeError(
134                "No driver of supported ones %s is available" % str(self.drivers)
135            )
136
137    def connector(self):
138        return self.driver.connect(self.driver_args)
139
140    def test_connection(self):
141        pass
142
143    @with_connection
144    def close_connection(self):
145        rv = self.connection.close()
146        self.set_connection(None)
147        return rv
148
149    def tables(self, *queries):
150        tables = dict()
151        for query in queries:
152            if isinstance(query, Field):
153                key = query.tablename
154                if tables.get(key, query.table) is not query.table:
155                    raise ValueError("Name conflict in table list: %s" % key)
156                tables[key] = query.table
157            elif isinstance(query, (Expression, Query)):
158                tmp = [x for x in (query.first, query.second) if x is not None]
159                tables = merge_tablemaps(tables, self.tables(*tmp))
160        return tables
161
162    def get_table(self, *queries):
163        tablemap = self.tables(*queries)
164        if len(tablemap) == 1:
165            return tablemap.popitem()[1]
166        elif len(tablemap) < 1:
167            raise RuntimeError("No table selected")
168        else:
169            raise RuntimeError("Too many tables selected (%s)" % str(list(tablemap)))
170
171    def common_filter(self, query, tablist):
172        tenant_fieldname = self.db._request_tenant
173        for table in tablist:
174            if isinstance(table, basestring):
175                table = self.db[table]
176            # deal with user provided filters
177            if table._common_filter is not None:
178                query = query & table._common_filter(query)
179            # deal with multi_tenant filters
180            if tenant_fieldname in table:
181                default = table[tenant_fieldname].default
182                if default is not None:
183                    newquery = table[tenant_fieldname] == default
184                    if query is None:
185                        query = newquery
186                    else:
187                        query = query & newquery
188        return query
189
190    def _expand(self, expression, field_type=None, colnames=False, query_env={}):
191        return str(expression)
192
193    def expand_all(self, fields, tabledict):
194        new_fields = []
195        append = new_fields.append
196        for item in fields:
197            if isinstance(item, SQLALL):
198                new_fields += item._table
199            elif isinstance(item, str):
200                m = REGEX_TABLE_DOT_FIELD.match(item)
201                if m:
202                    tablename, fieldname = m.groups()
203                    append(self.db[tablename][fieldname])
204                else:
205                    append(Expression(self.db, lambda item=item: item))
206            else:
207                append(item)
208        # ## if no fields specified take them all from the requested tables
209        if not new_fields:
210            for table in tabledict.values():
211                for field in table:
212                    append(field)
213        return new_fields
214
215    def parse_value(self, value, field_itype, field_type, blob_decode=True):
216        # [Note - gi0baro] I think next if block can be (should be?) avoided
217        if field_type != "blob" and isinstance(value, str):
218            try:
219                value = value.decode(self.db._db_codec)
220            except Exception:
221                pass
222        if PY2 and isinstance(value, unicode):
223            value = value.encode("utf-8")
224        if isinstance(field_type, SQLCustomType):
225            value = field_type.decoder(value)
226        if not isinstance(field_type, str) or value is None:
227            return value
228        elif field_type == "blob" and not blob_decode:
229            return value
230        else:
231            return self.parser.parse(value, field_itype, field_type)
232
233    def _add_operators_to_parsed_row(self, rid, table, row):
234        for key, record_operator in iteritems(self.db.record_operators):
235            setattr(row, key, record_operator(row, table, rid))
236        if table._db._lazy_tables:
237            row["__get_lazy_reference__"] = LazyReferenceGetter(table, rid)
238
239    def _add_reference_sets_to_parsed_row(self, rid, table, tablename, row):
240        for rfield in table._referenced_by:
241            referee_link = self.db._referee_name and self.db._referee_name % dict(
242                table=rfield.tablename, field=rfield.name
243            )
244            if referee_link and referee_link not in row and referee_link != tablename:
245                row[referee_link] = LazySet(rfield, rid)
246
247    def _regex_select_as_parser(self, colname):
248        return re.search(REGEX_SELECT_AS_PARSER, colname)
249
250    def _parse(
251        self,
252        row,
253        tmps,
254        fields,
255        colnames,
256        blob_decode,
257        cacheable,
258        fields_virtual,
259        fields_lazy,
260    ):
261        new_row = defaultdict(self.db.Row)
262        extras = self.db.Row()
263        #: let's loop over columns
264        for (j, colname) in enumerate(colnames):
265            value = row[j]
266            tmp = tmps[j]
267            tablename = None
268            #: do we have a real column?
269            if tmp:
270                (tablename, fieldname, table, field, ft, fit) = tmp
271                colset = new_row[tablename]
272                #: parse value
273                value = self.parse_value(value, fit, ft, blob_decode)
274                if field.filter_out:
275                    value = field.filter_out(value)
276                colset[fieldname] = value
277                #! backward compatibility
278                if ft == "id" and fieldname != "id" and "id" not in table.fields:
279                    colset["id"] = value
280                #: additional parsing for 'id' fields
281                if ft == "id" and not cacheable:
282                    self._add_operators_to_parsed_row(value, table, colset)
283                    #: table may be 'nested_select' which doesn't have '_reference_by'
284                    if hasattr(table, '_reference_by'):
285                        self._add_reference_sets_to_parsed_row(
286                            value, table, tablename, colset
287                        )
288            #: otherwise we set the value in extras
289            else:
290                #: fields[j] may be None if only 'colnames' was specified in db.executesql()
291                f_itype, ftype = fields[j] and [fields[j]._itype, fields[j].type] or [None, None]
292                value = self.parse_value(
293                    value, f_itype, ftype, blob_decode
294                )
295                extras[colname] = value
296                if not fields[j]:
297                    new_row[colname] = value
298                else:
299                    new_column_match = self._regex_select_as_parser(colname)
300                    if new_column_match is not None:
301                        new_column_name = new_column_match.group(1)
302                        new_row[new_column_name] = value
303        #: add extras if needed (eg. operations results)
304        if extras:
305            new_row["_extra"] = extras
306        #: add virtuals
307        new_row = self.db.Row(**new_row)
308        for tablename in fields_virtual.keys():
309            for f, v in fields_virtual[tablename][1]:
310                try:
311                    new_row[tablename][f] = v.f(new_row)
312                except (AttributeError, KeyError):
313                    pass  # not enough fields to define virtual field
314            for f, v in fields_lazy[tablename][1]:
315                try:
316                    new_row[tablename][f] = v.handler(v.f, new_row)
317                except (AttributeError, KeyError):
318                    pass  # not enough fields to define virtual field
319        return new_row
320
321    def _parse_expand_colnames(self, fieldlist):
322        """
323        - Expand a list of colnames into a list of
324          (tablename, fieldname, table_obj, field_obj, field_type)
325        - Create a list of table for virtual/lazy fields
326        """
327        fields_virtual = {}
328        fields_lazy = {}
329        tmps = []
330        for field in fieldlist:
331            if not isinstance(field, Field):
332                tmps.append(None)
333                continue
334            table = field.table
335            tablename, fieldname = table._tablename, field.name
336            ft = field.type
337            fit = field._itype
338            tmps.append((tablename, fieldname, table, field, ft, fit))
339            if tablename not in fields_virtual:
340                fields_virtual[tablename] = (
341                    table,
342                    [(f.name, f) for f in table._virtual_fields],
343                )
344                fields_lazy[tablename] = (
345                    table,
346                    [(f.name, f) for f in table._virtual_methods],
347                )
348        return (fields_virtual, fields_lazy, tmps)
349
350    def parse(self, rows, fields, colnames, blob_decode=True, cacheable=False):
351        (fields_virtual, fields_lazy, tmps) = self._parse_expand_colnames(fields)
352        new_rows = [
353            self._parse(
354                row,
355                tmps,
356                fields,
357                colnames,
358                blob_decode,
359                cacheable,
360                fields_virtual,
361                fields_lazy,
362            )
363            for row in rows
364        ]
365        rowsobj = self.db.Rows(self.db, new_rows, colnames, rawrows=rows, fields=fields)
366        # Old style virtual fields
367        for tablename, tmp in fields_virtual.items():
368            table = tmp[0]
369            # ## old style virtual fields
370            for item in table.virtualfields:
371                try:
372                    rowsobj = rowsobj.setvirtualfields(**{tablename: item})
373                except (KeyError, AttributeError):
374                    # to avoid breaking virtualfields when partial select
375                    pass
376        return rowsobj
377
378    def iterparse(self, sql, fields, colnames, blob_decode=True, cacheable=False):
379        """
380        Iterator to parse one row at a time.
381        It doesn't support the old style virtual fields
382        """
383        return IterRows(self.db, sql, fields, colnames, blob_decode, cacheable)
384
385    def adapt(self, value):
386        return value
387
388    def represent(self, obj, field_type):
389        if isinstance(obj, CALLABLETYPES):
390            obj = obj()
391        return self.representer.represent(obj, field_type)
392
393    def _drop_table_cleanup(self, table):
394        del self.db[table._tablename]
395        del self.db.tables[self.db.tables.index(table._tablename)]
396        self.db._remove_references_to(table)
397
398    def drop_table(self, table, mode=""):
399        self._drop_table_cleanup(table)
400
401    def rowslice(self, rows, minimum=0, maximum=None):
402        return rows
403
404    def sqlsafe_table(self, tablename, original_tablename=None):
405        return tablename
406
407    def sqlsafe_field(self, fieldname):
408        return fieldname
409
410
411class DebugHandler(ExecutionHandler):
412    def before_execute(self, command):
413        self.adapter.db.logger.debug("SQL: %s" % command)
414
415
416class SQLAdapter(BaseAdapter):
417    commit_on_alter_table = False
418    # [Note - gi0baro] can_select_for_update should be deprecated and removed
419    can_select_for_update = True
420    execution_handlers = []
421    migrator_cls = Migrator
422
423    def __init__(self, *args, **kwargs):
424        super(SQLAdapter, self).__init__(*args, **kwargs)
425        migrator_cls = self.adapter_args.get("migrator", self.migrator_cls)
426        self.migrator = migrator_cls(self)
427        self.execution_handlers = list(self.db.execution_handlers)
428        if self.db._debug:
429            self.execution_handlers.insert(0, DebugHandler)
430
431    def test_connection(self):
432        self.execute("SELECT 1;")
433
434    def represent(self, obj, field_type):
435        if isinstance(obj, (Expression, Field)):
436            return str(obj)
437        return super(SQLAdapter, self).represent(obj, field_type)
438
439    def adapt(self, obj):
440        return "'%s'" % obj.replace("'", "''")
441
442    def smart_adapt(self, obj):
443        if isinstance(obj, (int, float)):
444            return str(obj)
445        return self.adapt(str(obj))
446
447    def fetchall(self):
448        return self.cursor.fetchall()
449
450    def fetchone(self):
451        return self.cursor.fetchone()
452
453    def _build_handlers_for_execution(self):
454        rv = []
455        for handler_class in self.execution_handlers:
456            rv.append(handler_class(self))
457        return rv
458
459    def filter_sql_command(self, command):
460        return command
461
462    @with_connection_or_raise
463    def execute(self, *args, **kwargs):
464        command = self.filter_sql_command(args[0])
465        handlers = self._build_handlers_for_execution()
466        for handler in handlers:
467            handler.before_execute(command)
468        rv = self.cursor.execute(command, *args[1:], **kwargs)
469        for handler in handlers:
470            handler.after_execute(command)
471        return rv
472
473    def _expand(self, expression, field_type=None, colnames=False, query_env={}):
474        if isinstance(expression, Field):
475            if not colnames:
476                rv = expression.sqlsafe
477            else:
478                rv = expression.longname
479            if field_type == "string" and expression.type not in (
480                "string",
481                "text",
482                "json",
483                "jsonb",
484                "password",
485            ):
486                rv = self.dialect.cast(rv, self.types["text"], query_env)
487        elif isinstance(expression, (Expression, Query)):
488            first = expression.first
489            second = expression.second
490            op = expression.op
491            optional_args = expression.optional_args or {}
492            optional_args["query_env"] = query_env
493            if second is not None:
494                rv = op(first, second, **optional_args)
495            elif first is not None:
496                rv = op(first, **optional_args)
497            elif isinstance(op, str):
498                if op.endswith(";"):
499                    op = op[:-1]
500                rv = "(%s)" % op
501            else:
502                rv = op()
503        elif field_type:
504            rv = self.represent(expression, field_type)
505        elif isinstance(expression, (list, tuple)):
506            rv = ",".join(self.represent(item, field_type) for item in expression)
507        elif isinstance(expression, bool):
508            rv = self.dialect.true_exp if expression else self.dialect.false_exp
509        else:
510            rv = expression
511        return str(rv)
512
513    def _expand_for_index(
514        self, expression, field_type=None, colnames=False, query_env={}
515    ):
516        if isinstance(expression, Field):
517            return expression._rname
518        return self._expand(expression, field_type, colnames, query_env)
519
520    @contextmanager
521    def index_expander(self):
522        self.expand = self._expand_for_index
523        yield
524        self.expand = self._expand
525
526    def lastrowid(self, table):
527        return self.cursor.lastrowid
528
529    def _insert(self, table, fields):
530        if fields:
531            return self.dialect.insert(
532                table._rname,
533                ",".join(el[0]._rname for el in fields),
534                ",".join(self.expand(v, f.type) for f, v in fields),
535            )
536        return self.dialect.insert_empty(table._rname)
537
538    def insert(self, table, fields):
539        query = self._insert(table, fields)
540        try:
541            self.execute(query)
542        except:
543            e = sys.exc_info()[1]
544            if hasattr(table, "_on_insert_error"):
545                return table._on_insert_error(table, fields, e)
546            raise e
547        if hasattr(table, "_primarykey"):
548            pkdict = dict(
549                [(k[0].name, k[1]) for k in fields if k[0].name in table._primarykey]
550            )
551            if pkdict:
552                return pkdict
553        id = self.lastrowid(table)
554        if hasattr(table, "_primarykey") and len(table._primarykey) == 1:
555            id = {table._primarykey[0]: id}
556        if not isinstance(id, integer_types):
557            return id
558        rid = Reference(id)
559        (rid._table, rid._record) = (table, None)
560        return rid
561
562    def _update(self, table, query, fields):
563        sql_q = ""
564        query_env = dict(current_scope=[table._tablename])
565        if query:
566            if use_common_filters(query):
567                query = self.common_filter(query, [table])
568            sql_q = self.expand(query, query_env=query_env)
569        sql_v = ",".join(
570            [
571                "%s=%s"
572                % (field._rname, self.expand(value, field.type, query_env=query_env))
573                for (field, value) in fields
574            ]
575        )
576        return self.dialect.update(table, sql_v, sql_q)
577
578    def update(self, table, query, fields):
579        sql = self._update(table, query, fields)
580        try:
581            self.execute(sql)
582        except:
583            e = sys.exc_info()[1]
584            if hasattr(table, "_on_update_error"):
585                return table._on_update_error(table, query, fields, e)
586            raise e
587        try:
588            return self.cursor.rowcount
589        except:
590            return None
591
592    def _delete(self, table, query):
593        sql_q = ""
594        query_env = dict(current_scope=[table._tablename])
595        if query:
596            if use_common_filters(query):
597                query = self.common_filter(query, [table])
598            sql_q = self.expand(query, query_env=query_env)
599        return self.dialect.delete(table, sql_q)
600
601    def delete(self, table, query):
602        sql = self._delete(table, query)
603        self.execute(sql)
604        try:
605            return self.cursor.rowcount
606        except:
607            return None
608
609    def _colexpand(self, field, query_env):
610        return self.expand(field, colnames=True, query_env=query_env)
611
612    def _geoexpand(self, field, query_env):
613        if (
614            isinstance(field.type, str)
615            and field.type.startswith("geo")
616            and isinstance(field, Field)
617        ):
618            field = field.st_astext()
619        return self.expand(field, query_env=query_env)
620
621    def _build_joins_for_select(self, tablenames, param):
622        if not isinstance(param, (tuple, list)):
623            param = [param]
624        tablemap = {}
625        for item in param:
626            if isinstance(item, Expression):
627                item = item.first
628            key = item._tablename
629            if tablemap.get(key, item) is not item:
630                raise ValueError("Name conflict in table list: %s" % key)
631            tablemap[key] = item
632        join_tables = [t._tablename for t in param if not isinstance(t, Expression)]
633        join_on = [t for t in param if isinstance(t, Expression)]
634        tables_to_merge = {}
635        for t in join_on:
636            tables_to_merge = merge_tablemaps(tables_to_merge, self.tables(t))
637        join_on_tables = [t.first._tablename for t in join_on]
638        for t in join_on_tables:
639            if t in tables_to_merge:
640                tables_to_merge.pop(t)
641        important_tablenames = join_tables + join_on_tables + list(tables_to_merge)
642        excluded = [t for t in tablenames if t not in important_tablenames]
643        return (
644            join_tables,
645            join_on,
646            tables_to_merge,
647            join_on_tables,
648            important_tablenames,
649            excluded,
650            tablemap,
651        )
652
653    def _select_wcols(
654        self,
655        query,
656        fields,
657        left=False,
658        join=False,
659        distinct=False,
660        orderby=False,
661        groupby=False,
662        having=False,
663        limitby=False,
664        orderby_on_limitby=True,
665        for_update=False,
666        outer_scoped=[],
667        required=None,
668        cache=None,
669        cacheable=None,
670        processor=None,
671    ):
672        #: parse tablemap
673        tablemap = self.tables(query)
674        #: apply common filters if needed
675        if use_common_filters(query):
676            query = self.common_filter(query, list(tablemap.values()))
677        #: auto-adjust tables
678        tablemap = merge_tablemaps(tablemap, self.tables(*fields))
679        #: remove outer scoped tables if needed
680        for item in outer_scoped:
681            # FIXME: check for name conflicts
682            tablemap.pop(item, None)
683        if len(tablemap) < 1:
684            raise SyntaxError("Set: no tables selected")
685        query_tables = list(tablemap)
686        #: check for_update argument
687        # [Note - gi0baro] I think this should be removed since useless?
688        #                  should affect only NoSQL?
689        if self.can_select_for_update is False and for_update is True:
690            raise SyntaxError("invalid select attribute: for_update")
691        #: build joins (inner, left outer) and table names
692        if join:
693            (
694                # FIXME? ijoin_tables is never used
695                ijoin_tables,
696                ijoin_on,
697                itables_to_merge,
698                ijoin_on_tables,
699                iimportant_tablenames,
700                iexcluded,
701                itablemap,
702            ) = self._build_joins_for_select(tablemap, join)
703            tablemap = merge_tablemaps(tablemap, itables_to_merge)
704            tablemap = merge_tablemaps(tablemap, itablemap)
705        if left:
706            (
707                join_tables,
708                join_on,
709                tables_to_merge,
710                join_on_tables,
711                important_tablenames,
712                excluded,
713                jtablemap,
714            ) = self._build_joins_for_select(tablemap, left)
715            tablemap = merge_tablemaps(tablemap, tables_to_merge)
716            tablemap = merge_tablemaps(tablemap, jtablemap)
717        current_scope = outer_scoped + list(tablemap)
718        query_env = dict(current_scope=current_scope, parent_scope=outer_scoped)
719        #: prepare columns and expand fields
720        colnames = [self._colexpand(x, query_env) for x in fields]
721        sql_fields = ", ".join(self._geoexpand(x, query_env) for x in fields)
722        table_alias = lambda name: tablemap[name].query_name(outer_scoped)[0]
723        if join and not left:
724            cross_joins = iexcluded + list(itables_to_merge)
725            tokens = [table_alias(cross_joins[0])]
726            tokens += [
727                self.dialect.cross_join(table_alias(t), query_env)
728                for t in cross_joins[1:]
729            ]
730            tokens += [self.dialect.join(t, query_env) for t in ijoin_on]
731            sql_t = " ".join(tokens)
732        elif not join and left:
733            cross_joins = excluded + list(tables_to_merge)
734            tokens = [table_alias(cross_joins[0])]
735            tokens += [
736                self.dialect.cross_join(table_alias(t), query_env)
737                for t in cross_joins[1:]
738            ]
739            # FIXME: WTF? This is not correct syntax at least on PostgreSQL
740            if join_tables:
741                tokens.append(
742                    self.dialect.left_join(
743                        ",".join([table_alias(t) for t in join_tables]), query_env
744                    )
745                )
746            tokens += [self.dialect.left_join(t, query_env) for t in join_on]
747            sql_t = " ".join(tokens)
748        elif join and left:
749            all_tables_in_query = set(
750                important_tablenames + iimportant_tablenames + query_tables
751            )
752            tables_in_joinon = set(join_on_tables + ijoin_on_tables)
753            tables_not_in_joinon = list(
754                all_tables_in_query.difference(tables_in_joinon)
755            )
756            tokens = [table_alias(tables_not_in_joinon[0])]
757            tokens += [
758                self.dialect.cross_join(table_alias(t), query_env)
759                for t in tables_not_in_joinon[1:]
760            ]
761            tokens += [self.dialect.join(t, query_env) for t in ijoin_on]
762            # FIXME: WTF? This is not correct syntax at least on PostgreSQL
763            if join_tables:
764                tokens.append(
765                    self.dialect.left_join(
766                        ",".join([table_alias(t) for t in join_tables]), query_env
767                    )
768                )
769            tokens += [self.dialect.left_join(t, query_env) for t in join_on]
770            sql_t = " ".join(tokens)
771        else:
772            sql_t = ", ".join(table_alias(t) for t in query_tables)
773        #: expand query if needed
774        if query:
775            query = self.expand(query, query_env=query_env)
776        if having:
777            having = self.expand(having, query_env=query_env)
778        #: groupby
779        sql_grp = groupby
780        if groupby:
781            if isinstance(groupby, (list, tuple)):
782                groupby = xorify(groupby)
783            sql_grp = self.expand(groupby, query_env=query_env)
784        #: orderby
785        sql_ord = False
786        if orderby:
787            if isinstance(orderby, (list, tuple)):
788                orderby = xorify(orderby)
789            if str(orderby) == "<random>":
790                sql_ord = self.dialect.random
791            else:
792                sql_ord = self.expand(orderby, query_env=query_env)
793        #: set default orderby if missing
794        if (
795            limitby
796            and not groupby
797            and query_tables
798            and orderby_on_limitby
799            and not orderby
800        ):
801            sql_ord = ", ".join(
802                [
803                    tablemap[t][x].sqlsafe
804                    for t in query_tables
805                    if not isinstance(tablemap[t], Select)
806                    for x in (
807                        hasattr(tablemap[t], "_primarykey")
808                        and tablemap[t]._primarykey
809                        or ["_id"]
810                    )
811                ]
812            )
813        #: build sql using dialect
814        return (
815            colnames,
816            self.dialect.select(
817                sql_fields,
818                sql_t,
819                query,
820                sql_grp,
821                having,
822                sql_ord,
823                limitby,
824                distinct,
825                for_update and self.can_select_for_update,
826            ),
827        )
828
829    def _select(self, query, fields, attributes):
830        return self._select_wcols(query, fields, **attributes)[1]
831
832    def nested_select(self, query, fields, attributes):
833        return Select(self.db, query, fields, attributes)
834
835    def _select_aux_execute(self, sql):
836        self.execute(sql)
837        return self.cursor.fetchall()
838
839    def _select_aux(self, sql, fields, attributes, colnames):
840        cache = attributes.get("cache", None)
841        if not cache:
842            rows = self._select_aux_execute(sql)
843        else:
844            if isinstance(cache, dict):
845                cache_model = cache["model"]
846                time_expire = cache["expiration"]
847                key = cache.get("key")
848                if not key:
849                    key = self.uri + "/" + sql + "/rows"
850                    key = hashlib_md5(key).hexdigest()
851            else:
852                (cache_model, time_expire) = cache
853                key = self.uri + "/" + sql + "/rows"
854                key = hashlib_md5(key).hexdigest()
855            rows = cache_model(
856                key,
857                lambda self=self, sql=sql: self._select_aux_execute(sql),
858                time_expire,
859            )
860        if isinstance(rows, tuple):
861            rows = list(rows)
862        limitby = attributes.get("limitby", None) or (0,)
863        rows = self.rowslice(rows, limitby[0], None)
864        processor = attributes.get("processor", self.parse)
865        cacheable = attributes.get("cacheable", False)
866        return processor(rows, fields, colnames, cacheable=cacheable)
867
868    def _cached_select(self, cache, sql, fields, attributes, colnames):
869        del attributes["cache"]
870        (cache_model, time_expire) = cache
871        key = self.uri + "/" + sql
872        key = hashlib_md5(key).hexdigest()
873        args = (sql, fields, attributes, colnames)
874        ret = cache_model(
875            key, lambda self=self, args=args: self._select_aux(*args), time_expire
876        )
877        ret._restore_fields(fields)
878        return ret
879
880    def select(self, query, fields, attributes):
881        colnames, sql = self._select_wcols(query, fields, **attributes)
882        cache = attributes.get("cache", None)
883        if cache and attributes.get("cacheable", False):
884            return self._cached_select(cache, sql, fields, attributes, colnames)
885        return self._select_aux(sql, fields, attributes, colnames)
886
887    def iterselect(self, query, fields, attributes):
888        colnames, sql = self._select_wcols(query, fields, **attributes)
889        cacheable = attributes.get("cacheable", False)
890        return self.iterparse(sql, fields, colnames, cacheable=cacheable)
891
892    def _count(self, query, distinct=None):
893        tablemap = self.tables(query)
894        tablenames = list(tablemap)
895        tables = list(tablemap.values())
896        query_env = dict(current_scope=tablenames)
897        sql_q = ""
898        if query:
899            if use_common_filters(query):
900                query = self.common_filter(query, tables)
901            sql_q = self.expand(query, query_env=query_env)
902        sql_t = ",".join(self.table_alias(t, []) for t in tables)
903        sql_fields = "*"
904        if distinct:
905            if isinstance(distinct, (list, tuple)):
906                distinct = xorify(distinct)
907            sql_fields = self.expand(distinct, query_env=query_env)
908        return self.dialect.select(
909            self.dialect.count(sql_fields, distinct), sql_t, sql_q
910        )
911
912    def count(self, query, distinct=None):
913        self.execute(self._count(query, distinct))
914        return self.cursor.fetchone()[0]
915
916    def bulk_insert(self, table, items):
917        return [self.insert(table, item) for item in items]
918
919    def create_table(self, *args, **kwargs):
920        return self.migrator.create_table(*args, **kwargs)
921
922    def _drop_table_cleanup(self, table):
923        super(SQLAdapter, self)._drop_table_cleanup(table)
924        if table._dbt:
925            self.migrator.file_delete(table._dbt)
926            self.migrator.log("success!\n", table)
927
928    def drop_table(self, table, mode=""):
929        queries = self.dialect.drop_table(table, mode)
930        for query in queries:
931            if table._dbt:
932                self.migrator.log(query + "\n", table)
933            self.execute(query)
934        self.commit()
935        self._drop_table_cleanup(table)
936
937    @deprecated("drop", "drop_table", "SQLAdapter")
938    def drop(self, table, mode=""):
939        return self.drop_table(table, mode="")
940
941    def truncate(self, table, mode=""):
942        # Prepare functions "write_to_logfile" and "close_logfile"
943        try:
944            queries = self.dialect.truncate(table, mode)
945            for query in queries:
946                self.migrator.log(query + "\n", table)
947                self.execute(query)
948            self.migrator.log("success!\n", table)
949        finally:
950            pass
951
952    def create_index(self, table, index_name, *fields, **kwargs):
953        expressions = [
954            field._rname if isinstance(field, Field) else field for field in fields
955        ]
956        sql = self.dialect.create_index(index_name, table, expressions, **kwargs)
957        try:
958            self.execute(sql)
959            self.commit()
960        except Exception as e:
961            self.rollback()
962            err = (
963                "Error creating index %s\n  Driver error: %s\n"
964                + "  SQL instruction: %s"
965            )
966            raise RuntimeError(err % (index_name, str(e), sql))
967        return True
968
969    def drop_index(self, table, index_name):
970        sql = self.dialect.drop_index(index_name, table)
971        try:
972            self.execute(sql)
973            self.commit()
974        except Exception as e:
975            self.rollback()
976            err = "Error dropping index %s\n  Driver error: %s"
977            raise RuntimeError(err % (index_name, str(e)))
978        return True
979
980    def distributed_transaction_begin(self, key):
981        pass
982
983    @with_connection
984    def commit(self):
985        return self.connection.commit()
986
987    @with_connection
988    def rollback(self):
989        return self.connection.rollback()
990
991    @with_connection
992    def prepare(self, key):
993        self.connection.prepare()
994
995    @with_connection
996    def commit_prepared(self, key):
997        self.connection.commit()
998
999    @with_connection
1000    def rollback_prepared(self, key):
1001        self.connection.rollback()
1002
1003    def create_sequence_and_triggers(self, query, table, **args):
1004        self.execute(query)
1005
1006    def sqlsafe_table(self, tablename, original_tablename=None):
1007        if original_tablename is not None:
1008            return self.dialect.alias(original_tablename, tablename)
1009        return self.dialect.quote(tablename)
1010
1011    def sqlsafe_field(self, fieldname):
1012        return self.dialect.quote(fieldname)
1013
1014    def table_alias(self, tbl, current_scope=[]):
1015        if isinstance(tbl, basestring):
1016            tbl = self.db[tbl]
1017        return tbl.query_name(current_scope)[0]
1018
1019    def id_query(self, table):
1020        pkeys = getattr(table, "_primarykey", None)
1021        if pkeys:
1022            return table[pkeys[0]] != None
1023        return table._id != None
1024
1025
1026class NoSQLAdapter(BaseAdapter):
1027    can_select_for_update = False
1028
1029    def commit(self):
1030        pass
1031
1032    def rollback(self):
1033        pass
1034
1035    def prepare(self):
1036        pass
1037
1038    def commit_prepared(self, key):
1039        pass
1040
1041    def rollback_prepared(self, key):
1042        pass
1043
1044    def id_query(self, table):
1045        return table._id > 0
1046
1047    def create_table(self, table, migrate=True, fake_migrate=False, polymodel=None):
1048        table._dbt = None
1049        table._notnulls = []
1050        for field_name in table.fields:
1051            if table[field_name].notnull:
1052                table._notnulls.append(field_name)
1053        table._uniques = []
1054        for field_name in table.fields:
1055            if table[field_name].unique:
1056                # this is unnecessary if the fields are indexed and unique
1057                table._uniques.append(field_name)
1058
1059    def drop_table(self, table, mode=""):
1060        ctable = self.connection[table._tablename]
1061        ctable.drop()
1062        self._drop_table_cleanup(table)
1063
1064    @deprecated("drop", "drop_table", "SQLAdapter")
1065    def drop(self, table, mode=""):
1066        return self.drop_table(table, mode="")
1067
1068    def _select(self, *args, **kwargs):
1069        raise NotOnNOSQLError("Nested queries are not supported on NoSQL databases")
1070
1071    def nested_select(self, *args, **kwargs):
1072        raise NotOnNOSQLError("Nested queries are not supported on NoSQL databases")
1073
1074
1075class NullAdapter(BaseAdapter):
1076    def _load_dependencies(self):
1077        from ..dialects.base import CommonDialect
1078
1079        self.dialect = CommonDialect(self)
1080
1081    def find_driver(self):
1082        pass
1083
1084    def connector(self):
1085        return NullDriver()
Note: See TracBrowser for help on using the repository browser.