source: OpenRLabs-Git/deploy/rlabs-docker/web2py-rlabs/gluon/packages/dal/pydal/adapters/mongo.py

main
Last change on this file was 42bd667, checked in by David Fuertes <dfuertes@…>, 4 years ago

Historial Limpio

  • Property mode set to 100755
File size: 36.9 KB
Line 
1import re
2import copy
3import random
4from datetime import datetime
5from .._compat import basestring, long
6from ..exceptions import NotOnNOSQLError
7from ..helpers.classes import FakeCursor, Reference, SQLALL
8from ..helpers.methods import use_common_filters, xorify
9from ..objects import Field, Row, Query, Expression
10from .base import NoSQLAdapter
11from . import adapters
12
13try:
14    from bson import Binary
15    from bson.binary import USER_DEFINED_SUBTYPE
16except:
17
18    class Binary(object):
19        pass
20
21    USER_DEFINED_SUBTYPE = 0
22
23
24@adapters.register_for("mongodb")
25class Mongo(NoSQLAdapter):
26    dbengine = "mongodb"
27    drivers = ("pymongo",)
28
29    def find_driver(self):
30        super(Mongo, self).find_driver()
31        #: ensure pymongo version >= 3.0
32        if "fake_version" in self.driver_args:
33            version = self.driver_args["fake_version"]
34        else:
35            from pymongo import version
36        if int(version.split(".")[0]) < 3:
37            raise RuntimeError(
38                "pydal requires pymongo version >= 3.0, found '%s'" % version
39            )
40
41    def _initialize_(self):
42        super(Mongo, self)._initialize_()
43        #: uri parse
44        from pymongo import uri_parser
45
46        m = uri_parser.parse_uri(self.uri)
47        if isinstance(m, tuple):
48            m = {"database": m[1]}
49        if m.get("database") is None:
50            raise SyntaxError("Database is required!")
51        self._driver_db = m["database"]
52        #: mongodb imports and utils
53        from bson.objectid import ObjectId
54        from bson.son import SON
55        from pymongo.write_concern import WriteConcern
56
57        self.epoch = datetime.fromtimestamp(0)
58        self.SON = SON
59        self.ObjectId = ObjectId
60        self.WriteConcern = WriteConcern
61        #: options
62        self.db_codec = "UTF-8"
63        # this is the minimum amount of replicates that it should wait
64        # for on insert/update
65        self.minimumreplication = self.adapter_args.get("minimumreplication", 0)
66        # by default all inserts and selects are performed asynchronous,
67        # but now the default is
68        # synchronous, except when overruled by either this default or
69        # function parameter
70        self.safe = 1 if self.adapter_args.get("safe", True) else 0
71        self.get_connection()
72
73    def connector(self):
74        conn = self.driver.MongoClient(self.uri, w=self.safe)[self._driver_db]
75        conn.cursor = lambda: FakeCursor()
76        conn.close = lambda: None
77        conn.commit = lambda: None
78        return conn
79
80    def _after_first_connection(self):
81        # server version
82        self._server_version = self.connection.command("serverStatus")["version"]
83        self.server_version = tuple([int(x) for x in self._server_version.split(".")])
84        self.server_version_major = (
85            self.server_version[0] + self.server_version[1] / 10.0
86        )
87
88    def object_id(self, arg=None):
89        """ Convert input to a valid Mongodb ObjectId instance
90
91        self.object_id("<random>") -> ObjectId (not unique) instance """
92        if not arg:
93            arg = 0
94        if isinstance(arg, basestring):
95            # we assume an integer as default input
96            rawhex = len(arg.replace("0x", "").replace("L", "")) == 24
97            if arg.isdigit() and (not rawhex):
98                arg = int(arg)
99            elif arg == "<random>":
100                arg = int(
101                    "0x%s"
102                    % "".join([random.choice("0123456789abcdef") for x in range(24)]),
103                    0,
104                )
105            elif arg.isalnum():
106                if not arg.startswith("0x"):
107                    arg = "0x%s" % arg
108                try:
109                    arg = int(arg, 0)
110                except ValueError as e:
111                    raise ValueError("invalid objectid argument string: %s" % e)
112            else:
113                raise ValueError(
114                    "Invalid objectid argument string. "
115                    + "Requires an integer or base 16 value"
116                )
117        elif isinstance(arg, self.ObjectId):
118            return arg
119        elif isinstance(arg, (Row, Reference)):
120            return self.object_id(long(arg["id"]))
121        elif not isinstance(arg, (int, long)):
122            raise TypeError(
123                "object_id argument must be of type ObjectId or an objectid "
124                + "representable integer (type %s)" % type(arg)
125            )
126        hexvalue = hex(arg)[2:].rstrip("L").zfill(24)
127        return self.ObjectId(hexvalue)
128
129    def _get_collection(self, tablename, safe=None):
130        ctable = self.connection[tablename]
131        if safe is not None and safe != self.safe:
132            wc = self.WriteConcern(w=self._get_safe(safe))
133            ctable = ctable.with_options(write_concern=wc)
134        return ctable
135
136    def _get_safe(self, val=None):
137        if val is None:
138            return self.safe
139        return 1 if val else 0
140
141    def _regex_select_as_parser(self, colname):
142        return re.search(self.dialect.REGEX_SELECT_AS_PARSER, colname)
143
144    @staticmethod
145    def _parse_data(expression, attribute, value=None):
146        if isinstance(expression, (list, tuple)):
147            ret = False
148            for e in expression:
149                ret = Mongo._parse_data(e, attribute, value) or ret
150            return ret
151        if value is not None:
152            try:
153                expression._parse_data[attribute] = value
154            except AttributeError:
155                return None
156        try:
157            return expression._parse_data[attribute]
158        except (AttributeError, TypeError):
159            return None
160
161    def _expand(self, expression, field_type=None, query_env={}):
162        if isinstance(expression, Field):
163            if expression.type == "id":
164                result = "_id"
165            else:
166                result = expression.name
167            if self._parse_data(expression, "pipeline"):
168                # field names as part of expressions need to start with '$'
169                result = "$" + result
170        elif isinstance(expression, (Expression, Query)):
171            first = expression.first
172            second = expression.second
173            if isinstance(first, Field) and "reference" in first.type:
174                # cast to Mongo ObjectId
175                if isinstance(second, (tuple, list, set)):
176                    second = [self.object_id(item) for item in expression.second]
177                else:
178                    second = self.object_id(expression.second)
179            op = expression.op
180            optional_args = expression.optional_args or {}
181            optional_args["query_env"] = query_env
182            if second is not None:
183                result = op(first, second, **optional_args)
184            elif first is not None:
185                result = op(first, **optional_args)
186            elif isinstance(op, str):
187                result = op
188            else:
189                result = op(**optional_args)
190        elif isinstance(expression, Expansion):
191            expression.query = self.expand(
192                expression.query, field_type, query_env=query_env
193            )
194            result = expression
195        elif isinstance(expression, (list, tuple)):
196            result = [self.represent(item, field_type) for item in expression]
197        elif field_type:
198            result = self.represent(expression, field_type)
199        else:
200            result = expression
201        return result
202
203    def represent(self, obj, field_type):
204        if isinstance(obj, self.ObjectId):
205            return obj
206        return super(Mongo, self).represent(obj, field_type)
207
208    def truncate(self, table, mode, safe=None):
209        ctable = self.connection[table._tablename]
210        ctable.delete_many({})
211
212    def count(self, query, distinct=None, snapshot=True):
213        if not isinstance(query, Query):
214            raise SyntaxError("Type '%s' not supported in count" % type(query))
215        distinct_fields = []
216        if distinct is True:
217            distinct_fields = [x for x in query.first.table if x.name != "id"]
218        elif distinct:
219            if isinstance(distinct, Field):
220                distinct_fields = [distinct]
221            else:
222                while isinstance(distinct, Expression) and isinstance(
223                    distinct.second, Field
224                ):
225                    distinct_fields += [distinct.second]
226                    distinct = distinct.first
227                if isinstance(distinct, Field):
228                    distinct_fields += [distinct]
229            distinct = True
230        expanded = Expansion(
231            self, "count", query, fields=distinct_fields, distinct=distinct
232        )
233        ctable = expanded.get_collection()
234        if not expanded.pipeline:
235            return ctable.count(filter=expanded.query_dict)
236        for record in ctable.aggregate(expanded.pipeline):
237            return record["count"]
238        return 0
239
240    def select(self, query, fields, attributes, snapshot=False):
241        attributes["snapshot"] = snapshot
242        return self.__select(query, fields, **attributes)
243
244    def __select(
245        self,
246        query,
247        fields,
248        left=False,
249        join=False,
250        distinct=False,
251        orderby=False,
252        groupby=False,
253        having=False,
254        limitby=False,
255        orderby_on_limitby=True,
256        for_update=False,
257        outer_scoped=[],
258        required=None,
259        cache=None,
260        cacheable=None,
261        processor=None,
262        snapshot=False,
263    ):
264        new_fields = []
265        for item in fields:
266            if isinstance(item, SQLALL):
267                new_fields += item._table
268            else:
269                new_fields.append(item)
270        fields = new_fields
271        tablename = self.get_table(query, *fields)._tablename
272
273        if for_update:
274            self.db.logger.warning("Attribute 'for_update' unsupported by MongoDB")
275        if join or left:
276            raise NotOnNOSQLError("Joins not supported on NoSQL databases")
277        if required or cache or cacheable:
278            self.db.logger.warning(
279                "Attributes 'required', 'cache' and 'cacheable' are"
280                + " unsupported by MongoDB"
281            )
282
283        if limitby and orderby_on_limitby and not orderby:
284            if groupby:
285                orderby = groupby
286            else:
287                table = self.db[tablename]
288                orderby = [
289                    table[x]
290                    for x in (
291                        hasattr(table, "_primarykey") and table._primarykey or ["_id"]
292                    )
293                ]
294
295        if not orderby:
296            mongosort_list = []
297        else:
298            if snapshot:
299                raise RuntimeError("snapshot and orderby are mutually exclusive")
300            if isinstance(orderby, (list, tuple)):
301                orderby = xorify(orderby)
302
303            if str(orderby) == "<random>":
304                # !!!! need to add 'random'
305                mongosort_list = self.dialect.random
306            else:
307                mongosort_list = []
308                for f in self.expand(orderby).split(","):
309                    include = 1
310                    if f.startswith("-"):
311                        include = -1
312                        f = f[1:]
313                    if f.startswith("$"):
314                        f = f[1:]
315                    mongosort_list.append((f, include))
316
317        expanded = Expansion(
318            self,
319            "select",
320            query,
321            fields or self.db[tablename],
322            groupby=groupby,
323            distinct=distinct,
324            having=having,
325        )
326        ctable = self.connection[tablename]
327        modifiers = {"snapshot": snapshot}
328        if int("".join(self.driver.version.split("."))) > 370:
329            modifiers = {}
330
331        if not expanded.pipeline:
332            if limitby:
333                limitby_skip, limitby_limit = limitby[0], int(limitby[1]) - 1
334            else:
335                limitby_skip = limitby_limit = 0
336            mongo_list_dicts = ctable.find(
337                expanded.query_dict,
338                expanded.field_dicts,
339                skip=limitby_skip,
340                limit=limitby_limit,
341                sort=mongosort_list,
342                modifiers=modifiers,
343            )
344            null_rows = []
345        else:
346            if mongosort_list:
347                sortby_dict = self.SON()
348                for f in mongosort_list:
349                    sortby_dict[f[0]] = f[1]
350                expanded.pipeline.append({"$sort": sortby_dict})
351            if limitby and limitby[1]:
352                expanded.pipeline.append({"$limit": limitby[1]})
353            if limitby and limitby[0]:
354                expanded.pipeline.append({"$skip": limitby[0]})
355
356            mongo_list_dicts = ctable.aggregate(expanded.pipeline)
357            null_rows = [(None,)]
358
359        rows = []
360        # populate row in proper order
361        # Here we replace ._id with .id to follow the standard naming
362        colnames = []
363        newnames = []
364        for field in expanded.fields:
365            if hasattr(field, "tablename"):
366                if field.name in ("id", "_id"):
367                    # Mongodb reserved uuid key
368                    colname = (tablename + "." + "id", "_id")
369                else:
370                    colname = (field.longname, field.name)
371            elif not isinstance(query, Expression):
372                colname = (field.name, field.name)
373            colnames.append(colname[1])
374            newnames.append(colname[0])
375
376        for record in mongo_list_dicts:
377            row = []
378            for colname in colnames:
379                try:
380                    value = record[colname]
381                except:
382                    value = None
383                if self.server_version_major < 2.6:
384                    # '$size' not present in server versions < 2.6
385                    if isinstance(value, list) and "$addToSet" in colname:
386                        value = len(value)
387
388                row.append(value)
389            rows.append(row)
390        if not rows:
391            rows = null_rows
392
393        processor = processor or self.parse
394        result = processor(rows, fields, newnames, blob_decode=True)
395        return result
396
397    def check_notnull(self, table, values):
398        for fieldname in table._notnulls:
399            if fieldname not in values or values[fieldname] is None:
400                raise Exception("NOT NULL constraint failed: %s" % fieldname)
401
402    def check_unique(self, table, values):
403        if len(table._uniques) > 0:
404            db = table._db
405            unique_queries = []
406            for fieldname in table._uniques:
407                if fieldname in values:
408                    value = values[fieldname]
409                else:
410                    value = table[fieldname].default
411                unique_queries.append(
412                    Query(db, self.dialect.eq, table[fieldname], value)
413                )
414
415            if len(unique_queries) > 0:
416                unique_query = unique_queries[0]
417
418                # if more than one field, build a query of ORs
419                for query in unique_queries[1:]:
420                    unique_query = Query(db, self.dialect._or, unique_query, query)
421
422                if self.count(unique_query, distinct=False) != 0:
423                    for query in unique_queries:
424                        if self.count(query, distinct=False) != 0:
425                            # one of the 'OR' queries failed, see which one
426                            raise Exception(
427                                "NOT UNIQUE constraint failed: %s" % query.first.name
428                            )
429
430    def insert(self, table, fields, safe=None):
431        """Safe determines whether a asynchronous request is done or a
432        synchronous action is done
433        For safety, we use by default synchronous requests"""
434
435        values = {}
436        safe = self._get_safe(safe)
437        ctable = self._get_collection(table._tablename, safe)
438
439        for k, v in fields:
440            if k.name not in ["id", "safe"]:
441                fieldname = k.name
442                fieldtype = table[k.name].type
443                values[fieldname] = self.represent(v, fieldtype)
444
445        # validate notnulls
446        try:
447            self.check_notnull(table, values)
448        except Exception as e:
449            if hasattr(table, "_on_insert_error"):
450                return table._on_insert_error(table, fields, e)
451            raise e
452
453        # validate uniques
454        try:
455            self.check_unique(table, values)
456        except Exception as e:
457            if hasattr(table, "_on_insert_error"):
458                return table._on_insert_error(table, fields, e)
459            raise e
460
461        # perform the insert
462        result = ctable.insert_one(values)
463
464        if result.acknowledged:
465            Oid = result.inserted_id
466            rid = Reference(long(str(Oid), 16))
467            (rid._table, rid._record) = (table, None)
468            return rid
469        else:
470            return None
471
472    def update(self, table, query, fields, safe=None):
473        # return amount of adjusted rows or zero, but no exceptions
474        # @ related not finding the result
475        if not isinstance(query, Query):
476            raise RuntimeError("Not implemented")
477
478        safe = self._get_safe(safe)
479        if safe:
480            amount = 0
481        else:
482            amount = self.count(query, distinct=False)
483            if amount == 0:
484                return amount
485
486        expanded = Expansion(self, "update", query, fields)
487        ctable = expanded.get_collection(safe)
488        if expanded.pipeline:
489            try:
490                for doc in ctable.aggregate(expanded.pipeline):
491                    result = ctable.replace_one({"_id": doc["_id"]}, doc)
492                    if safe and result.acknowledged:
493                        amount += result.matched_count
494                return amount
495            except Exception as e:
496                # TODO Reverse update query to verify that the query succeeded
497                raise RuntimeError("uncaught exception when updating rows: %s" % e)
498        try:
499            result = ctable.update_many(
500                filter=expanded.query_dict, update={"$set": expanded.field_dicts}
501            )
502            if safe and result.acknowledged:
503                amount = result.matched_count
504            return amount
505        except Exception as e:
506            # TODO Reverse update query to verify that the query succeeded
507            raise RuntimeError("uncaught exception when updating rows: %s" % e)
508
509    def delete(self, table, query, safe=None):
510        if not isinstance(query, Query):
511            raise RuntimeError("query type %s is not supported" % type(query))
512
513        safe = self._get_safe(safe)
514        expanded = Expansion(self, "delete", query)
515        ctable = expanded.get_collection(safe)
516        if expanded.pipeline:
517            deleted = [x["_id"] for x in ctable.aggregate(expanded.pipeline)]
518        else:
519            deleted = [x["_id"] for x in ctable.find(expanded.query_dict)]
520
521        # find references to deleted items
522        db = self.db
523        cascade = []
524        set_null = []
525        for field in table._referenced_by:
526            if field.type == "reference " + table._tablename:
527                if field.ondelete == "CASCADE":
528                    cascade.append(field)
529                if field.ondelete == "SET NULL":
530                    set_null.append(field)
531        cascade_list = []
532        set_null_list = []
533        for field in table._referenced_by_list:
534            if field.type == "list:reference " + table._tablename:
535                if field.ondelete == "CASCADE":
536                    cascade_list.append(field)
537                if field.ondelete == "SET NULL":
538                    set_null_list.append(field)
539
540        # perform delete
541        result = ctable.delete_many({"_id": {"$in": deleted}})
542        if result.acknowledged:
543            amount = result.deleted_count
544        else:
545            amount = len(deleted)
546
547        # clean up any references
548        if amount and deleted:
549            # ::TODO:: test if deleted references cascade
550            def remove_from_list(field, deleted, safe):
551                for delete in deleted:
552                    modify = {field.name: delete}
553                    dtable = self._get_collection(field.tablename, safe)
554                    dtable.update_many(filter=modify, update={"$pull": modify})
555
556            # for cascaded items, if the reference is the only item in the
557            # list, then remove the entire record, else delete reference
558            # from the list
559            for field in cascade_list:
560                for delete in deleted:
561                    modify = {field.name: [delete]}
562                    dtable = self._get_collection(field.tablename, safe)
563                    dtable.delete_many(filter=modify)
564                remove_from_list(field, deleted, safe)
565            for field in set_null_list:
566                remove_from_list(field, deleted, safe)
567            for field in cascade:
568                db(field.belongs(deleted)).delete()
569            for field in set_null:
570                db(field.belongs(deleted)).update(**{field.name: None})
571        return amount
572
573    def bulk_insert(self, table, items):
574        return [self.insert(table, item) for item in items]
575
576
577class Expansion(object):
578    """
579    Class to encapsulate a pydal expression and track the parse
580    expansion and its results.
581
582    Two different MongoDB mechanisms are targeted here.  If the query
583    is sufficiently simple, then simple queries are generated.  The
584    bulk of the complexity here is however to support more complex
585    queries that are targeted to the MongoDB Aggregation Pipeline.
586
587    This class supports four operations: 'count', 'select', 'update'
588    and 'delete'.
589
590    Behavior varies somewhat for each operation type.  However
591    building each pipeline stage is shared where the behavior is the
592    same (or similar) for the different operations.
593
594    In general an attempt is made to build the query without using the
595    pipeline, and if that fails then the query is rebuilt with the
596    pipeline.
597
598    QUERY constructed in _build_pipeline_query():
599      $project : used to calculate expressions if needed
600      $match: filters out records
601
602    FIELDS constructed in _expand_fields():
603        FIELDS:COUNT
604          $group : filter for distinct if needed
605          $group: count the records remaining
606
607        FIELDS:SELECT
608          $group : implement aggregations if needed
609          $project: implement expressions (etc) for select
610
611        FIELDS:UPDATE
612          $project: implement expressions (etc) for update
613
614    HAVING constructed in _add_having():
615      $project : used to calculate expressions
616      $match: filters out records
617      $project : used to filter out previous expression fields
618
619    """
620
621    def __init__(
622        self,
623        adapter,
624        crud,
625        query,
626        fields=(),
627        tablename=None,
628        groupby=None,
629        distinct=False,
630        having=None,
631    ):
632        self.adapter = adapter
633        self.NULL_QUERY = {
634            "_id": {"$gt": self.adapter.ObjectId("000000000000000000000000")}
635        }
636        self._parse_data = {
637            "pipeline": False,
638            "need_group": bool(groupby or distinct or having),
639        }
640        self.crud = crud
641        self.having = having
642        self.distinct = distinct
643        if not groupby and distinct:
644            if distinct is True:
645                # groupby gets all fields
646                self.groupby = fields
647            else:
648                self.groupby = distinct
649        else:
650            self.groupby = groupby
651
652        if crud == "update":
653            self.values = [
654                (f[0], self.annotate_expression(f[1])) for f in (fields or [])
655            ]
656            self.fields = [f[0] for f in self.values]
657        else:
658            self.fields = [self.annotate_expression(f) for f in (fields or [])]
659
660        self.tablename = tablename or adapter.get_table(query, *self.fields)._tablename
661        if use_common_filters(query):
662            query = adapter.common_filter(query, [self.tablename])
663        self.query = self.annotate_expression(query)
664
665        # expand the query
666        self.pipeline = []
667        self.query_dict = adapter.expand(self.query)
668        self.field_dicts = adapter.SON()
669        self.field_groups = adapter.SON()
670        self.field_groups["_id"] = adapter.SON()
671
672        if self._parse_data["pipeline"]:
673            # if the query needs the aggregation engine, set that up
674            self._build_pipeline_query()
675
676            # expand the fields for the aggregation engine
677            self._expand_fields(None)
678        else:
679            # expand the fields
680            try:
681                if not self._parse_data["need_group"]:
682                    self._expand_fields(self._fields_loop_abort)
683                else:
684                    self._parse_data["pipeline"] = True
685                    raise StopIteration
686            except StopIteration:
687                # if the fields needs the aggregation engine, set that up
688                self.field_dicts = adapter.SON()
689                if self.query_dict:
690                    if self.query_dict != self.NULL_QUERY:
691                        self.pipeline = [{"$match": self.query_dict}]
692                    self.query_dict = {}
693                # expand the fields for the aggregation engine
694                self._expand_fields(None)
695
696        if not self._parse_data["pipeline"]:
697            if crud == "update":
698                # do not update id fields
699                for fieldname in ("_id", "id"):
700                    if fieldname in self.field_dicts:
701                        del self.field_dicts[fieldname]
702        else:
703            if crud == "update":
704                self._add_all_fields_projection(self.field_dicts)
705                self.field_dicts = adapter.SON()
706
707            elif crud == "select":
708                if self._parse_data["need_group"]:
709                    if not self.groupby:
710                        # no groupby, aggregate all records
711                        self.field_groups["_id"] = None
712                    # id has no value after aggregations
713                    self.field_dicts["_id"] = False
714                    self.pipeline.append({"$group": self.field_groups})
715                if self.field_dicts:
716                    self.pipeline.append({"$project": self.field_dicts})
717                    self.field_dicts = adapter.SON()
718                self._add_having()
719
720            elif crud == "count":
721                if self._parse_data["need_group"]:
722                    self.pipeline.append({"$group": self.field_groups})
723                self.pipeline.append({"$group": {"_id": None, "count": {"$sum": 1}}})
724
725            # elif crud == 'delete':
726            #    pass
727
728    @property
729    def dialect(self):
730        return self.adapter.dialect
731
732    def _build_pipeline_query(self):
733        # search for anything needing the $match stage.
734        #   currently only '$regex' requires the match stage
735        def parse_need_match_stage(items, parent, parent_key):
736            need_match = False
737            non_matched_indices = []
738            if isinstance(items, list):
739                indices = range(len(items))
740            elif isinstance(items, dict):
741                indices = items.keys()
742            else:
743                return
744
745            for i in indices:
746                if parse_need_match_stage(items[i], items, i):
747                    need_match = True
748
749                elif i not in [self.dialect.REGEXP_MARK1, self.dialect.REGEXP_MARK2]:
750                    non_matched_indices.append(i)
751
752                if i == self.dialect.REGEXP_MARK1:
753                    need_match = True
754                    self.query_dict["project"].update(items[i])
755                    parent[parent_key] = items[self.dialect.REGEXP_MARK2]
756
757            if need_match:
758                for i in non_matched_indices:
759                    name = str(items[i])
760                    self.query_dict["project"][name] = items[i]
761                    items[i] = {name: True}
762
763            if parent is None and self.query_dict["project"]:
764                self.query_dict["match"] = items
765            return need_match
766
767        expanded = self.adapter.expand(self.query)
768
769        if self.dialect.REGEXP_MARK1 in expanded:
770            # the REGEXP_MARK is at the top of the tree, so can just split
771            # the regex over a '$project' and a '$match'
772            self.query_dict = None
773            match = expanded[self.dialect.REGEXP_MARK2]
774            project = expanded[self.dialect.REGEXP_MARK1]
775
776        else:
777            self.query_dict = {"project": {}, "match": {}}
778            if parse_need_match_stage(expanded, None, None):
779                project = self.query_dict["project"]
780                match = self.query_dict["match"]
781            else:
782                project = {"__query__": expanded}
783                match = {"__query__": True}
784
785        if self.crud in ["select", "update"]:
786            self._add_all_fields_projection(project)
787        else:
788            self.pipeline.append({"$project": project})
789        self.pipeline.append({"$match": match})
790        self.query_dict = None
791
792    def _expand_fields(self, mid_loop):
793        if self.crud == "update":
794            mid_loop = mid_loop or self._fields_loop_update_pipeline
795            for field, value in self.values:
796                self._expand_field(field, value, mid_loop)
797        elif self.crud in ["select", "count"]:
798            mid_loop = mid_loop or self._fields_loop_select_pipeline
799            for field in self.fields:
800                self._expand_field(field, field, mid_loop)
801        elif self.fields:
802            raise RuntimeError(self.crud + " not supported with fields")
803
804    def _expand_field(self, field, value, mid_loop):
805        expanded = {}
806        if isinstance(field, Field):
807            expanded = self.adapter.expand(value, field.type)
808        elif isinstance(field, (Expression, Query)):
809            expanded = self.adapter.expand(field)
810            field.name = str(expanded)
811        else:
812            raise RuntimeError("%s not supported with fields" % type(field))
813
814        if mid_loop:
815            expanded = mid_loop(expanded, field, value)
816        self.field_dicts[field.name] = expanded
817
818    def _fields_loop_abort(self, expanded, *args):
819        # if we need the aggregation engine, then start over
820        if self._parse_data["pipeline"]:
821            raise StopIteration()
822        return expanded
823
824    def _fields_loop_update_pipeline(self, expanded, field, value):
825        if not isinstance(value, Expression):
826            if self.adapter.server_version_major >= 2.6:
827                expanded = {"$literal": expanded}
828
829            # '$literal' not present in server versions < 2.6
830            elif field.type in ["string", "text", "password"]:
831                expanded = {"$concat": [expanded]}
832            elif field.type in ["integer", "bigint", "float", "double"]:
833                expanded = {"$add": [expanded]}
834            elif field.type == "boolean":
835                expanded = {"$and": [expanded]}
836            elif field.type in ["date", "time", "datetime"]:
837                expanded = {"$add": [expanded]}
838            else:
839                raise RuntimeError(
840                    "updating with expressions not supported for field type "
841                    + "'%s' in MongoDB version < 2.6" % field.type
842                )
843        return expanded
844
845    def _fields_loop_select_pipeline(self, expanded, field, value):
846        # search for anything needing $group
847        def parse_groups(items, parent, parent_key):
848            for item in items:
849                if isinstance(items[item], list):
850                    for list_item in items[item]:
851                        if isinstance(list_item, dict):
852                            parse_groups(
853                                list_item, items[item], items[item].index(list_item)
854                            )
855
856                elif isinstance(items[item], dict):
857                    parse_groups(items[item], items, item)
858
859                if item == self.dialect.GROUP_MARK:
860                    name = str(items)
861                    self.field_groups[name] = items[item]
862                    parent[parent_key] = "$" + name
863            return items
864
865        if self.dialect.AS_MARK in field.name:
866            # The AS_MARK in the field name is used by base to alias the
867            # result, we don't actually need the AS_MARK in the parse tree
868            # so we remove it here.
869            if isinstance(expanded, list):
870                # AS mark is first element in list, drop it
871                expanded = expanded[1]
872
873            elif self.dialect.AS_MARK in expanded:
874                # AS mark is element in dict, drop it
875                del expanded[self.dialect.AS_MARK]
876
877            else:
878                # ::TODO:: should be possible to do this...
879                raise SyntaxError("AS() not at top of parse tree")
880
881        if self.dialect.GROUP_MARK in expanded:
882            # the GROUP_MARK is at the top of the tree, so can just pass
883            # the group result straight through the '$project' stage
884            self.field_groups[field.name] = expanded[self.dialect.GROUP_MARK]
885            expanded = 1
886
887        elif self.dialect.GROUP_MARK in field.name:
888            # the GROUP_MARK is not at the top of the tree, so we need to
889            # pass the group results through to a '$project' stage.
890            expanded = parse_groups(expanded, None, None)
891
892        elif self._parse_data["need_group"]:
893            if field in self.groupby:
894                # this is a 'groupby' field
895                self.field_groups["_id"][field.name] = expanded
896                expanded = "$_id." + field.name
897            else:
898                raise SyntaxError("field '%s' not in groupby" % field)
899
900        return expanded
901
902    def _add_all_fields_projection(self, fields):
903        for fieldname in self.adapter.db[self.tablename].fields:
904            # add all fields to projection to pass them through
905            if fieldname not in fields and fieldname not in ("_id", "id"):
906                fields[fieldname] = 1
907        self.pipeline.append({"$project": fields})
908
909    def _add_having(self):
910        if not self.having:
911            return
912        self._expand_field(self.having, None, self._fields_loop_select_pipeline)
913        fields = {"__having__": self.field_dicts[self.having.name]}
914        for fieldname in self.pipeline[-1]["$project"]:
915            # add all fields to projection to pass them through
916            if fieldname not in fields and fieldname not in ("_id", "id"):
917                fields[fieldname] = 1
918
919        self.pipeline.append({"$project": copy.copy(fields)})
920        self.pipeline.append({"$match": {"__having__": True}})
921        del fields["__having__"]
922        self.pipeline.append({"$project": fields})
923
924    def annotate_expression(self, expression):
925        def mark_has_field(expression):
926            if not isinstance(expression, (Expression, Query)):
927                return False
928            first_has_field = mark_has_field(expression.first)
929            second_has_field = mark_has_field(expression.second)
930            expression.has_field = (
931                isinstance(expression, Field) or first_has_field or second_has_field
932            )
933            return expression.has_field
934
935        def add_parse_data(child, parent):
936            if isinstance(child, (Expression, Query)):
937                child.parse_root = parent.parse_root
938                child.parse_parent = parent
939                child.parse_depth = parent.parse_depth + 1
940                child._parse_data = parent._parse_data
941                add_parse_data(child.first, child)
942                add_parse_data(child.second, child)
943            elif isinstance(child, (list, tuple)):
944                for c in child:
945                    add_parse_data(c, parent)
946
947        if isinstance(expression, (Expression, Query)):
948            expression.parse_root = expression
949            expression.parse_depth = -1
950            expression._parse_data = self._parse_data
951            add_parse_data(expression, expression)
952        mark_has_field(expression)
953        return expression
954
955    def get_collection(self, safe=None):
956        return self.adapter._get_collection(self.tablename, safe)
957
958
959class MongoBlob(Binary):
960    MONGO_BLOB_BYTES = USER_DEFINED_SUBTYPE
961    MONGO_BLOB_NON_UTF8_STR = USER_DEFINED_SUBTYPE + 1
962
963    def __new__(cls, value):
964        # return None and Binary() unmolested
965        if value is None or isinstance(value, Binary):
966            return value
967
968        # bytearray is marked as MONGO_BLOB_BYTES
969        if isinstance(value, bytearray):
970            return Binary.__new__(cls, bytes(value), MongoBlob.MONGO_BLOB_BYTES)
971
972        # return non-strings as Binary(), eg: PY3 bytes()
973        if not isinstance(value, basestring):
974            return Binary(value)
975
976        # if string is encodable as UTF-8, then return as string
977        try:
978            value.encode("utf-8")
979            return value
980        except UnicodeDecodeError:
981            # string which can not be UTF-8 encoded, eg: pickle strings
982            return Binary.__new__(cls, value, MongoBlob.MONGO_BLOB_NON_UTF8_STR)
983
984    def __repr__(self):
985        return repr(MongoBlob.decode(self))
986
987    @staticmethod
988    def decode(value):
989        if isinstance(value, Binary):
990            if value.subtype == MongoBlob.MONGO_BLOB_BYTES:
991                return bytearray(value)
992            if value.subtype == MongoBlob.MONGO_BLOB_NON_UTF8_STR:
993                return str(value)
994        return value
Note: See TracBrowser for help on using the repository browser.