Skip to content

Commit 4480351

Browse files
Merge branch 'main' into drop-39
2 parents 2ea3eb7 + 1d1ba95 commit 4480351

File tree

16 files changed

+673
-164
lines changed

16 files changed

+673
-164
lines changed

devito/core/gpu.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,29 @@ def _rcompile_wrapper(cls, **kwargs0):
126126

127127
def wrapper(expressions, mode='default', options=None, **kwargs1):
128128
kwargs = {**kwargs0, **kwargs1}
129+
options = options or {}
129130

130131
if mode == 'host':
131-
options = options or {}
132132
target = {
133133
'platform': 'cpu64',
134134
'language': 'C' if options0['par-disabled'] else 'openmp',
135135
'compiler': 'custom'
136136
}
137137
else:
138-
options = {**options0, **(options or {})}
138+
# Always use the default `par-tile` for recursive compilation
139+
# unless the caller explicitly overrides it so that if the user
140+
# supplies a multi par-tile there is no need to worry about the
141+
# small kernels typically generated by recursive compilation
142+
par_tile0 = options0['par-tile']
143+
par_tile = options.get('par-tile')
144+
if par_tile0 and par_tile:
145+
options = {**options0, **options, 'par-tile': par_tile}
146+
elif par_tile0:
147+
par_tile = ParTile(par_tile0.default, default=par_tile0.default)
148+
options = {**options0, **options, 'par-tile': par_tile}
149+
else:
150+
options = {**options0, **options}
151+
139152
target = None
140153

141154
return rcompile(expressions, kwargs, options, target=target)

devito/ir/iet/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,8 @@ def arguments(self):
15161516
return self.halo_scheme.arguments
15171517

15181518
@property
1519-
def is_empty(self):
1520-
return len(self.halo_scheme) == 0
1519+
def is_void(self):
1520+
return self.halo_scheme.is_void
15211521

15221522
@property
15231523
def body(self):

devito/ir/iet/visitors.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
IndexedData, DeviceMap)
2929

3030

31-
__all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols',
32-
'MapExprStmts', 'MapHaloSpots', 'MapNodes', 'IsPerfectIteration',
33-
'printAST', 'CGen', 'CInterface', 'Transformer', 'Uxreplace']
31+
__all__ = ['FindApplications', 'FindNodes', 'FindWithin', 'FindSections',
32+
'FindSymbols', 'MapExprStmts', 'MapHaloSpots', 'MapNodes',
33+
'IsPerfectIteration', 'printAST', 'CGen', 'CInterface', 'Transformer',
34+
'Uxreplace']
3435

3536

3637
class Visitor(GenericVisitor):
@@ -1112,6 +1113,49 @@ def visit_Node(self, o, ret=None):
11121113
return ret
11131114

11141115

1116+
class FindWithin(FindNodes):
1117+
1118+
@classmethod
1119+
def default_retval(cls):
1120+
return [], False
1121+
1122+
"""
1123+
Like FindNodes, but given an additional parameter `within=(start, stop)`,
1124+
it starts collecting matching nodes only after `start` is found, and stops
1125+
collecting matching nodes after `stop` is found.
1126+
"""
1127+
1128+
def __init__(self, match, start, stop=None):
1129+
super().__init__(match)
1130+
self.start = start
1131+
self.stop = stop
1132+
1133+
def visit(self, o, ret=None):
1134+
found, _ = self._visit(o, ret=ret)
1135+
return found
1136+
1137+
def visit_Node(self, o, ret=None):
1138+
if ret is None:
1139+
ret = self.default_retval()
1140+
found, flag = ret
1141+
1142+
if o is self.start:
1143+
flag = True
1144+
1145+
if flag and self.rule(self.match, o):
1146+
found.append(o)
1147+
for i in o.children:
1148+
found, newflag = self._visit(i, ret=(found, flag))
1149+
if flag and not newflag:
1150+
return found, newflag
1151+
flag = newflag
1152+
1153+
if o is self.stop:
1154+
flag = False
1155+
1156+
return found, flag
1157+
1158+
11151159
class FindApplications(Visitor):
11161160

