Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

hemamaps / Django   python

Repository URL to install this package:

/ contrib / gis / db / models / functions.py

from decimal import Decimal

from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.sql import AreaField
from django.contrib.gis.measure import (
    Area as AreaMeasure, Distance as DistanceMeasure,
)
from django.core.exceptions import FieldError
from django.db.models import FloatField, IntegerField, TextField
from django.db.models.expressions import Func, Value
from django.utils import six

NUMERIC_TYPES = six.integer_types + (float, Decimal)


class GeoFunc(Func):
    function = None
    output_field_class = None
    geom_param_pos = 0

    def __init__(self, *expressions, **extra):
        if 'output_field' not in extra and self.output_field_class:
            extra['output_field'] = self.output_field_class()
        super(GeoFunc, self).__init__(*expressions, **extra)

    @property
    def name(self):
        return self.__class__.__name__

    @property
    def srid(self):
        expr = self.source_expressions[self.geom_param_pos]
        if hasattr(expr, 'srid'):
            return expr.srid
        try:
            return expr.field.srid
        except (AttributeError, FieldError):
            return None

    def as_sql(self, compiler, connection):
        if self.function is None:
            self.function = connection.ops.spatial_function_name(self.name)
        return super(GeoFunc, self).as_sql(compiler, connection)

    def resolve_expression(self, *args, **kwargs):
        res = super(GeoFunc, self).resolve_expression(*args, **kwargs)
        base_srid = res.srid
        if not base_srid:
            raise TypeError("Geometry functions can only operate on geometric content.")

        for pos, expr in enumerate(res.source_expressions[1:], start=1):
            if isinstance(expr, GeomValue) and expr.srid != base_srid:
                # Automatic SRID conversion so objects are comparable
                res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
        return res

    def _handle_param(self, value, param_name='', check_types=None):
        if not hasattr(value, 'resolve_expression'):
            if check_types and not isinstance(value, check_types):
                raise TypeError(
                    "The %s parameter has the wrong type: should be %s." % (
                        param_name, str(check_types))
                )
        return value


class GeomValue(Value):
    geography = False

    @property
    def srid(self):
        return self.value.srid

    def as_sql(self, compiler, connection):
        if self.geography:
            self.value = connection.ops.Adapter(self.value, geography=self.geography)
        else:
            self.value = connection.ops.Adapter(self.value)
        return super(GeomValue, self).as_sql(compiler, connection)

    def as_mysql(self, compiler, connection):
        return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]

    def as_sqlite(self, compiler, connection):
        return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]

    def as_oracle(self, compiler, connection):
        return 'SDO_GEOMETRY(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]


class GeoFuncWithGeoParam(GeoFunc):
    def __init__(self, expression, geom, *expressions, **extra):
        if not hasattr(geom, 'srid') or not geom.srid:
            raise ValueError("Please provide a geometry attribute with a defined SRID.")
        super(GeoFuncWithGeoParam, self).__init__(expression, GeomValue(geom), *expressions, **extra)


class SQLiteDecimalToFloatMixin(object):
    """
    By default, Decimal values are converted to str by the SQLite backend, which
    is not acceptable by the GIS functions expecting numeric values.
    """
    def as_sqlite(self, compiler, connection):
        for expr in self.get_source_expressions():
            if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
                expr.value = float(expr.value)
        return super(SQLiteDecimalToFloatMixin, self).as_sql(compiler, connection)


class OracleToleranceMixin(object):
    tolerance = 0.05

    def as_oracle(self, compiler, connection):
        tol = self.extra.get('tolerance', self.tolerance)
        self.template = "%%(function)s(%%(expressions)s, %s)" % tol
        return super(OracleToleranceMixin, self).as_sql(compiler, connection)


class Area(OracleToleranceMixin, GeoFunc):
    def as_sql(self, compiler, connection):
        if connection.ops.geography:
            # Geography fields support area calculation, returns square meters.
            self.output_field = AreaField('sq_m')
        elif not self.output_field.geodetic(connection):
            # Getting the area units of the geographic field.
            units = self.output_field.units_name(connection)
            if units:
                self.output_field = AreaField(
                    AreaMeasure.unit_attname(self.output_field.units_name(connection))
                )
            else:
                self.output_field = FloatField()
        else:
            # TODO: Do we want to support raw number areas for geodetic fields?
            raise NotImplementedError('Area on geodetic coordinate systems not supported.')
        return super(Area, self).as_sql(compiler, connection)

    def as_oracle(self, compiler, connection):
        self.output_field = AreaField('sq_m')  # Oracle returns area in units of meters.
        return super(Area, self).as_oracle(compiler, connection)


class AsGeoJSON(GeoFunc):
    output_field_class = TextField

    def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
        expressions = [expression]
        if precision is not None:
            expressions.append(self._handle_param(precision, 'precision', six.integer_types))
        options = 0
        if crs and bbox:
            options = 3
        elif bbox:
            options = 1
        elif crs:
            options = 2
        if options:
            expressions.append(options)
        super(AsGeoJSON, self).__init__(*expressions, **extra)


