1from __future__ import print_function
2# gevent-test-requires-resource: psycopg2
3# pylint:disable=import-error,broad-except,bare-except
4import sys
5import contextlib
6
7import gevent
8from gevent.queue import Queue
9from gevent.socket import wait_read, wait_write
10from psycopg2 import extensions, OperationalError, connect
11
12
13if sys.version_info[0] >= 3:
14 integer_types = (int,)
15else:
16 import __builtin__
17 integer_types = (int, __builtin__.long)
18
19
20def gevent_wait_callback(conn, timeout=None):
21 """A wait callback useful to allow gevent to work with Psycopg."""
22 while 1:
23 state = conn.poll()
24 if state == extensions.POLL_OK:
25 break
26 elif state == extensions.POLL_READ:
27 wait_read(conn.fileno(), timeout=timeout)
28 elif state == extensions.POLL_WRITE:
29 wait_write(conn.fileno(), timeout=timeout)
30 else:
31 raise OperationalError(
32 "Bad result from poll: %r" % state)
33
34
35extensions.set_wait_callback(gevent_wait_callback)
36
37
38class AbstractDatabaseConnectionPool(object):
39
40 def __init__(self, maxsize=100):
41 if not isinstance(maxsize, integer_types):
42 raise TypeError('Expected integer, got %r' % (maxsize, ))
43 self.maxsize = maxsize
44 self.pool = Queue()
45 self.size = 0
46
47 def create_connection(self):
48 raise NotImplementedError()
49
50 def get(self):
51 pool = self.pool
52 if self.size >= self.maxsize or pool.qsize():
53 return pool.get()
54
55 self.size += 1
56 try:
57 new_item = self.create_connection()
58 except:
59 self.size -= 1
60 raise
61 return new_item
62
63 def put(self, item):
64 self.pool.put(item)
65
66 def closeall(self):
67 while not self.pool.empty():
68 conn = self.pool.get_nowait()
69 try:
70 conn.close()
71 except Exception:
72 pass
73
74 @contextlib.contextmanager
75 def connection(self, isolation_level=None):
76 conn = self.get()
77 try:
78 if isolation_level is not None:
79 if conn.isolation_level == isolation_level:
80 isolation_level = None
81 else:
82 conn.set_isolation_level(isolation_level)
83 yield conn
84 except:
85 if conn.closed:
86 conn = None
87 self.closeall()
88 else:
89 conn = self._rollback(conn)
90 raise
91 else:
92 if conn.closed:
93 raise OperationalError("Cannot commit because connection was closed: %r" % (conn, ))
94 conn.commit()
95 finally:
96 if conn is not None and not conn.closed:
97 if isolation_level is not None:
98 conn.set_isolation_level(isolation_level)
99 self.put(conn)
100 else:
101 self.size -= 1
102
103 @contextlib.contextmanager
104 def cursor(self, *args, **kwargs):
105 isolation_level = kwargs.pop('isolation_level', None)
106 with self.connection(isolation_level) as conn:
107 yield conn.cursor(*args, **kwargs)
108
109 def _rollback(self, conn):
110 try:
111 conn.rollback()
112 except:
113 gevent.get_hub().handle_error(conn, *sys.exc_info())
114 return
115 return conn
116
117 def execute(self, *args, **kwargs):
118 with self.cursor(**kwargs) as cursor:
119 cursor.execute(*args)
120 return cursor.rowcount
121
122 def fetchone(self, *args, **kwargs):
123 with self.cursor(**kwargs) as cursor:
124 cursor.execute(*args)
125 return cursor.fetchone()
126
127 def fetchall(self, *args, **kwargs):
128 with self.cursor(**kwargs) as cursor:
129 cursor.execute(*args)
130 return cursor.fetchall()
131
132 def fetchiter(self, *args, **kwargs):
133 with self.cursor(**kwargs) as cursor:
134 cursor.execute(*args)
135 while True:
136 items = cursor.fetchmany()
137 if not items:
138 break
139 for item in items:
140 yield item
141
142
143class PostgresConnectionPool(AbstractDatabaseConnectionPool):
144
145 def __init__(self, *args, **kwargs):
146 self.connect = kwargs.pop('connect', connect)
147 maxsize = kwargs.pop('maxsize', None)
148 self.args = args
149 self.kwargs = kwargs
150 AbstractDatabaseConnectionPool.__init__(self, maxsize)
151
152 def create_connection(self):
153 return self.connect(*self.args, **self.kwargs)
154
155
156def main():
157 import time
158 pool = PostgresConnectionPool("dbname=postgres", maxsize=3)
159 start = time.time()
160 for _ in range(4):
161 gevent.spawn(pool.execute, 'select pg_sleep(1);')
162 gevent.wait()
163 delay = time.time() - start
164 print('Running "select pg_sleep(1);" 4 times with 3 connections. Should take about 2 seconds: %.2fs' % delay)
165
166if __name__ == '__main__':
167 main()