Source code for numpoly.array_function.dsplit

"""Split array into multiple sub-arrays along the 3rd axis (depth)."""
from __future__ import annotations
from typing import List

import numpy
import numpy.typing
import numpoly

from ..baseclass import ndpoly, PolyLike
from ..dispatch import implements


[docs]@implements(numpy.dsplit) def dsplit( ary: PolyLike, indices_or_sections: numpy.typing.ArrayLike, ) -> List[ndpoly]: """ Split array into multiple sub-arrays along the 3rd axis (depth). Please refer to the `split` documentation. `dsplit` is equivalent to `split` with ``axis=2``, the array is always split along the third axis provided the array dimension is greater than or equal to 3. Example: >>> poly = numpoly.monomial(8).reshape(2, 2, 2) >>> poly polynomial([[[1, q0], [q0**2, q0**3]], <BLANKLINE> [[q0**4, q0**5], [q0**6, q0**7]]]) >>> part1, part2 = numpoly.dsplit(poly, 2) >>> part1 polynomial([[[1], [q0**2]], <BLANKLINE> [[q0**4], [q0**6]]]) >>> part2 polynomial([[[q0], [q0**3]], <BLANKLINE> [[q0**5], [q0**7]]]) """ ary = numpoly.aspolynomial(ary) results = numpy.dsplit(ary.values, indices_or_sections=indices_or_sections) return [ numpoly.polynomial(result, names=ary.indeterminants, allocation=ary.allocation) for result in results ]