Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
cql / cql / thrifteries.py
Size: Mime:
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import zlib
import cql
from cql.cursor import Cursor, _VOID_DESCRIPTION, _COUNT_DESCRIPTION
from cql.query import cql_quote, cql_quote_name, prepare_query, PreparedQuery
from cql.connection import Connection
from cql.cassandra import Cassandra
from thrift.Thrift import TApplicationException
from thrift.transport import TTransport, TSocket
from thrift.protocol import TBinaryProtocol
from cql.cassandra.ttypes import (AuthenticationRequest, Compression,
        ConsistencyLevel, CqlResultType, InvalidRequestException,
        UnavailableException, TimedOutException, SchemaDisagreementException)

MIN_THRIFT_FOR_PREPARED_QUERIES = (19, 27, 0)
MIN_THRIFT_FOR_CL_IN_PROTOCOL = (19, 35, 0)

class ThriftCursor(Cursor):
    def __init__(self, parent_connection):
        Cursor.__init__(self, parent_connection)

        if hasattr(parent_connection.client, 'execute_prepared_cql_query') \
                and parent_connection.remote_thrift_version >= MIN_THRIFT_FOR_PREPARED_QUERIES:
            self.supports_prepared_queries = True

        cl_in_protocol = parent_connection.remote_thrift_version >= MIN_THRIFT_FOR_CL_IN_PROTOCOL
        self.use_cql3_methods = cl_in_protocol and self.cql_major_version == 3

    def compress_query_text(self, querytext):
        if self.compression == 'GZIP':
            compressed_q = zlib.compress(querytext)
        else:
            compressed_q = querytext
        req_compression = getattr(Compression, self.compression or 'NONE')
        return compressed_q, req_compression

    def prepare_query(self, query):
        if isinstance(query, unicode):
            raise ValueError("CQL query must be bytes, not unicode")
        prepared_q_text, paramnames = prepare_query(query)
        compressed_q, compression = self.compress_query_text(prepared_q_text)

        if self.use_cql3_methods:
            doquery = self._connection.client.prepare_cql3_query
        else:
            doquery = self._connection.client.prepare_cql_query

        presult = doquery(compressed_q, compression)

        assert presult.count == len(paramnames)
        if presult.variable_types is None and presult.count > 0:
            raise cql.ProgrammingError("Cassandra did not provide types for bound"
                                       " parameters. Prepared statements are only"
                                       " supported with cql3.")
        return PreparedQuery(query, presult.itemId, presult.variable_types, paramnames)

    def get_response(self, cql_query, consistency_level):
        compressed_q, compress = self.compress_query_text(cql_query)
        cl = getattr(ConsistencyLevel, consistency_level)
        if self.use_cql3_methods:
            doquery = self._connection.client.execute_cql3_query
            return self.handle_cql_execution_errors(doquery, compressed_q, compress, cl)
        else:
            doquery = self._connection.client.execute_cql_query
            return self.handle_cql_execution_errors(doquery, compressed_q, compress)

    def get_response_prepared(self, prepared_query, params, consistency_level):
        paramvals = prepared_query.encode_params(params)
        cl = getattr(ConsistencyLevel, consistency_level)
        if self.use_cql3_methods:
            doquery = self._connection.client.execute_prepared_cql3_query
            return self.handle_cql_execution_errors(doquery, prepared_query.itemid,
                                                    paramvals, cl)
        else:
            doquery = self._connection.client.execute_prepared_cql_query
            return self.handle_cql_execution_errors(doquery, prepared_query.itemid,
                                                    paramvals)

    def handle_cql_execution_errors(self, executor, *args, **kwargs):
        try:
            return executor(*args, **kwargs)
        except InvalidRequestException, ire:
            raise cql.ProgrammingError("Bad Request: %s" % ire.why)
        except SchemaDisagreementException, sde:
            raise cql.IntegrityError("Schema versions disagree, (try again later).")
        except UnavailableException:
            raise cql.OperationalError("Unable to complete request: one or "
                                       "more nodes were unavailable.")
        except TimedOutException:
            raise cql.OperationalError("Request did not complete within rpc_timeout.")
        except TApplicationException, tapp:
            raise cql.InternalError("Internal application error")

    def process_execution_results(self, response, decoder=None):
        if response.type == CqlResultType.ROWS:
            self.decoder = (decoder or self.default_decoder)(response.schema)
            self.result = [r.columns for r in response.rows]
            self.rs_idx = 0
            self.rowcount = len(self.result)
            if self.result:
                self.get_metadata_info(self.result[0])
        elif response.type == CqlResultType.INT:
            self.result = [(response.num,)]
            self.rs_idx = 0
            self.rowcount = 1
            # TODO: name could be the COUNT expression
            self.description = _COUNT_DESCRIPTION
            self.name_info = None
        elif response.type == CqlResultType.VOID:
            self.result = []
            self.rs_idx = 0
            self.rowcount = 0
            self.description = _VOID_DESCRIPTION
            self.name_info = ()
        else:
            raise Exception('unknown result type %s' % response.type)

        # 'Return values are not defined.'
        return True

    def columnvalues(self, row):
        return [column.value for column in row]

    def columninfo(self, row):
        return (column.name for column in row)

class ThriftConnection(Connection):
    cursorclass = ThriftCursor

    def establish_connection(self):
        if self.transport is None:
            socket = TSocket.TSocket(self.host, self.port)
            self.transport = TTransport.TFramedTransport(socket)

        if not self.transport.isOpen():
            self.transport.open()

        protocol = TBinaryProtocol.TBinaryProtocolAccelerated(self.transport)
        self.client = Cassandra.Client(protocol)

        if self.credentials:
            self.client.login(AuthenticationRequest(credentials=self.credentials))

        self.remote_thrift_version = tuple(map(int, self.client.describe_version().split('.')))

        if self.cql_version:
            self.set_cql_version(self.cql_version)

    def set_cql_version(self, cql_version):
        self.client.set_cql_version(cql_version)
        try:
            self.cql_major_version = int(cql_version.split('.')[0])
        except ValueError:
            pass

    def set_initial_keyspace(self, keyspace):
        c = self.cursor()
        if self.cql_major_version >= 3:
            ksname = cql_quote_name(keyspace)
        else:
            ksname = cql_quote(keyspace)
        c.execute('USE %s' % ksname)
        c.close()

    def terminate_connection(self):
        self.transport.close()