Source code for numpoly.array_function.apply_along_axis

"""Apply a function to 1-D slices along the given axis."""
from __future__ import annotations
from functools import wraps
from typing import Any, Callable, List

import numpy
import numpoly

from ..baseclass import ndpoly, PolyLike
from ..dispatch import implements


[docs]@implements(numpy.apply_along_axis) def apply_along_axis( func1d: Callable[[PolyLike], PolyLike], axis: int, arr: PolyLike, *args: Any, **kwargs: Any, ) -> ndpoly: """ Apply a function to 1-D slices along the given axis. Execute `func1d(a, *args)` where `func1d` operates on 1-D arrays and `a` is a 1-D slice of `arr` along `axis`. This is equivalent to (but faster than) the following use of `ndindex` and `s_`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of indices:: Ni, Nk = a.shape[:axis], a.shape[axis+1:] for ii in ndindex(Ni): for kk in ndindex(Nk): f = func1d(arr[ii+s_[:,]+kk]) Nj = f.shape for jj in ndindex(Nj): out[ii+jj+kk] = f[jj] Equivalently, eliminating the inner loop, this can be expressed as:: Ni, Nk = a.shape[:axis], a.shape[axis+1:] for ii in ndindex(Ni): for kk in ndindex(Nk): out[ii+s_[...,]+kk] = func1d(arr[ii+s_[:,]+kk]) Args: func1d: This function should accept 1-D arrays. It is applied to 1-D slices of `arr` along the specified axis. axis: Axis along which `arr` is sliced. arr: Input array. args: Additional arguments to `func1d`. kwargs: Additional named arguments to `func1d`. Return: The output array. The shape of `out` is identical to the shape of `arr`, except along the `axis` dimension. This axis is removed, and replaced with new dimensions equal to the shape of the return value of `func1d`. So if `func1d` returns a scalar `out` will have one fewer dimensions than `arr`. Example: >>> q0, q1 = numpoly.variable(2) >>> b = numpoly.polynomial([[1, 2, 3*q0], ... [3, 6*q1, 6], ... [2, 7, 9]]) >>> numpoly.apply_along_axis(numpoly.mean, 0, b) polynomial([2.0, 2.0*q1+3.0, q0+5.0]) >>> numpoly.apply_along_axis(numpoly.mean, 1, b) polynomial([q0+1.0, 2.0*q1+3.0, 6.0]) """ collection: List[ndpoly] = list() @wraps(func1d) def wrapper_func(array): """Wrap func1d function.""" # Align indeterminants in case slicing changed them array = numpoly.polynomial( array, names=arr.indeterminants, allocation=arr.allocation ) array, _ = numpoly.align.align_indeterminants(array, arr.indeterminants) # Evaluate function out = func1d(array, *args, **kwargs) # Restore indeterminants in case func1d changed them. out, _ = numpoly.align.align_indeterminants(out, arr.indeterminants) # Return dummy index integer value that will be replaced with # polynomials afterwards. ret_val = len(collection) * numpy.ones(out.shape, dtype=int) collection.append(out) return ret_val # Initiate wrapper arr = numpoly.aspolynomial(arr) out = numpy.apply_along_axis(wrapper_func, axis=axis, arr=arr.values) # align exponents polynomials = numpoly.align.align_exponents(*collection) dtype = numpoly.result_type(*polynomials) # Store results into new array ret_val = numpoly.ndpoly( exponents=polynomials[0].exponents, shape=out.shape, names=polynomials[0].indeterminants, dtype=dtype, ).values for idx, polynomial in enumerate(polynomials): ret_val[out == idx] = polynomial.values return numpoly.polynomial( ret_val, dtype=dtype, names=polynomials[0].indeterminants, allocation=polynomials[0].allocation, )