source: OpenRLabs-Git/deploy/rlabs-docker/web2py-rlabs/gluon/contrib/pymysql/connections.py

main
Last change on this file was 42bd667, checked in by David Fuertes <dfuertes@…>, 4 years ago

Historial Limpio

  • Property mode set to 100755
File size: 54.4 KB
Line 
1# Python implementation of the MySQL client-server protocol
2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
3# Error codes:
4# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
5from __future__ import print_function
6from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
7
8import errno
9from functools import partial
10import hashlib
11import io
12import os
13import socket
14import struct
15import sys
16import traceback
17import warnings
18
19from .charset import MBLENGTH, charset_by_name, charset_by_id
20from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
21from .converters import escape_item, escape_string, through, conversions as _conv
22from .cursors import Cursor
23from .optionfile import Parser
24from .util import byte2int, int2byte
25from . import err
26
27try:
28    import ssl
29    SSL_ENABLED = True
30except ImportError:
31    ssl = None
32    SSL_ENABLED = False
33
34try:
35    import getpass
36    DEFAULT_USER = getpass.getuser()
37    del getpass
38except (ImportError, KeyError):
39    # KeyError occurs when there's no entry in OS database for a current user.
40    DEFAULT_USER = None
41
42
43DEBUG = False
44
45_py_version = sys.version_info[:2]
46
47
48# socket.makefile() in Python 2 is not usable because very inefficient and
49# bad behavior about timeout.
50# XXX: ._socketio doesn't work under IronPython.
51if _py_version == (2, 7) and not IRONPYTHON:
52    # read method of file-like returned by sock.makefile() is very slow.
53    # So we copy io-based one from Python 3.
54    from ._socketio import SocketIO
55
56    def _makefile(sock, mode):
57        return io.BufferedReader(SocketIO(sock, mode))
58elif _py_version == (2, 6):
59    # Python 2.6 doesn't have fast io module.
60    # So we make original one.
61    class SockFile(object):
62        def __init__(self, sock):
63            self._sock = sock
64
65        def read(self, n):
66            read = self._sock.recv(n)
67            if len(read) == n:
68                return read
69            while True:
70                data = self._sock.recv(n-len(read))
71                if not data:
72                    return read
73                read += data
74                if len(read) == n:
75                    return read
76
77    def _makefile(sock, mode):
78        assert mode == 'rb'
79        return SockFile(sock)
80else:
81    # socket.makefile in Python 3 is nice.
82    def _makefile(sock, mode):
83        return sock.makefile(mode)
84
85
86TEXT_TYPES = set([
87    FIELD_TYPE.BIT,
88    FIELD_TYPE.BLOB,
89    FIELD_TYPE.LONG_BLOB,
90    FIELD_TYPE.MEDIUM_BLOB,
91    FIELD_TYPE.STRING,
92    FIELD_TYPE.TINY_BLOB,
93    FIELD_TYPE.VAR_STRING,
94    FIELD_TYPE.VARCHAR,
95    FIELD_TYPE.GEOMETRY])
96
97sha_new = partial(hashlib.new, 'sha1')
98
99NULL_COLUMN = 251
100UNSIGNED_CHAR_COLUMN = 251
101UNSIGNED_SHORT_COLUMN = 252
102UNSIGNED_INT24_COLUMN = 253
103UNSIGNED_INT64_COLUMN = 254
104
105DEFAULT_CHARSET = 'latin1'
106
107MAX_PACKET_LEN = 2**24-1
108
109
110def dump_packet(data): # pragma: no cover
111    def is_ascii(data):
112        if 65 <= byte2int(data) <= 122:
113            if isinstance(data, int):
114                return chr(data)
115            return data
116        return '.'
117
118    try:
119        print("packet length:", len(data))
120        for i in range(1, 6):
121            f = sys._getframe(i)
122            print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
123        print("-" * 66)
124    except ValueError:
125        pass
126    dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
127    for d in dump_data:
128        print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) +
129              '   ' * (16 - len(d)) + ' ' * 2 +
130              ''.join(map(lambda x: "{}".format(is_ascii(x)), d)))
131    print("-" * 66)
132    print()
133
134
135def _scramble(password, message):
136    if not password:
137        return b''
138    if DEBUG: print('password=' + str(password))
139    stage1 = sha_new(password).digest()
140    stage2 = sha_new(stage1).digest()
141    s = sha_new()
142    s.update(message)
143    s.update(stage2)
144    result = s.digest()
145    return _my_crypt(result, stage1)
146
147
148def _my_crypt(message1, message2):
149    length = len(message1)
150    result = b''
151    for i in range_type(length):
152        x = (struct.unpack('B', message1[i:i+1])[0] ^
153             struct.unpack('B', message2[i:i+1])[0])
154        result += struct.pack('B', x)
155    return result
156
157# old_passwords support ported from libmysql/password.c
158SCRAMBLE_LENGTH_323 = 8
159
160
161class RandStruct_323(object):
162    def __init__(self, seed1, seed2):
163        self.max_value = 0x3FFFFFFF
164        self.seed1 = seed1 % self.max_value
165        self.seed2 = seed2 % self.max_value
166
167    def my_rnd(self):
168        self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value
169        self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value
170        return float(self.seed1) / float(self.max_value)
171
172
173def _scramble_323(password, message):
174    hash_pass = _hash_password_323(password)
175    hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
176    hash_pass_n = struct.unpack(">LL", hash_pass)
177    hash_message_n = struct.unpack(">LL", hash_message)
178
179    rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0],
180                             hash_pass_n[1] ^ hash_message_n[1])
181    outbuf = io.BytesIO()
182    for _ in range_type(min(SCRAMBLE_LENGTH_323, len(message))):
183        outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
184    extra = int2byte(int(rand_st.my_rnd() * 31))
185    out = outbuf.getvalue()
186    outbuf = io.BytesIO()
187    for c in out:
188        outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
189    return outbuf.getvalue()
190
191
192def _hash_password_323(password):
193    nr = 1345345333
194    add = 7
195    nr2 = 0x12345671
196
197    # x in py3 is numbers, p27 is chars
198    for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
199        nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
200        nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
201        add = (add + c) & 0xFFFFFFFF
202
203    r1 = nr & ((1 << 31) - 1)  # kill sign bits
204    r2 = nr2 & ((1 << 31) - 1)
205    return struct.pack(">LL", r1, r2)
206
207
208def pack_int24(n):
209    return struct.pack('<I', n)[:3]
210
211# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
212def lenenc_int(i):
213    if (i < 0):
214        raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
215    elif (i < 0xfb):
216        return int2byte(i)
217    elif (i < (1 << 16)):
218        return b'\xfc' + struct.pack('<H', i)
219    elif (i < (1 << 24)):
220        return b'\xfd' + struct.pack('<I', i)[:3]
221    elif (i < (1 << 64)):
222        return b'\xfe' + struct.pack('<Q', i)
223    else:
224        raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
225
226class MysqlPacket(object):
227    """Representation of a MySQL response packet.
228
229    Provides an interface for reading/parsing the packet results.
230    """
231    __slots__ = ('_position', '_data')
232
233    def __init__(self, data, encoding):
234        self._position = 0
235        self._data = data
236
237    def get_all_data(self):
238        return self._data
239
240    def read(self, size):
241        """Read the first 'size' bytes in packet and advance cursor past them."""
242        result = self._data[self._position:(self._position+size)]
243        if len(result) != size:
244            error = ('Result length not requested length:\n'
245                     'Expected=%s.  Actual=%s.  Position: %s.  Data Length: %s'
246                     % (size, len(result), self._position, len(self._data)))
247            if DEBUG:
248                print(error)
249                self.dump()
250            raise AssertionError(error)
251        self._position += size
252        return result
253
254    def read_all(self):
255        """Read all remaining data in the packet.
256
257        (Subsequent read() will return errors.)
258        """
259        result = self._data[self._position:]
260        self._position = None  # ensure no subsequent read()
261        return result
262
263    def advance(self, length):
264        """Advance the cursor in data buffer 'length' bytes."""
265        new_position = self._position + length
266        if new_position < 0 or new_position > len(self._data):
267            raise Exception('Invalid advance amount (%s) for cursor.  '
268                            'Position=%s' % (length, new_position))
269        self._position = new_position
270
271    def rewind(self, position=0):
272        """Set the position of the data buffer cursor to 'position'."""
273        if position < 0 or position > len(self._data):
274            raise Exception("Invalid position to rewind cursor to: %s." % position)
275        self._position = position
276
277    def get_bytes(self, position, length=1):
278        """Get 'length' bytes starting at 'position'.
279
280        Position is start of payload (first four packet header bytes are not
281        included) starting at index '0'.
282
283        No error checking is done.  If requesting outside end of buffer
284        an empty string (or string shorter than 'length') may be returned!
285        """
286        return self._data[position:(position+length)]
287
288    if PY2:
289        def read_uint8(self):
290            result = ord(self._data[self._position])
291            self._position += 1
292            return result
293    else:
294        def read_uint8(self):
295            result = self._data[self._position]
296            self._position += 1
297            return result
298
299    def read_uint16(self):
300        result = struct.unpack_from('<H', self._data, self._position)[0]
301        self._position += 2
302        return result
303
304    def read_uint24(self):
305        low, high = struct.unpack_from('<HB', self._data, self._position)
306        self._position += 3
307        return low + (high << 16)
308
309    def read_uint32(self):
310        result = struct.unpack_from('<I', self._data, self._position)[0]
311        self._position += 4
312        return result
313
314    def read_uint64(self):
315        result = struct.unpack_from('<Q', self._data, self._position)[0]
316        self._position += 8
317        return result
318
319    def read_string(self):
320        end_pos = self._data.find(b'\0', self._position)
321        if end_pos < 0:
322            return None
323        result = self._data[self._position:end_pos]
324        self._position = end_pos + 1
325        return result
326
327    def read_length_encoded_integer(self):
328        """Read a 'Length Coded Binary' number from the data buffer.
329
330        Length coded numbers can be anywhere from 1 to 9 bytes depending
331        on the value of the first byte.
332        """
333        c = self.read_uint8()
334        if c == NULL_COLUMN:
335            return None
336        if c < UNSIGNED_CHAR_COLUMN:
337            return c
338        elif c == UNSIGNED_SHORT_COLUMN:
339            return self.read_uint16()
340        elif c == UNSIGNED_INT24_COLUMN:
341            return self.read_uint24()
342        elif c == UNSIGNED_INT64_COLUMN:
343            return self.read_uint64()
344
345    def read_length_coded_string(self):
346        """Read a 'Length Coded String' from the data buffer.
347
348        A 'Length Coded String' consists first of a length coded
349        (unsigned, positive) integer represented in 1-9 bytes followed by
350        that many bytes of binary data.  (For example "cat" would be "3cat".)
351        """
352        length = self.read_length_encoded_integer()
353        if length is None:
354            return None
355        return self.read(length)
356
357    def read_struct(self, fmt):
358        s = struct.Struct(fmt)
359        result = s.unpack_from(self._data, self._position)
360        self._position += s.size
361        return result
362
363    def is_ok_packet(self):
364        # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
365        return self._data[0:1] == b'\0' and len(self._data) >= 7
366
367    def is_eof_packet(self):
368        # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
369        # Caution: \xFE may be LengthEncodedInteger.
370        # If \xFE is LengthEncodedInteger header, 8bytes followed.
371        return self._data[0:1] == b'\xfe' and len(self._data) < 9
372
373    def is_auth_switch_request(self):
374        # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
375        return self._data[0:1] == b'\xfe'
376
377    def is_resultset_packet(self):
378        field_count = ord(self._data[0:1])
379        return 1 <= field_count <= 250
380
381    def is_load_local_packet(self):
382        return self._data[0:1] == b'\xfb'
383
384    def is_error_packet(self):
385        return self._data[0:1] == b'\xff'
386
387    def check_error(self):
388        if self.is_error_packet():
389            self.rewind()
390            self.advance(1)  # field_count == error (we already know that)
391            errno = self.read_uint16()
392            if DEBUG: print("errno =", errno)
393            err.raise_mysql_exception(self._data)
394
395    def dump(self):
396        dump_packet(self._data)
397
398
399class FieldDescriptorPacket(MysqlPacket):
400    """A MysqlPacket that represents a specific column's metadata in the result.
401
402    Parsing is automatically done and the results are exported via public
403    attributes on the class such as: db, table_name, name, length, type_code.
404    """
405
406    def __init__(self, data, encoding):
407        MysqlPacket.__init__(self, data, encoding)
408        self._parse_field_descriptor(encoding)
409
410    def _parse_field_descriptor(self, encoding):
411        """Parse the 'Field Descriptor' (Metadata) packet.
412
413        This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
414        """
415        self.catalog = self.read_length_coded_string()
416        self.db = self.read_length_coded_string()
417        self.table_name = self.read_length_coded_string().decode(encoding)
418        self.org_table = self.read_length_coded_string().decode(encoding)
419        self.name = self.read_length_coded_string().decode(encoding)
420        self.org_name = self.read_length_coded_string().decode(encoding)
421        self.charsetnr, self.length, self.type_code, self.flags, self.scale = (
422            self.read_struct('<xHIBHBxx'))
423        # 'default' is a length coded binary and is still in the buffer?
424        # not used for normal result sets...
425
426    def description(self):
427        """Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
428        return (
429            self.name,
430            self.type_code,
431            None,  # TODO: display_length; should this be self.length?
432            self.get_column_length(),  # 'internal_size'
433            self.get_column_length(),  # 'precision'  # TODO: why!?!?
434            self.scale,
435            self.flags % 2 == 0)
436
437    def get_column_length(self):
438        if self.type_code == FIELD_TYPE.VAR_STRING:
439            mblen = MBLENGTH.get(self.charsetnr, 1)
440            return self.length // mblen
441        return self.length
442
443    def __str__(self):
444        return ('%s %r.%r.%r, type=%s, flags=%x'
445                % (self.__class__, self.db, self.table_name, self.name,
446                   self.type_code, self.flags))
447
448
449class OKPacketWrapper(object):
450    """
451    OK Packet Wrapper. It uses an existing packet object, and wraps
452    around it, exposing useful variables while still providing access
453    to the original packet objects variables and methods.
454    """
455
456    def __init__(self, from_packet):
457        if not from_packet.is_ok_packet():
458            raise ValueError('Cannot create ' + str(self.__class__.__name__) +
459                             ' object from invalid packet type')
460
461        self.packet = from_packet
462        self.packet.advance(1)
463
464        self.affected_rows = self.packet.read_length_encoded_integer()
465        self.insert_id = self.packet.read_length_encoded_integer()
466        self.server_status, self.warning_count = self.read_struct('<HH')
467        self.message = self.packet.read_all()
468        self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
469
470    def __getattr__(self, key):
471        return getattr(self.packet, key)
472
473
474class EOFPacketWrapper(object):
475    """
476    EOF Packet Wrapper. It uses an existing packet object, and wraps
477    around it, exposing useful variables while still providing access
478    to the original packet objects variables and methods.
479    """
480
481    def __init__(self, from_packet):
482        if not from_packet.is_eof_packet():
483            raise ValueError(
484                "Cannot create '{0}' object from invalid packet type".format(
485                    self.__class__))
486
487        self.packet = from_packet
488        self.warning_count, self.server_status = self.packet.read_struct('<xhh')
489        if DEBUG: print("server_status=", self.server_status)
490        self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
491
492    def __getattr__(self, key):
493        return getattr(self.packet, key)
494
495
496class LoadLocalPacketWrapper(object):
497    """
498    Load Local Packet Wrapper. It uses an existing packet object, and wraps
499    around it, exposing useful variables while still providing access
500    to the original packet objects variables and methods.
501    """
502
503    def __init__(self, from_packet):
504        if not from_packet.is_load_local_packet():
505            raise ValueError(
506                "Cannot create '{0}' object from invalid packet type".format(
507                    self.__class__))
508
509        self.packet = from_packet
510        self.filename = self.packet.get_all_data()[1:]
511        if DEBUG: print("filename=", self.filename)
512
513    def __getattr__(self, key):
514        return getattr(self.packet, key)
515
516
517class Connection(object):
518    """
519    Representation of a socket with a mysql server.
520
521    The proper way to get an instance of this class is to call
522    connect().
523    """
524
525    _sock = None
526    _auth_plugin_name = ''
527
528    def __init__(self, host=None, user=None, password="",
529                 database=None, port=0, unix_socket=None,
530                 charset='', sql_mode=None,
531                 read_default_file=None, conv=None, use_unicode=None,
532                 client_flag=0, cursorclass=Cursor, init_command=None,
533                 connect_timeout=None, ssl=None, read_default_group=None,
534                 compress=None, named_pipe=None, no_delay=None,
535                 autocommit=False, db=None, passwd=None, local_infile=False,
536                 max_allowed_packet=16*1024*1024, defer_connect=False,
537                 auth_plugin_map={}, read_timeout=None, write_timeout=None):
538        """
539        Establish a connection to the MySQL database. Accepts several
540        arguments:
541
542        host: Host where the database server is located
543        user: Username to log in as
544        password: Password to use.
545        database: Database to use, None to not use a particular one.
546        port: MySQL port to use, default is usually OK. (default: 3306)
547        unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
548        charset: Charset you want to use.
549        sql_mode: Default SQL_MODE to use.
550        read_default_file:
551            Specifies  my.cnf file to read these parameters from under the [client] section.
552        conv:
553            Conversion dictionary to use instead of the default one.
554            This is used to provide custom marshalling and unmarshaling of types.
555            See converters.
556        use_unicode:
557            Whether or not to default to unicode strings.
558            This option defaults to true for Py3k.
559        client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT.
560        cursorclass: Custom cursor class to use.
561        init_command: Initial SQL statement to run when connection is established.
562        connect_timeout: Timeout before throwing an exception when connecting.
563        ssl:
564            A dict of arguments similar to mysql_ssl_set()'s parameters.
565            For now the capath and cipher arguments are not supported.
566        read_default_group: Group to read from in the configuration file.
567        compress; Not supported
568        named_pipe: Not supported
569        autocommit: Autocommit mode. None means use server default. (default: False)
570        local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
571        max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
572            Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
573        defer_connect: Don't explicitly connect on contruction - wait for connect call.
574            (default: False)
575        auth_plugin_map: A dict of plugin names to a class that processes that plugin.
576            The class will take the Connection object as the argument to the constructor.
577            The class needs an authenticate method taking an authentication packet as
578            an argument.  For the dialog plugin, a prompt(echo, prompt) method can be used
579            (if no authenticate method) for returning a string from the user. (experimental)
580        db: Alias for database. (for compatibility to MySQLdb)
581        passwd: Alias for password. (for compatibility to MySQLdb)
582        """
583        if no_delay is not None:
584            warnings.warn("no_delay option is deprecated", DeprecationWarning)
585
586        if use_unicode is None and sys.version_info[0] > 2:
587            use_unicode = True
588
589        if db is not None and database is None:
590            database = db
591        if passwd is not None and not password:
592            password = passwd
593
594        if compress or named_pipe:
595            raise NotImplementedError("compress and named_pipe arguments are not supported")
596
597        if local_infile:
598            client_flag |= CLIENT.LOCAL_FILES
599
600        self.ssl = False
601        if ssl:
602            if not SSL_ENABLED:
603                raise NotImplementedError("ssl module not found")
604            self.ssl = True
605            client_flag |= CLIENT.SSL
606            self.ctx = self._create_ssl_ctx(ssl)
607
608        if read_default_group and not read_default_file:
609            if sys.platform.startswith("win"):
610                read_default_file = "c:\\my.ini"
611            else:
612                read_default_file = "/etc/my.cnf"
613
614        if read_default_file:
615            if not read_default_group:
616                read_default_group = "client"
617
618            cfg = Parser()
619            cfg.read(os.path.expanduser(read_default_file))
620
621            def _config(key, arg):
622                if arg:
623                    return arg
624                try:
625                    return cfg.get(read_default_group, key)
626                except Exception:
627                    return arg
628
629            user = _config("user", user)
630            password = _config("password", password)
631            host = _config("host", host)
632            database = _config("database", database)
633            unix_socket = _config("socket", unix_socket)
634            port = int(_config("port", port))
635            charset = _config("default-character-set", charset)
636
637        self.host = host or "localhost"
638        self.port = port or 3306
639        self.user = user or DEFAULT_USER
640        self.password = password or ""
641        self.db = database
642        self.unix_socket = unix_socket
643        if read_timeout is not None and read_timeout <= 0:
644            raise ValueError("read_timeout should be >= 0")
645        self._read_timeout = read_timeout
646        if write_timeout is not None and write_timeout <= 0:
647            raise ValueError("write_timeout should be >= 0")
648        self._write_timeout = write_timeout
649        if charset:
650            self.charset = charset
651            self.use_unicode = True
652        else:
653            self.charset = DEFAULT_CHARSET
654            self.use_unicode = False
655
656        if use_unicode is not None:
657            self.use_unicode = use_unicode
658
659        self.encoding = charset_by_name(self.charset).encoding
660
661        client_flag |= CLIENT.CAPABILITIES
662        if self.db:
663            client_flag |= CLIENT.CONNECT_WITH_DB
664        self.client_flag = client_flag
665
666        self.cursorclass = cursorclass
667        self.connect_timeout = connect_timeout
668
669        self._result = None
670        self._affected_rows = 0
671        self.host_info = "Not connected"
672
673        #: specified autocommit mode. None means use server default.
674        self.autocommit_mode = autocommit
675
676        if conv is None:
677            conv = _conv
678        # Need for MySQLdb compatibility.
679        self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
680        self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
681        self.sql_mode = sql_mode
682        self.init_command = init_command
683        self.max_allowed_packet = max_allowed_packet
684        self._auth_plugin_map = auth_plugin_map
685        if defer_connect:
686            self._sock = None
687        else:
688            self.connect()
689
690    def _create_ssl_ctx(self, sslp):
691        if isinstance(sslp, ssl.SSLContext):
692            return sslp
693        ca = sslp.get('ca')
694        capath = sslp.get('capath')
695        hasnoca = ca is None and capath is None
696        ctx = ssl.create_default_context(cafile=ca, capath=capath)
697        ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
698        ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
699        if 'cert' in sslp:
700            ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
701        if 'cipher' in sslp:
702            ctx.set_ciphers(sslp['cipher'])
703        ctx.options |= ssl.OP_NO_SSLv2
704        ctx.options |= ssl.OP_NO_SSLv3
705        return ctx
706
707    def close(self):
708        """Send the quit message and close the socket"""
709        if self._sock is None:
710            raise err.Error("Already closed")
711        send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
712        try:
713            self._write_bytes(send_data)
714        except Exception:
715            pass
716        finally:
717            sock = self._sock
718            self._sock = None
719            self._rfile = None
720            sock.close()
721
722    @property
723    def open(self):
724        return self._sock is not None
725
726    def __del__(self):
727        if self._sock:
728            try:
729                self._sock.close()
730            except:
731                pass
732        self._sock = None
733        self._rfile = None
734
735    def autocommit(self, value):
736        self.autocommit_mode = bool(value)
737        current = self.get_autocommit()
738        if value != current:
739            self._send_autocommit_mode()
740
741    def get_autocommit(self):
742        return bool(self.server_status &
743                    SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
744
745    def _read_ok_packet(self):
746        pkt = self._read_packet()
747        if not pkt.is_ok_packet():
748            raise err.OperationalError(2014, "Command Out of Sync")
749        ok = OKPacketWrapper(pkt)
750        self.server_status = ok.server_status
751        return ok
752
753    def _send_autocommit_mode(self):
754        """Set whether or not to commit after every execute()"""
755        self._execute_command(COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" %
756                              self.escape(self.autocommit_mode))
757        self._read_ok_packet()
758
759    def begin(self):
760        """Begin transaction."""
761        self._execute_command(COMMAND.COM_QUERY, "BEGIN")
762        self._read_ok_packet()
763
764    def commit(self):
765        """Commit changes to stable storage"""
766        self._execute_command(COMMAND.COM_QUERY, "COMMIT")
767        self._read_ok_packet()
768
769    def rollback(self):
770        """Roll back the current transaction"""
771        self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
772        self._read_ok_packet()
773
774    def show_warnings(self):
775        """SHOW WARNINGS"""
776        self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
777        result = MySQLResult(self)
778        result.read()
779        return result.rows
780
781    def select_db(self, db):
782        """Set current db"""
783        self._execute_command(COMMAND.COM_INIT_DB, db)
784        self._read_ok_packet()
785
786    def escape(self, obj, mapping=None):
787        """Escape whatever value you pass to it.
788       
789        Non-standard, for internal use; do not use this in your applications.
790        """
791        if isinstance(obj, str_type):
792            return "'" + self.escape_string(obj) + "'"
793        return escape_item(obj, self.charset, mapping=mapping)
794
795    def literal(self, obj):
796        """Alias for escape()
797       
798        Non-standard, for internal use; do not use this in your applications.
799        """
800        return self.escape(obj, self.encoders)
801
802    def escape_string(self, s):
803        if (self.server_status &
804                SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES):
805            return s.replace("'", "''")
806        return escape_string(s)
807
808    def cursor(self, cursor=None):
809        """Create a new cursor to execute queries with"""
810        if cursor:
811            return cursor(self)
812        return self.cursorclass(self)
813
814    def __enter__(self):
815        """Context manager that returns a Cursor"""
816        return self.cursor()
817
818    def __exit__(self, exc, value, traceback):
819        """On successful exit, commit. On exception, rollback"""
820        if exc:
821            self.rollback()
822        else:
823            self.commit()
824
825    # The following methods are INTERNAL USE ONLY (called from Cursor)
826    def query(self, sql, unbuffered=False):
827        # if DEBUG:
828        #     print("DEBUG: sending query:", sql)
829        if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
830            if PY2:
831                sql = sql.encode(self.encoding)
832            else:
833                sql = sql.encode(self.encoding, 'surrogateescape')
834        self._execute_command(COMMAND.COM_QUERY, sql)
835        self._affected_rows = self._read_query_result(unbuffered=unbuffered)
836        return self._affected_rows
837
838    def next_result(self, unbuffered=False):
839        self._affected_rows = self._read_query_result(unbuffered=unbuffered)
840        return self._affected_rows
841
842    def affected_rows(self):
843        return self._affected_rows
844
845    def kill(self, thread_id):
846        arg = struct.pack('<I', thread_id)
847        self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
848        return self._read_ok_packet()
849
850    def ping(self, reconnect=True):
851        """Check if the server is alive"""
852        if self._sock is None:
853            if reconnect:
854                self.connect()
855                reconnect = False
856            else:
857                raise err.Error("Already closed")
858        try:
859            self._execute_command(COMMAND.COM_PING, "")
860            return self._read_ok_packet()
861        except Exception:
862            if reconnect:
863                self.connect()
864                return self.ping(False)
865            else:
866                raise
867
868    def set_charset(self, charset):
869        # Make sure charset is supported.
870        encoding = charset_by_name(charset).encoding
871
872        self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s" % self.escape(charset))
873        self._read_packet()
874        self.charset = charset
875        self.encoding = encoding
876
877    def connect(self, sock=None):
878        try:
879            if sock is None:
880                if self.unix_socket and self.host in ('localhost', '127.0.0.1'):
881                    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
882                    sock.settimeout(self.connect_timeout)
883                    sock.connect(self.unix_socket)
884                    self.host_info = "Localhost via UNIX socket"
885                    if DEBUG: print('connected using unix_socket')
886                else:
887                    while True:
888                        try:
889                            sock = socket.create_connection(
890                                (self.host, self.port), self.connect_timeout)
891                            break
892                        except (OSError, IOError) as e:
893                            if e.errno == errno.EINTR:
894                                continue
895                            raise
896                    self.host_info = "socket %s:%d" % (self.host, self.port)
897                    if DEBUG: print('connected using socket')
898                    sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
899                sock.settimeout(None)
900                sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
901            self._sock = sock
902            self._rfile = _makefile(sock, 'rb')
903            self._next_seq_id = 0
904
905            self._get_server_information()
906            self._request_authentication()
907
908            if self.sql_mode is not None:
909                c = self.cursor()
910                c.execute("SET sql_mode=%s", (self.sql_mode,))
911
912            if self.init_command is not None:
913                c = self.cursor()
914                c.execute(self.init_command)
915                c.close()
916                self.commit()
917
918            if self.autocommit_mode is not None:
919                self.autocommit(self.autocommit_mode)
920        except BaseException as e:
921            self._rfile = None
922            if sock is not None:
923                try:
924                    sock.close()
925                except:
926                    pass
927
928            if isinstance(e, (OSError, IOError, socket.error)):
929                exc = err.OperationalError(
930                        2003,
931                        "Can't connect to MySQL server on %r (%s)" % (
932                            self.host, e))
933                # Keep original exception and traceback to investigate error.
934                exc.original_exception = e
935                exc.traceback = traceback.format_exc()
936                if DEBUG: print(exc.traceback)
937                raise exc
938
939            # If e is neither DatabaseError or IOError, It's a bug.
940            # But raising AssertionError hides original error.
941            # So just reraise it.
942            raise
943
944    def write_packet(self, payload):
945        """Writes an entire "mysql packet" in its entirety to the network
946        addings its length and sequence number.
947        """
948        # Internal note: when you build packet manualy and calls _write_bytes()
949        # directly, you should set self._next_seq_id properly.
950        data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
951        if DEBUG: dump_packet(data)
952        self._write_bytes(data)
953        self._next_seq_id = (self._next_seq_id + 1) % 256
954
955    def _read_packet(self, packet_type=MysqlPacket):
956        """Read an entire "mysql packet" in its entirety from the network
957        and return a MysqlPacket type that represents the results.
958        """
959        buff = b''
960        while True:
961            packet_header = self._read_bytes(4)
962            if DEBUG: dump_packet(packet_header)
963
964            btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
965            bytes_to_read = btrl + (btrh << 16)
966            if packet_number != self._next_seq_id:
967                raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
968                    (packet_number, self._next_seq_id))
969            self._next_seq_id = (self._next_seq_id + 1) % 256
970
971            recv_data = self._read_bytes(bytes_to_read)
972            if DEBUG: dump_packet(recv_data)
973            buff += recv_data
974            # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
975            if bytes_to_read == 0xffffff:
976                continue
977            if bytes_to_read < MAX_PACKET_LEN:
978                break
979
980        packet = packet_type(buff, self.encoding)
981        packet.check_error()
982        return packet
983
984    def _read_bytes(self, num_bytes):
985        self._sock.settimeout(self._read_timeout)
986        while True:
987            try:
988                data = self._rfile.read(num_bytes)
989                break
990            except (IOError, OSError) as e:
991                if e.errno == errno.EINTR:
992                    continue
993                raise err.OperationalError(
994                    2013,
995                    "Lost connection to MySQL server during query (%s)" % (e,))
996        if len(data) < num_bytes:
997            raise err.OperationalError(
998                2013, "Lost connection to MySQL server during query")
999        return data
1000
1001    def _write_bytes(self, data):
1002        self._sock.settimeout(self._write_timeout)
1003        try:
1004            self._sock.sendall(data)
1005        except IOError as e:
1006            raise err.OperationalError(2006, "MySQL server has gone away (%r)" % (e,))
1007
1008    def _read_query_result(self, unbuffered=False):
1009        if unbuffered:
1010            try:
1011                result = MySQLResult(self)
1012                result.init_unbuffered_query()
1013            except:
1014                result.unbuffered_active = False
1015                result.connection = None
1016                raise
1017        else:
1018            result = MySQLResult(self)
1019            result.read()
1020        self._result = result
1021        if result.server_status is not None:
1022            self.server_status = result.server_status
1023        return result.affected_rows
1024
1025    def insert_id(self):
1026        if self._result:
1027            return self._result.insert_id
1028        else:
1029            return 0
1030
1031    def _execute_command(self, command, sql):
1032        if not self._sock:
1033            raise err.InterfaceError("(0, '')")
1034
1035        # If the last query was unbuffered, make sure it finishes before
1036        # sending new commands
1037        if self._result is not None:
1038            if self._result.unbuffered_active:
1039                warnings.warn("Previous unbuffered result was left incomplete")
1040                self._result._finish_unbuffered_query()
1041            while self._result.has_next:
1042                self.next_result()
1043            self._result = None
1044
1045        if isinstance(sql, text_type):
1046            sql = sql.encode(self.encoding)
1047
1048        packet_size = min(MAX_PACKET_LEN, len(sql) + 1)  # +1 is for command
1049
1050        # tiny optimization: build first packet manually instead of
1051        # calling self..write_packet()
1052        prelude = struct.pack('<iB', packet_size, command)
1053        packet = prelude + sql[:packet_size-1]
1054        self._write_bytes(packet)
1055        if DEBUG: dump_packet(packet)
1056        self._next_seq_id = 1
1057
1058        if packet_size < MAX_PACKET_LEN:
1059            return
1060
1061        sql = sql[packet_size-1:]
1062        while True:
1063            packet_size = min(MAX_PACKET_LEN, len(sql))
1064            self.write_packet(sql[:packet_size])
1065            sql = sql[packet_size:]
1066            if not sql and packet_size < MAX_PACKET_LEN:
1067                break
1068
1069    def _request_authentication(self):
1070        # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
1071        if int(self.server_version.split('.', 1)[0]) >= 5:
1072            self.client_flag |= CLIENT.MULTI_RESULTS
1073
1074        if self.user is None:
1075            raise ValueError("Did not specify a username")
1076
1077        charset_id = charset_by_name(self.charset).id
1078        if isinstance(self.user, text_type):
1079            self.user = self.user.encode(self.encoding)
1080
1081        data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')
1082
1083        if self.ssl and self.server_capabilities & CLIENT.SSL:
1084            self.write_packet(data_init)
1085
1086            self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
1087            self._rfile = _makefile(self._sock, 'rb')
1088
1089        data = data_init + self.user + b'\0'
1090
1091        authresp = b''
1092        if self._auth_plugin_name in ('', 'mysql_native_password'):
1093            authresp = _scramble(self.password.encode('latin1'), self.salt)
1094
1095        if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
1096            data += lenenc_int(len(authresp)) + authresp
1097        elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
1098            data += struct.pack('B', len(authresp)) + authresp
1099        else:  # pragma: no cover - not testing against servers without secure auth (>=5.0)
1100            data += authresp + b'\0'
1101
1102        if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
1103            if isinstance(self.db, text_type):
1104                self.db = self.db.encode(self.encoding)
1105            data += self.db + b'\0'
1106
1107        if self.server_capabilities & CLIENT.PLUGIN_AUTH:
1108            name = self._auth_plugin_name
1109            if isinstance(name, text_type):
1110                name = name.encode('ascii')
1111            data += name + b'\0'
1112
1113        self.write_packet(data)
1114        auth_packet = self._read_packet()
1115
1116        # if authentication method isn't accepted the first byte
1117        # will have the octet 254
1118        if auth_packet.is_auth_switch_request():
1119            # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
1120            auth_packet.read_uint8() # 0xfe packet identifier
1121            plugin_name = auth_packet.read_string()
1122            if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None:
1123                auth_packet = self._process_auth(plugin_name, auth_packet)
1124            else:
1125                # send legacy handshake
1126                data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
1127                self.write_packet(data)
1128                auth_packet = self._read_packet()
1129
1130    def _process_auth(self, plugin_name, auth_packet):
1131        plugin_class = self._auth_plugin_map.get(plugin_name)
1132        if not plugin_class:
1133            plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
1134        if plugin_class:
1135            try:
1136                handler = plugin_class(self)
1137                return handler.authenticate(auth_packet)
1138            except AttributeError:
1139                if plugin_name != b'dialog':
1140                    raise err.OperationalError(2059, "Authentication plugin '%s'" \
1141                              " not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
1142            except TypeError:
1143                raise err.OperationalError(2059, "Authentication plugin '%s'" \
1144                    " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
1145        else:
1146            handler = None
1147        if plugin_name == b"mysql_native_password":
1148            # https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
1149            data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
1150        elif plugin_name == b"mysql_old_password":
1151            # https://dev.mysql.com/doc/internals/en/old-password-authentication.html
1152            data = _scramble_323(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
1153        elif plugin_name == b"mysql_clear_password":
1154            # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
1155            data = self.password.encode('latin1') + b'\0'
1156        elif plugin_name == b"dialog":
1157            pkt = auth_packet
1158            while True:
1159                flag = pkt.read_uint8()
1160                echo = (flag & 0x06) == 0x02
1161                last = (flag & 0x01) == 0x01
1162                prompt = pkt.read_all()
1163
1164                if prompt == b"Password: ":
1165                    self.write_packet(self.password.encode('latin1') + b'\0')
1166                elif handler:
1167                    resp = 'no response - TypeError within plugin.prompt method'
1168                    try:
1169                        resp = handler.prompt(echo, prompt)
1170                        self.write_packet(resp + b'\0')
1171                    except AttributeError:
1172                        raise err.OperationalError(2059, "Authentication plugin '%s'" \
1173                                  " not loaded: - %r missing prompt method" % (plugin_name, handler))
1174                    except TypeError:
1175                        raise err.OperationalError(2061, "Authentication plugin '%s'" \
1176                                  " %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt))
1177                else:
1178                    raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler))
1179                pkt = self._read_packet()
1180                pkt.check_error()
1181                if pkt.is_ok_packet() or last:
1182                    break
1183            return pkt
1184        else:
1185            raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name)
1186
1187        self.write_packet(data)
1188        pkt = self._read_packet()
1189        pkt.check_error()
1190        return pkt
1191
1192    # _mysql support
1193    def thread_id(self):
1194        return self.server_thread_id[0]
1195
1196    def character_set_name(self):
1197        return self.charset
1198
1199    def get_host_info(self):
1200        return self.host_info
1201
1202    def get_proto_info(self):
1203        return self.protocol_version
1204
1205    def _get_server_information(self):
1206        i = 0
1207        packet = self._read_packet()
1208        data = packet.get_all_data()
1209
1210        if DEBUG: dump_packet(data)
1211        self.protocol_version = byte2int(data[i:i+1])
1212        i += 1
1213
1214        server_end = data.find(b'\0', i)
1215        self.server_version = data[i:server_end].decode('latin1')
1216        i = server_end + 1
1217
1218        self.server_thread_id = struct.unpack('<I', data[i:i+4])
1219        i += 4
1220
1221        self.salt = data[i:i+8]
1222        i += 9  # 8 + 1(filler)
1223
1224        self.server_capabilities = struct.unpack('<H', data[i:i+2])[0]
1225        i += 2
1226
1227        if len(data) >= i + 6:
1228            lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i:i+6])
1229            i += 6
1230            self.server_language = lang
1231            self.server_charset = charset_by_id(lang).name
1232
1233            self.server_status = stat
1234            if DEBUG: print("server_status: %x" % stat)
1235
1236            self.server_capabilities |= cap_h << 16
1237            if DEBUG: print("salt_len:", salt_len)
1238            salt_len = max(12, salt_len - 9)
1239
1240        # reserved
1241        i += 10
1242
1243        if len(data) >= i + salt_len:
1244            # salt_len includes auth_plugin_data_part_1 and filler
1245            self.salt += data[i:i+salt_len]
1246            i += salt_len
1247
1248        i+=1
1249        # AUTH PLUGIN NAME may appear here.
1250        if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1251            # Due to Bug#59453 the auth-plugin-name is missing the terminating
1252            # NUL-char in versions prior to 5.5.10 and 5.6.2.
1253            # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
1254            # didn't use version checks as mariadb is corrected and reports
1255            # earlier than those two.
1256            server_end = data.find(b'\0', i)
1257            if server_end < 0: # pragma: no cover - very specific upstream bug
1258                # not found \0 and last field so take it all
1259                self._auth_plugin_name = data[i:].decode('latin1')
1260            else:
1261                self._auth_plugin_name = data[i:server_end].decode('latin1')
1262
1263    def get_server_info(self):
1264        return self.server_version
1265
1266    Warning = err.Warning
1267    Error = err.Error
1268    InterfaceError = err.InterfaceError
1269    DatabaseError = err.DatabaseError
1270    DataError = err.DataError
1271    OperationalError = err.OperationalError
1272    IntegrityError = err.IntegrityError
1273    InternalError = err.InternalError
1274    ProgrammingError = err.ProgrammingError
1275    NotSupportedError = err.NotSupportedError
1276
1277
1278class MySQLResult(object):
1279
1280    def __init__(self, connection):
1281        """
1282        :type connection: Connection
1283        """
1284        self.connection = connection
1285        self.affected_rows = None
1286        self.insert_id = None
1287        self.server_status = None
1288        self.warning_count = 0
1289        self.message = None
1290        self.field_count = 0
1291        self.description = None
1292        self.rows = None
1293        self.has_next = None
1294        self.unbuffered_active = False
1295
1296    def __del__(self):
1297        if self.unbuffered_active:
1298            self._finish_unbuffered_query()
1299
1300    def read(self):
1301        try:
1302            first_packet = self.connection._read_packet()
1303
1304            if first_packet.is_ok_packet():
1305                self._read_ok_packet(first_packet)
1306            elif first_packet.is_load_local_packet():
1307                self._read_load_local_packet(first_packet)
1308            else:
1309                self._read_result_packet(first_packet)
1310        finally:
1311            self.connection = None
1312
1313    def init_unbuffered_query(self):
1314        self.unbuffered_active = True
1315        first_packet = self.connection._read_packet()
1316
1317        if first_packet.is_ok_packet():
1318            self._read_ok_packet(first_packet)
1319            self.unbuffered_active = False
1320            self.connection = None
1321        elif first_packet.is_load_local_packet():
1322            self._read_load_local_packet(first_packet)
1323            self.unbuffered_active = False
1324            self.connection = None
1325        else:
1326            self.field_count = first_packet.read_length_encoded_integer()
1327            self._get_descriptions()
1328
1329            # Apparently, MySQLdb picks this number because it's the maximum
1330            # value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1331            # we set it to this instead of None, which would be preferred.
1332            self.affected_rows = 18446744073709551615
1333
1334    def _read_ok_packet(self, first_packet):
1335        ok_packet = OKPacketWrapper(first_packet)
1336        self.affected_rows = ok_packet.affected_rows
1337        self.insert_id = ok_packet.insert_id
1338        self.server_status = ok_packet.server_status
1339        self.warning_count = ok_packet.warning_count
1340        self.message = ok_packet.message
1341        self.has_next = ok_packet.has_next
1342
1343    def _read_load_local_packet(self, first_packet):
1344        load_packet = LoadLocalPacketWrapper(first_packet)
1345        sender = LoadLocalFile(load_packet.filename, self.connection)
1346        try:
1347            sender.send_data()
1348        except:
1349            self.connection._read_packet()  # skip ok packet
1350            raise
1351
1352        ok_packet = self.connection._read_packet()
1353        if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
1354            raise err.OperationalError(2014, "Commands Out of Sync")
1355        self._read_ok_packet(ok_packet)
1356
1357    def _check_packet_is_eof(self, packet):
1358        if not packet.is_eof_packet():
1359            return False
1360        #TODO: Support CLIENT.DEPRECATE_EOF
1361        # 1) Add DEPRECATE_EOF to CAPABILITIES
1362        # 2) Mask CAPABILITIES with server_capabilities
1363        # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
1364        wp = EOFPacketWrapper(packet)
1365        self.warning_count = wp.warning_count
1366        self.has_next = wp.has_next
1367        return True
1368
1369    def _read_result_packet(self, first_packet):
1370        self.field_count = first_packet.read_length_encoded_integer()
1371        self._get_descriptions()
1372        self._read_rowdata_packet()
1373
1374    def _read_rowdata_packet_unbuffered(self):
1375        # Check if in an active query
1376        if not self.unbuffered_active:
1377            return
1378
1379        # EOF
1380        packet = self.connection._read_packet()
1381        if self._check_packet_is_eof(packet):
1382            self.unbuffered_active = False
1383            self.connection = None
1384            self.rows = None
1385            return
1386
1387        row = self._read_row_from_packet(packet)
1388        self.affected_rows = 1
1389        self.rows = (row,)  # rows should tuple of row for MySQL-python compatibility.
1390        return row
1391
1392    def _finish_unbuffered_query(self):
1393        # After much reading on the MySQL protocol, it appears that there is,
1394        # in fact, no way to stop MySQL from sending all the data after
1395        # executing a query, so we just spin, and wait for an EOF packet.
1396        while self.unbuffered_active:
1397            packet = self.connection._read_packet()
1398            if self._check_packet_is_eof(packet):
1399                self.unbuffered_active = False
1400                self.connection = None  # release reference to kill cyclic reference.
1401
1402    def _read_rowdata_packet(self):
1403        """Read a rowdata packet for each data row in the result set."""
1404        rows = []
1405        while True:
1406            packet = self.connection._read_packet()
1407            if self._check_packet_is_eof(packet):
1408                self.connection = None  # release reference to kill cyclic reference.
1409                break
1410            rows.append(self._read_row_from_packet(packet))
1411
1412        self.affected_rows = len(rows)
1413        self.rows = tuple(rows)
1414
1415    def _read_row_from_packet(self, packet):
1416        row = []
1417        for encoding, converter in self.converters:
1418            try:
1419                data = packet.read_length_coded_string()
1420            except IndexError:
1421                # No more columns in this row
1422                # See https://github.com/PyMySQL/PyMySQL/pull/434
1423                break
1424            if data is not None:
1425                if encoding is not None:
1426                    data = data.decode(encoding)
1427                if DEBUG: print("DEBUG: DATA = ", data)
1428                if converter is not None:
1429                    data = converter(data)
1430            row.append(data)
1431        return tuple(row)
1432
1433    def _get_descriptions(self):
1434        """Read a column descriptor packet for each column in the result."""
1435        self.fields = []
1436        self.converters = []
1437        use_unicode = self.connection.use_unicode
1438        conn_encoding = self.connection.encoding
1439        description = []
1440
1441        for i in range_type(self.field_count):
1442            field = self.connection._read_packet(FieldDescriptorPacket)
1443            self.fields.append(field)
1444            description.append(field.description())
1445            field_type = field.type_code
1446            if use_unicode:
1447                if field_type == FIELD_TYPE.JSON:
1448                    # When SELECT from JSON column: charset = binary
1449                    # When SELECT CAST(... AS JSON): charset = connection encoding
1450                    # This behavior is different from TEXT / BLOB.
1451                    # We should decode result by connection encoding regardless charsetnr.
1452                    # See https://github.com/PyMySQL/PyMySQL/issues/488
1453                    encoding = conn_encoding  # SELECT CAST(... AS JSON)
1454                elif field_type in TEXT_TYPES:
1455                    if field.charsetnr == 63:  # binary
1456                        # TEXTs with charset=binary means BINARY types.
1457                        encoding = None
1458                    else:
1459                        encoding = conn_encoding
1460                else:
1461                    # Integers, Dates and Times, and other basic data is encoded in ascii
1462                    encoding = 'ascii'
1463            else:
1464                encoding = None
1465            converter = self.connection.decoders.get(field_type)
1466            if converter is through:
1467                converter = None
1468            if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
1469            self.converters.append((encoding, converter))
1470
1471        eof_packet = self.connection._read_packet()
1472        assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
1473        self.description = tuple(description)
1474
1475
1476class LoadLocalFile(object):
1477    def __init__(self, filename, connection):
1478        self.filename = filename
1479        self.connection = connection
1480
1481    def send_data(self):
1482        """Send data packets from the local file to the server"""
1483        if not self.connection._sock:
1484            raise err.InterfaceError("(0, '')")
1485        conn = self.connection
1486
1487        try:
1488            with open(self.filename, 'rb') as open_file:
1489                packet_size = min(conn.max_allowed_packet, 16*1024)  # 16KB is efficient enough
1490                while True:
1491                    chunk = open_file.read(packet_size)
1492                    if not chunk:
1493                        break
1494                    conn.write_packet(chunk)
1495        except IOError:
1496            raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
1497        finally:
1498            # send the empty packet to signify we are done sending data
1499            conn.write_packet(b'')
Note: See TracBrowser for help on using the repository browser.