source: OpenRLabs-Git/deploy/rlabs-docker/web2py-rlabs/gluon/packages/dal/pydal/restapi.py

main
Last change on this file was 42bd667, checked in by David Fuertes <dfuertes@…>, 4 years ago

Historial Limpio

  • Property mode set to 100755
File size: 21.3 KB
Line 
1import collections
2import copy
3import datetime
4import fnmatch
5import functools
6import re
7
8__version__ = "0.1"
9
10__all__ = ["RestAPI", "Policy", "ALLOW_ALL_POLICY", "DENY_ALL_POLICY"]
11
12MAX_LIMIT = 1000
13
14
15class PolicyViolation(ValueError):
16    pass
17
18
19class InvalidFormat(ValueError):
20    pass
21
22
23class NotFound(ValueError):
24    pass
25
26
27def maybe_call(value):
28    return value() if callable(value) else value
29
30
31def error_wrapper(func):
32    @functools.wraps(func)
33    def wrapper(*args, **kwargs):
34        data = {}
35        try:
36            data = func(*args, **kwargs)
37            if not data.get("errors"):
38                data["status"] = "success"
39                data["code"] = 200
40            else:
41                data["status"] = "error"
42                data["message"] = "Validation Errors"
43                data["code"] = 422
44        except PolicyViolation as e:
45            data["status"] = "error"
46            data["message"] = str(e)
47            data["code"] = 401
48        except NotFound as e:
49            data["status"] = "error"
50            data["message"] = str(e)
51            data["code"] = 404
52        except (InvalidFormat, KeyError, ValueError) as e:
53            data["status"] = "error"
54            data["message"] = str(e)
55            data["code"] = 400
56        finally:
57            data["timestamp"] = datetime.datetime.utcnow().isoformat()
58            data["api_version"] = __version__
59        return data
60
61    return wrapper
62
63
64class Policy(object):
65
66    model = {
67        "POST": {"authorize": False, "fields": None},
68        "PUT": {"authorize": False, "fields": None},
69        "DELETE": {"authorize": False},
70        "GET": {
71            "authorize": False,
72            "fields": None,
73            "query": None,
74            "allowed_patterns": [],
75            "denied_patterns": [],
76            "limit": MAX_LIMIT,
77            "allow_lookup": False,
78        },
79    }
80
81    def __init__(self):
82        self.info = {}
83
84    def set(self, tablename, method, **attributes):
85        method = method.upper()
86        if not method in self.model:
87            raise InvalidFormat("Invalid policy method: %s" % method)
88        invalid_keys = [key for key in attributes if key not in self.model[method]]
89        if invalid_keys:
90            raise InvalidFormat("Invalid keys: %s" % ",".join(invalid_keys))
91        if not tablename in self.info:
92            self.info[tablename] = copy.deepcopy(self.model)
93        self.info[tablename][method].update(attributes)
94
95    def get(self, tablename, method, name):
96        policy = self.info.get(tablename) or self.info.get("*")
97        if not policy:
98            raise PolicyViolation("No policy for this object")
99        return maybe_call(policy[method][name])
100
101    def check_if_allowed(
102        self, method, tablename, id=None, get_vars=None, post_vars=None, exceptions=True
103    ):
104        get_vars = get_vars or {}
105        post_vars = post_vars or {}
106        policy = self.info.get(tablename) or self.info.get("*")
107        if not policy:
108            if exceptions:
109                raise PolicyViolation("No policy for this object")
110            return False
111        policy = policy.get(method.upper())
112        if not policy:
113            if exceptions:
114                raise PolicyViolation("No policy for this method")
115            return False
116        authorize = policy.get("authorize")
117        if authorize is False or (
118            callable(authorize) and not authorize(tablename, id, get_vars, post_vars)
119        ):
120            if exceptions:
121                raise PolicyViolation("Not authorized")
122            return False
123        for key in get_vars:
124            if any(fnmatch.fnmatch(key, p) for p in policy["denied_patterns"]):
125                if exceptions:
126                    raise PolicyViolation("Pattern is not allowed")
127                return False
128            allowed_patterns = policy["allowed_patterns"]
129            if "**" not in allowed_patterns and not any(
130                fnmatch.fnmatch(key, p) for p in allowed_patterns
131            ):
132                if exceptions:
133                    raise PolicyViolation("Pattern is not explicitely allowed")
134                return False
135        return True
136
137    def check_if_lookup_allowed(self, tablename, exceptions=True):
138        policy = self.info.get(tablename) or self.info.get("*")
139        if not policy:
140            if exceptions:
141                raise PolicyViolation("No policy for this object")
142            return False
143        policy = policy.get("GET")
144        if not policy:
145            if exceptions:
146                raise PolicyViolation("No policy for this method")
147            return False
148        if policy.get("allow_lookup"):
149            return True
150        return False
151
152    def allowed_fieldnames(self, table, method="GET"):
153        method = method.upper()
154        policy = self.info.get(table._tablename) or self.info.get("*")
155        policy = policy[method]
156        allowed_fieldnames = policy["fields"]
157        if not allowed_fieldnames:
158            allowed_fieldnames = [
159                f.name
160                for f in table
161                if (method == "GET" and maybe_call(f.readable))
162                or (method != "GET" and maybe_call(f.writable))
163            ]
164        return allowed_fieldnames
165
166    def check_fieldnames(self, table, fieldnames, method="GET"):
167        allowed_fieldnames = self.allowed_fieldnames(table, method)
168        invalid_fieldnames = set(fieldnames) - set(allowed_fieldnames)
169        if invalid_fieldnames:
170            raise InvalidFormat("Invalid fields: %s" % list(invalid_fieldnames))
171
172
173DENY_ALL_POLICY = Policy()
174ALLOW_ALL_POLICY = Policy()
175ALLOW_ALL_POLICY.set(
176    tablename="*",
177    method="GET",
178    authorize=True,
179    allowed_patterns=["**"],
180    allow_lookup=True,
181)
182ALLOW_ALL_POLICY.set(tablename="*", method="POST", authorize=True)
183ALLOW_ALL_POLICY.set(tablename="*", method="PUT", authorize=True)
184ALLOW_ALL_POLICY.set(tablename="*", method="DELETE", authorize=True)
185
186
187class RestAPI(object):
188
189    re_table_and_fields = re.compile(r"\w+([\w+(,\w+)+])?")
190    re_lookups = re.compile(
191        r"((\w*\!?\:)?(\w+(\[\w+(,\w+)*\])?)(\.\w+(\[\w+(,\w+)*\])?)*)"
192    )
193    re_no_brackets = re.compile(r"\[.*?\]")
194
195    def __init__(self, db, policy):
196        self.db = db
197        self.policy = policy
198
199    @error_wrapper
200    def __call__(self, method, tablename, id=None, get_vars=None, post_vars=None):
201        method = method.upper()
202        get_vars = get_vars or {}
203        post_vars = post_vars or {}
204        # validate incoming request
205        tname, tfieldnames = RestAPI.parse_table_and_fields(tablename)
206        if not tname in self.db.tables:
207            raise InvalidFormat("Invalid table name: %s" % tname)
208        if self.policy:
209            self.policy.check_if_allowed(method, tablename, id, get_vars, post_vars)
210            if method in ["POST", "PUT"]:
211                self.policy.check_fieldnames(
212                    self.db[tablename], post_vars.keys(), method
213                )
214        # apply rules
215        if method == "GET":
216            if id:
217                get_vars["id.eq"] = id
218            return self.search(tablename, get_vars)
219        elif method == "POST":
220            table = self.db[tablename]
221            return table.validate_and_insert(**post_vars).as_dict()
222        elif method == "PUT":
223            id = id or post_vars["id"]
224            if not id:
225                raise InvalidFormat("No item id specified")
226            table = self.db[tablename]
227            data = table.validate_and_update(id, **post_vars).as_dict()
228            if not data.get("errors") and not data.get("updated"):
229                raise NotFound("Item not found")
230            return data
231        elif method == "DELETE":
232            id = id or post_vars["id"]
233            if not id:
234                raise InvalidFormat("No item id specified")
235            table = self.db[tablename]
236            deleted = self.db(table._id == id).delete()
237            if not deleted:
238                raise NotFound("Item not found")
239            return {"deleted": deleted}
240
241    def table_model(self, table, fieldnames):
242        """ converts a table into its form template """
243        items = []
244        fields = post_fields = put_fields = table.fields
245        if self.policy:
246            fields = self.policy.allowed_fieldnames(table, method="GET")
247            put_fields = self.policy.allowed_fieldnames(table, method="PUT")
248            post_fields = self.policy.allowed_fieldnames(table, method="POST")
249        for fieldname in fields:
250            if fieldnames and not fieldname in fieldnames:
251                continue
252            field = table[fieldname]
253            item = {"name": field.name, "label": field.label}
254            # https://github.com/collection-json/extensions/blob/master/template-validation.md
255            item["default"] = (
256                field.default() if callable(field.default) else field.default
257            )
258            parts = field.type.split()
259            item["type"] = parts[0].split("(")[0]
260            if len(parts) > 1:
261                item["references"] = parts[1]
262            if hasattr(field, "regex"):
263                item["regex"] = field.regex
264            item["required"] = field.required
265            item["unique"] = field.unique
266            item["post_writable"] = field.name in post_fields
267            item["put_writable"] = field.name in put_fields
268            item["options"] = field.options
269            if field.type == "id":
270                item["referenced_by"] = [
271                    "%s.%s" % (f._tablename, f.name)
272                    for f in table._referenced_by
273                    if self.policy
274                    and self.policy.check_if_allowed(
275                        "GET", f._tablename, exceptions=False
276                    )
277                ]
278            items.append(item)
279        return items
280
281    @staticmethod
282    def make_query(field, condition, value):
283        expression = {
284            "eq": lambda: field == value,
285            "ne": lambda: field == value,
286            "lt": lambda: field < value,
287            "gt": lambda: field > value,
288            "le": lambda: field <= value,
289            "ge": lambda: field >= value,
290            "startswith": lambda: field.startswith(str(value)),
291            "in": lambda: field.belongs(
292                value.split(",") if isinstance(value, str) else list(value)
293            ),
294            "contains": lambda: field.contains(value),
295        }
296        return expression[condition]()
297
298    @staticmethod
299    def parse_table_and_fields(text):
300        if not RestAPI.re_table_and_fields.match(text):
301            raise ValueError
302        parts = text.split("[")
303        if len(parts) == 1:
304            return parts[0], []
305        elif len(parts) == 2:
306            return parts[0], parts[1][:-1].split(",")
307
308    def search(self, tname, vars):
309        def check_table_permission(tablename):
310            if self.policy:
311                self.policy.check_if_allowed("GET", tablename)
312
313        def check_table_lookup_permission(tablename):
314            if self.policy:
315                self.policy.check_if_lookup_allowed(tablename)
316
317        def filter_fieldnames(table, fieldnames):
318            if self.policy:
319                if fieldnames:
320                    self.policy.check_fieldnames(table, fieldnames)
321                else:
322                    fieldnames = self.policy.allowed_fieldnames(table)
323            elif not fieldnames:
324                fieldnames = table.fields
325            return fieldnames
326
327        db = self.db
328        tname, tfieldnames = RestAPI.parse_table_and_fields(tname)
329        check_table_permission(tname)
330        tfieldnames = filter_fieldnames(db[tname], tfieldnames)
331        query = []
332        offset = 0
333        limit = 100
334        model = False
335        options_list = False
336        table = db[tname]
337        queries = []
338        if self.policy:
339            common_query = self.policy.get(tname, "GET", "query")
340            if common_query:
341                queries.append(common_query)
342        hop1 = collections.defaultdict(list)
343        hop2 = collections.defaultdict(list)
344        hop3 = collections.defaultdict(list)
345        model_fieldnames = tfieldnames
346        lookup = {}
347        orderby = None
348        for key, value in vars.items():
349            if key == "@offset":
350                offset = int(value)
351            elif key == "@limit":
352                limit = min(
353                    int(value),
354                    self.policy.get(tname, "GET", "limit")
355                    if self.policy
356                    else MAX_LIMIT,
357                )
358            elif key == "@order":
359                orderby = [
360                    ~table[f[1:]] if f[:1] == "~" else table[f]
361                    for f in value.split(",")
362                    if (f[1:] if f[:1] == "~" else f) in table.fields
363                ] or None
364            elif key == "@lookup":
365                lookup = {item[0]: {} for item in RestAPI.re_lookups.findall(value)}
366            elif key == "@model":
367                model = str(value).lower()[:1] == "t"
368            elif key == "@options_list":
369                options_list = str(value).lower()[:1] == "t"
370            else:
371                key_parts = key.rsplit(".")
372                if not key_parts[-1] in (
373                    "eq",
374                    "ne",
375                    "gt",
376                    "lt",
377                    "ge",
378                    "le",
379                    "startswith",
380                    "contains",
381                    "in",
382                ):
383                    key_parts.append("eq")
384                is_negated = key_parts[0] == "not"
385                if is_negated:
386                    key_parts = key_parts[1:]
387                key, condition = key_parts[:-1], key_parts[-1]
388                if len(key) == 1:  # example: name.eq=='Chair'
389                    query = self.make_query(table[key[0]], condition, value)
390                    queries.append(query if not is_negated else ~query)
391                elif len(key) == 2:  # example: color.name.eq=='red'
392                    hop1[is_negated, key[0]].append((key[1], condition, value))
393                elif len(key) == 3:  # example: a.rel.desc.eq=='above'
394                    hop2[is_negated, key[0], key[1]].append((key[2], condition, value))
395                elif len(key) == 4:  # example: a.rel.b.name.eq == 'Table'
396                    hop3[is_negated, key[0], key[1], key[2]].append(
397                        (key[3], condition, value)
398                    )
399
400        for item in hop1:
401            is_negated, fieldname = item
402            ref_tablename = table[fieldname].type.split(" ")[1]
403            ref_table = db[ref_tablename]
404            subqueries = [self.make_query(ref_table[k], c, v) for k, c, v in hop1[item]]
405            subquery = functools.reduce(lambda a, b: a & b, subqueries)
406            query = table[fieldname].belongs(db(subquery)._select(ref_table._id))
407            queries.append(query if not is_negated else ~query)
408
409        for item in hop2:
410            is_negated, linkfield, linktable = item
411            ref_table = db[linktable]
412            subqueries = [self.make_query(ref_table[k], c, v) for k, c, v in hop2[item]]
413            subquery = functools.reduce(lambda a, b: a & b, subqueries)
414            query = table._id.belongs(db(subquery)._select(ref_table[linkfield]))
415            queries.append(query if not is_negated else ~query)
416
417        for item in hop3:
418            is_negated, linkfield, linktable, otherfield = item
419            ref_table = db[linktable]
420            ref_ref_tablename = ref_table[otherfield].type.split(" ")[1]
421            ref_ref_table = db[ref_ref_tablename]
422            subqueries = [
423                self.make_query(ref_ref_table[k], c, v) for k, c, v in hop3[item]
424            ]
425            subquery = functools.reduce(lambda a, b: a & b, subqueries)
426            subquery &= ref_ref_table._id == ref_table[otherfield]
427            query = table._id.belongs(
428                db(subquery)._select(ref_table[linkfield], groupby=ref_table[linkfield])
429            )
430            queries.append(query if not is_negated else ~query)
431
432        if not queries:
433            queries.append(table)
434
435        query = functools.reduce(lambda a, b: a & b, queries)
436        tfields = [table[tfieldname] for tfieldname in tfieldnames]
437        rows = db(query).select(
438            *tfields, limitby=(offset, limit + offset), orderby=orderby
439        )
440
441        lookup_map = {}
442        for key in list(lookup.keys()):
443            name, key = key.split(":") if ":" in key else ("", key)
444            clean_key = RestAPI.re_no_brackets.sub("", key)
445            lookup_map[clean_key] = {
446                "name": name.rstrip("!") or clean_key,
447                "collapsed": name.endswith("!"),
448            }
449            key = key.split(".")
450
451            if len(key) == 1:
452                key, tfieldnames = RestAPI.parse_table_and_fields(key[0])
453                ref_tablename = table[key].type.split(" ")[1]
454                ref_table = db[ref_tablename]
455                tfieldnames = filter_fieldnames(ref_table, tfieldnames)
456                check_table_lookup_permission(ref_tablename)
457                ids = [row[key] for row in rows]
458                tfields = [ref_table[tfieldname] for tfieldname in tfieldnames]
459                if not "id" in tfieldnames:
460                    tfields.append(ref_table["id"])
461                drows = db(ref_table._id.belongs(ids)).select(*tfields).as_dict()
462                if tfieldnames and not "id" in tfieldnames:
463                    for row in drows.values():
464                        del row["id"]
465                lkey, collapsed = lookup_map[key]["name"], lookup_map[key]["collapsed"]
466                for row in rows:
467                    new_row = drows.get(row[key])
468                    if new_row and collapsed:
469                        del row[key]
470                        for rkey in new_row:
471                            row[lkey + "_" + rkey] = new_row[rkey]
472                    else:
473                        row[lkey] = new_row
474
475            elif len(key) == 2:
476                lfield, key = key
477                key, tfieldnames = RestAPI.parse_table_and_fields(key)
478                check_table_lookup_permission(key)
479                ref_table = db[key]
480                tfieldnames = filter_fieldnames(ref_table, tfieldnames)
481                ids = [row["id"] for row in rows]
482                tfields = [ref_table[tfieldname] for tfieldname in tfieldnames]
483                if not lfield in tfieldnames:
484                    tfields.append(ref_table[lfield])
485                lrows = db(ref_table[lfield].belongs(ids)).select(*tfields)
486                drows = collections.defaultdict(list)
487                for row in lrows:
488                    row = row.as_dict()
489                    drows[row[lfield]].append(row)
490                    if not lfield in tfieldnames:
491                        del row[lfield]
492                lkey = lookup_map[lfield + "." + key]["name"]
493                for row in rows:
494                    row[lkey] = drows.get(row.id, [])
495
496            elif len(key) == 3:
497                lfield, key, rfield = key
498                key, tfieldnames = RestAPI.parse_table_and_fields(key)
499                rfield, tfieldnames2 = RestAPI.parse_table_and_fields(rfield)
500                check_table_lookup_permission(key)
501                ref_table = db[key]
502                ref_ref_tablename = ref_table[rfield].type.split(" ")[1]
503                check_table_lookup_permission(ref_ref_tablename)
504                ref_ref_table = db[ref_ref_tablename]
505                tfieldnames = filter_fieldnames(ref_table, tfieldnames)
506                tfieldnames2 = filter_fieldnames(ref_ref_table, tfieldnames2)
507                ids = [row["id"] for row in rows]
508                tfields = [ref_table[tfieldname] for tfieldname in tfieldnames]
509                if not lfield in tfieldnames:
510                    tfields.append(ref_table[lfield])
511                if not rfield in tfieldnames:
512                    tfields.append(ref_table[rfield])
513                tfields += [ref_ref_table[tfieldname] for tfieldname in tfieldnames2]
514                left = ref_ref_table.on(ref_table[rfield] == ref_ref_table["id"])
515                lrows = db(ref_table[lfield].belongs(ids)).select(*tfields, left=left)
516                drows = collections.defaultdict(list)
517                lkey = lfield + "." + key + "." + rfield
518                lkey, collapsed = (
519                    lookup_map[lkey]["name"],
520                    lookup_map[lkey]["collapsed"],
521                )
522                for row in lrows:
523                    row = row.as_dict()
524                    new_row = row[key]
525                    lfield_value, rfield_value = new_row[lfield], new_row[rfield]
526                    if not lfield in tfieldnames:
527                        del new_row[lfield]
528                    if not rfield in tfieldnames:
529                        del new_row[rfield]
530                    if collapsed:
531                        new_row.update(row[ref_ref_tablename])
532                    else:
533                        new_row[rfield] = row[ref_ref_tablename]
534                    drows[lfield_value].append(new_row)
535                for row in rows:
536                    row[lkey] = drows.get(row.id, [])
537
538        response = {}
539        if not options_list:
540            response["items"] = rows.as_list()
541        else:
542            if table._format:
543                response["items"] = [
544                    dict(value=row.id, text=(table._format % row)) for row in rows
545                ]
546            else:
547                response["items"] = [dict(value=row.id, text=row.id) for row in rows]
548        if offset == 0:
549            response["count"] = db(query).count()
550        if model:
551            response["model"] = self.table_model(table, model_fieldnames)
552        return response
Note: See TracBrowser for help on using the repository browser.