Source code for numpoly.array_function.apply_over_axes
"""Apply a function repeatedly over multiple axes."""
from __future__ import annotations
from functools import wraps
from typing import Callable
import numpy
import numpy.typing
import numpoly
from ..baseclass import ndpoly, PolyLike
from ..dispatch import implements
[docs]
@implements(numpy.apply_over_axes)
def apply_over_axes(
func: Callable,
a: PolyLike,
axes: numpy.typing.ArrayLike,
) -> ndpoly:
"""
Apply a function repeatedly over multiple axes.
`func` is called as `res = func(a, axis)`, where `axis` is the first
element of `axes`. The result `res` of the function call must have either
the same dimensions as `a` or one less dimension. If `res` has one less
dimension than `a`, a dimension is inserted before `axis`. The call to
`func` is then repeated for each axis in `axes`, with `res` as the first
argument.
Args:
func:
This function must take two arguments, `func(a, axis)`.
a:
Input array.
axes:
Axes over which `func` is applied; the elements must be integers.
Return:
The output array. The number of dimensions is the same as `a`, but
the shape can be different. This depends on whether `func` changes
the shape of its output with respect to its input.
Example:
>>> polynomial = numpy.arange(24).reshape(2, 4, 3)
>>> polynomial = polynomial*numpoly.variable(3)
>>> numpoly.apply_over_axes(
... func=numpoly.sum, a=polynomial, axes=1)
polynomial([[[18*q0, 22*q1, 26*q2]],
<BLANKLINE>
[[66*q0, 70*q1, 74*q2]]])
>>> numpoly.apply_over_axes(
... func=numpoly.sum, a=polynomial, axes=[0, 2])
polynomial([[[16*q2+14*q1+12*q0],
[22*q2+20*q1+18*q0],
[28*q2+26*q1+24*q0],
[34*q2+32*q1+30*q0]]])
"""
@wraps(func)
def wrapper_func(array, axis):
"""Wrap func function."""
# Align indeterminants in case slicing changed them
array = numpoly.polynomial(array, names=a.indeterminants)
array, _ = numpoly.align.align_indeterminants(array, a.indeterminants)
# Evaluate function
out = func(array, axis=axis)
# Restore indeterminants in case func changed them.
out, _ = numpoly.align.align_indeterminants(out, a.indeterminants)
return out
# Initiate wrapper
a = numpoly.aspolynomial(a)
out = numpy.apply_over_axes(wrapper_func, a=a.values, axes=axes)
return out