# orm/util.py
# Copyright (C) 2005-2018 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .. import sql, util, event, exc as sa_exc, inspection
from ..sql import expression, util as sql_util, operators
from .interfaces import PropComparator, MapperProperty
from . import attributes
import re
from .base import instance_str, state_str, state_class_str, attribute_str, \
state_attribute_str, object_mapper, object_state, _none_set, _never_set
from .base import class_mapper, _class_to_mapper
from .base import InspectionAttr
from .path_registry import PathRegistry
all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire",
"none"))
class CascadeOptions(frozenset):
"""Keeps track of the options sent to relationship().cascade"""
_add_w_all_cascades = all_cascades.difference([
'all', 'none', 'delete-orphan'])
_allowed_cascades = all_cascades
__slots__ = (
'save_update', 'delete', 'refresh_expire', 'merge',
'expunge', 'delete_orphan')
def __new__(cls, value_list):
if isinstance(value_list, util.string_types) or value_list is None:
return cls.from_string(value_list)
values = set(value_list)
if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
"Invalid cascade option(s): %s" %
", ".join([repr(x) for x in
sorted(values.difference(cls._allowed_cascades))]))
if "all" in values:
values.update(cls._add_w_all_cascades)
if "none" in values:
values.clear()
values.discard('all')
self = frozenset.__new__(CascadeOptions, values)
self.save_update = 'save-update' in values
self.delete = 'delete' in values
self.refresh_expire = 'refresh-expire' in values
self.merge = 'merge' in values
self.expunge = 'expunge' in values
self.delete_orphan = "delete-orphan" in values
if self.delete_orphan and not self.delete:
util.warn("The 'delete-orphan' cascade "
"option requires 'delete'.")
return self
def __repr__(self):
return "CascadeOptions(%r)" % (
",".join([x for x in sorted(self)])
)
@classmethod
def from_string(cls, arg):
values = [
c for c
in re.split(r'\s*,\s*', arg or "")
if c
]
return cls(values)
def _validator_events(
desc, key, validator, include_removes, include_backrefs):
"""Runs a validation method on an attribute value to be set or
appended.
"""
if not include_backrefs:
def detect_is_backref(state, initiator):
impl = state.manager[key].impl
return initiator.impl is not impl
if include_removes:
def append(state, value, initiator):
if (
initiator.op is not attributes.OP_BULK_REPLACE and
(include_backrefs or not detect_is_backref(state, initiator))
):
return validator(state.obj(), key, value, False)
else:
return value
def bulk_set(state, values, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
obj = state.obj()
values[:] = [
validator(obj, key, value, False) for value in values]
def set_(state, value, oldvalue, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
return validator(state.obj(), key, value, False)
else:
return value
def remove(state, value, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
validator(state.obj(), key, value, True)
else:
def append(state, value, initiator):
if (
initiator.op is not attributes.OP_BULK_REPLACE and
(include_backrefs or not detect_is_backref(state, initiator))
):
return validator(state.obj(), key, value)
else:
return value
def bulk_set(state, values, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
obj = state.obj()
values[:] = [
validator(obj, key, value) for value in values]
def set_(state, value, oldvalue, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
return validator(state.obj(), key, value)
else:
return value
event.listen(desc, 'append', append, raw=True, retval=True)
event.listen(desc, 'bulk_replace', bulk_set, raw=True)
event.listen(desc, 'set', set_, raw=True, retval=True)
if include_removes:
event.listen(desc, "remove", remove, raw=True, retval=True)
def polymorphic_union(table_map, typecolname,
aliasname='p_union', cast_nulls=True):
"""Create a ``UNION`` statement used by a polymorphic mapper.
See :ref:`concrete_inheritance` for an example of how
this is used.
:param table_map: mapping of polymorphic identities to
:class:`.Table` objects.
:param typecolname: string name of a "discriminator" column, which will be
derived from the query, producing the polymorphic identity for
each row. If ``None``, no polymorphic discriminator is generated.
:param aliasname: name of the :func:`~sqlalchemy.sql.expression.alias()`
construct generated.
:param cast_nulls: if True, non-existent columns, which are represented
as labeled NULLs, will be passed into CAST. This is a legacy behavior
that is problematic on some backends such as Oracle - in which case it
can be set to False.
"""
colnames = util.OrderedSet()
colnamemaps = {}
types = {}
for key in table_map:
table = table_map[key]
# mysql doesn't like selecting from a select;
# make it an alias of the select
if isinstance(table, sql.Select):
table = table.alias()
table_map[key] = table
m = {}
for c in table.c:
colnames.add(c.key)
m[c.key] = c
types[c.key] = c.type
colnamemaps[table] = m
def col(name, table):
try:
return colnamemaps[table][name]
except KeyError:
if cast_nulls:
return sql.cast(sql.null(), types[name]).label(name)
else:
return sql.type_coerce(sql.null(), types[name]).label(name)
result = []
for type, table in table_map.items():
if typecolname is not None:
result.append(
sql.select([col(name, table) for name in colnames] +
[sql.literal_column(
sql_util._quote_ddl_expr(type)).
label(typecolname)],
from_obj=[table]))
else:
result.append(sql.select([col(name, table) for name in colnames],
from_obj=[table]))
return sql.union_all(*result).alias(aliasname)
def identity_key(*args, **kwargs):
"""Generate "identity key" tuples, as are used as keys in the
:attr:`.Session.identity_map` dictionary.
This function has several call styles:
* ``identity_key(class, ident, identity_token=token)``
This form receives a mapped class and a primary key scalar or
tuple as an argument.
E.g.::
>>> identity_key(MyClass, (1, 2))
(<class '__main__.MyClass'>, (1, 2), None)
:param class: mapped class (must be a positional argument)
:param ident: primary key, may be a scalar or tuple argument.
;param identity_token: optional identity token
.. versionadded:: 1.2 added identity_token
* ``identity_key(instance=instance)``
This form will produce the identity key for a given instance. The
instance need not be persistent, only that its primary key attributes
are populated (else the key will contain ``None`` for those missing
values).
E.g.::
>>> instance = MyClass(1, 2)
>>> identity_key(instance=instance)
(<class '__main__.MyClass'>, (1, 2), None)
In this form, the given instance is ultimately run though
:meth:`.Mapper.identity_key_from_instance`, which will have the
effect of performing a database check for the corresponding row
if the object is expired.
:param instance: object instance (must be given as a keyword arg)
* ``identity_key(class, row=row, identity_token=token)``
This form is similar to the class/tuple form, except is passed a
database result row as a :class:`.RowProxy` object.
E.g.::
>>> row = engine.execute("select * from table where a=1 and b=2").\
first()
>>> identity_key(MyClass, row=row)
(<class '__main__.MyClass'>, (1, 2), None)
:param class: mapped class (must be a positional argument)
:param row: :class:`.RowProxy` row returned by a :class:`.ResultProxy`
(must be given as a keyword arg)
;param identity_token: optional identity token
.. versionadded:: 1.2 added identity_token
"""
if args:
row = None
largs = len(args)
if largs == 1:
class_ = args[0]
try:
row = kwargs.pop("row")
except KeyError:
ident = kwargs.pop("ident")
elif largs in (2, 3):
class_, ident = args
else:
raise sa_exc.ArgumentError(
"expected up to three positional arguments, "
"got %s" % largs)
identity_token = kwargs.pop("identity_token", None)
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs))
mapper = class_mapper(class_)
if row is None:
return mapper.identity_key_from_primary_key(
util.to_list(ident), identity_token=identity_token)
else:
return mapper.identity_key_from_row(
row, identity_token=identity_token)
else:
instance = kwargs.pop("instance")
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs.keys))
mapper = object_mapper(instance)
return mapper.identity_key_from_instance(instance)
class ORMAdapter(sql_util.ColumnAdapter):
"""ColumnAdapter subclass which excludes adaptation of entities from
non-matching mappers.
"""
def __init__(self, entity, equivalents=None, adapt_required=False,
chain_to=None, allow_label_resolve=True,
anonymize_labels=False):
info = inspection.inspect(entity)
self.mapper = info.mapper
selectable = info.selectable
is_aliased_class = info.is_aliased_class
if is_aliased_class:
self.aliased_class = entity
else:
self.aliased_class = None
sql_util.ColumnAdapter.__init__(
self, selectable, equivalents, chain_to,
adapt_required=adapt_required,
allow_label_resolve=allow_label_resolve,
anonymize_labels=anonymize_labels,
include_fn=self._include_fn
)
def _include_fn(self, elem):
entity = elem._annotations.get('parentmapper', None)
return not entity or entity.isa(self.mapper)
class AliasedClass(object):
r"""Represents an "aliased" form of a mapped class for usage with Query.
The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias`
Loading ...