Example psycopg2_pool.py

  1from __future__ import print_function
  2# gevent-test-requires-resource: psycopg2
  3# pylint:disable=import-error,broad-except,bare-except
  4import sys
  5import contextlib
  6
  7import gevent
  8from gevent.queue import Queue
  9from gevent.socket import wait_read, wait_write
 10from psycopg2 import extensions, OperationalError, connect
 11
 12
 13if sys.version_info[0] >= 3:
 14    integer_types = (int,)
 15else:
 16    import __builtin__
 17    integer_types = (int, __builtin__.long)
 18
 19
 20def gevent_wait_callback(conn, timeout=None):
 21    """A wait callback useful to allow gevent to work with Psycopg."""
 22    while 1:
 23        state = conn.poll()
 24        if state == extensions.POLL_OK:
 25            break
 26        elif state == extensions.POLL_READ:
 27            wait_read(conn.fileno(), timeout=timeout)
 28        elif state == extensions.POLL_WRITE:
 29            wait_write(conn.fileno(), timeout=timeout)
 30        else:
 31            raise OperationalError(
 32                "Bad result from poll: %r" % state)
 33
 34
 35extensions.set_wait_callback(gevent_wait_callback)
 36
 37
 38class AbstractDatabaseConnectionPool(object):
 39
 40    def __init__(self, maxsize=100):
 41        if not isinstance(maxsize, integer_types):
 42            raise TypeError('Expected integer, got %r' % (maxsize, ))
 43        self.maxsize = maxsize
 44        self.pool = Queue()
 45        self.size = 0
 46
 47    def create_connection(self):
 48        raise NotImplementedError()
 49
 50    def get(self):
 51        pool = self.pool
 52        if self.size >= self.maxsize or pool.qsize():
 53            return pool.get()
 54
 55        self.size += 1
 56        try:
 57            new_item = self.create_connection()
 58        except:
 59            self.size -= 1
 60            raise
 61        return new_item
 62
 63    def put(self, item):
 64        self.pool.put(item)
 65
 66    def closeall(self):
 67        while not self.pool.empty():
 68            conn = self.pool.get_nowait()
 69            try:
 70                conn.close()
 71            except Exception:
 72                pass
 73
 74    @contextlib.contextmanager
 75    def connection(self, isolation_level=None):
 76        conn = self.get()
 77        try:
 78            if isolation_level is not None:
 79                if conn.isolation_level == isolation_level:
 80                    isolation_level = None
 81                else:
 82                    conn.set_isolation_level(isolation_level)
 83            yield conn
 84        except:
 85            if conn.closed:
 86                conn = None
 87                self.closeall()
 88            else:
 89                conn = self._rollback(conn)
 90            raise
 91        else:
 92            if conn.closed:
 93                raise OperationalError("Cannot commit because connection was closed: %r" % (conn, ))
 94            conn.commit()
 95        finally:
 96            if conn is not None and not conn.closed:
 97                if isolation_level is not None:
 98                    conn.set_isolation_level(isolation_level)
 99                self.put(conn)
100            else:
101                self.size -= 1
102
103    @contextlib.contextmanager
104    def cursor(self, *args, **kwargs):
105        isolation_level = kwargs.pop('isolation_level', None)
106        with self.connection(isolation_level) as conn:
107            yield conn.cursor(*args, **kwargs)
108
109    def _rollback(self, conn):
110        try:
111            conn.rollback()
112        except:
113            gevent.get_hub().handle_error(conn, *sys.exc_info())
114            return
115        return conn
116
117    def execute(self, *args, **kwargs):
118        with self.cursor(**kwargs) as cursor:
119            cursor.execute(*args)
120            return cursor.rowcount
121
122    def fetchone(self, *args, **kwargs):
123        with self.cursor(**kwargs) as cursor:
124            cursor.execute(*args)
125            return cursor.fetchone()
126
127    def fetchall(self, *args, **kwargs):
128        with self.cursor(**kwargs) as cursor:
129            cursor.execute(*args)
130            return cursor.fetchall()
131
132    def fetchiter(self, *args, **kwargs):
133        with self.cursor(**kwargs) as cursor:
134            cursor.execute(*args)
135            while True:
136                items = cursor.fetchmany()
137                if not items:
138                    break
139                for item in items:
140                    yield item
141
142
143class PostgresConnectionPool(AbstractDatabaseConnectionPool):
144
145    def __init__(self, *args, **kwargs):
146        self.connect = kwargs.pop('connect', connect)
147        maxsize = kwargs.pop('maxsize', None)
148        self.args = args
149        self.kwargs = kwargs
150        AbstractDatabaseConnectionPool.__init__(self, maxsize)
151
152    def create_connection(self):
153        return self.connect(*self.args, **self.kwargs)
154
155
156def main():
157    import time
158    pool = PostgresConnectionPool("dbname=postgres", maxsize=3)
159    start = time.time()
160    for _ in range(4):
161        gevent.spawn(pool.execute, 'select pg_sleep(1);')
162    gevent.wait()
163    delay = time.time() - start
164    print('Running "select pg_sleep(1);" 4 times with 3 connections. Should take about 2 seconds: %.2fs' % delay)
165
166if __name__ == '__main__':
167    main()

Current source