1 | from pymysql.tests import base |
---|
2 | import pymysql.cursors |
---|
3 | |
---|
4 | import datetime |
---|
5 | import warnings |
---|
6 | |
---|
7 | |
---|
8 | class TestDictCursor(base.PyMySQLTestCase): |
---|
9 | bob = {'name': 'bob', 'age': 21, 'DOB': datetime.datetime(1990, 2, 6, 23, 4, 56)} |
---|
10 | jim = {'name': 'jim', 'age': 56, 'DOB': datetime.datetime(1955, 5, 9, 13, 12, 45)} |
---|
11 | fred = {'name': 'fred', 'age': 100, 'DOB': datetime.datetime(1911, 9, 12, 1, 1, 1)} |
---|
12 | |
---|
13 | cursor_type = pymysql.cursors.DictCursor |
---|
14 | |
---|
15 | def setUp(self): |
---|
16 | super(TestDictCursor, self).setUp() |
---|
17 | self.conn = conn = self.connections[0] |
---|
18 | c = conn.cursor(self.cursor_type) |
---|
19 | |
---|
20 | # create a table ane some data to query |
---|
21 | with warnings.catch_warnings(): |
---|
22 | warnings.filterwarnings("ignore") |
---|
23 | c.execute("drop table if exists dictcursor") |
---|
24 | # include in filterwarnings since for unbuffered dict cursor warning for lack of table |
---|
25 | # will only be propagated at start of next execute() call |
---|
26 | c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""") |
---|
27 | data = [("bob", 21, "1990-02-06 23:04:56"), |
---|
28 | ("jim", 56, "1955-05-09 13:12:45"), |
---|
29 | ("fred", 100, "1911-09-12 01:01:01")] |
---|
30 | c.executemany("insert into dictcursor values (%s,%s,%s)", data) |
---|
31 | |
---|
32 | def tearDown(self): |
---|
33 | c = self.conn.cursor() |
---|
34 | c.execute("drop table dictcursor") |
---|
35 | super(TestDictCursor, self).tearDown() |
---|
36 | |
---|
37 | def _ensure_cursor_expired(self, cursor): |
---|
38 | pass |
---|
39 | |
---|
40 | def test_DictCursor(self): |
---|
41 | bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy() |
---|
42 | #all assert test compare to the structure as would come out from MySQLdb |
---|
43 | conn = self.conn |
---|
44 | c = conn.cursor(self.cursor_type) |
---|
45 | |
---|
46 | # try an update which should return no rows |
---|
47 | c.execute("update dictcursor set age=20 where name='bob'") |
---|
48 | bob['age'] = 20 |
---|
49 | # pull back the single row dict for bob and check |
---|
50 | c.execute("SELECT * from dictcursor where name='bob'") |
---|
51 | r = c.fetchone() |
---|
52 | self.assertEqual(bob, r, "fetchone via DictCursor failed") |
---|
53 | self._ensure_cursor_expired(c) |
---|
54 | |
---|
55 | # same again, but via fetchall => tuple) |
---|
56 | c.execute("SELECT * from dictcursor where name='bob'") |
---|
57 | r = c.fetchall() |
---|
58 | self.assertEqual([bob], r, "fetch a 1 row result via fetchall failed via DictCursor") |
---|
59 | # same test again but iterate over the |
---|
60 | c.execute("SELECT * from dictcursor where name='bob'") |
---|
61 | for r in c: |
---|
62 | self.assertEqual(bob, r, "fetch a 1 row result via iteration failed via DictCursor") |
---|
63 | # get all 3 row via fetchall |
---|
64 | c.execute("SELECT * from dictcursor") |
---|
65 | r = c.fetchall() |
---|
66 | self.assertEqual([bob,jim,fred], r, "fetchall failed via DictCursor") |
---|
67 | #same test again but do a list comprehension |
---|
68 | c.execute("SELECT * from dictcursor") |
---|
69 | r = list(c) |
---|
70 | self.assertEqual([bob,jim,fred], r, "DictCursor should be iterable") |
---|
71 | # get all 2 row via fetchmany |
---|
72 | c.execute("SELECT * from dictcursor") |
---|
73 | r = c.fetchmany(2) |
---|
74 | self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor") |
---|
75 | self._ensure_cursor_expired(c) |
---|
76 | |
---|
77 | def test_custom_dict(self): |
---|
78 | class MyDict(dict): pass |
---|
79 | |
---|
80 | class MyDictCursor(self.cursor_type): |
---|
81 | dict_type = MyDict |
---|
82 | |
---|
83 | keys = ['name', 'age', 'DOB'] |
---|
84 | bob = MyDict([(k, self.bob[k]) for k in keys]) |
---|
85 | jim = MyDict([(k, self.jim[k]) for k in keys]) |
---|
86 | fred = MyDict([(k, self.fred[k]) for k in keys]) |
---|
87 | |
---|
88 | cur = self.conn.cursor(MyDictCursor) |
---|
89 | cur.execute("SELECT * FROM dictcursor WHERE name='bob'") |
---|
90 | r = cur.fetchone() |
---|
91 | self.assertEqual(bob, r, "fetchone() returns MyDictCursor") |
---|
92 | self._ensure_cursor_expired(cur) |
---|
93 | |
---|
94 | cur.execute("SELECT * FROM dictcursor") |
---|
95 | r = cur.fetchall() |
---|
96 | self.assertEqual([bob, jim, fred], r, |
---|
97 | "fetchall failed via MyDictCursor") |
---|
98 | |
---|
99 | cur.execute("SELECT * FROM dictcursor") |
---|
100 | r = list(cur) |
---|
101 | self.assertEqual([bob, jim, fred], r, |
---|
102 | "list failed via MyDictCursor") |
---|
103 | |
---|
104 | cur.execute("SELECT * FROM dictcursor") |
---|
105 | r = cur.fetchmany(2) |
---|
106 | self.assertEqual([bob, jim], r, |
---|
107 | "list failed via MyDictCursor") |
---|
108 | self._ensure_cursor_expired(cur) |
---|
109 | |
---|
110 | |
---|
111 | class TestSSDictCursor(TestDictCursor): |
---|
112 | cursor_type = pymysql.cursors.SSDictCursor |
---|
113 | |
---|
114 | def _ensure_cursor_expired(self, cursor): |
---|
115 | list(cursor.fetchall_unbuffered()) |
---|
116 | |
---|
117 | if __name__ == "__main__": |
---|
118 | import unittest |
---|
119 | unittest.main() |
---|