Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aroundthecode / SQLAlchemy   python

Repository URL to install this package:

Version: 1.2.10 

/ dialects / postgresql / pygresql.py

# postgresql/pygresql.py
# Copyright (C) 2005-2018 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

"""
.. dialect:: postgresql+pygresql
    :name: pygresql
    :dbapi: pgdb
    :connectstring: postgresql+pygresql://user:password@host:port/dbname\
[?key=value&key=value...]
    :url: http://www.pygresql.org/
"""

import decimal
import re

from ... import exc, processors, util
from ...types import Numeric, JSON as Json
from ...sql.elements import Null
from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
    _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
from .hstore import HSTORE
from .json import JSON, JSONB


class _PGNumeric(Numeric):

    def bind_processor(self, dialect):
        return None

    def result_processor(self, dialect, coltype):
        if not isinstance(coltype, int):
            coltype = coltype.oid
        if self.asdecimal:
            if coltype in _FLOAT_TYPES:
                return processors.to_decimal_processor_factory(
                    decimal.Decimal,
                    self._effective_decimal_return_scale)
            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
                # PyGreSQL returns Decimal natively for 1700 (numeric)
                return None
            else:
                raise exc.InvalidRequestError(
                    "Unknown PG numeric type: %d" % coltype)
        else:
            if coltype in _FLOAT_TYPES:
                # PyGreSQL returns float natively for 701 (float8)
                return None
            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
                return processors.to_float
            else:
                raise exc.InvalidRequestError(
                    "Unknown PG numeric type: %d" % coltype)


class _PGHStore(HSTORE):

    def bind_processor(self, dialect):
        if not dialect.has_native_hstore:
            return super(_PGHStore, self).bind_processor(dialect)
        hstore = dialect.dbapi.Hstore
        def process(value):
            if isinstance(value, dict):
                return hstore(value)
            return value
        return process

    def result_processor(self, dialect, coltype):
        if not dialect.has_native_hstore:
            return super(_PGHStore, self).result_processor(dialect, coltype)


class _PGJSON(JSON):

    def bind_processor(self, dialect):
        if not dialect.has_native_json:
            return super(_PGJSON, self).bind_processor(dialect)
        json = dialect.dbapi.Json

        def process(value):
            if value is self.NULL:
                value = None
            elif isinstance(value, Null) or (
                value is None and self.none_as_null):
                return None
            if value is None or isinstance(value, (dict, list)):
                return json(value)
            return value

        return process

    def result_processor(self, dialect, coltype):
        if not dialect.has_native_json:
            return super(_PGJSON, self).result_processor(dialect, coltype)


class _PGJSONB(JSONB):

    def bind_processor(self, dialect):
        if not dialect.has_native_json:
            return super(_PGJSONB, self).bind_processor(dialect)
        json = dialect.dbapi.Json

        def process(value):
            if value is self.NULL:
                value = None
            elif isinstance(value, Null) or (
                value is None and self.none_as_null):
                return None
            if value is None or isinstance(value, (dict, list)):
                return json(value)
            return value

        return process

    def result_processor(self, dialect, coltype):
        if not dialect.has_native_json:
            return super(_PGJSONB, self).result_processor(dialect, coltype)


class _PGUUID(UUID):

    def bind_processor(self, dialect):
        if not dialect.has_native_uuid:
            return super(_PGUUID, self).bind_processor(dialect)
        uuid = dialect.dbapi.Uuid

        def process(value):
            if value is None:
                return None
            if isinstance(value, (str, bytes)):
                if len(value) == 16:
                    return uuid(bytes=value)
                return uuid(value)
            if isinstance(value, int):
                return uuid(int=value)
            return value

        return process

    def result_processor(self, dialect, coltype):
        if not dialect.has_native_uuid:
            return super(_PGUUID, self).result_processor(dialect, coltype)
        if not self.as_uuid:
            def process(value):
                if value is not None:
                    return str(value)
            return process


class _PGCompiler(PGCompiler):

    def visit_mod_binary(self, binary, operator, **kw):
        return self.process(binary.left, **kw) + " %% " + \
            self.process(binary.right, **kw)

    def post_process_text(self, text):
        return text.replace('%', '%%')


class _PGIdentifierPreparer(PGIdentifierPreparer):

    def _escape_identifier(self, value):
        value = value.replace(self.escape_quote, self.escape_to_quote)
        return value.replace('%', '%%')


class PGDialect_pygresql(PGDialect):

    driver = 'pygresql'

    statement_compiler = _PGCompiler
    preparer = _PGIdentifierPreparer

    @classmethod
    def dbapi(cls):
        import pgdb
        return pgdb

    colspecs = util.update_copy(
        PGDialect.colspecs,
        {
            Numeric: _PGNumeric,
            HSTORE: _PGHStore,
            Json: _PGJSON,
            JSON: _PGJSON,
            JSONB: _PGJSONB,
            UUID: _PGUUID,
        }
    )

    def __init__(self, **kwargs):
        super(PGDialect_pygresql, self).__init__(**kwargs)
        try:
            version = self.dbapi.version
            m = re.match(r'(\d+)\.(\d+)', version)
            version = (int(m.group(1)), int(m.group(2)))
        except (AttributeError, ValueError, TypeError):
            version = (0, 0)
        self.dbapi_version = version
        if version < (5, 0):
            has_native_hstore = has_native_json = has_native_uuid = False
            if version != (0, 0):
                util.warn("PyGreSQL is only fully supported by SQLAlchemy"
                    " since version 5.0.")
        else:
            self.supports_unicode_statements = True
            self.supports_unicode_binds = True
            has_native_hstore = has_native_json = has_native_uuid = True
        self.has_native_hstore = has_native_hstore
        self.has_native_json = has_native_json
        self.has_native_uuid = has_native_uuid

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='user')
        if 'port' in opts:
            opts['host'] = '%s:%s' % (
                opts.get('host', '').rsplit(':', 1)[0], opts.pop('port'))
        opts.update(url.query)
        return [], opts

    def is_disconnect(self, e, connection, cursor):
        if isinstance(e, self.dbapi.Error):
            if not connection:
                return False
            try:
                connection = connection.connection
            except AttributeError:
                pass
            else:
                if not connection:
                    return False
            try:
                return connection.closed
            except AttributeError:  # PyGreSQL < 5.0
                return connection._cnx is None
        return False


dialect = PGDialect_pygresql