11171161
"""

devito/ir/support/basic.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
q_constant, q_comp_acc, q_affine, q_routine, search,
1212
uxreplace)
1313
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
14-
flatten, memoized_meth, memoized_generator)
14+
flatten, memoized_meth, memoized_generator, smart_gt,
15+
smart_lt)
1516
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1617
CriticalRegion, Function, Symbol, Temp, TempArray,
1718
TBArray)
@@ -364,11 +365,12 @@ def distance(self, other):
364365
# trip count. E.g. it ranges from 0 to 3; `other` performs a
365366
# constant access at 4
366367
for v in (self[n], other[n]):
367-
try:
368-
if bool(v < sit.symbolic_min or v > sit.symbolic_max):
369-
return Vector(S.ImaginaryUnit)
370-
except TypeError:
371-
pass
368+
# Note: Uses smart_ comparisons avoid evaluating expensive
369+
# symbolic Lt or Gt operations,
370+
# Note: Boolean is split to make the conditional short
371+
# circuit more frequently for mild speedup.
372+
if smart_lt(v, sit.symbolic_min) or smart_gt(v, sit.symbolic_max):
373+
return Vector(S.ImaginaryUnit)
372374

373375
# Case 2: `sit` is an IterationInterval over a local SubDimension
374376
# and `other` performs a constant access
@@ -382,32 +384,36 @@ def distance(self, other):
382384
if disjoint_test(self[n], other[n], sai, sit):
383385
return Vector(S.ImaginaryUnit)
384386

387+
# Compute the distance along the current IterationInterval
385388
if self.function._mem_shared:
386389
# Special case: the distance between two regular, thread-shared
387-
# objects fallbacks to zero, as any other value would be nonsensical
390+
# objects falls back to zero, as any other value would be
391+
# nonsensical
392+
ret.append(S.Zero)
393+
elif degenerating_dimensions(sai, oai):
394+
# Special case: `sai` and `oai` may be different symbolic objects
395+
# but they can be proved to systematically generate the same value
388396
ret.append(S.Zero)
389-
390397
elif sai and oai and sai._defines & sit.dim._defines:
391-
# E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`, `ai=t`
398+
# E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`,
399+
# and `ai=t`
392400
if sit.direction is Backward:
393401
ret.append(other[n] - self[n])
394402
else:
395403
ret.append(self[n] - other[n])
396-
397404
elif not sai and not oai:
398405
# E.g., `self=R<a,[3]>` and `other=W<a,[4]>`
399406
if self[n] - other[n] == 0:
400407
ret.append(S.Zero)
401408
else:
402409
break
403-
404410
elif sai in self.ispace and oai in other.ispace:
405411
# E.g., `self=R<f,[x, y]>`, `sai=time`,
406412
# `self.itintervals=(time, x, y)`, `n=0`
407413
continue
408-
409414
else:
410-
# E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`, `n=1`
415+
# E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`,
416+
# and `n=1`
411417
return vinf(ret)
412418

413419
n = len(ret)
@@ -1408,3 +1414,19 @@ def disjoint_test(e0, e1, d, it):
14081414
i1 = sympy.Interval(min(p10, p11), max(p10, p11))
14091415

14101416
return not bool(i0.intersect(i1))
1417+
1418+
1419+
def degenerating_dimensions(d0, d1):
1420+
"""
1421+
True if `d0` and `d1` are Dimensions that are possibly symbolically
1422+
different, but they can be proved to systematically degenerate to the
1423+
same value, False otherwise.
1424+
"""
1425+
# Case 1: ModuloDimensions of size 1
1426+
try:
1427+
if d0.is_Modulo and d1.is_Modulo and d0.modulo == d1.modulo == 1:
1428+
return True
1429+
except AttributeError:
1430+
pass
1431+
1432+
return False

devito/mpi/halo_scheme.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def __hash__(self):
4343
return hash((self.loc_indices, self.loc_dirs, self.halos, self.dims,
4444
self.bundle))
4545

46+
@cached_property
47+
def loc_values(self):
48+
return frozenset(self.loc_indices.values())
49+
4650
def union(self, other):
4751
"""
4852
Return a new HaloSchemeEntry that is the union of this and `other`.
@@ -384,6 +388,10 @@ def owned_size(self):
384388
mapper[d] = (max(maxl, s.left), max(maxr, s.right))
385389
return mapper
386390

391+
@cached_property
392+
def functions(self):
393+
return frozenset(self.fmapper)
394+
387395
@cached_property
388396
def dimensions(self):
389397
retval = set()
@@ -413,6 +421,38 @@ def loc_values(self):
413421
def arguments(self):
414422
return self.dimensions | set(flatten(self.honored.values()))
415423

