source: OpenRLabs-Git/deploy/rlabs-docker/web2py-rlabs/gluon/packages/dal/pydal/adapters/oracle.py

main
Last change on this file was 42bd667, checked in by David Fuertes <dfuertes@…>, 4 years ago

Historial Limpio

  • Property mode set to 100755
File size: 8.9 KB
Line 
1import re
2import sys
3import os
4from .._compat import integer_types, long
5from ..helpers.classes import Reference
6from ..helpers.methods import use_common_filters
7from .base import SQLAdapter
8from ..objects import Table, Field, Expression, Query
9from . import adapters, with_connection, with_connection_or_raise
10
11
12@adapters.register_for("oracle")
13class Oracle(SQLAdapter):
14    dbengine = "oracle"
15    drivers = ("cx_Oracle",)
16
17    cmd_fix = re.compile("[^']*('[^']*'[^']*)*\:(?P<clob>(C|B)LOB\('([^']+|'')*'\))")
18
19    def _initialize_(self, do_connect):
20        super(Oracle, self)._initialize_(do_connect)
21        self.ruri = self.uri.split("://", 1)[1]
22        if "threaded" not in self.driver_args:
23            self.driver_args["threaded"] = True
24        # set character encoding defaults
25        if "encoding" not in self.driver_args:
26            self.driver_args["encoding"] = "UTF-8"
27        if "nencoding" not in self.driver_args:
28            self.driver_args["nencoding"] = "UTF-8"
29
30    def connector(self):
31        return self.driver.connect(self.ruri, **self.driver_args)
32
33    def after_connection(self):
34        self.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS';")
35        self.execute(
36            "ALTER SESSION SET NLS_TIMESTAMP_FORMAT = " + "'YYYY-MM-DD HH24:MI:SS';"
37        )
38
39    def test_connection(self):
40        self.execute("SELECT 1 FROM DUAL;")
41
42    @with_connection
43    def close_connection(self):
44        self.connection = None
45
46    @with_connection_or_raise
47    def execute(self, *args, **kwargs):
48        command = self.filter_sql_command(args[0])
49        i = 1
50        while True:
51            m = re.match(self.REGEX_CLOB, command)
52            if not m:
53                break
54            command = command[: m.start("clob")] + str(i) + command[m.end("clob") :]
55            args = args + (m.group("clob")[6:-2].replace("''", "'"),)
56            i += 1
57        if command[-1:] == ";":
58            command = command[:-1]
59        handlers = self._build_handlers_for_execution()
60        for handler in handlers:
61            handler.before_execute(command)
62        if len(args) > 1:
63            rv = self.cursor.execute(command, args[1:], **kwargs)
64        else:
65            rv = self.cursor.execute(command, **kwargs)
66        for handler in handlers:
67            handler.after_execute(command)
68        return rv
69
70    def lastrowid(self, table):
71        sequence_name = table._sequence_name
72        self.execute("SELECT %s.currval FROM dual;" % sequence_name)
73        return long(self.cursor.fetchone()[0])
74
75    def sqlsafe_table(self, tablename, original_tablename=None):
76        if original_tablename is not None:
77            return self.dialect.alias(original_tablename, tablename)
78        return self.dialect.quote(tablename)
79
80    def create_sequence_and_triggers(self, query, table, **args):
81        tablename = table._rname
82        if not "_id" in table:
83            return self.execute(query)
84        id_name = table._id._rname
85        sequence_name = table._sequence_name
86        trigger_name = table._trigger_name
87        self.execute(query)
88        self.execute(
89            """
90            CREATE SEQUENCE %s START WITH 1 INCREMENT BY 1 NOMAXVALUE
91            MINVALUE -1;"""
92            % sequence_name
93        )
94        self.execute(
95            _trigger_sql
96            % dict(
97                trigger_name=self.dialect.quote(trigger_name),
98                tablename=self.dialect.quote(tablename),
99                sequence_name=self.dialect.quote(sequence_name),
100                id=self.dialect.quote(id_name),
101            )
102        )
103
104    def _select_aux_execute(self, sql):
105        self.execute(sql)
106        return self.fetchall()
107
108    def fetchall(self):
109        from ..drivers import cx_Oracle
110
111        if any(
112            x[1] == cx_Oracle.LOB or x[1] == cx_Oracle.CLOB
113            for x in self.cursor.description
114        ):
115            return [
116                tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r])
117                for r in self.cursor
118            ]
119        else:
120            return self.cursor.fetchall()
121
122    def sqlsafe_table(self, tablename, original_tablename=None):
123        if original_tablename is not None:
124            return "%s %s" % (
125                self.dialect.quote(original_tablename),
126                self.dialect.quote(tablename),
127            )
128        return self.dialect.quote(tablename)
129
130    def _expand(self, expression, field_type=None, colnames=False, query_env={}):
131        # override default expand to ensure quoted fields
132        if isinstance(expression, Field):
133            if not colnames:
134                rv = self.dialect.sqlsafe(expression)
135            else:
136                rv = self.dialect.longname(expression)
137            if field_type == "string" and expression.type not in (
138                "string",
139                "text",
140                "json",
141                "password",
142            ):
143                rv = self.dialect.cast(rv, self.types["text"], query_env)
144            return str(rv)
145        else:
146            return super(Oracle, self)._expand(
147                expression, field_type, colnames, query_env
148            )
149
150    def expand(self, expression, field_type=None, colnames=False, query_env={}):
151        return self._expand(expression, field_type, colnames, query_env)
152
153    def _build_value_for_insert(self, field, value, r_values):
154        if field.type is "text":
155            _rname = (field._rname[1] == '"') and field._rname[1:-1] or field._rname
156            r_values[_rname] = value
157            return ":" + _rname
158        return self.expand(value, field.type)
159
160    def _update(self, table, query, fields):
161        sql_q = ""
162        query_env = dict(current_scope=[table._tablename])
163        if query:
164            if use_common_filters(query):
165                query = self.common_filter(query, [table])
166            sql_q = self.expand(query, query_env=query_env)
167        sql_v = ",".join(
168            [
169                "%s=%s"
170                % (
171                    self.dialect.quote(field._rname),
172                    self.expand(value, field.type, query_env=query_env),
173                )
174                for (field, value) in fields
175            ]
176        )
177        return self.dialect.update(table, sql_v, sql_q)
178
179    def _insert(self, table, fields):
180        if fields:
181            r_values = {}
182            return (
183                self.dialect.insert(
184                    table._rname,
185                    ",".join(self.dialect.quote(el[0]._rname) for el in fields),
186                    ",".join(
187                        self._build_value_for_insert(f, v, r_values) for f, v in fields
188                    ),
189                ),
190                r_values,
191            )
192        return self.dialect.insert_empty(table._rname), None
193
194    def insert(self, table, fields):
195        query, values = self._insert(table, fields)
196        try:
197            if not values:
198                self.execute(query)
199            else:
200                if type(values) == dict:
201                    self.execute(query, **values)
202                else:
203                    self.execute(query, values)
204        except:
205            e = sys.exc_info()[1]
206            if hasattr(table, "_on_insert_error"):
207                return table._on_insert_error(table, fields, e)
208            raise e
209        if hasattr(table, "_primarykey"):
210            pkdict = dict(
211                [(k[0].name, k[1]) for k in fields if k[0].name in table._primarykey]
212            )
213            if pkdict:
214                return pkdict
215        id = self.lastrowid(table)
216        if hasattr(table, "_primarykey") and len(table._primarykey) == 1:
217            id = {table._primarykey[0]: id}
218        if not isinstance(id, integer_types):
219            return id
220        rid = Reference(id)
221        (rid._table, rid._record) = (table, None)
222        return rid
223
224    def _regex_select_as_parser(self, colname):
225        return re.compile('\s+"(\S+)"').search(colname)
226
227    def parse(self, rows, fields, colnames, blob_decode=True, cacheable=False):
228        if len(rows) and len(rows[0]) == len(fields) + 1 and type(rows[0][-1]) == int:
229            # paging has added a trailing rownum column to be discarded
230            rows = [row[:-1] for row in rows]
231        return super(Oracle, self).parse(rows, fields, colnames, blob_decode, cacheable)
232
233
234_trigger_sql = """
235    CREATE OR REPLACE TRIGGER %(trigger_name)s BEFORE INSERT ON %(tablename)s FOR EACH ROW
236    DECLARE
237        curr_val NUMBER;
238        diff_val NUMBER;
239        PRAGMA autonomous_transaction;
240    BEGIN
241        IF :NEW.%(id)s IS NOT NULL THEN
242            EXECUTE IMMEDIATE 'SELECT %(sequence_name)s.nextval FROM dual' INTO curr_val;
243            diff_val := :NEW.%(id)s - curr_val - 1;
244            IF diff_val != 0 THEN
245            EXECUTE IMMEDIATE 'alter sequence %(sequence_name)s increment by '|| diff_val;
246            EXECUTE IMMEDIATE 'SELECT %(sequence_name)s.nextval FROM dual' INTO curr_val;
247            EXECUTE IMMEDIATE 'alter sequence %(sequence_name)s increment by 1';
248            END IF;
249        END IF;
250        SELECT %(sequence_name)s.nextval INTO :NEW.%(id)s FROM DUAL;
251    END;
252"""
Note: See TracBrowser for help on using the repository browser.