1 | import collections |
---|
2 | import copy |
---|
3 | import datetime |
---|
4 | import fnmatch |
---|
5 | import functools |
---|
6 | import re |
---|
7 | |
---|
8 | __version__ = "0.1" |
---|
9 | |
---|
10 | __all__ = ["RestAPI", "Policy", "ALLOW_ALL_POLICY", "DENY_ALL_POLICY"] |
---|
11 | |
---|
12 | MAX_LIMIT = 1000 |
---|
13 | |
---|
14 | |
---|
15 | class PolicyViolation(ValueError): |
---|
16 | pass |
---|
17 | |
---|
18 | |
---|
19 | class InvalidFormat(ValueError): |
---|
20 | pass |
---|
21 | |
---|
22 | |
---|
23 | class NotFound(ValueError): |
---|
24 | pass |
---|
25 | |
---|
26 | |
---|
27 | def maybe_call(value): |
---|
28 | return value() if callable(value) else value |
---|
29 | |
---|
30 | |
---|
31 | def error_wrapper(func): |
---|
32 | @functools.wraps(func) |
---|
33 | def wrapper(*args, **kwargs): |
---|
34 | data = {} |
---|
35 | try: |
---|
36 | data = func(*args, **kwargs) |
---|
37 | if not data.get("errors"): |
---|
38 | data["status"] = "success" |
---|
39 | data["code"] = 200 |
---|
40 | else: |
---|
41 | data["status"] = "error" |
---|
42 | data["message"] = "Validation Errors" |
---|
43 | data["code"] = 422 |
---|
44 | except PolicyViolation as e: |
---|
45 | data["status"] = "error" |
---|
46 | data["message"] = str(e) |
---|
47 | data["code"] = 401 |
---|
48 | except NotFound as e: |
---|
49 | data["status"] = "error" |
---|
50 | data["message"] = str(e) |
---|
51 | data["code"] = 404 |
---|
52 | except (InvalidFormat, KeyError, ValueError) as e: |
---|
53 | data["status"] = "error" |
---|
54 | data["message"] = str(e) |
---|
55 | data["code"] = 400 |
---|
56 | finally: |
---|
57 | data["timestamp"] = datetime.datetime.utcnow().isoformat() |
---|
58 | data["api_version"] = __version__ |
---|
59 | return data |
---|
60 | |
---|
61 | return wrapper |
---|
62 | |
---|
63 | |
---|
64 | class Policy(object): |
---|
65 | |
---|
66 | model = { |
---|
67 | "POST": {"authorize": False, "fields": None}, |
---|
68 | "PUT": {"authorize": False, "fields": None}, |
---|
69 | "DELETE": {"authorize": False}, |
---|
70 | "GET": { |
---|
71 | "authorize": False, |
---|
72 | "fields": None, |
---|
73 | "query": None, |
---|
74 | "allowed_patterns": [], |
---|
75 | "denied_patterns": [], |
---|
76 | "limit": MAX_LIMIT, |
---|
77 | "allow_lookup": False, |
---|
78 | }, |
---|
79 | } |
---|
80 | |
---|
81 | def __init__(self): |
---|
82 | self.info = {} |
---|
83 | |
---|
84 | def set(self, tablename, method, **attributes): |
---|
85 | method = method.upper() |
---|
86 | if not method in self.model: |
---|
87 | raise InvalidFormat("Invalid policy method: %s" % method) |
---|
88 | invalid_keys = [key for key in attributes if key not in self.model[method]] |
---|
89 | if invalid_keys: |
---|
90 | raise InvalidFormat("Invalid keys: %s" % ",".join(invalid_keys)) |
---|
91 | if not tablename in self.info: |
---|
92 | self.info[tablename] = copy.deepcopy(self.model) |
---|
93 | self.info[tablename][method].update(attributes) |
---|
94 | |
---|
95 | def get(self, tablename, method, name): |
---|
96 | policy = self.info.get(tablename) or self.info.get("*") |
---|
97 | if not policy: |
---|
98 | raise PolicyViolation("No policy for this object") |
---|
99 | return maybe_call(policy[method][name]) |
---|
100 | |
---|
101 | def check_if_allowed( |
---|
102 | self, method, tablename, id=None, get_vars=None, post_vars=None, exceptions=True |
---|
103 | ): |
---|
104 | get_vars = get_vars or {} |
---|
105 | post_vars = post_vars or {} |
---|
106 | policy = self.info.get(tablename) or self.info.get("*") |
---|
107 | if not policy: |
---|
108 | if exceptions: |
---|
109 | raise PolicyViolation("No policy for this object") |
---|
110 | return False |
---|
111 | policy = policy.get(method.upper()) |
---|
112 | if not policy: |
---|
113 | if exceptions: |
---|
114 | raise PolicyViolation("No policy for this method") |
---|
115 | return False |
---|
116 | authorize = policy.get("authorize") |
---|
117 | if authorize is False or ( |
---|
118 | callable(authorize) and not authorize(tablename, id, get_vars, post_vars) |
---|
119 | ): |
---|
120 | if exceptions: |
---|
121 | raise PolicyViolation("Not authorized") |
---|
122 | return False |
---|
123 | for key in get_vars: |
---|
124 | if any(fnmatch.fnmatch(key, p) for p in policy["denied_patterns"]): |
---|
125 | if exceptions: |
---|
126 | raise PolicyViolation("Pattern is not allowed") |
---|
127 | return False |
---|
128 | allowed_patterns = policy["allowed_patterns"] |
---|
129 | if "**" not in allowed_patterns and not any( |
---|
130 | fnmatch.fnmatch(key, p) for p in allowed_patterns |
---|
131 | ): |
---|
132 | if exceptions: |
---|
133 | raise PolicyViolation("Pattern is not explicitely allowed") |
---|
134 | return False |
---|
135 | return True |
---|
136 | |
---|
137 | def check_if_lookup_allowed(self, tablename, exceptions=True): |
---|
138 | policy = self.info.get(tablename) or self.info.get("*") |
---|
139 | if not policy: |
---|
140 | if exceptions: |
---|
141 | raise PolicyViolation("No policy for this object") |
---|
142 | return False |
---|
143 | policy = policy.get("GET") |
---|
144 | if not policy: |
---|
145 | if exceptions: |
---|
146 | raise PolicyViolation("No policy for this method") |
---|
147 | return False |
---|
148 | if policy.get("allow_lookup"): |
---|
149 | return True |
---|
150 | return False |
---|
151 | |
---|
152 | def allowed_fieldnames(self, table, method="GET"): |
---|
153 | method = method.upper() |
---|
154 | policy = self.info.get(table._tablename) or self.info.get("*") |
---|
155 | policy = policy[method] |
---|
156 | allowed_fieldnames = policy["fields"] |
---|
157 | if not allowed_fieldnames: |
---|
158 | allowed_fieldnames = [ |
---|
159 | f.name |
---|
160 | for f in table |
---|
161 | if (method == "GET" and maybe_call(f.readable)) |
---|
162 | or (method != "GET" and maybe_call(f.writable)) |
---|
163 | ] |
---|
164 | return allowed_fieldnames |
---|
165 | |
---|
166 | def check_fieldnames(self, table, fieldnames, method="GET"): |
---|
167 | allowed_fieldnames = self.allowed_fieldnames(table, method) |
---|
168 | invalid_fieldnames = set(fieldnames) - set(allowed_fieldnames) |
---|
169 | if invalid_fieldnames: |
---|
170 | raise InvalidFormat("Invalid fields: %s" % list(invalid_fieldnames)) |
---|
171 | |
---|
172 | |
---|
173 | DENY_ALL_POLICY = Policy() |
---|
174 | ALLOW_ALL_POLICY = Policy() |
---|
175 | ALLOW_ALL_POLICY.set( |
---|
176 | tablename="*", |
---|
177 | method="GET", |
---|
178 | authorize=True, |
---|
179 | allowed_patterns=["**"], |
---|
180 | allow_lookup=True, |
---|
181 | ) |
---|
182 | ALLOW_ALL_POLICY.set(tablename="*", method="POST", authorize=True) |
---|
183 | ALLOW_ALL_POLICY.set(tablename="*", method="PUT", authorize=True) |
---|
184 | ALLOW_ALL_POLICY.set(tablename="*", method="DELETE", authorize=True) |
---|
185 | |
---|
186 | |
---|
187 | class RestAPI(object): |
---|
188 | |
---|
189 | re_table_and_fields = re.compile(r"\w+([\w+(,\w+)+])?") |
---|
190 | re_lookups = re.compile( |
---|
191 | r"((\w*\!?\:)?(\w+(\[\w+(,\w+)*\])?)(\.\w+(\[\w+(,\w+)*\])?)*)" |
---|
192 | ) |
---|
193 | re_no_brackets = re.compile(r"\[.*?\]") |
---|
194 | |
---|
195 | def __init__(self, db, policy): |
---|
196 | self.db = db |
---|
197 | self.policy = policy |
---|
198 | |
---|
199 | @error_wrapper |
---|
200 | def __call__(self, method, tablename, id=None, get_vars=None, post_vars=None): |
---|
201 | method = method.upper() |
---|
202 | get_vars = get_vars or {} |
---|
203 | post_vars = post_vars or {} |
---|
204 | # validate incoming request |
---|
205 | tname, tfieldnames = RestAPI.parse_table_and_fields(tablename) |
---|
206 | if not tname in self.db.tables: |
---|
207 | raise InvalidFormat("Invalid table name: %s" % tname) |
---|
208 | if self.policy: |
---|
209 | self.policy.check_if_allowed(method, tablename, id, get_vars, post_vars) |
---|
210 | if method in ["POST", "PUT"]: |
---|
211 | self.policy.check_fieldnames( |
---|
212 | self.db[tablename], post_vars.keys(), method |
---|
213 | ) |
---|
214 | # apply rules |
---|
215 | if method == "GET": |
---|
216 | if id: |
---|
217 | get_vars["id.eq"] = id |
---|
218 | return self.search(tablename, get_vars) |
---|
219 | elif method == "POST": |
---|
220 | table = self.db[tablename] |
---|
221 | return table.validate_and_insert(**post_vars).as_dict() |
---|
222 | elif method == "PUT": |
---|
223 | id = id or post_vars["id"] |
---|
224 | if not id: |
---|
225 | raise InvalidFormat("No item id specified") |
---|
226 | table = self.db[tablename] |
---|
227 | data = table.validate_and_update(id, **post_vars).as_dict() |
---|
228 | if not data.get("errors") and not data.get("updated"): |
---|
229 | raise NotFound("Item not found") |
---|
230 | return data |
---|
231 | elif method == "DELETE": |
---|
232 | id = id or post_vars["id"] |
---|
233 | if not id: |
---|
234 | raise InvalidFormat("No item id specified") |
---|
235 | table = self.db[tablename] |
---|
236 | deleted = self.db(table._id == id).delete() |
---|
237 | if not deleted: |
---|
238 | raise NotFound("Item not found") |
---|
239 | return {"deleted": deleted} |
---|
240 | |
---|
241 | def table_model(self, table, fieldnames): |
---|
242 | """ converts a table into its form template """ |
---|
243 | items = [] |
---|
244 | fields = post_fields = put_fields = table.fields |
---|
245 | if self.policy: |
---|
246 | fields = self.policy.allowed_fieldnames(table, method="GET") |
---|
247 | put_fields = self.policy.allowed_fieldnames(table, method="PUT") |
---|
248 | post_fields = self.policy.allowed_fieldnames(table, method="POST") |
---|
249 | for fieldname in fields: |
---|
250 | if fieldnames and not fieldname in fieldnames: |
---|
251 | continue |
---|
252 | field = table[fieldname] |
---|
253 | item = {"name": field.name, "label": field.label} |
---|
254 | # https://github.com/collection-json/extensions/blob/master/template-validation.md |
---|
255 | item["default"] = ( |
---|
256 | field.default() if callable(field.default) else field.default |
---|
257 | ) |
---|
258 | parts = field.type.split() |
---|
259 | item["type"] = parts[0].split("(")[0] |
---|
260 | if len(parts) > 1: |
---|
261 | item["references"] = parts[1] |
---|
262 | if hasattr(field, "regex"): |
---|
263 | item["regex"] = field.regex |
---|
264 | item["required"] = field.required |
---|
265 | item["unique"] = field.unique |
---|
266 | item["post_writable"] = field.name in post_fields |
---|
267 | item["put_writable"] = field.name in put_fields |
---|
268 | item["options"] = field.options |
---|
269 | if field.type == "id": |
---|
270 | item["referenced_by"] = [ |
---|
271 | "%s.%s" % (f._tablename, f.name) |
---|
272 | for f in table._referenced_by |
---|
273 | if self.policy |
---|
274 | and self.policy.check_if_allowed( |
---|
275 | "GET", f._tablename, exceptions=False |
---|
276 | ) |
---|
277 | ] |
---|
278 | items.append(item) |
---|
279 | return items |
---|
280 | |
---|
281 | @staticmethod |
---|
282 | def make_query(field, condition, value): |
---|
283 | expression = { |
---|
284 | "eq": lambda: field == value, |
---|
285 | "ne": lambda: field == value, |
---|
286 | "lt": lambda: field < value, |
---|
287 | "gt": lambda: field > value, |
---|
288 | "le": lambda: field <= value, |
---|
289 | "ge": lambda: field >= value, |
---|
290 | "startswith": lambda: field.startswith(str(value)), |
---|
291 | "in": lambda: field.belongs( |
---|
292 | value.split(",") if isinstance(value, str) else list(value) |
---|
293 | ), |
---|
294 | "contains": lambda: field.contains(value), |
---|
295 | } |
---|
296 | return expression[condition]() |
---|
297 | |
---|
298 | @staticmethod |
---|
299 | def parse_table_and_fields(text): |
---|
300 | if not RestAPI.re_table_and_fields.match(text): |
---|
301 | raise ValueError |
---|
302 | parts = text.split("[") |
---|
303 | if len(parts) == 1: |
---|
304 | return parts[0], [] |
---|
305 | elif len(parts) == 2: |
---|
306 | return parts[0], parts[1][:-1].split(",") |
---|
307 | |
---|
308 | def search(self, tname, vars): |
---|
309 | def check_table_permission(tablename): |
---|
310 | if self.policy: |
---|
311 | self.policy.check_if_allowed("GET", tablename) |
---|
312 | |
---|
313 | def check_table_lookup_permission(tablename): |
---|
314 | if self.policy: |
---|
315 | self.policy.check_if_lookup_allowed(tablename) |
---|
316 | |
---|
317 | def filter_fieldnames(table, fieldnames): |
---|
318 | if self.policy: |
---|
319 | if fieldnames: |
---|
320 | self.policy.check_fieldnames(table, fieldnames) |
---|
321 | else: |
---|
322 | fieldnames = self.policy.allowed_fieldnames(table) |
---|
323 | elif not fieldnames: |
---|
324 | fieldnames = table.fields |
---|
325 | return fieldnames |
---|
326 | |
---|
327 | db = self.db |
---|
328 | tname, tfieldnames = RestAPI.parse_table_and_fields(tname) |
---|
329 | check_table_permission(tname) |
---|
330 | tfieldnames = filter_fieldnames(db[tname], tfieldnames) |
---|
331 | query = [] |
---|
332 | offset = 0 |
---|
333 | limit = 100 |
---|
334 | model = False |
---|
335 | options_list = False |
---|
336 | table = db[tname] |
---|
337 | queries = [] |
---|
338 | if self.policy: |
---|
339 | common_query = self.policy.get(tname, "GET", "query") |
---|
340 | if common_query: |
---|
341 | queries.append(common_query) |
---|
342 | hop1 = collections.defaultdict(list) |
---|
343 | hop2 = collections.defaultdict(list) |
---|
344 | hop3 = collections.defaultdict(list) |
---|
345 | model_fieldnames = tfieldnames |
---|
346 | lookup = {} |
---|
347 | orderby = None |
---|
348 | for key, value in vars.items(): |
---|
349 | if key == "@offset": |
---|
350 | offset = int(value) |
---|
351 | elif key == "@limit": |
---|
352 | limit = min( |
---|
353 | int(value), |
---|
354 | self.policy.get(tname, "GET", "limit") |
---|
355 | if self.policy |
---|
356 | else MAX_LIMIT, |
---|
357 | ) |
---|
358 | elif key == "@order": |
---|
359 | orderby = [ |
---|
360 | ~table[f[1:]] if f[:1] == "~" else table[f] |
---|
361 | for f in value.split(",") |
---|
362 | if (f[1:] if f[:1] == "~" else f) in table.fields |
---|
363 | ] or None |
---|
364 | elif key == "@lookup": |
---|
365 | lookup = {item[0]: {} for item in RestAPI.re_lookups.findall(value)} |
---|
366 | elif key == "@model": |
---|
367 | model = str(value).lower()[:1] == "t" |
---|
368 | elif key == "@options_list": |
---|
369 | options_list = str(value).lower()[:1] == "t" |
---|
370 | else: |
---|
371 | key_parts = key.rsplit(".") |
---|
372 | if not key_parts[-1] in ( |
---|
373 | "eq", |
---|
374 | "ne", |
---|
375 | "gt", |
---|
376 | "lt", |
---|
377 | "ge", |
---|
378 | "le", |
---|
379 | "startswith", |
---|
380 | "contains", |
---|
381 | "in", |
---|
382 | ): |
---|
383 | key_parts.append("eq") |
---|
384 | is_negated = key_parts[0] == "not" |
---|
385 | if is_negated: |
---|
386 | key_parts = key_parts[1:] |
---|
387 | key, condition = key_parts[:-1], key_parts[-1] |
---|
388 | if len(key) == 1: # example: name.eq=='Chair' |
---|
389 | query = self.make_query(table[key[0]], condition, value) |
---|
390 | queries.append(query if not is_negated else ~query) |
---|
391 | elif len(key) == 2: # example: color.name.eq=='red' |
---|
392 | hop1[is_negated, key[0]].append((key[1], condition, value)) |
---|
393 | elif len(key) == 3: # example: a.rel.desc.eq=='above' |
---|
394 | hop2[is_negated, key[0], key[1]].append((key[2], condition, value)) |
---|
395 | elif len(key) == 4: # example: a.rel.b.name.eq == 'Table' |
---|
396 | hop3[is_negated, key[0], key[1], key[2]].append( |
---|
397 | (key[3], condition, value) |
---|
398 | ) |
---|
399 | |
---|
400 | for item in hop1: |
---|
401 | is_negated, fieldname = item |
---|
402 | ref_tablename = table[fieldname].type.split(" ")[1] |
---|
403 | ref_table = db[ref_tablename] |
---|
404 | subqueries = [self.make_query(ref_table[k], c, v) for k, c, v in hop1[item]] |
---|
405 | subquery = functools.reduce(lambda a, b: a & b, subqueries) |
---|
406 | query = table[fieldname].belongs(db(subquery)._select(ref_table._id)) |
---|
407 | queries.append(query if not is_negated else ~query) |
---|
408 | |
---|
409 | for item in hop2: |
---|
410 | is_negated, linkfield, linktable = item |
---|
411 | ref_table = db[linktable] |
---|
412 | subqueries = [self.make_query(ref_table[k], c, v) for k, c, v in hop2[item]] |
---|
413 | subquery = functools.reduce(lambda a, b: a & b, subqueries) |
---|
414 | query = table._id.belongs(db(subquery)._select(ref_table[linkfield])) |
---|
415 | queries.append(query if not is_negated else ~query) |
---|
416 | |
---|
417 | for item in hop3: |
---|
418 | is_negated, linkfield, linktable, otherfield = item |
---|
419 | ref_table = db[linktable] |
---|
420 | ref_ref_tablename = ref_table[otherfield].type.split(" ")[1] |
---|
421 | ref_ref_table = db[ref_ref_tablename] |
---|
422 | subqueries = [ |
---|
423 | self.make_query(ref_ref_table[k], c, v) for k, c, v in hop3[item] |
---|
424 | ] |
---|
425 | subquery = functools.reduce(lambda a, b: a & b, subqueries) |
---|
426 | subquery &= ref_ref_table._id == ref_table[otherfield] |
---|
427 | query = table._id.belongs( |
---|
428 | db(subquery)._select(ref_table[linkfield], groupby=ref_table[linkfield]) |
---|
429 | ) |
---|
430 | queries.append(query if not is_negated else ~query) |
---|
431 | |
---|
432 | if not queries: |
---|
433 | queries.append(table) |
---|
434 | |
---|
435 | query = functools.reduce(lambda a, b: a & b, queries) |
---|
436 | tfields = [table[tfieldname] for tfieldname in tfieldnames] |
---|
437 | rows = db(query).select( |
---|
438 | *tfields, limitby=(offset, limit + offset), orderby=orderby |
---|
439 | ) |
---|
440 | |
---|
441 | lookup_map = {} |
---|
442 | for key in list(lookup.keys()): |
---|
443 | name, key = key.split(":") if ":" in key else ("", key) |
---|
444 | clean_key = RestAPI.re_no_brackets.sub("", key) |
---|
445 | lookup_map[clean_key] = { |
---|
446 | "name": name.rstrip("!") or clean_key, |
---|
447 | "collapsed": name.endswith("!"), |
---|
448 | } |
---|
449 | key = key.split(".") |
---|
450 | |
---|
451 | if len(key) == 1: |
---|
452 | key, tfieldnames = RestAPI.parse_table_and_fields(key[0]) |
---|
453 | ref_tablename = table[key].type.split(" ")[1] |
---|
454 | ref_table = db[ref_tablename] |
---|
455 | tfieldnames = filter_fieldnames(ref_table, tfieldnames) |
---|
456 | check_table_lookup_permission(ref_tablename) |
---|
457 | ids = [row[key] for row in rows] |
---|
458 | tfields = [ref_table[tfieldname] for tfieldname in tfieldnames] |
---|
459 | if not "id" in tfieldnames: |
---|
460 | tfields.append(ref_table["id"]) |
---|
461 | drows = db(ref_table._id.belongs(ids)).select(*tfields).as_dict() |
---|
462 | if tfieldnames and not "id" in tfieldnames: |
---|
463 | for row in drows.values(): |
---|
464 | del row["id"] |
---|
465 | lkey, collapsed = lookup_map[key]["name"], lookup_map[key]["collapsed"] |
---|
466 | for row in rows: |
---|
467 | new_row = drows.get(row[key]) |
---|
468 | if new_row and collapsed: |
---|
469 | del row[key] |
---|
470 | for rkey in new_row: |
---|
471 | row[lkey + "_" + rkey] = new_row[rkey] |
---|
472 | else: |
---|
473 | row[lkey] = new_row |
---|
474 | |
---|
475 | elif len(key) == 2: |
---|
476 | lfield, key = key |
---|
477 | key, tfieldnames = RestAPI.parse_table_and_fields(key) |
---|
478 | check_table_lookup_permission(key) |
---|
479 | ref_table = db[key] |
---|
480 | tfieldnames = filter_fieldnames(ref_table, tfieldnames) |
---|
481 | ids = [row["id"] for row in rows] |
---|
482 | tfields = [ref_table[tfieldname] for tfieldname in tfieldnames] |
---|
483 | if not lfield in tfieldnames: |
---|
484 | tfields.append(ref_table[lfield]) |
---|
485 | lrows = db(ref_table[lfield].belongs(ids)).select(*tfields) |
---|
486 | drows = collections.defaultdict(list) |
---|
487 | for row in lrows: |
---|
488 | row = row.as_dict() |
---|
489 | drows[row[lfield]].append(row) |
---|
490 | if not lfield in tfieldnames: |
---|
491 | del row[lfield] |
---|
492 | lkey = lookup_map[lfield + "." + key]["name"] |
---|
493 | for row in rows: |
---|
494 | row[lkey] = drows.get(row.id, []) |
---|
495 | |
---|
496 | elif len(key) == 3: |
---|
497 | lfield, key, rfield = key |
---|
498 | key, tfieldnames = RestAPI.parse_table_and_fields(key) |
---|
499 | rfield, tfieldnames2 = RestAPI.parse_table_and_fields(rfield) |
---|
500 | check_table_lookup_permission(key) |
---|
501 | ref_table = db[key] |
---|
502 | ref_ref_tablename = ref_table[rfield].type.split(" ")[1] |
---|
503 | check_table_lookup_permission(ref_ref_tablename) |
---|
504 | ref_ref_table = db[ref_ref_tablename] |
---|
505 | tfieldnames = filter_fieldnames(ref_table, tfieldnames) |
---|
506 | tfieldnames2 = filter_fieldnames(ref_ref_table, tfieldnames2) |
---|
507 | ids = [row["id"] for row in rows] |
---|
508 | tfields = [ref_table[tfieldname] for tfieldname in tfieldnames] |
---|
509 | if not lfield in tfieldnames: |
---|
510 | tfields.append(ref_table[lfield]) |
---|
511 | if not rfield in tfieldnames: |
---|
512 | tfields.append(ref_table[rfield]) |
---|
513 | tfields += [ref_ref_table[tfieldname] for tfieldname in tfieldnames2] |
---|
514 | left = ref_ref_table.on(ref_table[rfield] == ref_ref_table["id"]) |
---|
515 | lrows = db(ref_table[lfield].belongs(ids)).select(*tfields, left=left) |
---|
516 | drows = collections.defaultdict(list) |
---|
517 | lkey = lfield + "." + key + "." + rfield |
---|
518 | lkey, collapsed = ( |
---|
519 | lookup_map[lkey]["name"], |
---|
520 | lookup_map[lkey]["collapsed"], |
---|
521 | ) |
---|
522 | for row in lrows: |
---|
523 | row = row.as_dict() |
---|
524 | new_row = row[key] |
---|
525 | lfield_value, rfield_value = new_row[lfield], new_row[rfield] |
---|
526 | if not lfield in tfieldnames: |
---|
527 | del new_row[lfield] |
---|
528 | if not rfield in tfieldnames: |
---|
529 | del new_row[rfield] |
---|
530 | if collapsed: |
---|
531 | new_row.update(row[ref_ref_tablename]) |
---|
532 | else: |
---|
533 | new_row[rfield] = row[ref_ref_tablename] |
---|
534 | drows[lfield_value].append(new_row) |
---|
535 | for row in rows: |
---|
536 | row[lkey] = drows.get(row.id, []) |
---|
537 | |
---|
538 | response = {} |
---|
539 | if not options_list: |
---|
540 | response["items"] = rows.as_list() |
---|
541 | else: |
---|
542 | if table._format: |
---|
543 | response["items"] = [ |
---|
544 | dict(value=row.id, text=(table._format % row)) for row in rows |
---|
545 | ] |
---|
546 | else: |
---|
547 | response["items"] = [dict(value=row.id, text=row.id) for row in rows] |
---|
548 | if offset == 0: |
---|
549 | response["count"] = db(query).count() |
---|
550 | if model: |
---|
551 | response["model"] = self.table_model(table, model_fieldnames) |
---|
552 | return response |
---|