424+
def issubset(self, other):
425+
"""
426+
Check if `self` is a subset of `other`.
427+
"""
428+
if not isinstance(other, HaloScheme):
429+
return False
430+
431+
if not all(f in other.fmapper for f in self.fmapper):
432+
return False
433+
434+
for f, hse0 in self.fmapper.items():
435+
hse1 = other.fmapper[f]
436+
437+
# Clearly, `hse0`'s halos must be a subset of `hse1`'s halos...
438+
if not hse0.halos.issubset(hse1.halos) or \
439+
hse0.bundle is not hse1.bundle:
440+
return False
441+
442+
# But now, to be a subset, `hse0`'s must be expecting such halos
443+
# at a time index that is less than or equal to that of `hse1`
444+
if hse0.loc_dirs != hse1.loc_dirs:
445+
return False
446+
447+
loc_dirs = hse0.loc_dirs
448+
raw_loc_indices = {d: (hse0.loc_indices[d], hse1.loc_indices[d])
449+
for d in hse0.loc_indices}
450+
projected_loc_indices, _ = process_loc_indices(raw_loc_indices, loc_dirs)
451+
if projected_loc_indices != hse1.loc_indices:
452+
return False
453+
454+
return True
455+
416456
def project(self, functions):
417457
"""
418458
Create a new HaloScheme that only retains the HaloSchemeEntries corresponding

devito/passes/clusters/cse.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import defaultdict
1+
from collections import defaultdict, Counter
22
from functools import cached_property, singledispatch
33

44
import numpy as np
@@ -13,6 +13,7 @@
1313
from devito.finite_differences.differentiable import IndexDerivative
1414
from devito.ir import Cluster, Scope, cluster_pass
1515
from devito.symbolics import estimate_cost, q_leaf, q_terminal
16+
from devito.symbolics.search import search
1617
from devito.symbolics.manipulation import _uxreplace
1718
from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype
1819
from devito.types import Eq, Symbol, Temp
@@ -25,9 +26,15 @@ class CTemp(Temp):
2526
"""
2627
A cluster-level Temp, similar to Temp, ensured to have different priority
2728
"""
29+
2830
ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp')
2931

3032

33+
def retrieve_ctemps(exprs, mode='all'):
34+
"""Shorthand to retrieve the CTemps in `exprs`"""
35+
return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs')
36+
37+
3138
@cluster_pass
3239
def cse(cluster, sregistry=None, options=None, **kwargs):
3340
"""
@@ -225,8 +232,15 @@ def _compact(exprs, exclude):
225232

226233
mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)}
227234

228-
mapper.update({e.lhs: e.rhs for e in candidates
229-
if sum([i.rhs.count(e.lhs) for i in exprs]) == 1})
235+
# Find all the CTemps in expression right-hand-sides without removing duplicates
236+
ctemps = retrieve_ctemps(e.rhs for e in exprs)
237+
238+
# If there are ctemps in the expressions, then add any that only appear once to
239+
# the mapper
240+
if ctemps:
241+
ctemp_count = Counter(ctemps)
242+
mapper.update({e.lhs: e.rhs for e in candidates
243+
if ctemp_count[e.lhs] == 1})
230244

231245
processed = []
232246
for e in exprs:
@@ -244,6 +258,10 @@ def _toposort(exprs):
244258
"""
245259
Ensure the expression list is topologically sorted.
246260
"""
261+
if not any(isinstance(e.lhs, CTemp) for e in exprs):
262+
# No CSE temps, no need to topological sort
263+
return exprs
264+
247265
dag = DAG(exprs)
248266

249267
for e0 in exprs:
@@ -255,9 +273,14 @@ def _toposort(exprs):
255273
dag.add_edge(e0, e1, force_add=True)
256274

257275
def choose_element(queue, scheduled):
258-
# Try to honor temporary names as much as possible
259-
first = sorted(queue, key=lambda i: str(i.lhs)).pop(0)
260-
queue.remove(first)
276+
tmps = [i for i in queue if isinstance(i.lhs, CTemp)]
277+
if tmps:
278+
# Try to honor temporary names as much as possible
279+
first = sorted(tmps, key=lambda i: i.lhs.name).pop(0)
280+
queue.remove(first)
281+
else:
282+
first = sorted(queue, key=lambda i: exprs.index(i)).pop(0)
283+
queue.remove(first)
261284
return first
262285

263286
processed = dag.topological_sort(choose_element)

devito/passes/clusters/misc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,7 @@ def is_cross(source, sink):
352352
v = len(cg0.exprs)
353353
return t0 < v <= t1 or t1 < v <= t0
354354

355-
for cg1 in cgroups[n+1:]:
356-
n1 = cgroups.index(cg1)
355+
for n1, cg1 in enumerate(cgroups[n+1:], start=n+1):
357356

358357
# A Scope to compute all cross-ClusterGroup anti-dependences
359358
scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross)

0 commit comments

Comments
 (0)