1 | # -*- coding: utf-8 -*- |
---|
2 | # pylint: disable=no-member,not-an-iterable |
---|
3 | |
---|
4 | import base64 |
---|
5 | import binascii |
---|
6 | import cgi |
---|
7 | import copy |
---|
8 | import csv |
---|
9 | import datetime |
---|
10 | import decimal |
---|
11 | import os |
---|
12 | import shutil |
---|
13 | import sys |
---|
14 | import types |
---|
15 | import re |
---|
16 | from collections import OrderedDict |
---|
17 | from ._compat import ( |
---|
18 | PY2, |
---|
19 | StringIO, |
---|
20 | BytesIO, |
---|
21 | pjoin, |
---|
22 | exists, |
---|
23 | hashlib_md5, |
---|
24 | basestring, |
---|
25 | iteritems, |
---|
26 | xrange, |
---|
27 | implements_iterator, |
---|
28 | implements_bool, |
---|
29 | copyreg, |
---|
30 | reduce, |
---|
31 | to_bytes, |
---|
32 | to_native, |
---|
33 | to_unicode, |
---|
34 | long, |
---|
35 | text_type, |
---|
36 | ) |
---|
37 | from ._globals import DEFAULT, IDENTITY, AND, OR |
---|
38 | from ._gae import Key |
---|
39 | from .exceptions import NotFoundException, NotAuthorizedException |
---|
40 | from .helpers.regex import ( |
---|
41 | REGEX_TABLE_DOT_FIELD, |
---|
42 | REGEX_ALPHANUMERIC, |
---|
43 | REGEX_PYTHON_KEYWORDS, |
---|
44 | REGEX_UPLOAD_EXTENSION, |
---|
45 | REGEX_UPLOAD_PATTERN, |
---|
46 | REGEX_UPLOAD_CLEANUP, |
---|
47 | REGEX_VALID_TB_FLD, |
---|
48 | REGEX_TYPE, |
---|
49 | REGEX_TABLE_DOT_FIELD_OPTIONAL_QUOTES, |
---|
50 | ) |
---|
51 | from .helpers.classes import ( |
---|
52 | Reference, |
---|
53 | MethodAdder, |
---|
54 | SQLCallableList, |
---|
55 | SQLALL, |
---|
56 | Serializable, |
---|
57 | BasicStorage, |
---|
58 | SQLCustomType, |
---|
59 | OpRow, |
---|
60 | cachedprop, |
---|
61 | ) |
---|
62 | from .helpers.methods import ( |
---|
63 | list_represent, |
---|
64 | bar_decode_integer, |
---|
65 | bar_decode_string, |
---|
66 | bar_encode, |
---|
67 | archive_record, |
---|
68 | cleanup, |
---|
69 | use_common_filters, |
---|
70 | attempt_upload_on_insert, |
---|
71 | attempt_upload_on_update, |
---|
72 | delete_uploaded_files, |
---|
73 | uuidstr |
---|
74 | ) |
---|
75 | from .helpers.serializers import serializers |
---|
76 | from .utils import deprecated |
---|
77 | |
---|
78 | if not PY2: |
---|
79 | unicode = str |
---|
80 | |
---|
81 | DEFAULTLENGTH = { |
---|
82 | "string": 512, |
---|
83 | "password": 512, |
---|
84 | "upload": 512, |
---|
85 | "text": 2 ** 15, |
---|
86 | "blob": 2 ** 31, |
---|
87 | } |
---|
88 | |
---|
89 | DEFAULT_REGEX = { |
---|
90 | "id": "[1-9]\d*", |
---|
91 | "decimal": "\d{1,10}\.\d{2}", |
---|
92 | "integer": "[+-]?\d*", |
---|
93 | "float": "[+-]?\d*(\.\d*)?", |
---|
94 | "double": "[+-]?\d*(\.\d*)?", |
---|
95 | "date": "\d{4}\-\d{2}\-\d{2}", |
---|
96 | "time": "\d{2}\:\d{2}(\:\d{2}(\.\d*)?)?", |
---|
97 | "datetime": "\d{4}\-\d{2}\-\d{2} \d{2}\:\d{2}(\:\d{2}(\.\d*)?)?", |
---|
98 | } |
---|
99 | |
---|
100 | |
---|
101 | def csv_reader(utf8_data, dialect=csv.excel, encoding="utf-8", **kwargs): |
---|
102 | """like csv.reader but allows to specify an encoding, defaults to utf-8""" |
---|
103 | csv_reader = csv.reader(utf8_data, dialect=dialect, **kwargs) |
---|
104 | for row in csv_reader: |
---|
105 | yield [to_unicode(cell, encoding) for cell in row] |
---|
106 | |
---|
107 | |
---|
108 | class Row(BasicStorage): |
---|
109 | |
---|
110 | """ |
---|
111 | A dictionary that lets you do d['a'] as well as d.a |
---|
112 | this is only used to store a `Row` |
---|
113 | """ |
---|
114 | |
---|
115 | def __getitem__(self, k): |
---|
116 | key = str(k) |
---|
117 | |
---|
118 | _extra = BasicStorage.get(self, "_extra", None) |
---|
119 | if _extra is not None: |
---|
120 | v = _extra.get(key, DEFAULT) |
---|
121 | if v is not DEFAULT: |
---|
122 | return v |
---|
123 | |
---|
124 | try: |
---|
125 | return BasicStorage.__getattribute__(self, key) |
---|
126 | except AttributeError: |
---|
127 | pass |
---|
128 | |
---|
129 | m = REGEX_TABLE_DOT_FIELD.match(key) |
---|
130 | if m: |
---|
131 | key2 = m.group(2) |
---|
132 | try: |
---|
133 | return BasicStorage.__getitem__(self, m.group(1))[key2] |
---|
134 | except (KeyError, TypeError): |
---|
135 | pass |
---|
136 | try: |
---|
137 | return BasicStorage.__getitem__(self, key2) |
---|
138 | except KeyError: |
---|
139 | pass |
---|
140 | |
---|
141 | lg = BasicStorage.get(self, "__get_lazy_reference__", None) |
---|
142 | if callable(lg): |
---|
143 | v = self[key] = lg(key) |
---|
144 | return v |
---|
145 | |
---|
146 | raise KeyError(key) |
---|
147 | |
---|
148 | def __repr__(self): |
---|
149 | return "<Row %s>" % self.as_dict(custom_types=[LazySet]) |
---|
150 | |
---|
151 | def __int__(self): |
---|
152 | return self.get("id") |
---|
153 | |
---|
154 | def __long__(self): |
---|
155 | return long(int(self)) |
---|
156 | |
---|
157 | def __hash__(self): |
---|
158 | return id(self) |
---|
159 | |
---|
160 | __str__ = __repr__ |
---|
161 | |
---|
162 | __call__ = __getitem__ |
---|
163 | |
---|
164 | def __getattr__(self, k): |
---|
165 | try: |
---|
166 | return self.__getitem__(k) |
---|
167 | except KeyError: |
---|
168 | raise AttributeError |
---|
169 | |
---|
170 | def __copy__(self): |
---|
171 | return Row(self) |
---|
172 | |
---|
173 | def __eq__(self, other): |
---|
174 | try: |
---|
175 | return self.as_dict() == other.as_dict() |
---|
176 | except AttributeError: |
---|
177 | return False |
---|
178 | |
---|
179 | def get(self, key, default=None): |
---|
180 | try: |
---|
181 | return self.__getitem__(key) |
---|
182 | except (KeyError, AttributeError, TypeError): |
---|
183 | return default |
---|
184 | |
---|
185 | def as_dict(self, datetime_to_str=False, custom_types=None): |
---|
186 | SERIALIZABLE_TYPES = [str, int, float, bool, list, dict] |
---|
187 | DT_INST = (datetime.date, datetime.datetime, datetime.time) |
---|
188 | if PY2: |
---|
189 | SERIALIZABLE_TYPES += [unicode, long] |
---|
190 | if isinstance(custom_types, (list, tuple, set)): |
---|
191 | SERIALIZABLE_TYPES += custom_types |
---|
192 | elif custom_types: |
---|
193 | SERIALIZABLE_TYPES.append(custom_types) |
---|
194 | d = dict(self) |
---|
195 | for k in list(d.keys()): |
---|
196 | v = d[k] |
---|
197 | if d[k] is None: |
---|
198 | continue |
---|
199 | elif isinstance(v, Row): |
---|
200 | d[k] = v.as_dict() |
---|
201 | elif isinstance(v, Reference): |
---|
202 | d[k] = long(v) |
---|
203 | elif isinstance(v, decimal.Decimal): |
---|
204 | d[k] = float(v) |
---|
205 | elif isinstance(v, DT_INST): |
---|
206 | if datetime_to_str: |
---|
207 | d[k] = v.isoformat().replace("T", " ")[:19] |
---|
208 | elif not isinstance(v, tuple(SERIALIZABLE_TYPES)): |
---|
209 | del d[k] |
---|
210 | return d |
---|
211 | |
---|
212 | def as_xml(self, row_name="row", colnames=None, indent=" "): |
---|
213 | def f(row, field, indent=" "): |
---|
214 | if isinstance(row, Row): |
---|
215 | spc = indent + " \n" |
---|
216 | items = [f(row[x], x, indent + " ") for x in row] |
---|
217 | return "%s<%s>\n%s\n%s</%s>" % ( |
---|
218 | indent, |
---|
219 | field, |
---|
220 | spc.join(item for item in items if item), |
---|
221 | indent, |
---|
222 | field, |
---|
223 | ) |
---|
224 | elif not callable(row): |
---|
225 | if re.match(REGEX_ALPHANUMERIC, field): |
---|
226 | return "%s<%s>%s</%s>" % (indent, field, row, field) |
---|
227 | else: |
---|
228 | return '%s<extra name="%s">%s</extra>' % (indent, field, row) |
---|
229 | else: |
---|
230 | return None |
---|
231 | |
---|
232 | return f(self, row_name, indent=indent) |
---|
233 | |
---|
234 | def as_json( |
---|
235 | self, mode="object", default=None, colnames=None, serialize=True, **kwargs |
---|
236 | ): |
---|
237 | """ |
---|
238 | serializes the row to a JSON object |
---|
239 | kwargs are passed to .as_dict method |
---|
240 | only "object" mode supported |
---|
241 | |
---|
242 | `serialize = False` used by Rows.as_json |
---|
243 | |
---|
244 | TODO: return array mode with query column order |
---|
245 | |
---|
246 | mode and colnames are not implemented |
---|
247 | """ |
---|
248 | |
---|
249 | item = self.as_dict(**kwargs) |
---|
250 | if serialize: |
---|
251 | return serializers.json(item) |
---|
252 | else: |
---|
253 | return item |
---|
254 | |
---|
255 | |
---|
256 | def pickle_row(s): |
---|
257 | return Row, (dict(s),) |
---|
258 | |
---|
259 | |
---|
260 | copyreg.pickle(Row, pickle_row) |
---|
261 | |
---|
262 | |
---|
263 | class Table(Serializable, BasicStorage): |
---|
264 | |
---|
265 | """ |
---|
266 | Represents a database table |
---|
267 | |
---|
268 | Example:: |
---|
269 | You can create a table as:: |
---|
270 | db = DAL(...) |
---|
271 | db.define_table('users', Field('name')) |
---|
272 | |
---|
273 | And then:: |
---|
274 | |
---|
275 | db.users.insert(name='me') # print db.users._insert(...) to see SQL |
---|
276 | db.users.drop() |
---|
277 | |
---|
278 | """ |
---|
279 | |
---|
280 | def __init__(self, db, tablename, *fields, **args): |
---|
281 | """ |
---|
282 | Initializes the table and performs checking on the provided fields. |
---|
283 | |
---|
284 | Each table will have automatically an 'id'. |
---|
285 | |
---|
286 | If a field is of type Table, the fields (excluding 'id') from that table |
---|
287 | will be used instead. |
---|
288 | |
---|
289 | Raises: |
---|
290 | SyntaxError: when a supplied field is of incorrect type. |
---|
291 | """ |
---|
292 | # import DAL here to avoid circular imports |
---|
293 | from .base import DAL |
---|
294 | |
---|
295 | super(Table, self).__init__() |
---|
296 | self._actual = False # set to True by define_table() |
---|
297 | self._db = db |
---|
298 | self._migrate = None |
---|
299 | self._tablename = self._dalname = tablename |
---|
300 | if ( |
---|
301 | not isinstance(tablename, str) |
---|
302 | or hasattr(DAL, tablename) |
---|
303 | or not REGEX_VALID_TB_FLD.match(tablename) |
---|
304 | or REGEX_PYTHON_KEYWORDS.match(tablename) |
---|
305 | ): |
---|
306 | raise SyntaxError( |
---|
307 | "Field: invalid table name: %s, " |
---|
308 | 'use rname for "funny" names' % tablename |
---|
309 | ) |
---|
310 | self._rname = args.get("rname") or db and db._adapter.dialect.quote(tablename) |
---|
311 | self._raw_rname = args.get("rname") or db and tablename |
---|
312 | self._sequence_name = ( |
---|
313 | args.get("sequence_name") |
---|
314 | or db |
---|
315 | and db._adapter.dialect.sequence_name(self._raw_rname) |
---|
316 | ) |
---|
317 | self._trigger_name = ( |
---|
318 | args.get("trigger_name") |
---|
319 | or db |
---|
320 | and db._adapter.dialect.trigger_name(tablename) |
---|
321 | ) |
---|
322 | self._common_filter = args.get("common_filter") |
---|
323 | self._format = args.get("format") |
---|
324 | self._singular = args.get("singular", tablename.replace("_", " ").capitalize()) |
---|
325 | self._plural = args.get("plural") |
---|
326 | # horrible but for backward compatibility of appadmin |
---|
327 | if "primarykey" in args and args["primarykey"] is not None: |
---|
328 | self._primarykey = args.get("primarykey") |
---|
329 | |
---|
330 | self._before_insert = [attempt_upload_on_insert(self)] |
---|
331 | self._before_update = [delete_uploaded_files, attempt_upload_on_update(self)] |
---|
332 | self._before_delete = [delete_uploaded_files] |
---|
333 | self._after_insert = [] |
---|
334 | self._after_update = [] |
---|
335 | self._after_delete = [] |
---|
336 | |
---|
337 | self._virtual_fields = [] |
---|
338 | self._virtual_methods = [] |
---|
339 | |
---|
340 | self.add_method = MethodAdder(self) |
---|
341 | |
---|
342 | fieldnames = set() |
---|
343 | newfields = [] |
---|
344 | _primarykey = getattr(self, "_primarykey", None) |
---|
345 | if _primarykey is not None: |
---|
346 | if not isinstance(_primarykey, list): |
---|
347 | raise SyntaxError( |
---|
348 | "primarykey must be a list of fields from table '%s'" % tablename |
---|
349 | ) |
---|
350 | if len(_primarykey) == 1: |
---|
351 | self._id = [ |
---|
352 | f |
---|
353 | for f in fields |
---|
354 | if isinstance(f, Field) and f.name == _primarykey[0] |
---|
355 | ][0] |
---|
356 | elif not [ |
---|
357 | f |
---|
358 | for f in fields |
---|
359 | if (isinstance(f, Field) and f.type == "id") |
---|
360 | or (isinstance(f, dict) and f.get("type", None) == "id") |
---|
361 | ]: |
---|
362 | field = Field("id", "id") |
---|
363 | newfields.append(field) |
---|
364 | fieldnames.add("id") |
---|
365 | self._id = field |
---|
366 | |
---|
367 | virtual_fields = [] |
---|
368 | |
---|
369 | def include_new(field): |
---|
370 | newfields.append(field) |
---|
371 | fieldnames.add(field.name) |
---|
372 | if field.type == "id": |
---|
373 | self._id = field |
---|
374 | |
---|
375 | for field in fields: |
---|
376 | if isinstance(field, (FieldVirtual, FieldMethod)): |
---|
377 | virtual_fields.append(field) |
---|
378 | elif isinstance(field, Field) and field.name not in fieldnames: |
---|
379 | if field.db is not None: |
---|
380 | field = copy.copy(field) |
---|
381 | include_new(field) |
---|
382 | elif isinstance(field, (list, tuple)): |
---|
383 | for other in field: |
---|
384 | include_new(other) |
---|
385 | elif isinstance(field, Table): |
---|
386 | table = field |
---|
387 | for field in table: |
---|
388 | if field.name not in fieldnames and field.type != "id": |
---|
389 | t2 = not table._actual and self._tablename |
---|
390 | include_new(field.clone(point_self_references_to=t2)) |
---|
391 | elif isinstance(field, dict) and field["fieldname"] not in fieldnames: |
---|
392 | include_new(Field(**field)) |
---|
393 | elif not isinstance(field, (Field, Table)): |
---|
394 | raise SyntaxError( |
---|
395 | "define_table argument is not a Field, Table of list: %s" % field |
---|
396 | ) |
---|
397 | fields = newfields |
---|
398 | self._fields = SQLCallableList() |
---|
399 | self.virtualfields = [] |
---|
400 | |
---|
401 | if db and db._adapter.uploads_in_blob is True: |
---|
402 | uploadfields = [f.name for f in fields if f.type == "blob"] |
---|
403 | for field in fields: |
---|
404 | fn = field.uploadfield |
---|
405 | if ( |
---|
406 | isinstance(field, Field) |
---|
407 | and field.type == "upload" |
---|
408 | and fn is True |
---|
409 | and not field.uploadfs |
---|
410 | ): |
---|
411 | fn = field.uploadfield = "%s_blob" % field.name |
---|
412 | if ( |
---|
413 | isinstance(fn, str) |
---|
414 | and fn not in uploadfields |
---|
415 | and not field.uploadfs |
---|
416 | ): |
---|
417 | fields.append( |
---|
418 | Field(fn, "blob", default="", writable=False, readable=False) |
---|
419 | ) |
---|
420 | |
---|
421 | fieldnames_set = set() |
---|
422 | reserved = dir(Table) + ["fields"] |
---|
423 | if db and db._check_reserved: |
---|
424 | check_reserved_keyword = db.check_reserved_keyword |
---|
425 | else: |
---|
426 | |
---|
427 | def check_reserved_keyword(field_name): |
---|
428 | if field_name in reserved: |
---|
429 | raise SyntaxError("field name %s not allowed" % field_name) |
---|
430 | |
---|
431 | for field in fields: |
---|
432 | field_name = field.name |
---|
433 | check_reserved_keyword(field_name) |
---|
434 | if db and db._ignore_field_case: |
---|
435 | fname_item = field_name.lower() |
---|
436 | else: |
---|
437 | fname_item = field_name |
---|
438 | if fname_item in fieldnames_set: |
---|
439 | raise SyntaxError( |
---|
440 | "duplicate field %s in table %s" % (field_name, tablename) |
---|
441 | ) |
---|
442 | else: |
---|
443 | fieldnames_set.add(fname_item) |
---|
444 | |
---|
445 | self.fields.append(field_name) |
---|
446 | self[field_name] = field |
---|
447 | if field.type == "id": |
---|
448 | self["id"] = field |
---|
449 | field.bind(self) |
---|
450 | self.ALL = SQLALL(self) |
---|
451 | |
---|
452 | if _primarykey is not None: |
---|
453 | for k in _primarykey: |
---|
454 | if k not in self.fields: |
---|
455 | raise SyntaxError( |
---|
456 | "primarykey must be a list of fields from table '%s " |
---|
457 | % tablename |
---|
458 | ) |
---|
459 | else: |
---|
460 | self[k].notnull = True |
---|
461 | for field in virtual_fields: |
---|
462 | self[field.name] = field |
---|
463 | |
---|
464 | @property |
---|
465 | def fields(self): |
---|
466 | return self._fields |
---|
467 | |
---|
468 | def _structure(self): |
---|
469 | keys = [ |
---|
470 | "name", |
---|
471 | "type", |
---|
472 | "writable", |
---|
473 | "listable", |
---|
474 | "searchable", |
---|
475 | "regex", |
---|
476 | "options", |
---|
477 | "default", |
---|
478 | "label", |
---|
479 | "unique", |
---|
480 | "notnull", |
---|
481 | "required", |
---|
482 | ] |
---|
483 | |
---|
484 | def noncallable(obj): |
---|
485 | return obj if not callable(obj) else None |
---|
486 | |
---|
487 | return [ |
---|
488 | {key: noncallable(getattr(field, key)) for key in keys} |
---|
489 | for field in self |
---|
490 | if field.readable and not field.type == "password" |
---|
491 | ] |
---|
492 | |
---|
493 | @cachedprop |
---|
494 | def _upload_fieldnames(self): |
---|
495 | return set(field.name for field in self if field.type == "upload") |
---|
496 | |
---|
497 | def update(self, *args, **kwargs): |
---|
498 | raise RuntimeError("Syntax Not Supported") |
---|
499 | |
---|
500 | def _enable_record_versioning( |
---|
501 | self, |
---|
502 | archive_db=None, |
---|
503 | archive_name="%(tablename)s_archive", |
---|
504 | is_active="is_active", |
---|
505 | current_record="current_record", |
---|
506 | current_record_label=None, |
---|
507 | migrate=None, |
---|
508 | redefine=None, |
---|
509 | ): |
---|
510 | db = self._db |
---|
511 | archive_db = archive_db or db |
---|
512 | archive_name = archive_name % dict(tablename=self._dalname) |
---|
513 | if archive_name in archive_db.tables(): |
---|
514 | return # do not try define the archive if already exists |
---|
515 | fieldnames = self.fields() |
---|
516 | same_db = archive_db is db |
---|
517 | field_type = self if same_db else "bigint" |
---|
518 | clones = [] |
---|
519 | for field in self: |
---|
520 | nfk = same_db or not field.type.startswith("reference") |
---|
521 | clones.append( |
---|
522 | field.clone(unique=False, type=field.type if nfk else "bigint") |
---|
523 | ) |
---|
524 | |
---|
525 | d = dict(format=self._format) |
---|
526 | if migrate: |
---|
527 | d["migrate"] = migrate |
---|
528 | elif isinstance(self._migrate, basestring): |
---|
529 | d["migrate"] = self._migrate + "_archive" |
---|
530 | elif self._migrate: |
---|
531 | d["migrate"] = self._migrate |
---|
532 | if redefine: |
---|
533 | d["redefine"] = redefine |
---|
534 | archive_db.define_table( |
---|
535 | archive_name, |
---|
536 | Field(current_record, field_type, label=current_record_label), |
---|
537 | *clones, |
---|
538 | **d |
---|
539 | ) |
---|
540 | |
---|
541 | self._before_update.append( |
---|
542 | lambda qset, fs, db=archive_db, an=archive_name, cn=current_record: archive_record( |
---|
543 | qset, fs, db[an], cn |
---|
544 | ) |
---|
545 | ) |
---|
546 | if is_active and is_active in fieldnames: |
---|
547 | self._before_delete.append(lambda qset: qset.update(is_active=False)) |
---|
548 | newquery = lambda query, t=self, name=self._tablename: reduce( |
---|
549 | AND, |
---|
550 | [ |
---|
551 | tab.is_active == True |
---|
552 | for tab in db._adapter.tables(query).values() |
---|
553 | if tab._raw_rname == self._raw_rname |
---|
554 | ], |
---|
555 | ) |
---|
556 | query = self._common_filter |
---|
557 | if query: |
---|
558 | self._common_filter = lambda q: reduce(AND, [query(q), newquery(q)]) |
---|
559 | else: |
---|
560 | self._common_filter = newquery |
---|
561 | |
---|
562 | def _validate(self, **vars): |
---|
563 | errors = Row() |
---|
564 | for key, value in iteritems(vars): |
---|
565 | value, error = getattr(self, key).validate(value, vars.get("id")) |
---|
566 | if error: |
---|
567 | errors[key] = error |
---|
568 | return errors |
---|
569 | |
---|
570 | def _create_references(self): |
---|
571 | db = self._db |
---|
572 | pr = db._pending_references |
---|
573 | self._referenced_by_list = [] |
---|
574 | self._referenced_by = [] |
---|
575 | self._references = [] |
---|
576 | for field in self: |
---|
577 | # fieldname = field.name #FIXME not used ? |
---|
578 | field_type = field.type |
---|
579 | if isinstance(field_type, str) and ( |
---|
580 | field_type.startswith("reference ") |
---|
581 | or field_type.startswith("list:reference ") |
---|
582 | ): |
---|
583 | |
---|
584 | is_list = field_type[:15] == "list:reference " |
---|
585 | if is_list: |
---|
586 | ref = field_type[15:].strip() |
---|
587 | else: |
---|
588 | ref = field_type[10:].strip() |
---|
589 | |
---|
590 | if not ref: |
---|
591 | SyntaxError("Table: reference to nothing: %s" % ref) |
---|
592 | if "." in ref: |
---|
593 | rtablename, throw_it, rfieldname = ref.partition(".") |
---|
594 | else: |
---|
595 | rtablename, rfieldname = ref, None |
---|
596 | if rtablename not in db: |
---|
597 | pr[rtablename] = pr.get(rtablename, []) + [field] |
---|
598 | continue |
---|
599 | rtable = db[rtablename] |
---|
600 | if rfieldname: |
---|
601 | if not hasattr(rtable, "_primarykey"): |
---|
602 | raise SyntaxError( |
---|
603 | "keyed tables can only reference other keyed tables (for now)" |
---|
604 | ) |
---|
605 | if rfieldname not in rtable.fields: |
---|
606 | raise SyntaxError( |
---|
607 | "invalid field '%s' for referenced table '%s'" |
---|
608 | " in table '%s'" % (rfieldname, rtablename, self._tablename) |
---|
609 | ) |
---|
610 | rfield = rtable[rfieldname] |
---|
611 | else: |
---|
612 | rfield = rtable._id |
---|
613 | if is_list: |
---|
614 | rtable._referenced_by_list.append(field) |
---|
615 | else: |
---|
616 | rtable._referenced_by.append(field) |
---|
617 | field.referent = rfield |
---|
618 | self._references.append(field) |
---|
619 | else: |
---|
620 | field.referent = None |
---|
621 | if self._tablename in pr: |
---|
622 | referees = pr.pop(self._tablename) |
---|
623 | for referee in referees: |
---|
624 | if referee.type.startswith("list:reference "): |
---|
625 | self._referenced_by_list.append(referee) |
---|
626 | else: |
---|
627 | self._referenced_by.append(referee) |
---|
628 | |
---|
629 | def _filter_fields(self, record, id=False): |
---|
630 | return dict( |
---|
631 | [ |
---|
632 | (k, v) |
---|
633 | for (k, v) in iteritems(record) |
---|
634 | if k in self.fields and (getattr(self, k).type != "id" or id) |
---|
635 | ] |
---|
636 | ) |
---|
637 | |
---|
638 | def _build_query(self, key): |
---|
639 | """ for keyed table only """ |
---|
640 | query = None |
---|
641 | for k, v in iteritems(key): |
---|
642 | if k in self._primarykey: |
---|
643 | if query: |
---|
644 | query = query & (getattr(self, k) == v) |
---|
645 | else: |
---|
646 | query = getattr(self, k) == v |
---|
647 | else: |
---|
648 | raise SyntaxError( |
---|
649 | "Field %s is not part of the primary key of %s" |
---|
650 | % (k, self._tablename) |
---|
651 | ) |
---|
652 | return query |
---|
653 | |
---|
654 | def __getitem__(self, key): |
---|
655 | if str(key).isdigit() or (Key is not None and isinstance(key, Key)): |
---|
656 | # non negative key or gae |
---|
657 | return ( |
---|
658 | self._db(self._id == str(key)) |
---|
659 | .select(limitby=(0, 1), orderby_on_limitby=False) |
---|
660 | .first() |
---|
661 | ) |
---|
662 | elif isinstance(key, dict): |
---|
663 | # keyed table |
---|
664 | query = self._build_query(key) |
---|
665 | return ( |
---|
666 | self._db(query).select(limitby=(0, 1), orderby_on_limitby=False).first() |
---|
667 | ) |
---|
668 | elif key is not None: |
---|
669 | try: |
---|
670 | return getattr(self, key) |
---|
671 | except: |
---|
672 | raise KeyError(key) |
---|
673 | |
---|
674 | def __call__(self, key=DEFAULT, **kwargs): |
---|
675 | for_update = kwargs.get("_for_update", False) |
---|
676 | if "_for_update" in kwargs: |
---|
677 | del kwargs["_for_update"] |
---|
678 | |
---|
679 | orderby = kwargs.get("_orderby", None) |
---|
680 | if "_orderby" in kwargs: |
---|
681 | del kwargs["_orderby"] |
---|
682 | |
---|
683 | if key is not DEFAULT: |
---|
684 | if isinstance(key, Query): |
---|
685 | record = ( |
---|
686 | self._db(key) |
---|
687 | .select( |
---|
688 | limitby=(0, 1), |
---|
689 | for_update=for_update, |
---|
690 | orderby=orderby, |
---|
691 | orderby_on_limitby=False, |
---|
692 | ) |
---|
693 | .first() |
---|
694 | ) |
---|
695 | elif not str(key).isdigit(): |
---|
696 | record = None |
---|
697 | else: |
---|
698 | record = ( |
---|
699 | self._db(self._id == key) |
---|
700 | .select( |
---|
701 | limitby=(0, 1), |
---|
702 | for_update=for_update, |
---|
703 | orderby=orderby, |
---|
704 | orderby_on_limitby=False, |
---|
705 | ) |
---|
706 | .first() |
---|
707 | ) |
---|
708 | if record: |
---|
709 | for k, v in iteritems(kwargs): |
---|
710 | if record[k] != v: |
---|
711 | return None |
---|
712 | return record |
---|
713 | elif kwargs: |
---|
714 | query = reduce( |
---|
715 | lambda a, b: a & b, |
---|
716 | [getattr(self, k) == v for k, v in iteritems(kwargs)], |
---|
717 | ) |
---|
718 | return ( |
---|
719 | self._db(query) |
---|
720 | .select( |
---|
721 | limitby=(0, 1), |
---|
722 | for_update=for_update, |
---|
723 | orderby=orderby, |
---|
724 | orderby_on_limitby=False, |
---|
725 | ) |
---|
726 | .first() |
---|
727 | ) |
---|
728 | else: |
---|
729 | return None |
---|
730 | |
---|
731 | def __setitem__(self, key, value): |
---|
732 | if key is None: |
---|
733 | # table[None] = value (shortcut for insert) |
---|
734 | self.insert(**self._filter_fields(value)) |
---|
735 | elif str(key).isdigit(): |
---|
736 | # table[non negative key] = value (shortcut for update) |
---|
737 | if not self._db(self._id == key).update(**self._filter_fields(value)): |
---|
738 | raise SyntaxError("No such record: %s" % key) |
---|
739 | elif isinstance(key, dict): |
---|
740 | # keyed table |
---|
741 | if not isinstance(value, dict): |
---|
742 | raise SyntaxError("value must be a dictionary: %s" % value) |
---|
743 | if set(key.keys()) == set(self._primarykey): |
---|
744 | value = self._filter_fields(value) |
---|
745 | kv = {} |
---|
746 | kv.update(value) |
---|
747 | kv.update(key) |
---|
748 | if not self.insert(**kv): |
---|
749 | query = self._build_query(key) |
---|
750 | self._db(query).update(**self._filter_fields(value)) |
---|
751 | else: |
---|
752 | raise SyntaxError( |
---|
753 | "key must have all fields from primary key: %s" % self._primarykey |
---|
754 | ) |
---|
755 | else: |
---|
756 | if isinstance(value, FieldVirtual): |
---|
757 | value.bind(self, str(key)) |
---|
758 | self._virtual_fields.append(value) |
---|
759 | elif isinstance(value, FieldMethod): |
---|
760 | value.bind(self, str(key)) |
---|
761 | self._virtual_methods.append(value) |
---|
762 | self.__dict__[str(key)] = value |
---|
763 | |
---|
764 | def __setattr__(self, key, value): |
---|
765 | if key[:1] != "_" and key in self: |
---|
766 | raise SyntaxError("Object exists and cannot be redefined: %s" % key) |
---|
767 | self[key] = value |
---|
768 | |
---|
769 | def __delitem__(self, key): |
---|
770 | if isinstance(key, dict): |
---|
771 | query = self._build_query(key) |
---|
772 | if not self._db(query).delete(): |
---|
773 | raise SyntaxError("No such record: %s" % key) |
---|
774 | elif not str(key).isdigit() or not self._db(self._id == key).delete(): |
---|
775 | raise SyntaxError("No such record: %s" % key) |
---|
776 | |
---|
777 | def __iter__(self): |
---|
778 | for fieldname in self.fields: |
---|
779 | yield getattr(self, fieldname) |
---|
780 | |
---|
781 | def __repr__(self): |
---|
782 | return "<Table %s (%s)>" % (self._tablename, ", ".join(self.fields())) |
---|
783 | |
---|
784 | def __str__(self): |
---|
785 | if self._tablename == self._dalname: |
---|
786 | return self._tablename |
---|
787 | return self._db._adapter.dialect._as(self._dalname, self._tablename) |
---|
788 | |
---|
789 | @property |
---|
790 | @deprecated("sqlsafe", "sql_shortref", "Table") |
---|
791 | def sqlsafe(self): |
---|
792 | return self.sql_shortref |
---|
793 | |
---|
794 | @property |
---|
795 | @deprecated("sqlsafe_alias", "sql_fullref", "Table") |
---|
796 | def sqlsafe_alias(self): |
---|
797 | return self.sql_fullref |
---|
798 | |
---|
799 | @property |
---|
800 | def sql_shortref(self): |
---|
801 | if self._tablename == self._dalname: |
---|
802 | return self._rname |
---|
803 | return self._db._adapter.sqlsafe_table(self._tablename) |
---|
804 | |
---|
805 | @property |
---|
806 | def sql_fullref(self): |
---|
807 | if self._tablename == self._dalname: |
---|
808 | if self._db._adapter.dbengine == "oracle": |
---|
809 | return self._db._adapter.dialect.quote(self._rname) |
---|
810 | return self._rname |
---|
811 | return self._db._adapter.sqlsafe_table(self._tablename, self._rname) |
---|
812 | |
---|
813 | def query_name(self, *args, **kwargs): |
---|
814 | return (self.sql_fullref,) |
---|
815 | |
---|
816 | def _drop(self, mode=""): |
---|
817 | return self._db._adapter.dialect.drop_table(self, mode) |
---|
818 | |
---|
819 | def drop(self, mode=""): |
---|
820 | return self._db._adapter.drop_table(self, mode) |
---|
821 | |
---|
822 | def _filter_fields_for_operation(self, fields): |
---|
823 | new_fields = {} # format: new_fields[name] = (field, value) |
---|
824 | input_fieldnames = set(fields) |
---|
825 | table_fieldnames = set(self.fields) |
---|
826 | empty_fieldnames = OrderedDict((name, name) for name in self.fields) |
---|
827 | for name in list(input_fieldnames & table_fieldnames): |
---|
828 | field = getattr(self, name) |
---|
829 | value = field.filter_in(fields[name]) if field.filter_in else fields[name] |
---|
830 | new_fields[name] = (field, value) |
---|
831 | del empty_fieldnames[name] |
---|
832 | return list(empty_fieldnames), new_fields |
---|
833 | |
---|
834 | def _compute_fields_for_operation(self, fields, to_compute): |
---|
835 | row = OpRow(self) |
---|
836 | for name, tup in iteritems(fields): |
---|
837 | field, value = tup |
---|
838 | if isinstance( |
---|
839 | value, |
---|
840 | ( |
---|
841 | types.LambdaType, |
---|
842 | types.FunctionType, |
---|
843 | types.MethodType, |
---|
844 | types.BuiltinFunctionType, |
---|
845 | types.BuiltinMethodType, |
---|
846 | ), |
---|
847 | ): |
---|
848 | value = value() |
---|
849 | row.set_value(name, value, field) |
---|
850 | for name, field in to_compute: |
---|
851 | try: |
---|
852 | row.set_value(name, field.compute(row), field) |
---|
853 | except (KeyError, AttributeError): |
---|
854 | # error silently unless field is required! |
---|
855 | if field.required and name not in fields: |
---|
856 | raise RuntimeError("unable to compute required field: %s" % name) |
---|
857 | return row |
---|
858 | |
---|
859 | def _fields_and_values_for_insert(self, fields): |
---|
860 | empty_fieldnames, new_fields = self._filter_fields_for_operation(fields) |
---|
861 | to_compute = [] |
---|
862 | for name in empty_fieldnames: |
---|
863 | field = getattr(self, name) |
---|
864 | if field.compute: |
---|
865 | to_compute.append((name, field)) |
---|
866 | elif field.default is not None: |
---|
867 | new_fields[name] = (field, field.default) |
---|
868 | elif field.required: |
---|
869 | raise RuntimeError("Table: missing required field: %s" % name) |
---|
870 | return self._compute_fields_for_operation(new_fields, to_compute) |
---|
871 | |
---|
872 | def _fields_and_values_for_update(self, fields): |
---|
873 | empty_fieldnames, new_fields = self._filter_fields_for_operation(fields) |
---|
874 | to_compute = [] |
---|
875 | for name in empty_fieldnames: |
---|
876 | field = getattr(self, name) |
---|
877 | if field.compute: |
---|
878 | to_compute.append((name, field)) |
---|
879 | if field.update is not None: |
---|
880 | new_fields[name] = (field, field.update) |
---|
881 | return self._compute_fields_for_operation(new_fields, to_compute) |
---|
882 | |
---|
883 | def _insert(self, **fields): |
---|
884 | row = self._fields_and_values_for_insert(fields) |
---|
885 | return self._db._adapter._insert(self, row.op_values()) |
---|
886 | |
---|
887 | def insert(self, **fields): |
---|
888 | row = self._fields_and_values_for_insert(fields) |
---|
889 | if any(f(row) for f in self._before_insert): |
---|
890 | return 0 |
---|
891 | ret = self._db._adapter.insert(self, row.op_values()) |
---|
892 | if ret and self._after_insert: |
---|
893 | for f in self._after_insert: |
---|
894 | f(row, ret) |
---|
895 | return ret |
---|
896 | |
---|
897 | def _validate_fields(self, fields, defattr="default", id=None): |
---|
898 | response = Row() |
---|
899 | response.id, response.errors, new_fields = None, Row(), Row() |
---|
900 | for field in self: |
---|
901 | # we validate even if not passed in case it is required |
---|
902 | error = default = None |
---|
903 | if not field.required and not field.compute: |
---|
904 | default = getattr(field, defattr) |
---|
905 | if callable(default): |
---|
906 | default = default() |
---|
907 | if not field.compute: |
---|
908 | value = fields.get(field.name, default) |
---|
909 | value, error = field.validate(value, id) |
---|
910 | if error: |
---|
911 | response.errors[field.name] = "%s" % error |
---|
912 | elif field.name in fields: |
---|
913 | # only write if the field was passed and no error |
---|
914 | new_fields[field.name] = value |
---|
915 | return response, new_fields |
---|
916 | |
---|
917 | def validate_and_insert(self, **fields): |
---|
918 | response, new_fields = self._validate_fields(fields, "default") |
---|
919 | if not response.errors: |
---|
920 | response.id = self.insert(**new_fields) |
---|
921 | return response |
---|
922 | |
---|
923 | def validate_and_update(self, _key, **fields): |
---|
924 | record = self(**_key) if isinstance(_key, dict) else self(_key) |
---|
925 | response, new_fields = self._validate_fields(fields, "update", record.id) |
---|
926 | #: do the update |
---|
927 | if not response.errors and record: |
---|
928 | if "_id" in self: |
---|
929 | myset = self._db(self._id == record[self._id.name]) |
---|
930 | else: |
---|
931 | query = None |
---|
932 | for key, value in iteritems(_key): |
---|
933 | if query is None: |
---|
934 | query = getattr(self, key) == value |
---|
935 | else: |
---|
936 | query = query & (getattr(self, key) == value) |
---|
937 | myset = self._db(query) |
---|
938 | response.updated = myset.update(**new_fields) |
---|
939 | if record: |
---|
940 | response.id = record.id |
---|
941 | return response |
---|
942 | |
---|
943 | def update_or_insert(self, _key=DEFAULT, **values): |
---|
944 | if _key is DEFAULT: |
---|
945 | record = self(**values) |
---|
946 | elif isinstance(_key, dict): |
---|
947 | record = self(**_key) |
---|
948 | else: |
---|
949 | record = self(_key) |
---|
950 | if record: |
---|
951 | record.update_record(**values) |
---|
952 | newid = None |
---|
953 | else: |
---|
954 | newid = self.insert(**values) |
---|
955 | return newid |
---|
956 | |
---|
957 | def validate_and_update_or_insert(self, _key=DEFAULT, **fields): |
---|
958 | if _key is DEFAULT or _key == "": |
---|
959 | primary_keys = {} |
---|
960 | for key, value in iteritems(fields): |
---|
961 | if key in self._primarykey: |
---|
962 | primary_keys[key] = value |
---|
963 | if primary_keys != {}: |
---|
964 | record = self(**primary_keys) |
---|
965 | _key = primary_keys |
---|
966 | else: |
---|
967 | required_keys = {} |
---|
968 | for key, value in iteritems(fields): |
---|
969 | if getattr(self, key).required: |
---|
970 | required_keys[key] = value |
---|
971 | record = self(**required_keys) |
---|
972 | _key = required_keys |
---|
973 | elif isinstance(_key, dict): |
---|
974 | record = self(**_key) |
---|
975 | else: |
---|
976 | record = self(_key) |
---|
977 | |
---|
978 | if record: |
---|
979 | response = self.validate_and_update(_key, **fields) |
---|
980 | if hasattr(self, "_primarykey"): |
---|
981 | primary_keys = {} |
---|
982 | for key in self._primarykey: |
---|
983 | primary_keys[key] = getattr(record, key) |
---|
984 | response.id = primary_keys |
---|
985 | else: |
---|
986 | response = self.validate_and_insert(**fields) |
---|
987 | return response |
---|
988 | |
---|
989 | def bulk_insert(self, items): |
---|
990 | """ |
---|
991 | here items is a list of dictionaries |
---|
992 | """ |
---|
993 | data = [self._fields_and_values_for_insert(item) for item in items] |
---|
994 | if any(f(el) for el in data for f in self._before_insert): |
---|
995 | return 0 |
---|
996 | ret = self._db._adapter.bulk_insert(self, [el.op_values() for el in data]) |
---|
997 | ret and [ |
---|
998 | [f(el, ret[k]) for k, el in enumerate(data)] for f in self._after_insert |
---|
999 | ] |
---|
1000 | return ret |
---|
1001 | |
---|
1002 | def _truncate(self, mode=""): |
---|
1003 | return self._db._adapter.dialect.truncate(self, mode) |
---|
1004 | |
---|
1005 | def truncate(self, mode=""): |
---|
1006 | return self._db._adapter.truncate(self, mode) |
---|
1007 | |
---|
1008 | def import_from_csv_file( |
---|
1009 | self, |
---|
1010 | csvfile, |
---|
1011 | id_map=None, |
---|
1012 | null="<NULL>", |
---|
1013 | unique="uuid", |
---|
1014 | id_offset=None, # id_offset used only when id_map is None |
---|
1015 | transform=None, |
---|
1016 | validate=False, |
---|
1017 | encoding="utf-8", |
---|
1018 | **kwargs |
---|
1019 | ): |
---|
1020 | """ |
---|
1021 | Import records from csv file. |
---|
1022 | Column headers must have same names as table fields. |
---|
1023 | Field 'id' is ignored. |
---|
1024 | If column names read 'table.file' the 'table.' prefix is ignored. |
---|
1025 | |
---|
1026 | - 'unique' argument is a field which must be unique (typically a |
---|
1027 | uuid field) |
---|
1028 | - 'restore' argument is default False; if set True will remove old values |
---|
1029 | in table first. |
---|
1030 | - 'id_map' if set to None will not map ids |
---|
1031 | |
---|
1032 | The import will keep the id numbers in the restored table. |
---|
1033 | This assumes that there is a field of type id that is integer and in |
---|
1034 | incrementing order. |
---|
1035 | Will keep the id numbers in restored table. |
---|
1036 | """ |
---|
1037 | if validate: |
---|
1038 | inserting = self.validate_and_insert |
---|
1039 | else: |
---|
1040 | inserting = self.insert |
---|
1041 | |
---|
1042 | delimiter = kwargs.get("delimiter", ",") |
---|
1043 | quotechar = kwargs.get("quotechar", '"') |
---|
1044 | quoting = kwargs.get("quoting", csv.QUOTE_MINIMAL) |
---|
1045 | restore = kwargs.get("restore", False) |
---|
1046 | if restore: |
---|
1047 | self._db[self].truncate() |
---|
1048 | |
---|
1049 | reader = csv_reader( |
---|
1050 | csvfile, |
---|
1051 | delimiter=delimiter, |
---|
1052 | encoding=encoding, |
---|
1053 | quotechar=quotechar, |
---|
1054 | quoting=quoting, |
---|
1055 | ) |
---|
1056 | colnames = None |
---|
1057 | if isinstance(id_map, dict): |
---|
1058 | if self._tablename not in id_map: |
---|
1059 | id_map[self._tablename] = {} |
---|
1060 | id_map_self = id_map[self._tablename] |
---|
1061 | |
---|
1062 | def fix(field, value, id_map, id_offset): |
---|
1063 | list_reference_s = "list:reference" |
---|
1064 | if value == null: |
---|
1065 | value = None |
---|
1066 | elif field.type == "blob": |
---|
1067 | value = base64.b64decode(value) |
---|
1068 | elif field.type == "double" or field.type == "float": |
---|
1069 | if not value.strip(): |
---|
1070 | value = None |
---|
1071 | else: |
---|
1072 | value = float(value) |
---|
1073 | elif field.type in ("integer", "bigint"): |
---|
1074 | if not value.strip(): |
---|
1075 | value = None |
---|
1076 | else: |
---|
1077 | value = long(value) |
---|
1078 | elif field.type.startswith("list:string"): |
---|
1079 | value = bar_decode_string(value) |
---|
1080 | elif field.type.startswith(list_reference_s): |
---|
1081 | ref_table = field.type[len(list_reference_s) :].strip() |
---|
1082 | if id_map is not None: |
---|
1083 | value = [ |
---|
1084 | id_map[ref_table][long(v)] for v in bar_decode_string(value) |
---|
1085 | ] |
---|
1086 | else: |
---|
1087 | value = [v for v in bar_decode_string(value)] |
---|
1088 | elif field.type.startswith("list:"): |
---|
1089 | value = bar_decode_integer(value) |
---|
1090 | elif id_map and field.type.startswith("reference"): |
---|
1091 | try: |
---|
1092 | value = id_map[field.type[9:].strip()][long(value)] |
---|
1093 | except KeyError: |
---|
1094 | pass |
---|
1095 | elif id_offset and field.type.startswith("reference"): |
---|
1096 | try: |
---|
1097 | value = id_offset[field.type[9:].strip()] + long(value) |
---|
1098 | except KeyError: |
---|
1099 | pass |
---|
1100 | return value |
---|
1101 | |
---|
1102 | def is_id(colname): |
---|
1103 | if colname in self: |
---|
1104 | return getattr(self, colname).type == "id" |
---|
1105 | else: |
---|
1106 | return False |
---|
1107 | |
---|
1108 | first = True |
---|
1109 | unique_idx = None |
---|
1110 | for lineno, line in enumerate(reader): |
---|
1111 | if not line: |
---|
1112 | return |
---|
1113 | if not colnames: |
---|
1114 | # assume this is the first line of the input, contains colnames |
---|
1115 | colnames = [x.split(".", 1)[-1] for x in line] |
---|
1116 | |
---|
1117 | cols, cid = {}, None |
---|
1118 | for i, colname in enumerate(colnames): |
---|
1119 | if is_id(colname): |
---|
1120 | cid = colname |
---|
1121 | elif colname in self.fields: |
---|
1122 | cols[colname] = getattr(self, colname) |
---|
1123 | if colname == unique: |
---|
1124 | unique_idx = i |
---|
1125 | elif len(line) == len(colnames): |
---|
1126 | # every other line contains instead data |
---|
1127 | items = dict(zip(colnames, line)) |
---|
1128 | if transform: |
---|
1129 | items = transform(items) |
---|
1130 | |
---|
1131 | ditems = dict() |
---|
1132 | csv_id = None |
---|
1133 | for field in self: |
---|
1134 | fieldname = field.name |
---|
1135 | if fieldname in items: |
---|
1136 | try: |
---|
1137 | value = fix(field, items[fieldname], id_map, id_offset) |
---|
1138 | if field.type != "id": |
---|
1139 | ditems[fieldname] = value |
---|
1140 | else: |
---|
1141 | csv_id = long(value) |
---|
1142 | except ValueError: |
---|
1143 | raise RuntimeError("Unable to parse line:%s" % (lineno + 1)) |
---|
1144 | if not (id_map or csv_id is None or id_offset is None or unique_idx): |
---|
1145 | curr_id = inserting(**ditems) |
---|
1146 | if first: |
---|
1147 | first = False |
---|
1148 | # First curr_id is bigger than csv_id, |
---|
1149 | # then we are not restoring but |
---|
1150 | # extending db table with csv db table |
---|
1151 | id_offset[self._tablename] = ( |
---|
1152 | (curr_id - csv_id) if curr_id > csv_id else 0 |
---|
1153 | ) |
---|
1154 | # create new id until we get the same as old_id+offset |
---|
1155 | while curr_id < csv_id + id_offset[self._tablename]: |
---|
1156 | self._db(getattr(self, cid) == curr_id).delete() |
---|
1157 | curr_id = inserting(**ditems) |
---|
1158 | # Validation. Check for duplicate of 'unique' &, |
---|
1159 | # if present, update instead of insert. |
---|
1160 | elif not unique_idx: |
---|
1161 | new_id = inserting(**ditems) |
---|
1162 | else: |
---|
1163 | unique_value = line[unique_idx] |
---|
1164 | query = getattr(self, unique) == unique_value |
---|
1165 | record = self._db(query).select().first() |
---|
1166 | if record: |
---|
1167 | record.update_record(**ditems) |
---|
1168 | new_id = record[self._id.name] |
---|
1169 | else: |
---|
1170 | new_id = inserting(**ditems) |
---|
1171 | if id_map and csv_id is not None: |
---|
1172 | id_map_self[csv_id] = new_id |
---|
1173 | if lineno % 1000 == 999: |
---|
1174 | self._db.commit() |
---|
1175 | |
---|
1176 | def as_dict(self, flat=False, sanitize=True): |
---|
1177 | table_as_dict = dict( |
---|
1178 | tablename=str(self), |
---|
1179 | fields=[], |
---|
1180 | sequence_name=self._sequence_name, |
---|
1181 | trigger_name=self._trigger_name, |
---|
1182 | common_filter=self._common_filter, |
---|
1183 | format=self._format, |
---|
1184 | singular=self._singular, |
---|
1185 | plural=self._plural, |
---|
1186 | ) |
---|
1187 | |
---|
1188 | for field in self: |
---|
1189 | if (field.readable or field.writable) or (not sanitize): |
---|
1190 | table_as_dict["fields"].append( |
---|
1191 | field.as_dict(flat=flat, sanitize=sanitize) |
---|
1192 | ) |
---|
1193 | return table_as_dict |
---|
1194 | |
---|
1195 | def with_alias(self, alias): |
---|
1196 | try: |
---|
1197 | if self._db[alias]._rname == self._rname: |
---|
1198 | return self._db[alias] |
---|
1199 | except AttributeError: # we never used this alias |
---|
1200 | pass |
---|
1201 | other = copy.copy(self) |
---|
1202 | other["ALL"] = SQLALL(other) |
---|
1203 | other["_tablename"] = alias |
---|
1204 | for fieldname in other.fields: |
---|
1205 | tmp = getattr(self, fieldname).clone() |
---|
1206 | tmp.bind(other) |
---|
1207 | other[fieldname] = tmp |
---|
1208 | if "id" in self and "id" not in other.fields: |
---|
1209 | other["id"] = other[self.id.name] |
---|
1210 | other._id = other[self._id.name] |
---|
1211 | setattr(self._db._aliased_tables, alias, other) |
---|
1212 | return other |
---|
1213 | |
---|
1214 | def on(self, query): |
---|
1215 | return Expression(self._db, self._db._adapter.dialect.on, self, query) |
---|
1216 | |
---|
1217 | def create_index(self, name, *fields, **kwargs): |
---|
1218 | return self._db._adapter.create_index(self, name, *fields, **kwargs) |
---|
1219 | |
---|
1220 | def drop_index(self, name): |
---|
1221 | return self._db._adapter.drop_index(self, name) |
---|
1222 | |
---|
1223 | |
---|
1224 | class Select(BasicStorage): |
---|
1225 | def __init__(self, db, query, fields, attributes): |
---|
1226 | self._db = db |
---|
1227 | self._tablename = None # alias will be stored here |
---|
1228 | self._rname = self._raw_rname = self._dalname = None |
---|
1229 | self._common_filter = None |
---|
1230 | self._query = query |
---|
1231 | # if false, the subquery will never reference tables from parent scope |
---|
1232 | self._correlated = attributes.pop("correlated", True) |
---|
1233 | self._attributes = attributes |
---|
1234 | self._qfields = list(fields) |
---|
1235 | self._fields = SQLCallableList() |
---|
1236 | self._virtual_fields = [] |
---|
1237 | self._virtual_methods = [] |
---|
1238 | self.virtualfields = [] |
---|
1239 | self._sql_cache = None |
---|
1240 | self._colnames_cache = None |
---|
1241 | fieldcheck = set() |
---|
1242 | |
---|
1243 | for item in fields: |
---|
1244 | if isinstance(item, Field): |
---|
1245 | checkname = item.name |
---|
1246 | field = item.clone() |
---|
1247 | elif isinstance(item, Expression): |
---|
1248 | if item.op != item._dialect._as: |
---|
1249 | continue |
---|
1250 | checkname = item.second |
---|
1251 | field = Field(item.second, type=item.type) |
---|
1252 | else: |
---|
1253 | raise SyntaxError("Invalid field in Select") |
---|
1254 | if db and db._ignore_field_case: |
---|
1255 | checkname = checkname.lower() |
---|
1256 | if checkname in fieldcheck: |
---|
1257 | raise SyntaxError("duplicate field %s in select query" % field.name) |
---|
1258 | fieldcheck.add(checkname) |
---|
1259 | field.bind(self) |
---|
1260 | self.fields.append(field.name) |
---|
1261 | self[field.name] = field |
---|
1262 | self.ALL = SQLALL(self) |
---|
1263 | |
---|
1264 | @property |
---|
1265 | def fields(self): |
---|
1266 | return self._fields |
---|
1267 | |
---|
1268 | def update(self, *args, **kwargs): |
---|
1269 | raise RuntimeError("update() method not supported") |
---|
1270 | |
---|
1271 | def __getitem__(self, key): |
---|
1272 | try: |
---|
1273 | return getattr(self, key) |
---|
1274 | except AttributeError: |
---|
1275 | raise KeyError(key) |
---|
1276 | |
---|
1277 | def __setitem__(self, key, value): |
---|
1278 | self.__dict__[str(key)] = value |
---|
1279 | |
---|
1280 | def __call__(self): |
---|
1281 | adapter = self._db._adapter |
---|
1282 | colnames, sql = self._compile() |
---|
1283 | cache = self._attributes.get("cache", None) |
---|
1284 | if cache and self._attributes.get("cacheable", False): |
---|
1285 | return adapter._cached_select( |
---|
1286 | cache, sql, self._fields, self._attributes, colnames |
---|
1287 | ) |
---|
1288 | return adapter._select_aux(sql, self._qfields, self._attributes, colnames) |
---|
1289 | |
---|
1290 | def __setattr__(self, key, value): |
---|
1291 | if key[:1] != "_" and key in self: |
---|
1292 | raise SyntaxError("Object exists and cannot be redefined: %s" % key) |
---|
1293 | self[key] = value |
---|
1294 | |
---|
1295 | def __iter__(self): |
---|
1296 | for fieldname in self.fields: |
---|
1297 | yield self[fieldname] |
---|
1298 | |
---|
1299 | def __repr__(self): |
---|
1300 | return "<Select (%s)>" % ", ".join(map(str, self._qfields)) |
---|
1301 | |
---|
1302 | def __str__(self): |
---|
1303 | return self._compile(with_alias=(self._tablename is not None))[1] |
---|
1304 | |
---|
1305 | def with_alias(self, alias): |
---|
1306 | other = copy.copy(self) |
---|
1307 | other["ALL"] = SQLALL(other) |
---|
1308 | other["_tablename"] = alias |
---|
1309 | for fieldname in other.fields: |
---|
1310 | tmp = self[fieldname].clone() |
---|
1311 | tmp.bind(other) |
---|
1312 | other[fieldname] = tmp |
---|
1313 | return other |
---|
1314 | |
---|
1315 | def on(self, query): |
---|
1316 | if not self._tablename: |
---|
1317 | raise SyntaxError("Subselect must be aliased for use in a JOIN") |
---|
1318 | return Expression(self._db, self._db._adapter.dialect.on, self, query) |
---|
1319 | |
---|
1320 | def _compile(self, outer_scoped=[], with_alias=False): |
---|
1321 | if not self._correlated: |
---|
1322 | outer_scoped = [] |
---|
1323 | if outer_scoped or not self._sql_cache: |
---|
1324 | adapter = self._db._adapter |
---|
1325 | attributes = self._attributes.copy() |
---|
1326 | attributes["outer_scoped"] = outer_scoped |
---|
1327 | colnames, sql = adapter._select_wcols( |
---|
1328 | self._query, self._qfields, **attributes |
---|
1329 | ) |
---|
1330 | # Do not cache when the query may depend on external tables |
---|
1331 | if not outer_scoped: |
---|
1332 | self._colnames_cache, self._sql_cache = colnames, sql |
---|
1333 | else: |
---|
1334 | colnames, sql = self._colnames_cache, self._sql_cache |
---|
1335 | if with_alias and self._tablename is not None: |
---|
1336 | sql = "(%s)" % sql[:-1] |
---|
1337 | sql = self._db._adapter.dialect.alias(sql, self._tablename) |
---|
1338 | return colnames, sql |
---|
1339 | |
---|
1340 | def query_name(self, outer_scoped=[]): |
---|
1341 | if self._tablename is None: |
---|
1342 | raise SyntaxError("Subselect must be aliased for use in a JOIN") |
---|
1343 | colnames, sql = self._compile(outer_scoped, True) |
---|
1344 | # This method should also return list of placeholder values |
---|
1345 | # in the future |
---|
1346 | return (sql,) |
---|
1347 | |
---|
1348 | @property |
---|
1349 | def sql_shortref(self): |
---|
1350 | if self._tablename is None: |
---|
1351 | raise SyntaxError("Subselect must be aliased for use in a JOIN") |
---|
1352 | return self._db._adapter.dialect.quote(self._tablename) |
---|
1353 | |
---|
1354 | def _filter_fields(self, record, id=False): |
---|
1355 | return dict( |
---|
1356 | [ |
---|
1357 | (k, v) |
---|
1358 | for (k, v) in iteritems(record) |
---|
1359 | if k in self.fields and (self[k].type != "id" or id) |
---|
1360 | ] |
---|
1361 | ) |
---|
1362 | |
---|
1363 | |
---|
1364 | def _expression_wrap(wrapper): |
---|
1365 | def wrap(self, *args, **kwargs): |
---|
1366 | return wrapper(self, *args, **kwargs) |
---|
1367 | |
---|
1368 | return wrap |
---|
1369 | |
---|
1370 | |
---|
1371 | class Expression(object): |
---|
1372 | _dialect_expressions_ = {} |
---|
1373 | |
---|
1374 | def __new__(cls, *args, **kwargs): |
---|
1375 | for name, wrapper in iteritems(cls._dialect_expressions_): |
---|
1376 | setattr(cls, name, _expression_wrap(wrapper)) |
---|
1377 | new_cls = super(Expression, cls).__new__(cls) |
---|
1378 | return new_cls |
---|
1379 | |
---|
1380 | def __init__(self, db, op, first=None, second=None, type=None, **optional_args): |
---|
1381 | self.db = db |
---|
1382 | self.op = op |
---|
1383 | self.first = first |
---|
1384 | self.second = second |
---|
1385 | self._table = getattr(first, "_table", None) |
---|
1386 | if not type and first and hasattr(first, "type"): |
---|
1387 | self.type = first.type |
---|
1388 | else: |
---|
1389 | self.type = type |
---|
1390 | if isinstance(self.type, str): |
---|
1391 | self._itype = REGEX_TYPE.match(self.type).group(0) |
---|
1392 | else: |
---|
1393 | self._itype = None |
---|
1394 | self.optional_args = optional_args |
---|
1395 | |
---|
1396 | @property |
---|
1397 | def _dialect(self): |
---|
1398 | return self.db._adapter.dialect |
---|
1399 | |
---|
1400 | def sum(self): |
---|
1401 | return Expression(self.db, self._dialect.aggregate, self, "SUM", self.type) |
---|
1402 | |
---|
1403 | def max(self): |
---|
1404 | return Expression(self.db, self._dialect.aggregate, self, "MAX", self.type) |
---|
1405 | |
---|
1406 | def min(self): |
---|
1407 | return Expression(self.db, self._dialect.aggregate, self, "MIN", self.type) |
---|
1408 | |
---|
1409 | def len(self): |
---|
1410 | return Expression(self.db, self._dialect.length, self, None, "integer") |
---|
1411 | |
---|
1412 | def avg(self): |
---|
1413 | return Expression(self.db, self._dialect.aggregate, self, "AVG", self.type) |
---|
1414 | |
---|
1415 | def abs(self): |
---|
1416 | return Expression(self.db, self._dialect.aggregate, self, "ABS", self.type) |
---|
1417 | |
---|
1418 | def cast(self, cast_as, **kwargs): |
---|
1419 | return Expression( |
---|
1420 | self.db, |
---|
1421 | self._dialect.cast, |
---|
1422 | self, |
---|
1423 | self._dialect.types[cast_as] % kwargs, |
---|
1424 | cast_as, |
---|
1425 | ) |
---|
1426 | |
---|
1427 | def lower(self): |
---|
1428 | return Expression(self.db, self._dialect.lower, self, None, self.type) |
---|
1429 | |
---|
1430 | def upper(self): |
---|
1431 | return Expression(self.db, self._dialect.upper, self, None, self.type) |
---|
1432 | |
---|
1433 | def replace(self, a, b): |
---|
1434 | return Expression(self.db, self._dialect.replace, self, (a, b), self.type) |
---|
1435 | |
---|
1436 | def year(self): |
---|
1437 | return Expression(self.db, self._dialect.extract, self, "year", "integer") |
---|
1438 | |
---|
1439 | def month(self): |
---|
1440 | return Expression(self.db, self._dialect.extract, self, "month", "integer") |
---|
1441 | |
---|
1442 | def day(self): |
---|
1443 | return Expression(self.db, self._dialect.extract, self, "day", "integer") |
---|
1444 | |
---|
1445 | def hour(self): |
---|
1446 | return Expression(self.db, self._dialect.extract, self, "hour", "integer") |
---|
1447 | |
---|
1448 | def minutes(self): |
---|
1449 | return Expression(self.db, self._dialect.extract, self, "minute", "integer") |
---|
1450 | |
---|
1451 | def coalesce(self, *others): |
---|
1452 | return Expression(self.db, self._dialect.coalesce, self, others, self.type) |
---|
1453 | |
---|
1454 | def coalesce_zero(self): |
---|
1455 | return Expression(self.db, self._dialect.coalesce_zero, self, None, self.type) |
---|
1456 | |
---|
1457 | def seconds(self): |
---|
1458 | return Expression(self.db, self._dialect.extract, self, "second", "integer") |
---|
1459 | |
---|
1460 | def epoch(self): |
---|
1461 | return Expression(self.db, self._dialect.epoch, self, None, "integer") |
---|
1462 | |
---|
1463 | def __getitem__(self, i): |
---|
1464 | if isinstance(i, slice): |
---|
1465 | start = i.start or 0 |
---|
1466 | stop = i.stop |
---|
1467 | |
---|
1468 | db = self.db |
---|
1469 | if start < 0: |
---|
1470 | pos0 = "(%s - %d)" % (self.len(), abs(start) - 1) |
---|
1471 | else: |
---|
1472 | pos0 = start + 1 |
---|
1473 | |
---|
1474 | maxint = sys.maxint if PY2 else sys.maxsize |
---|
1475 | if stop is None or stop == maxint: |
---|
1476 | length = self.len() |
---|
1477 | elif stop < 0: |
---|
1478 | length = "(%s - %d - %s)" % (self.len(), abs(stop) - 1, pos0) |
---|
1479 | else: |
---|
1480 | length = "(%s - %s)" % (stop + 1, pos0) |
---|
1481 | |
---|
1482 | return Expression( |
---|
1483 | db, self._dialect.substring, self, (pos0, length), self.type |
---|
1484 | ) |
---|
1485 | else: |
---|
1486 | return self[i : i + 1] |
---|
1487 | |
---|
1488 | def __str__(self): |
---|
1489 | return str(self.db._adapter.expand(self, self.type)) |
---|
1490 | |
---|
1491 | def __or__(self, other): # for use in sortby |
---|
1492 | return Expression(self.db, self._dialect.comma, self, other, self.type) |
---|
1493 | |
---|
1494 | def __invert__(self): |
---|
1495 | if hasattr(self, "_op") and self.op == self._dialect.invert: |
---|
1496 | return self.first |
---|
1497 | return Expression(self.db, self._dialect.invert, self, type=self.type) |
---|
1498 | |
---|
1499 | def __add__(self, other): |
---|
1500 | return Expression(self.db, self._dialect.add, self, other, self.type) |
---|
1501 | |
---|
1502 | def __sub__(self, other): |
---|
1503 | if self.type in ("integer", "bigint"): |
---|
1504 | result_type = "integer" |
---|
1505 | elif self.type in ["date", "time", "datetime", "double", "float"]: |
---|
1506 | result_type = "double" |
---|
1507 | elif self.type.startswith("decimal("): |
---|
1508 | result_type = self.type |
---|
1509 | else: |
---|
1510 | raise SyntaxError("subtraction operation not supported for type") |
---|
1511 | return Expression(self.db, self._dialect.sub, self, other, result_type) |
---|
1512 | |
---|
1513 | def __mul__(self, other): |
---|
1514 | return Expression(self.db, self._dialect.mul, self, other, self.type) |
---|
1515 | |
---|
1516 | def __div__(self, other): |
---|
1517 | return Expression(self.db, self._dialect.div, self, other, self.type) |
---|
1518 | |
---|
1519 | def __truediv__(self, other): |
---|
1520 | return self.__div__(other) |
---|
1521 | |
---|
1522 | def __mod__(self, other): |
---|
1523 | return Expression(self.db, self._dialect.mod, self, other, self.type) |
---|
1524 | |
---|
1525 | def __eq__(self, value): |
---|
1526 | return Query(self.db, self._dialect.eq, self, value) |
---|
1527 | |
---|
1528 | def __ne__(self, value): |
---|
1529 | return Query(self.db, self._dialect.ne, self, value) |
---|
1530 | |
---|
1531 | def __lt__(self, value): |
---|
1532 | return Query(self.db, self._dialect.lt, self, value) |
---|
1533 | |
---|
1534 | def __le__(self, value): |
---|
1535 | return Query(self.db, self._dialect.lte, self, value) |
---|
1536 | |
---|
1537 | def __gt__(self, value): |
---|
1538 | return Query(self.db, self._dialect.gt, self, value) |
---|
1539 | |
---|
1540 | def __ge__(self, value): |
---|
1541 | return Query(self.db, self._dialect.gte, self, value) |
---|
1542 | |
---|
1543 | def like(self, value, case_sensitive=True, escape=None): |
---|
1544 | op = case_sensitive and self._dialect.like or self._dialect.ilike |
---|
1545 | return Query(self.db, op, self, value, escape=escape) |
---|
1546 | |
---|
1547 | def ilike(self, value, escape=None): |
---|
1548 | return self.like(value, case_sensitive=False, escape=escape) |
---|
1549 | |
---|
1550 | def regexp(self, value): |
---|
1551 | return Query(self.db, self._dialect.regexp, self, value) |
---|
1552 | |
---|
1553 | def belongs(self, *value, **kwattr): |
---|
1554 | """ |
---|
1555 | Accepts the following inputs:: |
---|
1556 | |
---|
1557 | field.belongs(1, 2) |
---|
1558 | field.belongs((1, 2)) |
---|
1559 | field.belongs(query) |
---|
1560 | |
---|
1561 | Does NOT accept: |
---|
1562 | |
---|
1563 | field.belongs(1) |
---|
1564 | |
---|
1565 | If the set you want back includes `None` values, you can do:: |
---|
1566 | |
---|
1567 | field.belongs((1, None), null=True) |
---|
1568 | |
---|
1569 | """ |
---|
1570 | db = self.db |
---|
1571 | if len(value) == 1: |
---|
1572 | value = value[0] |
---|
1573 | if isinstance(value, Query): |
---|
1574 | value = db(value)._select(value.first._table._id) |
---|
1575 | elif not isinstance(value, (Select, basestring)): |
---|
1576 | value = set(value) |
---|
1577 | if kwattr.get("null") and None in value: |
---|
1578 | value.remove(None) |
---|
1579 | return (self == None) | Query( |
---|
1580 | self.db, self._dialect.belongs, self, value |
---|
1581 | ) |
---|
1582 | return Query(self.db, self._dialect.belongs, self, value) |
---|
1583 | |
---|
1584 | def startswith(self, value): |
---|
1585 | if self.type not in ("string", "text", "json", "jsonb", "upload"): |
---|
1586 | raise SyntaxError("startswith used with incompatible field type") |
---|
1587 | return Query(self.db, self._dialect.startswith, self, value) |
---|
1588 | |
---|
1589 | def endswith(self, value): |
---|
1590 | if self.type not in ("string", "text", "json", "jsonb", "upload"): |
---|
1591 | raise SyntaxError("endswith used with incompatible field type") |
---|
1592 | return Query(self.db, self._dialect.endswith, self, value) |
---|
1593 | |
---|
1594 | def contains(self, value, all=False, case_sensitive=False): |
---|
1595 | """ |
---|
1596 | For GAE contains() is always case sensitive |
---|
1597 | """ |
---|
1598 | if isinstance(value, (list, tuple)): |
---|
1599 | subqueries = [ |
---|
1600 | self.contains(str(v), case_sensitive=case_sensitive) |
---|
1601 | for v in value |
---|
1602 | if str(v) |
---|
1603 | ] |
---|
1604 | if not subqueries: |
---|
1605 | return self.contains("") |
---|
1606 | else: |
---|
1607 | return reduce(all and AND or OR, subqueries) |
---|
1608 | if self.type not in ( |
---|
1609 | "string", |
---|
1610 | "text", |
---|
1611 | "json", |
---|
1612 | "jsonb", |
---|
1613 | "upload", |
---|
1614 | ) and not self.type.startswith("list:"): |
---|
1615 | raise SyntaxError("contains used with incompatible field type") |
---|
1616 | return Query( |
---|
1617 | self.db, self._dialect.contains, self, value, case_sensitive=case_sensitive |
---|
1618 | ) |
---|
1619 | |
---|
1620 | def with_alias(self, alias): |
---|
1621 | return Expression(self.db, self._dialect._as, self, alias, self.type) |
---|
1622 | |
---|
1623 | @property |
---|
1624 | def alias(self): |
---|
1625 | if self.op == self._dialect._as: |
---|
1626 | return self.second |
---|
1627 | |
---|
1628 | # GIS expressions |
---|
1629 | |
---|
1630 | def st_asgeojson(self, precision=15, options=0): |
---|
1631 | return Expression( |
---|
1632 | self.db, |
---|
1633 | self._dialect.st_asgeojson, |
---|
1634 | self, |
---|
1635 | dict(precision=precision, options=options), |
---|
1636 | "string", |
---|
1637 | ) |
---|
1638 | |
---|
1639 | def st_astext(self): |
---|
1640 | return Expression(self.db, self._dialect.st_astext, self, type="string") |
---|
1641 | |
---|
1642 | def st_aswkb(self): |
---|
1643 | return Expression(self.db, self._dialect.st_aswkb, self, type="string") |
---|
1644 | |
---|
1645 | def st_x(self): |
---|
1646 | return Expression(self.db, self._dialect.st_x, self, type="string") |
---|
1647 | |
---|
1648 | def st_y(self): |
---|
1649 | return Expression(self.db, self._dialect.st_y, self, type="string") |
---|
1650 | |
---|
1651 | def st_distance(self, other): |
---|
1652 | return Expression(self.db, self._dialect.st_distance, self, other, "double") |
---|
1653 | |
---|
1654 | def st_simplify(self, value): |
---|
1655 | return Expression(self.db, self._dialect.st_simplify, self, value, self.type) |
---|
1656 | |
---|
1657 | def st_simplifypreservetopology(self, value): |
---|
1658 | return Expression( |
---|
1659 | self.db, self._dialect.st_simplifypreservetopology, self, value, self.type |
---|
1660 | ) |
---|
1661 | |
---|
1662 | def st_transform(self, value): |
---|
1663 | return Expression(self.db, self._dialect.st_transform, self, value, self.type) |
---|
1664 | |
---|
1665 | # GIS queries |
---|
1666 | |
---|
1667 | def st_contains(self, value): |
---|
1668 | return Query(self.db, self._dialect.st_contains, self, value) |
---|
1669 | |
---|
1670 | def st_equals(self, value): |
---|
1671 | return Query(self.db, self._dialect.st_equals, self, value) |
---|
1672 | |
---|
1673 | def st_intersects(self, value): |
---|
1674 | return Query(self.db, self._dialect.st_intersects, self, value) |
---|
1675 | |
---|
1676 | def st_overlaps(self, value): |
---|
1677 | return Query(self.db, self._dialect.st_overlaps, self, value) |
---|
1678 | |
---|
1679 | def st_touches(self, value): |
---|
1680 | return Query(self.db, self._dialect.st_touches, self, value) |
---|
1681 | |
---|
1682 | def st_within(self, value): |
---|
1683 | return Query(self.db, self._dialect.st_within, self, value) |
---|
1684 | |
---|
1685 | def st_dwithin(self, value, distance): |
---|
1686 | return Query(self.db, self._dialect.st_dwithin, self, (value, distance)) |
---|
1687 | |
---|
1688 | # JSON Expressions |
---|
1689 | |
---|
1690 | def json_key(self, key): |
---|
1691 | """ |
---|
1692 | Get the json in key which you can use to build queries or as one of the |
---|
1693 | fields you want to get in a select. |
---|
1694 | |
---|
1695 | Example: |
---|
1696 | Usage:: |
---|
1697 | |
---|
1698 | To use as one of the fields you want to get in a select |
---|
1699 | |
---|
1700 | >>> tj = db.define_table('tj', Field('testjson', 'json')) |
---|
1701 | >>> tj.insert(testjson={u'a': {u'a1': 2, u'a0': 1}, u'b': 3, u'c': {u'c0': {u'c01': [2, 4]}}}) |
---|
1702 | >>> row = db(db.tj).select(db.tj.testjson.json_key('a').with_alias('a')).first() |
---|
1703 | >>> row.a |
---|
1704 | {u'a1': 2, u'a0': 1} |
---|
1705 | |
---|
1706 | Using it as part of building a query |
---|
1707 | |
---|
1708 | >>> row = db(tj.testjson.json_key('a').json_key_value('a0') == 1).select().first() |
---|
1709 | >>> row |
---|
1710 | <Row {'testjson': {u'a': {u'a1': 2, u'a0': 1}, u'c': {u'c0': {u'c01': [2, 4]}}, u'b': 3}, 'id': 1L}> |
---|
1711 | |
---|
1712 | """ |
---|
1713 | return Expression(self.db, self._dialect.json_key, self, key) |
---|
1714 | |
---|
1715 | def json_key_value(self, key): |
---|
1716 | """ |
---|
1717 | Get the value int or text in key |
---|
1718 | |
---|
1719 | Example: |
---|
1720 | Usage:: |
---|
1721 | |
---|
1722 | To use as one of the fields you want to get in a select |
---|
1723 | |
---|
1724 | >>> tj = db.define_table('tj', Field('testjson', 'json')) |
---|
1725 | >>> tj.insert(testjson={u'a': {u'a1': 2, u'a0': 1}, u'b': 3, u'c': {u'c0': {u'c01': [2, 4]}}}) |
---|
1726 | >>> row = db(db.tj).select(db.tj.testjson.json_key_value('b').with_alias('b')).first() |
---|
1727 | >>> row.b |
---|
1728 | '3' |
---|
1729 | |
---|
1730 | Using it as part of building a query |
---|
1731 | |
---|
1732 | >>> row = db(db.tj.testjson.json_key('a').json_key_value('a0') == 1).select().first() |
---|
1733 | >>> row |
---|
1734 | <Row {'testjson': {u'a': {u'a1': 2, u'a0': 1}, u'c': {u'c0': {u'c01': [2, 4]}}, u'b': 3}, 'id': 1L}> |
---|
1735 | |
---|
1736 | """ |
---|
1737 | return Expression(self.db, self._dialect.json_key_value, self, key) |
---|
1738 | |
---|
1739 | def json_path(self, path): |
---|
1740 | """ |
---|
1741 | Get the json in path which you can use for more queries |
---|
1742 | |
---|
1743 | Example: |
---|
1744 | Usage:: |
---|
1745 | |
---|
1746 | >>> tj = db.define_table('tj', Field('testjson', 'json')) |
---|
1747 | >>> tj.insert(testjson={u'a': {u'a1': 2, u'a0': 1}, u'b': 3, u'c': {u'c0': {u'c01': [2, 4]}}}) |
---|
1748 | >>> row = db(db.tj.id > 0).select(db.tj.testjson.json_path('{c, c0, c01, 0}').with_alias('firstc01')).first() |
---|
1749 | >>> row.firstc01 |
---|
1750 | 2 |
---|
1751 | """ |
---|
1752 | return Expression(self.db, self._dialect.json_path, self, path) |
---|
1753 | |
---|
1754 | def json_path_value(self, path): |
---|
1755 | """ |
---|
1756 | Get the value in path which you can use for more queries |
---|
1757 | |
---|
1758 | Example: |
---|
1759 | Usage:: |
---|
1760 | |
---|
1761 | >>> tj = db.define_table('tj', Field('testjson', 'json')) |
---|
1762 | >>> tj.insert(testjson={u'a': {u'a1': 2, u'a0': 1}, u'b': 3, u'c': {u'c0': {u'c01': [2, 4]}}}) |
---|
1763 | >>> db(db.tj.testjson.json_path_value('{a, a1}') == 2).select().first() |
---|
1764 | <Row {'testjson': {u'a': {u'a1': 2, u'a0': 1}, u'c': {u'c0': {u'c01': [2, 4]}}, u'b': 3}, 'id': 1L}> |
---|
1765 | """ |
---|
1766 | return Expression(self.db, self._dialect.json_path_value, self, path) |
---|
1767 | |
---|
1768 | # JSON Queries |
---|
1769 | |
---|
1770 | def json_contains(self, jsonvalue): |
---|
1771 | """ |
---|
1772 | Containment operator, jsonvalue parameter must be a json string |
---|
1773 | e.g. '{"country": "Peru"}' |
---|
1774 | |
---|
1775 | Example: |
---|
1776 | Usage:: |
---|
1777 | |
---|
1778 | >>> tj = db.define_table('tj', Field('testjson', 'json')) |
---|
1779 | >>> tj.insert(testjson={u'a': {u'a1': 2, u'a0': 1}, u'b': 3, u'c': {u'c0': {u'c01': [2, 4]}}}) |
---|
1780 | >>> db(db.tj.testjson.json_contains('{"c": {"c0":{"c01": [2]}}}')).select().first() |
---|
1781 | <Row {'testjson': {u'a': {u'a1': 2, u'a0': 1}, u'c': {u'c0': {u'c01': [2, 4]}}, u'b': 3}, 'id': 1L}> |
---|
1782 | """ |
---|
1783 | return Query(self.db, self._dialect.json_contains, self, jsonvalue) |
---|
1784 | |
---|
1785 | |
---|
1786 | class FieldVirtual(object): |
---|
1787 | def __init__( |
---|
1788 | self, |
---|
1789 | name, |
---|
1790 | f=None, |
---|
1791 | ftype="string", |
---|
1792 | label=None, |
---|
1793 | table_name=None, |
---|
1794 | readable=True, |
---|
1795 | listable=True, |
---|
1796 | ): |
---|
1797 | # for backward compatibility |
---|
1798 | (self.name, self.f) = (name, f) if f else ("unknown", name) |
---|
1799 | self.type = ftype |
---|
1800 | self.label = label or self.name.replace("_", " ").title() |
---|
1801 | self.represent = lambda v, r=None: v |
---|
1802 | self.formatter = IDENTITY |
---|
1803 | self.comment = None |
---|
1804 | self.readable = readable |
---|
1805 | self.listable = listable |
---|
1806 | self.searchable = False |
---|
1807 | self.writable = False |
---|
1808 | self.requires = None |
---|
1809 | self.widget = None |
---|
1810 | self.tablename = table_name |
---|
1811 | self.filter_out = None |
---|
1812 | |
---|
1813 | def bind(self, table, name): |
---|
1814 | if self.tablename is not None: |
---|
1815 | raise ValueError("FieldVirtual %s is already bound to a table" % self) |
---|
1816 | if self.name == "unknown": # for backward compatibility |
---|
1817 | self.name = name |
---|
1818 | elif name != self.name: |
---|
1819 | raise ValueError("Cannot rename FieldVirtual %s to %s" % (self.name, name)) |
---|
1820 | self.tablename = table._tablename |
---|
1821 | |
---|
1822 | def __str__(self): |
---|
1823 | return "%s.%s" % (self.tablename, self.name) |
---|
1824 | |
---|
1825 | |
---|
1826 | class FieldMethod(object): |
---|
1827 | def __init__(self, name, f=None, handler=None): |
---|
1828 | # for backward compatibility |
---|
1829 | (self.name, self.f) = (name, f) if f else ("unknown", name) |
---|
1830 | self.handler = handler or VirtualCommand |
---|
1831 | |
---|
1832 | def bind(self, table, name): |
---|
1833 | if self.name == "unknown": # for backward compatibility |
---|
1834 | self.name = name |
---|
1835 | elif name != self.name: |
---|
1836 | raise ValueError("Cannot rename FieldMethod %s to %s" % (self.name, name)) |
---|
1837 | |
---|
1838 | |
---|
1839 | @implements_bool |
---|
1840 | class Field(Expression, Serializable): |
---|
1841 | |
---|
1842 | Virtual = FieldVirtual |
---|
1843 | Method = FieldMethod |
---|
1844 | Lazy = FieldMethod # for backward compatibility |
---|
1845 | |
---|
1846 | """ |
---|
1847 | Represents a database field |
---|
1848 | |
---|
1849 | Example: |
---|
1850 | Usage:: |
---|
1851 | |
---|
1852 | a = Field(name, 'string', length=32, default=None, required=False, |
---|
1853 | requires=IS_NOT_EMPTY(), ondelete='CASCADE', |
---|
1854 | notnull=False, unique=False, |
---|
1855 | regex=None, options=None, |
---|
1856 | uploadfield=True, widget=None, label=None, comment=None, |
---|
1857 | uploadfield=True, # True means store on disk, |
---|
1858 | # 'a_field_name' means store in this field in db |
---|
1859 | # False means file content will be discarded. |
---|
1860 | writable=True, readable=True, searchable=True, listable=True, |
---|
1861 | update=None, authorize=None, |
---|
1862 | autodelete=False, represent=None, uploadfolder=None, |
---|
1863 | uploadseparate=False # upload to separate directories by uuid_keys |
---|
1864 | # first 2 character and tablename.fieldname |
---|
1865 | # False - old behavior |
---|
1866 | # True - put uploaded file in |
---|
1867 | # <uploaddir>/<tablename>.<fieldname>/uuid_key[:2] |
---|
1868 | # directory) |
---|
1869 | uploadfs=None # a pyfilesystem where to store upload |
---|
1870 | ) |
---|
1871 | |
---|
1872 | to be used as argument of `DAL.define_table` |
---|
1873 | |
---|
1874 | """ |
---|
1875 | |
---|
1876 | def __init__( |
---|
1877 | self, |
---|
1878 | fieldname, |
---|
1879 | type="string", |
---|
1880 | length=None, |
---|
1881 | default=DEFAULT, |
---|
1882 | required=False, |
---|
1883 | requires=DEFAULT, |
---|
1884 | ondelete="CASCADE", |
---|
1885 | notnull=False, |
---|
1886 | unique=False, |
---|
1887 | uploadfield=True, |
---|
1888 | widget=None, |
---|
1889 | label=None, |
---|
1890 | comment=None, |
---|
1891 | writable=True, |
---|
1892 | readable=True, |
---|
1893 | searchable=True, |
---|
1894 | listable=True, |
---|
1895 | regex=None, |
---|
1896 | options=None, |
---|
1897 | update=None, |
---|
1898 | authorize=None, |
---|
1899 | autodelete=False, |
---|
1900 | represent=None, |
---|
1901 | uploadfolder=None, |
---|
1902 | uploadseparate=False, |
---|
1903 | uploadfs=None, |
---|
1904 | compute=None, |
---|
1905 | custom_store=None, |
---|
1906 | custom_retrieve=None, |
---|
1907 | custom_retrieve_file_properties=None, |
---|
1908 | custom_delete=None, |
---|
1909 | filter_in=None, |
---|
1910 | filter_out=None, |
---|
1911 | custom_qualifier=None, |
---|
1912 | map_none=None, |
---|
1913 | rname=None, |
---|
1914 | **others |
---|
1915 | ): |
---|
1916 | self._db = self.db = None # both for backward compatibility |
---|
1917 | self.table = self._table = None |
---|
1918 | self.op = None |
---|
1919 | self.first = None |
---|
1920 | self.second = None |
---|
1921 | if PY2 and isinstance(fieldname, unicode): |
---|
1922 | try: |
---|
1923 | fieldname = str(fieldname) |
---|
1924 | except UnicodeEncodeError: |
---|
1925 | raise SyntaxError("Field: invalid unicode field name") |
---|
1926 | self.name = fieldname = cleanup(fieldname) |
---|
1927 | if ( |
---|
1928 | not isinstance(fieldname, str) |
---|
1929 | or hasattr(Table, fieldname) |
---|
1930 | or not REGEX_VALID_TB_FLD.match(fieldname) |
---|
1931 | or REGEX_PYTHON_KEYWORDS.match(fieldname) |
---|
1932 | ): |
---|
1933 | raise SyntaxError( |
---|
1934 | "Field: invalid field name: %s, " |
---|
1935 | 'use rname for "funny" names' % fieldname |
---|
1936 | ) |
---|
1937 | |
---|
1938 | if not isinstance(type, (Table, Field)): |
---|
1939 | self.type = type |
---|
1940 | else: |
---|
1941 | self.type = "reference %s" % type |
---|
1942 | |
---|
1943 | self.length = ( |
---|
1944 | length if length is not None else DEFAULTLENGTH.get(self.type, 512) |
---|
1945 | ) |
---|
1946 | self.default = default if default is not DEFAULT else (update or None) |
---|
1947 | self.required = required # is this field required |
---|
1948 | self.ondelete = ondelete.upper() # this is for reference fields only |
---|
1949 | self.notnull = notnull |
---|
1950 | self.unique = unique |
---|
1951 | # split to deal with decimal(,) |
---|
1952 | self.regex = regex |
---|
1953 | if not regex and isinstance(self.type, str): |
---|
1954 | self.regex = DEFAULT_REGEX.get(self.type.split("(")[0]) |
---|
1955 | self.options = options |
---|
1956 | self.uploadfield = uploadfield |
---|
1957 | self.uploadfolder = uploadfolder |
---|
1958 | self.uploadseparate = uploadseparate |
---|
1959 | self.uploadfs = uploadfs |
---|
1960 | self.widget = widget |
---|
1961 | self.comment = comment |
---|
1962 | self.writable = writable |
---|
1963 | self.readable = readable |
---|
1964 | self.searchable = searchable |
---|
1965 | self.listable = listable |
---|
1966 | self.update = update |
---|
1967 | self.authorize = authorize |
---|
1968 | self.autodelete = autodelete |
---|
1969 | self.represent = ( |
---|
1970 | list_represent |
---|
1971 | if represent is None and type in ("list:integer", "list:string") |
---|
1972 | else represent |
---|
1973 | ) |
---|
1974 | self.compute = compute |
---|
1975 | self.isattachment = True |
---|
1976 | self.custom_store = custom_store |
---|
1977 | self.custom_retrieve = custom_retrieve |
---|
1978 | self.custom_retrieve_file_properties = custom_retrieve_file_properties |
---|
1979 | self.custom_delete = custom_delete |
---|
1980 | self.filter_in = filter_in |
---|
1981 | self.filter_out = filter_out |
---|
1982 | self.custom_qualifier = custom_qualifier |
---|
1983 | self.label = label if label is not None else fieldname.replace("_", " ").title() |
---|
1984 | self.requires = requires if requires is not None else [] |
---|
1985 | self.map_none = map_none |
---|
1986 | self._rname = self._raw_rname = rname |
---|
1987 | stype = self.type |
---|
1988 | if isinstance(self.type, SQLCustomType): |
---|
1989 | stype = self.type.type |
---|
1990 | self._itype = REGEX_TYPE.match(stype).group(0) if stype else None |
---|
1991 | for key in others: |
---|
1992 | setattr(self, key, others[key]) |
---|
1993 | |
---|
1994 | def bind(self, table): |
---|
1995 | if self._table is not None: |
---|
1996 | raise ValueError("Field %s is already bound to a table" % self.longname) |
---|
1997 | self.db = self._db = table._db |
---|
1998 | self.table = self._table = table |
---|
1999 | self.tablename = self._tablename = table._tablename |
---|
2000 | if self._db and self._rname is None: |
---|
2001 | self._rname = self._db._adapter.sqlsafe_field(self.name) |
---|
2002 | self._raw_rname = self.name |
---|
2003 | |
---|
2004 | def set_attributes(self, *args, **attributes): |
---|
2005 | self.__dict__.update(*args, **attributes) |
---|
2006 | return self |
---|
2007 | |
---|
2008 | def clone(self, point_self_references_to=False, **args): |
---|
2009 | field = copy.copy(self) |
---|
2010 | if point_self_references_to and self.type == "reference %s" % self._tablename: |
---|
2011 | field.type = "reference %s" % point_self_references_to |
---|
2012 | field.__dict__.update(args) |
---|
2013 | field.db = field._db = None |
---|
2014 | field.table = field._table = None |
---|
2015 | field.tablename = field._tablename = None |
---|
2016 | if self._db and self._rname == self._db._adapter.sqlsafe_field(self.name): |
---|
2017 | # Reset the name because it may need to be requoted by bind() |
---|
2018 | field._rname = field._raw_rname = None |
---|
2019 | return field |
---|
2020 | |
---|
2021 | def store(self, file, filename=None, path=None): |
---|
2022 | # make sure filename is a str sequence |
---|
2023 | filename = "{}".format(filename) |
---|
2024 | if self.custom_store: |
---|
2025 | return self.custom_store(file, filename, path) |
---|
2026 | if isinstance(file, cgi.FieldStorage): |
---|
2027 | filename = filename or file.filename |
---|
2028 | file = file.file |
---|
2029 | elif not filename: |
---|
2030 | filename = file.name |
---|
2031 | filename = os.path.basename(filename.replace("/", os.sep).replace("\\", os.sep)) |
---|
2032 | m = re.search(REGEX_UPLOAD_EXTENSION, filename) |
---|
2033 | extension = m and m.group(1) or "txt" |
---|
2034 | uuid_key = uuidstr().replace("-", "")[-16:] |
---|
2035 | encoded_filename = to_native(base64.b16encode(to_bytes(filename)).lower()) |
---|
2036 | # Fields that are not bound to a table use "tmp" as the table name |
---|
2037 | tablename = getattr(self, "_tablename", "tmp") |
---|
2038 | newfilename = "%s.%s.%s.%s" % ( |
---|
2039 | tablename, |
---|
2040 | self.name, |
---|
2041 | uuid_key, |
---|
2042 | encoded_filename, |
---|
2043 | ) |
---|
2044 | newfilename = ( |
---|
2045 | newfilename[: (self.length - 1 - len(extension))] + "." + extension |
---|
2046 | ) |
---|
2047 | self_uploadfield = self.uploadfield |
---|
2048 | if isinstance(self_uploadfield, Field): |
---|
2049 | blob_uploadfield_name = self_uploadfield.uploadfield |
---|
2050 | keys = { |
---|
2051 | self_uploadfield.name: newfilename, |
---|
2052 | blob_uploadfield_name: file.read(), |
---|
2053 | } |
---|
2054 | self_uploadfield.table.insert(**keys) |
---|
2055 | elif self_uploadfield is True: |
---|
2056 | if self.uploadfs: |
---|
2057 | dest_file = self.uploadfs.open(text_type(newfilename), "wb") |
---|
2058 | else: |
---|
2059 | if path: |
---|
2060 | pass |
---|
2061 | elif self.uploadfolder: |
---|
2062 | path = self.uploadfolder |
---|
2063 | elif self.db is not None and self.db._adapter.folder: |
---|
2064 | path = pjoin(self.db._adapter.folder, "..", "uploads") |
---|
2065 | else: |
---|
2066 | raise RuntimeError( |
---|
2067 | "you must specify a Field(..., uploadfolder=...)" |
---|
2068 | ) |
---|
2069 | if self.uploadseparate: |
---|
2070 | if self.uploadfs: |
---|
2071 | raise RuntimeError("not supported") |
---|
2072 | path = pjoin( |
---|
2073 | path, "%s.%s" % (tablename, self.name), uuid_key[:2] |
---|
2074 | ) |
---|
2075 | if not exists(path): |
---|
2076 | os.makedirs(path) |
---|
2077 | pathfilename = pjoin(path, newfilename) |
---|
2078 | dest_file = open(pathfilename, "wb") |
---|
2079 | try: |
---|
2080 | shutil.copyfileobj(file, dest_file) |
---|
2081 | except IOError: |
---|
2082 | raise IOError( |
---|
2083 | 'Unable to store file "%s" because invalid permissions, ' |
---|
2084 | "readonly file system, or filename too long" % pathfilename |
---|
2085 | ) |
---|
2086 | dest_file.close() |
---|
2087 | return newfilename |
---|
2088 | |
---|
2089 | def retrieve(self, name, path=None, nameonly=False): |
---|
2090 | """ |
---|
2091 | If `nameonly==True` return (filename, fullfilename) instead of |
---|
2092 | (filename, stream) |
---|
2093 | """ |
---|
2094 | self_uploadfield = self.uploadfield |
---|
2095 | if self.custom_retrieve: |
---|
2096 | return self.custom_retrieve(name, path) |
---|
2097 | if self.authorize or isinstance(self_uploadfield, str): |
---|
2098 | row = self.db(self == name).select().first() |
---|
2099 | if not row: |
---|
2100 | raise NotFoundException |
---|
2101 | if self.authorize and not self.authorize(row): |
---|
2102 | raise NotAuthorizedException |
---|
2103 | file_properties = self.retrieve_file_properties(name, path) |
---|
2104 | filename = file_properties["filename"] |
---|
2105 | if isinstance(self_uploadfield, str): # ## if file is in DB |
---|
2106 | stream = BytesIO(to_bytes(row[self_uploadfield] or "")) |
---|
2107 | elif isinstance(self_uploadfield, Field): |
---|
2108 | blob_uploadfield_name = self_uploadfield.uploadfield |
---|
2109 | query = self_uploadfield == name |
---|
2110 | data = self_uploadfield.table(query)[blob_uploadfield_name] |
---|
2111 | stream = BytesIO(to_bytes(data)) |
---|
2112 | elif self.uploadfs: |
---|
2113 | # ## if file is on pyfilesystem |
---|
2114 | stream = self.uploadfs.open(text_type(name), "rb") |
---|
2115 | else: |
---|
2116 | # ## if file is on regular filesystem |
---|
2117 | # this is intentionally a string with filename and not a stream |
---|
2118 | # this propagates and allows stream_file_or_304_or_206 to be called |
---|
2119 | fullname = pjoin(file_properties["path"], name) |
---|
2120 | if nameonly: |
---|
2121 | return (filename, fullname) |
---|
2122 | stream = open(fullname, "rb") |
---|
2123 | return (filename, stream) |
---|
2124 | |
---|
2125 | def retrieve_file_properties(self, name, path=None): |
---|
2126 | m = re.match(REGEX_UPLOAD_PATTERN, name) |
---|
2127 | if not m or not self.isattachment: |
---|
2128 | raise TypeError("Can't retrieve %s file properties" % name) |
---|
2129 | self_uploadfield = self.uploadfield |
---|
2130 | if self.custom_retrieve_file_properties: |
---|
2131 | return self.custom_retrieve_file_properties(name, path) |
---|
2132 | if m.group("name"): |
---|
2133 | try: |
---|
2134 | filename = base64.b16decode(m.group("name"), True).decode("utf-8") |
---|
2135 | filename = re.sub(REGEX_UPLOAD_CLEANUP, "_", filename) |
---|
2136 | except (TypeError, AttributeError, binascii.Error): |
---|
2137 | filename = name |
---|
2138 | else: |
---|
2139 | filename = name |
---|
2140 | # ## if file is in DB |
---|
2141 | if isinstance(self_uploadfield, (str, Field)): |
---|
2142 | return dict(path=None, filename=filename) |
---|
2143 | # ## if file is on filesystem |
---|
2144 | if not path: |
---|
2145 | if self.uploadfolder: |
---|
2146 | path = self.uploadfolder |
---|
2147 | else: |
---|
2148 | path = pjoin(self.db._adapter.folder, "..", "uploads") |
---|
2149 | if self.uploadseparate: |
---|
2150 | t = m.group("table") |
---|
2151 | f = m.group("field") |
---|
2152 | u = m.group("uuidkey") |
---|
2153 | path = pjoin(path, "%s.%s" % (t, f), u[:2]) |
---|
2154 | return dict(path=path, filename=filename) |
---|
2155 | |
---|
2156 | def formatter(self, value): |
---|
2157 | if value is None: |
---|
2158 | return self.map_none |
---|
2159 | requires = self.requires |
---|
2160 | if not requires or requires is DEFAULT: |
---|
2161 | return value |
---|
2162 | if not isinstance(requires, (list, tuple)): |
---|
2163 | requires = [requires] |
---|
2164 | elif isinstance(requires, tuple): |
---|
2165 | requires = list(requires) |
---|
2166 | else: |
---|
2167 | requires = copy.copy(requires) |
---|
2168 | requires.reverse() |
---|
2169 | for item in requires: |
---|
2170 | if hasattr(item, "formatter"): |
---|
2171 | value = item.formatter(value) |
---|
2172 | return value |
---|
2173 | |
---|
2174 | def validate(self, value, record_id=None): |
---|
2175 | requires = self.requires |
---|
2176 | if not requires or requires is DEFAULT: |
---|
2177 | return ((value if value != self.map_none else None), None) |
---|
2178 | if not isinstance(requires, (list, tuple)): |
---|
2179 | requires = [requires] |
---|
2180 | for validator in requires: |
---|
2181 | # notice that some validator may have different behavior |
---|
2182 | # depending on the record id, for example |
---|
2183 | # IS_NOT_IN_DB should exclude the current record_id from check |
---|
2184 | (value, error) = validator(value, record_id) |
---|
2185 | if error: |
---|
2186 | return (value, error) |
---|
2187 | return ((value if value != self.map_none else None), None) |
---|
2188 | |
---|
2189 | def count(self, distinct=None): |
---|
2190 | return Expression(self.db, self._dialect.count, self, distinct, "integer") |
---|
2191 | |
---|
2192 | def as_dict(self, flat=False, sanitize=True): |
---|
2193 | attrs = ( |
---|
2194 | "name", |
---|
2195 | "authorize", |
---|
2196 | "represent", |
---|
2197 | "ondelete", |
---|
2198 | "custom_store", |
---|
2199 | "autodelete", |
---|
2200 | "custom_retrieve", |
---|
2201 | "filter_out", |
---|
2202 | "uploadseparate", |
---|
2203 | "widget", |
---|
2204 | "uploadfs", |
---|
2205 | "update", |
---|
2206 | "custom_delete", |
---|
2207 | "uploadfield", |
---|
2208 | "uploadfolder", |
---|
2209 | "custom_qualifier", |
---|
2210 | "unique", |
---|
2211 | "writable", |
---|
2212 | "compute", |
---|
2213 | "map_none", |
---|
2214 | "default", |
---|
2215 | "type", |
---|
2216 | "required", |
---|
2217 | "readable", |
---|
2218 | "requires", |
---|
2219 | "comment", |
---|
2220 | "label", |
---|
2221 | "length", |
---|
2222 | "notnull", |
---|
2223 | "custom_retrieve_file_properties", |
---|
2224 | "filter_in", |
---|
2225 | ) |
---|
2226 | serializable = (int, long, basestring, float, tuple, bool, type(None)) |
---|
2227 | |
---|
2228 | def flatten(obj): |
---|
2229 | if isinstance(obj, dict): |
---|
2230 | return dict((flatten(k), flatten(v)) for k, v in obj.items()) |
---|
2231 | elif isinstance(obj, (tuple, list, set)): |
---|
2232 | return [flatten(v) for v in obj] |
---|
2233 | elif isinstance(obj, serializable): |
---|
2234 | return obj |
---|
2235 | elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): |
---|
2236 | return str(obj) |
---|
2237 | else: |
---|
2238 | return None |
---|
2239 | |
---|
2240 | d = dict() |
---|
2241 | if not (sanitize and not (self.readable or self.writable)): |
---|
2242 | for attr in attrs: |
---|
2243 | if flat: |
---|
2244 | d.update({attr: flatten(getattr(self, attr))}) |
---|
2245 | else: |
---|
2246 | d.update({attr: getattr(self, attr)}) |
---|
2247 | d["fieldname"] = d.pop("name") |
---|
2248 | return d |
---|
2249 | |
---|
2250 | def __bool__(self): |
---|
2251 | return True |
---|
2252 | |
---|
2253 | def __str__(self): |
---|
2254 | if self._table: |
---|
2255 | return "%s.%s" % (self.tablename, self.name) |
---|
2256 | return "<no table>.%s" % self.name |
---|
2257 | |
---|
2258 | def __hash__(self): |
---|
2259 | return id(self) |
---|
2260 | |
---|
2261 | @property |
---|
2262 | def sqlsafe(self): |
---|
2263 | if self._table is None: |
---|
2264 | raise SyntaxError("Field %s is not bound to any table" % self.name) |
---|
2265 | return self._table.sql_shortref + "." + self._rname |
---|
2266 | |
---|
2267 | @property |
---|
2268 | @deprecated("sqlsafe_name", "_rname", "Field") |
---|
2269 | def sqlsafe_name(self): |
---|
2270 | return self._rname |
---|
2271 | |
---|
2272 | @property |
---|
2273 | def longname(self): |
---|
2274 | if self._table is None: |
---|
2275 | raise SyntaxError("Field %s is not bound to any table" % self.name) |
---|
2276 | return self._table._tablename + "." + self.name |
---|
2277 | |
---|
2278 | |
---|
2279 | class Query(Serializable): |
---|
2280 | |
---|
2281 | """ |
---|
2282 | Necessary to define a set. |
---|
2283 | It can be stored or can be passed to `DAL.__call__()` to obtain a `Set` |
---|
2284 | |
---|
2285 | Example: |
---|
2286 | Use as:: |
---|
2287 | |
---|
2288 | query = db.users.name=='Max' |
---|
2289 | set = db(query) |
---|
2290 | records = set.select() |
---|
2291 | |
---|
2292 | """ |
---|
2293 | |
---|
2294 | def __init__( |
---|
2295 | self, |
---|
2296 | db, |
---|
2297 | op, |
---|
2298 | first=None, |
---|
2299 | second=None, |
---|
2300 | ignore_common_filters=False, |
---|
2301 | **optional_args |
---|
2302 | ): |
---|
2303 | self.db = self._db = db |
---|
2304 | self.op = op |
---|
2305 | self.first = first |
---|
2306 | self.second = second |
---|
2307 | self.ignore_common_filters = ignore_common_filters |
---|
2308 | self.optional_args = optional_args |
---|
2309 | |
---|
2310 | @property |
---|
2311 | def _dialect(self): |
---|
2312 | return self.db._adapter.dialect |
---|
2313 | |
---|
2314 | def __repr__(self): |
---|
2315 | return "<Query %s>" % str(self) |
---|
2316 | |
---|
2317 | def __str__(self): |
---|
2318 | return str(self.db._adapter.expand(self)) |
---|
2319 | |
---|
2320 | def __and__(self, other): |
---|
2321 | return Query(self.db, self._dialect._and, self, other) |
---|
2322 | |
---|
2323 | __rand__ = __and__ |
---|
2324 | |
---|
2325 | def __or__(self, other): |
---|
2326 | return Query(self.db, self._dialect._or, self, other) |
---|
2327 | |
---|
2328 | __ror__ = __or__ |
---|
2329 | |
---|
2330 | def __invert__(self): |
---|
2331 | if self.op == self._dialect._not: |
---|
2332 | return self.first |
---|
2333 | return Query(self.db, self._dialect._not, self) |
---|
2334 | |
---|
2335 | def __eq__(self, other): |
---|
2336 | return repr(self) == repr(other) |
---|
2337 | |
---|
2338 | def __ne__(self, other): |
---|
2339 | return not (self == other) |
---|
2340 | |
---|
2341 | def case(self, t=1, f=0): |
---|
2342 | return Expression(self.db, self._dialect.case, self, (t, f)) |
---|
2343 | |
---|
2344 | def as_dict(self, flat=False, sanitize=True): |
---|
2345 | """Experimental stuff |
---|
2346 | |
---|
2347 | This allows to return a plain dictionary with the basic |
---|
2348 | query representation. Can be used with json/xml services |
---|
2349 | for client-side db I/O |
---|
2350 | |
---|
2351 | Example: |
---|
2352 | Usage:: |
---|
2353 | |
---|
2354 | q = db.auth_user.id != 0 |
---|
2355 | q.as_dict(flat=True) |
---|
2356 | { |
---|
2357 | "op": "NE", |
---|
2358 | "first":{ |
---|
2359 | "tablename": "auth_user", |
---|
2360 | "fieldname": "id" |
---|
2361 | }, |
---|
2362 | "second":0 |
---|
2363 | } |
---|
2364 | """ |
---|
2365 | |
---|
2366 | SERIALIZABLE_TYPES = ( |
---|
2367 | tuple, |
---|
2368 | dict, |
---|
2369 | set, |
---|
2370 | list, |
---|
2371 | int, |
---|
2372 | long, |
---|
2373 | float, |
---|
2374 | basestring, |
---|
2375 | type(None), |
---|
2376 | bool, |
---|
2377 | ) |
---|
2378 | |
---|
2379 | def loop(d): |
---|
2380 | newd = dict() |
---|
2381 | for k, v in d.items(): |
---|
2382 | if k in ("first", "second"): |
---|
2383 | if isinstance(v, self.__class__): |
---|
2384 | newd[k] = loop(v.__dict__) |
---|
2385 | elif isinstance(v, Field): |
---|
2386 | newd[k] = {"tablename": v._tablename, "fieldname": v.name} |
---|
2387 | elif isinstance(v, Expression): |
---|
2388 | newd[k] = loop(v.__dict__) |
---|
2389 | elif isinstance(v, SERIALIZABLE_TYPES): |
---|
2390 | newd[k] = v |
---|
2391 | elif isinstance( |
---|
2392 | v, (datetime.date, datetime.time, datetime.datetime) |
---|
2393 | ): |
---|
2394 | newd[k] = text_type(v) |
---|
2395 | elif k == "op": |
---|
2396 | if callable(v): |
---|
2397 | newd[k] = v.__name__ |
---|
2398 | elif isinstance(v, basestring): |
---|
2399 | newd[k] = v |
---|
2400 | else: |
---|
2401 | pass # not callable or string |
---|
2402 | elif isinstance(v, SERIALIZABLE_TYPES): |
---|
2403 | if isinstance(v, dict): |
---|
2404 | newd[k] = loop(v) |
---|
2405 | else: |
---|
2406 | newd[k] = v |
---|
2407 | return newd |
---|
2408 | |
---|
2409 | if flat: |
---|
2410 | return loop(self.__dict__) |
---|
2411 | else: |
---|
2412 | return self.__dict__ |
---|
2413 | |
---|
2414 | |
---|
2415 | class Set(Serializable): |
---|
2416 | |
---|
2417 | """ |
---|
2418 | Represents a set of records in the database. |
---|
2419 | Records are identified by the `query=Query(...)` object. |
---|
2420 | Normally the Set is generated by `DAL.__call__(Query(...))` |
---|
2421 | |
---|
2422 | Given a set, for example:: |
---|
2423 | |
---|
2424 | myset = db(db.users.name=='Max') |
---|
2425 | |
---|
2426 | you can:: |
---|
2427 | |
---|
2428 | myset.update(db.users.name='Massimo') |
---|
2429 | myset.delete() # all elements in the set |
---|
2430 | myset.select(orderby=db.users.id, groupby=db.users.name, limitby=(0, 10)) |
---|
2431 | |
---|
2432 | and take subsets: |
---|
2433 | |
---|
2434 | subset = myset(db.users.id<5) |
---|
2435 | |
---|
2436 | """ |
---|
2437 | |
---|
2438 | def __init__(self, db, query, ignore_common_filters=None): |
---|
2439 | self.db = db |
---|
2440 | self._db = db # for backward compatibility |
---|
2441 | self.dquery = None |
---|
2442 | |
---|
2443 | # if query is a dict, parse it |
---|
2444 | if isinstance(query, dict): |
---|
2445 | query = self.parse(query) |
---|
2446 | |
---|
2447 | if ( |
---|
2448 | ignore_common_filters is not None |
---|
2449 | and use_common_filters(query) == ignore_common_filters |
---|
2450 | ): |
---|
2451 | query = copy.copy(query) |
---|
2452 | query.ignore_common_filters = ignore_common_filters |
---|
2453 | self.query = query |
---|
2454 | |
---|
2455 | def __repr__(self): |
---|
2456 | return "<Set %s>" % str(self.query) |
---|
2457 | |
---|
2458 | def __call__(self, query, ignore_common_filters=False): |
---|
2459 | return self.where(query, ignore_common_filters) |
---|
2460 | |
---|
2461 | def where(self, query, ignore_common_filters=False): |
---|
2462 | if query is None: |
---|
2463 | return self |
---|
2464 | elif isinstance(query, Table): |
---|
2465 | query = self.db._adapter.id_query(query) |
---|
2466 | elif isinstance(query, str): |
---|
2467 | query = Expression(self.db, query) |
---|
2468 | elif isinstance(query, Field): |
---|
2469 | query = query != None |
---|
2470 | if self.query: |
---|
2471 | return Set( |
---|
2472 | self.db, self.query & query, ignore_common_filters=ignore_common_filters |
---|
2473 | ) |
---|
2474 | else: |
---|
2475 | return Set(self.db, query, ignore_common_filters=ignore_common_filters) |
---|
2476 | |
---|
2477 | def _count(self, distinct=None): |
---|
2478 | return self.db._adapter._count(self.query, distinct) |
---|
2479 | |
---|
2480 | def _select(self, *fields, **attributes): |
---|
2481 | adapter = self.db._adapter |
---|
2482 | tablenames = adapter.tables( |
---|
2483 | self.query, |
---|
2484 | attributes.get("join", None), |
---|
2485 | attributes.get("left", None), |
---|
2486 | attributes.get("orderby", None), |
---|
2487 | attributes.get("groupby", None), |
---|
2488 | ) |
---|
2489 | fields = adapter.expand_all(fields, tablenames) |
---|
2490 | return adapter._select(self.query, fields, attributes) |
---|
2491 | |
---|
2492 | def _delete(self): |
---|
2493 | db = self.db |
---|
2494 | table = db._adapter.get_table(self.query) |
---|
2495 | return db._adapter._delete(table, self.query) |
---|
2496 | |
---|
2497 | def _update(self, **update_fields): |
---|
2498 | db = self.db |
---|
2499 | table = db._adapter.get_table(self.query) |
---|
2500 | row = table._fields_and_values_for_update(update_fields) |
---|
2501 | return db._adapter._update(table, self.query, row.op_values()) |
---|
2502 | |
---|
2503 | def as_dict(self, flat=False, sanitize=True): |
---|
2504 | if flat: |
---|
2505 | uid = dbname = uri = None |
---|
2506 | codec = self.db._db_codec |
---|
2507 | if not sanitize: |
---|
2508 | uri, dbname, uid = (self.db._dbname, str(self.db), self.db._db_uid) |
---|
2509 | d = {"query": self.query.as_dict(flat=flat)} |
---|
2510 | d["db"] = {"uid": uid, "codec": codec, "name": dbname, "uri": uri} |
---|
2511 | return d |
---|
2512 | else: |
---|
2513 | return self.__dict__ |
---|
2514 | |
---|
2515 | def parse(self, dquery): |
---|
2516 | """Experimental: Turn a dictionary into a Query object""" |
---|
2517 | self.dquery = dquery |
---|
2518 | return self.build(self.dquery) |
---|
2519 | |
---|
2520 | def build(self, d): |
---|
2521 | """Experimental: see .parse()""" |
---|
2522 | op, first, second = (d["op"], d["first"], d.get("second", None)) |
---|
2523 | left = right = built = None |
---|
2524 | |
---|
2525 | if op in ("AND", "OR"): |
---|
2526 | if not (type(first), type(second)) == (dict, dict): |
---|
2527 | raise SyntaxError("Invalid AND/OR query") |
---|
2528 | if op == "AND": |
---|
2529 | built = self.build(first) & self.build(second) |
---|
2530 | else: |
---|
2531 | built = self.build(first) | self.build(second) |
---|
2532 | elif op == "NOT": |
---|
2533 | if first is None: |
---|
2534 | raise SyntaxError("Invalid NOT query") |
---|
2535 | built = ~self.build(first) # pylint: disable=invalid-unary-operand-type |
---|
2536 | else: |
---|
2537 | # normal operation (GT, EQ, LT, ...) |
---|
2538 | for k, v in {"left": first, "right": second}.items(): |
---|
2539 | if isinstance(v, dict) and v.get("op"): |
---|
2540 | v = self.build(v) |
---|
2541 | if isinstance(v, dict) and ("tablename" in v): |
---|
2542 | v = self.db[v["tablename"]][v["fieldname"]] |
---|
2543 | if k == "left": |
---|
2544 | left = v |
---|
2545 | else: |
---|
2546 | right = v |
---|
2547 | |
---|
2548 | if hasattr(self.db._adapter, op): |
---|
2549 | opm = getattr(self.db._adapter, op) |
---|
2550 | |
---|
2551 | if op == "EQ": |
---|
2552 | built = left == right |
---|
2553 | elif op == "NE": |
---|
2554 | built = left != right |
---|
2555 | elif op == "GT": |
---|
2556 | built = left > right |
---|
2557 | elif op == "GE": |
---|
2558 | built = left >= right |
---|
2559 | elif op == "LT": |
---|
2560 | built = left < right |
---|
2561 | elif op == "LE": |
---|
2562 | built = left <= right |
---|
2563 | elif op in ("JOIN", "LEFT_JOIN", "RANDOM", "ALLOW_NULL"): |
---|
2564 | built = Expression(self.db, opm) |
---|
2565 | elif op in ( |
---|
2566 | "LOWER", |
---|
2567 | "UPPER", |
---|
2568 | "EPOCH", |
---|
2569 | "PRIMARY_KEY", |
---|
2570 | "COALESCE_ZERO", |
---|
2571 | "RAW", |
---|
2572 | "INVERT", |
---|
2573 | ): |
---|
2574 | built = Expression(self.db, opm, left) |
---|
2575 | elif op in ( |
---|
2576 | "COUNT", |
---|
2577 | "EXTRACT", |
---|
2578 | "AGGREGATE", |
---|
2579 | "SUBSTRING", |
---|
2580 | "REGEXP", |
---|
2581 | "LIKE", |
---|
2582 | "ILIKE", |
---|
2583 | "STARTSWITH", |
---|
2584 | "ENDSWITH", |
---|
2585 | "ADD", |
---|
2586 | "SUB", |
---|
2587 | "MUL", |
---|
2588 | "DIV", |
---|
2589 | "MOD", |
---|
2590 | "AS", |
---|
2591 | "ON", |
---|
2592 | "COMMA", |
---|
2593 | "NOT_NULL", |
---|
2594 | "COALESCE", |
---|
2595 | "CONTAINS", |
---|
2596 | "BELONGS", |
---|
2597 | ): |
---|
2598 | built = Expression(self.db, opm, left, right) |
---|
2599 | # expression as string |
---|
2600 | elif not (left or right): |
---|
2601 | built = Expression(self.db, op) |
---|
2602 | else: |
---|
2603 | raise SyntaxError("Operator not supported: %s" % op) |
---|
2604 | |
---|
2605 | return built |
---|
2606 | |
---|
2607 | def isempty(self): |
---|
2608 | return not self.select(limitby=(0, 1), orderby_on_limitby=False) |
---|
2609 | |
---|
2610 | def count(self, distinct=None, cache=None): |
---|
2611 | db = self.db |
---|
2612 | if cache: |
---|
2613 | sql = self._count(distinct=distinct) |
---|
2614 | if isinstance(cache, dict): |
---|
2615 | cache_model = cache["model"] |
---|
2616 | time_expire = cache["expiration"] |
---|
2617 | key = cache.get("key") |
---|
2618 | if not key: |
---|
2619 | key = db._uri + "/" + sql |
---|
2620 | key = hashlib_md5(key).hexdigest() |
---|
2621 | else: |
---|
2622 | cache_model, time_expire = cache |
---|
2623 | key = db._uri + "/" + sql |
---|
2624 | key = hashlib_md5(key).hexdigest() |
---|
2625 | return cache_model( |
---|
2626 | key, |
---|
2627 | lambda self=self, distinct=distinct: db._adapter.count( |
---|
2628 | self.query, distinct |
---|
2629 | ), |
---|
2630 | time_expire, |
---|
2631 | ) |
---|
2632 | return db._adapter.count(self.query, distinct) |
---|
2633 | |
---|
2634 | def select(self, *fields, **attributes): |
---|
2635 | adapter = self.db._adapter |
---|
2636 | tablenames = adapter.tables( |
---|
2637 | self.query, |
---|
2638 | attributes.get("join", None), |
---|
2639 | attributes.get("left", None), |
---|
2640 | attributes.get("orderby", None), |
---|
2641 | attributes.get("groupby", None), |
---|
2642 | ) |
---|
2643 | fields = adapter.expand_all(fields, tablenames) |
---|
2644 | return adapter.select(self.query, fields, attributes) |
---|
2645 | |
---|
2646 | def iterselect(self, *fields, **attributes): |
---|
2647 | adapter = self.db._adapter |
---|
2648 | tablenames = adapter.tables( |
---|
2649 | self.query, |
---|
2650 | attributes.get("join", None), |
---|
2651 | attributes.get("left", None), |
---|
2652 | attributes.get("orderby", None), |
---|
2653 | attributes.get("groupby", None), |
---|
2654 | ) |
---|
2655 | fields = adapter.expand_all(fields, tablenames) |
---|
2656 | return adapter.iterselect(self.query, fields, attributes) |
---|
2657 | |
---|
2658 | def nested_select(self, *fields, **attributes): |
---|
2659 | adapter = self.db._adapter |
---|
2660 | tablenames = adapter.tables( |
---|
2661 | self.query, |
---|
2662 | attributes.get("join", None), |
---|
2663 | attributes.get("left", None), |
---|
2664 | attributes.get("orderby", None), |
---|
2665 | attributes.get("groupby", None), |
---|
2666 | ) |
---|
2667 | fields = adapter.expand_all(fields, tablenames) |
---|
2668 | return adapter.nested_select(self.query, fields, attributes) |
---|
2669 | |
---|
2670 | def delete(self): |
---|
2671 | db = self.db |
---|
2672 | table = db._adapter.get_table(self.query) |
---|
2673 | if any(f(self) for f in table._before_delete): |
---|
2674 | return 0 |
---|
2675 | ret = db._adapter.delete(table, self.query) |
---|
2676 | ret and [f(self) for f in table._after_delete] |
---|
2677 | return ret |
---|
2678 | |
---|
2679 | def delete_naive(self): |
---|
2680 | """ |
---|
2681 | Same as delete but does not call table._before_delete and _after_delete |
---|
2682 | """ |
---|
2683 | db = self.db |
---|
2684 | table = db._adapter.get_table(self.query) |
---|
2685 | ret = db._adapter.delete(table, self.query) |
---|
2686 | return ret |
---|
2687 | |
---|
2688 | def update(self, **update_fields): |
---|
2689 | db = self.db |
---|
2690 | table = db._adapter.get_table(self.query) |
---|
2691 | row = table._fields_and_values_for_update(update_fields) |
---|
2692 | if not row._values: |
---|
2693 | raise ValueError("No fields to update") |
---|
2694 | if any(f(self, row) for f in table._before_update): |
---|
2695 | return 0 |
---|
2696 | ret = db._adapter.update(table, self.query, row.op_values()) |
---|
2697 | ret and [f(self, row) for f in table._after_update] |
---|
2698 | return ret |
---|
2699 | |
---|
2700 | def update_naive(self, **update_fields): |
---|
2701 | """ |
---|
2702 | Same as update but does not call table._before_update and _after_update |
---|
2703 | """ |
---|
2704 | table = self.db._adapter.get_table(self.query) |
---|
2705 | row = table._fields_and_values_for_update(update_fields) |
---|
2706 | if not row._values: |
---|
2707 | raise ValueError("No fields to update") |
---|
2708 | ret = self.db._adapter.update(table, self.query, row.op_values()) |
---|
2709 | return ret |
---|
2710 | |
---|
2711 | def validate_and_update(self, **update_fields): |
---|
2712 | table = self.db._adapter.get_table(self.query) |
---|
2713 | response = Row() |
---|
2714 | response.errors = Row() |
---|
2715 | new_fields = copy.copy(update_fields) |
---|
2716 | for key, value in iteritems(update_fields): |
---|
2717 | value, error = table[key].validate(value, update_fields.get("id")) |
---|
2718 | if error: |
---|
2719 | response.errors[key] = "%s" % error |
---|
2720 | else: |
---|
2721 | new_fields[key] = value |
---|
2722 | if response.errors: |
---|
2723 | response.updated = None |
---|
2724 | else: |
---|
2725 | row = table._fields_and_values_for_update(new_fields) |
---|
2726 | if not row._values: |
---|
2727 | raise ValueError("No fields to update") |
---|
2728 | if any(f(self, row) for f in table._before_update): |
---|
2729 | ret = 0 |
---|
2730 | else: |
---|
2731 | ret = self.db._adapter.update(table, self.query, row.op_values()) |
---|
2732 | ret and [f(self, row) for f in table._after_update] |
---|
2733 | response.updated = ret |
---|
2734 | return response |
---|
2735 | |
---|
2736 | |
---|
2737 | class LazyReferenceGetter(object): |
---|
2738 | def __init__(self, table, id): |
---|
2739 | self.db = table._db |
---|
2740 | self.tablename = table._tablename |
---|
2741 | self.id = id |
---|
2742 | |
---|
2743 | def __call__(self, other_tablename): |
---|
2744 | if self.db._lazy_tables is False: |
---|
2745 | raise AttributeError() |
---|
2746 | table = self.db[self.tablename] |
---|
2747 | other_table = self.db[other_tablename] |
---|
2748 | for rfield in table._referenced_by: |
---|
2749 | if rfield.table == other_table: |
---|
2750 | return LazySet(rfield, self.id) |
---|
2751 | raise AttributeError() |
---|
2752 | |
---|
2753 | |
---|
2754 | class LazySet(object): |
---|
2755 | def __init__(self, field, id): |
---|
2756 | self.db, self.tablename, self.fieldname, self.id = ( |
---|
2757 | field.db, |
---|
2758 | field._tablename, |
---|
2759 | field.name, |
---|
2760 | id, |
---|
2761 | ) |
---|
2762 | |
---|
2763 | def _getset(self): |
---|
2764 | query = self.db[self.tablename][self.fieldname] == self.id |
---|
2765 | return Set(self.db, query) |
---|
2766 | |
---|
2767 | def __repr__(self): |
---|
2768 | return repr(self._getset()) |
---|
2769 | |
---|
2770 | def __call__(self, query, ignore_common_filters=False): |
---|
2771 | return self.where(query, ignore_common_filters) |
---|
2772 | |
---|
2773 | def where(self, query, ignore_common_filters=False): |
---|
2774 | return self._getset()(query, ignore_common_filters) |
---|
2775 | |
---|
2776 | def _count(self, distinct=None): |
---|
2777 | return self._getset()._count(distinct) |
---|
2778 | |
---|
2779 | def _select(self, *fields, **attributes): |
---|
2780 | return self._getset()._select(*fields, **attributes) |
---|
2781 | |
---|
2782 | def _delete(self): |
---|
2783 | return self._getset()._delete() |
---|
2784 | |
---|
2785 | def _update(self, **update_fields): |
---|
2786 | return self._getset()._update(**update_fields) |
---|
2787 | |
---|
2788 | def isempty(self): |
---|
2789 | return self._getset().isempty() |
---|
2790 | |
---|
2791 | def count(self, distinct=None, cache=None): |
---|
2792 | return self._getset().count(distinct, cache) |
---|
2793 | |
---|
2794 | def select(self, *fields, **attributes): |
---|
2795 | return self._getset().select(*fields, **attributes) |
---|
2796 | |
---|
2797 | def nested_select(self, *fields, **attributes): |
---|
2798 | return self._getset().nested_select(*fields, **attributes) |
---|
2799 | |
---|
2800 | def delete(self): |
---|
2801 | return self._getset().delete() |
---|
2802 | |
---|
2803 | def delete_naive(self): |
---|
2804 | return self._getset().delete_naive() |
---|
2805 | |
---|
2806 | def update(self, **update_fields): |
---|
2807 | return self._getset().update(**update_fields) |
---|
2808 | |
---|
2809 | def update_naive(self, **update_fields): |
---|
2810 | return self._getset().update_naive(**update_fields) |
---|
2811 | |
---|
2812 | def validate_and_update(self, **update_fields): |
---|
2813 | return self._getset().validate_and_update(**update_fields) |
---|
2814 | |
---|
2815 | |
---|
2816 | class VirtualCommand(object): |
---|
2817 | def __init__(self, method, row): |
---|
2818 | self.method = method |
---|
2819 | self.row = row |
---|
2820 | |
---|
2821 | def __call__(self, *args, **kwargs): |
---|
2822 | return self.method(self.row, *args, **kwargs) |
---|
2823 | |
---|
2824 | |
---|
2825 | @implements_bool |
---|
2826 | class BasicRows(object): |
---|
2827 | """ |
---|
2828 | Abstract class for Rows and IterRows |
---|
2829 | """ |
---|
2830 | |
---|
2831 | def __bool__(self): |
---|
2832 | return True if self.first() is not None else False |
---|
2833 | |
---|
2834 | def __str__(self): |
---|
2835 | """ |
---|
2836 | Serializes the table into a csv file |
---|
2837 | """ |
---|
2838 | |
---|
2839 | s = StringIO() |
---|
2840 | self.export_to_csv_file(s) |
---|
2841 | return s.getvalue() |
---|
2842 | |
---|
2843 | def as_trees(self, parent_name="parent_id", children_name="children", render=False): |
---|
2844 | """ |
---|
2845 | returns the data as list of trees. |
---|
2846 | |
---|
2847 | :param parent_name: the name of the field holding the reference to the |
---|
2848 | parent (default parent_id). |
---|
2849 | :param children_name: the name where the children of each row will be |
---|
2850 | stored as a list (default children). |
---|
2851 | :param render: whether we will render the fields using their represent |
---|
2852 | (default False) can be a list of fields to render or |
---|
2853 | True to render all. |
---|
2854 | """ |
---|
2855 | roots = [] |
---|
2856 | drows = {} |
---|
2857 | rows = ( |
---|
2858 | list(self.render(fields=None if render is True else render)) |
---|
2859 | if render |
---|
2860 | else self |
---|
2861 | ) |
---|
2862 | for row in rows: |
---|
2863 | drows[row.id] = row |
---|
2864 | row[children_name] = [] |
---|
2865 | for row in rows: |
---|
2866 | parent = row[parent_name] |
---|
2867 | if parent is None: |
---|
2868 | roots.append(row) |
---|
2869 | else: |
---|
2870 | drows[parent][children_name].append(row) |
---|
2871 | return roots |
---|
2872 | |
---|
2873 | def as_list( |
---|
2874 | self, |
---|
2875 | compact=True, |
---|
2876 | storage_to_dict=True, |
---|
2877 | datetime_to_str=False, |
---|
2878 | custom_types=None, |
---|
2879 | ): |
---|
2880 | """ |
---|
2881 | Returns the data as a list or dictionary. |
---|
2882 | |
---|
2883 | Args: |
---|
2884 | storage_to_dict: when True returns a dict, otherwise a list |
---|
2885 | datetime_to_str: convert datetime fields as strings |
---|
2886 | """ |
---|
2887 | (oc, self.compact) = (self.compact, compact) |
---|
2888 | if storage_to_dict: |
---|
2889 | items = [item.as_dict(datetime_to_str, custom_types) for item in self] |
---|
2890 | else: |
---|
2891 | items = [item for item in self] |
---|
2892 | self.compact = oc |
---|
2893 | return items |
---|
2894 | |
---|
2895 | def as_dict( |
---|
2896 | self, |
---|
2897 | key="id", |
---|
2898 | compact=True, |
---|
2899 | storage_to_dict=True, |
---|
2900 | datetime_to_str=False, |
---|
2901 | custom_types=None, |
---|
2902 | ): |
---|
2903 | """ |
---|
2904 | Returns the data as a dictionary of dictionaries (storage_to_dict=True) |
---|
2905 | or records (False) |
---|
2906 | |
---|
2907 | Args: |
---|
2908 | key: the name of the field to be used as dict key, normally the id |
---|
2909 | compact: ? (default True) |
---|
2910 | storage_to_dict: when True returns a dict, otherwise a list(default True) |
---|
2911 | datetime_to_str: convert datetime fields as strings (default False) |
---|
2912 | """ |
---|
2913 | |
---|
2914 | # test for multiple rows |
---|
2915 | multi = False |
---|
2916 | f = self.first() |
---|
2917 | if f and isinstance(key, basestring): |
---|
2918 | multi = any([isinstance(v, f.__class__) for v in f.values()]) |
---|
2919 | if ("." not in key) and multi: |
---|
2920 | # No key provided, default to int indices |
---|
2921 | def new_key(): |
---|
2922 | i = 0 |
---|
2923 | while True: |
---|
2924 | yield i |
---|
2925 | i += 1 |
---|
2926 | |
---|
2927 | key_generator = new_key() |
---|
2928 | key = lambda r: next(key_generator) |
---|
2929 | |
---|
2930 | rows = self.as_list(compact, storage_to_dict, datetime_to_str, custom_types) |
---|
2931 | if isinstance(key, str) and key.count(".") == 1: |
---|
2932 | (table, field) = key.split(".") |
---|
2933 | return dict([(r[table][field], r) for r in rows]) |
---|
2934 | elif isinstance(key, str): |
---|
2935 | return dict([(r[key], r) for r in rows]) |
---|
2936 | else: |
---|
2937 | return dict([(key(r), r) for r in rows]) |
---|
2938 | |
---|
2939 | def xml(self, strict=False, row_name="row", rows_name="rows"): |
---|
2940 | """ |
---|
2941 | Serializes the table using sqlhtml.SQLTABLE (if present) |
---|
2942 | """ |
---|
2943 | if not strict and not self.db.has_representer("rows_xml"): |
---|
2944 | strict = True |
---|
2945 | |
---|
2946 | if strict: |
---|
2947 | return "<%s>\n%s\n</%s>" % ( |
---|
2948 | rows_name, |
---|
2949 | "\n".join( |
---|
2950 | row.as_xml(row_name=row_name, colnames=self.colnames) |
---|
2951 | for row in self |
---|
2952 | ), |
---|
2953 | rows_name, |
---|
2954 | ) |
---|
2955 | |
---|
2956 | rv = self.db.represent("rows_xml", self) |
---|
2957 | if hasattr(rv, "xml") and callable(getattr(rv, "xml")): |
---|
2958 | return rv.xml() |
---|
2959 | return rv |
---|
2960 | |
---|
2961 | def as_xml(self, row_name="row", rows_name="rows"): |
---|
2962 | return self.xml(strict=True, row_name=row_name, rows_name=rows_name) |
---|
2963 | |
---|
2964 | def as_json(self, mode="object", default=None): |
---|
2965 | """ |
---|
2966 | Serializes the rows to a JSON list or object with objects |
---|
2967 | mode='object' is not implemented (should return a nested |
---|
2968 | object structure) |
---|
2969 | """ |
---|
2970 | items = [ |
---|
2971 | record.as_json( |
---|
2972 | mode=mode, default=default, serialize=False, colnames=self.colnames |
---|
2973 | ) |
---|
2974 | for record in self |
---|
2975 | ] |
---|
2976 | |
---|
2977 | return serializers.json(items) |
---|
2978 | |
---|
2979 | @property |
---|
2980 | def colnames_fields(self): |
---|
2981 | """ |
---|
2982 | Returns the list of fields matching colnames, possibly |
---|
2983 | including virtual fields (i.e. Field.Virtual and |
---|
2984 | Field.Method instances). |
---|
2985 | Use this property instead of plain fields attribute |
---|
2986 | whenever you have an entry in colnames which references |
---|
2987 | a virtual field, and you still need a correspondance |
---|
2988 | between column names and fields. |
---|
2989 | |
---|
2990 | NOTE that references to the virtual fields must have been |
---|
2991 | **forced** in some way within colnames, because in the general |
---|
2992 | case it is not possible to have them as a result of a select. |
---|
2993 | """ |
---|
2994 | colnames = self.colnames |
---|
2995 | # instances of Field or Expression only are allowed in fields |
---|
2996 | plain_fields = self.fields |
---|
2997 | if len(colnames) > len(plain_fields): |
---|
2998 | # correspondance between colnames and fields is broken, |
---|
2999 | # search for missing virtual fields |
---|
3000 | fields = [] |
---|
3001 | fi = 0 |
---|
3002 | for col in colnames: |
---|
3003 | m = re.match(REGEX_TABLE_DOT_FIELD_OPTIONAL_QUOTES, col) |
---|
3004 | if m: |
---|
3005 | t, f = m.groups() |
---|
3006 | table = self.db[t] |
---|
3007 | field = table[f] |
---|
3008 | if field in table._virtual_fields + table._virtual_methods: |
---|
3009 | fields.append(field) |
---|
3010 | continue |
---|
3011 | fields.append(plain_fields[fi]) |
---|
3012 | fi += 1 |
---|
3013 | assert len(colnames) == len(fields) |
---|
3014 | return fields |
---|
3015 | return plain_fields |
---|
3016 | |
---|
3017 | def export_to_csv_file(self, ofile, null="<NULL>", *args, **kwargs): |
---|
3018 | """ |
---|
3019 | Exports data to csv, the first line contains the column names |
---|
3020 | |
---|
3021 | Args: |
---|
3022 | ofile: where the csv must be exported to |
---|
3023 | null: how null values must be represented (default '<NULL>') |
---|
3024 | delimiter: delimiter to separate values (default ',') |
---|
3025 | quotechar: character to use to quote string values (default '"') |
---|
3026 | quoting: quote system, use csv.QUOTE_*** (default csv.QUOTE_MINIMAL) |
---|
3027 | represent: use the fields .represent value (default False) |
---|
3028 | colnames: list of column names to use (default self.colnames) |
---|
3029 | |
---|
3030 | This will only work when exporting rows objects!!!! |
---|
3031 | DO NOT use this with db.export_to_csv() |
---|
3032 | """ |
---|
3033 | delimiter = kwargs.get("delimiter", ",") |
---|
3034 | quotechar = kwargs.get("quotechar", '"') |
---|
3035 | quoting = kwargs.get("quoting", csv.QUOTE_MINIMAL) |
---|
3036 | represent = kwargs.get("represent", False) |
---|
3037 | writer = csv.writer( |
---|
3038 | ofile, delimiter=delimiter, quotechar=quotechar, quoting=quoting |
---|
3039 | ) |
---|
3040 | |
---|
3041 | def unquote_colnames(colnames): |
---|
3042 | unq_colnames = [] |
---|
3043 | for col in colnames: |
---|
3044 | m = self.db._adapter.REGEX_TABLE_DOT_FIELD.match(col) |
---|
3045 | if not m: |
---|
3046 | unq_colnames.append(col) |
---|
3047 | else: |
---|
3048 | unq_colnames.append(".".join(m.groups())) |
---|
3049 | return unq_colnames |
---|
3050 | |
---|
3051 | colnames = kwargs.get("colnames", self.colnames) |
---|
3052 | write_colnames = kwargs.get("write_colnames", True) |
---|
3053 | # a proper csv starting with the column names |
---|
3054 | if write_colnames: |
---|
3055 | writer.writerow(unquote_colnames(colnames)) |
---|
3056 | |
---|
3057 | def none_exception(value): |
---|
3058 | """ |
---|
3059 | Returns a cleaned up value that can be used for csv export: |
---|
3060 | |
---|
3061 | - unicode text is encoded as such |
---|
3062 | - None values are replaced with the given representation (default <NULL>) |
---|
3063 | """ |
---|
3064 | if value is None: |
---|
3065 | return null |
---|
3066 | elif PY2 and isinstance(value, unicode): |
---|
3067 | return value.encode("utf8") |
---|
3068 | elif isinstance(value, Reference): |
---|
3069 | return long(value) |
---|
3070 | elif hasattr(value, "isoformat"): |
---|
3071 | return value.isoformat()[:19].replace("T", " ") |
---|
3072 | elif isinstance(value, (list, tuple)): # for type='list:..' |
---|
3073 | return bar_encode(value) |
---|
3074 | return value |
---|
3075 | |
---|
3076 | repr_cache = {} |
---|
3077 | fieldlist = self.colnames_fields |
---|
3078 | fieldmap = dict(zip(self.colnames, fieldlist)) |
---|
3079 | for record in self: |
---|
3080 | row = [] |
---|
3081 | for col in colnames: |
---|
3082 | field = fieldmap[col] |
---|
3083 | if isinstance(field, (Field, FieldVirtual)): |
---|
3084 | t = field.tablename |
---|
3085 | f = field.name |
---|
3086 | if isinstance(record.get(t, None), (Row, dict)): |
---|
3087 | value = record[t][f] |
---|
3088 | else: |
---|
3089 | value = record[f] |
---|
3090 | if field.type == "blob" and value is not None: |
---|
3091 | value = base64.b64encode(value) |
---|
3092 | elif represent and field.represent: |
---|
3093 | if field.type.startswith("reference"): |
---|
3094 | if field not in repr_cache: |
---|
3095 | repr_cache[field] = {} |
---|
3096 | if value not in repr_cache[field]: |
---|
3097 | repr_cache[field][value] = field.represent( |
---|
3098 | value, record |
---|
3099 | ) |
---|
3100 | value = repr_cache[field][value] |
---|
3101 | else: |
---|
3102 | value = field.represent(value, record) |
---|
3103 | row.append(none_exception(value)) |
---|
3104 | else: |
---|
3105 | row.append(record._extra[col]) |
---|
3106 | writer.writerow(row) |
---|
3107 | |
---|
3108 | # for consistent naming yet backwards compatible |
---|
3109 | as_csv = __str__ |
---|
3110 | json = as_json |
---|
3111 | |
---|
3112 | |
---|
3113 | class Rows(BasicRows): |
---|
3114 | """ |
---|
3115 | A wrapper for the return value of a select. It basically represents a table. |
---|
3116 | It has an iterator and each row is represented as a `Row` dictionary. |
---|
3117 | """ |
---|
3118 | |
---|
3119 | # ## TODO: this class still needs some work to care for ID/OID |
---|
3120 | |
---|
3121 | def __init__( |
---|
3122 | self, db=None, records=[], colnames=[], compact=True, rawrows=None, fields=[] |
---|
3123 | ): |
---|
3124 | self.db = db |
---|
3125 | self.records = records |
---|
3126 | self.fields = fields |
---|
3127 | self.colnames = colnames |
---|
3128 | self.compact = compact |
---|
3129 | self.response = rawrows |
---|
3130 | |
---|
3131 | def __repr__(self): |
---|
3132 | return "<Rows (%s)>" % len(self.records) |
---|
3133 | |
---|
3134 | def setvirtualfields(self, **keyed_virtualfields): |
---|
3135 | """ |
---|
3136 | For reference:: |
---|
3137 | |
---|
3138 | db.define_table('x', Field('number', 'integer')) |
---|
3139 | if db(db.x).isempty(): [db.x.insert(number=i) for i in range(10)] |
---|
3140 | |
---|
3141 | from gluon.dal import lazy_virtualfield |
---|
3142 | |
---|
3143 | class MyVirtualFields(object): |
---|
3144 | # normal virtual field (backward compatible, discouraged) |
---|
3145 | def normal_shift(self): return self.x.number+1 |
---|
3146 | # lazy virtual field (because of @staticmethod) |
---|
3147 | @lazy_virtualfield |
---|
3148 | def lazy_shift(instance, row, delta=4): return row.x.number+delta |
---|
3149 | db.x.virtualfields.append(MyVirtualFields()) |
---|
3150 | |
---|
3151 | for row in db(db.x).select(): |
---|
3152 | print row.number, row.normal_shift, row.lazy_shift(delta=7) |
---|
3153 | |
---|
3154 | """ |
---|
3155 | if not keyed_virtualfields: |
---|
3156 | return self |
---|
3157 | for row in self.records: |
---|
3158 | for (tablename, virtualfields) in iteritems(keyed_virtualfields): |
---|
3159 | attributes = dir(virtualfields) |
---|
3160 | if tablename not in row: |
---|
3161 | box = row[tablename] = Row() |
---|
3162 | else: |
---|
3163 | box = row[tablename] |
---|
3164 | updated = False |
---|
3165 | for attribute in attributes: |
---|
3166 | if attribute[0] != "_": |
---|
3167 | method = getattr(virtualfields, attribute) |
---|
3168 | if hasattr(method, "__lazy__"): |
---|
3169 | box[attribute] = VirtualCommand(method, row) |
---|
3170 | elif type(method) == types.MethodType: |
---|
3171 | if not updated: |
---|
3172 | virtualfields.__dict__.update(row) |
---|
3173 | updated = True |
---|
3174 | box[attribute] = method() |
---|
3175 | return self |
---|
3176 | |
---|
3177 | def __add__(self, other): |
---|
3178 | if self.colnames != other.colnames: |
---|
3179 | raise Exception("Cannot & incompatible Rows objects") |
---|
3180 | records = self.records + other.records |
---|
3181 | return self.__class__( |
---|
3182 | self.db, |
---|
3183 | records, |
---|
3184 | self.colnames, |
---|
3185 | fields=self.fields, |
---|
3186 | compact=self.compact or other.compact, |
---|
3187 | ) |
---|
3188 | |
---|
3189 | def __and__(self, other): |
---|
3190 | if self.colnames != other.colnames: |
---|
3191 | raise Exception("Cannot & incompatible Rows objects") |
---|
3192 | records = [] |
---|
3193 | other_records = list(other.records) |
---|
3194 | for record in self.records: |
---|
3195 | if record in other_records: |
---|
3196 | records.append(record) |
---|
3197 | other_records.remove(record) |
---|
3198 | return self.__class__( |
---|
3199 | self.db, |
---|
3200 | records, |
---|
3201 | self.colnames, |
---|
3202 | fields=self.fields, |
---|
3203 | compact=self.compact or other.compact, |
---|
3204 | ) |
---|
3205 | |
---|
3206 | def __or__(self, other): |
---|
3207 | if self.colnames != other.colnames: |
---|
3208 | raise Exception("Cannot | incompatible Rows objects") |
---|
3209 | records = [record for record in other.records if record not in self.records] |
---|
3210 | records = self.records + records |
---|
3211 | return self.__class__( |
---|
3212 | self.db, |
---|
3213 | records, |
---|
3214 | self.colnames, |
---|
3215 | fields=self.fields, |
---|
3216 | compact=self.compact or other.compact, |
---|
3217 | ) |
---|
3218 | |
---|
3219 | def __len__(self): |
---|
3220 | return len(self.records) |
---|
3221 | |
---|
3222 | def __getslice__(self, a, b): |
---|
3223 | return self.__class__( |
---|
3224 | self.db, |
---|
3225 | self.records[a:b], |
---|
3226 | self.colnames, |
---|
3227 | compact=self.compact, |
---|
3228 | fields=self.fields, |
---|
3229 | ) |
---|
3230 | |
---|
3231 | def __getitem__(self, i): |
---|
3232 | if isinstance(i, slice): |
---|
3233 | return self.__getslice__(i.start, i.stop) |
---|
3234 | row = self.records[i] |
---|
3235 | keys = list(row.keys()) |
---|
3236 | if self.compact and len(keys) == 1 and keys[0] != "_extra": |
---|
3237 | return row[keys[0]] |
---|
3238 | return row |
---|
3239 | |
---|
3240 | def __iter__(self): |
---|
3241 | """ |
---|
3242 | Iterator over records |
---|
3243 | """ |
---|
3244 | |
---|
3245 | for i in xrange(len(self)): |
---|
3246 | yield self[i] |
---|
3247 | |
---|
3248 | def __eq__(self, other): |
---|
3249 | if isinstance(other, Rows): |
---|
3250 | return self.records == other.records |
---|
3251 | else: |
---|
3252 | return False |
---|
3253 | |
---|
3254 | def column(self, column=None): |
---|
3255 | return [r[str(column) if column else self.colnames[0]] for r in self] |
---|
3256 | |
---|
3257 | def first(self): |
---|
3258 | if not self.records: |
---|
3259 | return None |
---|
3260 | return self[0] |
---|
3261 | |
---|
3262 | def last(self): |
---|
3263 | if not self.records: |
---|
3264 | return None |
---|
3265 | return self[-1] |
---|
3266 | |
---|
3267 | def append(self, row): |
---|
3268 | self.records.append(row) |
---|
3269 | |
---|
3270 | def insert(self, position, row): |
---|
3271 | self.records.insert(position, row) |
---|
3272 | |
---|
3273 | def find(self, f, limitby=None): |
---|
3274 | """ |
---|
3275 | Returns a new Rows object, a subset of the original object, |
---|
3276 | filtered by the function `f` |
---|
3277 | """ |
---|
3278 | if not self: |
---|
3279 | return self.__class__( |
---|
3280 | self.db, [], self.colnames, compact=self.compact, fields=self.fields |
---|
3281 | ) |
---|
3282 | records = [] |
---|
3283 | if limitby: |
---|
3284 | a, b = limitby |
---|
3285 | else: |
---|
3286 | a, b = 0, len(self) |
---|
3287 | k = 0 |
---|
3288 | for i, row in enumerate(self): |
---|
3289 | if f(row): |
---|
3290 | if a <= k: |
---|
3291 | records.append(self.records[i]) |
---|
3292 | k += 1 |
---|
3293 | if k == b: |
---|
3294 | break |
---|
3295 | return self.__class__( |
---|
3296 | self.db, records, self.colnames, compact=self.compact, fields=self.fields |
---|
3297 | ) |
---|
3298 | |
---|
3299 | def exclude(self, f): |
---|
3300 | """ |
---|
3301 | Removes elements from the calling Rows object, filtered by the function |
---|
3302 | `f`, and returns a new Rows object containing the removed elements |
---|
3303 | """ |
---|
3304 | if not self.records: |
---|
3305 | return self.__class__( |
---|
3306 | self.db, [], self.colnames, compact=self.compact, fields=self.fields |
---|
3307 | ) |
---|
3308 | removed = [] |
---|
3309 | i = 0 |
---|
3310 | while i < len(self): |
---|
3311 | row = self[i] |
---|
3312 | if f(row): |
---|
3313 | removed.append(self.records[i]) |
---|
3314 | del self.records[i] |
---|
3315 | else: |
---|
3316 | i += 1 |
---|
3317 | return self.__class__( |
---|
3318 | self.db, removed, self.colnames, compact=self.compact, fields=self.fields |
---|
3319 | ) |
---|
3320 | |
---|
3321 | def sort(self, f, reverse=False): |
---|
3322 | """ |
---|
3323 | Returns a list of sorted elements (not sorted in place) |
---|
3324 | """ |
---|
3325 | rows = self.__class__( |
---|
3326 | self.db, [], self.colnames, compact=self.compact, fields=self.fields |
---|
3327 | ) |
---|
3328 | # When compact=True, iterating over self modifies each record, |
---|
3329 | # so when sorting self, it is necessary to return a sorted |
---|
3330 | # version of self.records rather than the sorted self directly. |
---|
3331 | rows.records = [ |
---|
3332 | r |
---|
3333 | for (r, s) in sorted( |
---|
3334 | zip(self.records, self), key=lambda r: f(r[1]), reverse=reverse |
---|
3335 | ) |
---|
3336 | ] |
---|
3337 | return rows |
---|
3338 | |
---|
3339 | def join(self, field, name=None, constraint=None, fields=[], orderby=None): |
---|
3340 | if len(self) == 0: |
---|
3341 | return self |
---|
3342 | mode = "referencing" if field.type == "id" else "referenced" |
---|
3343 | func = lambda ids: field.belongs(ids) |
---|
3344 | db, ids, maps = self.db, [], {} |
---|
3345 | if not fields: |
---|
3346 | fields = [f for f in field._table if f.readable] |
---|
3347 | if mode == "referencing": |
---|
3348 | # try all refernced field names |
---|
3349 | names = ( |
---|
3350 | [name] |
---|
3351 | if name |
---|
3352 | else list( |
---|
3353 | set( |
---|
3354 | f.name for f in field._table._referenced_by if f.name in self[0] |
---|
3355 | ) |
---|
3356 | ) |
---|
3357 | ) |
---|
3358 | # get all the ids |
---|
3359 | ids = [row.get(name) for row in self for name in names] |
---|
3360 | # filter out the invalid ids |
---|
3361 | ids = filter(lambda id: str(id).isdigit(), ids) |
---|
3362 | # build the query |
---|
3363 | query = func(ids) |
---|
3364 | if constraint: |
---|
3365 | query = query & constraint |
---|
3366 | tmp = not field.name in [f.name for f in fields] |
---|
3367 | if tmp: |
---|
3368 | fields.append(field) |
---|
3369 | other = db(query).select(*fields, orderby=orderby, cacheable=True) |
---|
3370 | for row in other: |
---|
3371 | id = row[field.name] |
---|
3372 | maps[id] = row |
---|
3373 | for row in self: |
---|
3374 | for name in names: |
---|
3375 | row[name] = maps.get(row[name]) |
---|
3376 | if mode == "referenced": |
---|
3377 | if not name: |
---|
3378 | name = field._tablename |
---|
3379 | # build the query |
---|
3380 | query = func([row.id for row in self]) |
---|
3381 | if constraint: |
---|
3382 | query = query & constraint |
---|
3383 | name = name or field._tablename |
---|
3384 | tmp = not field.name in [f.name for f in fields] |
---|
3385 | if tmp: |
---|
3386 | fields.append(field) |
---|
3387 | other = db(query).select(*fields, orderby=orderby, cacheable=True) |
---|
3388 | for row in other: |
---|
3389 | id = row[field] |
---|
3390 | if not id in maps: |
---|
3391 | maps[id] = [] |
---|
3392 | if tmp: |
---|
3393 | try: |
---|
3394 | del row[field.name] |
---|
3395 | except: |
---|
3396 | del row[field.tablename][field.name] |
---|
3397 | if not row[field.tablename] and len(row.keys()) == 2: |
---|
3398 | del row[field.tablename] |
---|
3399 | row = row[row.keys()[0]] |
---|
3400 | maps[id].append(row) |
---|
3401 | for row in self: |
---|
3402 | row[name] = maps.get(row.id, []) |
---|
3403 | return self |
---|
3404 | |
---|
3405 | def group_by_value(self, *fields, **args): |
---|
3406 | """ |
---|
3407 | Regroups the rows, by one of the fields |
---|
3408 | """ |
---|
3409 | one_result = False |
---|
3410 | if "one_result" in args: |
---|
3411 | one_result = args["one_result"] |
---|
3412 | |
---|
3413 | def build_fields_struct(row, fields, num, groups): |
---|
3414 | """ |
---|
3415 | helper function: |
---|
3416 | """ |
---|
3417 | if num > len(fields) - 1: |
---|
3418 | if one_result: |
---|
3419 | return row |
---|
3420 | else: |
---|
3421 | return [row] |
---|
3422 | |
---|
3423 | key = fields[num] |
---|
3424 | value = row[key] |
---|
3425 | |
---|
3426 | if value not in groups: |
---|
3427 | groups[value] = build_fields_struct(row, fields, num + 1, {}) |
---|
3428 | else: |
---|
3429 | struct = build_fields_struct(row, fields, num + 1, groups[value]) |
---|
3430 | |
---|
3431 | # still have more grouping to do |
---|
3432 | if isinstance(struct, dict): |
---|
3433 | groups[value].update() |
---|
3434 | # no more grouping, first only is off |
---|
3435 | elif isinstance(struct, list): |
---|
3436 | groups[value] += struct |
---|
3437 | # no more grouping, first only on |
---|
3438 | else: |
---|
3439 | groups[value] = struct |
---|
3440 | |
---|
3441 | return groups |
---|
3442 | |
---|
3443 | if len(fields) == 0: |
---|
3444 | return self |
---|
3445 | |
---|
3446 | # if select returned no results |
---|
3447 | if not self.records: |
---|
3448 | return {} |
---|
3449 | |
---|
3450 | grouped_row_group = dict() |
---|
3451 | |
---|
3452 | # build the struct |
---|
3453 | for row in self: |
---|
3454 | build_fields_struct(row, fields, 0, grouped_row_group) |
---|
3455 | |
---|
3456 | return grouped_row_group |
---|
3457 | |
---|
3458 | def render(self, i=None, fields=None): |
---|
3459 | """ |
---|
3460 | Takes an index and returns a copy of the indexed row with values |
---|
3461 | transformed via the "represent" attributes of the associated fields. |
---|
3462 | |
---|
3463 | Args: |
---|
3464 | i: index. If not specified, a generator is returned for iteration |
---|
3465 | over all the rows. |
---|
3466 | fields: a list of fields to transform (if None, all fields with |
---|
3467 | "represent" attributes will be transformed) |
---|
3468 | """ |
---|
3469 | if i is None: |
---|
3470 | return (self.render(i, fields=fields) for i in range(len(self))) |
---|
3471 | if not self.db.has_representer("rows_render"): |
---|
3472 | raise RuntimeError( |
---|
3473 | "Rows.render() needs a `rows_render` \ |
---|
3474 | representer in DAL instance" |
---|
3475 | ) |
---|
3476 | row = copy.deepcopy(self.records[i]) |
---|
3477 | keys = list(row.keys()) |
---|
3478 | if not fields: |
---|
3479 | fields = [f for f in self.fields if isinstance(f, Field) and f.represent] |
---|
3480 | for field in fields: |
---|
3481 | row[field._tablename][field.name] = self.db.represent( |
---|
3482 | "rows_render", |
---|
3483 | field, |
---|
3484 | row[field._tablename][field.name], |
---|
3485 | row[field._tablename], |
---|
3486 | ) |
---|
3487 | |
---|
3488 | if self.compact and len(keys) == 1 and keys[0] != "_extra": |
---|
3489 | return row[keys[0]] |
---|
3490 | return row |
---|
3491 | |
---|
3492 | def __getstate__(self): |
---|
3493 | ret = self.__dict__.copy() |
---|
3494 | ret.pop("fields", None) |
---|
3495 | return ret |
---|
3496 | |
---|
3497 | def _restore_fields(self, fields): |
---|
3498 | if not hasattr(self, "fields"): |
---|
3499 | self.fields = fields |
---|
3500 | return self |
---|
3501 | |
---|
3502 | |
---|
3503 | @implements_iterator |
---|
3504 | class IterRows(BasicRows): |
---|
3505 | def __init__(self, db, sql, fields, colnames, blob_decode, cacheable): |
---|
3506 | self.db = db |
---|
3507 | self.fields = fields |
---|
3508 | self.colnames = colnames |
---|
3509 | self.blob_decode = blob_decode |
---|
3510 | self.cacheable = cacheable |
---|
3511 | ( |
---|
3512 | self.fields_virtual, |
---|
3513 | self.fields_lazy, |
---|
3514 | self.tmps, |
---|
3515 | ) = self.db._adapter._parse_expand_colnames(fields) |
---|
3516 | self.sql = sql |
---|
3517 | self._head = None |
---|
3518 | self.last_item = None |
---|
3519 | self.last_item_id = None |
---|
3520 | self.compact = True |
---|
3521 | self.sql = sql |
---|
3522 | # get a new cursor in order to be able to iterate without undesired behavior |
---|
3523 | # not completely safe but better than before |
---|
3524 | self.cursor = self.db._adapter.cursor |
---|
3525 | self.db._adapter.execute(sql) |
---|
3526 | # give the adapter a new cursor since this one is busy |
---|
3527 | self.db._adapter.reset_cursor() |
---|
3528 | |
---|
3529 | def __next__(self): |
---|
3530 | db_row = self.cursor.fetchone() |
---|
3531 | if db_row is None: |
---|
3532 | raise StopIteration |
---|
3533 | row = self.db._adapter._parse( |
---|
3534 | db_row, |
---|
3535 | self.tmps, |
---|
3536 | self.fields, |
---|
3537 | self.colnames, |
---|
3538 | self.blob_decode, |
---|
3539 | self.cacheable, |
---|
3540 | self.fields_virtual, |
---|
3541 | self.fields_lazy, |
---|
3542 | ) |
---|
3543 | if self.compact: |
---|
3544 | # The following is to translate |
---|
3545 | # <Row {'t0': {'id': 1L, 'name': 'web2py'}}> |
---|
3546 | # in |
---|
3547 | # <Row {'id': 1L, 'name': 'web2py'}> |
---|
3548 | # normally accomplished by Rows.__get_item__ |
---|
3549 | keys = list(row.keys()) |
---|
3550 | if len(keys) == 1 and keys[0] != "_extra": |
---|
3551 | row = row[keys[0]] |
---|
3552 | return row |
---|
3553 | |
---|
3554 | def __iter__(self): |
---|
3555 | if self._head: |
---|
3556 | yield self._head |
---|
3557 | try: |
---|
3558 | row = next(self) |
---|
3559 | while row is not None: |
---|
3560 | yield row |
---|
3561 | row = next(self) |
---|
3562 | except StopIteration: |
---|
3563 | # Iterator is over, adjust the cursor logic |
---|
3564 | return |
---|
3565 | return |
---|
3566 | |
---|
3567 | def first(self): |
---|
3568 | if self._head is None: |
---|
3569 | try: |
---|
3570 | self._head = next(self) |
---|
3571 | except StopIteration: |
---|
3572 | return None |
---|
3573 | return self._head |
---|
3574 | |
---|
3575 | def __getitem__(self, key): |
---|
3576 | if not isinstance(key, (int, long)): |
---|
3577 | raise TypeError |
---|
3578 | |
---|
3579 | if key == self.last_item_id: |
---|
3580 | return self.last_item |
---|
3581 | |
---|
3582 | n_to_drop = key |
---|
3583 | if self.last_item_id is not None: |
---|
3584 | if self.last_item_id < key: |
---|
3585 | n_to_drop -= self.last_item_id + 1 |
---|
3586 | else: |
---|
3587 | raise IndexError |
---|
3588 | |
---|
3589 | # fetch and drop the first key - 1 elements |
---|
3590 | for i in xrange(n_to_drop): |
---|
3591 | self.cursor._fetchone() |
---|
3592 | row = next(self) |
---|
3593 | if row is None: |
---|
3594 | raise IndexError |
---|
3595 | else: |
---|
3596 | self.last_item_id = key |
---|
3597 | self.last_item = row |
---|
3598 | return row |
---|
3599 | |
---|
3600 | |
---|
3601 | # # rowcount it doesn't seem to be reliable on all drivers |
---|
3602 | # def __len__(self): |
---|
3603 | # return self.db._adapter.cursor.rowcount |
---|