diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index ea4025a3ab..0eb1bed0fa 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -497,6 +497,7 @@ def parse_product_arch(): def get_visible_devices(): device_vars = ( 'CUDA_VISIBLE_DEVICES', + 'NVIDIA_VISIBLE_DEVICES', 'ROCR_VISIBLE_DEVICES', 'HIP_VISIBLE_DEVICES' ) diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 24bbdea972..08feb17839 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -6,7 +6,9 @@ import sympy -from devito.tools import Pickable, as_mapper, as_tuple, frozendict, is_integer +from devito.tools import ( + Pickable, as_mapper, as_tuple, frozendict, is_integer, memoized_func +) from devito.types.dimension import Dimension from devito.types.utils import DimensionTuple from devito.warnings import warn @@ -557,6 +559,7 @@ def _evaluate(self, **kwargs): def _eval_deriv(self): return self._eval_fd(self.expr) + @memoized_func(scope='build') def _eval_fd(self, expr, **kwargs): """ Evaluate the finite-difference approximation of the Derivative. diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 6322b8ca8f..fed3baecc3 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -1028,6 +1028,10 @@ def compare(self, other): def base(self): return self.expr.func(*[a for a in self.expr.args if a is not self.weights]) + @cached_property + def pivot(self): + return self.base.subs({d: 0 for d in self.dimensions}) + @property def weights(self): return self._weights diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index 30199fb3d8..bdf3199b0d 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -170,14 +170,15 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici # `coefficients` method (`taylor` or `symbolic`) if weights is None: weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0) - if isinstance(weights, Iterable) and len(weights) != len(indices): + _, wdim, _ = process_weights(weights, expr, dim) + elif isinstance(weights, Iterable) and len(weights) != len(indices): warning(f"Number of weights ({len(weights)}) does not match " f"number of indices ({len(indices)}), reverting to Taylor") scale = False + wdim = None weights = fd_weights_registry['taylor'](expr, deriv_order, indices, x0) # Did fd_weights_registry return a new Function/Expression instead of a values? - _, wdim, _ = process_weights(weights, expr, dim) if wdim is not None: weights = [weights._subs(wdim, i) for i in range(len(indices))] diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 438c4da0c9..8c9304b126 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -228,10 +228,14 @@ def make_stencil_dimension(expr, _min, _max): @cacheit -def numeric_weights(function, deriv_order, indices, x0): +def _numeric_weights(deriv_order, indices, x0): return finite_diff_weights(deriv_order, indices, x0)[-1][-1] +def numeric_weights(function, deriv_order, indices, x0): + return _numeric_weights(deriv_order, indices, x0) + + fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights, 'symbolic': numeric_weights} # Backward compat for 'symbolic' coeff_priority = {'taylor': 1, 'standard': 1} diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 8ef479955e..96ae8c56ae 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -18,10 +18,12 @@ from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.symbolics.queries import q_leaf -from devito.tools import ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype +from devito.tools import ( + ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype, memoized_func +) from devito.types.basic import AbstractFunction -__all__ = ['BasePrinter', 'ccode'] +__all__ = ['BasePrinter', 'ccode', 'get_printer'] class BasePrinter(CodePrinter): @@ -449,15 +451,20 @@ def _print_Fallback(self, expr): sympy.printing.str.StrPrinter._print_Add = BasePrinter._print_Add -def ccode(expr, printer=None, **settings): +@memoized_func +def get_printer(printer, dtype): + return printer(settings={'dtype': dtype}) + + +def ccode(expr, printer=None, dtype=None): """Generate C++ code from an expression. Parameters ---------- expr : expr-like The expression to be printed. - settings : dict - Options for code printing. + dtype : data-type, optional + Data type used by the printer. Returns ------- @@ -468,4 +475,5 @@ def ccode(expr, printer=None, **settings): if printer is None: from devito.passes.iet.languages.C import CPrinter printer = CPrinter - return printer(settings=settings).doprint(expr, None) + dtype = printer._default_settings['dtype'] if dtype is None else dtype + return get_printer(printer, dtype).doprint(expr, None) diff --git a/devito/ir/clusters/analysis.py b/devito/ir/clusters/analysis.py index 5ebae71b0f..f78f1ee456 100644 --- a/devito/ir/clusters/analysis.py +++ b/devito/ir/clusters/analysis.py @@ -101,7 +101,7 @@ def _callback(self, clusters, dim, prefix): is_parallel_atomic = False scope = Scope(flatten(c.exprs for c in clusters)) - for dep in scope.d_all_gen(): + for dep in scope.d_all_gen(writes=scope.writes_tensor): test00 = dep.is_indep(dim) and not dep.is_storage_related(dim) test01 = all(dep.is_reduce_atmost(i) for i in prev) if test00 and test01: @@ -112,10 +112,6 @@ def _callback(self, clusters, dim, prefix): is_parallel_indep &= (dep.distance_mapper.get(dim.root) == 0) continue - if dep.function in scope.initialized: - # False alarm, the dependence is over a locally-defined symbol - continue - if dep.is_reduction: is_parallel_atomic = True continue diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index cbf206b3ff..717b508b92 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -8,13 +8,15 @@ from devito.ir.support import ( PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext, DataSpace, Forward, Guards, Interval, IntervalGroup, IterationSpace, PrefetchUpdate, Properties, Scope, WaitLock, WithLock, - detect_accesses, detect_io, maximum, minimum, normalize_properties, normalize_syncs, - null_ispace, tailor_properties, update_properties + detect_accesses, maximum, minimum, normalize_properties, normalize_syncs, null_ispace, + tailor_properties, update_properties ) from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import estimate_cost -from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype +from devito.tools import ( + CacheInstances, as_tuple, cached_hash, filter_ordered, flatten, infer_dtype +) from devito.types import ( CriticalRegion, Fence, Indexed, PhaseMarker, TensorMove, ThreadArrive, ThreadCommit, ThreadPoolSync, ThreadWait, WeakFence @@ -23,110 +25,45 @@ __all__ = ["Cluster", "ClusterGroup"] -class Cluster: +class EqBlock(CacheInstances): """ - A Cluster is an ordered sequence of expressions in an IterationSpace. - - Parameters - ---------- - exprs : expr-like or list of expr-like - An ordered sequence of expressions computing a tensor. - ispace : IterationSpace, optional - The Cluster iteration space. - guards : dict, optional - Mapper from Dimensions to expr-like, representing the conditions under - which the Cluster should be computed. - properties : dict, optional - Mapper from Dimensions to Property, describing the Cluster properties - such as its parallel Dimensions. - syncs : dict, optional - Mapper from Dimensions to lists of SyncOps, that is ordered sequences of - synchronization operations that must be performed in order to compute the - Cluster asynchronously. - halo_scheme : HaloScheme, optional - The halo exchanges required by the Cluster. + A sequence of equations with associated metadata. """ + @classmethod + def _preprocess_args(cls, exprs, ispace=null_ispace, guards=None, + properties=None, syncs=None, halo_scheme=None): + exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs)) + guards = Guards(guards or {}) + properties = Properties(properties or {}) + syncs = normalize_syncs(syncs or {}) + + return (exprs, ispace, guards, properties, syncs, halo_scheme), {} + def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None, syncs=None, halo_scheme=None): - self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs)) + self._exprs = exprs self._ispace = ispace - self._guards = Guards(guards or {}) - self._syncs = normalize_syncs(syncs or {}) - - properties = Properties(properties or {}) - properties = tailor_properties(properties, ispace) - self._properties = update_properties(properties, self.exprs) - + self._guards = guards + self._syncs = syncs self._halo_scheme = halo_scheme - def __repr__(self): - return "Cluster([{}])".format(('\n' + ' '*9).join(f'{i}' for i in self.exprs)) - - @classmethod - def from_clusters(cls, *clusters): - """ - Build a new Cluster from a sequence of pre-existing Clusters with - compatible IterationSpace. - """ - assert len(clusters) > 0 - root = clusters[0] - - if len(clusters) == 1: - return root - - if not all(root.ispace.is_compatible(c.ispace) for c in clusters): - raise ValueError("Cannot build a Cluster from Clusters with " - "incompatible IterationSpace") - if not all(root.guards == c.guards for c in clusters): - raise ValueError("Cannot build a Cluster from Clusters with " - "non-homogeneous guards") - - writes = set().union(*[c.scope.writes for c in clusters]) - reads = set().union(*[c.scope.reads for c in clusters]) - if any(f._mem_shared for f in writes & reads): - raise ValueError("Cannot build a Cluster from Clusters with " - "read-write conflicts on shared-memory Functions") - - exprs = chain(*[c.exprs for c in clusters]) - ispace = IterationSpace.union(*[c.ispace for c in clusters]) - - guards = root.guards - - properties = reduce_properties(clusters) - - try: - syncs = normalize_syncs(*[c.syncs for c in clusters]) - except ValueError as e: - raise ValueError( - "Cannot build a Cluster from Clusters with " - "non-compatible synchronization operations" - ) from e - - halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters]) - - return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme) + properties = tailor_properties(properties, ispace) + self._properties = update_properties(properties, self._exprs) - def rebuild(self, *args, **kwargs): - """ - Build a new Cluster from the attributes given as keywords. All other - attributes are taken from ``self``. - """ - # Shortcut for backwards compatibility - if args: - if len(args) != 1: - raise ValueError("rebuild takes at most one positional argument (exprs)") - if kwargs.get('exprs'): - raise ValueError("`exprs` provided both as arg and kwarg") - kwargs['exprs'] = args[0] + def __eq__(self, other): + return (type(self) is type(other) and + self.exprs == other.exprs and + self.ispace == other.ispace and + self.guards == other.guards and + self.properties == other.properties and + self.syncs == other.syncs and + self.halo_scheme == other.halo_scheme) - return self.__class__(exprs=kwargs.get('exprs', self.exprs), - ispace=kwargs.get('ispace', self.ispace), - guards=kwargs.get('guards', self.guards), - properties=kwargs.get('properties', self.properties), - syncs=kwargs.get('syncs', self.syncs), - halo_scheme=kwargs.get('halo_scheme', self.halo_scheme)) + def __hash__(self): + return hash((self.exprs, self.ispace, self.guards, self.properties, + self.syncs, self.halo_scheme)) @property def exprs(self): @@ -382,8 +319,8 @@ def dtype(self): performing integer arithmetic are ignored, assuming that they are only carrying out array index calculations. - If two expressions perform calculations with different precision, the - data type with highest precision is returned. + If two expressions perform calculations with different precision, + the data type with highest precision is returned. """ dtypes = set() for i in self.exprs: @@ -399,8 +336,8 @@ def dtype(self): @cached_property def dspace(self): """ - Derive the DataSpace of the Cluster from its expressions, IterationSpace, - and Guards. + Derive the DataSpace of the Cluster from its expressions, + IterationSpace, and Guards. """ accesses = detect_accesses(self.exprs) @@ -491,7 +428,8 @@ def traffic(self): ----- If a Function is both read and written, then it is counted twice. """ - reads, writes = detect_io(self.exprs, relax=True) + reads = flatten(i.read_functions_relaxed for i in self.exprs) + writes = flatten(i.write_functions_relaxed for i in self.exprs) accesses = [(i, 'r') for i in reads] + [(i, 'w') for i in writes] # Ordering isn't important at this point, so returning an unordered @@ -525,6 +463,156 @@ def traffic(self): return ret +class Cluster: + + """ + A context-sensitive sequence of equations. + + The structural payload (equations, IterationSpace, ...) lives in the + underlying EqBlock. A Cluster, unlike EqBlock, deliberately keeps identity + semantics because its position in a sequence of Clusters does matter. It + follows that two Cluster instances may share the same EqBlock, but they + remain distinct: Clusters intentionally use object identity for equality + and hashing, so only references to the same Cluster object compare equal. + + Parameters + ---------- + exprs : expr-like or list of expr-like + An ordered sequence of expressions computing a tensor. + ispace : IterationSpace, optional + The Cluster iteration space. + guards : dict, optional + Mapper from Dimensions to expr-like, representing the conditions under + which the Cluster should be computed. + properties : dict, optional + Mapper from Dimensions to Property, describing the Cluster properties + such as its parallel Dimensions. + syncs : dict, optional + Mapper from Dimensions to lists of SyncOps, that is ordered sequences of + synchronization operations that must be performed in order to compute the + Cluster asynchronously. + halo_scheme : HaloScheme, optional + The halo exchanges required by the Cluster. + """ + + def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None, + syncs=None, halo_scheme=None): + self._block = EqBlock(exprs, ispace, guards, properties, syncs, halo_scheme) + + def __repr__(self): + return "Cluster([{}])".format(('\n' + ' '*9).join(f'{i}' for i in self.exprs)) + + def __getattr__(self, name): + try: + block = object.__getattribute__(self, '_block') + except AttributeError: + raise AttributeError(name) from None + return getattr(block, name) + + @property + def exprs(self): + return self._block.exprs + + @property + def ispace(self): + return self._block.ispace + + @property + def guards(self): + return self._block.guards + + @property + def properties(self): + return self._block.properties + + @property + def syncs(self): + return self._block.syncs + + @property + def halo_scheme(self): + return self._block.halo_scheme + + @classmethod + def from_clusters(cls, *clusters): + """ + Build a new Cluster from a sequence of pre-existing Clusters with + compatible IterationSpace. + """ + assert len(clusters) > 0 + root = clusters[0] + + if len(clusters) == 1: + return root + + if not all(root.ispace.is_compatible(c.ispace) for c in clusters): + raise ValueError("Cannot build a Cluster from Clusters with " + "incompatible IterationSpace") + if not all(root.guards == c.guards for c in clusters): + raise ValueError("Cannot build a Cluster from Clusters with " + "non-homogeneous guards") + + writes = set().union(*[c.scope.writes for c in clusters]) + reads = set().union(*[c.scope.reads for c in clusters]) + if any(f._mem_shared for f in writes & reads): + raise ValueError("Cannot build a Cluster from Clusters with " + "read-write conflicts on shared-memory Functions") + + exprs = chain(*[c.exprs for c in clusters]) + ispace = IterationSpace.union(*[c.ispace for c in clusters]) + + guards = root.guards + + properties = reduce_properties(clusters) + + try: + syncs = normalize_syncs(*[c.syncs for c in clusters]) + except ValueError as e: + raise ValueError( + "Cannot build a Cluster from Clusters with " + "non-compatible synchronization operations" + ) from e + + halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters]) + + return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme) + + def rebuild(self, *args, **kwargs): + """ + Build a new Cluster from the attributes given as keywords. All other + attributes are taken from ``self``. + """ + # Shortcut for backwards compatibility + if args: + if len(args) != 1: + raise ValueError("rebuild takes at most one positional argument (exprs)") + if kwargs.get('exprs'): + raise ValueError("`exprs` provided both as arg and kwarg") + kwargs['exprs'] = args[0] + + exprs = kwargs.get('exprs', self.exprs) + ispace = kwargs.get('ispace', self.ispace) + guards = kwargs.get('guards', self.guards) + properties = kwargs.get('properties', self.properties) + syncs = kwargs.get('syncs', self.syncs) + halo_scheme = kwargs.get('halo_scheme', self.halo_scheme) + + if exprs is self.exprs and \ + ispace is self.ispace and \ + guards is self.guards and \ + properties is self.properties and \ + syncs is self.syncs and \ + halo_scheme is self.halo_scheme: + return self + + return self.__class__(exprs=exprs, + ispace=ispace, + guards=guards, + properties=properties, + syncs=syncs, + halo_scheme=halo_scheme) + + class ClusterGroup(tuple): """ @@ -552,6 +640,18 @@ def __new__(cls, clusters, ispace=None): return obj + def __eq__(self, other): + return (isinstance(other, ClusterGroup) and + super().__eq__(other) and + self._ispace == other._ispace) + + def __ne__(self, other): + return not self == other + + @cached_hash + def __hash__(self): + return hash((tuple(self), self._ispace)) + @classmethod def concatenate(cls, *cgroups): return list(chain(*cgroups)) diff --git a/devito/ir/clusters/visitors.py b/devito/ir/clusters/visitors.py index 11bcad5365..da0a62344a 100644 --- a/devito/ir/clusters/visitors.py +++ b/devito/ir/clusters/visitors.py @@ -2,7 +2,7 @@ from itertools import groupby from devito.ir.support import IterationSpace, null_ispace -from devito.tools import flatten, timed_pass +from devito.tools import cached_hash, flatten, timed_pass __all__ = ['Queue', 'cluster_pass'] @@ -113,6 +113,10 @@ def _process_fatd(self, clusters, level, prefix=None, **kwargs): class Prefix(IterationSpace): + @classmethod + def _preprocess_args(cls, ispace, guards, properties, syncs): + return (ispace, guards, properties, syncs), {} + def __init__(self, ispace, guards, properties, syncs): super().__init__(ispace.intervals, ispace.sub_iterators, ispace.directions) @@ -127,6 +131,7 @@ def __eq__(self, other): self.properties == other.properties and self.syncs == other.syncs) + @cached_hash def __hash__(self): return hash((self.intervals, self.sub_iterators, self.directions, self.guards, self.properties, self.syncs)) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 8d72704b79..73d0065a58 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -1,3 +1,4 @@ +from contextlib import suppress from functools import cached_property import numpy as np @@ -6,11 +7,12 @@ from devito.finite_differences.differentiable import diff2sympy from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.ir.support import ( - GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses, - detect_io + GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses +) +from devito.symbolics import IntDiv, limits_mapper, retrieve_accesses, uxreplace +from devito.tools import ( + Pickable, Tag, as_hashable, filter_sorted, frozendict, reuse_if_unchanged ) -from devito.symbolics import IntDiv, limits_mapper, uxreplace -from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min __all__ = [ @@ -31,8 +33,8 @@ class IREq(sympy.Eq, Pickable): __rkwargs__ = ('ispace', 'conditionals', 'implicit_dims', 'operation') def _hashable_content(self): - return (*super()._hashable_content(), - *tuple(getattr(self, i) for i in self.__rkwargs__)) + return (super()._hashable_content() + + tuple(as_hashable(getattr(self, i)) for i in self.__rkwargs__)) @property def is_Scalar(self): @@ -80,6 +82,74 @@ def is_Reduction(self): def is_Increment(self): return self.operation is OpInc + @cached_property + def writes(self): + from devito.symbolics.queries import q_routine + + terminals = set(retrieve_accesses(self.lhs)) + if q_routine(self.rhs): + with suppress(AttributeError): + # Everything except: foreign routines, such as `cos` or `sin` etc. + terminals.update(self.rhs.writes) + + return tuple(terminals) + + @cached_property + def reads_explicit(self): + terminals = set(retrieve_accesses(self.rhs, deep=True)) + with suppress(AttributeError): + terminals.update(retrieve_accesses(self.lhs.indices)) + + return tuple(terminals) + + @cached_property + def reads_conditional(self): + accesses = [] + for v in self.conditionals.values(): + accesses.extend(retrieve_accesses(v)) + + return tuple(accesses) + + @cached_property + def reads(self): + return tuple(set(self.reads_explicit) | set(self.reads_conditional)) + + @cached_property + def _read_functions(self): + found = [] + for i in self.reads: + with suppress(AttributeError): + i = i.function + found.append(i) + return tuple(filter_sorted(found)) + + @cached_property + def _write_functions(self): + found = [] + for i in self.writes: + with suppress(AttributeError): + i = i.function + found.append(i) + return tuple(filter_sorted(found)) + + @cached_property + def read_functions(self): + return tuple(i for i in self._read_functions if i.is_Input) + + @cached_property + def write_functions(self): + return tuple(i for i in self._write_functions if i.is_Input) + + @cached_property + def read_functions_relaxed(self): + return tuple(i for i in self._read_functions + if i.is_Input or i.is_AbstractFunction) + + @cached_property + def write_functions_relaxed(self): + return tuple(i for i in self._write_functions + if i.is_Input or i.is_AbstractFunction) + def apply(self, func): """ Apply a callable to `self` and each expr-like attribute carried by `self`, @@ -175,7 +245,7 @@ class LoweredEq(IREq): `LoweredEq.__rkwargs__` must appear in `kwargs`. """ - __rkwargs__ = IREq.__rkwargs__ + ('reads', 'writes') + __rkwargs__ = IREq.__rkwargs__ def __new__(cls, *args, **kwargs): if len(args) == 1 and isinstance(args[0], LoweredEq): @@ -250,20 +320,11 @@ def __new__(cls, *args, **kwargs): expr._ispace = ispace expr._conditionals = conditionals - expr._reads, expr._writes = detect_io(expr) expr._implicit_dims = input_expr.implicit_dims expr._operation = Operation.detect(input_expr) return expr - @property - def reads(self): - return self._reads - - @property - def writes(self): - return self._writes - def xreplace(self, rules): return LoweredEq(self.lhs.xreplace(rules), self.rhs.xreplace(rules), **self.state) @@ -292,6 +353,7 @@ class ClusterizedEq(IREq): These two properties make a ClusterizedEq suitable for use in a Cluster. """ + @reuse_if_unchanged('__rkwargs__') def __new__(cls, *args, **kwargs): if len(args) == 1: # origin: ClusterizedEq(expr, **kwargs) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index d106b5e811..bc4b5dc48c 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -6,7 +6,7 @@ from collections import OrderedDict, namedtuple from collections.abc import Iterable from contextlib import suppress -from functools import cached_property +from functools import cache, cached_property import cgen as c from sympy import IndexedBase, sympify @@ -16,7 +16,7 @@ from devito.ir.equations import DummyEq, OpInc, OpMax, OpMin, OpMinMax from devito.ir.support import ( AFFINE, INBOUND, PARALLEL, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT, SEQUENTIAL, - VECTORIZED, Forward, PrefetchUpdate, Property, WithLock, detect_io + VECTORIZED, Forward, PrefetchUpdate, Property, WithLock ) from devito.symbolics import CallFromPointer, ListInitializer from devito.tools import ( @@ -102,27 +102,36 @@ class Node(Signer): def __new__(cls, *args, **kwargs): obj = super().__new__(cls) - argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__) - try: - defaults = dict( - zip(argnames[-len(defaultvalues):], defaultvalues, strict=True) - ) - except TypeError: - # No default kwarg values - defaults = {} - obj._args = {k: v for k, v in zip(argnames[1:], args, strict=False)} + argnames, defaults = _constructor_args(cls) + obj._args = {k: v for k, v in zip(argnames, args, strict=False)} obj._args.update(kwargs.items()) - obj._args.update({k: defaults.get(k) for k in argnames[1:] if k not in obj._args}) + obj._args.update({k: defaults.get(k) for k in argnames if k not in obj._args}) return obj def _rebuild(self, *args, **kwargs): """Reconstruct ``self``.""" handle = self._args.copy() # Original constructor arguments argnames = [i for i in self._traversable if i not in kwargs] - handle.update(OrderedDict([(k, v) for k, v in zip(argnames, args, strict=False)])) - handle.update(kwargs) + updates = OrderedDict([(k, v) for k, v in zip(argnames, args, strict=False)]) + updates.update(kwargs) + + if updates and all(self._same_arg(k, v) for k, v in updates.items()): + return self + + handle.update(updates) return type(self)(**handle) + def _same_arg(self, key, value): + with suppress(AttributeError): + if _same_as_before(getattr(self, key), value): + return True + + with suppress(KeyError): + if _same_as_before(self._args[key], value): + return True + + return False + @cached_property def ccode(self): """ @@ -452,7 +461,7 @@ def rhs(self): @cached_property def reads(self): """The Functions read by the Expression.""" - return detect_io(self.expr, relax=True)[0] + return self.expr.read_functions_relaxed @cached_property def write(self): @@ -1558,9 +1567,6 @@ def DummyExpr(*args, init=False): return Expression(DummyEq(*args), init=init) -BlankLine = CBlankLine() - - # Nodes required for distributed-memory halo exchange @@ -1635,3 +1641,54 @@ def functions(self): Iteration/Expression tree. ``local`` is a boolean indicating whether the definition of the callable is known or not. """ + + +# *** Utils + + +@cache +def _constructor_args(cls): + """ + Return cached constructor argument names and default values for an IET type. + + IET node construction records the original constructor arguments in + ``_args``. This helper avoids repeating ``inspect.getfullargspec`` for every + node instance of the same class. + """ + argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__) + if defaultvalues is None: + defaults = {} + else: + defaults = dict(zip(argnames[-len(defaultvalues):], defaultvalues, strict=True)) + + return tuple(argnames[1:]), defaults + + +def _same_as_before(old, new): + """ + Return True if ``new`` preserves the object identity structure of ``old``. + + This intentionally does not use equality for arbitrary objects. It only + recurses through common containers and otherwise requires object identity, + which keeps no-op rebuild detection compatible with IET mapper semantics. + """ + if old is new: + return True + + if isinstance(old, (tuple, list)) and isinstance(new, (tuple, list)): + return len(old) == len(new) and all( + _same_as_before(i, j) for i, j in zip(old, new, strict=True) + ) + + if type(old) is not type(new): + return False + + if isinstance(old, dict) and isinstance(new, dict): + return old.keys() == new.keys() and all( + _same_as_before(v, new[k]) for k, v in old.items() + ) + + return False + + +BlankLine = CBlankLine() diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 1eae6433e3..4208c4b237 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -15,9 +15,10 @@ from sympy.core.function import Application from devito.exceptions import CompilationError +from devito.ir.cgen.printer import get_printer from devito.ir.iet.nodes import ( BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, - Section + Section, _same_as_before ) from devito.ir.support.space import Backward from devito.symbolics import ( @@ -26,7 +27,7 @@ from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import ( GenericVisitor, as_tuple, c_restrict_void_p, filter_ordered, filter_sorted, flatten, - is_external_ctype, sorted_priority + is_external_ctype, memoized_weak_meth, sorted_priority ) from devito.types import ( ArrayObject, CompositeObject, DeviceMap, Dimension, IndexedData, Pointer @@ -255,8 +256,9 @@ def __init__(self, *args, printer=None, **kwargs): printer = CPrinter self.printer = printer - def ccode(self, expr, **kwargs): - return self.printer(settings=kwargs).doprint(expr, None) + def ccode(self, expr, dtype=None): + dtype = self.printer._default_settings['dtype'] if dtype is None else dtype + return get_printer(self.printer, dtype).doprint(expr, None) @property def _qualifiers_mapper(self): @@ -1113,12 +1115,17 @@ def _defines_aliases(n): def __init__(self, mode: str = 'symbolics') -> None: super().__init__() + self.mode = mode modes = mode.split('|') if len(modes) == 1: self.rule = self.rules[mode] else: self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes]) + @memoized_weak_meth(key=lambda i: i.mode, freeze=tuple, thaw=list) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def _post_visit(self, ret): return sorted(filter_ordered(ret, key=id), key=str) @@ -1165,8 +1172,13 @@ class FindNodes(LazyVisitor[Node, list[Node], None]): def __init__(self, match: type, mode: str = 'type') -> None: super().__init__() self.match = match + self.mode = mode self.rule = self.rules[mode] + @memoized_weak_meth(key=lambda i: (i.match, i.mode), freeze=tuple, thaw=list) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]: if self.rule(self.match, o): yield o @@ -1187,6 +1199,10 @@ def __init__(self, match: type, start: Node, stop: Node | None = None) -> None: self.start = start self.stop = stop + def visit(self, o, *args, **kwargs): + # `start` and `stop` are part of this visitor's state. + return GenericVisitor.visit(self, o, *args, **kwargs) + def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]: yield from () return flag # noqa: B901 @@ -1234,8 +1250,13 @@ class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType], None]) def __init__(self, cls: type[ApplicationType] = Application): super().__init__() + self.cls = cls self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic) + @memoized_weak_meth(key=lambda i: i.cls, freeze=frozenset, thaw=set) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def _post_visit(self, ret): return set(ret) @@ -1319,6 +1340,13 @@ def __init__(self, mapper, nested=False): self.mapper = mapper self.nested = nested + def visit(self, o, *args, **kwargs): + # Subclasses may implement mapper-independent transformations. + if type(self) is Transformer and not self.mapper: + return o + + return super().visit(o, *args, **kwargs) + def transform(self, o, handle, **kwargs): if handle is None: # None -> drop `o` @@ -1332,12 +1360,12 @@ def transform(self, o, handle, **kwargs): else: children = o.children children = (tuple(handle) + children[0],) + tuple(children[1:]) - return o._rebuild(*children, **o.args_frozen) + return reuse_if_unchanged(o, *children, **o.args_frozen) else: # Replace `o` with `handle` if self.nested: children = [self._visit(i, **kwargs) for i in handle.children] - return handle._rebuild(*children, **handle.args_frozen) + return reuse_if_unchanged(handle, *children, **handle.args_frozen) else: return handle @@ -1346,7 +1374,12 @@ def visit_object(self, o, **kwargs): def visit_tuple(self, o, **kwargs): visited = tuple(self._visit(i, **kwargs) for i in o) - return tuple(i for i in visited if i is not None) + processed = tuple(i for i in visited if i is not None) + + if _same_as_before(o, processed): + return o + + return processed visit_list = visit_tuple @@ -1357,7 +1390,7 @@ def visit_Node(self, o, **kwargs): children = [self._visit(i, **kwargs) for i in o.children] if o._traversable and not any(children) and any(o.children): return None - return o._rebuild(*children, **o.args_frozen) + return reuse_if_unchanged(o, *children, **o.args_frozen) def visit_Operator(self, o, **kwargs): raise ValueError("Cannot apply a Transformer visitor to an Operator directly") @@ -1374,8 +1407,14 @@ class Uxreplace(Transformer): The substitution rules. """ + def visit(self, o, *args, **kwargs): + if not self.mapper: + return o + + return super().visit(o, *args, **kwargs) + def visit_Expression(self, o): - return o._rebuild(expr=uxreplace(o.expr, self.mapper)) + return reuse_if_unchanged(o, expr=uxreplace(o.expr, self.mapper)) def _visit_Iteration_common(self, o): nodes = self._visit(o.nodes) @@ -1392,8 +1431,8 @@ def visit_Iteration(self, o): nodes, dimension, limits, pragmas, uindices = \ self._visit_Iteration_common(o) - return o._rebuild(nodes=nodes, dimension=dimension, limits=limits, - pragmas=pragmas, uindices=uindices) + return reuse_if_unchanged(o, nodes=nodes, dimension=dimension, limits=limits, + pragmas=pragmas, uindices=uindices) def visit_PragmaIteration(self, o): nodes, dimension, limits, pragmas, uindices = \ @@ -1420,7 +1459,7 @@ def visit_Return(self, o): def visit_Callable(self, o): body = self._visit(o.body) parameters = [self.mapper.get(i, i) for i in o.parameters] - return o._rebuild(body=body, parameters=parameters) + return reuse_if_unchanged(o, body=body, parameters=parameters) def visit_Call(self, o): arguments = [] @@ -1431,47 +1470,47 @@ def visit_Call(self, o): arguments.append(uxreplace(i, self.mapper)) if o.retobj is not None: retobj = uxreplace(o.retobj, self.mapper) - return o._rebuild(arguments=arguments, retobj=retobj) + return reuse_if_unchanged(o, arguments=arguments, retobj=retobj) else: - return o._rebuild(arguments=arguments) + return reuse_if_unchanged(o, arguments=arguments) def visit_Lambda(self, o): body = self._visit(o.body) parameters = [self.mapper.get(i, i) for i in o.parameters] - return o._rebuild(body=body, parameters=parameters) + return reuse_if_unchanged(o, body=body, parameters=parameters) def visit_Conditional(self, o): condition = uxreplace(o.condition, self.mapper) then_body = self._visit(o.then_body) else_body = self._visit(o.else_body) - return o._rebuild(condition=condition, then_body=then_body, - else_body=else_body) + return reuse_if_unchanged(o, condition=condition, then_body=then_body, + else_body=else_body) def visit_Switch(self, o): condition = uxreplace(o.condition, self.mapper) nodes = self._visit(o.nodes) default = self._visit(o.default) - return o._rebuild(condition=condition, nodes=nodes, default=default) + return reuse_if_unchanged(o, condition=condition, nodes=nodes, default=default) def visit_PointerCast(self, o): function = self.mapper.get(o.function, o.function) obj = self.mapper.get(o.obj, o.obj) - return o._rebuild(function=function, obj=obj) + return reuse_if_unchanged(o, function=function, obj=obj) def visit_Dereference(self, o): pointee = self.mapper.get(o.pointee, o.pointee) pointer = self.mapper.get(o.pointer, o.pointer) - return o._rebuild(pointee=pointee, pointer=pointer) + return reuse_if_unchanged(o, pointee=pointee, pointer=pointer) def visit_Pragma(self, o): arguments = [uxreplace(i, self.mapper) for i in o.arguments] - return o._rebuild(arguments=arguments) + return reuse_if_unchanged(o, arguments=arguments) def visit_PragmaTransfer(self, o): function = uxreplace(o.function, self.mapper) arguments = [uxreplace(i, self.mapper) for i in o.arguments] if o.imask is None: - return o._rebuild(function=function, arguments=arguments) + return reuse_if_unchanged(o, function=function, arguments=arguments) # An `imask` may be None, a list of symbols/numbers, or a list of # 2-tuples representing ranges @@ -1483,25 +1522,26 @@ def visit_PragmaTransfer(self, o): uxreplace(j, self.mapper))) except TypeError: imask.append(uxreplace(v, self.mapper)) - return o._rebuild(function=function, imask=imask, arguments=arguments) + return reuse_if_unchanged(o, function=function, imask=imask, + arguments=arguments) def visit_ParallelTree(self, o): prefix = self._visit(o.prefix) body = self._visit(o.body) nthreads = self.mapper.get(o.nthreads, o.nthreads) - return o._rebuild(prefix=prefix, body=body, nthreads=nthreads) + return reuse_if_unchanged(o, prefix=prefix, body=body, nthreads=nthreads) def visit_HaloSpot(self, o): hs = o.halo_scheme fmapper = {self.mapper.get(k, k): v for k, v in hs.fmapper.items()} halo_scheme = hs._rebuild(fmapper=fmapper) body = self._visit(o.body) - return o._rebuild(halo_scheme=halo_scheme, body=body) + return reuse_if_unchanged(o, halo_scheme=halo_scheme, body=body) def visit_While(self, o, **kwargs): condition = uxreplace(o.condition, self.mapper) body = self._visit(o.body) - return o._rebuild(condition=condition, body=body) + return reuse_if_unchanged(o, condition=condition, body=body) visit_ThreadedProdder = visit_Call @@ -1510,8 +1550,8 @@ def visit_KernelLaunch(self, o): grid = self.mapper.get(o.grid, o.grid) block = self.mapper.get(o.block, o.block) stream = self.mapper.get(o.stream, o.stream) - return o._rebuild(grid=grid, block=block, stream=stream, - arguments=arguments) + return reuse_if_unchanged(o, grid=grid, block=block, stream=stream, + arguments=arguments) # Utils @@ -1519,6 +1559,16 @@ def visit_KernelLaunch(self, o): blankline = c.Line("") +def reuse_if_unchanged(o, *children, **kwargs): + if children and not _same_as_before(o.children, children): + return o._rebuild(*children, **kwargs) + + if kwargs: + return o._rebuild(*children, **kwargs) + + return o + + def printAST(node, verbose=True): return PrintAST(verbose=verbose)._visit(node) diff --git a/devito/ir/stree/algorithms.py b/devito/ir/stree/algorithms.py index d4a761dfc8..68fc697d3e 100644 --- a/devito/ir/stree/algorithms.py +++ b/devito/ir/stree/algorithms.py @@ -111,7 +111,7 @@ def stree_build(clusters, profiler=None, **kwargs): else: parent = tip - NodeExprs(exprs, c.ispace, c.dspace, c.ops, c.traffic, parent) + NodeExprs(exprs, c.ispace, c.ops, c.traffic, parent) # Nest within a NodeSection if possible if profiler is None or \ diff --git a/devito/ir/stree/tree.py b/devito/ir/stree/tree.py index e033c9fd15..96e498396d 100644 --- a/devito/ir/stree/tree.py +++ b/devito/ir/stree/tree.py @@ -115,11 +115,10 @@ class NodeExprs(ScheduleTree): is_Exprs = True - def __init__(self, exprs, ispace, dspace, ops, traffic, parent=None): + def __init__(self, exprs, ispace, ops, traffic, parent=None): super().__init__(parent) self.exprs = exprs self.ispace = ispace - self.dspace = dspace self.ops = ops self.traffic = traffic diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 7939ee8fe8..9283f2c2df 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Iterable from contextlib import suppress -from functools import cached_property +from functools import cached_property, wraps from itertools import chain, product import sympy @@ -10,12 +10,11 @@ from devito.ir.support.utils import AccessMode, extrema from devito.ir.support.vector import LabeledVector, Vector from devito.symbolics import ( - compare_ops, q_affine, q_comp_acc, q_constant, q_routine, retrieve_indexed, - retrieve_terminals, search, uxreplace + compare_ops, q_affine, q_comp_acc, q_constant, retrieve_indexed ) from devito.tools import ( - CacheInstances, Tag, as_mapper, as_tuple, filter_sorted, flatten, is_integer, - memoized_generator, memoized_meth, smart_gt, smart_lt + CacheInstances, Tag, as_mapper, as_tuple, cached_hash, filter_sorted, flatten, + is_integer, memoized_func, memoized_generator, memoized_meth, smart_gt, smart_lt ) from devito.types import ( ComponentAccess, CriticalRegion, Dimension, DimensionTuple, Fence, Function, Symbol, @@ -200,7 +199,7 @@ def is_scalar(self): return self.rank == 0 -class TimedAccess(IterationInstance, AccessMode): +class TimedAccess(IterationInstance, AccessMode, CacheInstances): """ A TimedAccess ties together an IterationInstance and an AccessMode. @@ -218,6 +217,10 @@ class TimedAccess(IterationInstance, AccessMode): on the values of the index functions and the access mode (read, write). """ + @classmethod + def _preprocess_args(cls, access, mode, timestamp, ispace=null_ispace): + return (access, mode, timestamp, ispace), {} + def __new__(cls, access, mode, timestamp, ispace=None): obj = super().__new__(cls, access) AccessMode.__init__(obj, mode=mode) @@ -247,6 +250,7 @@ def __eq__(self, other): self.access == other.access and self.ispace == other.ispace) + @cached_hash def __hash__(self): return hash((self.access, self.mode, self.timestamp, self.ispace)) @@ -320,6 +324,21 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp + def rebuild(self, **kwargs): + access = kwargs.get('access', self.access) + mode = kwargs.get('mode', self.mode) + timestamp = kwargs.get('timestamp', self.timestamp) + ispace = kwargs.get('ispace', self.ispace) + + if access is self.access and \ + mode is self.mode and \ + timestamp is self.timestamp and \ + ispace is self.ispace: + return self + + return TimedAccess(access, mode, timestamp, ispace) + + @memoized_meth def distance(self, other, logical=False): """ Compute the distance from ``self`` to ``other``. @@ -853,18 +872,91 @@ class Scope(CacheInstances): # Describes a rule for dependencies Rule = Callable[[TimedAccess, TimedAccess], bool] + def normalize_input(func): + + @wraps(func) + def wrapper(self, *args, writes=None, **kwargs): + mapped = {} + for k in as_tuple(writes or self.writes): + v = self.getwrites(k) + if v: + mapped[k] = v + return func(self, *args, writes=mapped, **kwargs) + + return wrapper + + @classmethod + @memoized_func(scope='build') + def from_scopes(cls, scope0, scope1): + """ + Build a synthetic Scope out of two existing Scopes by reusing their + cached reads and writes rather than rediscovering accesses from the + underlying expressions. + + This is used to analyze cross-scope dependences cheaply, for example in + loop-fusion hazard checks. Return None if the two Scopes cannot induce + any cross-scope dependences. + """ + offset = len(scope0.exprs) + + targets = ( + set(scope0.writes) & scope1.functions + ) | ( + set(scope1.writes) & scope0.functions + ) + if not targets: + return None + + def is_cross(source, sink): + t0 = source.timestamp + t1 = sink.timestamp + return t0 < offset <= t1 or t1 < offset <= t0 + + reads = {} + writes = {} + + for f in targets: + shifted = tuple( + i.rebuild(timestamp=i.timestamp + offset) + for i in scope1.getreads(f) + ) + accesses = scope0.getreads(f) + if shifted: + accesses = accesses + shifted if accesses else shifted + if accesses: + reads[f] = accesses + + shifted = tuple( + i.rebuild(timestamp=i.timestamp + offset) + for i in scope1.getwrites(f) + ) + accesses = scope0.getwrites(f) + if shifted: + accesses = accesses + shifted if accesses else shifted + if accesses: + writes[f] = accesses + + return cls((), rules=is_cross, reads=reads.items(), writes=writes.items()) + @classmethod def _preprocess_args(cls, exprs: Expr | Iterable[Expr], **kwargs) -> tuple[tuple, dict]: + for i in ('reads', 'writes'): + with suppress(KeyError): + kwargs[i] = tuple(kwargs[i]) + return (as_tuple(exprs),), kwargs def __init__(self, exprs: tuple[Expr], - rules: Rule | tuple[Rule] | None = None) -> None: + rules: Rule | tuple[Rule] | None = None, + reads=None, writes=None) -> None: """ A Scope enables data dependence analysis on a totally ordered sequence of expressions. """ self.exprs = exprs + self._reads = dict(reads) if reads is not None else None + self._writes = dict(writes) if writes is not None else None # A set of rules to drive the collection of dependencies self.rules: tuple[Scope.Rule] = as_tuple(rules) # type: ignore[assignment] @@ -876,13 +968,7 @@ def writes_gen(self): Generate all write accesses. """ for i, e in enumerate(self.exprs): - terminals = retrieve_accesses(e.lhs) - if q_routine(e.rhs): - with suppress(AttributeError): - # Everything except: foreign routines, such as `cos` or `sin` etc. - terminals.update(e.rhs.writes) - - for j in terminals: + for j in e.writes: mode = 'WR' if e.is_Reduction else 'W' yield TimedAccess(j, mode, i, e.ispace) @@ -909,8 +995,17 @@ def writes(self): """ Create a mapper from functions to write accesses. """ + if self._writes is not None: + return self._writes + return as_mapper(self.writes_gen(), key=lambda i: i.function) + @cached_property + def writes_tensor(self): + initialized = frozenset(e.lhs.function for e in self.exprs + if not e.is_Reduction and e.is_scalar) + return frozenset(self.writes) - initialized + @memoized_generator def reads_explicit_gen(self): """ @@ -919,11 +1014,7 @@ def reads_explicit_gen(self): expressions. """ for i, e in enumerate(self.exprs): - # Reads - terminals = retrieve_accesses(e.rhs, deep=True) - with suppress(AttributeError): - terminals.update(retrieve_accesses(e.lhs.indices)) - for j in terminals: + for j in e.reads_explicit: mode = 'RR' if j.function is e.lhs.function and e.is_Reduction else 'R' yield TimedAccess(j, mode, i, e.ispace) @@ -932,9 +1023,8 @@ def reads_explicit_gen(self): yield TimedAccess(e.lhs, 'RR', i, e.ispace) # Look up ConditionalDimensions - for v in e.conditionals.values(): - for j in retrieve_accesses(v): - yield TimedAccess(j, 'R', -1, e.ispace) + for j in e.reads_conditional: + yield TimedAccess(j, 'R', -1, e.ispace) @memoized_generator def reads_implicit_gen(self): @@ -1008,21 +1098,22 @@ def reads_smart_gen(self, f): the iteration symbols. """ if isinstance(f, (Function, Temp, TempArray, TBArray)): - for i in chain(self.reads_explicit_gen(), self.reads_synchro_gen()): - if f is i.function: - for j in extrema(i.access): - yield TimedAccess(j, i.mode, i.timestamp, i.ispace) + for i in self.getreads(f): + for j in extrema(i.access): + yield TimedAccess(j, i.mode, i.timestamp, i.ispace) else: - for i in self.reads_gen(): - if f is i.function: - yield i + for i in self.getreads(f): + yield i @cached_property def reads(self): """ Create a mapper from functions to read accesses. """ + if self._reads is not None: + return self._reads + return as_mapper(self.reads_gen(), key=lambda i: i.function) @cached_property @@ -1033,9 +1124,9 @@ def read_only(self): return set(self.reads) - set(self.writes) @cached_property - def initialized(self): - return frozenset(e.lhs.function for e in self.exprs - if not e.is_Reduction and e.is_scalar) + def has_barrier(self): + """True if the Scope contains a fence-like control-flow object.""" + return any(isinstance(e.rhs, (Fence, CriticalRegion)) for e in self.exprs) def getreads(self, function): return as_tuple(self.reads.get(function)) @@ -1095,11 +1186,17 @@ def a_query(self, timestamps=None, modes=None): if a.timestamp in timestamps and a.mode in modes) @memoized_generator - def d_flow_gen(self): - """Generate the flow (or "read-after-write") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_flow_gen(self, writes=None): + """ + Generate the flow (or "read-after-write") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): + reads = tuple(self.reads_smart_gen(k)) for w in v: - for r in self.reads_smart_gen(k): + for r in reads: if any(not rule(w, r) for rule in self.rules): continue @@ -1126,11 +1223,17 @@ def d_flow(self): return DependenceGroup(self.d_flow_gen()) @memoized_generator - def d_anti_gen(self, depcls=Dependence): - """Generate the anti (or "write-after-read") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_anti_gen(self, depcls=Dependence, writes=None): + """ + Generate the anti (or "write-after-read") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): + reads = tuple(self.reads_smart_gen(k)) for w in v: - for r in self.reads_smart_gen(k): + for r in reads: if any(not rule(r, w) for rule in self.rules): continue @@ -1165,11 +1268,16 @@ def d_anti_logical(self): return DependenceGroup(self.d_anti_gen(depcls=LogicalDependence)) @memoized_generator - def d_output_gen(self): - """Generate the output (or "write-after-write") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_output_gen(self, writes=None): + """ + Generate the output (or "write-after-write") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for v in writes.values(): for w1 in v: - for w2 in self.writes.get(k, []): + for w2 in v: if any(not rule(w2, w1) for rule in self.rules): continue @@ -1193,9 +1301,15 @@ def d_output(self): """Output (or "write-after-write") dependences.""" return DependenceGroup(self.d_output_gen()) - def d_all_gen(self): - """Generate all flow, anti and output dependences.""" - return chain(self.d_flow_gen(), self.d_anti_gen(), self.d_output_gen()) + def d_all_gen(self, writes=None): + """ + Generate all flow, anti and output dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + return chain(self.d_flow_gen(writes=writes), + self.d_anti_gen(writes=writes), + self.d_output_gen(writes=writes)) @cached_property def d_all(self): @@ -1381,23 +1495,6 @@ def vinf(entries): return Vector(*(entries + [S.Infinity])) -def retrieve_accesses(exprs, **kwargs): - """ - Like retrieve_terminals, but ensure that if a ComponentAccess is found, - the ComponentAccess itself is returned, while the wrapped Indexed is discarded. - """ - kwargs['mode'] = 'unique' - - compaccs = search(exprs, ComponentAccess) - if not compaccs: - return retrieve_terminals(exprs, **kwargs) - - subs = {i: Symbol(f'dummy{n}') for n, i in enumerate(compaccs)} - exprs1 = uxreplace(exprs, subs) - - return compaccs | retrieve_terminals(exprs1, **kwargs) - set(subs.values()) - - def disjoint_test(e0, e1, d, it): """ A rudimentary test to check if two accesses `e0` and `e1` along `d` within diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index b8a335b1f4..697ddb0f8e 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -272,31 +272,34 @@ class Guards(frozendict): def get(self, d, v=true): return super().get(d, v) + def _reuse_if_untouched(self, mapper): + return self if mapper == self else Guards(mapper) + def andg(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self try: m[d] = simplify_and(m[d], guard) except KeyError: m[d] = guard - return Guards(m) + return self._reuse_if_untouched(m) def xandg(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self try: m[d] = And(m[d], guard) except KeyError: m[d] = guard - return Guards(m) + return self._reuse_if_untouched(m) def pairwise_or(self, d, *guards): m = dict(self) @@ -311,17 +314,17 @@ def pairwise_or(self, d, *guards): else: m[d] = g - return Guards(m) + return self._reuse_if_untouched(m) def impose(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self m[d] = guard - return Guards(m) + return self._reuse_if_untouched(m) def popany(self, dims): m = dict(self) @@ -329,12 +332,12 @@ def popany(self, dims): for d in as_tuple(dims): m.pop(d, None) - return Guards(m) + return self._reuse_if_untouched(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return Guards(m) + return self._reuse_if_untouched(m) def as_map(self, d, cls): if cls not in (Le, Lt, Ge, Gt): diff --git a/devito/ir/support/properties.py b/devito/ir/support/properties.py index 9e787a8b9e..6827e7c7bf 100644 --- a/devito/ir/support/properties.py +++ b/devito/ir/support/properties.py @@ -199,19 +199,27 @@ class Properties(frozendict): A mapper {Dimension -> {properties}}. """ + def __init__(self, *args, **kwargs): + mapper = dict(*args, **kwargs) + mapper = {d: frozenset(as_tuple(v)) for d, v in mapper.items()} + super().__init__(mapper) + @property def dimensions(self): return tuple(self) + def _reuse_if_untouched(self, mapper): + return self if mapper == self else Properties(mapper) + def add(self, dims, properties=None): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | set(as_tuple(properties)) - return Properties(m) + return self._reuse_if_untouched(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return Properties(m) + return self._reuse_if_untouched(m) def drop(self, dims=None, properties=None): if dims is None: @@ -222,7 +230,7 @@ def drop(self, dims=None, properties=None): m.pop(d, None) else: m[d] = self[d] - set(as_tuple(properties)) - return Properties(m) + return self._reuse_if_untouched(m) def parallelize(self, dims): m = dict(self) @@ -231,13 +239,13 @@ def parallelize(self, dims): v.difference_update({PARALLEL_IF_PVT, PARALLEL_IF_ATOMIC, SEQUENTIAL}) v.add(PARALLEL) m[d] = v - return Properties(m) + return self._reuse_if_untouched(m) def affine(self, dims): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {AFFINE} - return Properties(m) + return self._reuse_if_untouched(m) def sequentialize(self, dims=None): if dims is None: @@ -245,13 +253,13 @@ def sequentialize(self, dims=None): m = dict(self) for d in as_tuple(dims): m[d] = normalize_properties(set(self.get(d, [])), {SEQUENTIAL}) - return Properties(m) + return self._reuse_if_untouched(m) def prefetchable(self, dims, v=PREFETCHABLE): m = dict(self) for d in as_tuple(dims): m[d] = self.get(d, set()) | {v} - return Properties(m) + return self._reuse_if_untouched(m) def block(self, dims, kind='default'): if kind == 'default': @@ -263,7 +271,7 @@ def block(self, dims, kind='default'): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {p} - return Properties(m) + return self._reuse_if_untouched(m) def inbound(self, dims): return self.add(dims, INBOUND) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 1c760128b5..4e95ebc4cb 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -8,8 +8,8 @@ from devito.ir.support.utils import maximum, minimum from devito.ir.support.vector import Vector, vmax, vmin from devito.tools import ( - Ordering, Stamp, as_list, as_set, as_tuple, filter_ordered, flatten, frozendict, - is_integer, toposort + CacheInstances, Ordering, Stamp, as_list, as_set, as_tuple, cached_hash, + filter_ordered, flatten, frozendict, is_integer, toposort ) from devito.types import Dimension, ModuloDimension @@ -53,6 +53,7 @@ def __eq__(self, o): is_compatible = __eq__ + @cached_hash def __hash__(self): return hash(self.dim.name) @@ -88,7 +89,7 @@ def negate(self): translate = negate -class NullInterval(AbstractInterval): +class NullInterval(AbstractInterval, CacheInstances): """ A degenerate iterated closed interval on Z. @@ -96,9 +97,14 @@ class NullInterval(AbstractInterval): is_Null = True + @classmethod + def _preprocess_args(cls, dim, stamp=S0): + return (dim, stamp), {} + def __repr__(self): return f"{self.dim}[Null]{self.stamp}" + @cached_hash def __hash__(self): return hash(self.dim) @@ -120,7 +126,7 @@ def switch(self, d): return NullInterval(d, self.stamp) -class Interval(AbstractInterval): +class Interval(AbstractInterval, CacheInstances): """ Interval(dim, lower, upper) @@ -134,6 +140,18 @@ class Interval(AbstractInterval): is_Defined = True + @classmethod + def _preprocess_args(cls, dim, lower=0, upper=0, stamp=S0): + try: + lower = int(lower) + except TypeError: + assert isinstance(lower, Expr) + try: + upper = int(upper) + except TypeError: + assert isinstance(upper, Expr) + return (dim, lower, upper, stamp), {} + def __init__(self, dim, lower=0, upper=0, stamp=S0): super().__init__(dim, stamp) @@ -151,6 +169,7 @@ def __init__(self, dim, lower=0, upper=0, stamp=S0): def __repr__(self): return f"{self.dim}[{self.lower},{self.upper}]{self.stamp}" + @cached_hash def __hash__(self): return hash((self.dim, self.offsets)) @@ -304,12 +323,18 @@ def expand(self): ) -class IntervalGroup(Ordering): +class IntervalGroup(Ordering, CacheInstances): """ A sequence of Intervals equipped with set-like operations. """ + @classmethod + def _preprocess_args(cls, items=None, relations=None, mode='total'): + items = as_tuple(items) + relations = tuple(tuple(i) for i in as_tuple(relations)) + return (items,), {'relations': relations, 'mode': mode} + @classmethod def reorder(cls, items, relations): if not all(isinstance(i, AbstractInterval) for i in items): @@ -335,13 +360,14 @@ def simplify_relations(cls, relations, items, mode): return super().simplify_relations(relations, items, mode) def __eq__(self, o): - return len(self) == len(o) and all(i == j for i, j in zip(self, o, strict=True)) + return isinstance(o, IntervalGroup) and super().__eq__(o) def __contains__(self, d): return any(i.dim is d for i in self) + @cached_hash def __hash__(self): - return hash(tuple(self)) + return hash((tuple(self), self.relations, self.mode)) def __repr__(self): return "IntervalGroup[{}]".format(', '.join([repr(i) for i in self])) @@ -598,6 +624,7 @@ def __eq__(self, other): def __repr__(self): return self._name + @cached_hash def __hash__(self): return hash(self._name) @@ -618,6 +645,11 @@ class IterationInterval(Interval): An Interval associated with metadata. """ + @classmethod + def _preprocess_args(cls, interval, sub_iterators=(), direction=Forward): + sub_iterators = tuple(filter_ordered(as_tuple(sub_iterators))) + return (interval, sub_iterators, direction), {} + def __init__(self, interval, sub_iterators=(), direction=Forward): super().__init__(interval.dim, *interval.offsets, stamp=interval.stamp) self.sub_iterators = sub_iterators @@ -631,6 +663,7 @@ def __eq__(self, other): return False return self.direction is other.direction and super().__eq__(other) + @cached_hash def __hash__(self): return hash((self.dim, self.offsets, self.direction)) @@ -665,9 +698,6 @@ def __repr__(self): def __eq__(self, other): return isinstance(other, Space) and self.intervals == other.intervals - def __hash__(self): - return hash(self.intervals) - def __len__(self): return len(self.intervals) @@ -731,8 +761,9 @@ def __eq__(self, other): self.intervals == other.intervals and self.parts == other.parts) + @cached_hash def __hash__(self): - return hash((super().__hash__(), self.parts)) + return hash((self.intervals, self.parts)) @classmethod def union(cls, *others): @@ -769,7 +800,7 @@ def reset(self): return DataSpace(intervals, parts) -class IterationSpace(Space): +class IterationSpace(Space, CacheInstances): """ Represent an iteration space as a Space with additional metadata and operations. @@ -785,23 +816,29 @@ class IterationSpace(Space): A mapper from Dimensions in ``intervals`` to IterationDirections. """ - def __init__(self, intervals, sub_iterators=None, directions=None): - super().__init__(intervals) + @classmethod + def _preprocess_args(cls, intervals, sub_iterators=None, directions=None): + if not isinstance(intervals, IntervalGroup): + intervals = IntervalGroup(as_tuple(intervals)) - # Normalize sub-iterators sub_iterators = sub_iterators or {} sub_iterators = {d: tuple(filter_ordered(as_tuple(v))) - for d, v in sub_iterators.items() if d in self.intervals} - sub_iterators.update({i.dim: () for i in self.intervals + for d, v in sub_iterators.items() if d in intervals} + sub_iterators.update({i.dim: () for i in intervals if i.dim not in sub_iterators}) - self._sub_iterators = frozendict(sub_iterators) - # Normalize directions directions = directions or {} - directions = {d: v for d, v in directions.items() if d in self.intervals} - directions.update({i.dim: Any for i in self.intervals + directions = {d: v for d, v in directions.items() if d in intervals} + directions.update({i.dim: Any for i in intervals if i.dim not in directions}) - self._directions = frozendict(directions) + + return (intervals, frozendict(sub_iterators), frozendict(directions)), {} + + def __init__(self, intervals, sub_iterators=None, directions=None): + super().__init__(intervals) + + self._sub_iterators = sub_iterators + self._directions = directions def __repr__(self): ret = ', '.join([f"{repr(i)}{repr(self.directions[i.dim])}" @@ -822,8 +859,9 @@ def __lt__(self, other): """ return len(self.itintervals) < len(other.itintervals) + @cached_hash def __hash__(self): - return hash((super().__hash__(), self.sub_iterators, self.directions)) + return hash((self.intervals, self.sub_iterators, self.directions)) def __contains__(self, d): try: diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index 644bab5d4c..5f2ee39af7 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -3,8 +3,8 @@ from itertools import product from devito.finite_differences import IndexDerivative -from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search -from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split +from devito.symbolics import retrieve_indexed, search +from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, split from devito.types import ( Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension, TensorMove ) @@ -14,7 +14,6 @@ 'IMask', 'Stencil', 'detect_accesses', - 'detect_io', 'erange', 'extrema', 'maximum', @@ -217,70 +216,6 @@ def detect_accesses(exprs): return mapper -def detect_io(exprs, relax=False): - """ - ``{exprs} -> ({reads}, {writes})`` - - Parameters - ---------- - exprs : expr-like or list of expr-like - The searched expressions. - relax : bool, optional - If False, as by default, collect all Input objects, such as - Constants and Functions. Otherwise, also collect AbstractFunctions. - """ - exprs = as_tuple(exprs) - if relax is False: - rule = lambda i: i.is_Input - else: - rule = lambda i: i.is_Input or i.is_AbstractFunction - - # Don't forget the nasty case with indirections on the LHS: - # >>> u[t, a[x]] = f[x] -> (reads={a, f}, writes={u}) - - roots = [] - for i in exprs: - try: - roots.append(i.rhs) - roots.extend(list(i.lhs.indices)) - roots.extend(list(i.conditionals.values())) - except AttributeError: - # E.g., CallFromPointer - roots.append(i) - - reads = [] - terminals = flatten(retrieve_terminals(i, deep=True) for i in roots) - for i in terminals: - candidates = set(i.free_symbols) - with suppress(AttributeError): - candidates.update({i.function}) - for j in candidates: - try: - if rule(j): - reads.append(j) - except AttributeError: - pass - - writes = [] - for i in exprs: - try: - f = i.lhs.function - except AttributeError: - continue - try: - if rule(f): - writes.append(f) - except AttributeError: - # We only end up here after complex IET transformations which make - # use of composite types - assert isinstance(i.lhs, CallFromPointer) - f = i.lhs.base.function - if rule(f): - writes.append(f) - - return tuple(filter_sorted(reads)), tuple(filter_sorted(writes)) - - def pull_dims(exprs, flag=True): """ Extract all Dimensions from one or more expressions. If `flag=True` diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 6e77c39281..a57ce5bd04 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -32,14 +32,14 @@ from devito.operator.registry import operator_selector from devito.parameters import configuration from devito.passes import ( - Graph, error_mapper, generate_implicit, generate_macros, is_on_device, lower_dtypes, - lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate + Graph, error_mapper, finalize_args, generate_implicit, generate_macros, is_on_device, + lower_dtypes, lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate ) from devito.symbolics import estimate_cost, subs_op_args from devito.tools import ( DAG, CacheInstances, MemoryEstimate, OrderedSet, ReducerMap, Signer, as_mapper, - as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, split, - timed_pass, timed_region + as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, memoized_func, + split, timed_pass, timed_region ) from devito.types import Buffer, Evaluable, device_layer, disk_layer, host_layer from devito.types.dimension import Thickness @@ -184,7 +184,11 @@ def __new__(cls, expressions, **kwargs): # Lower to a JIT-compilable object with timed_region('op-compile') as r: - op = cls._build(expressions, **kwargs) + try: + op = cls._build(expressions, **kwargs) + finally: + CacheInstances.clear_caches() + memoized_func.clear_build_caches() op._profiler.py_timers.update(r.timings) # Emit info about how long it took to perform the lowering @@ -261,15 +265,12 @@ def _build(cls, expressions, **kwargs): op._state = cls._initialize_state(**kwargs) # Produced by the various compilation passes - op._reads = filter_sorted(flatten(e.reads for e in irs.expressions)) - op._writes = filter_sorted(flatten(e.writes for e in irs.expressions)) + op._reads = filter_sorted(flatten(e.read_functions for e in irs.expressions)) + op._writes = filter_sorted(flatten(e.write_functions for e in irs.expressions)) op._dimensions = set().union(*[e.dimensions for e in irs.expressions]) op._dtype, op._dspace = irs.clusters.meta op._profiler = profiler - # Clear build-scoped instance caches - CacheInstances.clear_caches() - return op def __init__(self, *args, **kwargs): @@ -521,6 +522,9 @@ def _lower_iet(cls, uiet, **kwargs): # Target-independent optimizations minimize_symbols(graph) + # Finalize helper signatures after all IET transformations have settled. + finalize_args(graph) + return graph.root, graph # Read-only properties exposed to the outside world @@ -1422,11 +1426,13 @@ def _physical_deviceid(self): if isinstance(self.platform, Device): # Get the physical device ID (as CUDA_VISIBLE_DEVICES may be set) logical_deviceid = self.get('deviceid', -1) + visible_device_var, visible_devices = get_visible_devices() if logical_deviceid < 0: rank = self.comm.Get_rank() if self.comm != MPI.COMM_NULL else 0 - logical_deviceid = rank - - visible_device_var, visible_devices = get_visible_devices() + if visible_devices is None: + logical_deviceid = rank + else: + logical_deviceid = rank % len(visible_devices) if visible_devices is None: return logical_deviceid elif len(visible_devices) == 1: diff --git a/devito/parameters.py b/devito/parameters.py index 2aba1112ad..7fb957ce39 100644 --- a/devito/parameters.py +++ b/devito/parameters.py @@ -286,7 +286,9 @@ def __exit__(self, exc_type, exc_val, traceback): class switchenv(SwitchDecorator): """ - Temporarily set environment variables from a dictionary + Temporarily set environment variables from a dictionary. A value of None + unsets the corresponding environment variable for the duration of the + context. Note: This does not propagate any environment variables that change inside the context manager, so should be used cautiously. @@ -296,7 +298,11 @@ def __init__(self, params): self.params = params def __enter__(self): - os.environ.update(self.params) + for k, v in self.params.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v def __exit__(self, exc_type, exc_val, traceback): os.environ.clear() diff --git a/devito/passes/clusters/__init__.py b/devito/passes/clusters/__init__.py index c41a628e06..e27d2b755d 100644 --- a/devito/passes/clusters/__init__.py +++ b/devito/passes/clusters/__init__.py @@ -3,6 +3,7 @@ from .cse import * # noqa from .aliases import * # noqa from .factorization import * # noqa +from .fusion import * # noqa from .blocking import * # noqa from .asynchrony import * # noqa from .implicit import * # noqa diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 194d26523d..b8c4108b83 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -5,8 +5,8 @@ from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue -from devito.passes.clusters.misc import fuse -from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace +from devito.passes.clusters.fusion import fuse +from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, search, uxreplace from devito.tools import infer_dtype, timed_pass from devito.types import Eq, Inc, Indexed, Symbol @@ -15,13 +15,16 @@ @timed_pass() def lower_index_derivatives(clusters, mode=None, **kwargs): + max_depth = _max_index_derivative_depth(clusters) clusters, weights, mapper = _lower_index_derivatives(clusters, **kwargs) if not weights: return clusters if mode != 'noop': - clusters = fuse(clusters, toposort='maximal') + for _ in range(max_depth): + clusters = fuse(clusters, toposort='nofuse') + clusters = fuse(clusters, toposort=False) # At this point we can detect redundancies induced by inner derivatives that # previously were just not detectable via e.g. plain CSE. For example, if @@ -258,3 +261,16 @@ def callback(self, clusters, prefix, subs0=None, seen=None): seen.update(processed) return processed + + +# *** Utils + + +def _max_index_derivative_depth(clusters): + max_depth = 0 + + for c in clusters: + for i in search(c.exprs, IndexDerivative): + max_depth = max(max_depth, i.depth) + + return max_depth diff --git a/devito/passes/clusters/fusion.py b/devito/passes/clusters/fusion.py new file mode 100644 index 0000000000..837b76df13 --- /dev/null +++ b/devito/passes/clusters/fusion.py @@ -0,0 +1,338 @@ +from collections import Counter, defaultdict +from itertools import groupby + +from devito.finite_differences import IndexDerivative +from devito.ir.clusters import Cluster, ClusterGroup, Queue +from devito.ir.support import ( + InitArray, PrefetchUpdate, ReleaseLock, Scope, SyncArray, WaitLock, WithLock +) +from devito.symbolics import search +from devito.tools import ( + DAG, as_tuple, flatten, frozendict, memoized_func, memoized_meth, timed_pass +) + +__all__ = ['fuse'] + + +# No hazard: fusion may proceed. +NO_HAZARD = None +# Ordering hazard: preserve program order and forbid fusion. +EDGE = 'edge' +# Prefix anti-dependence: break the execution flow across the pair. +BREAK = 'break' + + +@memoized_func(scope='build') +def _fusion_hazards(scope0, scope1, prefix): + """ + Classify the dependence hazard that would arise from fusing two scopes. + """ + scope = Scope.from_scopes(scope0, scope1) + if scope is None: + return NO_HAZARD + + anti = False + for i in scope.d_anti_gen(): + if i.cause & prefix: + return BREAK + anti = True + + if anti: + return EDGE + + for i in scope.d_flow_gen(): + if not (i.cause & prefix): + return EDGE + + for _ in scope.d_output_gen(): + return EDGE + + return NO_HAZARD + + +class Fusion(Queue): + + """ + Fuse Clusters with compatible IterationSpace. + """ + + _q_guards_in_key = True + _q_syncs_in_key = True + + def __init__(self, toposort, options=None): + options = options or {} + + self.toposort = toposort + self.fusetasks = options.get('fuse-tasks', False) + + super().__init__() + + def process(self, clusters): + cgroups = [ClusterGroup(c, c.ispace) for c in clusters] + cgroups = self._process_fdta(cgroups, 1) + clusters = ClusterGroup.concatenate(*cgroups) + return clusters + + def callback(self, cgroups, prefix): + # Toposort to maximize fusion + if self.toposort: + clusters = self._toposort(cgroups, prefix) + if self.toposort == 'nofuse': + return [clusters] + else: + clusters = ClusterGroup(cgroups) + + # Fusion + processed = [] + for _, group in groupby(clusters, key=self._key): + g = list(group) + + for maybe_fusible in self._apply_heuristics(g): + try: + # Perform fusion + processed.append(Cluster.from_clusters(*maybe_fusible)) + except ValueError: + # We end up here if, for example, some Clusters have same + # iteration Dimensions but different (partial) orderings + processed.extend(maybe_fusible) + + # Maximize effectiveness of topo-sorting at next stage by only + # grouping together Clusters characterized by data dependencies + if self.toposort and prefix: + dag = self._build_dag(processed, prefix) + mapper = dag.connected_components(enumerated=True) + groups = groupby(processed, key=mapper.get) + return [ClusterGroup(tuple(g), prefix) for _, g in groups] + else: + return [ClusterGroup(processed, prefix)] + + class Key(tuple): + + """ + A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that + two Clusters (ClusterGroups) are topo-fusible if and only if their Key is + identical. + + A Key contains elements that can logically be split into two groups -- the + `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) + having same `strict` but different `weak` parts are, by definition, not + fusible; however, since at least their `strict` parts match, they can at + least be topologically reordered. + """ + + def __new__(cls, itintervals, guards, syncs, weak): + strict = [itintervals, guards, syncs] + obj = super().__new__(cls, strict + weak) + + obj.itintervals = itintervals + obj.guards = guards + obj.syncs = syncs + + obj.strict = tuple(strict) + obj.weak = tuple(weak) + + return obj + + @memoized_meth + def _key(self, c): + itintervals = frozenset(c.ispace.itintervals) + guards = c.guards if any(c.guards) else None + + # We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and + # WithLocks, but not with any other SyncOps + mapper = defaultdict(set) + for d, v in c.syncs.items(): + for s in v: + if isinstance(s, PrefetchUpdate): + continue + elif isinstance(s, WaitLock) and not self.fusetasks: + # NOTE: A mix of Clusters w/ and w/o WaitLocks can safely + # be fused, as in the worst case scenario the WaitLocks + # get "hoisted" above the first Cluster in the sequence + continue + elif isinstance(s, (InitArray, SyncArray, WaitLock, ReleaseLock)): + mapper[d].add(type(s)) + elif isinstance(s, WithLock) and self.fusetasks: + # NOTE: Different WithLocks aren't fused unless the user + # explicitly asks for it + mapper[d].add(type(s)) + else: + mapper[d].add(s) + if d in mapper: + mapper[d] = frozenset(mapper[d]) + syncs = frozendict(mapper) + + # Clusters representing HaloTouches should get merged, if possible + weak = [c.is_halo_touch] + + # If there are writes to thread-shared object, make it part of the key. + # This will promote fusion of non-adjacent Clusters writing to (some + # form of) shared memory, which in turn will minimize the number of + # necessary barriers. Same story for reads from thread-shared objects + weak.extend([ + any(f._mem_shared for f in c.scope.writes), + any(f._mem_shared for f in c.scope.reads) + ]) + weak.append(c.properties.is_core_init()) + + # Prefetchable Clusters should get merged, if possible + weak.append(c.is_glb_load_to_mem_shared) + + # Promoting adjacency of IndexDerivatives will maximize their reuse + weak.append(any(search(c.exprs, IndexDerivative))) + + # Promote adjacency of Clusters with same guard + weak.append(c.guards) + + key = self.Key(itintervals, guards, syncs, weak) + + return key + + def _apply_heuristics(self, clusters): + # We know at this point that `clusters` are potentially fusible since + # they have same `_key`, but should we actually fuse them? In most cases + # yes, but there are exceptions... + + # 1) Consider the following scenario with three Clusters: + # c0[no syncs] + # c1[WaitLock] + # c2[no syncs] + # Then we return two groups [[c0], [c1, c2]] rather than a single group + # [[c0, c1, c2]] because this way c0 can be computed without having to + # wait on a lock for a longer period + processed = [] + + group = [] + flag = False # True -> need to dump before creating a new group + + def dump(): + processed.append(tuple(group)) + group[:] = [] + + for c in clusters: + if any(isinstance(i, WaitLock) for i in flatten(c.syncs.values())): + if flag: + dump() + flag = False + else: + flag = True + group.append(c) + dump() + + # 2) Don't group HaloTouch's + groups, processed = processed, [] + for group in groups: + for flag, minigroup in groupby(group, key=lambda c: c.is_wild): + if flag: + processed.extend([(c,) for c in minigroup]) + else: + processed.append(tuple(minigroup)) + + return processed + + def _toposort(self, cgroups, prefix): + # Are there any ClusterGroups that could potentially be topologically + # reordered? If not, do not waste time + counter = Counter(self._key(cg).strict for cg in cgroups) + if len(counter.most_common()) == 1 or \ + not any(v > 1 for it, v in counter.most_common()): + return ClusterGroup(cgroups, prefix) + + dag = self._build_dag(cgroups, prefix) + + def choose_element(queue, scheduled): + if not scheduled: + return queue.pop() + + k = self._key(scheduled[-1]) + m = {i: self._key(i) for i in queue} + + # Process the `strict` part of the key + candidates = [i for i in queue if m[i].itintervals == k.itintervals] + + compatible = [i for i in candidates if m[i].guards == k.guards] + candidates = compatible or candidates + + compatible = [i for i in candidates if m[i].syncs == k.syncs] + candidates = compatible or candidates + + # Process the `weak` part of the key + for i in range(len(k.weak), -1, -1): + choosable = [e for e in candidates if m[e].weak[:i] == k.weak[:i]] + try: + # Ensure stability + e = min(choosable, key=lambda i: cgroups.index(i)) + except ValueError: + continue + queue.remove(e) + return e + + # Fallback + e = min(queue, key=lambda i: cgroups.index(i)) + queue.remove(e) + return e + + return ClusterGroup(dag.topological_sort(choose_element), prefix) + + def _build_dag(self, cgroups, prefix): + """ + A DAG representing the data dependences across the ClusterGroups within + a given scope. + """ + prefix = frozenset(i.dim for i in as_tuple(prefix)) + + dag = DAG(nodes=cgroups) + for n, cg0 in enumerate(cgroups): + # Track whether there is any fence between `cg0` and the current `cg1`. + fenced = cg0.scope.has_barrier + + for n1, cg1 in enumerate(cgroups[n+1:], start=n+1): + fenced = fenced or cg1.scope.has_barrier + + hazard = _fusion_hazards(cg0.scope, cg1.scope, prefix) + if not (hazard or fenced): + continue + + # Anti-dependences along `prefix` break the execution flow + # (intuitively, "the loop nests are to be kept separated") + # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` + # * All ClusterGroups after `cg1` cannot precede `cg1` + if hazard == BREAK: + for cg2 in cgroups[n:n1]: + dag.add_edge(cg2, cg1) + for cg2 in cgroups[n1+1:]: + dag.add_edge(cg1, cg2) + break + elif fenced or hazard == EDGE: + # Any anti- and iaw-dependences impose that `cg1` follows `cg0` + # and forbid any sort of fusion. Fences have the same effect + dag.add_edge(cg0, cg1) + + return dag + + +@timed_pass() +def fuse(clusters, toposort=False, options=None): + """ + Clusters fusion. + + If `toposort=True`, then the Clusters are reordered to maximize the likelihood + of fusion; the new ordering is computed such that all data dependencies are + honored. + + If `toposort='maximal'`, then `toposort` is performed, iteratively, multiple + times to actually maximize Clusters fusion. Hence, this is more aggressive than + `toposort=True`. + """ + if toposort != 'maximal': + return Fusion(toposort, options).process(clusters) + + nxt = clusters + while True: + nxt = fuse(clusters, toposort='nofuse', options=options) + if all(c0 is c1 for c0, c1 in zip(clusters, nxt, strict=True)): + break + clusters = nxt + clusters = fuse(clusters, toposort=False, options=options) + + return clusters diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 494ebe7490..68b982eedc 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -1,18 +1,13 @@ -from collections import Counter, defaultdict from itertools import groupby, product -from devito.finite_differences import IndexDerivative -from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass -from devito.ir.support import ( - SEPARABLE, SEQUENTIAL, InitArray, PrefetchUpdate, ReleaseLock, Scope, SyncArray, - WaitLock, WithLock -) +from devito.ir.clusters import Queue, cluster_pass +from devito.ir.support import SEPARABLE, SEQUENTIAL, Scope from devito.passes.clusters.utils import in_critical_region -from devito.symbolics import pow_to_mul, search -from devito.tools import DAG, Stamp, as_tuple, flatten, frozendict, timed_pass +from devito.symbolics import pow_to_mul +from devito.tools import Stamp, flatten, frozendict, timed_pass from devito.types import Hyperplane -__all__ = ['Lift', 'fission', 'fuse', 'optimize_hyperplanes', 'optimize_pows'] +__all__ = ['Lift', 'fission', 'optimize_hyperplanes', 'optimize_pows'] class Lift(Queue): @@ -107,309 +102,12 @@ def callback(self, clusters, prefix): return lifted + processed -class Fusion(Queue): - - """ - Fuse Clusters with compatible IterationSpace. - """ - - _q_guards_in_key = True - _q_syncs_in_key = True - - def __init__(self, toposort, options=None): - options = options or {} - - self.toposort = toposort - self.fusetasks = options.get('fuse-tasks', False) - - super().__init__() - - def process(self, clusters): - cgroups = [ClusterGroup(c, c.ispace) for c in clusters] - cgroups = self._process_fdta(cgroups, 1) - clusters = ClusterGroup.concatenate(*cgroups) - return clusters - - def callback(self, cgroups, prefix): - # Toposort to maximize fusion - if self.toposort: - clusters = self._toposort(cgroups, prefix) - if self.toposort == 'nofuse': - return [clusters] - else: - clusters = ClusterGroup(cgroups) - - # Fusion - processed = [] - for _, group in groupby(clusters, key=self._key): - g = list(group) - - for maybe_fusible in self._apply_heuristics(g): - try: - # Perform fusion - processed.append(Cluster.from_clusters(*maybe_fusible)) - except ValueError: - # We end up here if, for example, some Clusters have same - # iteration Dimensions but different (partial) orderings - processed.extend(maybe_fusible) - - # Maximize effectiveness of topo-sorting at next stage by only - # grouping together Clusters characterized by data dependencies - if self.toposort and prefix: - dag = self._build_dag(processed, prefix) - mapper = dag.connected_components(enumerated=True) - groups = groupby(processed, key=mapper.get) - return [ClusterGroup(tuple(g), prefix) for _, g in groups] - else: - return [ClusterGroup(processed, prefix)] - - class Key(tuple): - - """ - A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that - two Clusters (ClusterGroups) are topo-fusible if and only if their Key is - identical. - - A Key contains elements that can logically be split into two groups -- the - `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) - having same `strict` but different `weak` parts are, by definition, not - fusible; however, since at least their `strict` parts match, they can at - least be topologically reordered. - """ - - def __new__(cls, itintervals, guards, syncs, weak): - strict = [itintervals, guards, syncs] - obj = super().__new__(cls, strict + weak) - - obj.itintervals = itintervals - obj.guards = guards - obj.syncs = syncs - - obj.strict = tuple(strict) - obj.weak = tuple(weak) - - return obj - - def _key(self, c): - itintervals = frozenset(c.ispace.itintervals) - guards = c.guards if any(c.guards) else None - - # We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and - # WithLocks, but not with any other SyncOps - mapper = defaultdict(set) - for d, v in c.syncs.items(): - for s in v: - if isinstance(s, PrefetchUpdate): - continue - elif isinstance(s, WaitLock) and not self.fusetasks: - # NOTE: A mix of Clusters w/ and w/o WaitLocks can safely - # be fused, as in the worst case scenario the WaitLocks - # get "hoisted" above the first Cluster in the sequence - continue - elif isinstance(s, (InitArray, SyncArray, WaitLock, ReleaseLock)): - mapper[d].add(type(s)) - elif isinstance(s, WithLock) and self.fusetasks: - # NOTE: Different WithLocks aren't fused unless the user - # explicitly asks for it - mapper[d].add(type(s)) - else: - mapper[d].add(s) - if d in mapper: - mapper[d] = frozenset(mapper[d]) - syncs = frozendict(mapper) - - # Clusters representing HaloTouches should get merged, if possible - weak = [c.is_halo_touch] - - # If there are writes to thread-shared object, make it part of the key. - # This will promote fusion of non-adjacent Clusters writing to (some - # form of) shared memory, which in turn will minimize the number of - # necessary barriers. Same story for reads from thread-shared objects - weak.extend([ - any(f._mem_shared for f in c.scope.writes), - any(f._mem_shared for f in c.scope.reads) - ]) - weak.append(c.properties.is_core_init()) - - # Prefetchable Clusters should get merged, if possible - weak.append(c.is_glb_load_to_mem_shared) - - # Promoting adjacency of IndexDerivatives will maximize their reuse - weak.append(any(search(c.exprs, IndexDerivative))) - - # Promote adjacency of Clusters with same guard - weak.append(c.guards) - - key = self.Key(itintervals, guards, syncs, weak) - - return key - - def _apply_heuristics(self, clusters): - # We know at this point that `clusters` are potentially fusible since - # they have same `_key`, but should we actually fuse them? In most cases - # yes, but there are exceptions... - - # 1) Consider the following scenario with three Clusters: - # c0[no syncs] - # c1[WaitLock] - # c2[no syncs] - # Then we return two groups [[c0], [c1, c2]] rather than a single group - # [[c0, c1, c2]] because this way c0 can be computed without having to - # wait on a lock for a longer period - processed = [] - - group = [] - flag = False # True -> need to dump before creating a new group - - def dump(): - processed.append(tuple(group)) - group[:] = [] - - for c in clusters: - if any(isinstance(i, WaitLock) for i in flatten(c.syncs.values())): - if flag: - dump() - flag = False - else: - flag = True - group.append(c) - dump() - - # 2) Don't group HaloTouch's - - groups, processed = processed, [] - for group in groups: - for flag, minigroup in groupby(group, key=lambda c: c.is_wild): - if flag: - processed.extend([(c,) for c in minigroup]) - else: - processed.append(tuple(minigroup)) - - return processed - - def _toposort(self, cgroups, prefix): - # Are there any ClusterGroups that could potentially be topologically - # reordered? If not, do not waste time - counter = Counter(self._key(cg).strict for cg in cgroups) - if len(counter.most_common()) == 1 or \ - not any(v > 1 for it, v in counter.most_common()): - return ClusterGroup(cgroups, prefix) - - dag = self._build_dag(cgroups, prefix) - - def choose_element(queue, scheduled): - if not scheduled: - return queue.pop() - - k = self._key(scheduled[-1]) - m = {i: self._key(i) for i in queue} - - # Process the `strict` part of the key - candidates = [i for i in queue if m[i].itintervals == k.itintervals] - - compatible = [i for i in candidates if m[i].guards == k.guards] - candidates = compatible or candidates - - compatible = [i for i in candidates if m[i].syncs == k.syncs] - candidates = compatible or candidates - - # Process the `weak` part of the key - for i in range(len(k.weak), -1, -1): - choosable = [e for e in candidates if m[e].weak[:i] == k.weak[:i]] - try: - # Ensure stability - e = min(choosable, key=lambda i: cgroups.index(i)) - except ValueError: - continue - queue.remove(e) - return e - - # Fallback - e = min(queue, key=lambda i: cgroups.index(i)) - queue.remove(e) - return e - - return ClusterGroup(dag.topological_sort(choose_element), prefix) - - def _build_dag(self, cgroups, prefix): - """ - A DAG representing the data dependences across the ClusterGroups within - a given scope. - """ - prefix = {i.dim for i in as_tuple(prefix)} - - dag = DAG(nodes=cgroups) - for n, cg0 in enumerate(cgroups): - - def is_cross(source, sink): - # True if a cross-ClusterGroup dependence, False otherwise - t0 = source.timestamp - t1 = sink.timestamp - v = len(cg0.exprs) # noqa: B023 - return t0 < v <= t1 or t1 < v <= t0 - - for n1, cg1 in enumerate(cgroups[n+1:], start=n+1): - - # A Scope to compute all cross-ClusterGroup anti-dependences - scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross) - - # Anti-dependences along `prefix` break the execution flow - # (intuitively, "the loop nests are to be kept separated") - # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` - # * All ClusterGroups after `cg1` cannot precede `cg1` - if any(i.cause & prefix for i in scope.d_anti_gen()): - for cg2 in cgroups[n:cgroups.index(cg1)]: - dag.add_edge(cg2, cg1) - for cg2 in cgroups[cgroups.index(cg1)+1:]: - dag.add_edge(cg1, cg2) - break - - # Any anti- and iaw-dependences impose that `cg1` follows `cg0` - # and forbid any sort of fusion. Fences have the same effect - elif ( - any(scope.d_anti_gen()) or - any(i.is_iaw for i in scope.d_output_gen()) or - any(c.is_fence for c in flatten(cgroups[n:n1+1])) - ) or any(not (i.cause and i.cause & prefix) for i in scope.d_flow_gen()) \ - or any(scope.d_output_gen()): - dag.add_edge(cg0, cg1) - - return dag - - -@timed_pass() -def fuse(clusters, toposort=False, options=None): - """ - Clusters fusion. - - If `toposort=True`, then the Clusters are reordered to maximize the likelihood - of fusion; the new ordering is computed such that all data dependencies are - honored. - - If `toposort='maximal'`, then `toposort` is performed, iteratively, multiple - times to actually maximize Clusters fusion. Hence, this is more aggressive than - `toposort=True`. - """ - if toposort != 'maximal': - return Fusion(toposort, options).process(clusters) - - nxt = clusters - while True: - nxt = fuse(clusters, toposort='nofuse', options=options) - if all(c0 is c1 for c0, c1 in zip(clusters, nxt, strict=True)): - break - clusters = nxt - clusters = fuse(clusters, toposort=False, options=options) - - return clusters - - @cluster_pass(mode='all') def optimize_pows(cluster, *args): """ Convert integer powers into Muls, such as ``a**2 => a*a``. """ - return cluster.rebuild(exprs=[pow_to_mul(e) for e in cluster.exprs]) + return cluster.rebuild(exprs=pow_to_mul(cluster.exprs)) class Fission(Queue): diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 9b936fba76..6d1a9527ee 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -15,7 +15,9 @@ from devito.mpi.routines import Gather, HaloUpdate, HaloWait, MPIMsg, Scatter from devito.passes import needs_transfer from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search -from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass +from devito.tools import ( + DAG, as_hashable, as_tuple, filter_ordered, memoized_func, sorted_priority, timed_pass +) from devito.types import ( Array, Auto, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension, Indirection, ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, Symbol, @@ -25,7 +27,7 @@ from devito.types.dense import DiscreteFunction from devito.types.dimension import AbstractIncrDimension, BlockDimension -__all__ = ['Graph', 'iet_pass', 'iet_visit'] +__all__ = ['Graph', 'finalize_args', 'iet_pass', 'iet_visit'] class Byproduct: @@ -102,7 +104,7 @@ def sync_mapper(self): A mapper {Iteration -> SyncSpot} describing the Iterations, if any, living an asynchronous region, across all Callables in the Graph. """ - dag = create_call_graph(self.root.name, self.efuncs) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) mapper = MapNodes(SyncSpot, (Iteration, Call)).visit(self.root) @@ -127,14 +129,20 @@ def sync_mapper(self): def apply(self, func, **kwargs): """ - Apply `func` to all nodes in the Graph. This changes the state of the Graph. + Apply ``func`` to all nodes in the Graph. + + Callable parameters and Call arguments are reconciled before the graph + walk, after each changed node, and after the pass has completed. """ - dag = create_call_graph(self.root.name, self.efuncs) + _update_args(self) + + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) # Apply `func` efuncs = dict(self.efuncs) for i in dag.topological_sort(): efunc, metadata = func(efuncs[i], **kwargs) + new_efuncs = metadata.get('efuncs', []) self.includes.extend(as_tuple(metadata.get('includes'))) self.headers.extend(as_tuple(metadata.get('headers'))) @@ -151,17 +159,13 @@ def apply(self, func, **kwargs): except KeyError: pass - if efunc is efuncs[i]: + if efunc is efuncs[i] and not new_efuncs: continue - new_efuncs = metadata.get('efuncs', []) - efuncs[i] = efunc efuncs.update(dict([(i.name, i) for i in new_efuncs])) - # Update the parameters / arguments lists since `func` may have - # introduced or removed objects - efuncs = update_args(efunc, efuncs, dag) + efuncs = _update_args_efunc(efunc, efuncs, dag) # Minimize code size if len(efuncs) > len(self.efuncs): @@ -170,6 +174,7 @@ def apply(self, func, **kwargs): efuncs = reuse_efuncs(self.root, efuncs, self.sregistry) self.efuncs = efuncs + _update_args(self) # Uniqueness self.includes = filter_ordered(self.includes) @@ -184,7 +189,7 @@ def visit(self, func, **kwargs): from nodes to info. Unlike `apply`, `visit` does not change the state of the Graph. """ - dag = create_call_graph(self.root.name, self.efuncs) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) toposort = dag.topological_sort() mapper = dict([(i, func(self.efuncs[i], **kwargs)) for i in toposort]) @@ -206,7 +211,35 @@ def filter(self, key): ) -def iet_pass(func): +@timed_pass(name='finalize_args') +def finalize_args(graph): + """ + Finalize Callable parameter lists and Call argument lists across ``graph``. + + IET passes may temporarily leave helper signatures stale while introducing + or eliminating symbols. This pass reconciles the whole call graph once, + after lowering has settled. + """ + _update_args(graph) + + +def _update_args(graph): + dag = create_call_graph(graph.root.name, as_hashable(graph.efuncs)) + + efuncs = graph.efuncs + for i in dag.topological_sort(): + efuncs = _update_args_efunc(efuncs[i], efuncs, dag) + + graph.efuncs = efuncs + + +def iet_pass(func=None): + """ + Decorate an IET pass. + """ + if func is None: + return iet_pass + if isinstance(func, tuple): assert len(func) == 2 and func[0] is iet_visit call = lambda graph: graph.visit @@ -231,6 +264,7 @@ def wrapper(*args, **kwargs): # Instance method case self, graph = args return maybe_timed(call(graph), func.__name__)(partial(func, self), **kwargs) + return wrapper @@ -238,11 +272,14 @@ def iet_visit(func): return iet_pass((iet_visit, func)) +@memoized_func(scope='build') def create_call_graph(root, efuncs): """ Create a Call graph -- a Direct Acyclic Graph with edges from callees to callers. """ + efuncs = dict(efuncs) + dag = DAG(nodes=[root]) queue = [root] @@ -438,7 +475,7 @@ def reuse_efuncs(root, efuncs, sregistry=None): # assuming that `bar0` and `bar1` are compatible, we first process the # `bar`'s to obtain `[foo0(u(x)): bar0(u), foo1(u(x)): bar0(u)]`, # and finally `foo0(u(x)): bar0(u)` - dag = create_call_graph(root.name, efuncs) + dag = create_call_graph(root.name, as_hashable(efuncs)) mapper = {} for i in dag.topological_sort(): @@ -480,6 +517,7 @@ def reuse_efuncs(root, efuncs, sregistry=None): return retval +@memoized_func(scope='build') def abstract_efunc(efunc): """ Abstract `efunc` applying a set of rules: @@ -492,7 +530,7 @@ def abstract_efunc(efunc): """ functions = FindSymbols('basics|symbolics|dimensions').visit(efunc) - mapper = abstract_objects(functions) + mapper = abstract_objects(tuple(functions)) efunc = Uxreplace(mapper).visit(efunc) efunc = efunc._rebuild(name='foo') @@ -500,7 +538,8 @@ def abstract_efunc(efunc): return efunc -def abstract_objects(objects0, sregistry=None): +@memoized_func(scope='build') +def abstract_objects(objects0): """ Proxy for `abstract_object`. """ @@ -519,7 +558,7 @@ def abstract_objects(objects0, sregistry=None): # Build abstraction mappings mapper = {} - sregistry = sregistry or SymbolRegistry() + sregistry = SymbolRegistry() for i in objects: abstract_object(i, mapper, sregistry) @@ -690,7 +729,7 @@ def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='nthreads')) -def update_args(root, efuncs, dag): +def _update_args_efunc(root, efuncs, dag): """ Re-derive the parameters of `root` and apply the changes in cascade through the `efuncs`. @@ -780,6 +819,14 @@ def _filter(v, efunc=None): mapper = {c: c._rebuild(arguments=_filter(c.arguments)) for c in FindNodes(Call).visit(efuncs[n]) if c.name == root.name} - efuncs[n] = Transformer(mapper).visit(efuncs[n]) + if not mapper: + continue + + efunc = Transformer(mapper).visit(efuncs[n]) + if efunc is efuncs[n]: + continue + + efuncs[n] = efunc + efuncs = _update_args_efunc(efunc, efuncs, dag) return efuncs diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 2de99ee002..db835400ae 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -293,17 +293,16 @@ def _mark_overlappable(iet): scope = Scope([n.expr for n in exprs]) - for dep in scope.d_all_gen(): - if dep.function in hs.functions: - cause = dep.cause & hs.dimensions - if any(dep.distance_mapper[d] is S.Infinity for d in cause): - # E.g., dependencies across PARALLEL iterations - # for x - # for y - # ... = ... f[x, y-1] ... - # for y - # f[x, y] = ... - break + for dep in scope.d_all_gen(writes=hs.functions): + cause = dep.cause & hs.dimensions + if any(dep.distance_mapper[d] is S.Infinity for d in cause): + # E.g., dependencies across PARALLEL iterations + # for x + # for y + # ... = ... f[x, y-1] ... + # for y + # f[x, y] = ... + break else: # All good -- we can perform comp/comm overlap! found.append(hs) @@ -507,7 +506,7 @@ def rule1(dep, loc_indices): for d, v in loc_indices.items()) for f, v in hsf.fmapper.items(): - for dep in scope.d_flow.project(f): + for dep in scope.d_flow_gen(writes=f): if not rule0(dep) and not rule1(dep, v.loc_indices): return False diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 57d9314e16..44dabb7790 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -64,6 +64,9 @@ def uxreplace(expr, rule): Finally, `uxreplace` supports Reconstructable objects, that is, it searches for replacement opportunities inside the Reconstructable's `__rkwargs__`. """ + if not rule: + return expr + return _uxreplace(expr, rule)[0] @@ -129,13 +132,15 @@ def _(iterable, rule): ax, flag = _uxreplace(a, rule) ret.append(ax) changed |= flag - return iterable.__class__(ret), changed + return (iterable.__class__(ret), True) if changed else (iterable, False) @_uxreplace_dispatch.register(EnrichedTuple) def _(iterable, rule): retval, changed = _uxreplace_dispatch(tuple(iterable), rule) - return iterable.__class__(*retval, getters=iterable.getters), changed + if changed: + return iterable.__class__(*retval, getters=iterable.getters), True + return iterable, False @_uxreplace_dispatch.register(dict) @@ -146,7 +151,7 @@ def _(mapper, rule): vx, flag = _uxreplace_dispatch(v, rule) ret[k] = vx changed |= flag - return ret, changed + return (ret, True) if changed else (mapper, False) @singledispatch @@ -282,10 +287,18 @@ def subs_if_composite(expr, subs): Indexed"). Instead, if `subs` consists of just "primitive" expressions, then resort to the much faster `uxreplace`. """ - if all(isinstance(i, (Indexed, IndexDerivative)) for i in subs): + if not subs: + return expr + + if type(expr) is tuple: + return reuse_if_untouched(expr, (subs_if_composite(e, subs) for e in expr)) + elif type(expr) is list: + return reuse_if_untouched(expr, [subs_if_composite(e, subs) for e in expr]) + elif all(isinstance(i, (Indexed, IndexDerivative)) for i in subs): return uxreplace(expr, subs) else: - return expr.subs(subs) + processed = expr.subs(subs) + return expr if processed == expr else processed def xreplace_indices(exprs, mapper, key=None): @@ -304,14 +317,26 @@ def xreplace_indices(exprs, mapper, key=None): callable, apply the replacement to a symbol S if and only if ``key(S)`` gives True. """ - handle = flatten(retrieve_indexed(i) for i in as_tuple(exprs)) + exprs0 = as_tuple(exprs) + + handle = flatten(retrieve_indexed(i) for i in exprs0) if isinstance(key, Iterable): handle = [i for i in handle if i.base.label in key] elif callable(key): handle = [i for i in handle if key(i)] - mapper = dict(zip(handle, [i.xreplace(mapper) for i in handle], strict=True)) - replaced = [uxreplace(i, mapper) for i in as_tuple(exprs)] - return replaced if isinstance(exprs, Iterable) else replaced[0] + mapper = {i: v for i in handle if (v := i.xreplace(mapper)) != i} + if not mapper: + return exprs + + replaced = [uxreplace(i, mapper) for i in exprs0] + + if isinstance(exprs, Iterable): + if len(replaced) == len(exprs0) and \ + all(i is j for i, j in zip(replaced, exprs0, strict=True)): + return exprs + return replaced + else: + return replaced[0] def _eval_numbers(expr, args): @@ -344,7 +369,9 @@ def flatten_args(args, op, ignore=None): def pow_to_mul(expr): - if q_leaf(expr) or isinstance(expr, Basic): + if isinstance(expr, (tuple, list)): + return reuse_if_untouched(expr, (pow_to_mul(i) for i in expr)) + elif q_leaf(expr) or isinstance(expr, Basic): return expr elif expr.is_Pow: base, exp = expr.as_base_exp() @@ -359,7 +386,7 @@ def pow_to_mul(expr): elif (int(exp) - exp != 0): # Fractional powers also remain untouched, # but at least we traverse the base looking for other Pows - return expr.func(pow_to_mul(base), exp, evaluate=False) + return reuse_if_untouched(expr, (pow_to_mul(base), exp), evaluate=False) elif exp > 0: return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False) elif exp < 0: @@ -383,7 +410,7 @@ def pow_to_mul(expr): except ValueError: pass - return expr.func(*args, evaluate=False) + return reuse_if_untouched(expr, args, evaluate=False) def indexify(expr): @@ -429,10 +456,18 @@ def normalize_args(args): def reuse_if_untouched(expr, args, evaluate=False): """ - Reconstruct `expr` iff any of the provided `args` is different than - the corresponding arg in `expr.args`. + Reconstruct `expr` iff any of the provided `args` is different from + the corresponding arg in `expr.args`, or from the corresponding item + for plain tuples/lists. """ - if all(a is b for a, b in zip(expr.args, args, strict=False)): + args = tuple(args) + + if isinstance(expr, (tuple, list)): + if len(args) == len(expr) and \ + all(a is b for a, b in zip(expr, args, strict=True)): + return expr + return type(expr)(args) + elif all(a is b for a, b in zip(expr.args, args, strict=False)): return expr else: return expr.func(*args, evaluate=evaluate) diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index 9c30948064..522210fbca 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -8,9 +8,10 @@ from devito.symbolics.queries import ( q_derivative, q_dimension, q_function, q_indexed, q_leaf, q_symbol, q_terminal ) -from devito.tools import as_tuple +from devito.tools import as_tuple, memoized_func __all__ = [ + 'retrieve_accesses', 'retrieve_derivatives', 'retrieve_dimensions', 'retrieve_function_carriers', @@ -140,10 +141,19 @@ def retrieve_indexed(exprs, mode='all', deep=False): def retrieve_functions(exprs, mode='all', deep=False): """Shorthand to retrieve the DiscreteFunctions in `exprs`.""" - indexeds = search(exprs, q_indexed, mode, 'dfs', deep) + query = lambda i: q_function(i) or q_indexed(i) + found = search(exprs, query, 'all', 'dfs', deep) + + functions = modes[mode]() + indexed_functions = set() + + for i in found: + if q_function(i): + functions.add(i) if mode == 'unique' else functions.append(i) + else: + indexed_functions.add(i.function) - functions = search(exprs, q_function, mode, 'dfs', deep) - functions.update({i.function for i in indexeds}) + functions.update(indexed_functions) return functions @@ -177,6 +187,26 @@ def retrieve_terminals(exprs, mode='all', deep=False): return search(exprs, q_terminal, mode, 'dfs', deep) +@memoized_func(scope='build') +def retrieve_accesses(exprs, deep=False): + """ + Like retrieve_terminals, but ensure that if a ComponentAccess is found, + the ComponentAccess itself is returned, while the wrapped Indexed is discarded. + """ + from devito.symbolics.manipulation import uxreplace + from devito.types import ComponentAccess, Symbol + + compaccs = search(exprs, ComponentAccess) + if not compaccs: + return frozenset(retrieve_terminals(exprs, mode='unique', deep=deep)) + + subs = {i: Symbol(f'dummy{n}') for n, i in enumerate(compaccs)} + exprs1 = uxreplace(exprs, subs) + + return frozenset(compaccs | retrieve_terminals(exprs1, mode='unique', deep=deep) - + set(subs.values())) + + def retrieve_dimensions(exprs, mode='all', deep=False): """Shorthand to retrieve the dimensions in ``exprs``.""" return search(exprs, q_dimension, mode, 'dfs', deep) diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index d875878d02..2bf901c3f4 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -14,6 +14,7 @@ __all__ = [ 'DAG', 'Bunch', + 'DefaultFrozenDict', 'DefaultOrderedDict', 'EnrichedTuple', 'MemoryEstimate', @@ -672,6 +673,42 @@ def __hash__(self): return self._hash +class DefaultFrozenDict(frozendict): + """ + An immutable mapper that returns a configured default value for missing + keys when accessed via ``obj[key]``. + + Unlike :class:`collections.defaultdict`, the mapping remains immutable and + missing-key access does not mutate internal state. The ``get`` method + preserves the standard dictionary semantics, defaulting to ``None`` unless + the caller provides an explicit fallback. + """ + + _sentinel = object() + + def __init__(self, *args, default=_sentinel, **kwargs): + self._default = default + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + try: + return self._dict[key] + except KeyError: + if self._default is self._sentinel: + raise + + if callable(self._default): + return self._default() + else: + return self._default + + def get(self, key, default=None): + return self._dict.get(key, default) + + def copy(self, **add_or_replace): + return self.__class__(self, default=self._default, **add_or_replace) + + class MemoryEstimate(frozendict): """ An immutable mapper for a memory estimate, providing the estimated memory diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index c10f5ea092..7ffc461591 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -1,9 +1,39 @@ from collections.abc import Callable, Hashable -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from itertools import tee from typing import TypeVar +from weakref import WeakKeyDictionary -__all__ = ['CacheInstances', 'memoized_func', 'memoized_generator', 'memoized_meth'] +__all__ = [ + 'CacheInstances', + 'cached_hash', + 'memoized_func', + 'memoized_generator', + 'memoized_meth', + 'memoized_weak_meth', + 'reuse_if_unchanged' +] + + +def cached_hash(func): + """ + Cache an immutable object's ``__hash__`` return value in ``_mhash``. + + Warning: avoid explicitly calling a superclass' cached ``__hash__`` on a + subclass instance, as that would stash the superclass hash in ``_mhash``. + + Warning: avoid using it on pickled objects. + """ + @wraps(func) + def wrapper(self): + try: + return self._mhash + except AttributeError: + ret = func(self) + self._mhash = ret + return ret + + return wrapper class memoized_func: @@ -19,9 +49,22 @@ class memoized_func: https://wiki.python.org/moin/PythonDecoratorLibrary#Memoize """ - def __init__(self, func): + # Long-lived caches for process-global helpers, such as arch discovery. + _scope_persistent = 'persistent' + # Build-scoped caches that may retain compiler inputs during Operator construction. + _scope_build = 'build' + _scoped_caches = {} + + def __new__(cls, func=None, *, scope=None): + if func is None: + return lambda f: cls(f, scope=scope) + return super().__new__(cls) + + def __init__(self, func, *, scope=None): self.func = func + self.scope = scope or self._scope_persistent self.cache = {} + self._scoped_caches.setdefault(self.scope, set()).add(self) def __call__(self, *args, **kw): if not isinstance(args, Hashable): @@ -44,6 +87,18 @@ def __get__(self, obj, objtype): """Support instance methods.""" return partial(self.__call__, obj) + def clear(self): + self.cache.clear() + + @classmethod + def clear_scoped_caches(cls, scope): + for cache in cls._scoped_caches.get(scope, ()): + cache.clear() + + @classmethod + def clear_build_caches(cls): + cls.clear_scoped_caches(cls._scope_build) + class memoized_meth: """ @@ -86,11 +141,19 @@ def __call__(self, *args, **kw): cache = obj.__cache_meth except AttributeError: cache = obj.__cache_meth = {} - key = (self.func, args[1:], frozenset(kw.items())) + if kw: + key = (self.func, args[1:], frozenset(kw.items())) + else: + key = (self.func, args[1:]) + try: res = cache[key] except KeyError: res = cache[key] = self.func(*args, **kw) + except TypeError: + # Uncacheable, e.g. an unhashable item within ``args``. + return self.func(*args, **kw) + return res @@ -128,6 +191,54 @@ def __call__(self, *args, **kwargs): return result +def memoized_weak_meth(*, key=None, freeze=None, thaw=None): + """ + Cache a method result against its first argument using weak references. + + This is useful for visitors operating on temporary IR roots: the cache can + be shared across short-lived visitor instances without keeping those roots + alive. Only calls without extra arguments are cached; all other calls fall + back to the wrapped method. + + Parameters + ---------- + key : callable, optional + A callable receiving ``self`` and returning a hashable cache partition. + freeze : callable, optional + Convert the method result before storing it in the cache. + thaw : callable, optional + Convert the cached value before returning it to the caller. + """ + def decorator(func): + caches = {} + + @wraps(func) + def wrapper(self, o, *args, **kwargs): + if args or kwargs: + return func(self, o, *args, **kwargs) + + try: + partition = key(self) if key is not None else None + cache = caches.setdefault(partition, WeakKeyDictionary()) + ret = cache[o] + except KeyError: + ret = func(self, o) + if freeze is not None: + ret = freeze(ret) + cache[o] = ret + except TypeError: + return func(self, o) + + if thaw is not None: + return thaw(ret) + + return ret + + return wrapper + + return decorator + + # Describes the type of a subclass of CacheInstances InstanceType = TypeVar('InstanceType', bound='CacheInstances', covariant=True) @@ -154,6 +265,9 @@ def __init__(cls: type[InstanceType], *args) -> None: # type: ignore def __call__(cls: type[InstanceType], # type: ignore *args, **kwargs) -> InstanceType: + if cls._instance_cache_size == 0: + return super().__call__(*args, **kwargs) + args, kwargs = cls._preprocess_args(*args, **kwargs) return cls._instance_cache(*args, **kwargs) @@ -173,7 +287,7 @@ class CacheInstances(metaclass=CacheInstancesMeta): """ _instance_cache: Callable | None = None - _instance_cache_size: int = 128 + _instance_cache_size: int = 8192 @classmethod def _preprocess_args(cls, *args, **kwargs): @@ -189,3 +303,36 @@ def clear_caches() -> None: Clears all IR instance caches. """ CacheInstancesMeta.clear_caches() + + +def reuse_if_unchanged(fields): + """ + Decorator for wrapper-style constructors that should return the original + object when called as ``Cls(existing_obj, **same_metadata)``. + + The wrapped callable is assumed to be a classmethod-like constructor + receiving ``cls`` as first argument. The fast path triggers only when: + + * the constructor is called with exactly one positional argument; + * that argument is already an exact instance of ``cls``; + * any explicitly provided metadata fields are the same objects as the + corresponding attributes on the input object. + """ + def decorator(func): + @wraps(func) + def wrapper(cls, *args, **kwargs): + if len(args) == 1: + input_obj = args[0] + if type(input_obj) is cls: + names = getattr(cls, fields) if isinstance(fields, str) else fields + for name in names: + if name in kwargs and \ + kwargs[name] is not getattr(input_obj, name, None): + break + else: + return input_obj + return func(cls, *args, **kwargs) + + return wrapper + + return decorator diff --git a/devito/tools/utils.py b/devito/tools/utils.py index 91b5bcdbf7..470be7e79e 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -1,6 +1,6 @@ import types from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from functools import reduce, wraps from itertools import chain, combinations, groupby, product, zip_longest from operator import attrgetter, mul @@ -10,6 +10,7 @@ __all__ = [ 'all_equal', + 'as_hashable', 'as_list', 'as_mapper', 'as_set', @@ -87,6 +88,28 @@ def as_tuple(item, type=None, length=None): return t +def as_hashable(item): + """ + Convert common containers into a hashable representation. + + Unknown unhashable objects fall back to identity, avoiding false cache hits. + """ + if isinstance(item, Mapping): + items = ((as_hashable(k), as_hashable(v)) for k, v in item.items()) + return tuple(sorted(items, key=repr)) + if isinstance(item, (tuple, list)): + return tuple(as_hashable(i) for i in item) + if isinstance(item, (set, frozenset)): + return tuple(sorted((as_hashable(i) for i in item), key=repr)) + + try: + hash(item) + except TypeError: + return (type(item), id(item)) + else: + return item + + def as_mapper(iterable, key=None, get=None): """ Rearrange an iterable into a dictionary of lists in which keys are diff --git a/devito/types/basic.py b/devito/types/basic.py index 22a87e848f..f80f23ba0c 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -498,6 +498,8 @@ def _C_name(self): @property def _C_ctype(self): + if self.dtype is None: + return CustomDtype('void') return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): diff --git a/devito/types/caching.py b/devito/types/caching.py index 742c0b3d33..9fbbcaa638 100644 --- a/devito/types/caching.py +++ b/devito/types/caching.py @@ -4,7 +4,7 @@ import sympy from sympy.core import cache -from devito.tools import safe_dict_copy +from devito.tools import memoized_func, safe_dict_copy __all__ = ['CacheManager', 'Cached', 'Uncached', '_SymbolCache'] @@ -175,6 +175,10 @@ def clear(cls, force=True): # SymPy 1.14 and later pass + # Drop compiler-scoped Python memoization that may still hold strong + # references to symbolic objects pending collection. + memoized_func.clear_build_caches() + # Take a copy of the dictionary so we can safely iterate over it # even if another thread is making changes cache_copied = safe_dict_copy(_SymbolCache) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 18d0d3609c..861ceb17d7 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -10,7 +10,7 @@ from devito import Constant, Eq, Function, Grid, Operator, configuration, exp, log, sin from devito.arch.compiler import CustomCompiler, GNUCompiler from devito.exceptions import InvalidOperator -from devito.ir.cgen.printer import BasePrinter +from devito.ir.cgen.printer import BasePrinter, get_printer from devito.passes.iet.langbase import LangBB from devito.passes.iet.languages.C import CBB, CPrinter from devito.passes.iet.languages.openacc import AccBB, AccPrinter @@ -204,6 +204,18 @@ def test_math_functions(dtype: np.dtype[np.inexact], assert call_str in str(op) +def test_printer_registry() -> None: + default = get_printer(CPrinter, np.float32) + + assert get_printer(CPrinter, np.float32) is default + + float64 = get_printer(CPrinter, np.float64) + assert get_printer(CPrinter, np.float64) is float64 + + float16 = get_printer(CPrinter, np.float16) + assert get_printer(CPrinter, np.float16) is float16 + + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: """ diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index e84d0df5d8..a11aea7a5f 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -79,9 +79,18 @@ class TestDeviceID: CUDA_VISIBLE_DEVICES are correctly handled. """ + visible_device_envs = ( + 'CUDA_VISIBLE_DEVICES', + 'NVIDIA_VISIBLE_DEVICES', + 'ROCR_VISIBLE_DEVICES', + 'HIP_VISIBLE_DEVICES' + ) + @pytest.mark.parametrize('env_variables', [{"CUDA_VISIBLE_DEVICES": "1"}, {"CUDA_VISIBLE_DEVICES": "1,2"}, {"CUDA_VISIBLE_DEVICES": "1,0"}, + {"NVIDIA_VISIBLE_DEVICES": "1"}, + {"NVIDIA_VISIBLE_DEVICES": "1,2"}, {"ROCR_VISIBLE_DEVICES": "1"}, {"HIP_VISIBLE_DEVICES": " 1"}]) def test_visible_devices(self, env_variables): @@ -100,12 +109,21 @@ def test_visible_devices(self, env_variables): # All variants in parameterisation should yield deviceid 1 assert argmap1._physical_deviceid == 1 - # Check that physical deviceid is 0 when no environment variables set - op2 = Operator(eq) + def test_default_physical_deviceid(self): + """ + Test that the default physical device ID is 0 when no visible-device + environment variable is set. + """ + grid = Grid(shape=(10, 10)) + u = Function(name='u', grid=grid) + + eq = Eq(u, u+1) + + with switchenv({i: None for i in self.visible_device_envs}): + op2 = Operator(eq) - argmap2 = op2.arguments() - # Default physical deviceid expected to be 0 - assert argmap2._physical_deviceid == 0 + argmap2 = op2.arguments() + assert argmap2._physical_deviceid == 0 @pytest.mark.parallel(mode=2) @pytest.mark.parametrize('visible_devices', [ @@ -142,9 +160,10 @@ def test_visible_devices_mpi(self, visible_devices, mode): assert argmap1._physical_deviceid == expected # In default case, physical deviceid will equal rank - op2 = Operator(eq) - argmap2 = op2.arguments() - assert argmap2._physical_deviceid == rank + with switchenv({i: None for i in self.visible_device_envs}): + op2 = Operator(eq) + argmap2 = op2.arguments() + assert argmap2._physical_deviceid == rank def test_visible_devices_with_devito_deviceid(self): """Test interaction between CUDA_VISIBLE_DEVICES and DEVITO_DEVICEID""" @@ -184,8 +203,9 @@ def test_deviceid_per_rank(self, mode): op = Operator(Eq(u, u+1)) - argmap = op.arguments(deviceid=deviceid) - assert argmap._physical_deviceid == deviceid + with switchenv({i: None for i in self.visible_device_envs}): + argmap = op.arguments(deviceid=deviceid) + assert argmap._physical_deviceid == deviceid class TestCodeGeneration: diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index eda7351bb4..75b125d83d 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -887,7 +887,8 @@ def test_interp_complex(self, dtype): sc.coordinates.data[:] = [.5, .5, .5] fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) - fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) + rng = np.random.RandomState(0) + fc.data[:] = rng.randn(*grid.shape) + 1j * rng.randn(*grid.shape) opC = Operator([sc.interpolate(expr=fc)], name="OpC") opC() @@ -903,7 +904,8 @@ def test_interp_complex_and_real(self, dtype): coordinates=sc.coordinates) fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) - fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) + rng = np.random.RandomState(0) + fc.data[:] = rng.randn(*grid.shape) + 1j * rng.randn(*grid.shape) exprs = sc.interpolate(expr=fc) + scre.interpolate(expr=Real(fc)) opC = Operator(exprs, name="OpC") opC() diff --git a/tests/test_ir.py b/tests/test_ir.py index 16440ec54a..a805fb01cf 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -8,6 +8,7 @@ switchconfig ) from devito.ir.cgen import ccode +from devito.ir.clusters import Cluster, ClusterGroup from devito.ir.equations import LoweredEq from devito.ir.equations.algorithms import dimension_sort from devito.ir.iet import FindNodes, Iteration @@ -17,7 +18,8 @@ ) from devito.ir.support.guards import GuardOverflow from devito.ir.support.space import ( - Backward, Forward, Interval, IntervalGroup, IterationSpace, NullInterval + Backward, Forward, Interval, IntervalGroup, IterationInterval, IterationSpace, + NullInterval, null_ispace ) from devito.symbolics import DefFunction, FieldFromPointer from devito.tools import prod @@ -140,6 +142,12 @@ def test_vector_cmp(self, v_num, v_literal): assert v2 <= vs3 assert vs3 > v2 + def test_timedaccess_cached(self, fc, x, y): + ta0 = TimedAccess(fc[x, y], 'R', 0) + ta1 = TimedAccess(fc[x, y], 'R', 0, null_ispace) + + assert ta0 is ta1 + def test_iteration_instance_arithmetic(self, x, y, ii_num, ii_literal): """ Test arithmetic operations involving objects of type IterationInstance. @@ -359,6 +367,60 @@ def x(self, grid): def y(self, grid): return grid.dimensions[1] + def test_null_interval_cache_identity(self, x): + i0 = NullInterval(x) + i1 = NullInterval(x) + + assert i0 is i1 + + def test_interval_cache_identity(self, x): + i0 = Interval(x, -2, 2) + i1 = Interval(x, -2, 2) + + assert i0 is i1 + + def test_iteration_interval_cache_identity(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + i0 = IterationInterval(Interval(x), (xi,), Forward) + i1 = IterationInterval(Interval(x), (xi,), Forward) + + assert i0 is i1 + + def test_iteration_interval_cache_distinguishes_sub_iterators(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + i0 = IterationInterval(Interval(x), (xi,), Forward) + i1 = IterationInterval(Interval(x), (), Forward) + + assert i0 is not i1 + + def test_interval_group_cache_identity(self, x, y): + ig0 = IntervalGroup([Interval(x, -2, 2), Interval(y, -1, 1)], + relations=((x, y),), mode='partial') + ig1 = IntervalGroup((Interval(x, -2, 2), Interval(y, -1, 1)), + relations=((x, y),), mode='partial') + + assert ig0 is ig1 + + def test_iteration_space_cache_identity(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + ispace0 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + ispace1 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + + assert ispace0 is ispace1 + assert isinstance(ispace0[x], IterationInterval) + assert ispace0[x] is ispace1[x] + + def test_iteration_space_cache_distinguishes_sub_iterators(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + ispace0 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + ispace1 = IterationSpace([Interval(x)], directions={x: Forward}) + + assert ispace0 is not ispace1 + def test_intervals_intersection(self, x, y): nullx = NullInterval(x) @@ -788,6 +850,18 @@ def test_indirect_access(self): v = scope.d_flow.pop() assert v.function is s1 + def test_ireq_function_views_indirect_indices(self): + grid = Grid(shape=(4,)) + x, = grid.dimensions + + u = Function(name='u', grid=grid) + f = Function(name='f', grid=grid) + a = Function(name='a', grid=grid) + + expr = LoweredEq(Eq(u, f[a[x]])) + + assert set(expr.read_functions) == {f, a} + def test_array_shared(self): grid = Grid(shape=(4, 4)) x, y = grid.dimensions @@ -1088,6 +1162,25 @@ def test_dimension_sort(self, expr, expected): assert list(dimension_sort(expr)) == eval(expected) +class TestClusterGroup: + + def test_eq_hash_include_ispace(self): + grid = Grid(shape=(4,)) + x, = grid.dimensions + + f = Function(name='f', grid=grid) + cluster = Cluster(Eq(f[x], 1)) + + ispace0 = IterationSpace([Interval(x, 0, 0)], directions={x: Forward}) + ispace1 = IterationSpace([Interval(x, 0, 0)], directions={x: Backward}) + + cgroup0 = ClusterGroup((cluster,), ispace0) + cgroup1 = ClusterGroup((cluster,), ispace1) + + assert cgroup0 != cgroup1 + assert len({cgroup0, cgroup1}) == 2 + + class TestGuards: def test_guard_overflow(self): diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index aab8502a20..ba93e08470 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -18,7 +18,7 @@ INT, BaseCast, CallFromPointer, Cast, DefFunction, FieldFromComposite, FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, RoundUp, Rvalue, SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, - retrieve_indexed, uxreplace + retrieve_indexed, subs_if_composite, uxreplace, xreplace_indices ) from devito.tools import CustomDtype, as_tuple, dtypes_vector_mapper from devito.types import ( @@ -848,6 +848,18 @@ def test_is_on_grid(): assert all(uu._grid_map == {} for uu in retrieve_functions(u.subs({x: x0}).evaluate)) +def test_retrieve_functions_mixed_carriers(): + grid = Grid((10,)) + x = grid.dimensions[0] + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + expr = f + FIndexed(g.base, x) + + assert retrieve_functions(expr, mode='unique') == {f, g} + + @pytest.mark.parametrize('expr,expected', [ ('f[x+2]*g[x+4] + f[x+3]*g[x+5] + f[x+4] + f[x+1]', ['f[x+2]', 'g[x+4]', 'f[x+3]', 'g[x+5]', 'f[x+1]', 'f[x+4]']), @@ -905,6 +917,55 @@ def test_expressions(self, expr, subs, expected): assert uxreplace(eval(expr), eval(subs)) == eval(expected) + def test_uxreplace_reuses_empty_substitution(self): + grid = Grid(shape=(4, 4)) + f = Function(name='f', grid=grid) + expr = f.indexify() + 1 + + assert uxreplace(expr, {}) is expr + + def test_subs_if_composite_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert subs_if_composite(exprs, {}) is exprs + assert subs_if_composite(exprs, {g[x, y]: f[x, y]}) is exprs + assert subs_if_composite(exprs, {g[x, y] + 1: f[x, y]}) is exprs + + processed = subs_if_composite(exprs, {f[x, y]: g[x, y]}) + + assert processed is not exprs + assert processed[0] is not exprs[0] + + def test_pow_to_mul_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + f = Function(name='f', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert pow_to_mul(exprs) is exprs + assert pow_to_mul([exprs[0]])[0] is exprs[0] + + processed = pow_to_mul((Eq(f[x, y], f[x, y]**2),)) + + assert processed is not exprs + + def test_xreplace_indices_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + z = Dimension(name='z') + f = Function(name='f', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert xreplace_indices(exprs, {z: z + 1}) is exprs + assert xreplace_indices(exprs, {x: x + 1}) is not exprs + def test_custom_reconstructable(self): class MyDefFunction(DefFunction): diff --git a/tests/test_tools.py b/tests/test_tools.py index 0b06883e78..2996844ee5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -7,8 +7,8 @@ from devito import Eq, Operator, switchenv from devito.tools import ( - CacheInstances, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, filter_ordered, - toposort, transitive_closure + CacheInstances, DefaultFrozenDict, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, + filter_ordered, memoized_meth, memoized_weak_meth, toposort, transitive_closure ) from devito.types.basic import Symbol @@ -61,6 +61,93 @@ def test_transitive_closure(): assert mapper == {a: d, b: d, c: d, f: e} +def test_memoized_meth(): + + class Obj: + + def __init__(self): + self.calls = 0 + + @memoized_meth + def f(self, x=None): + self.calls += 1 + return x + + obj = Obj() + + assert obj.f(1) == 1 + assert obj.f(1) == 1 + assert obj.calls == 1 + + assert obj.f(x=2) == 2 + assert obj.f(x=2) == 2 + assert obj.calls == 2 + + assert obj.f([3]) == [3] + assert obj.f([3]) == [3] + assert obj.calls == 4 + + +def test_memoized_weak_meth(): + + class Root: + pass + + class Obj: + + def __init__(self, mode): + self.mode = mode + self.calls = 0 + + @memoized_weak_meth(key=lambda i: i.mode, freeze=tuple, thaw=list) + def f(self, root): + self.calls += 1 + return [self.mode] + + root = Root() + obj0 = Obj('a') + obj1 = Obj('a') + obj2 = Obj('b') + + ret = obj0.f(root) + ret.append('mutated') + + assert obj1.f(root) == ['a'] + assert obj0.calls == 1 + assert obj1.calls == 0 + + assert obj2.f(root) == ['b'] + assert obj2.calls == 1 + + assert obj0.f([]) == ['a'] + assert obj0.f([]) == ['a'] + assert obj0.calls == 3 + + +def test_default_frozen_dict(): + mapper = DefaultFrozenDict({'a': 'b'}, default='c') + + assert mapper['a'] == 'b' + assert mapper['d'] == 'c' + assert mapper.get('d') is None + assert mapper.get('d', 'e') == 'e' + + copied = mapper.copy(c='d') + assert copied['c'] == 'd' + assert copied['e'] == 'c' + + +def test_default_frozen_dict_factory(): + mapper = DefaultFrozenDict(default=lambda: []) + + v0 = mapper[a] + v1 = mapper[b] + + assert v0 == [] + assert v1 == [] + assert v0 is not v1 + + def test_loops_in_transitive_closure(): a = Symbol('a') b = Symbol('b') @@ -212,6 +299,31 @@ def __init__(self, value: int): cache_size = Object._instance_cache.cache_info()[-1] assert cache_size == 0 + def test_uncached_subclass_bypasses_parent_preprocess(self): + """ + Tests that an uncached subclass does not inherit its parent's + preprocessing contract. + """ + class Parent(CacheInstances): + @classmethod + def _preprocess_args(cls, value): + return (value + 1,), {} + + def __init__(self, value: int): + self.value = value + + class Child(Parent): + _instance_cache_size = 0 + + def __init__(self, left: int, right: int): + self.value = (left, right) + + obj0 = Child(1, 2) + obj1 = Child(1, 2) + + assert obj0.value == (1, 2) + assert obj0 is not obj1 + def test_switchenv(): # Save previous environment @@ -224,5 +336,11 @@ def test_switchenv(): # Check a temporary variable is unset inside the context manager assert os.environ.get('TEST_VAR') is None + # Check an existing variable can be temporarily unset inside the context manager + with switchenv({'TEST_VAR_UNSET': 'foo'}): + with switchenv({'TEST_VAR_UNSET': None}): + assert os.environ.get('TEST_VAR_UNSET') is None + assert os.environ['TEST_VAR_UNSET'] == 'foo' + # Make sure the switchenv does not persist to verify switchenv works as intended assert dict(os.environ) == previous_environ diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 06eb933351..b5d12f81d5 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -6,8 +6,8 @@ from devito.ir.equations import DummyEq from devito.ir.iet import ( Block, Call, Callable, Conditional, Expression, FindApplications, FindNodes, - FindSections, FindSymbols, IsPerfectIteration, Iteration, MapNodes, Transformer, - printAST + FindSections, FindSymbols, FindWithin, IsPerfectIteration, Iteration, MapNodes, + Transformer, Uxreplace, printAST ) from devito.types import Array, SpaceDimension, Symbol @@ -210,6 +210,15 @@ def test_find_sections(exprs, block1, block2, block3): assert len(found[2]) == 1 +def test_find_within_not_cached_like_findnodes(block3): + expr0 = FindWithin(Expression, block3.nodes[0], block3.nodes[1]).visit(block3) + expr1 = FindWithin(Expression, block3.nodes[1], block3.nodes[2]).visit(block3) + + assert len(expr0) == 3 + assert len(expr1) == 3 + assert expr0 != expr1 + + def test_is_perfect_iteration(block1, block2, block3, block4): checker = IsPerfectIteration() @@ -249,6 +258,14 @@ def test_transformer_wrap(exprs, block1, block2, block3): assert "a[i] = a[i] + b[i] + 5.0F;" in newcode +def test_transformer_reuses_untouched_node(block1): + assert Transformer({}).visit(block1) is block1 + + +def test_uxreplace_reuses_untouched_node(block1): + assert Uxreplace({}).visit(block1) is block1 + + def test_transformer_replace(exprs, block1, block2, block3): """Basic transformer test that replaces an expression""" line1 = '// Replaced expression'