Source code for numpoly.utils.cross_truncation

"""Truncation rules for indices."""
from __future__ import annotations

import numpy
import numpy.typing


[docs]def cross_truncate( indices: numpy.typing.ArrayLike, bound: numpy.typing.ArrayLike, norm: float, ) -> numpy.ndarray: r""" Truncate of indices using L_p norm. .. math: L_p(x) = (\sum_i |x_i/b_i|^p )^{1/p} \leq 1 where :math:`b_i` are bounds that each :math:`x_i` should follow. Args: indices: Indices to be truncated. bound: The bound function for witch the indices can not be larger than. norm: The `p` in the `L_p`-norm. Support includes both `L_0` and `L_inf`. Return: Boolean indices to ``indices`` with True for each index where the truncation criteria holds. Example: >>> indices = numpy.array(numpy.mgrid[:10, :10]).reshape(2, -1).T >>> indices[cross_truncate(indices, 2, norm=0.)].T array([[0, 0, 0, 1, 2], [0, 1, 2, 0, 0]]) >>> indices[cross_truncate(indices, 2, norm=1.)].T array([[0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 0]]) >>> indices[cross_truncate(indices, [0, 1], norm=1.)].T array([[0, 0], [0, 1]]) """ assert norm >= 0, "negative L_p norm not allowed" indices = numpy.asarray(indices) bound_ = numpy.broadcast_to(numpy.asfarray(bound).ravel(), (indices.shape[1],)) nudge_factor = 1e-12 * indices.shape[1] if numpy.any(bound_ < 0): return numpy.zeros((len(indices),), dtype=bool) if numpy.any(bound_ == 0): out = numpy.all(indices[:, bound_ == 0] == 0, axis=-1) if numpy.any(bound_): out &= cross_truncate( indices[:, bound_ != 0], bound_[bound_ != 0], norm=norm ) return out if norm == 0: out = numpy.sum(indices > 0, axis=-1) <= 1 + nudge_factor out[numpy.any(indices > bound_, axis=-1)] = False elif norm == numpy.inf: out = numpy.max(indices / bound_, axis=-1) <= 1 + nudge_factor else: out = ( numpy.sum((indices / bound_) ** norm, axis=-1) ** (1.0 / norm) <= 1 + nudge_factor ) assert numpy.all(out[numpy.all(indices == 0, axis=-1)]) return out