1 | #!/usr/bin/env python -O |
---|
2 | """ Script to test database capabilities and the DB-API interface |
---|
3 | for functionality and memory leaks. |
---|
4 | |
---|
5 | Adapted from a script by M-A Lemburg. |
---|
6 | |
---|
7 | """ |
---|
8 | import sys |
---|
9 | from time import time |
---|
10 | try: |
---|
11 | import unittest2 as unittest |
---|
12 | except ImportError: |
---|
13 | import unittest |
---|
14 | |
---|
15 | PY2 = sys.version_info[0] == 2 |
---|
16 | |
---|
17 | class DatabaseTest(unittest.TestCase): |
---|
18 | |
---|
19 | db_module = None |
---|
20 | connect_args = () |
---|
21 | connect_kwargs = dict(use_unicode=True, charset="utf8") |
---|
22 | create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" |
---|
23 | rows = 10 |
---|
24 | debug = False |
---|
25 | |
---|
26 | def setUp(self): |
---|
27 | db = self.db_module.connect(*self.connect_args, **self.connect_kwargs) |
---|
28 | self.connection = db |
---|
29 | self.cursor = db.cursor() |
---|
30 | self.BLOBText = ''.join([chr(i) for i in range(256)] * 100); |
---|
31 | if PY2: |
---|
32 | self.BLOBUText = unicode().join(unichr(i) for i in range(16834)) |
---|
33 | else: |
---|
34 | self.BLOBUText = "".join(chr(i) for i in range(16834)) |
---|
35 | data = bytearray(range(256)) * 16 |
---|
36 | self.BLOBBinary = self.db_module.Binary(data) |
---|
37 | |
---|
38 | leak_test = True |
---|
39 | |
---|
40 | def tearDown(self): |
---|
41 | if self.leak_test: |
---|
42 | import gc |
---|
43 | del self.cursor |
---|
44 | orphans = gc.collect() |
---|
45 | self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans) |
---|
46 | |
---|
47 | del self.connection |
---|
48 | orphans = gc.collect() |
---|
49 | self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans) |
---|
50 | |
---|
51 | def table_exists(self, name): |
---|
52 | try: |
---|
53 | self.cursor.execute('select * from %s where 1=0' % name) |
---|
54 | except Exception: |
---|
55 | return False |
---|
56 | else: |
---|
57 | return True |
---|
58 | |
---|
59 | def quote_identifier(self, ident): |
---|
60 | return '"%s"' % ident |
---|
61 | |
---|
62 | def new_table_name(self): |
---|
63 | i = id(self.cursor) |
---|
64 | while True: |
---|
65 | name = self.quote_identifier('tb%08x' % i) |
---|
66 | if not self.table_exists(name): |
---|
67 | return name |
---|
68 | i = i + 1 |
---|
69 | |
---|
70 | def create_table(self, columndefs): |
---|
71 | |
---|
72 | """ Create a table using a list of column definitions given in |
---|
73 | columndefs. |
---|
74 | |
---|
75 | generator must be a function taking arguments (row_number, |
---|
76 | col_number) returning a suitable data object for insertion |
---|
77 | into the table. |
---|
78 | |
---|
79 | """ |
---|
80 | self.table = self.new_table_name() |
---|
81 | self.cursor.execute('CREATE TABLE %s (%s) %s' % |
---|
82 | (self.table, |
---|
83 | ',\n'.join(columndefs), |
---|
84 | self.create_table_extra)) |
---|
85 | |
---|
86 | def check_data_integrity(self, columndefs, generator): |
---|
87 | # insert |
---|
88 | self.create_table(columndefs) |
---|
89 | insert_statement = ('INSERT INTO %s VALUES (%s)' % |
---|
90 | (self.table, |
---|
91 | ','.join(['%s'] * len(columndefs)))) |
---|
92 | data = [ [ generator(i,j) for j in range(len(columndefs)) ] |
---|
93 | for i in range(self.rows) ] |
---|
94 | if self.debug: |
---|
95 | print(data) |
---|
96 | self.cursor.executemany(insert_statement, data) |
---|
97 | self.connection.commit() |
---|
98 | # verify |
---|
99 | self.cursor.execute('select * from %s' % self.table) |
---|
100 | l = self.cursor.fetchall() |
---|
101 | if self.debug: |
---|
102 | print(l) |
---|
103 | self.assertEqual(len(l), self.rows) |
---|
104 | try: |
---|
105 | for i in range(self.rows): |
---|
106 | for j in range(len(columndefs)): |
---|
107 | self.assertEqual(l[i][j], generator(i,j)) |
---|
108 | finally: |
---|
109 | if not self.debug: |
---|
110 | self.cursor.execute('drop table %s' % (self.table)) |
---|
111 | |
---|
112 | def test_transactions(self): |
---|
113 | columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') |
---|
114 | def generator(row, col): |
---|
115 | if col == 0: return row |
---|
116 | else: return ('%i' % (row%10))*255 |
---|
117 | self.create_table(columndefs) |
---|
118 | insert_statement = ('INSERT INTO %s VALUES (%s)' % |
---|
119 | (self.table, |
---|
120 | ','.join(['%s'] * len(columndefs)))) |
---|
121 | data = [ [ generator(i,j) for j in range(len(columndefs)) ] |
---|
122 | for i in range(self.rows) ] |
---|
123 | self.cursor.executemany(insert_statement, data) |
---|
124 | # verify |
---|
125 | self.connection.commit() |
---|
126 | self.cursor.execute('select * from %s' % self.table) |
---|
127 | l = self.cursor.fetchall() |
---|
128 | self.assertEqual(len(l), self.rows) |
---|
129 | for i in range(self.rows): |
---|
130 | for j in range(len(columndefs)): |
---|
131 | self.assertEqual(l[i][j], generator(i,j)) |
---|
132 | delete_statement = 'delete from %s where col1=%%s' % self.table |
---|
133 | self.cursor.execute(delete_statement, (0,)) |
---|
134 | self.cursor.execute('select col1 from %s where col1=%s' % \ |
---|
135 | (self.table, 0)) |
---|
136 | l = self.cursor.fetchall() |
---|
137 | self.assertFalse(l, "DELETE didn't work") |
---|
138 | self.connection.rollback() |
---|
139 | self.cursor.execute('select col1 from %s where col1=%s' % \ |
---|
140 | (self.table, 0)) |
---|
141 | l = self.cursor.fetchall() |
---|
142 | self.assertTrue(len(l) == 1, "ROLLBACK didn't work") |
---|
143 | self.cursor.execute('drop table %s' % (self.table)) |
---|
144 | |
---|
145 | def test_truncation(self): |
---|
146 | columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') |
---|
147 | def generator(row, col): |
---|
148 | if col == 0: return row |
---|
149 | else: return ('%i' % (row%10))*((255-self.rows//2)+row) |
---|
150 | self.create_table(columndefs) |
---|
151 | insert_statement = ('INSERT INTO %s VALUES (%s)' % |
---|
152 | (self.table, |
---|
153 | ','.join(['%s'] * len(columndefs)))) |
---|
154 | |
---|
155 | try: |
---|
156 | self.cursor.execute(insert_statement, (0, '0'*256)) |
---|
157 | except Warning: |
---|
158 | if self.debug: print(self.cursor.messages) |
---|
159 | except self.connection.DataError: |
---|
160 | pass |
---|
161 | else: |
---|
162 | self.fail("Over-long column did not generate warnings/exception with single insert") |
---|
163 | |
---|
164 | self.connection.rollback() |
---|
165 | |
---|
166 | try: |
---|
167 | for i in range(self.rows): |
---|
168 | data = [] |
---|
169 | for j in range(len(columndefs)): |
---|
170 | data.append(generator(i,j)) |
---|
171 | self.cursor.execute(insert_statement,tuple(data)) |
---|
172 | except Warning: |
---|
173 | if self.debug: print(self.cursor.messages) |
---|
174 | except self.connection.DataError: |
---|
175 | pass |
---|
176 | else: |
---|
177 | self.fail("Over-long columns did not generate warnings/exception with execute()") |
---|
178 | |
---|
179 | self.connection.rollback() |
---|
180 | |
---|
181 | try: |
---|
182 | data = [ [ generator(i,j) for j in range(len(columndefs)) ] |
---|
183 | for i in range(self.rows) ] |
---|
184 | self.cursor.executemany(insert_statement, data) |
---|
185 | except Warning: |
---|
186 | if self.debug: print(self.cursor.messages) |
---|
187 | except self.connection.DataError: |
---|
188 | pass |
---|
189 | else: |
---|
190 | self.fail("Over-long columns did not generate warnings/exception with executemany()") |
---|
191 | |
---|
192 | self.connection.rollback() |
---|
193 | self.cursor.execute('drop table %s' % (self.table)) |
---|
194 | |
---|
195 | def test_CHAR(self): |
---|
196 | # Character data |
---|
197 | def generator(row,col): |
---|
198 | return ('%i' % ((row+col) % 10)) * 255 |
---|
199 | self.check_data_integrity( |
---|
200 | ('col1 char(255)','col2 char(255)'), |
---|
201 | generator) |
---|
202 | |
---|
203 | def test_INT(self): |
---|
204 | # Number data |
---|
205 | def generator(row,col): |
---|
206 | return row*row |
---|
207 | self.check_data_integrity( |
---|
208 | ('col1 INT',), |
---|
209 | generator) |
---|
210 | |
---|
211 | def test_DECIMAL(self): |
---|
212 | # DECIMAL |
---|
213 | def generator(row,col): |
---|
214 | from decimal import Decimal |
---|
215 | return Decimal("%d.%02d" % (row, col)) |
---|
216 | self.check_data_integrity( |
---|
217 | ('col1 DECIMAL(5,2)',), |
---|
218 | generator) |
---|
219 | |
---|
220 | def test_DATE(self): |
---|
221 | ticks = time() |
---|
222 | def generator(row,col): |
---|
223 | return self.db_module.DateFromTicks(ticks+row*86400-col*1313) |
---|
224 | self.check_data_integrity( |
---|
225 | ('col1 DATE',), |
---|
226 | generator) |
---|
227 | |
---|
228 | def test_TIME(self): |
---|
229 | ticks = time() |
---|
230 | def generator(row,col): |
---|
231 | return self.db_module.TimeFromTicks(ticks+row*86400-col*1313) |
---|
232 | self.check_data_integrity( |
---|
233 | ('col1 TIME',), |
---|
234 | generator) |
---|
235 | |
---|
236 | def test_DATETIME(self): |
---|
237 | ticks = time() |
---|
238 | def generator(row,col): |
---|
239 | return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) |
---|
240 | self.check_data_integrity( |
---|
241 | ('col1 DATETIME',), |
---|
242 | generator) |
---|
243 | |
---|
244 | def test_TIMESTAMP(self): |
---|
245 | ticks = time() |
---|
246 | def generator(row,col): |
---|
247 | return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) |
---|
248 | self.check_data_integrity( |
---|
249 | ('col1 TIMESTAMP',), |
---|
250 | generator) |
---|
251 | |
---|
252 | def test_fractional_TIMESTAMP(self): |
---|
253 | ticks = time() |
---|
254 | def generator(row,col): |
---|
255 | return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) |
---|
256 | self.check_data_integrity( |
---|
257 | ('col1 TIMESTAMP',), |
---|
258 | generator) |
---|
259 | |
---|
260 | def test_LONG(self): |
---|
261 | def generator(row,col): |
---|
262 | if col == 0: |
---|
263 | return row |
---|
264 | else: |
---|
265 | return self.BLOBUText # 'BLOB Text ' * 1024 |
---|
266 | self.check_data_integrity( |
---|
267 | ('col1 INT', 'col2 LONG'), |
---|
268 | generator) |
---|
269 | |
---|
270 | def test_TEXT(self): |
---|
271 | def generator(row,col): |
---|
272 | if col == 0: |
---|
273 | return row |
---|
274 | else: |
---|
275 | return self.BLOBUText[:5192] # 'BLOB Text ' * 1024 |
---|
276 | self.check_data_integrity( |
---|
277 | ('col1 INT', 'col2 TEXT'), |
---|
278 | generator) |
---|
279 | |
---|
280 | def test_LONG_BYTE(self): |
---|
281 | def generator(row,col): |
---|
282 | if col == 0: |
---|
283 | return row |
---|
284 | else: |
---|
285 | return self.BLOBBinary # 'BLOB\000Binary ' * 1024 |
---|
286 | self.check_data_integrity( |
---|
287 | ('col1 INT','col2 LONG BYTE'), |
---|
288 | generator) |
---|
289 | |
---|
290 | def test_BLOB(self): |
---|
291 | def generator(row,col): |
---|
292 | if col == 0: |
---|
293 | return row |
---|
294 | else: |
---|
295 | return self.BLOBBinary # 'BLOB\000Binary ' * 1024 |
---|
296 | self.check_data_integrity( |
---|
297 | ('col1 INT','col2 BLOB'), |
---|
298 | generator) |
---|