source: OpenRLabs-Git/deploy/rlabs-docker/web2py-rlabs/gluon/contrib/pymysql/cursors.py

main
Last change on this file was 42bd667, checked in by David Fuertes <dfuertes@…>, 4 years ago

Historial Limpio

  • Property mode set to 100755
File size: 16.4 KB
Line 
1# -*- coding: utf-8 -*-
2from __future__ import print_function, absolute_import
3from functools import partial
4import re
5import warnings
6
7from ._compat import range_type, text_type, PY2
8from . import err
9
10
11#: Regular expression for :meth:`Cursor.executemany`.
12#: executemany only suports simple bulk insert.
13#: You can use it to load large dataset.
14RE_INSERT_VALUES = re.compile(
15    r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
16    r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
17    r"(\s*(?:ON DUPLICATE.*)?)\Z",
18    re.IGNORECASE | re.DOTALL)
19
20
21class Cursor(object):
22    """
23    This is the object you use to interact with the database.
24    """
25
26    #: Max stetement size which :meth:`executemany` generates.
27    #:
28    #: Max size of allowed statement is max_allowed_packet - packet_header_size.
29    #: Default value of max_allowed_packet is 1048576.
30    max_stmt_length = 1024000
31
32    _defer_warnings = False
33
34    def __init__(self, connection):
35        """
36        Do not create an instance of a Cursor yourself. Call
37        connections.Connection.cursor().
38        """
39        self.connection = connection
40        self.description = None
41        self.rownumber = 0
42        self.rowcount = -1
43        self.arraysize = 1
44        self._executed = None
45        self._result = None
46        self._rows = None
47        self._warnings_handled = False
48
49    def close(self):
50        """
51        Closing a cursor just exhausts all remaining data.
52        """
53        conn = self.connection
54        if conn is None:
55            return
56        try:
57            while self.nextset():
58                pass
59        finally:
60            self.connection = None
61
62    def __enter__(self):
63        return self
64
65    def __exit__(self, *exc_info):
66        del exc_info
67        self.close()
68
69    def _get_db(self):
70        if not self.connection:
71            raise err.ProgrammingError("Cursor closed")
72        return self.connection
73
74    def _check_executed(self):
75        if not self._executed:
76            raise err.ProgrammingError("execute() first")
77
78    def _conv_row(self, row):
79        return row
80
81    def setinputsizes(self, *args):
82        """Does nothing, required by DB API."""
83
84    def setoutputsizes(self, *args):
85        """Does nothing, required by DB API."""
86
87    def _nextset(self, unbuffered=False):
88        """Get the next query set"""
89        conn = self._get_db()
90        current_result = self._result
91        # for unbuffered queries warnings are only available once whole result has been read
92        if unbuffered:
93            self._show_warnings()
94        if current_result is None or current_result is not conn._result:
95            return None
96        if not current_result.has_next:
97            return None
98        conn.next_result(unbuffered=unbuffered)
99        self._do_get_result()
100        return True
101
102    def nextset(self):
103        return self._nextset(False)
104
105    def _ensure_bytes(self, x, encoding=None):
106        if isinstance(x, text_type):
107            x = x.encode(encoding)
108        elif isinstance(x, (tuple, list)):
109            x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
110        return x
111
112    def _escape_args(self, args, conn):
113        ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
114
115        if isinstance(args, (tuple, list)):
116            if PY2:
117                args = tuple(map(ensure_bytes, args))
118            return tuple(conn.literal(arg) for arg in args)
119        elif isinstance(args, dict):
120            if PY2:
121                args = dict((ensure_bytes(key), ensure_bytes(val)) for
122                            (key, val) in args.items())
123            return dict((key, conn.literal(val)) for (key, val) in args.items())
124        else:
125            # If it's not a dictionary let's try escaping it anyways.
126            # Worst case it will throw a Value error
127            if PY2:
128                args = ensure_bytes(args)
129            return conn.escape(args)
130
131    def mogrify(self, query, args=None):
132        """
133        Returns the exact string that is sent to the database by calling the
134        execute() method.
135
136        This method follows the extension to the DB API 2.0 followed by Psycopg.
137        """
138        conn = self._get_db()
139        if PY2:  # Use bytes on Python 2 always
140            query = self._ensure_bytes(query, encoding=conn.encoding)
141
142        if args is not None:
143            query = query % self._escape_args(args, conn)
144
145        return query
146
147    def execute(self, query, args=None):
148        """Execute a query
149
150        :param str query: Query to execute.
151
152        :param args: parameters used with query. (optional)
153        :type args: tuple, list or dict
154
155        :return: Number of affected rows
156        :rtype: int
157
158        If args is a list or tuple, %s can be used as a placeholder in the query.
159        If args is a dict, %(name)s can be used as a placeholder in the query.
160        """
161        while self.nextset():
162            pass
163
164        query = self.mogrify(query, args)
165
166        result = self._query(query)
167        self._executed = query
168        return result
169
170    def executemany(self, query, args):
171        # type: (str, list) -> int
172        """Run several data against one query
173
174        :param query: query to execute on server
175        :param args:  Sequence of sequences or mappings.  It is used as parameter.
176        :return: Number of rows affected, if any.
177
178        This method improves performance on multiple-row INSERT and
179        REPLACE. Otherwise it is equivalent to looping over args with
180        execute().
181        """
182        if not args:
183            return
184
185        m = RE_INSERT_VALUES.match(query)
186        if m:
187            q_prefix = m.group(1) % ()
188            q_values = m.group(2).rstrip()
189            q_postfix = m.group(3) or ''
190            assert q_values[0] == '(' and q_values[-1] == ')'
191            return self._do_execute_many(q_prefix, q_values, q_postfix, args,
192                                         self.max_stmt_length,
193                                         self._get_db().encoding)
194
195        self.rowcount = sum(self.execute(query, arg) for arg in args)
196        return self.rowcount
197
198    def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
199        conn = self._get_db()
200        escape = self._escape_args
201        if isinstance(prefix, text_type):
202            prefix = prefix.encode(encoding)
203        if PY2 and isinstance(values, text_type):
204            values = values.encode(encoding)
205        if isinstance(postfix, text_type):
206            postfix = postfix.encode(encoding)
207        sql = bytearray(prefix)
208        args = iter(args)
209        v = values % escape(next(args), conn)
210        if isinstance(v, text_type):
211            if PY2:
212                v = v.encode(encoding)
213            else:
214                v = v.encode(encoding, 'surrogateescape')
215        sql += v
216        rows = 0
217        for arg in args:
218            v = values % escape(arg, conn)
219            if isinstance(v, text_type):
220                if PY2:
221                    v = v.encode(encoding)
222                else:
223                    v = v.encode(encoding, 'surrogateescape')
224            if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
225                rows += self.execute(sql + postfix)
226                sql = bytearray(prefix)
227            else:
228                sql += b','
229            sql += v
230        rows += self.execute(sql + postfix)
231        self.rowcount = rows
232        return rows
233
234    def callproc(self, procname, args=()):
235        """Execute stored procedure procname with args
236
237        procname -- string, name of procedure to execute on server
238
239        args -- Sequence of parameters to use with procedure
240
241        Returns the original args.
242
243        Compatibility warning: PEP-249 specifies that any modified
244        parameters must be returned. This is currently impossible
245        as they are only available by storing them in a server
246        variable and then retrieved by a query. Since stored
247        procedures return zero or more result sets, there is no
248        reliable way to get at OUT or INOUT parameters via callproc.
249        The server variables are named @_procname_n, where procname
250        is the parameter above and n is the position of the parameter
251        (from zero). Once all result sets generated by the procedure
252        have been fetched, you can issue a SELECT @_procname_0, ...
253        query using .execute() to get any OUT or INOUT values.
254
255        Compatibility warning: The act of calling a stored procedure
256        itself creates an empty result set. This appears after any
257        result sets generated by the procedure. This is non-standard
258        behavior with respect to the DB-API. Be sure to use nextset()
259        to advance through all result sets; otherwise you may get
260        disconnected.
261        """
262        conn = self._get_db()
263        for index, arg in enumerate(args):
264            q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
265            self._query(q)
266            self.nextset()
267
268        q = "CALL %s(%s)" % (procname,
269                             ','.join(['@_%s_%d' % (procname, i)
270                                       for i in range_type(len(args))]))
271        self._query(q)
272        self._executed = q
273        return args
274
275    def fetchone(self):
276        """Fetch the next row"""
277        self._check_executed()
278        if self._rows is None or self.rownumber >= len(self._rows):
279            return None
280        result = self._rows[self.rownumber]
281        self.rownumber += 1
282        return result
283
284    def fetchmany(self, size=None):
285        """Fetch several rows"""
286        self._check_executed()
287        if self._rows is None:
288            return ()
289        end = self.rownumber + (size or self.arraysize)
290        result = self._rows[self.rownumber:end]
291        self.rownumber = min(end, len(self._rows))
292        return result
293
294    def fetchall(self):
295        """Fetch all the rows"""
296        self._check_executed()
297        if self._rows is None:
298            return ()
299        if self.rownumber:
300            result = self._rows[self.rownumber:]
301        else:
302            result = self._rows
303        self.rownumber = len(self._rows)
304        return result
305
306    def scroll(self, value, mode='relative'):
307        self._check_executed()
308        if mode == 'relative':
309            r = self.rownumber + value
310        elif mode == 'absolute':
311            r = value
312        else:
313            raise err.ProgrammingError("unknown scroll mode %s" % mode)
314
315        if not (0 <= r < len(self._rows)):
316            raise IndexError("out of range")
317        self.rownumber = r
318
319    def _query(self, q):
320        conn = self._get_db()
321        self._last_executed = q
322        conn.query(q)
323        self._do_get_result()
324        return self.rowcount
325
326    def _do_get_result(self):
327        conn = self._get_db()
328
329        self.rownumber = 0
330        self._result = result = conn._result
331
332        self.rowcount = result.affected_rows
333        self.description = result.description
334        self.lastrowid = result.insert_id
335        self._rows = result.rows
336        self._warnings_handled = False
337
338        if not self._defer_warnings:
339            self._show_warnings()
340
341    def _show_warnings(self):
342        if self._warnings_handled:
343            return
344        self._warnings_handled = True
345        if self._result and (self._result.has_next or not self._result.warning_count):
346            return
347        ws = self._get_db().show_warnings()
348        if ws is None:
349            return
350        for w in ws:
351            msg = w[-1]
352            if PY2:
353                if isinstance(msg, unicode):
354                    msg = msg.encode('utf-8', 'replace')
355            warnings.warn(err.Warning(*w[1:3]), stacklevel=4)
356
357    def __iter__(self):
358        return iter(self.fetchone, None)
359
360    Warning = err.Warning
361    Error = err.Error
362    InterfaceError = err.InterfaceError
363    DatabaseError = err.DatabaseError
364    DataError = err.DataError
365    OperationalError = err.OperationalError
366    IntegrityError = err.IntegrityError
367    InternalError = err.InternalError
368    ProgrammingError = err.ProgrammingError
369    NotSupportedError = err.NotSupportedError
370
371
372class DictCursorMixin(object):
373    # You can override this to use OrderedDict or other dict-like types.
374    dict_type = dict
375
376    def _do_get_result(self):
377        super(DictCursorMixin, self)._do_get_result()
378        fields = []
379        if self.description:
380            for f in self._result.fields:
381                name = f.name
382                if name in fields:
383                    name = f.table_name + '.' + name
384                fields.append(name)
385            self._fields = fields
386
387        if fields and self._rows:
388            self._rows = [self._conv_row(r) for r in self._rows]
389
390    def _conv_row(self, row):
391        if row is None:
392            return None
393        return self.dict_type(zip(self._fields, row))
394
395
396class DictCursor(DictCursorMixin, Cursor):
397    """A cursor which returns results as a dictionary"""
398
399
400class SSCursor(Cursor):
401    """
402    Unbuffered Cursor, mainly useful for queries that return a lot of data,
403    or for connections to remote servers over a slow network.
404
405    Instead of copying every row of data into a buffer, this will fetch
406    rows as needed. The upside of this, is the client uses much less memory,
407    and rows are returned much faster when traveling over a slow network,
408    or if the result set is very big.
409
410    There are limitations, though. The MySQL protocol doesn't support
411    returning the total number of rows, so the only way to tell how many rows
412    there are is to iterate over every row returned. Also, it currently isn't
413    possible to scroll backwards, as only the current row is held in memory.
414    """
415
416    _defer_warnings = True
417
418    def _conv_row(self, row):
419        return row
420
421    def close(self):
422        conn = self.connection
423        if conn is None:
424            return
425
426        if self._result is not None and self._result is conn._result:
427            self._result._finish_unbuffered_query()
428
429        try:
430            while self.nextset():
431                pass
432        finally:
433            self.connection = None
434
435    def _query(self, q):
436        conn = self._get_db()
437        self._last_executed = q
438        conn.query(q, unbuffered=True)
439        self._do_get_result()
440        return self.rowcount
441
442    def nextset(self):
443        return self._nextset(unbuffered=True)
444
445    def read_next(self):
446        """Read next row"""
447        return self._conv_row(self._result._read_rowdata_packet_unbuffered())
448
449    def fetchone(self):
450        """Fetch next row"""
451        self._check_executed()
452        row = self.read_next()
453        if row is None:
454            self._show_warnings()
455            return None
456        self.rownumber += 1
457        return row
458
459    def fetchall(self):
460        """
461        Fetch all, as per MySQLdb. Pretty useless for large queries, as
462        it is buffered. See fetchall_unbuffered(), if you want an unbuffered
463        generator version of this method.
464        """
465        return list(self.fetchall_unbuffered())
466
467    def fetchall_unbuffered(self):
468        """
469        Fetch all, implemented as a generator, which isn't to standard,
470        however, it doesn't make sense to return everything in a list, as that
471        would use ridiculous memory for large result sets.
472        """
473        return iter(self.fetchone, None)
474
475    def __iter__(self):
476        return self.fetchall_unbuffered()
477
478    def fetchmany(self, size=None):
479        """Fetch many"""
480        self._check_executed()
481        if size is None:
482            size = self.arraysize
483
484        rows = []
485        for i in range_type(size):
486            row = self.read_next()
487            if row is None:
488                self._show_warnings()
489                break
490            rows.append(row)
491            self.rownumber += 1
492        return rows
493
494    def scroll(self, value, mode='relative'):
495        self._check_executed()
496
497        if mode == 'relative':
498            if value < 0:
499                raise err.NotSupportedError(
500                        "Backwards scrolling not supported by this cursor")
501
502            for _ in range_type(value):
503                self.read_next()
504            self.rownumber += value
505        elif mode == 'absolute':
506            if value < self.rownumber:
507                raise err.NotSupportedError(
508                    "Backwards scrolling not supported by this cursor")
509
510            end = value - self.rownumber
511            for _ in range_type(end):
512                self.read_next()
513            self.rownumber = value
514        else:
515            raise err.ProgrammingError("unknown scroll mode %s" % mode)
516
517
518class SSDictCursor(DictCursorMixin, SSCursor):
519    """An unbuffered cursor, which returns results as a dictionary"""
Note: See TracBrowser for help on using the repository browser.