Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / asyncpg   python

Repository URL to install this package:

/ protocol / prepared_stmt.pyx

# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


from asyncpg import exceptions


@cython.final
cdef class PreparedStatementState:

    def __cinit__(self, str name, str query, BaseProtocol protocol):
        self.name = name
        self.query = query
        self.settings = protocol.settings
        self.row_desc = self.parameters_desc = None
        self.args_codecs = self.rows_codecs = None
        self.args_num = self.cols_num = 0
        self.cols_desc = None
        self.closed = False
        self.refs = 0

    def _get_parameters(self):
        cdef Codec codec

        result = []
        for oid in self.parameters_desc:
            codec = self.settings.get_data_codec(oid)
            if codec is None:
                raise exceptions.InternalClientError(
                    'missing codec information for OID {}'.format(oid))
            result.append(apg_types.Type(
                oid, codec.name, codec.kind, codec.schema))

        return tuple(result)

    def _get_attributes(self):
        cdef Codec codec

        if not self.row_desc:
            return ()

        result = []
        for d in self.row_desc:
            name = d[0]
            oid = d[3]

            codec = self.settings.get_data_codec(oid)
            if codec is None:
                raise exceptions.InternalClientError(
                    'missing codec information for OID {}'.format(oid))

            name = name.decode(self.settings._encoding)

            result.append(
                apg_types.Attribute(name,
                    apg_types.Type(oid, codec.name, codec.kind, codec.schema)))

        return tuple(result)

    def _init_types(self):
        cdef:
            Codec codec
            set missing = set()

        if self.parameters_desc:
            for p_oid in self.parameters_desc:
                codec = self.settings.get_data_codec(<uint32_t>p_oid)
                if codec is None or not codec.has_encoder():
                    missing.add(p_oid)

        if self.row_desc:
            for rdesc in self.row_desc:
                codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
                if codec is None or not codec.has_decoder():
                    missing.add(rdesc[3])

        return missing

    cpdef _init_codecs(self):
        self._ensure_args_encoder()
        self._ensure_rows_decoder()

    def attach(self):
        self.refs += 1

    def detach(self):
        self.refs -= 1

    def mark_closed(self):
        self.closed = True

    cdef _encode_bind_msg(self, args):
        cdef:
            int idx
            WriteBuffer writer
            Codec codec

        if len(args) > 32767:
            raise exceptions.InterfaceError(
                'the number of query arguments cannot exceed 32767')

        writer = WriteBuffer.new()

        num_args_passed = len(args)
        if self.args_num != num_args_passed:
            hint = 'Check the query against the passed list of arguments.'

            if self.args_num == 0:
                # If the server was expecting zero arguments, it is likely
                # that the user tried to parametrize a statement that does
                # not support parameters.
                hint += (r'  Note that parameters are supported only in'
                         r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
                         r' statements, and will *not* work in statements '
                         r' like CREATE VIEW or DECLARE CURSOR.')

            raise exceptions.InterfaceError(
                'the server expects {x} argument{s} for this query, '
                '{y} {w} passed'.format(
                    x=self.args_num, s='s' if self.args_num != 1 else '',
                    y=num_args_passed,
                    w='was' if num_args_passed == 1 else 'were'),
                hint=hint)

        if self.have_text_args:
            writer.write_int16(self.args_num)
            for idx in range(self.args_num):
                codec = <Codec>(self.args_codecs[idx])
                writer.write_int16(codec.format)
        else:
            # All arguments are in binary format
            writer.write_int32(0x00010001)

        writer.write_int16(self.args_num)

        for idx in range(self.args_num):
            arg = args[idx]
            if arg is None:
                writer.write_int32(-1)
            else:
                codec = <Codec>(self.args_codecs[idx])
                try:
                    codec.encode(self.settings, writer, arg)
                except (AssertionError, exceptions.InternalClientError):
                    # These are internal errors and should raise as-is.
                    raise
                except exceptions.InterfaceError:
                    # This is already a descriptive error.
                    raise
                except Exception as e:
                    # Everything else is assumed to be an encoding error
                    # due to invalid input.
                    value_repr = repr(arg)
                    if len(value_repr) > 40:
                        value_repr = value_repr[:40] + '...'

                    raise exceptions.DataError(
                        'invalid input for query argument'
                        ' ${n}: {v} ({msg})'.format(
                            n=idx + 1, v=value_repr, msg=e)) from e

        if self.have_text_cols:
            writer.write_int16(self.cols_num)
            for idx in range(self.cols_num):
                codec = <Codec>(self.rows_codecs[idx])
                writer.write_int16(codec.format)
        else:
            # All columns are in binary format
            writer.write_int32(0x00010001)

        return writer

    cdef _ensure_rows_decoder(self):
        cdef:
            list cols_names
            object cols_mapping
            tuple row
            uint32_t oid
            Codec codec
            list codecs

        if self.cols_desc is not None:
            return

        if self.cols_num == 0:
            self.cols_desc = record.ApgRecordDesc_New({}, ())
            return

        cols_mapping = collections.OrderedDict()
        cols_names = []
        codecs = []
        for i from 0 <= i < self.cols_num:
            row = self.row_desc[i]
            col_name = row[0].decode(self.settings._encoding)
            cols_mapping[col_name] = i
            cols_names.append(col_name)
            oid = row[3]
            codec = self.settings.get_data_codec(oid)
            if codec is None or not codec.has_decoder():
                raise exceptions.InternalClientError(
                    'no decoder for OID {}'.format(oid))
            if not codec.is_binary():
                self.have_text_cols = True

            codecs.append(codec)

        self.cols_desc = record.ApgRecordDesc_New(
            cols_mapping, tuple(cols_names))

        self.rows_codecs = tuple(codecs)

    cdef _ensure_args_encoder(self):
        cdef:
            uint32_t p_oid
            Codec codec
            list codecs = []

        if self.args_num == 0 or self.args_codecs is not None:
            return

        for i from 0 <= i < self.args_num:
            p_oid = self.parameters_desc[i]
            codec = self.settings.get_data_codec(p_oid)
            if codec is None or not codec.has_encoder():
                raise exceptions.InternalClientError(
                    'no encoder for OID {}'.format(p_oid))
            if codec.type not in {}:
                self.have_text_args = True

            codecs.append(codec)

        self.args_codecs = tuple(codecs)

    cdef _set_row_desc(self, object desc):
        self.row_desc = _decode_row_desc(desc)
        self.cols_num = <int16_t>(len(self.row_desc))

    cdef _set_args_desc(self, object desc):
        self.parameters_desc = _decode_parameters_desc(desc)
        self.args_num = <int16_t>(len(self.parameters_desc))

    cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
        cdef:
            Codec codec
            int16_t fnum
            int32_t flen
            object dec_row
            tuple rows_codecs = self.rows_codecs
            ConnectionSettings settings = self.settings
            int32_t i
            FRBuffer rbuf
            ssize_t bl

        frb_init(&rbuf, cbuf, buf_len)

        fnum = hton.unpack_int16(frb_read(&rbuf, 2))

        if fnum != self.cols_num:
            raise exceptions.ProtocolError(
                'the number of columns in the result row ({}) is '
                'different from what was described ({})'.format(
                    fnum, self.cols_num))

        dec_row = record.ApgRecord_New(self.cols_desc, fnum)
        for i in range(fnum):
            flen = hton.unpack_int32(frb_read(&rbuf, 4))

            if flen == -1:
                val = None
            else:
                # Clamp buffer size to that of the reported field length
                # to make sure that codecs can rely on read_all() working
                # properly.
                bl = frb_get_len(&rbuf)
                if flen > bl:
                    frb_check(&rbuf, flen)
                frb_set_len(&rbuf, flen)
                codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
                val = codec.decode(settings, &rbuf)
                if frb_get_len(&rbuf) != 0:
                    raise BufferError(
                        'unexpected trailing {} bytes in buffer'.format(
                            frb_get_len(&rbuf)))
                frb_set_len(&rbuf, bl - flen)

            cpython.Py_INCREF(val)
            record.ApgRecord_SET_ITEM(dec_row, i, val)

        if frb_get_len(&rbuf) != 0:
            raise BufferError('unexpected trailing {} bytes in buffer'.format(
                frb_get_len(&rbuf)))

        return dec_row


cdef _decode_parameters_desc(object desc):
    cdef:
        ReadBuffer reader
        int16_t nparams
        uint32_t p_oid
        list result = []

    reader = ReadBuffer.new_message_parser(desc)
    nparams = reader.read_int16()

    for i from 0 <= i < nparams:
        p_oid = <uint32_t>reader.read_int32()
        result.append(p_oid)

    return result


cdef _decode_row_desc(object desc):
    cdef:
        ReadBuffer reader

        int16_t nfields

        bytes f_name
        uint32_t f_table_oid
        int16_t f_column_num
        uint32_t f_dt_oid
        int16_t f_dt_size
        int32_t f_dt_mod
        int16_t f_format

        list result

    reader = ReadBuffer.new_message_parser(desc)
    nfields = reader.read_int16()
    result = []

    for i from 0 <= i < nfields:
        f_name = reader.read_null_str()
        f_table_oid = <uint32_t>reader.read_int32()
        f_column_num = reader.read_int16()
        f_dt_oid = <uint32_t>reader.read_int32()
        f_dt_size = reader.read_int16()
        f_dt_mod = reader.read_int32()
        f_format = reader.read_int16()

        result.append(
Loading ...