362 lines
11 KiB
Python

import re
from django.contrib.gis.db.models.fields import BaseSpatialField
from django.contrib.gis.measure import Distance
from django.db import NotSupportedError
from django.db.models.expressions import Expression
from django.db.models.lookups import Lookup, Transform
from django.db.models.sql.query import Query
class RasterBandTransform(Transform):
def as_sql(self, compiler, connection):
return compiler.compile(self.lhs)
class GISLookup(Lookup):
sql_template = None
transform_func = None
distance = False
band_rhs = None
band_lhs = None
def __init__(self, lhs, rhs):
rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
super().__init__(lhs, rhs)
self.template_params = {}
self.process_rhs_params()
def process_rhs_params(self):
if self.rhs_params:
# Check if a band index was passed in the query argument.
if len(self.rhs_params) == (2 if self.lookup_name == 'relate' else 1):
self.process_band_indices()
elif len(self.rhs_params) > 1:
raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
elif isinstance(self.lhs, RasterBandTransform):
self.process_band_indices(only_lhs=True)
def process_band_indices(self, only_lhs=False):
"""
Extract the lhs band index from the band transform class and the rhs
band index from the input tuple.
"""
# PostGIS band indices are 1-based, so the band index needs to be
# increased to be consistent with the GDALRaster band indices.
if only_lhs:
self.band_rhs = 1
self.band_lhs = self.lhs.band_index + 1
return
if isinstance(self.lhs, RasterBandTransform):
self.band_lhs = self.lhs.band_index + 1
else:
self.band_lhs = 1
self.band_rhs, *self.rhs_params = self.rhs_params
def get_db_prep_lookup(self, value, connection):
# get_db_prep_lookup is called by process_rhs from super class
return ('%s', [connection.ops.Adapter(value)])
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, Query):
# If rhs is some Query, don't touch it.
return super().process_rhs(compiler, connection)
if isinstance(self.rhs, Expression):
self.rhs = self.rhs.resolve_expression(compiler.query)
rhs, rhs_params = super().process_rhs(compiler, connection)
placeholder = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
return placeholder % rhs, rhs_params
def get_rhs_op(self, connection, rhs):
# Unlike BuiltinLookup, the GIS get_rhs_op() implementation should return
# an object (SpatialOperator) with an as_sql() method to allow for more
# complex computations (where the lhs part can be mixed in).
return connection.ops.gis_operators[self.lookup_name]
def as_sql(self, compiler, connection):
lhs_sql, sql_params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
sql_params.extend(rhs_params)
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s', **self.template_params}
rhs_op = self.get_rhs_op(connection, rhs_sql)
return rhs_op.as_sql(connection, self, template_params, sql_params)
# ------------------
# Geometry operators
# ------------------
@BaseSpatialField.register_lookup
class OverlapsLeftLookup(GISLookup):
"""
The overlaps_left operator returns true if A's bounding box overlaps or is to the
left of B's bounding box.
"""
lookup_name = 'overlaps_left'
@BaseSpatialField.register_lookup
class OverlapsRightLookup(GISLookup):
"""
The 'overlaps_right' operator returns true if A's bounding box overlaps or is to the
right of B's bounding box.
"""
lookup_name = 'overlaps_right'
@BaseSpatialField.register_lookup
class OverlapsBelowLookup(GISLookup):
"""
The 'overlaps_below' operator returns true if A's bounding box overlaps or is below
B's bounding box.
"""
lookup_name = 'overlaps_below'
@BaseSpatialField.register_lookup
class OverlapsAboveLookup(GISLookup):
"""
The 'overlaps_above' operator returns true if A's bounding box overlaps or is above
B's bounding box.
"""
lookup_name = 'overlaps_above'
@BaseSpatialField.register_lookup
class LeftLookup(GISLookup):
"""
The 'left' operator returns true if A's bounding box is strictly to the left
of B's bounding box.
"""
lookup_name = 'left'
@BaseSpatialField.register_lookup
class RightLookup(GISLookup):
"""
The 'right' operator returns true if A's bounding box is strictly to the right
of B's bounding box.
"""
lookup_name = 'right'
@BaseSpatialField.register_lookup
class StrictlyBelowLookup(GISLookup):
"""
The 'strictly_below' operator returns true if A's bounding box is strictly below B's
bounding box.
"""
lookup_name = 'strictly_below'
@BaseSpatialField.register_lookup
class StrictlyAboveLookup(GISLookup):
"""
The 'strictly_above' operator returns true if A's bounding box is strictly above B's
bounding box.
"""
lookup_name = 'strictly_above'
@BaseSpatialField.register_lookup
class SameAsLookup(GISLookup):
"""
The "~=" operator is the "same as" operator. It tests actual geometric
equality of two features. So if A and B are the same feature,
vertex-by-vertex, the operator returns true.
"""
lookup_name = 'same_as'
BaseSpatialField.register_lookup(SameAsLookup, 'exact')
@BaseSpatialField.register_lookup
class BBContainsLookup(GISLookup):
"""
The 'bbcontains' operator returns true if A's bounding box completely contains
by B's bounding box.
"""
lookup_name = 'bbcontains'
@BaseSpatialField.register_lookup
class BBOverlapsLookup(GISLookup):
"""
The 'bboverlaps' operator returns true if A's bounding box overlaps B's bounding box.
"""
lookup_name = 'bboverlaps'
@BaseSpatialField.register_lookup
class ContainedLookup(GISLookup):
"""
The 'contained' operator returns true if A's bounding box is completely contained
by B's bounding box.
"""
lookup_name = 'contained'
# ------------------
# Geometry functions
# ------------------
@BaseSpatialField.register_lookup
class ContainsLookup(GISLookup):
lookup_name = 'contains'
@BaseSpatialField.register_lookup
class ContainsProperlyLookup(GISLookup):
lookup_name = 'contains_properly'
@BaseSpatialField.register_lookup
class CoveredByLookup(GISLookup):
lookup_name = 'coveredby'
@BaseSpatialField.register_lookup
class CoversLookup(GISLookup):
lookup_name = 'covers'
@BaseSpatialField.register_lookup
class CrossesLookup(GISLookup):
lookup_name = 'crosses'
@BaseSpatialField.register_lookup
class DisjointLookup(GISLookup):
lookup_name = 'disjoint'
@BaseSpatialField.register_lookup
class EqualsLookup(GISLookup):
lookup_name = 'equals'
@BaseSpatialField.register_lookup
class IntersectsLookup(GISLookup):
lookup_name = 'intersects'
@BaseSpatialField.register_lookup
class OverlapsLookup(GISLookup):
lookup_name = 'overlaps'
@BaseSpatialField.register_lookup
class RelateLookup(GISLookup):
lookup_name = 'relate'
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
pattern_regex = re.compile(r'^[012TF\*]{9}$')
def process_rhs(self, compiler, connection):
# Check the pattern argument
pattern = self.rhs_params[0]
backend_op = connection.ops.gis_operators[self.lookup_name]
if hasattr(backend_op, 'check_relate_argument'):
backend_op.check_relate_argument(pattern)
elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
sql, params = super().process_rhs(compiler, connection)
return sql, params + [pattern]
@BaseSpatialField.register_lookup
class TouchesLookup(GISLookup):
lookup_name = 'touches'
@BaseSpatialField.register_lookup
class WithinLookup(GISLookup):
lookup_name = 'within'
class DistanceLookupBase(GISLookup):
distance = True
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def process_rhs_params(self):
if not 1 <= len(self.rhs_params) <= 3:
raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
elif len(self.rhs_params) == 3 and self.rhs_params[2] != 'spheroid':
raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.")
# Check if the second parameter is a band index.
if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
self.process_band_indices()
def process_distance(self, compiler, connection):
dist_param = self.rhs_params[0]
return (
compiler.compile(dist_param.resolve_expression(compiler.query))
if hasattr(dist_param, 'resolve_expression') else
('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
)
@BaseSpatialField.register_lookup
class DWithinLookup(DistanceLookupBase):
lookup_name = 'dwithin'
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %(value)s)'
def process_distance(self, compiler, connection):
dist_param = self.rhs_params[0]
if (
not connection.features.supports_dwithin_distance_expr and
hasattr(dist_param, 'resolve_expression') and
not isinstance(dist_param, Distance)
):
raise NotSupportedError(
'This backend does not support expressions for specifying '
'distance in the dwithin lookup.'
)
return super().process_distance(compiler, connection)
def process_rhs(self, compiler, connection):
dist_sql, dist_params = self.process_distance(compiler, connection)
self.template_params['value'] = dist_sql
rhs_sql, params = super().process_rhs(compiler, connection)
return rhs_sql, params + dist_params
class DistanceLookupFromFunction(DistanceLookupBase):
def as_sql(self, compiler, connection):
spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
dist_sql, dist_params = self.process_distance(compiler, connection)
return (
'%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
params + dist_params,
)
@BaseSpatialField.register_lookup
class DistanceGTLookup(DistanceLookupFromFunction):
lookup_name = 'distance_gt'
op = '>'
@BaseSpatialField.register_lookup
class DistanceGTELookup(DistanceLookupFromFunction):
lookup_name = 'distance_gte'
op = '>='
@BaseSpatialField.register_lookup
class DistanceLTLookup(DistanceLookupFromFunction):
lookup_name = 'distance_lt'
op = '<'
@BaseSpatialField.register_lookup
class DistanceLTELookup(DistanceLookupFromFunction):
lookup_name = 'distance_lte'
op = '<='