Repository URL to install this package:
|
Version:
0.8.1 ▾
|
try:
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union
except ImportError:
pass
import re
import wrapt
from supertenant import consts
from supertenant.supermeter.data import db_data
from supertenant.supermeter.logger import log_integration_module_error, log_integration_module_exception
from supertenant.supermeter.managers.actions import SyncActions
from supertenant.supermeter.managers.db_manager import DBManager
from supertenant.supermeter.scope_manager import Span
# TODO: is this needed? if so understand what it do
regexp_sql_values = re.compile(r"('[\s\S][^']*'|\d*\.\d+|\d+|NULL)")
def get_sql_error(action_desc):
# type:(Optional[Dict[str, Any]]) -> Tuple[int, str]
rc = 1317
error_message = None
if action_desc is not None:
sql_reject = action_desc.get("SqlReject")
if sql_reject is not None:
rc = int(sql_reject.get("sql_rc"))
error_message = sql_reject.get("error_message")
else:
try:
rc = int(action_desc.get(consts.REJECT_ATTRIBUTE_SQL_RC)) # type: ignore
error_message = action_desc.get(consts.REJECT_ATTRIBUTE_ERROR_MESSAGE)
except Exception:
pass
if rc is None:
rc = 1317
if error_message is None:
error_message = "Query execution was interrupted"
return rc, error_message
class CursorWrapper(wrapt.ObjectProxy):
__slots__ = ("_module_name", "_data_type", "_connect_params", "_cursor_params", "_exception")
def __init__(self, cursor, module_name, data_type, connect_params=None, cursor_params=None, exception=Exception):
# type: (Any, str, Union[Type[db_data.PGData], Type[db_data.MYSQLData]], Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]], Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]], Type[Exception]) -> None # noqa: E501
super(CursorWrapper, self).__init__(wrapped=cursor)
self._module_name = module_name
self._data_type = data_type
self._connect_params = connect_params
self._cursor_params = cursor_params
self._exception = exception
def __enter__(self):
# type: () -> Any
r = self.__wrapped__.__enter__()
if r is self.__wrapped__:
return self
else:
# Should never get here
if hasattr(r, "cursor"):
# r is Connection-like.
log_integration_module_error(
self._module_name, "in __enter__ returning connection-like is not supported here"
)
elif hasattr(r, "execute"):
# r is Cursor-like but not this __wrapped__
log_integration_module_error(
self._module_name, "in __enter__ returning uninstrumented cursor - should not happen"
)
return r
def _collect_kvs(self, data, sql):
# type: (Union[db_data.PGData, db_data.MYSQLData], Union[str, bytes, object]) -> None
try:
if isinstance(sql, bytes):
sql = sql.decode("utf-8", "ignore")
if self._connect_params is not None:
if "db" in self._connect_params[1]:
data.set_db(self._connect_params[1]["db"])
elif "database" in self._connect_params[1]:
data.set_db(self._connect_params[1]["database"])
if "user" in self._connect_params[1]:
data.set_user(self._connect_params[1]["user"])
host = self._connect_params[1]["host"]
if "port" in self._connect_params[1]:
data.set_port(self._connect_params[1]["port"])
data.set_host(host)
data.set_integration_module_resource_id(host)
# put this at the end in case it fails so we have the rest of the data
sql = str(sql)
data.set_statement(sql)
except Exception as exc:
log_integration_module_exception(self._module_name, "_collect_kvs", exc)
def execute(self, sql, params=None):
# type: (Union[str, bytes, object], Union[Sequence[Any], Mapping[str, Any], object, None]) -> Any
before_data = self._data_type(self._module_name)
self._collect_kvs(before_data, sql)
span_id, act, poll_key = DBManager.open_span(before_data)
if span_id is None:
return self.__wrapped__.execute(sql, params)
with Span(span_id, self._data_type(self._module_name)) as span:
action, action_desc = SyncActions.get_action(span_id, act, poll_key)
if action == consts.ACTION_REJECT:
rc, error_message = get_sql_error(action_desc)
raise self._exception(rc, error_message)
try:
result = self.__wrapped__.execute(sql, params)
except Exception:
span.finish_data.mark_error()
raise
return result
def executemany(self, sql, seq_of_parameters):
# type: (Union[str, bytes, object], Iterable[Union[Sequence[Any], Mapping[str, Any], object, None]]) -> Any
before_data = self._data_type(self._module_name)
self._collect_kvs(before_data, sql)
span_id, act, poll_key = DBManager.open_span(before_data)
if span_id is None:
return self.__wrapped__.executemany(sql, seq_of_parameters)
with Span(span_id, self._data_type(self._module_name)) as span:
action, action_desc = SyncActions.get_action(span_id, act, poll_key)
if action == consts.ACTION_REJECT:
rc, error_message = get_sql_error(action_desc)
raise self._exception(rc, error_message)
try:
result = self.__wrapped__.executemany(sql, seq_of_parameters)
except Exception:
span.finish_data.mark_error()
raise
return result
def callproc(self, proc_name, params):
# type: (str, Iterable[Any]) -> Any
before_data = self._data_type(self._module_name)
self._collect_kvs(before_data, proc_name)
span_id, act, poll_key = DBManager.open_span(before_data)
if span_id is None:
return self.__wrapped__.callproc(proc_name, params)
with Span(span_id, self._data_type(self._module_name)) as span:
action, action_desc = SyncActions.get_action(span_id, act, poll_key)
if action == consts.ACTION_REJECT:
rc, error_message = get_sql_error(action_desc)
raise self._exception(rc, error_message)
try:
result = self.__wrapped__.callproc(proc_name, params)
except Exception:
span.finish_data.mark_error()
raise
return result
class ConnectionWrapper(wrapt.ObjectProxy):
__slots__ = ("_data_type", "_module_name", "_connect_params", "_exception")
def __init__(self, connection, data_type, module_name, connect_params, exception):
# type: (object, Union[Type[db_data.PGData], Type[db_data.MYSQLData]], str, Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]], Type[Exception]) -> None # noqa: E501
super(ConnectionWrapper, self).__init__(wrapped=connection)
self._data_type = data_type
self._module_name = module_name
self._connect_params = connect_params
self._exception = exception
def __enter__(self):
# type: () -> Any
r = self.__wrapped__.__enter__()
if hasattr(r, "cursor"):
# r is Connection-like.
if r is self.__wrapped__:
# Return the reference to this proxy object. Returning r would
# return the untraced reference.
return self
else:
# r is a different connection object.
# This should not happen in practice but play it safe so that
# the original functionality is maintained.
return r
elif hasattr(r, "execute"):
# r is Cursor-like.
if hasattr(r, "__wrapped__") and isinstance(r, wrapt.ObjectProxy):
return r
else:
log_integration_module_error(self._module_name, "in __enter__ returning uninstrumented cursor")
return r
else:
# Otherwise r is some other object, so maintain the functionality
# of the original.
return r
def cursor(self, *args, **kwargs):
# type: (List[Any], Dict[str, Any]) -> CursorWrapper
return CursorWrapper(
cursor=self.__wrapped__.cursor(*args, **kwargs),
module_name=self._module_name,
data_type=self._data_type,
connect_params=self._connect_params,
cursor_params=(args, kwargs) if args or kwargs else None,
exception=self._exception,
)
def begin(self):
# type: () -> Any
return self.__wrapped__.begin()
def commit(self):
# type: () -> Any
return self.__wrapped__.commit()
def rollback(self):
# type: () -> Any
return self.__wrapped__.rollback()
class ConnectionFactory(object):
def __init__(self, connect_func, data_type, module_name, exception):
# type: (Union[Callable[[Any], Any], Type[object]], Union[Type[db_data.PGData], Type[db_data.MYSQLData]], str, Type[Exception]) -> None # noqa: E501
# TODO: Find better typing for the *args **kwargs thing
self._connect_func = connect_func
self._module_name = module_name
self._wrapper_ctor = ConnectionWrapper
self._data_type = data_type
self._exception = exception
def __call__(self, *args, **kwargs):
# type: (List[Any], Dict[str, Any]) -> Any
connect_params = (args, kwargs) if args or kwargs else None
return self._wrapper_ctor(
connection=self._connect_func(*args, **kwargs),
data_type=self._data_type,
module_name=self._module_name,
connect_params=connect_params,
exception=self._exception,
)