from decimal import Decimal
from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.sql import AreaField
from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure,
)
from django.core.exceptions import FieldError
from django.db.models import FloatField, IntegerField, TextField
from django.db.models.expressions import Func, Value
from django.utils import six
NUMERIC_TYPES = six.integer_types + (float, Decimal)
class GeoFunc(Func):
function = None
output_field_class = None
geom_param_pos = 0
def __init__(self, *expressions, **extra):
if 'output_field' not in extra and self.output_field_class:
extra['output_field'] = self.output_field_class()
super(GeoFunc, self).__init__(*expressions, **extra)
@property
def name(self):
return self.__class__.__name__
@property
def srid(self):
expr = self.source_expressions[self.geom_param_pos]
if hasattr(expr, 'srid'):
return expr.srid
try:
return expr.field.srid
except (AttributeError, FieldError):
return None
def as_sql(self, compiler, connection):
if self.function is None:
self.function = connection.ops.spatial_function_name(self.name)
return super(GeoFunc, self).as_sql(compiler, connection)
def resolve_expression(self, *args, **kwargs):
res = super(GeoFunc, self).resolve_expression(*args, **kwargs)
base_srid = res.srid
if not base_srid:
raise TypeError("Geometry functions can only operate on geometric content.")
for pos, expr in enumerate(res.source_expressions[1:], start=1):
if isinstance(expr, GeomValue) and expr.srid != base_srid:
# Automatic SRID conversion so objects are comparable
res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
return res
def _handle_param(self, value, param_name='', check_types=None):
if not hasattr(value, 'resolve_expression'):
if check_types and not isinstance(value, check_types):
raise TypeError(
"The %s parameter has the wrong type: should be %s." % (
param_name, str(check_types))
)
return value
class GeomValue(Value):
geography = False
@property
def srid(self):
return self.value.srid
def as_sql(self, compiler, connection):
if self.geography:
self.value = connection.ops.Adapter(self.value, geography=self.geography)
else:
self.value = connection.ops.Adapter(self.value)
return super(GeomValue, self).as_sql(compiler, connection)
def as_mysql(self, compiler, connection):
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
def as_sqlite(self, compiler, connection):
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
def as_oracle(self, compiler, connection):
return 'SDO_GEOMETRY(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
class GeoFuncWithGeoParam(GeoFunc):
def __init__(self, expression, geom, *expressions, **extra):
if not hasattr(geom, 'srid') or not geom.srid:
raise ValueError("Please provide a geometry attribute with a defined SRID.")
super(GeoFuncWithGeoParam, self).__init__(expression, GeomValue(geom), *expressions, **extra)
class SQLiteDecimalToFloatMixin(object):
"""
By default, Decimal values are converted to str by the SQLite backend, which
is not acceptable by the GIS functions expecting numeric values.
"""
def as_sqlite(self, compiler, connection):
for expr in self.get_source_expressions():
if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
expr.value = float(expr.value)
return super(SQLiteDecimalToFloatMixin, self).as_sql(compiler, connection)
class OracleToleranceMixin(object):
tolerance = 0.05
def as_oracle(self, compiler, connection):
tol = self.extra.get('tolerance', self.tolerance)
self.template = "%%(function)s(%%(expressions)s, %s)" % tol
return super(OracleToleranceMixin, self).as_sql(compiler, connection)
class Area(OracleToleranceMixin, GeoFunc):
def as_sql(self, compiler, connection):
if connection.ops.geography:
# Geography fields support area calculation, returns square meters.
self.output_field = AreaField('sq_m')
elif not self.output_field.geodetic(connection):
# Getting the area units of the geographic field.
units = self.output_field.units_name(connection)
if units:
self.output_field = AreaField(
AreaMeasure.unit_attname(self.output_field.units_name(connection))
)
else:
self.output_field = FloatField()
else:
# TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
return super(Area, self).as_sql(compiler, connection)
def as_oracle(self, compiler, connection):
self.output_field = AreaField('sq_m') # Oracle returns area in units of meters.
return super(Area, self).as_oracle(compiler, connection)
class AsGeoJSON(GeoFunc):
output_field_class = TextField
def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
expressions = [expression]
if precision is not None:
expressions.append(self._handle_param(precision, 'precision', six.integer_types))
options = 0
if crs and bbox:
options = 3
elif bbox:
options = 1
elif crs:
options = 2
if options:
expressions.append(options)
super(AsGeoJSON, self).__init__(*expressions, **extra)
class AsGML(GeoFunc):
geom_param_pos = 1
output_field_class = TextField
def __init__(self, expression, version=2, precision=8, **extra):
expressions = [version, expression]
if precision is not None:
expressions.append(self._handle_param(precision, 'precision', six.integer_types))
super(AsGML, self).__init__(*expressions, **extra)
class AsKML(AsGML):
def as_sqlite(self, compiler, connection):
# No version parameter
self.source_expressions.pop(0)
return super(AsKML, self).as_sql(compiler, connection)
class AsSVG(GeoFunc):
output_field_class = TextField
def __init__(self, expression, relative=False, precision=8, **extra):
relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
expressions = [
expression,
relative,
self._handle_param(precision, 'precision', six.integer_types),
]
super(AsSVG, self).__init__(*expressions, **extra)
class BoundingCircle(GeoFunc):
def __init__(self, expression, num_seg=48, **extra):
super(BoundingCircle, self).__init__(*[expression, num_seg], **extra)
class Centroid(OracleToleranceMixin, GeoFunc):
pass
class Difference(OracleToleranceMixin, GeoFuncWithGeoParam):
pass
class DistanceResultMixin(object):
def convert_value(self, value, expression, connection, context):
if value is None:
return None
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection):
dist_att = 'm'
else:
units = geo_field.units_name(connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
else:
dist_att = None
if dist_att:
return DistanceMeasure(**{dist_att: value})
return value
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam):
output_field_class = FloatField
spheroid = None
def __init__(self, expr1, expr2, spheroid=None, **extra):
expressions = [expr1, expr2]
if spheroid is not None:
self.spheroid = spheroid
expressions += (self._handle_param(spheroid, 'spheroid', bool),)
super(Distance, self).__init__(*expressions, **extra)
def as_postgresql(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
src_field = self.get_source_fields()[0]
geography = src_field.geography and self.srid == 4326
if geography:
# Set parameters as geography if base field is geography
for pos, expr in enumerate(
self.source_expressions[self.geom_param_pos + 1:], start=self.geom_param_pos + 1):
if isinstance(expr, GeomValue):
expr.geography = True
elif geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions
if self.spheroid:
self.function = 'ST_Distance_Spheroid' # More accurate, resource intensive
# Replace boolean param by the real spheroid of the base field
self.source_expressions[2] = Value(geo_field._spheroid)
else:
self.function = 'ST_Distance_Sphere'
return super(Distance, self).as_sql(compiler, connection)
def as_oracle(self, compiler, connection):
if self.spheroid:
self.source_expressions.pop(2)
return super(Distance, self).as_oracle(compiler, connection)
class Envelope(GeoFunc):
pass
class ForceRHR(GeoFunc):
pass
class GeoHash(GeoFunc):
output_field_class = TextField
def __init__(self, expression, precision=None, **extra):
expressions = [expression]
if precision is not None:
expressions.append(self._handle_param(precision, 'precision', six.integer_types))
super(GeoHash, self).__init__(*expressions, **extra)
class Intersection(OracleToleranceMixin, GeoFuncWithGeoParam):
pass
class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
output_field_class = FloatField
def __init__(self, expr1, spheroid=True, **extra):
self.spheroid = spheroid
super(Length, self).__init__(expr1, **extra)
def as_sql(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
return super(Length, self).as_sql(compiler, connection)
def as_postgresql(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
src_field = self.get_source_fields()[0]
geography = src_field.geography and self.srid == 4326
if geography:
self.source_expressions.append(Value(self.spheroid))
elif geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
self.function = 'ST_Length_Spheroid'
self.source_expressions.append(Value(geo_field._spheroid))
else:
dim = min(f.dim for f in self.get_source_fields() if f)
if dim > 2:
self.function = connection.ops.length3d
return super(Length, self).as_sql(compiler, connection)
def as_sqlite(self, compiler, connection):
geo_field = GeometryField(srid=self.srid)
if geo_field.geodetic(connection):
if self.spheroid:
self.function = 'GeodesicLength'
else:
self.function = 'GreatCircleLength'
return super(Length, self).as_sql(compiler, connection)
class MemSize(GeoFunc):
output_field_class = IntegerField
class NumGeometries(GeoFunc):
output_field_class = IntegerField
class NumPoints(GeoFunc):
output_field_class = IntegerField
def as_sqlite(self, compiler, connection):
if self.source_expressions[self.geom_param_pos].output_field.geom_type != 'LINESTRING':
raise TypeError("Spatialite NumPoints can only operate on LineString content")
return super(NumPoints, self).as_sql(compiler, connection)
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
output_field_class = FloatField
def as_postgresql(self, compiler, connection):
dim = min(f.dim for f in self.get_source_fields())
if dim > 2:
self.function = connection.ops.perimeter3d
Loading ...