import re import os.path from .._compat import PY2, with_metaclass, iterkeys, to_unicode, long from .._globals import IDENTITY, THREAD_LOCAL from ..drivers import psycopg2_adapt from .base import SQLAdapter from ..utils import split_uri_args from . import AdapterMeta, adapters, with_connection, with_connection_or_raise class PostgreMeta(AdapterMeta): def __call__(cls, *args, **kwargs): if cls not in [Postgre, PostgreNew, PostgreBoolean]: return AdapterMeta.__call__(cls, *args, **kwargs) # choose driver according uri available_drivers = [ driver for driver in cls.drivers if driver in iterkeys(kwargs["db"]._drivers_available) ] uri_items = kwargs["uri"].split("://", 1)[0].split(":") uri_driver = uri_items[1] if len(uri_items) > 1 else None if uri_driver and uri_driver in available_drivers: driver = uri_driver else: driver = available_drivers[0] if available_drivers else cls.drivers[0] cls = adapters._registry_[uri_items[0] + ":" + driver] return AdapterMeta.__call__(cls, *args, **kwargs) @adapters.register_for("postgres") class Postgre(with_metaclass(PostgreMeta, SQLAdapter)): dbengine = "postgres" drivers = ("psycopg2",) support_distributed_transaction = True REGEX_URI = ( "^(?P[^:@]+)(:(?P[^@]*))?" r"@(?P[^:/]*|\[[^\]]+\])(:(?P\d+))?" "/(?P[^?]+)" r"(\?(?P.*))?$" ) # sslmode, ssl (no value) and unix_socket def __init__( self, db, uri, pool_size=0, folder=None, db_codec="UTF-8", credential_decoder=IDENTITY, driver_args={}, adapter_args={}, srid=4326, after_connection=None, ): self.srid = srid super(Postgre, self).__init__( db, uri, pool_size, folder, db_codec, credential_decoder, driver_args, adapter_args, after_connection, ) self._config_json() def _initialize_(self): super(Postgre, self)._initialize_() ruri = self.uri.split("://", 1)[1] m = re.match(self.REGEX_URI, ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = self.credential_decoder(m.group("user")) password = self.credential_decoder(m.group("password")) host = m.group("host") uriargs = m.group("uriargs") if uriargs: uri_args = split_uri_args(uriargs, need_equal=False) else: uri_args = dict() socket = uri_args.get("unix_socket") if not host and not socket: raise SyntaxError("Host or UNIX socket name required") db = m.group("db") self.driver_args.update(user=user, database=db) if password is not None: self.driver_args["password"] = password if socket: if not os.path.exists(socket): raise ValueError("UNIX socket %r not found" % socket) if self.driver_name == "psycopg2": # the psycopg2 driver let you configure the socket directory # only (not the socket file name) by passing it as the host # (must be an absolute path otherwise the driver tries a TCP/IP # connection to host); this behaviour is due to the underlying # libpq used by the driver socket_dir = os.path.abspath(os.path.dirname(socket)) self.driver_args["host"] = socket_dir else: port = int(m.group("port") or 5432) self.driver_args.update(host=host, port=port) sslmode = uri_args.get("sslmode") if sslmode and self.driver_name == "psycopg2": self.driver_args["sslmode"] = sslmode if self.driver: self.__version__ = "%s %s" % (self.driver.__name__, self.driver.__version__) else: self.__version__ = None THREAD_LOCAL._pydal_last_insert_ = None self.get_connection() def _get_json_dialect(self): from ..dialects.postgre import PostgreDialectJSON return PostgreDialectJSON def _get_json_parser(self): from ..parsers.postgre import PostgreAutoJSONParser return PostgreAutoJSONParser @property def _last_insert(self): return THREAD_LOCAL._pydal_last_insert_ @_last_insert.setter def _last_insert(self, value): THREAD_LOCAL._pydal_last_insert_ = value def connector(self): return self.driver.connect(**self.driver_args) def after_connection(self): self.execute("SET CLIENT_ENCODING TO 'UTF8'") self.execute("SET standard_conforming_strings=on;") def lastrowid(self, table): if self._last_insert: return long(self.cursor.fetchone()[0]) sequence_name = table._sequence_name self.execute("SELECT currval(%s);" % self.adapt(sequence_name)) return long(self.cursor.fetchone()[0]) def _insert(self, table, fields): self._last_insert = None if fields: retval = None if hasattr(table, "_id"): self._last_insert = (table._id, 1) retval = table._id._rname return self.dialect.insert( table._rname, ",".join(el[0]._rname for el in fields), ",".join(self.expand(v, f.type) for f, v in fields), retval, ) return self.dialect.insert_empty(table._rname) @with_connection def prepare(self, key): self.execute("PREPARE TRANSACTION '%s';" % key) @with_connection def commit_prepared(self, key): self.execute("COMMIT PREPARED '%s';" % key) @with_connection def rollback_prepared(self, key): self.execute("ROLLBACK PREPARED '%s';" % key) @adapters.register_for("postgres:psycopg2") class PostgrePsyco(Postgre): drivers = ("psycopg2",) def _config_json(self): use_json = ( self.driver.__version__ >= "2.0.12" and self.connection.server_version >= 90200 ) if use_json: self.dialect = self._get_json_dialect()(self) if self.driver.__version__ >= "2.5.0": self.parser = self._get_json_parser()(self) def adapt(self, obj): adapted = psycopg2_adapt(obj) # deal with new relic Connection Wrapper (newrelic>=2.10.0.8) cxn = getattr(self.connection, "__wrapped__", self.connection) adapted.prepare(cxn) rv = adapted.getquoted() if not PY2: if isinstance(rv, bytes): return rv.decode("utf-8") return rv @adapters.register_for("postgres2") class PostgreNew(Postgre): def _get_json_dialect(self): from ..dialects.postgre import PostgreDialectArraysJSON return PostgreDialectArraysJSON def _get_json_parser(self): from ..parsers.postgre import PostgreNewAutoJSONParser return PostgreNewAutoJSONParser @adapters.register_for("postgres2:psycopg2") class PostgrePsycoNew(PostgrePsyco, PostgreNew): pass @adapters.register_for("postgres3") class PostgreBoolean(PostgreNew): def _get_json_dialect(self): from ..dialects.postgre import PostgreDialectBooleanJSON return PostgreDialectBooleanJSON def _get_json_parser(self): from ..parsers.postgre import PostgreBooleanAutoJSONParser return PostgreBooleanAutoJSONParser @adapters.register_for("postgres3:psycopg2") class PostgrePsycoBoolean(PostgrePsycoNew, PostgreBoolean): pass @adapters.register_for("jdbc:postgres") class JDBCPostgre(Postgre): drivers = ("zxJDBC",) REGEX_URI = ( "^(?P[^:@]+)(:(?P[^@]*))?" r"@(?P[^:/]+|\[[^\]]+\])(:(?P\d+))?" "/(?P[^?]+)$" ) def _initialize_(self): super(Postgre, self)._initialize_() ruri = self.uri.split("://", 1)[1] m = re.match(self.REGEX_URI, ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = self.credential_decoder(m.group("user")) password = self.credential_decoder(m.group("password")) if password is None: password = "" host = m.group("host") db = m.group("db") port = m.group("port") or "5432" self.dsn = ("jdbc:postgresql://%s:%s/%s" % (host, port, db), user, password) if self.driver: self.__version__ = "%s %s" % (self.driver.__name__, self.driver.__version__) else: self.__version__ = None THREAD_LOCAL._pydal_last_insert_ = None self.get_connection() def connector(self): return self.driver.connect(*self.dsn, **self.driver_args) def after_connection(self): self.connection.set_client_encoding("UTF8") self.execute("BEGIN;") self.execute("SET CLIENT_ENCODING TO 'UNICODE';") def _config_json(self): use_json = self.connection.dbversion >= "9.2.0" if use_json: self.dialect = self._get_json_dialect()(self)