1 | import gc |
---|
2 | import json |
---|
3 | import os |
---|
4 | import re |
---|
5 | import warnings |
---|
6 | |
---|
7 | import unittest2 |
---|
8 | |
---|
9 | import pymysql |
---|
10 | from .._compat import CPYTHON |
---|
11 | |
---|
12 | |
---|
13 | class PyMySQLTestCase(unittest2.TestCase): |
---|
14 | # You can specify your test environment creating a file named |
---|
15 | # "databases.json" or editing the `databases` variable below. |
---|
16 | fname = os.path.join(os.path.dirname(__file__), "databases.json") |
---|
17 | if os.path.exists(fname): |
---|
18 | with open(fname) as f: |
---|
19 | databases = json.load(f) |
---|
20 | else: |
---|
21 | databases = [ |
---|
22 | {"host":"localhost","user":"root", |
---|
23 | "passwd":"","db":"test_pymysql", "use_unicode": True, 'local_infile': True}, |
---|
24 | {"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}] |
---|
25 | |
---|
26 | def mysql_server_is(self, conn, version_tuple): |
---|
27 | """Return True if the given connection is on the version given or |
---|
28 | greater. |
---|
29 | |
---|
30 | e.g.:: |
---|
31 | |
---|
32 | if self.mysql_server_is(conn, (5, 6, 4)): |
---|
33 | # do something for MySQL 5.6.4 and above |
---|
34 | """ |
---|
35 | server_version = conn.get_server_info() |
---|
36 | server_version_tuple = tuple( |
---|
37 | (int(dig) if dig is not None else 0) |
---|
38 | for dig in |
---|
39 | re.match(r'(\d+)\.(\d+)\.(\d+)', server_version).group(1, 2, 3) |
---|
40 | ) |
---|
41 | return server_version_tuple >= version_tuple |
---|
42 | |
---|
43 | def setUp(self): |
---|
44 | self.connections = [] |
---|
45 | for params in self.databases: |
---|
46 | self.connections.append(pymysql.connect(**params)) |
---|
47 | self.addCleanup(self._teardown_connections) |
---|
48 | |
---|
49 | def _teardown_connections(self): |
---|
50 | for connection in self.connections: |
---|
51 | connection.close() |
---|
52 | |
---|
53 | def safe_create_table(self, connection, tablename, ddl, cleanup=True): |
---|
54 | """create a table. |
---|
55 | |
---|
56 | Ensures any existing version of that table is first dropped. |
---|
57 | |
---|
58 | Also adds a cleanup rule to drop the table after the test |
---|
59 | completes. |
---|
60 | """ |
---|
61 | cursor = connection.cursor() |
---|
62 | |
---|
63 | with warnings.catch_warnings(): |
---|
64 | warnings.simplefilter("ignore") |
---|
65 | cursor.execute("drop table if exists `%s`" % (tablename,)) |
---|
66 | cursor.execute(ddl) |
---|
67 | cursor.close() |
---|
68 | if cleanup: |
---|
69 | self.addCleanup(self.drop_table, connection, tablename) |
---|
70 | |
---|
71 | def drop_table(self, connection, tablename): |
---|
72 | cursor = connection.cursor() |
---|
73 | with warnings.catch_warnings(): |
---|
74 | warnings.simplefilter("ignore") |
---|
75 | cursor.execute("drop table if exists `%s`" % (tablename,)) |
---|
76 | cursor.close() |
---|
77 | |
---|
78 | def safe_gc_collect(self): |
---|
79 | """Ensure cycles are collected via gc. |
---|
80 | |
---|
81 | Runs additional times on non-CPython platforms. |
---|
82 | |
---|
83 | """ |
---|
84 | gc.collect() |
---|
85 | if not CPYTHON: |
---|
86 | gc.collect() |
---|