class AsGML(GeoFunc):
    geom_param_pos = 1
    output_field_class = TextField

    def __init__(self, expression, version=2, precision=8, **extra):
        expressions = [version, expression]
        if precision is not None:
            expressions.append(self._handle_param(precision, 'precision', six.integer_types))
        super(AsGML, self).__init__(*expressions, **extra)


class AsKML(AsGML):
    def as_sqlite(self, compiler, connection):
        # No version parameter
        self.source_expressions.pop(0)
        return super(AsKML, self).as_sql(compiler, connection)


class AsSVG(GeoFunc):
    output_field_class = TextField

    def __init__(self, expression, relative=False, precision=8, **extra):
        relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
        expressions = [
            expression,
            relative,
            self._handle_param(precision, 'precision', six.integer_types),
        ]
        super(AsSVG, self).__init__(*expressions, **extra)


class BoundingCircle(GeoFunc):
    def __init__(self, expression, num_seg=48, **extra):
        super(BoundingCircle, self).__init__(*[expression, num_seg], **extra)


class Centroid(OracleToleranceMixin, GeoFunc):
    pass


class Difference(OracleToleranceMixin, GeoFuncWithGeoParam):
    pass


class DistanceResultMixin(object):
    def convert_value(self, value, expression, connection, context):
        if value is None:
            return None
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        if geo_field.geodetic(connection):
            dist_att = 'm'
        else:
            units = geo_field.units_name(connection)
            if units:
                dist_att = DistanceMeasure.unit_attname(units)
            else:
                dist_att = None
        if dist_att:
            return DistanceMeasure(**{dist_att: value})
        return value


class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam):
    output_field_class = FloatField
    spheroid = None

    def __init__(self, expr1, expr2, spheroid=None, **extra):
        expressions = [expr1, expr2]
        if spheroid is not None:
            self.spheroid = spheroid
            expressions += (self._handle_param(spheroid, 'spheroid', bool),)
        super(Distance, self).__init__(*expressions, **extra)

    def as_postgresql(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        src_field = self.get_source_fields()[0]
        geography = src_field.geography and self.srid == 4326
        if geography:
            # Set parameters as geography if base field is geography
            for pos, expr in enumerate(
                    self.source_expressions[self.geom_param_pos + 1:], start=self.geom_param_pos + 1):
                if isinstance(expr, GeomValue):
                    expr.geography = True
        elif geo_field.geodetic(connection):
            # Geometry fields with geodetic (lon/lat) coordinates need special distance functions
            if self.spheroid:
                self.function = 'ST_Distance_Spheroid'  # More accurate, resource intensive
                # Replace boolean param by the real spheroid of the base field
                self.source_expressions[2] = Value(geo_field._spheroid)
            else:
                self.function = 'ST_Distance_Sphere'
        return super(Distance, self).as_sql(compiler, connection)

    def as_oracle(self, compiler, connection):
        if self.spheroid:
            self.source_expressions.pop(2)
        return super(Distance, self).as_oracle(compiler, connection)


class Envelope(GeoFunc):
    pass


class ForceRHR(GeoFunc):
    pass


class GeoHash(GeoFunc):
    output_field_class = TextField

    def __init__(self, expression, precision=None, **extra):
        expressions = [expression]
        if precision is not None:
            expressions.append(self._handle_param(precision, 'precision', six.integer_types))
        super(GeoHash, self).__init__(*expressions, **extra)


class Intersection(OracleToleranceMixin, GeoFuncWithGeoParam):
    pass


class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
    output_field_class = FloatField

    def __init__(self, expr1, spheroid=True, **extra):
        self.spheroid = spheroid
        super(Length, self).__init__(expr1, **extra)

    def as_sql(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
            raise NotImplementedError("This backend doesn't support Length on geodetic fields")
        return super(Length, self).as_sql(compiler, connection)

    def as_postgresql(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info
        src_field = self.get_source_fields()[0]
        geography = src_field.geography and self.srid == 4326
        if geography:
            self.source_expressions.append(Value(self.spheroid))
        elif geo_field.geodetic(connection):
            # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
            self.function = 'ST_Length_Spheroid'
            self.source_expressions.append(Value(geo_field._spheroid))
        else:
            dim = min(f.dim for f in self.get_source_fields() if f)
            if dim > 2:
                self.function = connection.ops.length3d
        return super(Length, self).as_sql(compiler, connection)

    def as_sqlite(self, compiler, connection):
        geo_field = GeometryField(srid=self.srid)
        if geo_field.geodetic(connection):
            if self.spheroid:
                self.function = 'GeodesicLength'
            else:
                self.function = 'GreatCircleLength'
        return super(Length, self).as_sql(compiler, connection)


class MemSize(GeoFunc):
    output_field_class = IntegerField


class NumGeometries(GeoFunc):
    output_field_class = IntegerField


class NumPoints(GeoFunc):
    output_field_class = IntegerField

    def as_sqlite(self, compiler, connection):
        if self.source_expressions[self.geom_param_pos].output_field.geom_type != 'LINESTRING':
            raise TypeError("Spatialite NumPoints can only operate on LineString content")
        return super(NumPoints, self).as_sql(compiler, connection)


class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
    output_field_class = FloatField

    def as_postgresql(self, compiler, connection):
        dim = min(f.dim for f in self.get_source_fields())
        if dim > 2:
            self.function = connection.ops.perimeter3d
Loading ...