1 | import re |
---|
2 | import sys |
---|
3 | import types |
---|
4 | from collections import defaultdict |
---|
5 | from contextlib import contextmanager |
---|
6 | from .._compat import ( |
---|
7 | PY2, |
---|
8 | with_metaclass, |
---|
9 | iterkeys, |
---|
10 | iteritems, |
---|
11 | hashlib_md5, |
---|
12 | integer_types, |
---|
13 | basestring, |
---|
14 | ) |
---|
15 | from .._globals import IDENTITY |
---|
16 | from ..connection import ConnectionPool |
---|
17 | from ..exceptions import NotOnNOSQLError |
---|
18 | from ..helpers.classes import ( |
---|
19 | Reference, |
---|
20 | ExecutionHandler, |
---|
21 | SQLCustomType, |
---|
22 | SQLALL, |
---|
23 | NullDriver, |
---|
24 | ) |
---|
25 | from ..helpers.methods import use_common_filters, xorify, merge_tablemaps |
---|
26 | from ..helpers.regex import REGEX_SELECT_AS_PARSER, REGEX_TABLE_DOT_FIELD |
---|
27 | from ..migrator import Migrator |
---|
28 | from ..objects import ( |
---|
29 | Table, |
---|
30 | Field, |
---|
31 | Expression, |
---|
32 | Query, |
---|
33 | Rows, |
---|
34 | IterRows, |
---|
35 | LazySet, |
---|
36 | LazyReferenceGetter, |
---|
37 | VirtualCommand, |
---|
38 | Select, |
---|
39 | ) |
---|
40 | from ..utils import deprecated |
---|
41 | from . import AdapterMeta, with_connection, with_connection_or_raise |
---|
42 | |
---|
43 | |
---|
44 | CALLABLETYPES = ( |
---|
45 | types.LambdaType, |
---|
46 | types.FunctionType, |
---|
47 | types.BuiltinFunctionType, |
---|
48 | types.MethodType, |
---|
49 | types.BuiltinMethodType, |
---|
50 | ) |
---|
51 | |
---|
52 | |
---|
53 | class BaseAdapter(with_metaclass(AdapterMeta, ConnectionPool)): |
---|
54 | dbengine = "None" |
---|
55 | drivers = () |
---|
56 | uploads_in_blob = False |
---|
57 | support_distributed_transaction = False |
---|
58 | |
---|
59 | def __init__( |
---|
60 | self, |
---|
61 | db, |
---|
62 | uri, |
---|
63 | pool_size=0, |
---|
64 | folder=None, |
---|
65 | db_codec="UTF-8", |
---|
66 | credential_decoder=IDENTITY, |
---|
67 | driver_args={}, |
---|
68 | adapter_args={}, |
---|
69 | after_connection=None, |
---|
70 | entity_quoting=False, |
---|
71 | ): |
---|
72 | super(BaseAdapter, self).__init__() |
---|
73 | self._load_dependencies() |
---|
74 | self.db = db |
---|
75 | self.uri = uri |
---|
76 | self.pool_size = pool_size |
---|
77 | self.folder = folder |
---|
78 | self.db_codec = db_codec |
---|
79 | self.credential_decoder = credential_decoder |
---|
80 | self.driver_args = driver_args |
---|
81 | self.adapter_args = adapter_args |
---|
82 | self.expand = self._expand |
---|
83 | self._after_connection = after_connection |
---|
84 | self.set_connection(None) |
---|
85 | self.find_driver() |
---|
86 | self._initialize_() |
---|
87 | |
---|
88 | def _load_dependencies(self): |
---|
89 | from ..dialects import dialects |
---|
90 | from ..parsers import parsers |
---|
91 | from ..representers import representers |
---|
92 | |
---|
93 | self.dialect = dialects.get_for(self) |
---|
94 | self.parser = parsers.get_for(self) |
---|
95 | self.representer = representers.get_for(self) |
---|
96 | |
---|
97 | def _initialize_(self): |
---|
98 | self._find_work_folder() |
---|
99 | |
---|
100 | @property |
---|
101 | def types(self): |
---|
102 | return self.dialect.types |
---|
103 | |
---|
104 | @property |
---|
105 | def _available_drivers(self): |
---|
106 | return [ |
---|
107 | driver |
---|
108 | for driver in self.drivers |
---|
109 | if driver in iterkeys(self.db._drivers_available) |
---|
110 | ] |
---|
111 | |
---|
112 | def _driver_from_uri(self): |
---|
113 | rv = None |
---|
114 | if self.uri: |
---|
115 | items = self.uri.split("://", 1)[0].split(":") |
---|
116 | rv = items[1] if len(items) > 1 else None |
---|
117 | return rv |
---|
118 | |
---|
119 | def find_driver(self): |
---|
120 | if getattr(self, "driver", None) is not None: |
---|
121 | return |
---|
122 | requested_driver = self._driver_from_uri() or self.adapter_args.get("driver") |
---|
123 | if requested_driver: |
---|
124 | if requested_driver in self._available_drivers: |
---|
125 | self.driver_name = requested_driver |
---|
126 | self.driver = self.db._drivers_available[requested_driver] |
---|
127 | else: |
---|
128 | raise RuntimeError("Driver %s is not available" % requested_driver) |
---|
129 | elif self._available_drivers: |
---|
130 | self.driver_name = self._available_drivers[0] |
---|
131 | self.driver = self.db._drivers_available[self.driver_name] |
---|
132 | else: |
---|
133 | raise RuntimeError( |
---|
134 | "No driver of supported ones %s is available" % str(self.drivers) |
---|
135 | ) |
---|
136 | |
---|
137 | def connector(self): |
---|
138 | return self.driver.connect(self.driver_args) |
---|
139 | |
---|
140 | def test_connection(self): |
---|
141 | pass |
---|
142 | |
---|
143 | @with_connection |
---|
144 | def close_connection(self): |
---|
145 | rv = self.connection.close() |
---|
146 | self.set_connection(None) |
---|
147 | return rv |
---|
148 | |
---|
149 | def tables(self, *queries): |
---|
150 | tables = dict() |
---|
151 | for query in queries: |
---|
152 | if isinstance(query, Field): |
---|
153 | key = query.tablename |
---|
154 | if tables.get(key, query.table) is not query.table: |
---|
155 | raise ValueError("Name conflict in table list: %s" % key) |
---|
156 | tables[key] = query.table |
---|
157 | elif isinstance(query, (Expression, Query)): |
---|
158 | tmp = [x for x in (query.first, query.second) if x is not None] |
---|
159 | tables = merge_tablemaps(tables, self.tables(*tmp)) |
---|
160 | return tables |
---|
161 | |
---|
162 | def get_table(self, *queries): |
---|
163 | tablemap = self.tables(*queries) |
---|
164 | if len(tablemap) == 1: |
---|
165 | return tablemap.popitem()[1] |
---|
166 | elif len(tablemap) < 1: |
---|
167 | raise RuntimeError("No table selected") |
---|
168 | else: |
---|
169 | raise RuntimeError("Too many tables selected (%s)" % str(list(tablemap))) |
---|
170 | |
---|
171 | def common_filter(self, query, tablist): |
---|
172 | tenant_fieldname = self.db._request_tenant |
---|
173 | for table in tablist: |
---|
174 | if isinstance(table, basestring): |
---|
175 | table = self.db[table] |
---|
176 | # deal with user provided filters |
---|
177 | if table._common_filter is not None: |
---|
178 | query = query & table._common_filter(query) |
---|
179 | # deal with multi_tenant filters |
---|
180 | if tenant_fieldname in table: |
---|
181 | default = table[tenant_fieldname].default |
---|
182 | if default is not None: |
---|
183 | newquery = table[tenant_fieldname] == default |
---|
184 | if query is None: |
---|
185 | query = newquery |
---|
186 | else: |
---|
187 | query = query & newquery |
---|
188 | return query |
---|
189 | |
---|
190 | def _expand(self, expression, field_type=None, colnames=False, query_env={}): |
---|
191 | return str(expression) |
---|
192 | |
---|
193 | def expand_all(self, fields, tabledict): |
---|
194 | new_fields = [] |
---|
195 | append = new_fields.append |
---|
196 | for item in fields: |
---|
197 | if isinstance(item, SQLALL): |
---|
198 | new_fields += item._table |
---|
199 | elif isinstance(item, str): |
---|
200 | m = REGEX_TABLE_DOT_FIELD.match(item) |
---|
201 | if m: |
---|
202 | tablename, fieldname = m.groups() |
---|
203 | append(self.db[tablename][fieldname]) |
---|
204 | else: |
---|
205 | append(Expression(self.db, lambda item=item: item)) |
---|
206 | else: |
---|
207 | append(item) |
---|
208 | # ## if no fields specified take them all from the requested tables |
---|
209 | if not new_fields: |
---|
210 | for table in tabledict.values(): |
---|
211 | for field in table: |
---|
212 | append(field) |
---|
213 | return new_fields |
---|
214 | |
---|
215 | def parse_value(self, value, field_itype, field_type, blob_decode=True): |
---|
216 | # [Note - gi0baro] I think next if block can be (should be?) avoided |
---|
217 | if field_type != "blob" and isinstance(value, str): |
---|
218 | try: |
---|
219 | value = value.decode(self.db._db_codec) |
---|
220 | except Exception: |
---|
221 | pass |
---|
222 | if PY2 and isinstance(value, unicode): |
---|
223 | value = value.encode("utf-8") |
---|
224 | if isinstance(field_type, SQLCustomType): |
---|
225 | value = field_type.decoder(value) |
---|
226 | if not isinstance(field_type, str) or value is None: |
---|
227 | return value |
---|
228 | elif field_type == "blob" and not blob_decode: |
---|
229 | return value |
---|
230 | else: |
---|
231 | return self.parser.parse(value, field_itype, field_type) |
---|
232 | |
---|
233 | def _add_operators_to_parsed_row(self, rid, table, row): |
---|
234 | for key, record_operator in iteritems(self.db.record_operators): |
---|
235 | setattr(row, key, record_operator(row, table, rid)) |
---|
236 | if table._db._lazy_tables: |
---|
237 | row["__get_lazy_reference__"] = LazyReferenceGetter(table, rid) |
---|
238 | |
---|
239 | def _add_reference_sets_to_parsed_row(self, rid, table, tablename, row): |
---|
240 | for rfield in table._referenced_by: |
---|
241 | referee_link = self.db._referee_name and self.db._referee_name % dict( |
---|
242 | table=rfield.tablename, field=rfield.name |
---|
243 | ) |
---|
244 | if referee_link and referee_link not in row and referee_link != tablename: |
---|
245 | row[referee_link] = LazySet(rfield, rid) |
---|
246 | |
---|
247 | def _regex_select_as_parser(self, colname): |
---|
248 | return re.search(REGEX_SELECT_AS_PARSER, colname) |
---|
249 | |
---|
250 | def _parse( |
---|
251 | self, |
---|
252 | row, |
---|
253 | tmps, |
---|
254 | fields, |
---|
255 | colnames, |
---|
256 | blob_decode, |
---|
257 | cacheable, |
---|
258 | fields_virtual, |
---|
259 | fields_lazy, |
---|
260 | ): |
---|
261 | new_row = defaultdict(self.db.Row) |
---|
262 | extras = self.db.Row() |
---|
263 | #: let's loop over columns |
---|
264 | for (j, colname) in enumerate(colnames): |
---|
265 | value = row[j] |
---|
266 | tmp = tmps[j] |
---|
267 | tablename = None |
---|
268 | #: do we have a real column? |
---|
269 | if tmp: |
---|
270 | (tablename, fieldname, table, field, ft, fit) = tmp |
---|
271 | colset = new_row[tablename] |
---|
272 | #: parse value |
---|
273 | value = self.parse_value(value, fit, ft, blob_decode) |
---|
274 | if field.filter_out: |
---|
275 | value = field.filter_out(value) |
---|
276 | colset[fieldname] = value |
---|
277 | #! backward compatibility |
---|
278 | if ft == "id" and fieldname != "id" and "id" not in table.fields: |
---|
279 | colset["id"] = value |
---|
280 | #: additional parsing for 'id' fields |
---|
281 | if ft == "id" and not cacheable: |
---|
282 | self._add_operators_to_parsed_row(value, table, colset) |
---|
283 | #: table may be 'nested_select' which doesn't have '_reference_by' |
---|
284 | if hasattr(table, '_reference_by'): |
---|
285 | self._add_reference_sets_to_parsed_row( |
---|
286 | value, table, tablename, colset |
---|
287 | ) |
---|
288 | #: otherwise we set the value in extras |
---|
289 | else: |
---|
290 | #: fields[j] may be None if only 'colnames' was specified in db.executesql() |
---|
291 | f_itype, ftype = fields[j] and [fields[j]._itype, fields[j].type] or [None, None] |
---|
292 | value = self.parse_value( |
---|
293 | value, f_itype, ftype, blob_decode |
---|
294 | ) |
---|
295 | extras[colname] = value |
---|
296 | if not fields[j]: |
---|
297 | new_row[colname] = value |
---|
298 | else: |
---|
299 | new_column_match = self._regex_select_as_parser(colname) |
---|
300 | if new_column_match is not None: |
---|
301 | new_column_name = new_column_match.group(1) |
---|
302 | new_row[new_column_name] = value |
---|
303 | #: add extras if needed (eg. operations results) |
---|
304 | if extras: |
---|
305 | new_row["_extra"] = extras |
---|
306 | #: add virtuals |
---|
307 | new_row = self.db.Row(**new_row) |
---|
308 | for tablename in fields_virtual.keys(): |
---|
309 | for f, v in fields_virtual[tablename][1]: |
---|
310 | try: |
---|
311 | new_row[tablename][f] = v.f(new_row) |
---|
312 | except (AttributeError, KeyError): |
---|
313 | pass # not enough fields to define virtual field |
---|
314 | for f, v in fields_lazy[tablename][1]: |
---|
315 | try: |
---|
316 | new_row[tablename][f] = v.handler(v.f, new_row) |
---|
317 | except (AttributeError, KeyError): |
---|
318 | pass # not enough fields to define virtual field |
---|
319 | return new_row |
---|
320 | |
---|
321 | def _parse_expand_colnames(self, fieldlist): |
---|
322 | """ |
---|
323 | - Expand a list of colnames into a list of |
---|
324 | (tablename, fieldname, table_obj, field_obj, field_type) |
---|
325 | - Create a list of table for virtual/lazy fields |
---|
326 | """ |
---|
327 | fields_virtual = {} |
---|
328 | fields_lazy = {} |
---|
329 | tmps = [] |
---|
330 | for field in fieldlist: |
---|
331 | if not isinstance(field, Field): |
---|
332 | tmps.append(None) |
---|
333 | continue |
---|
334 | table = field.table |
---|
335 | tablename, fieldname = table._tablename, field.name |
---|
336 | ft = field.type |
---|
337 | fit = field._itype |
---|
338 | tmps.append((tablename, fieldname, table, field, ft, fit)) |
---|
339 | if tablename not in fields_virtual: |
---|
340 | fields_virtual[tablename] = ( |
---|
341 | table, |
---|
342 | [(f.name, f) for f in table._virtual_fields], |
---|
343 | ) |
---|
344 | fields_lazy[tablename] = ( |
---|
345 | table, |
---|
346 | [(f.name, f) for f in table._virtual_methods], |
---|
347 | ) |
---|
348 | return (fields_virtual, fields_lazy, tmps) |
---|
349 | |
---|
350 | def parse(self, rows, fields, colnames, blob_decode=True, cacheable=False): |
---|
351 | (fields_virtual, fields_lazy, tmps) = self._parse_expand_colnames(fields) |
---|
352 | new_rows = [ |
---|
353 | self._parse( |
---|
354 | row, |
---|
355 | tmps, |
---|
356 | fields, |
---|
357 | colnames, |
---|
358 | blob_decode, |
---|
359 | cacheable, |
---|
360 | fields_virtual, |
---|
361 | fields_lazy, |
---|
362 | ) |
---|
363 | for row in rows |
---|
364 | ] |
---|
365 | rowsobj = self.db.Rows(self.db, new_rows, colnames, rawrows=rows, fields=fields) |
---|
366 | # Old style virtual fields |
---|
367 | for tablename, tmp in fields_virtual.items(): |
---|
368 | table = tmp[0] |
---|
369 | # ## old style virtual fields |
---|
370 | for item in table.virtualfields: |
---|
371 | try: |
---|
372 | rowsobj = rowsobj.setvirtualfields(**{tablename: item}) |
---|
373 | except (KeyError, AttributeError): |
---|
374 | # to avoid breaking virtualfields when partial select |
---|
375 | pass |
---|
376 | return rowsobj |
---|
377 | |
---|
378 | def iterparse(self, sql, fields, colnames, blob_decode=True, cacheable=False): |
---|
379 | """ |
---|
380 | Iterator to parse one row at a time. |
---|
381 | It doesn't support the old style virtual fields |
---|
382 | """ |
---|
383 | return IterRows(self.db, sql, fields, colnames, blob_decode, cacheable) |
---|
384 | |
---|
385 | def adapt(self, value): |
---|
386 | return value |
---|
387 | |
---|
388 | def represent(self, obj, field_type): |
---|
389 | if isinstance(obj, CALLABLETYPES): |
---|
390 | obj = obj() |
---|
391 | return self.representer.represent(obj, field_type) |
---|
392 | |
---|
393 | def _drop_table_cleanup(self, table): |
---|
394 | del self.db[table._tablename] |
---|
395 | del self.db.tables[self.db.tables.index(table._tablename)] |
---|
396 | self.db._remove_references_to(table) |
---|
397 | |
---|
398 | def drop_table(self, table, mode=""): |
---|
399 | self._drop_table_cleanup(table) |
---|
400 | |
---|
401 | def rowslice(self, rows, minimum=0, maximum=None): |
---|
402 | return rows |
---|
403 | |
---|
404 | def sqlsafe_table(self, tablename, original_tablename=None): |
---|
405 | return tablename |
---|
406 | |
---|
407 | def sqlsafe_field(self, fieldname): |
---|
408 | return fieldname |
---|
409 | |
---|
410 | |
---|
411 | class DebugHandler(ExecutionHandler): |
---|
412 | def before_execute(self, command): |
---|
413 | self.adapter.db.logger.debug("SQL: %s" % command) |
---|
414 | |
---|
415 | |
---|
416 | class SQLAdapter(BaseAdapter): |
---|
417 | commit_on_alter_table = False |
---|
418 | # [Note - gi0baro] can_select_for_update should be deprecated and removed |
---|
419 | can_select_for_update = True |
---|
420 | execution_handlers = [] |
---|
421 | migrator_cls = Migrator |
---|
422 | |
---|
423 | def __init__(self, *args, **kwargs): |
---|
424 | super(SQLAdapter, self).__init__(*args, **kwargs) |
---|
425 | migrator_cls = self.adapter_args.get("migrator", self.migrator_cls) |
---|
426 | self.migrator = migrator_cls(self) |
---|
427 | self.execution_handlers = list(self.db.execution_handlers) |
---|
428 | if self.db._debug: |
---|
429 | self.execution_handlers.insert(0, DebugHandler) |
---|
430 | |
---|
431 | def test_connection(self): |
---|
432 | self.execute("SELECT 1;") |
---|
433 | |
---|
434 | def represent(self, obj, field_type): |
---|
435 | if isinstance(obj, (Expression, Field)): |
---|
436 | return str(obj) |
---|
437 | return super(SQLAdapter, self).represent(obj, field_type) |
---|
438 | |
---|
439 | def adapt(self, obj): |
---|
440 | return "'%s'" % obj.replace("'", "''") |
---|
441 | |
---|
442 | def smart_adapt(self, obj): |
---|
443 | if isinstance(obj, (int, float)): |
---|
444 | return str(obj) |
---|
445 | return self.adapt(str(obj)) |
---|
446 | |
---|
447 | def fetchall(self): |
---|
448 | return self.cursor.fetchall() |
---|
449 | |
---|
450 | def fetchone(self): |
---|
451 | return self.cursor.fetchone() |
---|
452 | |
---|
453 | def _build_handlers_for_execution(self): |
---|
454 | rv = [] |
---|
455 | for handler_class in self.execution_handlers: |
---|
456 | rv.append(handler_class(self)) |
---|
457 | return rv |
---|
458 | |
---|
459 | def filter_sql_command(self, command): |
---|
460 | return command |
---|
461 | |
---|
462 | @with_connection_or_raise |
---|
463 | def execute(self, *args, **kwargs): |
---|
464 | command = self.filter_sql_command(args[0]) |
---|
465 | handlers = self._build_handlers_for_execution() |
---|
466 | for handler in handlers: |
---|
467 | handler.before_execute(command) |
---|
468 | rv = self.cursor.execute(command, *args[1:], **kwargs) |
---|
469 | for handler in handlers: |
---|
470 | handler.after_execute(command) |
---|
471 | return rv |
---|
472 | |
---|
473 | def _expand(self, expression, field_type=None, colnames=False, query_env={}): |
---|
474 | if isinstance(expression, Field): |
---|
475 | if not colnames: |
---|
476 | rv = expression.sqlsafe |
---|
477 | else: |
---|
478 | rv = expression.longname |
---|
479 | if field_type == "string" and expression.type not in ( |
---|
480 | "string", |
---|
481 | "text", |
---|
482 | "json", |
---|
483 | "jsonb", |
---|
484 | "password", |
---|
485 | ): |
---|
486 | rv = self.dialect.cast(rv, self.types["text"], query_env) |
---|
487 | elif isinstance(expression, (Expression, Query)): |
---|
488 | first = expression.first |
---|
489 | second = expression.second |
---|
490 | op = expression.op |
---|
491 | optional_args = expression.optional_args or {} |
---|
492 | optional_args["query_env"] = query_env |
---|
493 | if second is not None: |
---|
494 | rv = op(first, second, **optional_args) |
---|
495 | elif first is not None: |
---|
496 | rv = op(first, **optional_args) |
---|
497 | elif isinstance(op, str): |
---|
498 | if op.endswith(";"): |
---|
499 | op = op[:-1] |
---|
500 | rv = "(%s)" % op |
---|
501 | else: |
---|
502 | rv = op() |
---|
503 | elif field_type: |
---|
504 | rv = self.represent(expression, field_type) |
---|
505 | elif isinstance(expression, (list, tuple)): |
---|
506 | rv = ",".join(self.represent(item, field_type) for item in expression) |
---|
507 | elif isinstance(expression, bool): |
---|
508 | rv = self.dialect.true_exp if expression else self.dialect.false_exp |
---|
509 | else: |
---|
510 | rv = expression |
---|
511 | return str(rv) |
---|
512 | |
---|
513 | def _expand_for_index( |
---|
514 | self, expression, field_type=None, colnames=False, query_env={} |
---|
515 | ): |
---|
516 | if isinstance(expression, Field): |
---|
517 | return expression._rname |
---|
518 | return self._expand(expression, field_type, colnames, query_env) |
---|
519 | |
---|
520 | @contextmanager |
---|
521 | def index_expander(self): |
---|
522 | self.expand = self._expand_for_index |
---|
523 | yield |
---|
524 | self.expand = self._expand |
---|
525 | |
---|
526 | def lastrowid(self, table): |
---|
527 | return self.cursor.lastrowid |
---|
528 | |
---|
529 | def _insert(self, table, fields): |
---|
530 | if fields: |
---|
531 | return self.dialect.insert( |
---|
532 | table._rname, |
---|
533 | ",".join(el[0]._rname for el in fields), |
---|
534 | ",".join(self.expand(v, f.type) for f, v in fields), |
---|
535 | ) |
---|
536 | return self.dialect.insert_empty(table._rname) |
---|
537 | |
---|
538 | def insert(self, table, fields): |
---|
539 | query = self._insert(table, fields) |
---|
540 | try: |
---|
541 | self.execute(query) |
---|
542 | except: |
---|
543 | e = sys.exc_info()[1] |
---|
544 | if hasattr(table, "_on_insert_error"): |
---|
545 | return table._on_insert_error(table, fields, e) |
---|
546 | raise e |
---|
547 | if hasattr(table, "_primarykey"): |
---|
548 | pkdict = dict( |
---|
549 | [(k[0].name, k[1]) for k in fields if k[0].name in table._primarykey] |
---|
550 | ) |
---|
551 | if pkdict: |
---|
552 | return pkdict |
---|
553 | id = self.lastrowid(table) |
---|
554 | if hasattr(table, "_primarykey") and len(table._primarykey) == 1: |
---|
555 | id = {table._primarykey[0]: id} |
---|
556 | if not isinstance(id, integer_types): |
---|
557 | return id |
---|
558 | rid = Reference(id) |
---|
559 | (rid._table, rid._record) = (table, None) |
---|
560 | return rid |
---|
561 | |
---|
562 | def _update(self, table, query, fields): |
---|
563 | sql_q = "" |
---|
564 | query_env = dict(current_scope=[table._tablename]) |
---|
565 | if query: |
---|
566 | if use_common_filters(query): |
---|
567 | query = self.common_filter(query, [table]) |
---|
568 | sql_q = self.expand(query, query_env=query_env) |
---|
569 | sql_v = ",".join( |
---|
570 | [ |
---|
571 | "%s=%s" |
---|
572 | % (field._rname, self.expand(value, field.type, query_env=query_env)) |
---|
573 | for (field, value) in fields |
---|
574 | ] |
---|
575 | ) |
---|
576 | return self.dialect.update(table, sql_v, sql_q) |
---|
577 | |
---|
578 | def update(self, table, query, fields): |
---|
579 | sql = self._update(table, query, fields) |
---|
580 | try: |
---|
581 | self.execute(sql) |
---|
582 | except: |
---|
583 | e = sys.exc_info()[1] |
---|
584 | if hasattr(table, "_on_update_error"): |
---|
585 | return table._on_update_error(table, query, fields, e) |
---|
586 | raise e |
---|
587 | try: |
---|
588 | return self.cursor.rowcount |
---|
589 | except: |
---|
590 | return None |
---|
591 | |
---|
592 | def _delete(self, table, query): |
---|
593 | sql_q = "" |
---|
594 | query_env = dict(current_scope=[table._tablename]) |
---|
595 | if query: |
---|
596 | if use_common_filters(query): |
---|
597 | query = self.common_filter(query, [table]) |
---|
598 | sql_q = self.expand(query, query_env=query_env) |
---|
599 | return self.dialect.delete(table, sql_q) |
---|
600 | |
---|
601 | def delete(self, table, query): |
---|
602 | sql = self._delete(table, query) |
---|
603 | self.execute(sql) |
---|
604 | try: |
---|
605 | return self.cursor.rowcount |
---|
606 | except: |
---|
607 | return None |
---|
608 | |
---|
609 | def _colexpand(self, field, query_env): |
---|
610 | return self.expand(field, colnames=True, query_env=query_env) |
---|
611 | |
---|
612 | def _geoexpand(self, field, query_env): |
---|
613 | if ( |
---|
614 | isinstance(field.type, str) |
---|
615 | and field.type.startswith("geo") |
---|
616 | and isinstance(field, Field) |
---|
617 | ): |
---|
618 | field = field.st_astext() |
---|
619 | return self.expand(field, query_env=query_env) |
---|
620 | |
---|
621 | def _build_joins_for_select(self, tablenames, param): |
---|
622 | if not isinstance(param, (tuple, list)): |
---|
623 | param = [param] |
---|
624 | tablemap = {} |
---|
625 | for item in param: |
---|
626 | if isinstance(item, Expression): |
---|
627 | item = item.first |
---|
628 | key = item._tablename |
---|
629 | if tablemap.get(key, item) is not item: |
---|
630 | raise ValueError("Name conflict in table list: %s" % key) |
---|
631 | tablemap[key] = item |
---|
632 | join_tables = [t._tablename for t in param if not isinstance(t, Expression)] |
---|
633 | join_on = [t for t in param if isinstance(t, Expression)] |
---|
634 | tables_to_merge = {} |
---|
635 | for t in join_on: |
---|
636 | tables_to_merge = merge_tablemaps(tables_to_merge, self.tables(t)) |
---|
637 | join_on_tables = [t.first._tablename for t in join_on] |
---|
638 | for t in join_on_tables: |
---|
639 | if t in tables_to_merge: |
---|
640 | tables_to_merge.pop(t) |
---|
641 | important_tablenames = join_tables + join_on_tables + list(tables_to_merge) |
---|
642 | excluded = [t for t in tablenames if t not in important_tablenames] |
---|
643 | return ( |
---|
644 | join_tables, |
---|
645 | join_on, |
---|
646 | tables_to_merge, |
---|
647 | join_on_tables, |
---|
648 | important_tablenames, |
---|
649 | excluded, |
---|
650 | tablemap, |
---|
651 | ) |
---|
652 | |
---|
653 | def _select_wcols( |
---|
654 | self, |
---|
655 | query, |
---|
656 | fields, |
---|
657 | left=False, |
---|
658 | join=False, |
---|
659 | distinct=False, |
---|
660 | orderby=False, |
---|
661 | groupby=False, |
---|
662 | having=False, |
---|
663 | limitby=False, |
---|
664 | orderby_on_limitby=True, |
---|
665 | for_update=False, |
---|
666 | outer_scoped=[], |
---|
667 | required=None, |
---|
668 | cache=None, |
---|
669 | cacheable=None, |
---|
670 | processor=None, |
---|
671 | ): |
---|
672 | #: parse tablemap |
---|
673 | tablemap = self.tables(query) |
---|
674 | #: apply common filters if needed |
---|
675 | if use_common_filters(query): |
---|
676 | query = self.common_filter(query, list(tablemap.values())) |
---|
677 | #: auto-adjust tables |
---|
678 | tablemap = merge_tablemaps(tablemap, self.tables(*fields)) |
---|
679 | #: remove outer scoped tables if needed |
---|
680 | for item in outer_scoped: |
---|
681 | # FIXME: check for name conflicts |
---|
682 | tablemap.pop(item, None) |
---|
683 | if len(tablemap) < 1: |
---|
684 | raise SyntaxError("Set: no tables selected") |
---|
685 | query_tables = list(tablemap) |
---|
686 | #: check for_update argument |
---|
687 | # [Note - gi0baro] I think this should be removed since useless? |
---|
688 | # should affect only NoSQL? |
---|
689 | if self.can_select_for_update is False and for_update is True: |
---|
690 | raise SyntaxError("invalid select attribute: for_update") |
---|
691 | #: build joins (inner, left outer) and table names |
---|
692 | if join: |
---|
693 | ( |
---|
694 | # FIXME? ijoin_tables is never used |
---|
695 | ijoin_tables, |
---|
696 | ijoin_on, |
---|
697 | itables_to_merge, |
---|
698 | ijoin_on_tables, |
---|
699 | iimportant_tablenames, |
---|
700 | iexcluded, |
---|
701 | itablemap, |
---|
702 | ) = self._build_joins_for_select(tablemap, join) |
---|
703 | tablemap = merge_tablemaps(tablemap, itables_to_merge) |
---|
704 | tablemap = merge_tablemaps(tablemap, itablemap) |
---|
705 | if left: |
---|
706 | ( |
---|
707 | join_tables, |
---|
708 | join_on, |
---|
709 | tables_to_merge, |
---|
710 | join_on_tables, |
---|
711 | important_tablenames, |
---|
712 | excluded, |
---|
713 | jtablemap, |
---|
714 | ) = self._build_joins_for_select(tablemap, left) |
---|
715 | tablemap = merge_tablemaps(tablemap, tables_to_merge) |
---|
716 | tablemap = merge_tablemaps(tablemap, jtablemap) |
---|
717 | current_scope = outer_scoped + list(tablemap) |
---|
718 | query_env = dict(current_scope=current_scope, parent_scope=outer_scoped) |
---|
719 | #: prepare columns and expand fields |
---|
720 | colnames = [self._colexpand(x, query_env) for x in fields] |
---|
721 | sql_fields = ", ".join(self._geoexpand(x, query_env) for x in fields) |
---|
722 | table_alias = lambda name: tablemap[name].query_name(outer_scoped)[0] |
---|
723 | if join and not left: |
---|
724 | cross_joins = iexcluded + list(itables_to_merge) |
---|
725 | tokens = [table_alias(cross_joins[0])] |
---|
726 | tokens += [ |
---|
727 | self.dialect.cross_join(table_alias(t), query_env) |
---|
728 | for t in cross_joins[1:] |
---|
729 | ] |
---|
730 | tokens += [self.dialect.join(t, query_env) for t in ijoin_on] |
---|
731 | sql_t = " ".join(tokens) |
---|
732 | elif not join and left: |
---|
733 | cross_joins = excluded + list(tables_to_merge) |
---|
734 | tokens = [table_alias(cross_joins[0])] |
---|
735 | tokens += [ |
---|
736 | self.dialect.cross_join(table_alias(t), query_env) |
---|
737 | for t in cross_joins[1:] |
---|
738 | ] |
---|
739 | # FIXME: WTF? This is not correct syntax at least on PostgreSQL |
---|
740 | if join_tables: |
---|
741 | tokens.append( |
---|
742 | self.dialect.left_join( |
---|
743 | ",".join([table_alias(t) for t in join_tables]), query_env |
---|
744 | ) |
---|
745 | ) |
---|
746 | tokens += [self.dialect.left_join(t, query_env) for t in join_on] |
---|
747 | sql_t = " ".join(tokens) |
---|
748 | elif join and left: |
---|
749 | all_tables_in_query = set( |
---|
750 | important_tablenames + iimportant_tablenames + query_tables |
---|
751 | ) |
---|
752 | tables_in_joinon = set(join_on_tables + ijoin_on_tables) |
---|
753 | tables_not_in_joinon = list( |
---|
754 | all_tables_in_query.difference(tables_in_joinon) |
---|
755 | ) |
---|
756 | tokens = [table_alias(tables_not_in_joinon[0])] |
---|
757 | tokens += [ |
---|
758 | self.dialect.cross_join(table_alias(t), query_env) |
---|
759 | for t in tables_not_in_joinon[1:] |
---|
760 | ] |
---|
761 | tokens += [self.dialect.join(t, query_env) for t in ijoin_on] |
---|
762 | # FIXME: WTF? This is not correct syntax at least on PostgreSQL |
---|
763 | if join_tables: |
---|
764 | tokens.append( |
---|
765 | self.dialect.left_join( |
---|
766 | ",".join([table_alias(t) for t in join_tables]), query_env |
---|
767 | ) |
---|
768 | ) |
---|
769 | tokens += [self.dialect.left_join(t, query_env) for t in join_on] |
---|
770 | sql_t = " ".join(tokens) |
---|
771 | else: |
---|
772 | sql_t = ", ".join(table_alias(t) for t in query_tables) |
---|
773 | #: expand query if needed |
---|
774 | if query: |
---|
775 | query = self.expand(query, query_env=query_env) |
---|
776 | if having: |
---|
777 | having = self.expand(having, query_env=query_env) |
---|
778 | #: groupby |
---|
779 | sql_grp = groupby |
---|
780 | if groupby: |
---|
781 | if isinstance(groupby, (list, tuple)): |
---|
782 | groupby = xorify(groupby) |
---|
783 | sql_grp = self.expand(groupby, query_env=query_env) |
---|
784 | #: orderby |
---|
785 | sql_ord = False |
---|
786 | if orderby: |
---|
787 | if isinstance(orderby, (list, tuple)): |
---|
788 | orderby = xorify(orderby) |
---|
789 | if str(orderby) == "<random>": |
---|
790 | sql_ord = self.dialect.random |
---|
791 | else: |
---|
792 | sql_ord = self.expand(orderby, query_env=query_env) |
---|
793 | #: set default orderby if missing |
---|
794 | if ( |
---|
795 | limitby |
---|
796 | and not groupby |
---|
797 | and query_tables |
---|
798 | and orderby_on_limitby |
---|
799 | and not orderby |
---|
800 | ): |
---|
801 | sql_ord = ", ".join( |
---|
802 | [ |
---|
803 | tablemap[t][x].sqlsafe |
---|
804 | for t in query_tables |
---|
805 | if not isinstance(tablemap[t], Select) |
---|
806 | for x in ( |
---|
807 | hasattr(tablemap[t], "_primarykey") |
---|
808 | and tablemap[t]._primarykey |
---|
809 | or ["_id"] |
---|
810 | ) |
---|
811 | ] |
---|
812 | ) |
---|
813 | #: build sql using dialect |
---|
814 | return ( |
---|
815 | colnames, |
---|
816 | self.dialect.select( |
---|
817 | sql_fields, |
---|
818 | sql_t, |
---|
819 | query, |
---|
820 | sql_grp, |
---|
821 | having, |
---|
822 | sql_ord, |
---|
823 | limitby, |
---|
824 | distinct, |
---|
825 | for_update and self.can_select_for_update, |
---|
826 | ), |
---|
827 | ) |
---|
828 | |
---|
829 | def _select(self, query, fields, attributes): |
---|
830 | return self._select_wcols(query, fields, **attributes)[1] |
---|
831 | |
---|
832 | def nested_select(self, query, fields, attributes): |
---|
833 | return Select(self.db, query, fields, attributes) |
---|
834 | |
---|
835 | def _select_aux_execute(self, sql): |
---|
836 | self.execute(sql) |
---|
837 | return self.cursor.fetchall() |
---|
838 | |
---|
839 | def _select_aux(self, sql, fields, attributes, colnames): |
---|
840 | cache = attributes.get("cache", None) |
---|
841 | if not cache: |
---|
842 | rows = self._select_aux_execute(sql) |
---|
843 | else: |
---|
844 | if isinstance(cache, dict): |
---|
845 | cache_model = cache["model"] |
---|
846 | time_expire = cache["expiration"] |
---|
847 | key = cache.get("key") |
---|
848 | if not key: |
---|
849 | key = self.uri + "/" + sql + "/rows" |
---|
850 | key = hashlib_md5(key).hexdigest() |
---|
851 | else: |
---|
852 | (cache_model, time_expire) = cache |
---|
853 | key = self.uri + "/" + sql + "/rows" |
---|
854 | key = hashlib_md5(key).hexdigest() |
---|
855 | rows = cache_model( |
---|
856 | key, |
---|
857 | lambda self=self, sql=sql: self._select_aux_execute(sql), |
---|
858 | time_expire, |
---|
859 | ) |
---|
860 | if isinstance(rows, tuple): |
---|
861 | rows = list(rows) |
---|
862 | limitby = attributes.get("limitby", None) or (0,) |
---|
863 | rows = self.rowslice(rows, limitby[0], None) |
---|
864 | processor = attributes.get("processor", self.parse) |
---|
865 | cacheable = attributes.get("cacheable", False) |
---|
866 | return processor(rows, fields, colnames, cacheable=cacheable) |
---|
867 | |
---|
868 | def _cached_select(self, cache, sql, fields, attributes, colnames): |
---|
869 | del attributes["cache"] |
---|
870 | (cache_model, time_expire) = cache |
---|
871 | key = self.uri + "/" + sql |
---|
872 | key = hashlib_md5(key).hexdigest() |
---|
873 | args = (sql, fields, attributes, colnames) |
---|
874 | ret = cache_model( |
---|
875 | key, lambda self=self, args=args: self._select_aux(*args), time_expire |
---|
876 | ) |
---|
877 | ret._restore_fields(fields) |
---|
878 | return ret |
---|
879 | |
---|
880 | def select(self, query, fields, attributes): |
---|
881 | colnames, sql = self._select_wcols(query, fields, **attributes) |
---|
882 | cache = attributes.get("cache", None) |
---|
883 | if cache and attributes.get("cacheable", False): |
---|
884 | return self._cached_select(cache, sql, fields, attributes, colnames) |
---|
885 | return self._select_aux(sql, fields, attributes, colnames) |
---|
886 | |
---|
887 | def iterselect(self, query, fields, attributes): |
---|
888 | colnames, sql = self._select_wcols(query, fields, **attributes) |
---|
889 | cacheable = attributes.get("cacheable", False) |
---|
890 | return self.iterparse(sql, fields, colnames, cacheable=cacheable) |
---|
891 | |
---|
892 | def _count(self, query, distinct=None): |
---|
893 | tablemap = self.tables(query) |
---|
894 | tablenames = list(tablemap) |
---|
895 | tables = list(tablemap.values()) |
---|
896 | query_env = dict(current_scope=tablenames) |
---|
897 | sql_q = "" |
---|
898 | if query: |
---|
899 | if use_common_filters(query): |
---|
900 | query = self.common_filter(query, tables) |
---|
901 | sql_q = self.expand(query, query_env=query_env) |
---|
902 | sql_t = ",".join(self.table_alias(t, []) for t in tables) |
---|
903 | sql_fields = "*" |
---|
904 | if distinct: |
---|
905 | if isinstance(distinct, (list, tuple)): |
---|
906 | distinct = xorify(distinct) |
---|
907 | sql_fields = self.expand(distinct, query_env=query_env) |
---|
908 | return self.dialect.select( |
---|
909 | self.dialect.count(sql_fields, distinct), sql_t, sql_q |
---|
910 | ) |
---|
911 | |
---|
912 | def count(self, query, distinct=None): |
---|
913 | self.execute(self._count(query, distinct)) |
---|
914 | return self.cursor.fetchone()[0] |
---|
915 | |
---|
916 | def bulk_insert(self, table, items): |
---|
917 | return [self.insert(table, item) for item in items] |
---|
918 | |
---|
919 | def create_table(self, *args, **kwargs): |
---|
920 | return self.migrator.create_table(*args, **kwargs) |
---|
921 | |
---|
922 | def _drop_table_cleanup(self, table): |
---|
923 | super(SQLAdapter, self)._drop_table_cleanup(table) |
---|
924 | if table._dbt: |
---|
925 | self.migrator.file_delete(table._dbt) |
---|
926 | self.migrator.log("success!\n", table) |
---|
927 | |
---|
928 | def drop_table(self, table, mode=""): |
---|
929 | queries = self.dialect.drop_table(table, mode) |
---|
930 | for query in queries: |
---|
931 | if table._dbt: |
---|
932 | self.migrator.log(query + "\n", table) |
---|
933 | self.execute(query) |
---|
934 | self.commit() |
---|
935 | self._drop_table_cleanup(table) |
---|
936 | |
---|
937 | @deprecated("drop", "drop_table", "SQLAdapter") |
---|
938 | def drop(self, table, mode=""): |
---|
939 | return self.drop_table(table, mode="") |
---|
940 | |
---|
941 | def truncate(self, table, mode=""): |
---|
942 | # Prepare functions "write_to_logfile" and "close_logfile" |
---|
943 | try: |
---|
944 | queries = self.dialect.truncate(table, mode) |
---|
945 | for query in queries: |
---|
946 | self.migrator.log(query + "\n", table) |
---|
947 | self.execute(query) |
---|
948 | self.migrator.log("success!\n", table) |
---|
949 | finally: |
---|
950 | pass |
---|
951 | |
---|
952 | def create_index(self, table, index_name, *fields, **kwargs): |
---|
953 | expressions = [ |
---|
954 | field._rname if isinstance(field, Field) else field for field in fields |
---|
955 | ] |
---|
956 | sql = self.dialect.create_index(index_name, table, expressions, **kwargs) |
---|
957 | try: |
---|
958 | self.execute(sql) |
---|
959 | self.commit() |
---|
960 | except Exception as e: |
---|
961 | self.rollback() |
---|
962 | err = ( |
---|
963 | "Error creating index %s\n Driver error: %s\n" |
---|
964 | + " SQL instruction: %s" |
---|
965 | ) |
---|
966 | raise RuntimeError(err % (index_name, str(e), sql)) |
---|
967 | return True |
---|
968 | |
---|
969 | def drop_index(self, table, index_name): |
---|
970 | sql = self.dialect.drop_index(index_name, table) |
---|
971 | try: |
---|
972 | self.execute(sql) |
---|
973 | self.commit() |
---|
974 | except Exception as e: |
---|
975 | self.rollback() |
---|
976 | err = "Error dropping index %s\n Driver error: %s" |
---|
977 | raise RuntimeError(err % (index_name, str(e))) |
---|
978 | return True |
---|
979 | |
---|
980 | def distributed_transaction_begin(self, key): |
---|
981 | pass |
---|
982 | |
---|
983 | @with_connection |
---|
984 | def commit(self): |
---|
985 | return self.connection.commit() |
---|
986 | |
---|
987 | @with_connection |
---|
988 | def rollback(self): |
---|
989 | return self.connection.rollback() |
---|
990 | |
---|
991 | @with_connection |
---|
992 | def prepare(self, key): |
---|
993 | self.connection.prepare() |
---|
994 | |
---|
995 | @with_connection |
---|
996 | def commit_prepared(self, key): |
---|
997 | self.connection.commit() |
---|
998 | |
---|
999 | @with_connection |
---|
1000 | def rollback_prepared(self, key): |
---|
1001 | self.connection.rollback() |
---|
1002 | |
---|
1003 | def create_sequence_and_triggers(self, query, table, **args): |
---|
1004 | self.execute(query) |
---|
1005 | |
---|
1006 | def sqlsafe_table(self, tablename, original_tablename=None): |
---|
1007 | if original_tablename is not None: |
---|
1008 | return self.dialect.alias(original_tablename, tablename) |
---|
1009 | return self.dialect.quote(tablename) |
---|
1010 | |
---|
1011 | def sqlsafe_field(self, fieldname): |
---|
1012 | return self.dialect.quote(fieldname) |
---|
1013 | |
---|
1014 | def table_alias(self, tbl, current_scope=[]): |
---|
1015 | if isinstance(tbl, basestring): |
---|
1016 | tbl = self.db[tbl] |
---|
1017 | return tbl.query_name(current_scope)[0] |
---|
1018 | |
---|
1019 | def id_query(self, table): |
---|
1020 | pkeys = getattr(table, "_primarykey", None) |
---|
1021 | if pkeys: |
---|
1022 | return table[pkeys[0]] != None |
---|
1023 | return table._id != None |
---|
1024 | |
---|
1025 | |
---|
1026 | class NoSQLAdapter(BaseAdapter): |
---|
1027 | can_select_for_update = False |
---|
1028 | |
---|
1029 | def commit(self): |
---|
1030 | pass |
---|
1031 | |
---|
1032 | def rollback(self): |
---|
1033 | pass |
---|
1034 | |
---|
1035 | def prepare(self): |
---|
1036 | pass |
---|
1037 | |
---|
1038 | def commit_prepared(self, key): |
---|
1039 | pass |
---|
1040 | |
---|
1041 | def rollback_prepared(self, key): |
---|
1042 | pass |
---|
1043 | |
---|
1044 | def id_query(self, table): |
---|
1045 | return table._id > 0 |
---|
1046 | |
---|
1047 | def create_table(self, table, migrate=True, fake_migrate=False, polymodel=None): |
---|
1048 | table._dbt = None |
---|
1049 | table._notnulls = [] |
---|
1050 | for field_name in table.fields: |
---|
1051 | if table[field_name].notnull: |
---|
1052 | table._notnulls.append(field_name) |
---|
1053 | table._uniques = [] |
---|
1054 | for field_name in table.fields: |
---|
1055 | if table[field_name].unique: |
---|
1056 | # this is unnecessary if the fields are indexed and unique |
---|
1057 | table._uniques.append(field_name) |
---|
1058 | |
---|
1059 | def drop_table(self, table, mode=""): |
---|
1060 | ctable = self.connection[table._tablename] |
---|
1061 | ctable.drop() |
---|
1062 | self._drop_table_cleanup(table) |
---|
1063 | |
---|
1064 | @deprecated("drop", "drop_table", "SQLAdapter") |
---|
1065 | def drop(self, table, mode=""): |
---|
1066 | return self.drop_table(table, mode="") |
---|
1067 | |
---|
1068 | def _select(self, *args, **kwargs): |
---|
1069 | raise NotOnNOSQLError("Nested queries are not supported on NoSQL databases") |
---|
1070 | |
---|
1071 | def nested_select(self, *args, **kwargs): |
---|
1072 | raise NotOnNOSQLError("Nested queries are not supported on NoSQL databases") |
---|
1073 | |
---|
1074 | |
---|
1075 | class NullAdapter(BaseAdapter): |
---|
1076 | def _load_dependencies(self): |
---|
1077 | from ..dialects.base import CommonDialect |
---|
1078 | |
---|
1079 | self.dialect = CommonDialect(self) |
---|
1080 | |
---|
1081 | def find_driver(self): |
---|
1082 | pass |
---|
1083 | |
---|
1084 | def connector(self): |
---|
1085 | return NullDriver() |
---|