1 | import re |
---|
2 | import copy |
---|
3 | import random |
---|
4 | from datetime import datetime |
---|
5 | from .._compat import basestring, long |
---|
6 | from ..exceptions import NotOnNOSQLError |
---|
7 | from ..helpers.classes import FakeCursor, Reference, SQLALL |
---|
8 | from ..helpers.methods import use_common_filters, xorify |
---|
9 | from ..objects import Field, Row, Query, Expression |
---|
10 | from .base import NoSQLAdapter |
---|
11 | from . import adapters |
---|
12 | |
---|
13 | try: |
---|
14 | from bson import Binary |
---|
15 | from bson.binary import USER_DEFINED_SUBTYPE |
---|
16 | except: |
---|
17 | |
---|
18 | class Binary(object): |
---|
19 | pass |
---|
20 | |
---|
21 | USER_DEFINED_SUBTYPE = 0 |
---|
22 | |
---|
23 | |
---|
24 | @adapters.register_for("mongodb") |
---|
25 | class Mongo(NoSQLAdapter): |
---|
26 | dbengine = "mongodb" |
---|
27 | drivers = ("pymongo",) |
---|
28 | |
---|
29 | def find_driver(self): |
---|
30 | super(Mongo, self).find_driver() |
---|
31 | #: ensure pymongo version >= 3.0 |
---|
32 | if "fake_version" in self.driver_args: |
---|
33 | version = self.driver_args["fake_version"] |
---|
34 | else: |
---|
35 | from pymongo import version |
---|
36 | if int(version.split(".")[0]) < 3: |
---|
37 | raise RuntimeError( |
---|
38 | "pydal requires pymongo version >= 3.0, found '%s'" % version |
---|
39 | ) |
---|
40 | |
---|
41 | def _initialize_(self): |
---|
42 | super(Mongo, self)._initialize_() |
---|
43 | #: uri parse |
---|
44 | from pymongo import uri_parser |
---|
45 | |
---|
46 | m = uri_parser.parse_uri(self.uri) |
---|
47 | if isinstance(m, tuple): |
---|
48 | m = {"database": m[1]} |
---|
49 | if m.get("database") is None: |
---|
50 | raise SyntaxError("Database is required!") |
---|
51 | self._driver_db = m["database"] |
---|
52 | #: mongodb imports and utils |
---|
53 | from bson.objectid import ObjectId |
---|
54 | from bson.son import SON |
---|
55 | from pymongo.write_concern import WriteConcern |
---|
56 | |
---|
57 | self.epoch = datetime.fromtimestamp(0) |
---|
58 | self.SON = SON |
---|
59 | self.ObjectId = ObjectId |
---|
60 | self.WriteConcern = WriteConcern |
---|
61 | #: options |
---|
62 | self.db_codec = "UTF-8" |
---|
63 | # this is the minimum amount of replicates that it should wait |
---|
64 | # for on insert/update |
---|
65 | self.minimumreplication = self.adapter_args.get("minimumreplication", 0) |
---|
66 | # by default all inserts and selects are performed asynchronous, |
---|
67 | # but now the default is |
---|
68 | # synchronous, except when overruled by either this default or |
---|
69 | # function parameter |
---|
70 | self.safe = 1 if self.adapter_args.get("safe", True) else 0 |
---|
71 | self.get_connection() |
---|
72 | |
---|
73 | def connector(self): |
---|
74 | conn = self.driver.MongoClient(self.uri, w=self.safe)[self._driver_db] |
---|
75 | conn.cursor = lambda: FakeCursor() |
---|
76 | conn.close = lambda: None |
---|
77 | conn.commit = lambda: None |
---|
78 | return conn |
---|
79 | |
---|
80 | def _after_first_connection(self): |
---|
81 | # server version |
---|
82 | self._server_version = self.connection.command("serverStatus")["version"] |
---|
83 | self.server_version = tuple([int(x) for x in self._server_version.split(".")]) |
---|
84 | self.server_version_major = ( |
---|
85 | self.server_version[0] + self.server_version[1] / 10.0 |
---|
86 | ) |
---|
87 | |
---|
88 | def object_id(self, arg=None): |
---|
89 | """ Convert input to a valid Mongodb ObjectId instance |
---|
90 | |
---|
91 | self.object_id("<random>") -> ObjectId (not unique) instance """ |
---|
92 | if not arg: |
---|
93 | arg = 0 |
---|
94 | if isinstance(arg, basestring): |
---|
95 | # we assume an integer as default input |
---|
96 | rawhex = len(arg.replace("0x", "").replace("L", "")) == 24 |
---|
97 | if arg.isdigit() and (not rawhex): |
---|
98 | arg = int(arg) |
---|
99 | elif arg == "<random>": |
---|
100 | arg = int( |
---|
101 | "0x%s" |
---|
102 | % "".join([random.choice("0123456789abcdef") for x in range(24)]), |
---|
103 | 0, |
---|
104 | ) |
---|
105 | elif arg.isalnum(): |
---|
106 | if not arg.startswith("0x"): |
---|
107 | arg = "0x%s" % arg |
---|
108 | try: |
---|
109 | arg = int(arg, 0) |
---|
110 | except ValueError as e: |
---|
111 | raise ValueError("invalid objectid argument string: %s" % e) |
---|
112 | else: |
---|
113 | raise ValueError( |
---|
114 | "Invalid objectid argument string. " |
---|
115 | + "Requires an integer or base 16 value" |
---|
116 | ) |
---|
117 | elif isinstance(arg, self.ObjectId): |
---|
118 | return arg |
---|
119 | elif isinstance(arg, (Row, Reference)): |
---|
120 | return self.object_id(long(arg["id"])) |
---|
121 | elif not isinstance(arg, (int, long)): |
---|
122 | raise TypeError( |
---|
123 | "object_id argument must be of type ObjectId or an objectid " |
---|
124 | + "representable integer (type %s)" % type(arg) |
---|
125 | ) |
---|
126 | hexvalue = hex(arg)[2:].rstrip("L").zfill(24) |
---|
127 | return self.ObjectId(hexvalue) |
---|
128 | |
---|
129 | def _get_collection(self, tablename, safe=None): |
---|
130 | ctable = self.connection[tablename] |
---|
131 | if safe is not None and safe != self.safe: |
---|
132 | wc = self.WriteConcern(w=self._get_safe(safe)) |
---|
133 | ctable = ctable.with_options(write_concern=wc) |
---|
134 | return ctable |
---|
135 | |
---|
136 | def _get_safe(self, val=None): |
---|
137 | if val is None: |
---|
138 | return self.safe |
---|
139 | return 1 if val else 0 |
---|
140 | |
---|
141 | def _regex_select_as_parser(self, colname): |
---|
142 | return re.search(self.dialect.REGEX_SELECT_AS_PARSER, colname) |
---|
143 | |
---|
144 | @staticmethod |
---|
145 | def _parse_data(expression, attribute, value=None): |
---|
146 | if isinstance(expression, (list, tuple)): |
---|
147 | ret = False |
---|
148 | for e in expression: |
---|
149 | ret = Mongo._parse_data(e, attribute, value) or ret |
---|
150 | return ret |
---|
151 | if value is not None: |
---|
152 | try: |
---|
153 | expression._parse_data[attribute] = value |
---|
154 | except AttributeError: |
---|
155 | return None |
---|
156 | try: |
---|
157 | return expression._parse_data[attribute] |
---|
158 | except (AttributeError, TypeError): |
---|
159 | return None |
---|
160 | |
---|
161 | def _expand(self, expression, field_type=None, query_env={}): |
---|
162 | if isinstance(expression, Field): |
---|
163 | if expression.type == "id": |
---|
164 | result = "_id" |
---|
165 | else: |
---|
166 | result = expression.name |
---|
167 | if self._parse_data(expression, "pipeline"): |
---|
168 | # field names as part of expressions need to start with '$' |
---|
169 | result = "$" + result |
---|
170 | elif isinstance(expression, (Expression, Query)): |
---|
171 | first = expression.first |
---|
172 | second = expression.second |
---|
173 | if isinstance(first, Field) and "reference" in first.type: |
---|
174 | # cast to Mongo ObjectId |
---|
175 | if isinstance(second, (tuple, list, set)): |
---|
176 | second = [self.object_id(item) for item in expression.second] |
---|
177 | else: |
---|
178 | second = self.object_id(expression.second) |
---|
179 | op = expression.op |
---|
180 | optional_args = expression.optional_args or {} |
---|
181 | optional_args["query_env"] = query_env |
---|
182 | if second is not None: |
---|
183 | result = op(first, second, **optional_args) |
---|
184 | elif first is not None: |
---|
185 | result = op(first, **optional_args) |
---|
186 | elif isinstance(op, str): |
---|
187 | result = op |
---|
188 | else: |
---|
189 | result = op(**optional_args) |
---|
190 | elif isinstance(expression, Expansion): |
---|
191 | expression.query = self.expand( |
---|
192 | expression.query, field_type, query_env=query_env |
---|
193 | ) |
---|
194 | result = expression |
---|
195 | elif isinstance(expression, (list, tuple)): |
---|
196 | result = [self.represent(item, field_type) for item in expression] |
---|
197 | elif field_type: |
---|
198 | result = self.represent(expression, field_type) |
---|
199 | else: |
---|
200 | result = expression |
---|
201 | return result |
---|
202 | |
---|
203 | def represent(self, obj, field_type): |
---|
204 | if isinstance(obj, self.ObjectId): |
---|
205 | return obj |
---|
206 | return super(Mongo, self).represent(obj, field_type) |
---|
207 | |
---|
208 | def truncate(self, table, mode, safe=None): |
---|
209 | ctable = self.connection[table._tablename] |
---|
210 | ctable.delete_many({}) |
---|
211 | |
---|
212 | def count(self, query, distinct=None, snapshot=True): |
---|
213 | if not isinstance(query, Query): |
---|
214 | raise SyntaxError("Type '%s' not supported in count" % type(query)) |
---|
215 | distinct_fields = [] |
---|
216 | if distinct is True: |
---|
217 | distinct_fields = [x for x in query.first.table if x.name != "id"] |
---|
218 | elif distinct: |
---|
219 | if isinstance(distinct, Field): |
---|
220 | distinct_fields = [distinct] |
---|
221 | else: |
---|
222 | while isinstance(distinct, Expression) and isinstance( |
---|
223 | distinct.second, Field |
---|
224 | ): |
---|
225 | distinct_fields += [distinct.second] |
---|
226 | distinct = distinct.first |
---|
227 | if isinstance(distinct, Field): |
---|
228 | distinct_fields += [distinct] |
---|
229 | distinct = True |
---|
230 | expanded = Expansion( |
---|
231 | self, "count", query, fields=distinct_fields, distinct=distinct |
---|
232 | ) |
---|
233 | ctable = expanded.get_collection() |
---|
234 | if not expanded.pipeline: |
---|
235 | return ctable.count(filter=expanded.query_dict) |
---|
236 | for record in ctable.aggregate(expanded.pipeline): |
---|
237 | return record["count"] |
---|
238 | return 0 |
---|
239 | |
---|
240 | def select(self, query, fields, attributes, snapshot=False): |
---|
241 | attributes["snapshot"] = snapshot |
---|
242 | return self.__select(query, fields, **attributes) |
---|
243 | |
---|
244 | def __select( |
---|
245 | self, |
---|
246 | query, |
---|
247 | fields, |
---|
248 | left=False, |
---|
249 | join=False, |
---|
250 | distinct=False, |
---|
251 | orderby=False, |
---|
252 | groupby=False, |
---|
253 | having=False, |
---|
254 | limitby=False, |
---|
255 | orderby_on_limitby=True, |
---|
256 | for_update=False, |
---|
257 | outer_scoped=[], |
---|
258 | required=None, |
---|
259 | cache=None, |
---|
260 | cacheable=None, |
---|
261 | processor=None, |
---|
262 | snapshot=False, |
---|
263 | ): |
---|
264 | new_fields = [] |
---|
265 | for item in fields: |
---|
266 | if isinstance(item, SQLALL): |
---|
267 | new_fields += item._table |
---|
268 | else: |
---|
269 | new_fields.append(item) |
---|
270 | fields = new_fields |
---|
271 | tablename = self.get_table(query, *fields)._tablename |
---|
272 | |
---|
273 | if for_update: |
---|
274 | self.db.logger.warning("Attribute 'for_update' unsupported by MongoDB") |
---|
275 | if join or left: |
---|
276 | raise NotOnNOSQLError("Joins not supported on NoSQL databases") |
---|
277 | if required or cache or cacheable: |
---|
278 | self.db.logger.warning( |
---|
279 | "Attributes 'required', 'cache' and 'cacheable' are" |
---|
280 | + " unsupported by MongoDB" |
---|
281 | ) |
---|
282 | |
---|
283 | if limitby and orderby_on_limitby and not orderby: |
---|
284 | if groupby: |
---|
285 | orderby = groupby |
---|
286 | else: |
---|
287 | table = self.db[tablename] |
---|
288 | orderby = [ |
---|
289 | table[x] |
---|
290 | for x in ( |
---|
291 | hasattr(table, "_primarykey") and table._primarykey or ["_id"] |
---|
292 | ) |
---|
293 | ] |
---|
294 | |
---|
295 | if not orderby: |
---|
296 | mongosort_list = [] |
---|
297 | else: |
---|
298 | if snapshot: |
---|
299 | raise RuntimeError("snapshot and orderby are mutually exclusive") |
---|
300 | if isinstance(orderby, (list, tuple)): |
---|
301 | orderby = xorify(orderby) |
---|
302 | |
---|
303 | if str(orderby) == "<random>": |
---|
304 | # !!!! need to add 'random' |
---|
305 | mongosort_list = self.dialect.random |
---|
306 | else: |
---|
307 | mongosort_list = [] |
---|
308 | for f in self.expand(orderby).split(","): |
---|
309 | include = 1 |
---|
310 | if f.startswith("-"): |
---|
311 | include = -1 |
---|
312 | f = f[1:] |
---|
313 | if f.startswith("$"): |
---|
314 | f = f[1:] |
---|
315 | mongosort_list.append((f, include)) |
---|
316 | |
---|
317 | expanded = Expansion( |
---|
318 | self, |
---|
319 | "select", |
---|
320 | query, |
---|
321 | fields or self.db[tablename], |
---|
322 | groupby=groupby, |
---|
323 | distinct=distinct, |
---|
324 | having=having, |
---|
325 | ) |
---|
326 | ctable = self.connection[tablename] |
---|
327 | modifiers = {"snapshot": snapshot} |
---|
328 | if int("".join(self.driver.version.split("."))) > 370: |
---|
329 | modifiers = {} |
---|
330 | |
---|
331 | if not expanded.pipeline: |
---|
332 | if limitby: |
---|
333 | limitby_skip, limitby_limit = limitby[0], int(limitby[1]) - 1 |
---|
334 | else: |
---|
335 | limitby_skip = limitby_limit = 0 |
---|
336 | mongo_list_dicts = ctable.find( |
---|
337 | expanded.query_dict, |
---|
338 | expanded.field_dicts, |
---|
339 | skip=limitby_skip, |
---|
340 | limit=limitby_limit, |
---|
341 | sort=mongosort_list, |
---|
342 | modifiers=modifiers, |
---|
343 | ) |
---|
344 | null_rows = [] |
---|
345 | else: |
---|
346 | if mongosort_list: |
---|
347 | sortby_dict = self.SON() |
---|
348 | for f in mongosort_list: |
---|
349 | sortby_dict[f[0]] = f[1] |
---|
350 | expanded.pipeline.append({"$sort": sortby_dict}) |
---|
351 | if limitby and limitby[1]: |
---|
352 | expanded.pipeline.append({"$limit": limitby[1]}) |
---|
353 | if limitby and limitby[0]: |
---|
354 | expanded.pipeline.append({"$skip": limitby[0]}) |
---|
355 | |
---|
356 | mongo_list_dicts = ctable.aggregate(expanded.pipeline) |
---|
357 | null_rows = [(None,)] |
---|
358 | |
---|
359 | rows = [] |
---|
360 | # populate row in proper order |
---|
361 | # Here we replace ._id with .id to follow the standard naming |
---|
362 | colnames = [] |
---|
363 | newnames = [] |
---|
364 | for field in expanded.fields: |
---|
365 | if hasattr(field, "tablename"): |
---|
366 | if field.name in ("id", "_id"): |
---|
367 | # Mongodb reserved uuid key |
---|
368 | colname = (tablename + "." + "id", "_id") |
---|
369 | else: |
---|
370 | colname = (field.longname, field.name) |
---|
371 | elif not isinstance(query, Expression): |
---|
372 | colname = (field.name, field.name) |
---|
373 | colnames.append(colname[1]) |
---|
374 | newnames.append(colname[0]) |
---|
375 | |
---|
376 | for record in mongo_list_dicts: |
---|
377 | row = [] |
---|
378 | for colname in colnames: |
---|
379 | try: |
---|
380 | value = record[colname] |
---|
381 | except: |
---|
382 | value = None |
---|
383 | if self.server_version_major < 2.6: |
---|
384 | # '$size' not present in server versions < 2.6 |
---|
385 | if isinstance(value, list) and "$addToSet" in colname: |
---|
386 | value = len(value) |
---|
387 | |
---|
388 | row.append(value) |
---|
389 | rows.append(row) |
---|
390 | if not rows: |
---|
391 | rows = null_rows |
---|
392 | |
---|
393 | processor = processor or self.parse |
---|
394 | result = processor(rows, fields, newnames, blob_decode=True) |
---|
395 | return result |
---|
396 | |
---|
397 | def check_notnull(self, table, values): |
---|
398 | for fieldname in table._notnulls: |
---|
399 | if fieldname not in values or values[fieldname] is None: |
---|
400 | raise Exception("NOT NULL constraint failed: %s" % fieldname) |
---|
401 | |
---|
402 | def check_unique(self, table, values): |
---|
403 | if len(table._uniques) > 0: |
---|
404 | db = table._db |
---|
405 | unique_queries = [] |
---|
406 | for fieldname in table._uniques: |
---|
407 | if fieldname in values: |
---|
408 | value = values[fieldname] |
---|
409 | else: |
---|
410 | value = table[fieldname].default |
---|
411 | unique_queries.append( |
---|
412 | Query(db, self.dialect.eq, table[fieldname], value) |
---|
413 | ) |
---|
414 | |
---|
415 | if len(unique_queries) > 0: |
---|
416 | unique_query = unique_queries[0] |
---|
417 | |
---|
418 | # if more than one field, build a query of ORs |
---|
419 | for query in unique_queries[1:]: |
---|
420 | unique_query = Query(db, self.dialect._or, unique_query, query) |
---|
421 | |
---|
422 | if self.count(unique_query, distinct=False) != 0: |
---|
423 | for query in unique_queries: |
---|
424 | if self.count(query, distinct=False) != 0: |
---|
425 | # one of the 'OR' queries failed, see which one |
---|
426 | raise Exception( |
---|
427 | "NOT UNIQUE constraint failed: %s" % query.first.name |
---|
428 | ) |
---|
429 | |
---|
430 | def insert(self, table, fields, safe=None): |
---|
431 | """Safe determines whether a asynchronous request is done or a |
---|
432 | synchronous action is done |
---|
433 | For safety, we use by default synchronous requests""" |
---|
434 | |
---|
435 | values = {} |
---|
436 | safe = self._get_safe(safe) |
---|
437 | ctable = self._get_collection(table._tablename, safe) |
---|
438 | |
---|
439 | for k, v in fields: |
---|
440 | if k.name not in ["id", "safe"]: |
---|
441 | fieldname = k.name |
---|
442 | fieldtype = table[k.name].type |
---|
443 | values[fieldname] = self.represent(v, fieldtype) |
---|
444 | |
---|
445 | # validate notnulls |
---|
446 | try: |
---|
447 | self.check_notnull(table, values) |
---|
448 | except Exception as e: |
---|
449 | if hasattr(table, "_on_insert_error"): |
---|
450 | return table._on_insert_error(table, fields, e) |
---|
451 | raise e |
---|
452 | |
---|
453 | # validate uniques |
---|
454 | try: |
---|
455 | self.check_unique(table, values) |
---|
456 | except Exception as e: |
---|
457 | if hasattr(table, "_on_insert_error"): |
---|
458 | return table._on_insert_error(table, fields, e) |
---|
459 | raise e |
---|
460 | |
---|
461 | # perform the insert |
---|
462 | result = ctable.insert_one(values) |
---|
463 | |
---|
464 | if result.acknowledged: |
---|
465 | Oid = result.inserted_id |
---|
466 | rid = Reference(long(str(Oid), 16)) |
---|
467 | (rid._table, rid._record) = (table, None) |
---|
468 | return rid |
---|
469 | else: |
---|
470 | return None |
---|
471 | |
---|
472 | def update(self, table, query, fields, safe=None): |
---|
473 | # return amount of adjusted rows or zero, but no exceptions |
---|
474 | # @ related not finding the result |
---|
475 | if not isinstance(query, Query): |
---|
476 | raise RuntimeError("Not implemented") |
---|
477 | |
---|
478 | safe = self._get_safe(safe) |
---|
479 | if safe: |
---|
480 | amount = 0 |
---|
481 | else: |
---|
482 | amount = self.count(query, distinct=False) |
---|
483 | if amount == 0: |
---|
484 | return amount |
---|
485 | |
---|
486 | expanded = Expansion(self, "update", query, fields) |
---|
487 | ctable = expanded.get_collection(safe) |
---|
488 | if expanded.pipeline: |
---|
489 | try: |
---|
490 | for doc in ctable.aggregate(expanded.pipeline): |
---|
491 | result = ctable.replace_one({"_id": doc["_id"]}, doc) |
---|
492 | if safe and result.acknowledged: |
---|
493 | amount += result.matched_count |
---|
494 | return amount |
---|
495 | except Exception as e: |
---|
496 | # TODO Reverse update query to verify that the query succeeded |
---|
497 | raise RuntimeError("uncaught exception when updating rows: %s" % e) |
---|
498 | try: |
---|
499 | result = ctable.update_many( |
---|
500 | filter=expanded.query_dict, update={"$set": expanded.field_dicts} |
---|
501 | ) |
---|
502 | if safe and result.acknowledged: |
---|
503 | amount = result.matched_count |
---|
504 | return amount |
---|
505 | except Exception as e: |
---|
506 | # TODO Reverse update query to verify that the query succeeded |
---|
507 | raise RuntimeError("uncaught exception when updating rows: %s" % e) |
---|
508 | |
---|
509 | def delete(self, table, query, safe=None): |
---|
510 | if not isinstance(query, Query): |
---|
511 | raise RuntimeError("query type %s is not supported" % type(query)) |
---|
512 | |
---|
513 | safe = self._get_safe(safe) |
---|
514 | expanded = Expansion(self, "delete", query) |
---|
515 | ctable = expanded.get_collection(safe) |
---|
516 | if expanded.pipeline: |
---|
517 | deleted = [x["_id"] for x in ctable.aggregate(expanded.pipeline)] |
---|
518 | else: |
---|
519 | deleted = [x["_id"] for x in ctable.find(expanded.query_dict)] |
---|
520 | |
---|
521 | # find references to deleted items |
---|
522 | db = self.db |
---|
523 | cascade = [] |
---|
524 | set_null = [] |
---|
525 | for field in table._referenced_by: |
---|
526 | if field.type == "reference " + table._tablename: |
---|
527 | if field.ondelete == "CASCADE": |
---|
528 | cascade.append(field) |
---|
529 | if field.ondelete == "SET NULL": |
---|
530 | set_null.append(field) |
---|
531 | cascade_list = [] |
---|
532 | set_null_list = [] |
---|
533 | for field in table._referenced_by_list: |
---|
534 | if field.type == "list:reference " + table._tablename: |
---|
535 | if field.ondelete == "CASCADE": |
---|
536 | cascade_list.append(field) |
---|
537 | if field.ondelete == "SET NULL": |
---|
538 | set_null_list.append(field) |
---|
539 | |
---|
540 | # perform delete |
---|
541 | result = ctable.delete_many({"_id": {"$in": deleted}}) |
---|
542 | if result.acknowledged: |
---|
543 | amount = result.deleted_count |
---|
544 | else: |
---|
545 | amount = len(deleted) |
---|
546 | |
---|
547 | # clean up any references |
---|
548 | if amount and deleted: |
---|
549 | # ::TODO:: test if deleted references cascade |
---|
550 | def remove_from_list(field, deleted, safe): |
---|
551 | for delete in deleted: |
---|
552 | modify = {field.name: delete} |
---|
553 | dtable = self._get_collection(field.tablename, safe) |
---|
554 | dtable.update_many(filter=modify, update={"$pull": modify}) |
---|
555 | |
---|
556 | # for cascaded items, if the reference is the only item in the |
---|
557 | # list, then remove the entire record, else delete reference |
---|
558 | # from the list |
---|
559 | for field in cascade_list: |
---|
560 | for delete in deleted: |
---|
561 | modify = {field.name: [delete]} |
---|
562 | dtable = self._get_collection(field.tablename, safe) |
---|
563 | dtable.delete_many(filter=modify) |
---|
564 | remove_from_list(field, deleted, safe) |
---|
565 | for field in set_null_list: |
---|
566 | remove_from_list(field, deleted, safe) |
---|
567 | for field in cascade: |
---|
568 | db(field.belongs(deleted)).delete() |
---|
569 | for field in set_null: |
---|
570 | db(field.belongs(deleted)).update(**{field.name: None}) |
---|
571 | return amount |
---|
572 | |
---|
573 | def bulk_insert(self, table, items): |
---|
574 | return [self.insert(table, item) for item in items] |
---|
575 | |
---|
576 | |
---|
577 | class Expansion(object): |
---|
578 | """ |
---|
579 | Class to encapsulate a pydal expression and track the parse |
---|
580 | expansion and its results. |
---|
581 | |
---|
582 | Two different MongoDB mechanisms are targeted here. If the query |
---|
583 | is sufficiently simple, then simple queries are generated. The |
---|
584 | bulk of the complexity here is however to support more complex |
---|
585 | queries that are targeted to the MongoDB Aggregation Pipeline. |
---|
586 | |
---|
587 | This class supports four operations: 'count', 'select', 'update' |
---|
588 | and 'delete'. |
---|
589 | |
---|
590 | Behavior varies somewhat for each operation type. However |
---|
591 | building each pipeline stage is shared where the behavior is the |
---|
592 | same (or similar) for the different operations. |
---|
593 | |
---|
594 | In general an attempt is made to build the query without using the |
---|
595 | pipeline, and if that fails then the query is rebuilt with the |
---|
596 | pipeline. |
---|
597 | |
---|
598 | QUERY constructed in _build_pipeline_query(): |
---|
599 | $project : used to calculate expressions if needed |
---|
600 | $match: filters out records |
---|
601 | |
---|
602 | FIELDS constructed in _expand_fields(): |
---|
603 | FIELDS:COUNT |
---|
604 | $group : filter for distinct if needed |
---|
605 | $group: count the records remaining |
---|
606 | |
---|
607 | FIELDS:SELECT |
---|
608 | $group : implement aggregations if needed |
---|
609 | $project: implement expressions (etc) for select |
---|
610 | |
---|
611 | FIELDS:UPDATE |
---|
612 | $project: implement expressions (etc) for update |
---|
613 | |
---|
614 | HAVING constructed in _add_having(): |
---|
615 | $project : used to calculate expressions |
---|
616 | $match: filters out records |
---|
617 | $project : used to filter out previous expression fields |
---|
618 | |
---|
619 | """ |
---|
620 | |
---|
621 | def __init__( |
---|
622 | self, |
---|
623 | adapter, |
---|
624 | crud, |
---|
625 | query, |
---|
626 | fields=(), |
---|
627 | tablename=None, |
---|
628 | groupby=None, |
---|
629 | distinct=False, |
---|
630 | having=None, |
---|
631 | ): |
---|
632 | self.adapter = adapter |
---|
633 | self.NULL_QUERY = { |
---|
634 | "_id": {"$gt": self.adapter.ObjectId("000000000000000000000000")} |
---|
635 | } |
---|
636 | self._parse_data = { |
---|
637 | "pipeline": False, |
---|
638 | "need_group": bool(groupby or distinct or having), |
---|
639 | } |
---|
640 | self.crud = crud |
---|
641 | self.having = having |
---|
642 | self.distinct = distinct |
---|
643 | if not groupby and distinct: |
---|
644 | if distinct is True: |
---|
645 | # groupby gets all fields |
---|
646 | self.groupby = fields |
---|
647 | else: |
---|
648 | self.groupby = distinct |
---|
649 | else: |
---|
650 | self.groupby = groupby |
---|
651 | |
---|
652 | if crud == "update": |
---|
653 | self.values = [ |
---|
654 | (f[0], self.annotate_expression(f[1])) for f in (fields or []) |
---|
655 | ] |
---|
656 | self.fields = [f[0] for f in self.values] |
---|
657 | else: |
---|
658 | self.fields = [self.annotate_expression(f) for f in (fields or [])] |
---|
659 | |
---|
660 | self.tablename = tablename or adapter.get_table(query, *self.fields)._tablename |
---|
661 | if use_common_filters(query): |
---|
662 | query = adapter.common_filter(query, [self.tablename]) |
---|
663 | self.query = self.annotate_expression(query) |
---|
664 | |
---|
665 | # expand the query |
---|
666 | self.pipeline = [] |
---|
667 | self.query_dict = adapter.expand(self.query) |
---|
668 | self.field_dicts = adapter.SON() |
---|
669 | self.field_groups = adapter.SON() |
---|
670 | self.field_groups["_id"] = adapter.SON() |
---|
671 | |
---|
672 | if self._parse_data["pipeline"]: |
---|
673 | # if the query needs the aggregation engine, set that up |
---|
674 | self._build_pipeline_query() |
---|
675 | |
---|
676 | # expand the fields for the aggregation engine |
---|
677 | self._expand_fields(None) |
---|
678 | else: |
---|
679 | # expand the fields |
---|
680 | try: |
---|
681 | if not self._parse_data["need_group"]: |
---|
682 | self._expand_fields(self._fields_loop_abort) |
---|
683 | else: |
---|
684 | self._parse_data["pipeline"] = True |
---|
685 | raise StopIteration |
---|
686 | except StopIteration: |
---|
687 | # if the fields needs the aggregation engine, set that up |
---|
688 | self.field_dicts = adapter.SON() |
---|
689 | if self.query_dict: |
---|
690 | if self.query_dict != self.NULL_QUERY: |
---|
691 | self.pipeline = [{"$match": self.query_dict}] |
---|
692 | self.query_dict = {} |
---|
693 | # expand the fields for the aggregation engine |
---|
694 | self._expand_fields(None) |
---|
695 | |
---|
696 | if not self._parse_data["pipeline"]: |
---|
697 | if crud == "update": |
---|
698 | # do not update id fields |
---|
699 | for fieldname in ("_id", "id"): |
---|
700 | if fieldname in self.field_dicts: |
---|
701 | del self.field_dicts[fieldname] |
---|
702 | else: |
---|
703 | if crud == "update": |
---|
704 | self._add_all_fields_projection(self.field_dicts) |
---|
705 | self.field_dicts = adapter.SON() |
---|
706 | |
---|
707 | elif crud == "select": |
---|
708 | if self._parse_data["need_group"]: |
---|
709 | if not self.groupby: |
---|
710 | # no groupby, aggregate all records |
---|
711 | self.field_groups["_id"] = None |
---|
712 | # id has no value after aggregations |
---|
713 | self.field_dicts["_id"] = False |
---|
714 | self.pipeline.append({"$group": self.field_groups}) |
---|
715 | if self.field_dicts: |
---|
716 | self.pipeline.append({"$project": self.field_dicts}) |
---|
717 | self.field_dicts = adapter.SON() |
---|
718 | self._add_having() |
---|
719 | |
---|
720 | elif crud == "count": |
---|
721 | if self._parse_data["need_group"]: |
---|
722 | self.pipeline.append({"$group": self.field_groups}) |
---|
723 | self.pipeline.append({"$group": {"_id": None, "count": {"$sum": 1}}}) |
---|
724 | |
---|
725 | # elif crud == 'delete': |
---|
726 | # pass |
---|
727 | |
---|
728 | @property |
---|
729 | def dialect(self): |
---|
730 | return self.adapter.dialect |
---|
731 | |
---|
732 | def _build_pipeline_query(self): |
---|
733 | # search for anything needing the $match stage. |
---|
734 | # currently only '$regex' requires the match stage |
---|
735 | def parse_need_match_stage(items, parent, parent_key): |
---|
736 | need_match = False |
---|
737 | non_matched_indices = [] |
---|
738 | if isinstance(items, list): |
---|
739 | indices = range(len(items)) |
---|
740 | elif isinstance(items, dict): |
---|
741 | indices = items.keys() |
---|
742 | else: |
---|
743 | return |
---|
744 | |
---|
745 | for i in indices: |
---|
746 | if parse_need_match_stage(items[i], items, i): |
---|
747 | need_match = True |
---|
748 | |
---|
749 | elif i not in [self.dialect.REGEXP_MARK1, self.dialect.REGEXP_MARK2]: |
---|
750 | non_matched_indices.append(i) |
---|
751 | |
---|
752 | if i == self.dialect.REGEXP_MARK1: |
---|
753 | need_match = True |
---|
754 | self.query_dict["project"].update(items[i]) |
---|
755 | parent[parent_key] = items[self.dialect.REGEXP_MARK2] |
---|
756 | |
---|
757 | if need_match: |
---|
758 | for i in non_matched_indices: |
---|
759 | name = str(items[i]) |
---|
760 | self.query_dict["project"][name] = items[i] |
---|
761 | items[i] = {name: True} |
---|
762 | |
---|
763 | if parent is None and self.query_dict["project"]: |
---|
764 | self.query_dict["match"] = items |
---|
765 | return need_match |
---|
766 | |
---|
767 | expanded = self.adapter.expand(self.query) |
---|
768 | |
---|
769 | if self.dialect.REGEXP_MARK1 in expanded: |
---|
770 | # the REGEXP_MARK is at the top of the tree, so can just split |
---|
771 | # the regex over a '$project' and a '$match' |
---|
772 | self.query_dict = None |
---|
773 | match = expanded[self.dialect.REGEXP_MARK2] |
---|
774 | project = expanded[self.dialect.REGEXP_MARK1] |
---|
775 | |
---|
776 | else: |
---|
777 | self.query_dict = {"project": {}, "match": {}} |
---|
778 | if parse_need_match_stage(expanded, None, None): |
---|
779 | project = self.query_dict["project"] |
---|
780 | match = self.query_dict["match"] |
---|
781 | else: |
---|
782 | project = {"__query__": expanded} |
---|
783 | match = {"__query__": True} |
---|
784 | |
---|
785 | if self.crud in ["select", "update"]: |
---|
786 | self._add_all_fields_projection(project) |
---|
787 | else: |
---|
788 | self.pipeline.append({"$project": project}) |
---|
789 | self.pipeline.append({"$match": match}) |
---|
790 | self.query_dict = None |
---|
791 | |
---|
792 | def _expand_fields(self, mid_loop): |
---|
793 | if self.crud == "update": |
---|
794 | mid_loop = mid_loop or self._fields_loop_update_pipeline |
---|
795 | for field, value in self.values: |
---|
796 | self._expand_field(field, value, mid_loop) |
---|
797 | elif self.crud in ["select", "count"]: |
---|
798 | mid_loop = mid_loop or self._fields_loop_select_pipeline |
---|
799 | for field in self.fields: |
---|
800 | self._expand_field(field, field, mid_loop) |
---|
801 | elif self.fields: |
---|
802 | raise RuntimeError(self.crud + " not supported with fields") |
---|
803 | |
---|
804 | def _expand_field(self, field, value, mid_loop): |
---|
805 | expanded = {} |
---|
806 | if isinstance(field, Field): |
---|
807 | expanded = self.adapter.expand(value, field.type) |
---|
808 | elif isinstance(field, (Expression, Query)): |
---|
809 | expanded = self.adapter.expand(field) |
---|
810 | field.name = str(expanded) |
---|
811 | else: |
---|
812 | raise RuntimeError("%s not supported with fields" % type(field)) |
---|
813 | |
---|
814 | if mid_loop: |
---|
815 | expanded = mid_loop(expanded, field, value) |
---|
816 | self.field_dicts[field.name] = expanded |
---|
817 | |
---|
818 | def _fields_loop_abort(self, expanded, *args): |
---|
819 | # if we need the aggregation engine, then start over |
---|
820 | if self._parse_data["pipeline"]: |
---|
821 | raise StopIteration() |
---|
822 | return expanded |
---|
823 | |
---|
824 | def _fields_loop_update_pipeline(self, expanded, field, value): |
---|
825 | if not isinstance(value, Expression): |
---|
826 | if self.adapter.server_version_major >= 2.6: |
---|
827 | expanded = {"$literal": expanded} |
---|
828 | |
---|
829 | # '$literal' not present in server versions < 2.6 |
---|
830 | elif field.type in ["string", "text", "password"]: |
---|
831 | expanded = {"$concat": [expanded]} |
---|
832 | elif field.type in ["integer", "bigint", "float", "double"]: |
---|
833 | expanded = {"$add": [expanded]} |
---|
834 | elif field.type == "boolean": |
---|
835 | expanded = {"$and": [expanded]} |
---|
836 | elif field.type in ["date", "time", "datetime"]: |
---|
837 | expanded = {"$add": [expanded]} |
---|
838 | else: |
---|
839 | raise RuntimeError( |
---|
840 | "updating with expressions not supported for field type " |
---|
841 | + "'%s' in MongoDB version < 2.6" % field.type |
---|
842 | ) |
---|
843 | return expanded |
---|
844 | |
---|
845 | def _fields_loop_select_pipeline(self, expanded, field, value): |
---|
846 | # search for anything needing $group |
---|
847 | def parse_groups(items, parent, parent_key): |
---|
848 | for item in items: |
---|
849 | if isinstance(items[item], list): |
---|
850 | for list_item in items[item]: |
---|
851 | if isinstance(list_item, dict): |
---|
852 | parse_groups( |
---|
853 | list_item, items[item], items[item].index(list_item) |
---|
854 | ) |
---|
855 | |
---|
856 | elif isinstance(items[item], dict): |
---|
857 | parse_groups(items[item], items, item) |
---|
858 | |
---|
859 | if item == self.dialect.GROUP_MARK: |
---|
860 | name = str(items) |
---|
861 | self.field_groups[name] = items[item] |
---|
862 | parent[parent_key] = "$" + name |
---|
863 | return items |
---|
864 | |
---|
865 | if self.dialect.AS_MARK in field.name: |
---|
866 | # The AS_MARK in the field name is used by base to alias the |
---|
867 | # result, we don't actually need the AS_MARK in the parse tree |
---|
868 | # so we remove it here. |
---|
869 | if isinstance(expanded, list): |
---|
870 | # AS mark is first element in list, drop it |
---|
871 | expanded = expanded[1] |
---|
872 | |
---|
873 | elif self.dialect.AS_MARK in expanded: |
---|
874 | # AS mark is element in dict, drop it |
---|
875 | del expanded[self.dialect.AS_MARK] |
---|
876 | |
---|
877 | else: |
---|
878 | # ::TODO:: should be possible to do this... |
---|
879 | raise SyntaxError("AS() not at top of parse tree") |
---|
880 | |
---|
881 | if self.dialect.GROUP_MARK in expanded: |
---|
882 | # the GROUP_MARK is at the top of the tree, so can just pass |
---|
883 | # the group result straight through the '$project' stage |
---|
884 | self.field_groups[field.name] = expanded[self.dialect.GROUP_MARK] |
---|
885 | expanded = 1 |
---|
886 | |
---|
887 | elif self.dialect.GROUP_MARK in field.name: |
---|
888 | # the GROUP_MARK is not at the top of the tree, so we need to |
---|
889 | # pass the group results through to a '$project' stage. |
---|
890 | expanded = parse_groups(expanded, None, None) |
---|
891 | |
---|
892 | elif self._parse_data["need_group"]: |
---|
893 | if field in self.groupby: |
---|
894 | # this is a 'groupby' field |
---|
895 | self.field_groups["_id"][field.name] = expanded |
---|
896 | expanded = "$_id." + field.name |
---|
897 | else: |
---|
898 | raise SyntaxError("field '%s' not in groupby" % field) |
---|
899 | |
---|
900 | return expanded |
---|
901 | |
---|
902 | def _add_all_fields_projection(self, fields): |
---|
903 | for fieldname in self.adapter.db[self.tablename].fields: |
---|
904 | # add all fields to projection to pass them through |
---|
905 | if fieldname not in fields and fieldname not in ("_id", "id"): |
---|
906 | fields[fieldname] = 1 |
---|
907 | self.pipeline.append({"$project": fields}) |
---|
908 | |
---|
909 | def _add_having(self): |
---|
910 | if not self.having: |
---|
911 | return |
---|
912 | self._expand_field(self.having, None, self._fields_loop_select_pipeline) |
---|
913 | fields = {"__having__": self.field_dicts[self.having.name]} |
---|
914 | for fieldname in self.pipeline[-1]["$project"]: |
---|
915 | # add all fields to projection to pass them through |
---|
916 | if fieldname not in fields and fieldname not in ("_id", "id"): |
---|
917 | fields[fieldname] = 1 |
---|
918 | |
---|
919 | self.pipeline.append({"$project": copy.copy(fields)}) |
---|
920 | self.pipeline.append({"$match": {"__having__": True}}) |
---|
921 | del fields["__having__"] |
---|
922 | self.pipeline.append({"$project": fields}) |
---|
923 | |
---|
924 | def annotate_expression(self, expression): |
---|
925 | def mark_has_field(expression): |
---|
926 | if not isinstance(expression, (Expression, Query)): |
---|
927 | return False |
---|
928 | first_has_field = mark_has_field(expression.first) |
---|
929 | second_has_field = mark_has_field(expression.second) |
---|
930 | expression.has_field = ( |
---|
931 | isinstance(expression, Field) or first_has_field or second_has_field |
---|
932 | ) |
---|
933 | return expression.has_field |
---|
934 | |
---|
935 | def add_parse_data(child, parent): |
---|
936 | if isinstance(child, (Expression, Query)): |
---|
937 | child.parse_root = parent.parse_root |
---|
938 | child.parse_parent = parent |
---|
939 | child.parse_depth = parent.parse_depth + 1 |
---|
940 | child._parse_data = parent._parse_data |
---|
941 | add_parse_data(child.first, child) |
---|
942 | add_parse_data(child.second, child) |
---|
943 | elif isinstance(child, (list, tuple)): |
---|
944 | for c in child: |
---|
945 | add_parse_data(c, parent) |
---|
946 | |
---|
947 | if isinstance(expression, (Expression, Query)): |
---|
948 | expression.parse_root = expression |
---|
949 | expression.parse_depth = -1 |
---|
950 | expression._parse_data = self._parse_data |
---|
951 | add_parse_data(expression, expression) |
---|
952 | mark_has_field(expression) |
---|
953 | return expression |
---|
954 | |
---|
955 | def get_collection(self, safe=None): |
---|
956 | return self.adapter._get_collection(self.tablename, safe) |
---|
957 | |
---|
958 | |
---|
959 | class MongoBlob(Binary): |
---|
960 | MONGO_BLOB_BYTES = USER_DEFINED_SUBTYPE |
---|
961 | MONGO_BLOB_NON_UTF8_STR = USER_DEFINED_SUBTYPE + 1 |
---|
962 | |
---|
963 | def __new__(cls, value): |
---|
964 | # return None and Binary() unmolested |
---|
965 | if value is None or isinstance(value, Binary): |
---|
966 | return value |
---|
967 | |
---|
968 | # bytearray is marked as MONGO_BLOB_BYTES |
---|
969 | if isinstance(value, bytearray): |
---|
970 | return Binary.__new__(cls, bytes(value), MongoBlob.MONGO_BLOB_BYTES) |
---|
971 | |
---|
972 | # return non-strings as Binary(), eg: PY3 bytes() |
---|
973 | if not isinstance(value, basestring): |
---|
974 | return Binary(value) |
---|
975 | |
---|
976 | # if string is encodable as UTF-8, then return as string |
---|
977 | try: |
---|
978 | value.encode("utf-8") |
---|
979 | return value |
---|
980 | except UnicodeDecodeError: |
---|
981 | # string which can not be UTF-8 encoded, eg: pickle strings |
---|
982 | return Binary.__new__(cls, value, MongoBlob.MONGO_BLOB_NON_UTF8_STR) |
---|
983 | |
---|
984 | def __repr__(self): |
---|
985 | return repr(MongoBlob.decode(self)) |
---|
986 | |
---|
987 | @staticmethod |
---|
988 | def decode(value): |
---|
989 | if isinstance(value, Binary): |
---|
990 | if value.subtype == MongoBlob.MONGO_BLOB_BYTES: |
---|
991 | return bytearray(value) |
---|
992 | if value.subtype == MongoBlob.MONGO_BLOB_NON_UTF8_STR: |
---|
993 | return str(value) |
---|
994 | return value |
---|