up follow livre

This commit is contained in:
Tykayn 2025-08-30 18:14:14 +02:00 committed by tykayn
parent 70a5c3465c
commit cffb31c1ef
12198 changed files with 2562132 additions and 35 deletions

View file

@ -0,0 +1,161 @@
# This file is generated by SciPy's build process
# It contains system_info results at the time of building this package.
from enum import Enum
__all__ = ["show"]
_built_with_meson = True
class DisplayModes(Enum):
stdout = "stdout"
dicts = "dicts"
def _cleanup(d):
"""
Removes empty values in a `dict` recursively
This ensures we remove values that Meson could not provide to CONFIG
"""
if isinstance(d, dict):
return { k: _cleanup(v) for k, v in d.items() if v != '' and _cleanup(v) != '' }
else:
return d
CONFIG = _cleanup(
{
"Compilers": {
"c": {
"name": "gcc",
"linker": r"ld.bfd",
"version": "10.2.1",
"commands": r"cc",
"args": r"",
"linker args": r"",
},
"cython": {
"name": r"cython",
"linker": r"cython",
"version": r"3.1.2",
"commands": r"cython",
"args": r"",
"linker args": r"",
},
"c++": {
"name": "gcc",
"linker": r"ld.bfd",
"version": "10.2.1",
"commands": r"c++",
"args": r"",
"linker args": r"",
},
"fortran": {
"name": "gcc",
"linker": r"ld.bfd",
"version": "10.2.1",
"commands": r"gfortran",
"args": r"",
"linker args": r"",
},
"pythran": {
"version": r"0.18.0",
"include directory": r"../../tmp/build-env-0fmy4m5d/lib/python3.13/site-packages/pythran"
},
},
"Machine Information": {
"host": {
"cpu": r"x86_64",
"family": r"x86_64",
"endian": r"little",
"system": r"linux",
},
"build": {
"cpu": r"x86_64",
"family": r"x86_64",
"endian": r"little",
"system": r"linux",
},
"cross-compiled": bool("False".lower().replace('false', '')),
},
"Build Dependencies": {
"blas": {
"name": "scipy-openblas",
"found": bool("True".lower().replace('false', '')),
"version": "0.3.28",
"detection method": "pkgconfig",
"include directory": r"/opt/_internal/cpython-3.13.5/lib/python3.13/site-packages/scipy_openblas32/include",
"lib directory": r"/opt/_internal/cpython-3.13.5/lib/python3.13/site-packages/scipy_openblas32/lib",
"openblas configuration": r"OpenBLAS 0.3.28 DYNAMIC_ARCH NO_AFFINITY Haswell MAX_THREADS=64",
"pc file directory": r"/project",
},
"lapack": {
"name": "scipy-openblas",
"found": bool("True".lower().replace('false', '')),
"version": "0.3.28",
"detection method": "pkgconfig",
"include directory": r"/opt/_internal/cpython-3.13.5/lib/python3.13/site-packages/scipy_openblas32/include",
"lib directory": r"/opt/_internal/cpython-3.13.5/lib/python3.13/site-packages/scipy_openblas32/lib",
"openblas configuration": r"OpenBLAS 0.3.28 DYNAMIC_ARCH NO_AFFINITY Haswell MAX_THREADS=64",
"pc file directory": r"/project",
},
"pybind11": {
"name": "pybind11",
"version": "3.0.0",
"detection method": "config-tool",
"include directory": r"unknown",
},
},
"Python Information": {
"path": r"/tmp/build-env-0fmy4m5d/bin/python",
"version": "3.13",
},
}
)
def _check_pyyaml():
import yaml
return yaml
def show(mode=DisplayModes.stdout.value):
"""
Show libraries and system information on which SciPy was built
and is being used
Parameters
----------
mode : {`'stdout'`, `'dicts'`}, optional.
Indicates how to display the config information.
`'stdout'` prints to console, `'dicts'` returns a dictionary
of the configuration.
Returns
-------
out : {`dict`, `None`}
If mode is `'dicts'`, a dict is returned, else None
Notes
-----
1. The `'stdout'` mode will give more readable
output if ``pyyaml`` is installed
"""
if mode == DisplayModes.stdout.value:
try: # Non-standard library, check import
yaml = _check_pyyaml()
print(yaml.dump(CONFIG))
except ModuleNotFoundError:
import warnings
import json
warnings.warn("Install `pyyaml` for better output", stacklevel=1)
print(json.dumps(CONFIG, indent=2))
elif mode == DisplayModes.dicts.value:
return CONFIG
else:
raise AttributeError(
f"Invalid `mode`, use one of: {', '.join([e.value for e in DisplayModes])}"
)

View file

@ -0,0 +1,138 @@
"""
SciPy: A scientific computing package for Python
================================================
Documentation is available in the docstrings and
online at https://docs.scipy.org/doc/scipy/
Subpackages
-----------
::
cluster --- Vector Quantization / Kmeans
constants --- Physical and mathematical constants and units
datasets --- Dataset methods
differentiate --- Finite difference differentiation tools
fft --- Discrete Fourier transforms
fftpack --- Legacy discrete Fourier transforms
integrate --- Integration routines
interpolate --- Interpolation Tools
io --- Data input and output
linalg --- Linear algebra routines
ndimage --- N-D image package
odr --- Orthogonal Distance Regression
optimize --- Optimization Tools
signal --- Signal Processing Tools
sparse --- Sparse Matrices
spatial --- Spatial data structures and algorithms
special --- Special functions
stats --- Statistical Functions
Public API in the main SciPy namespace
--------------------------------------
::
__version__ --- SciPy version string
LowLevelCallable --- Low-level callback function
show_config --- Show scipy build configuration
test --- Run scipy unittests
"""
import importlib as _importlib
from numpy import __version__ as __numpy_version__
try:
from scipy.__config__ import show as show_config
except ImportError as e:
msg = """Error importing SciPy: you cannot import SciPy while
being in scipy source directory; please exit the SciPy source
tree first and relaunch your Python interpreter."""
raise ImportError(msg) from e
from scipy.version import version as __version__
# Allow distributors to run custom init code
from . import _distributor_init
del _distributor_init
from scipy._lib import _pep440
# In maintenance branch, change to np_maxversion N+3 if numpy is at N
np_minversion = '1.25.2'
np_maxversion = '2.6.0'
if (_pep440.parse(__numpy_version__) < _pep440.Version(np_minversion) or
_pep440.parse(__numpy_version__) >= _pep440.Version(np_maxversion)):
import warnings
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
f" is required for this version of SciPy (detected "
f"version {__numpy_version__})",
UserWarning, stacklevel=2)
del _pep440
# This is the first import of an extension module within SciPy. If there's
# a general issue with the install, such that extension modules are missing
# or cannot be imported, this is where we'll get a failure - so give an
# informative error message.
try:
from scipy._lib._ccallback import LowLevelCallable
except ImportError as e:
msg = "The `scipy` install you are using seems to be broken, " + \
"(extension modules cannot be imported), " + \
"please try reinstalling."
raise ImportError(msg) from e
from scipy._lib._testutils import PytestTester
test = PytestTester(__name__)
del PytestTester
submodules = [
'cluster',
'constants',
'datasets',
'differentiate',
'fft',
'fftpack',
'integrate',
'interpolate',
'io',
'linalg',
'ndimage',
'odr',
'optimize',
'signal',
'sparse',
'spatial',
'special',
'stats'
]
__all__ = submodules + [
'LowLevelCallable',
'test',
'show_config',
'__version__',
]
def __dir__():
return __all__
def __getattr__(name):
if name in submodules:
return _importlib.import_module(f'scipy.{name}')
else:
try:
return globals()[name]
except KeyError:
raise AttributeError(
f"Module 'scipy' has no attribute '{name}'"
)

View file

@ -0,0 +1,18 @@
""" Distributor init file
Distributors: you can replace the contents of this file with your own custom
code to support particular distributions of SciPy.
For example, this is a good place to put any checks for hardware requirements
or BLAS/LAPACK library initialization.
The SciPy standard source distribution will not put code in this file beyond
the try-except import of `_distributor_init_local` (which is not part of a
standard source distribution), so you can safely replace this file with your
own version.
"""
try:
from . import _distributor_init_local # noqa: F401
except ImportError:
pass

View file

@ -0,0 +1,14 @@
"""
Module containing private utility functions
===========================================
The ``scipy._lib`` namespace is empty (for now). Tests for all
utilities in submodules of ``_lib`` can be run with::
from scipy import _lib
_lib.test()
"""
from scipy._lib._testutils import PytestTester
test = PytestTester(__name__)
del PytestTester

View file

@ -0,0 +1,931 @@
"""Utility functions to use Python Array API compatible libraries.
For the context about the Array API see:
https://data-apis.org/array-api/latest/purpose_and_scope.html
The SciPy use case of the Array API is described on the following page:
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
"""
import contextlib
import dataclasses
import functools
import os
import textwrap
from collections.abc import Generator, Iterable, Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from types import ModuleType
from typing import Any, Literal, TypeAlias
import numpy as np
import numpy.typing as npt
from scipy._lib import array_api_compat
from scipy._lib.array_api_compat import (
is_array_api_obj,
is_lazy_array,
size as xp_size,
numpy as np_compat,
device as xp_device,
is_numpy_namespace as is_numpy,
is_cupy_namespace as is_cupy,
is_torch_namespace as is_torch,
is_jax_namespace as is_jax,
is_dask_namespace as is_dask,
is_array_api_strict_namespace as is_array_api_strict
)
from scipy._lib._sparse import issparse
from scipy._lib._docscrape import FunctionDoc
__all__ = [
'_asarray', 'array_namespace', 'assert_almost_equal', 'assert_array_almost_equal',
'default_xp', 'eager_warns', 'is_lazy_array', 'is_marray',
'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch',
'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for',
'xp_assert_close', 'xp_assert_equal', 'xp_assert_less',
'xp_copy', 'xp_device', 'xp_ravel', 'xp_size',
'xp_unsupported_param_msg', 'xp_vector_norm', 'xp_capabilities',
'xp_result_type', 'xp_promote'
]
# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
_GLOBAL_CONFIG = {
"SCIPY_ARRAY_API": SCIPY_ARRAY_API,
"SCIPY_DEVICE": SCIPY_DEVICE,
}
Array: TypeAlias = Any # To be changed to a Protocol later (see array-api#589)
ArrayLike: TypeAlias = Array | npt.ArrayLike
def _compliance_scipy(arrays: Iterable[ArrayLike]) -> Iterator[Array]:
"""Raise exceptions on known-bad subclasses. Discard 0-dimensional ArrayLikes
and convert 1+-dimensional ArrayLikes to numpy.
The following subclasses are not supported and raise and error:
- `numpy.ma.MaskedArray`
- `numpy.matrix`
- NumPy arrays which do not have a boolean or numerical dtype
- Any array-like which is neither array API compatible nor coercible by NumPy
- Any array-like which is coerced by NumPy to an unsupported dtype
"""
for array in arrays:
if array is None:
continue
# this comes from `_util._asarray_validated`
if issparse(array):
msg = ('Sparse arrays/matrices are not supported by this function. '
'Perhaps one of the `scipy.sparse.linalg` functions '
'would work instead.')
raise ValueError(msg)
if isinstance(array, np.ma.MaskedArray):
raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
if isinstance(array, np.matrix):
raise TypeError("Inputs of type `numpy.matrix` are not supported.")
if isinstance(array, np.ndarray | np.generic):
dtype = array.dtype
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
raise TypeError(f"An argument has dtype `{dtype!r}`; "
f"only boolean and numerical dtypes are supported.")
if is_array_api_obj(array):
yield array
else:
try:
array = np.asanyarray(array)
except TypeError:
raise TypeError("An argument is neither array API compatible nor "
"coercible by NumPy.")
dtype = array.dtype
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
message = (
f"An argument was coerced to an unsupported dtype `{dtype!r}`; "
f"only boolean and numerical dtypes are supported."
)
raise TypeError(message)
# Ignore 0-dimensional arrays, coherently with array-api-compat.
# Raise if there are 1+-dimensional array-likes mixed with non-numpy
# Array API objects.
if array.ndim:
yield array
def _check_finite(array: Array, xp: ModuleType) -> None:
"""Check for NaNs or Infs."""
if not xp.all(xp.isfinite(array)):
msg = "array must not contain infs or NaNs"
raise ValueError(msg)
def array_namespace(*arrays: Array) -> ModuleType:
"""Get the array API compatible namespace for the arrays xs.
Parameters
----------
*arrays : sequence of array_like
Arrays used to infer the common namespace.
Returns
-------
namespace : module
Common namespace.
Notes
-----
Thin wrapper around `array_api_compat.array_namespace`.
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
2. `_compliance_scipy` raise exceptions on known-bad subclasses. See
its definition for more details.
When the global switch is False, it defaults to the `numpy` namespace.
In that case, there is no compliance check. This is a convenience to
ease the adoption. Otherwise, arrays must comply with the new rules.
"""
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
# here we could wrap the namespace if needed
return np_compat
api_arrays = list(_compliance_scipy(arrays))
# In case of a mix of array API compliant arrays and scalars, return
# the array API namespace. If there are only ArrayLikes (e.g. lists),
# return NumPy (wrapped by array-api-compat).
if api_arrays:
return array_api_compat.array_namespace(*api_arrays)
return np_compat
def _asarray(
array: ArrayLike,
dtype: Any = None,
order: Literal['K', 'A', 'C', 'F'] | None = None,
copy: bool | None = None,
*,
xp: ModuleType | None = None,
check_finite: bool = False,
subok: bool = False,
) -> Array:
"""SciPy-specific replacement for `np.asarray` with `order`, `check_finite`, and
`subok`.
Memory layout parameter `order` is not exposed in the Array API standard.
`order` is only enforced if the input array implementation
is NumPy based, otherwise `order` is just silently ignored.
`check_finite` is also not a keyword in the array API standard; included
here for convenience rather than that having to be a separate function
call inside SciPy functions.
`subok` is included to allow this function to preserve the behaviour of
`np.asanyarray` for NumPy based inputs.
"""
if xp is None:
xp = array_namespace(array)
if is_numpy(xp):
# Use NumPy API to support order
if copy is True:
array = np.array(array, order=order, dtype=dtype, subok=subok)
elif subok:
array = np.asanyarray(array, order=order, dtype=dtype)
else:
array = np.asarray(array, order=order, dtype=dtype)
else:
try:
array = xp.asarray(array, dtype=dtype, copy=copy)
except TypeError:
coerced_xp = array_namespace(xp.asarray(3))
array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
if check_finite:
_check_finite(array, xp)
return array
def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
"""
Copies an array.
Parameters
----------
x : array
xp : array_namespace
Returns
-------
copy : array
Copied array
Notes
-----
This copy function does not offer all the semantics of `np.copy`, i.e. the
`subok` and `order` keywords are not used.
"""
# Note: for older NumPy versions, `np.asarray` did not support the `copy` kwarg,
# so this uses our other helper `_asarray`.
if xp is None:
xp = array_namespace(x)
return _asarray(x, copy=True, xp=xp)
_default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp")
@contextmanager
def default_xp(xp: ModuleType) -> Generator[None, None, None]:
"""In all ``xp_assert_*`` and ``assert_*`` function calls executed within this
context manager, test by default that the array namespace is
the provided across all arrays, unless one explicitly passes the ``xp=``
parameter or ``check_namespace=False``.
Without this context manager, the default value for `xp` is the namespace
for the desired array (the second parameter of the tests).
"""
token = _default_xp_ctxvar.set(xp)
try:
yield
finally:
_default_xp_ctxvar.reset(token)
def eager_warns(x, warning_type, match=None):
"""pytest.warns context manager, but only if x is not a lazy array."""
import pytest
# This attribute is interpreted by pytest-run-parallel, ensuring that tests that use
# `eager_warns` aren't run in parallel (since pytest.warns isn't thread-safe).
__thread_safe__ = False # noqa: F841
if is_lazy_array(x):
return contextlib.nullcontext()
return pytest.warns(warning_type, match=match)
def _strict_check(actual, desired, xp, *,
check_namespace=True, check_dtype=True, check_shape=True,
check_0d=True):
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
try:
xp = _default_xp_ctxvar.get()
except LookupError:
xp = array_namespace(desired)
if check_namespace:
_assert_matching_namespace(actual, desired, xp)
# only NumPy distinguishes between scalars and arrays; we do if check_0d=True.
# do this first so we can then cast to array (and thus use the array API) below.
if is_numpy(xp) and check_0d:
_msg = ("Array-ness does not match:\n Actual: "
f"{type(actual)}\n Desired: {type(desired)}")
assert ((xp.isscalar(actual) and xp.isscalar(desired))
or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
actual = xp.asarray(actual)
desired = xp.asarray(desired)
if check_dtype:
_msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
assert actual.dtype == desired.dtype, _msg
if check_shape:
if is_dask(xp):
actual.compute_chunk_sizes()
desired.compute_chunk_sizes()
_msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
assert actual.shape == desired.shape, _msg
desired = xp.broadcast_to(desired, actual.shape)
return actual, desired, xp
def _assert_matching_namespace(actual, desired, xp):
__tracebackhide__ = True # Hide traceback for py.test
desired_arr_space = array_namespace(desired)
_msg = ("Namespace of desired array does not match expectations "
"set by the `default_xp` context manager or by the `xp`"
"pytest fixture.\n"
f"Desired array's space: {desired_arr_space.__name__}\n"
f"Expected namespace: {xp.__name__}")
assert desired_arr_space == xp, _msg
actual_arr_space = array_namespace(actual)
_msg = ("Namespace of actual and desired arrays do not match.\n"
f"Actual: {actual_arr_space.__name__}\n"
f"Desired: {xp.__name__}")
assert actual_arr_space == xp, _msg
def xp_assert_equal(actual, desired, *, check_namespace=True, check_dtype=True,
check_shape=True, check_0d=True, err_msg='', xp=None):
__tracebackhide__ = True # Hide traceback for py.test
actual, desired, xp = _strict_check(
actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape,
check_0d=check_0d
)
if is_cupy(xp):
return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
elif is_torch(xp):
# PyTorch recommends using `rtol=0, atol=0` like this
# to test for exact equality
err_msg = None if err_msg == '' else err_msg
return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
check_dtype=False, msg=err_msg)
# JAX uses `np.testing`
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
def xp_assert_close(actual, desired, *, rtol=None, atol=0, check_namespace=True,
check_dtype=True, check_shape=True, check_0d=True,
err_msg='', xp=None):
__tracebackhide__ = True # Hide traceback for py.test
actual, desired, xp = _strict_check(
actual, desired, xp,
check_namespace=check_namespace, check_dtype=check_dtype,
check_shape=check_shape, check_0d=check_0d
)
floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating'))
if rtol is None and floating:
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
# roughly half way between sqrt(eps) and the default for
# `numpy.testing.assert_allclose`, 1e-7
rtol = xp.finfo(actual.dtype).eps**0.5 * 4
elif rtol is None:
rtol = 1e-7
if is_cupy(xp):
return xp.testing.assert_allclose(actual, desired, rtol=rtol,
atol=atol, err_msg=err_msg)
elif is_torch(xp):
err_msg = None if err_msg == '' else err_msg
return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
equal_nan=True, check_dtype=False, msg=err_msg)
# JAX uses `np.testing`
return np.testing.assert_allclose(actual, desired, rtol=rtol,
atol=atol, err_msg=err_msg)
def xp_assert_less(actual, desired, *, check_namespace=True, check_dtype=True,
check_shape=True, check_0d=True, err_msg='', verbose=True, xp=None):
__tracebackhide__ = True # Hide traceback for py.test
actual, desired, xp = _strict_check(
actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape,
check_0d=check_0d
)
if is_cupy(xp):
return xp.testing.assert_array_less(actual, desired,
err_msg=err_msg, verbose=verbose)
elif is_torch(xp):
if actual.device.type != 'cpu':
actual = actual.cpu()
if desired.device.type != 'cpu':
desired = desired.cpu()
# JAX uses `np.testing`
return np.testing.assert_array_less(actual, desired,
err_msg=err_msg, verbose=verbose)
def assert_array_almost_equal(actual, desired, decimal=6, *args, **kwds):
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
"""
rtol, atol = 0, 1.5*10**(-decimal)
return xp_assert_close(actual, desired,
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
*args, **kwds)
def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
"""
rtol, atol = 0, 1.5*10**(-decimal)
return xp_assert_close(actual, desired,
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
*args, **kwds)
def xp_unsupported_param_msg(param: Any) -> str:
return f'Providing {param!r} is only supported for numpy arrays.'
def is_complex(x: Array, xp: ModuleType) -> bool:
return xp.isdtype(x.dtype, 'complex floating')
def scipy_namespace_for(xp: ModuleType) -> ModuleType | None:
"""Return the `scipy`-like namespace of a non-NumPy backend
That is, return the namespace corresponding with backend `xp` that contains
`scipy` sub-namespaces like `linalg` and `special`. If no such namespace
exists, return ``None``. Useful for dispatching.
"""
if is_cupy(xp):
import cupyx # type: ignore[import-not-found,import-untyped]
return cupyx.scipy
if is_jax(xp):
import jax # type: ignore[import-not-found]
return jax.scipy
if is_torch(xp):
return xp
return None
# maybe use `scipy.linalg` if/when array API support is added
def xp_vector_norm(x: Array, /, *,
axis: int | tuple[int] | None = None,
keepdims: bool = False,
ord: int | float = 2,
xp: ModuleType | None = None) -> Array:
xp = array_namespace(x) if xp is None else xp
if SCIPY_ARRAY_API:
# check for optional `linalg` extension
if hasattr(xp, 'linalg'):
return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord)
else:
if ord != 2:
raise ValueError(
"only the Euclidean norm (`ord=2`) is currently supported in "
"`xp_vector_norm` for backends not implementing the `linalg` "
"extension."
)
# return (x @ x)**0.5
# or to get the right behavior with nd, complex arrays
return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5
else:
# to maintain backwards compatibility
return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
def xp_ravel(x: Array, /, *, xp: ModuleType | None = None) -> Array:
# Equivalent of np.ravel written in terms of array API
# Even though it's one line, it comes up so often that it's worth having
# this function for readability
xp = array_namespace(x) if xp is None else xp
return xp.reshape(x, (-1,))
def xp_swapaxes(a, axis1, axis2, xp=None):
# Equivalent of np.swapaxes written in terms of array API
xp = array_namespace(a) if xp is None else xp
axes = list(range(a.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
a = xp.permute_dims(a, axes)
return a
# utility to find common dtype with option to force floating
def xp_result_type(*args, force_floating=False, xp):
"""
Returns the dtype that results from applying type promotion rules
(see Array API Standard Type Promotion Rules) to the arguments. Augments
standard `result_type` in a few ways:
- There is a `force_floating` argument that ensures that the result type
is floating point, even when all args are integer.
- When a TypeError is raised (e.g. due to an unsupported promotion)
and `force_floating=True`, we define a custom rule: use the result type
of the default float and any other floats passed. See
https://github.com/scipy/scipy/pull/22695/files#r1997905891
for rationale.
- This function accepts array-like iterables, which are immediately converted
to the namespace's arrays before result type calculation. Consequently, the
result dtype may be different when an argument is `1.` vs `[1.]`.
Typically, this function will be called shortly after `array_namespace`
on a subset of the arguments passed to `array_namespace`.
"""
args = [(_asarray(arg, subok=True, xp=xp) if np.iterable(arg) else arg)
for arg in args]
args_not_none = [arg for arg in args if arg is not None]
if force_floating:
args_not_none.append(1.0)
if is_numpy(xp) and xp.__version__ < '2.0':
# Follow NEP 50 promotion rules anyway
args_not_none = [arg.dtype if getattr(arg, 'size', 0) == 1 else arg
for arg in args_not_none]
return xp.result_type(*args_not_none)
try: # follow library's preferred promotion rules
return xp.result_type(*args_not_none)
except TypeError: # mixed type promotion isn't defined
if not force_floating:
raise
# use `result_type` of default floating point type and any floats present
# This can be revisited, but right now, the only backends that get here
# are array-api-strict (which is not for production use) and PyTorch
# (due to data-apis/array-api-compat#279).
float_args = []
for arg in args_not_none:
arg_array = xp.asarray(arg) if np.isscalar(arg) else arg
dtype = getattr(arg_array, 'dtype', arg)
if xp.isdtype(dtype, ('real floating', 'complex floating')):
float_args.append(arg)
return xp.result_type(*float_args, xp_default_dtype(xp))
def xp_promote(*args, broadcast=False, force_floating=False, xp):
"""
Promotes elements of *args to result dtype, ignoring `None`s.
Includes options for forcing promotion to floating point and
broadcasting the arrays, again ignoring `None`s.
Type promotion rules follow `xp_result_type` instead of `xp.result_type`.
Typically, this function will be called shortly after `array_namespace`
on a subset of the arguments passed to `array_namespace`.
This function accepts array-like iterables, which are immediately converted
to the namespace's arrays before result type calculation. Consequently, the
result dtype may be different when an argument is `1.` vs `[1.]`.
See Also
--------
xp_result_type
"""
args = [(_asarray(arg, subok=True, xp=xp) if np.iterable(arg) else arg)
for arg in args] # solely to prevent double conversion of iterable to array
dtype = xp_result_type(*args, force_floating=force_floating, xp=xp)
args = [(_asarray(arg, dtype=dtype, subok=True, xp=xp) if arg is not None else arg)
for arg in args]
if not broadcast:
return args[0] if len(args)==1 else tuple(args)
args_not_none = [arg for arg in args if arg is not None]
# determine result shape
shapes = {arg.shape for arg in args_not_none}
try:
shape = (np.broadcast_shapes(*shapes) if len(shapes) != 1
else args_not_none[0].shape)
except ValueError as e:
message = "Array shapes are incompatible for broadcasting."
raise ValueError(message) from e
out = []
for arg in args:
if arg is None:
out.append(arg)
continue
# broadcast only if needed
# Even if two arguments need broadcasting, this is faster than
# `broadcast_arrays`, especially since we've already determined `shape`
if arg.shape != shape:
kwargs = {'subok': True} if is_numpy(xp) else {}
arg = xp.broadcast_to(arg, shape, **kwargs)
# This is much faster than xp.astype(arg, dtype, copy=False)
if arg.dtype != dtype:
arg = xp.astype(arg, dtype)
out.append(arg)
return out[0] if len(out)==1 else tuple(out)
def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array:
xp = array_namespace(arr) if xp is None else xp
arr_dtype = arr.dtype
# The standard float dtypes are float32 and float64.
# Convert float32 to complex64,
# and float64 (and non-standard real dtypes) to complex128
if xp.isdtype(arr_dtype, xp.float32):
arr = xp.astype(arr, xp.complex64)
elif xp.isdtype(arr_dtype, 'real floating'):
arr = xp.astype(arr, xp.complex128)
return arr
def xp_default_dtype(xp):
"""Query the namespace-dependent default floating-point dtype.
"""
if is_torch(xp):
# historically, we allow pytorch to keep its default of float32
return xp.get_default_dtype()
else:
# we default to float64
return xp.float64
def xp_result_device(*args):
"""Return the device of an array in `args`, for the purpose of
input-output device propagation.
If there are multiple devices, return an arbitrary one.
If there are no arrays, return None (this typically happens only on NumPy).
"""
for arg in args:
# Do not do a duck-type test for the .device attribute, as many backends today
# don't have it yet. See workarouunds in array_api_compat.device().
if is_array_api_obj(arg):
return xp_device(arg)
return None
def is_marray(xp):
"""Returns True if `xp` is an MArray namespace; False otherwise."""
return "marray" in xp.__name__
@dataclasses.dataclass(repr=False)
class _XPSphinxCapability:
cpu: bool | None # None if not applicable
gpu: bool | None
warnings: list[str] = dataclasses.field(default_factory=list)
def _render(self, value):
if value is None:
return "n/a"
if not value:
return ""
if self.warnings:
res = "⚠️ " + '; '.join(self.warnings)
assert len(res) <= 20, "Warnings too long"
return res
return ""
def __str__(self):
cpu = self._render(self.cpu)
gpu = self._render(self.gpu)
return f"{cpu:20} {gpu:20}"
def _make_sphinx_capabilities(
# lists of tuples [(module name, reason), ...]
skip_backends=(), xfail_backends=(),
# @pytest.mark.skip/xfail_xp_backends kwargs
cpu_only=False, np_only=False, exceptions=(),
# xpx.lazy_xp_backends kwargs
allow_dask_compute=False, jax_jit=True,
# list of tuples [(module name, reason), ...]
warnings = (),
# unused in documentation
reason=None,
):
exceptions = set(exceptions)
# Default capabilities
capabilities = {
"numpy": _XPSphinxCapability(cpu=True, gpu=None),
"array_api_strict": _XPSphinxCapability(cpu=True, gpu=None),
"cupy": _XPSphinxCapability(cpu=None, gpu=True),
"torch": _XPSphinxCapability(cpu=True, gpu=True),
"jax.numpy": _XPSphinxCapability(cpu=True, gpu=True,
warnings=[] if jax_jit else ["no JIT"]),
# Note: Dask+CuPy is currently untested and unsupported
"dask.array": _XPSphinxCapability(cpu=True, gpu=None,
warnings=["computes graph"] if allow_dask_compute else []),
}
# documentation doesn't display the reason
for module, _ in list(skip_backends) + list(xfail_backends):
backend = capabilities[module]
if backend.cpu is not None:
backend.cpu = False
if backend.gpu is not None:
backend.gpu = False
for module, backend in capabilities.items():
if np_only and module not in exceptions | {"numpy"}:
if backend.cpu is not None:
backend.cpu = False
if backend.gpu is not None:
backend.gpu = False
elif cpu_only and module not in exceptions and backend.gpu is not None:
backend.gpu = False
for module, warning in warnings:
backend = capabilities[module]
backend.warnings.append(warning)
return capabilities
def _make_capabilities_note(fun_name, capabilities):
# Note: deliberately not documenting array-api-strict
note = f"""
`{fun_name}` has experimental support for Python Array API Standard compatible
backends in addition to NumPy. Please consider testing these features
by setting an environment variable ``SCIPY_ARRAY_API=1`` and providing
CuPy, PyTorch, JAX, or Dask arrays as array arguments. The following
combinations of backend and device (or other capability) are supported.
==================== ==================== ====================
Library CPU GPU
==================== ==================== ====================
NumPy {capabilities['numpy'] }
CuPy {capabilities['cupy'] }
PyTorch {capabilities['torch'] }
JAX {capabilities['jax.numpy'] }
Dask {capabilities['dask.array'] }
==================== ==================== ====================
See :ref:`dev-arrayapi` for more information.
"""
return textwrap.dedent(note)
def xp_capabilities(
*,
# Alternative capabilities table.
# Used only for testing this decorator.
capabilities_table=None,
# Generate pytest.mark.skip/xfail_xp_backends.
# See documentation in conftest.py.
# lists of tuples [(module name, reason), ...]
skip_backends=(), xfail_backends=(),
cpu_only=False, np_only=False, reason=None, exceptions=(),
# lists of tuples [(module name, reason), ...]
warnings=(),
# xpx.testing.lazy_xp_function kwargs.
# Refer to array-api-extra documentation.
allow_dask_compute=False, jax_jit=True,
):
"""Decorator for a function that states its support among various
Array API compatible backends.
This decorator has two effects:
1. It allows tagging tests with ``@make_xp_test_case`` or
``make_xp_pytest_param`` (see below) to automatically generate
SKIP/XFAIL markers and perform additional backend-specific
testing, such as extra validation for Dask and JAX;
2. It automatically adds a note to the function's docstring, containing
a table matching what has been tested.
See Also
--------
make_xp_test_case
make_xp_pytest_param
array_api_extra.testing.lazy_xp_function
"""
capabilities_table = (xp_capabilities_table if capabilities_table is None
else capabilities_table)
capabilities = dict(
skip_backends=skip_backends,
xfail_backends=xfail_backends,
cpu_only=cpu_only,
np_only=np_only,
reason=reason,
exceptions=exceptions,
allow_dask_compute=allow_dask_compute,
jax_jit=jax_jit,
warnings=warnings,
)
sphinx_capabilities = _make_sphinx_capabilities(**capabilities)
def decorator(f):
# Don't use a wrapper, as in some cases @xp_capabilities is
# applied to a ufunc
capabilities_table[f] = capabilities
note = _make_capabilities_note(f.__name__, sphinx_capabilities)
doc = FunctionDoc(f)
doc['Notes'].append(note)
doc = str(doc).split("\n", 1)[1] # remove signature
try:
f.__doc__ = doc
except AttributeError:
# Can't update __doc__ on ufuncs if SciPy
# was compiled against NumPy < 2.2.
pass
return f
return decorator
def _make_xp_pytest_marks(*funcs, capabilities_table=None):
capabilities_table = (xp_capabilities_table if capabilities_table is None
else capabilities_table)
import pytest
from scipy._lib.array_api_extra.testing import lazy_xp_function
marks = []
for func in funcs:
capabilities = capabilities_table[func]
exceptions = capabilities['exceptions']
reason = capabilities['reason']
if capabilities['cpu_only']:
marks.append(pytest.mark.skip_xp_backends(
cpu_only=True, exceptions=exceptions, reason=reason))
if capabilities['np_only']:
marks.append(pytest.mark.skip_xp_backends(
np_only=True, exceptions=exceptions, reason=reason))
for mod_name, reason in capabilities['skip_backends']:
marks.append(pytest.mark.skip_xp_backends(mod_name, reason=reason))
for mod_name, reason in capabilities['xfail_backends']:
marks.append(pytest.mark.xfail_xp_backends(mod_name, reason=reason))
lazy_kwargs = {k: capabilities[k]
for k in ('allow_dask_compute', 'jax_jit')}
lazy_xp_function(func, **lazy_kwargs)
return marks
def make_xp_test_case(*funcs, capabilities_table=None):
capabilities_table = (xp_capabilities_table if capabilities_table is None
else capabilities_table)
"""Generate pytest decorator for a test function that tests functionality
of one or more Array API compatible functions.
Read the parameters of the ``@xp_capabilities`` decorator applied to the
listed functions and:
- Generate the ``@pytest.mark.skip_xp_backends`` and
``@pytest.mark.xfail_xp_backends`` decorators
for the decorated test function
- Tag the function with `xpx.testing.lazy_xp_function`
See Also
--------
xp_capabilities
make_xp_pytest_param
array_api_extra.testing.lazy_xp_function
"""
marks = _make_xp_pytest_marks(*funcs, capabilities_table=capabilities_table)
return lambda func: functools.reduce(lambda f, g: g(f), marks, func)
def make_xp_pytest_param(func, *args, capabilities_table=None):
"""Variant of ``make_xp_test_case`` that returns a pytest.param for a function,
with all necessary skip_xp_backends and xfail_xp_backends marks applied::
@pytest.mark.parametrize(
"func", [make_xp_pytest_param(f1), make_xp_pytest_param(f2)]
)
def test(func, xp):
...
The above is equivalent to::
@pytest.mark.parametrize(
"func", [
pytest.param(f1, marks=[
pytest.mark.skip_xp_backends(...),
pytest.mark.xfail_xp_backends(...), ...]),
pytest.param(f2, marks=[
pytest.mark.skip_xp_backends(...),
pytest.mark.xfail_xp_backends(...), ...]),
)
def test(func, xp):
...
Parameters
----------
func : Callable
Function to be tested. It must be decorated with ``@xp_capabilities``.
*args : Any, optional
Extra pytest parameters for the use case, e.g.::
@pytest.mark.parametrize("func,verb", [
make_xp_pytest_param(f1, "hello"),
make_xp_pytest_param(f2, "world")])
def test(func, verb, xp):
# iterates on (func=f1, verb="hello")
# and (func=f2, verb="world")
See Also
--------
xp_capabilities
make_xp_test_case
array_api_extra.testing.lazy_xp_function
"""
import pytest
marks = _make_xp_pytest_marks(func, capabilities_table=capabilities_table)
return pytest.param(func, *args, marks=marks, id=func.__name__)
# Is it OK to have a dictionary that is mutated (once upon import) in many places?
xp_capabilities_table = {} # type: ignore[var-annotated]

View file

@ -0,0 +1,9 @@
# DO NOT RENAME THIS FILE
# This is a hook for array_api_extra/src/array_api_extra/_lib/_compat.py
# to override functions of array_api_compat.
from .array_api_compat import * # noqa: F403
from ._array_api import array_namespace as scipy_array_namespace
# overrides array_api_compat.array_namespace inside array-api-extra
array_namespace = scipy_array_namespace # type: ignore[assignment]

View file

@ -0,0 +1,103 @@
"""
Extra testing functions that forbid 0d-input, see #21044
While the xp_assert_* functions generally aim to follow the conventions of the
underlying `xp` library, NumPy in particular is inconsistent in its handling
of scalars vs. 0d-arrays, see https://github.com/numpy/numpy/issues/24897.
For example, this means that the following operations (as of v2.0.1) currently
return scalars, even though a 0d-array would often be more appropriate:
import numpy as np
np.array(0) * 2 # scalar, not 0d array
- np.array(0) # scalar, not 0d-array
np.sin(np.array(0)) # scalar, not 0d array
np.mean([1, 2, 3]) # scalar, not 0d array
Libraries like CuPy tend to return a 0d-array in scenarios like those above,
and even `xp.asarray(0)[()]` remains a 0d-array there. To deal with the reality
of the inconsistencies present in NumPy, as well as 20+ years of code on top,
the `xp_assert_*` functions here enforce consistency in the only way that
doesn't go against the tide, i.e. by forbidding 0d-arrays as the return type.
However, when scalars are not generally the expected NumPy return type,
it remains preferable to use the assert functions from
the `scipy._lib._array_api` module, which have less surprising behaviour.
"""
from scipy._lib._array_api import array_namespace, is_numpy
from scipy._lib._array_api import (xp_assert_close as xp_assert_close_base,
xp_assert_equal as xp_assert_equal_base,
xp_assert_less as xp_assert_less_base)
__all__: list[str] = []
def _check_scalar(actual, desired, *, xp=None, **kwargs):
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
xp = array_namespace(actual)
# necessary to handle non-numpy scalars, e.g. bare `0.0` has no shape
desired = xp.asarray(desired)
# Only NumPy distinguishes between scalars and arrays;
# shape check in xp_assert_* is sufficient except for shape == ()
if not (is_numpy(xp) and desired.shape == ()):
return
_msg = ("Result is a NumPy 0d-array. Many SciPy functions intend to follow "
"the convention of many NumPy functions, returning a scalar when a "
"0d-array would be correct. The specialized `xp_assert_*` functions "
"in the `scipy._lib._array_api_no_0d` module err on the side of "
"caution and do not accept 0d-arrays by default. If the correct "
"result may legitimately be a 0d-array, pass `check_0d=True`, "
"or use the `xp_assert_*` functions from `scipy._lib._array_api`.")
assert xp.isscalar(actual), _msg
def xp_assert_equal(actual, desired, *, check_0d=False, **kwargs):
# in contrast to xp_assert_equal_base, this defaults to check_0d=False,
# but will do an extra check in that case, which forbids 0d-arrays for `actual`
__tracebackhide__ = True # Hide traceback for py.test
# array-ness (check_0d == True) is taken care of by the *_base functions
if not check_0d:
_check_scalar(actual, desired, **kwargs)
return xp_assert_equal_base(actual, desired, check_0d=check_0d, **kwargs)
def xp_assert_close(actual, desired, *, check_0d=False, **kwargs):
# as for xp_assert_equal
__tracebackhide__ = True
if not check_0d:
_check_scalar(actual, desired, **kwargs)
return xp_assert_close_base(actual, desired, check_0d=check_0d, **kwargs)
def xp_assert_less(actual, desired, *, check_0d=False, **kwargs):
# as for xp_assert_equal
__tracebackhide__ = True
if not check_0d:
_check_scalar(actual, desired, **kwargs)
return xp_assert_less_base(actual, desired, check_0d=check_0d, **kwargs)
def assert_array_almost_equal(actual, desired, decimal=6, *args, **kwds):
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
"""
rtol, atol = 0, 1.5*10**(-decimal)
return xp_assert_close(actual, desired,
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
*args, **kwds)
def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
"""
rtol, atol = 0, 1.5*10**(-decimal)
return xp_assert_close(actual, desired,
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
*args, **kwds)

View file

@ -0,0 +1,229 @@
import sys as _sys
from keyword import iskeyword as _iskeyword
def _validate_names(typename, field_names, extra_field_names):
"""
Ensure that all the given names are valid Python identifiers that
do not start with '_'. Also check that there are no duplicates
among field_names + extra_field_names.
"""
for name in [typename] + field_names + extra_field_names:
if not isinstance(name, str):
raise TypeError('typename and all field names must be strings')
if not name.isidentifier():
raise ValueError('typename and all field names must be valid '
f'identifiers: {name!r}')
if _iskeyword(name):
raise ValueError('typename and all field names cannot be a '
f'keyword: {name!r}')
seen = set()
for name in field_names + extra_field_names:
if name.startswith('_'):
raise ValueError('Field names cannot start with an underscore: '
f'{name!r}')
if name in seen:
raise ValueError(f'Duplicate field name: {name!r}')
seen.add(name)
# Note: This code is adapted from CPython:Lib/collections/__init__.py
def _make_tuple_bunch(typename, field_names, extra_field_names=None,
module=None):
"""
Create a namedtuple-like class with additional attributes.
This function creates a subclass of tuple that acts like a namedtuple
and that has additional attributes.
The additional attributes are listed in `extra_field_names`. The
values assigned to these attributes are not part of the tuple.
The reason this function exists is to allow functions in SciPy
that currently return a tuple or a namedtuple to returned objects
that have additional attributes, while maintaining backwards
compatibility.
This should only be used to enhance *existing* functions in SciPy.
New functions are free to create objects as return values without
having to maintain backwards compatibility with an old tuple or
namedtuple return value.
Parameters
----------
typename : str
The name of the type.
field_names : list of str
List of names of the values to be stored in the tuple. These names
will also be attributes of instances, so the values in the tuple
can be accessed by indexing or as attributes. At least one name
is required. See the Notes for additional restrictions.
extra_field_names : list of str, optional
List of names of values that will be stored as attributes of the
object. See the notes for additional restrictions.
Returns
-------
cls : type
The new class.
Notes
-----
There are restrictions on the names that may be used in `field_names`
and `extra_field_names`:
* The names must be unique--no duplicates allowed.
* The names must be valid Python identifiers, and must not begin with
an underscore.
* The names must not be Python keywords (e.g. 'def', 'and', etc., are
not allowed).
Examples
--------
>>> from scipy._lib._bunch import _make_tuple_bunch
Create a class that acts like a namedtuple with length 2 (with field
names `x` and `y`) that will also have the attributes `w` and `beta`:
>>> Result = _make_tuple_bunch('Result', ['x', 'y'], ['w', 'beta'])
`Result` is the new class. We call it with keyword arguments to create
a new instance with given values.
>>> result1 = Result(x=1, y=2, w=99, beta=0.5)
>>> result1
Result(x=1, y=2, w=99, beta=0.5)
`result1` acts like a tuple of length 2:
>>> len(result1)
2
>>> result1[:]
(1, 2)
The values assigned when the instance was created are available as
attributes:
>>> result1.y
2
>>> result1.beta
0.5
"""
if len(field_names) == 0:
raise ValueError('field_names must contain at least one name')
if extra_field_names is None:
extra_field_names = []
_validate_names(typename, field_names, extra_field_names)
typename = _sys.intern(str(typename))
field_names = tuple(map(_sys.intern, field_names))
extra_field_names = tuple(map(_sys.intern, extra_field_names))
all_names = field_names + extra_field_names
arg_list = ', '.join(field_names)
full_list = ', '.join(all_names)
repr_fmt = ''.join(('(',
', '.join(f'{name}=%({name})r' for name in all_names),
')'))
tuple_new = tuple.__new__
_dict, _tuple, _zip = dict, tuple, zip
# Create all the named tuple methods to be added to the class namespace
s = f"""\
def __new__(_cls, {arg_list}, **extra_fields):
return _tuple_new(_cls, ({arg_list},))
def __init__(self, {arg_list}, **extra_fields):
for key in self._extra_fields:
if key not in extra_fields:
raise TypeError("missing keyword argument '%s'" % (key,))
for key, val in extra_fields.items():
if key not in self._extra_fields:
raise TypeError("unexpected keyword argument '%s'" % (key,))
self.__dict__[key] = val
def __setattr__(self, key, val):
if key in {repr(field_names)}:
raise AttributeError("can't set attribute %r of class %r"
% (key, self.__class__.__name__))
else:
self.__dict__[key] = val
"""
del arg_list
namespace = {'_tuple_new': tuple_new,
'__builtins__': dict(TypeError=TypeError,
AttributeError=AttributeError),
'__name__': f'namedtuple_{typename}'}
exec(s, namespace)
__new__ = namespace['__new__']
__new__.__doc__ = f'Create new instance of {typename}({full_list})'
__init__ = namespace['__init__']
__init__.__doc__ = f'Instantiate instance of {typename}({full_list})'
__setattr__ = namespace['__setattr__']
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self._asdict()
def _asdict(self):
'Return a new dict which maps field names to their values.'
out = _dict(_zip(self._fields, self))
out.update(self.__dict__)
return out
def __getnewargs_ex__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self), self.__dict__
# Modify function metadata to help with introspection and debugging
for method in (__new__, __repr__, _asdict, __getnewargs_ex__):
method.__qualname__ = f'{typename}.{method.__name__}'
# Build-up the class namespace dictionary
# and use type() to build the result class
class_namespace = {
'__doc__': f'{typename}({full_list})',
'_fields': field_names,
'__new__': __new__,
'__init__': __init__,
'__repr__': __repr__,
'__setattr__': __setattr__,
'_asdict': _asdict,
'_extra_fields': extra_field_names,
'__getnewargs_ex__': __getnewargs_ex__,
# _field_defaults and _replace are added to get Polars to detect
# a bunch object as a namedtuple. See gh-22450
'_field_defaults': {},
'_replace': None,
}
for index, name in enumerate(field_names):
def _get(self, index=index):
return self[index]
class_namespace[name] = property(_get)
for name in extra_field_names:
def _get(self, name=name):
return self.__dict__[name]
class_namespace[name] = property(_get)
result = type(typename, (tuple,), class_namespace)
# For pickling to work, the __module__ variable needs to be set to the
# frame where the named tuple is created. Bypass this step in environments
# where sys._getframe is not defined (Jython for example) or sys._getframe
# is not defined for arguments greater than 0 (IronPython), or where the
# user has specified a particular module.
if module is None:
try:
module = _sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
if module is not None:
result.__module__ = module
__new__.__module__ = module
return result

View file

@ -0,0 +1,251 @@
from . import _ccallback_c
import ctypes
PyCFuncPtr = ctypes.CFUNCTYPE(ctypes.c_void_p).__bases__[0]
ffi = None
class CData:
pass
def _import_cffi():
global ffi, CData
if ffi is not None:
return
try:
import cffi
ffi = cffi.FFI()
CData = ffi.CData
except ImportError:
ffi = False
class LowLevelCallable(tuple):
"""
Low-level callback function.
Some functions in SciPy take as arguments callback functions, which
can either be python callables or low-level compiled functions. Using
compiled callback functions can improve performance somewhat by
avoiding wrapping data in Python objects.
Such low-level functions in SciPy are wrapped in `LowLevelCallable`
objects, which can be constructed from function pointers obtained from
ctypes, cffi, Cython, or contained in Python `PyCapsule` objects.
.. seealso::
Functions accepting low-level callables:
`scipy.integrate.quad`, `scipy.ndimage.generic_filter`,
`scipy.ndimage.generic_filter1d`, `scipy.ndimage.geometric_transform`
Usage examples:
:ref:`ndimage-ccallbacks`, :ref:`quad-callbacks`
Parameters
----------
function : {PyCapsule, ctypes function pointer, cffi function pointer}
Low-level callback function.
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}
User data to pass on to the callback function.
signature : str, optional
Signature of the function. If omitted, determined from *function*,
if possible.
Attributes
----------
function
Callback function given.
user_data
User data given.
signature
Signature of the function.
Methods
-------
from_cython
Class method for constructing callables from Cython C-exported
functions.
Notes
-----
The argument ``function`` can be one of:
- PyCapsule, whose name contains the C function signature
- ctypes function pointer
- cffi function pointer
The signature of the low-level callback must match one of those expected
by the routine it is passed to.
If constructing low-level functions from a PyCapsule, the name of the
capsule must be the corresponding signature, in the format::
return_type (arg1_type, arg2_type, ...)
For example::
"void (double)"
"double (double, int *, void *)"
The context of a PyCapsule passed in as ``function`` is used as ``user_data``,
if an explicit value for ``user_data`` was not given.
"""
# Make the class immutable
__slots__ = ()
def __new__(cls, function, user_data=None, signature=None):
# We need to hold a reference to the function & user data,
# to prevent them going out of scope
item = cls._parse_callback(function, user_data, signature)
return tuple.__new__(cls, (item, function, user_data))
def __repr__(self):
return f"LowLevelCallable({self.function!r}, {self.user_data!r})"
@property
def function(self):
return tuple.__getitem__(self, 1)
@property
def user_data(self):
return tuple.__getitem__(self, 2)
@property
def signature(self):
return _ccallback_c.get_capsule_signature(tuple.__getitem__(self, 0))
def __getitem__(self, idx):
raise ValueError()
@classmethod
def from_cython(cls, module, name, user_data=None, signature=None):
"""
Create a low-level callback function from an exported Cython function.
Parameters
----------
module : module
Cython module where the exported function resides
name : str
Name of the exported function
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}, optional
User data to pass on to the callback function.
signature : str, optional
Signature of the function. If omitted, determined from *function*.
"""
try:
function = module.__pyx_capi__[name]
except AttributeError as e:
message = "Given module is not a Cython module with __pyx_capi__ attribute"
raise ValueError(message) from e
except KeyError as e:
message = f"No function {name!r} found in __pyx_capi__ of the module"
raise ValueError(message) from e
return cls(function, user_data, signature)
@classmethod
def _parse_callback(cls, obj, user_data=None, signature=None):
_import_cffi()
if isinstance(obj, LowLevelCallable):
func = tuple.__getitem__(obj, 0)
elif isinstance(obj, PyCFuncPtr):
func, signature = _get_ctypes_func(obj, signature)
elif isinstance(obj, CData):
func, signature = _get_cffi_func(obj, signature)
elif _ccallback_c.check_capsule(obj):
func = obj
else:
raise ValueError("Given input is not a callable or a "
"low-level callable (pycapsule/ctypes/cffi)")
if isinstance(user_data, ctypes.c_void_p):
context = _get_ctypes_data(user_data)
elif isinstance(user_data, CData):
context = _get_cffi_data(user_data)
elif user_data is None:
context = 0
elif _ccallback_c.check_capsule(user_data):
context = user_data
else:
raise ValueError("Given user data is not a valid "
"low-level void* pointer (pycapsule/ctypes/cffi)")
return _ccallback_c.get_raw_capsule(func, signature, context)
#
# ctypes helpers
#
def _get_ctypes_func(func, signature=None):
# Get function pointer
func_ptr = ctypes.cast(func, ctypes.c_void_p).value
# Construct function signature
if signature is None:
signature = _typename_from_ctypes(func.restype) + " ("
for j, arg in enumerate(func.argtypes):
if j == 0:
signature += _typename_from_ctypes(arg)
else:
signature += ", " + _typename_from_ctypes(arg)
signature += ")"
return func_ptr, signature
def _typename_from_ctypes(item):
if item is None:
return "void"
elif item is ctypes.c_void_p:
return "void *"
name = item.__name__
pointer_level = 0
while name.startswith("LP_"):
pointer_level += 1
name = name[3:]
if name.startswith('c_'):
name = name[2:]
if pointer_level > 0:
name += " " + "*"*pointer_level
return name
def _get_ctypes_data(data):
# Get voidp pointer
return ctypes.cast(data, ctypes.c_void_p).value
#
# CFFI helpers
#
def _get_cffi_func(func, signature=None):
# Get function pointer
func_ptr = ffi.cast('uintptr_t', func)
# Get signature
if signature is None:
signature = ffi.getctype(ffi.typeof(func)).replace('(*)', ' ')
return func_ptr, signature
def _get_cffi_data(data):
# Get pointer
return ffi.cast('uintptr_t', data)

View file

@ -0,0 +1,254 @@
"""
Disjoint set data structure
"""
class DisjointSet:
""" Disjoint set data structure for incremental connectivity queries.
.. versionadded:: 1.6.0
Attributes
----------
n_subsets : int
The number of subsets.
Methods
-------
add
merge
connected
subset
subset_size
subsets
__getitem__
Notes
-----
This class implements the disjoint set [1]_, also known as the *union-find*
or *merge-find* data structure. The *find* operation (implemented in
`__getitem__`) implements the *path halving* variant. The *merge* method
implements the *merge by size* variant.
References
----------
.. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Examples
--------
>>> from scipy.cluster.hierarchy import DisjointSet
Initialize a disjoint set:
>>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
Merge some subsets:
>>> disjoint_set.merge(1, 2)
True
>>> disjoint_set.merge(3, 'a')
True
>>> disjoint_set.merge('a', 'b')
True
>>> disjoint_set.merge('b', 'b')
False
Find root elements:
>>> disjoint_set[2]
1
>>> disjoint_set['b']
3
Test connectivity:
>>> disjoint_set.connected(1, 2)
True
>>> disjoint_set.connected(1, 'b')
False
List elements in disjoint set:
>>> list(disjoint_set)
[1, 2, 3, 'a', 'b']
Get the subset containing 'a':
>>> disjoint_set.subset('a')
{'a', 3, 'b'}
Get the size of the subset containing 'a' (without actually instantiating
the subset):
>>> disjoint_set.subset_size('a')
3
Get all subsets in the disjoint set:
>>> disjoint_set.subsets()
[{1, 2}, {'a', 3, 'b'}]
"""
def __init__(self, elements=None):
self.n_subsets = 0
self._sizes = {}
self._parents = {}
# _nbrs is a circular linked list which links connected elements.
self._nbrs = {}
# _indices tracks the element insertion order in `__iter__`.
self._indices = {}
if elements is not None:
for x in elements:
self.add(x)
def __iter__(self):
"""Returns an iterator of the elements in the disjoint set.
Elements are ordered by insertion order.
"""
return iter(self._indices)
def __len__(self):
return len(self._indices)
def __contains__(self, x):
return x in self._indices
def __getitem__(self, x):
"""Find the root element of `x`.
Parameters
----------
x : hashable object
Input element.
Returns
-------
root : hashable object
Root element of `x`.
"""
if x not in self._indices:
raise KeyError(x)
# find by "path halving"
parents = self._parents
while self._indices[x] != self._indices[parents[x]]:
parents[x] = parents[parents[x]]
x = parents[x]
return x
def add(self, x):
"""Add element `x` to disjoint set
"""
if x in self._indices:
return
self._sizes[x] = 1
self._parents[x] = x
self._nbrs[x] = x
self._indices[x] = len(self._indices)
self.n_subsets += 1
def merge(self, x, y):
"""Merge the subsets of `x` and `y`.
The smaller subset (the child) is merged into the larger subset (the
parent). If the subsets are of equal size, the root element which was
first inserted into the disjoint set is selected as the parent.
Parameters
----------
x, y : hashable object
Elements to merge.
Returns
-------
merged : bool
True if `x` and `y` were in disjoint sets, False otherwise.
"""
xr = self[x]
yr = self[y]
if self._indices[xr] == self._indices[yr]:
return False
sizes = self._sizes
if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
xr, yr = yr, xr
self._parents[yr] = xr
self._sizes[xr] += self._sizes[yr]
self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
self.n_subsets -= 1
return True
def connected(self, x, y):
"""Test whether `x` and `y` are in the same subset.
Parameters
----------
x, y : hashable object
Elements to test.
Returns
-------
result : bool
True if `x` and `y` are in the same set, False otherwise.
"""
return self._indices[self[x]] == self._indices[self[y]]
def subset(self, x):
"""Get the subset containing `x`.
Parameters
----------
x : hashable object
Input element.
Returns
-------
result : set
Subset containing `x`.
"""
if x not in self._indices:
raise KeyError(x)
result = [x]
nxt = self._nbrs[x]
while self._indices[nxt] != self._indices[x]:
result.append(nxt)
nxt = self._nbrs[nxt]
return set(result)
def subset_size(self, x):
"""Get the size of the subset containing `x`.
Note that this method is faster than ``len(self.subset(x))`` because
the size is directly read off an internal field, without the need to
instantiate the full subset.
Parameters
----------
x : hashable object
Input element.
Returns
-------
result : int
Size of the subset containing `x`.
"""
return self._sizes[self[x]]
def subsets(self):
"""Get all the subsets in the disjoint set.
Returns
-------
result : list
Subsets in the disjoint set.
"""
result = []
visited = set()
for x in self:
if x not in visited:
xset = self.subset(x)
visited.update(xset)
result.append(xset)
return result

View file

@ -0,0 +1,761 @@
# copied from numpydoc/docscrape.py, commit 97a6026508e0dd5382865672e9563a72cc113bd2
"""Extract reference documentation from the NumPy source tree."""
import copy
import inspect
import pydoc
import re
import sys
import textwrap
from collections import namedtuple
from collections.abc import Callable, Mapping
from functools import cached_property
from warnings import warn
def strip_blank_lines(l):
"Remove leading and trailing blank lines from a list of lines"
while l and not l[0].strip():
del l[0]
while l and not l[-1].strip():
del l[-1]
return l
class Reader:
"""A line-based string reader."""
def __init__(self, data):
"""
Parameters
----------
data : str
String with lines separated by '\\n'.
"""
if isinstance(data, list):
self._str = data
else:
self._str = data.split("\n") # store string as list of lines
self.reset()
def __getitem__(self, n):
return self._str[n]
def reset(self):
self._l = 0 # current line nr
def read(self):
if not self.eof():
out = self[self._l]
self._l += 1
return out
else:
return ""
def seek_next_non_empty_line(self):
for l in self[self._l :]:
if l.strip():
break
else:
self._l += 1
def eof(self):
return self._l >= len(self._str)
def read_to_condition(self, condition_func):
start = self._l
for line in self[start:]:
if condition_func(line):
return self[start : self._l]
self._l += 1
if self.eof():
return self[start : self._l + 1]
return []
def read_to_next_empty_line(self):
self.seek_next_non_empty_line()
def is_empty(line):
return not line.strip()
return self.read_to_condition(is_empty)
def read_to_next_unindented_line(self):
def is_unindented(line):
return line.strip() and (len(line.lstrip()) == len(line))
return self.read_to_condition(is_unindented)
def peek(self, n=0):
if self._l + n < len(self._str):
return self[self._l + n]
else:
return ""
def is_empty(self):
return not "".join(self._str).strip()
class ParseError(Exception):
def __str__(self):
message = self.args[0]
if hasattr(self, "docstring"):
message = f"{message} in {self.docstring!r}"
return message
Parameter = namedtuple("Parameter", ["name", "type", "desc"])
class NumpyDocString(Mapping):
"""Parses a numpydoc string to an abstract representation
Instances define a mapping from section title to structured data.
"""
sections = {
"Signature": "",
"Summary": [""],
"Extended Summary": [],
"Parameters": [],
"Attributes": [],
"Methods": [],
"Returns": [],
"Yields": [],
"Receives": [],
"Other Parameters": [],
"Raises": [],
"Warns": [],
"Warnings": [],
"See Also": [],
"Notes": [],
"References": "",
"Examples": "",
"index": {},
}
def __init__(self, docstring, config=None):
orig_docstring = docstring
docstring = textwrap.dedent(docstring).split("\n")
self._doc = Reader(docstring)
self._parsed_data = copy.deepcopy(self.sections)
try:
self._parse()
except ParseError as e:
e.docstring = orig_docstring
raise
def __getitem__(self, key):
return self._parsed_data[key]
def __setitem__(self, key, val):
if key not in self._parsed_data:
self._error_location(f"Unknown section {key}", error=False)
else:
self._parsed_data[key] = val
def __iter__(self):
return iter(self._parsed_data)
def __len__(self):
return len(self._parsed_data)
def _is_at_section(self):
self._doc.seek_next_non_empty_line()
if self._doc.eof():
return False
l1 = self._doc.peek().strip() # e.g. Parameters
if l1.startswith(".. index::"):
return True
l2 = self._doc.peek(1).strip() # ---------- or ==========
if len(l2) >= 3 and (set(l2) in ({"-"}, {"="})) and len(l2) != len(l1):
snip = "\n".join(self._doc._str[:2]) + "..."
self._error_location(
f"potentially wrong underline length... \n{l1} \n{l2} in \n{snip}",
error=False,
)
return l2.startswith("-" * len(l1)) or l2.startswith("=" * len(l1))
def _strip(self, doc):
i = 0
j = 0
for i, line in enumerate(doc):
if line.strip():
break
for j, line in enumerate(doc[::-1]):
if line.strip():
break
return doc[i : len(doc) - j]
def _read_to_next_section(self):
section = self._doc.read_to_next_empty_line()
while not self._is_at_section() and not self._doc.eof():
if not self._doc.peek(-1).strip(): # previous line was empty
section += [""]
section += self._doc.read_to_next_empty_line()
return section
def _read_sections(self):
while not self._doc.eof():
data = self._read_to_next_section()
name = data[0].strip()
if name.startswith(".."): # index section
yield name, data[1:]
elif len(data) < 2:
yield StopIteration
else:
yield name, self._strip(data[2:])
def _parse_param_list(self, content, single_element_is_type=False):
content = dedent_lines(content)
r = Reader(content)
params = []
while not r.eof():
header = r.read().strip()
if " : " in header:
arg_name, arg_type = header.split(" : ", maxsplit=1)
else:
# NOTE: param line with single element should never have a
# a " :" before the description line, so this should probably
# warn.
if header.endswith(" :"):
header = header[:-2]
if single_element_is_type:
arg_name, arg_type = "", header
else:
arg_name, arg_type = header, ""
desc = r.read_to_next_unindented_line()
desc = dedent_lines(desc)
desc = strip_blank_lines(desc)
params.append(Parameter(arg_name, arg_type, desc))
return params
# See also supports the following formats.
#
# <FUNCNAME>
# <FUNCNAME> SPACE* COLON SPACE+ <DESC> SPACE*
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)+ (COMMA | PERIOD)? SPACE*
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)* SPACE* COLON SPACE+ <DESC> SPACE*
# <FUNCNAME> is one of
# <PLAIN_FUNCNAME>
# COLON <ROLE> COLON BACKTICK <PLAIN_FUNCNAME> BACKTICK
# where
# <PLAIN_FUNCNAME> is a legal function name, and
# <ROLE> is any nonempty sequence of word characters.
# Examples: func_f1 :meth:`func_h1` :obj:`~baz.obj_r` :class:`class_j`
# <DESC> is a string describing the function.
_role = r":(?P<role>(py:)?\w+):"
_funcbacktick = r"`(?P<name>(?:~\w+\.)?[a-zA-Z0-9_\.-]+)`"
_funcplain = r"(?P<name2>[a-zA-Z0-9_\.-]+)"
_funcname = r"(" + _role + _funcbacktick + r"|" + _funcplain + r")"
_funcnamenext = _funcname.replace("role", "rolenext")
_funcnamenext = _funcnamenext.replace("name", "namenext")
_description = r"(?P<description>\s*:(\s+(?P<desc>\S+.*))?)?\s*$"
_func_rgx = re.compile(r"^\s*" + _funcname + r"\s*")
_line_rgx = re.compile(
r"^\s*"
+ r"(?P<allfuncs>"
+ _funcname # group for all function names
+ r"(?P<morefuncs>([,]\s+"
+ _funcnamenext
+ r")*)"
+ r")"
+ r"(?P<trailing>[,\.])?" # end of "allfuncs"
+ _description # Some function lists have a trailing comma (or period) '\s*'
)
# Empty <DESC> elements are replaced with '..'
empty_description = ".."
def _parse_see_also(self, content):
"""
func_name : Descriptive text
continued text
another_func_name : Descriptive text
func_name1, func_name2, :meth:`func_name`, func_name3
"""
content = dedent_lines(content)
items = []
def parse_item_name(text):
"""Match ':role:`name`' or 'name'."""
m = self._func_rgx.match(text)
if not m:
self._error_location(f"Error parsing See Also entry {line!r}")
role = m.group("role")
name = m.group("name") if role else m.group("name2")
return name, role, m.end()
rest = []
for line in content:
if not line.strip():
continue
line_match = self._line_rgx.match(line)
description = None
if line_match:
description = line_match.group("desc")
if line_match.group("trailing") and description:
self._error_location(
"Unexpected comma or period after function list at index %d of "
'line "%s"' % (line_match.end("trailing"), line),
error=False,
)
if not description and line.startswith(" "):
rest.append(line.strip())
elif line_match:
funcs = []
text = line_match.group("allfuncs")
while True:
if not text.strip():
break
name, role, match_end = parse_item_name(text)
funcs.append((name, role))
text = text[match_end:].strip()
if text and text[0] == ",":
text = text[1:].strip()
rest = list(filter(None, [description]))
items.append((funcs, rest))
else:
self._error_location(f"Error parsing See Also entry {line!r}")
return items
def _parse_index(self, section, content):
"""
.. index:: default
:refguide: something, else, and more
"""
def strip_each_in(lst):
return [s.strip() for s in lst]
out = {}
section = section.split("::")
if len(section) > 1:
out["default"] = strip_each_in(section[1].split(","))[0]
for line in content:
line = line.split(":")
if len(line) > 2:
out[line[1]] = strip_each_in(line[2].split(","))
return out
def _parse_summary(self):
"""Grab signature (if given) and summary"""
if self._is_at_section():
return
# If several signatures present, take the last one
while True:
summary = self._doc.read_to_next_empty_line()
summary_str = " ".join([s.strip() for s in summary]).strip()
compiled = re.compile(r"^([\w., ]+=)?\s*[\w\.]+\(.*\)$")
if compiled.match(summary_str):
self["Signature"] = summary_str
if not self._is_at_section():
continue
break
if summary is not None:
self["Summary"] = summary
if not self._is_at_section():
self["Extended Summary"] = self._read_to_next_section()
def _parse(self):
self._doc.reset()
self._parse_summary()
sections = list(self._read_sections())
section_names = {section for section, content in sections}
has_yields = "Yields" in section_names
# We could do more tests, but we are not. Arbitrarily.
if not has_yields and "Receives" in section_names:
msg = "Docstring contains a Receives section but not Yields."
raise ValueError(msg)
for section, content in sections:
if not section.startswith(".."):
section = (s.capitalize() for s in section.split(" "))
section = " ".join(section)
if self.get(section):
self._error_location(
"The section %s appears twice in %s"
% (section, "\n".join(self._doc._str))
)
if section in ("Parameters", "Other Parameters", "Attributes", "Methods"):
self[section] = self._parse_param_list(content)
elif section in ("Returns", "Yields", "Raises", "Warns", "Receives"):
self[section] = self._parse_param_list(
content, single_element_is_type=True
)
elif section.startswith(".. index::"):
self["index"] = self._parse_index(section, content)
elif section == "See Also":
self["See Also"] = self._parse_see_also(content)
else:
self[section] = content
@property
def _obj(self):
if hasattr(self, "_cls"):
return self._cls
elif hasattr(self, "_f"):
return self._f
return None
def _error_location(self, msg, error=True):
if self._obj is not None:
# we know where the docs came from:
try:
filename = inspect.getsourcefile(self._obj)
except TypeError:
filename = None
# Make UserWarning more descriptive via object introspection.
# Skip if introspection fails
name = getattr(self._obj, "__name__", None)
if name is None:
name = getattr(getattr(self._obj, "__class__", None), "__name__", None)
if name is not None:
msg += f" in the docstring of {name}"
msg += f" in {filename}." if filename else ""
if error:
raise ValueError(msg)
else:
warn(msg, stacklevel=3)
# string conversion routines
def _str_header(self, name, symbol="-"):
return [name, len(name) * symbol]
def _str_indent(self, doc, indent=4):
return [" " * indent + line for line in doc]
def _str_signature(self):
if self["Signature"]:
return [self["Signature"].replace("*", r"\*")] + [""]
return [""]
def _str_summary(self):
if self["Summary"]:
return self["Summary"] + [""]
return []
def _str_extended_summary(self):
if self["Extended Summary"]:
return self["Extended Summary"] + [""]
return []
def _str_param_list(self, name):
out = []
if self[name]:
out += self._str_header(name)
for param in self[name]:
parts = []
if param.name:
parts.append(param.name)
if param.type:
parts.append(param.type)
out += [" : ".join(parts)]
if param.desc and "".join(param.desc).strip():
out += self._str_indent(param.desc)
out += [""]
return out
def _str_section(self, name):
out = []
if self[name]:
out += self._str_header(name)
out += self[name]
out += [""]
return out
def _str_see_also(self, func_role):
if not self["See Also"]:
return []
out = []
out += self._str_header("See Also")
out += [""]
last_had_desc = True
for funcs, desc in self["See Also"]:
assert isinstance(funcs, list)
links = []
for func, role in funcs:
if role:
link = f":{role}:`{func}`"
elif func_role:
link = f":{func_role}:`{func}`"
else:
link = f"`{func}`_"
links.append(link)
link = ", ".join(links)
out += [link]
if desc:
out += self._str_indent([" ".join(desc)])
last_had_desc = True
else:
last_had_desc = False
out += self._str_indent([self.empty_description])
if last_had_desc:
out += [""]
out += [""]
return out
def _str_index(self):
idx = self["index"]
out = []
output_index = False
default_index = idx.get("default", "")
if default_index:
output_index = True
out += [f".. index:: {default_index}"]
for section, references in idx.items():
if section == "default":
continue
output_index = True
out += [f" :{section}: {', '.join(references)}"]
if output_index:
return out
return ""
def __str__(self, func_role=""):
out = []
out += self._str_signature()
out += self._str_summary()
out += self._str_extended_summary()
out += self._str_param_list("Parameters")
for param_list in ("Attributes", "Methods"):
out += self._str_param_list(param_list)
for param_list in (
"Returns",
"Yields",
"Receives",
"Other Parameters",
"Raises",
"Warns",
):
out += self._str_param_list(param_list)
out += self._str_section("Warnings")
out += self._str_see_also(func_role)
for s in ("Notes", "References", "Examples"):
out += self._str_section(s)
out += self._str_index()
return "\n".join(out)
def dedent_lines(lines):
"""Deindent a list of lines maximally"""
return textwrap.dedent("\n".join(lines)).split("\n")
class FunctionDoc(NumpyDocString):
def __init__(self, func, role="func", doc=None, config=None):
self._f = func
self._role = role # e.g. "func" or "meth"
if doc is None:
if func is None:
raise ValueError("No function or docstring given")
doc = inspect.getdoc(func) or ""
if config is None:
config = {}
NumpyDocString.__init__(self, doc, config)
def get_func(self):
func_name = getattr(self._f, "__name__", self.__class__.__name__)
if inspect.isclass(self._f):
func = getattr(self._f, "__call__", self._f.__init__)
else:
func = self._f
return func, func_name
def __str__(self):
out = ""
func, func_name = self.get_func()
roles = {"func": "function", "meth": "method"}
if self._role:
if self._role not in roles:
print(f"Warning: invalid role {self._role}")
out += f".. {roles.get(self._role, '')}:: {func_name}\n \n\n"
out += super().__str__(func_role=self._role)
return out
class ObjDoc(NumpyDocString):
def __init__(self, obj, doc=None, config=None):
self._f = obj
if config is None:
config = {}
NumpyDocString.__init__(self, doc, config=config)
class ClassDoc(NumpyDocString):
extra_public_methods = ["__call__"]
def __init__(self, cls, doc=None, modulename="", func_doc=FunctionDoc, config=None):
if not inspect.isclass(cls) and cls is not None:
raise ValueError(f"Expected a class or None, but got {cls!r}")
self._cls = cls
if "sphinx" in sys.modules:
from sphinx.ext.autodoc import ALL
else:
ALL = object()
if config is None:
config = {}
self.show_inherited_members = config.get("show_inherited_class_members", True)
if modulename and not modulename.endswith("."):
modulename += "."
self._mod = modulename
if doc is None:
if cls is None:
raise ValueError("No class or documentation string given")
doc = pydoc.getdoc(cls)
NumpyDocString.__init__(self, doc)
_members = config.get("members", [])
if _members is ALL:
_members = None
_exclude = config.get("exclude-members", [])
if config.get("show_class_members", True) and _exclude is not ALL:
def splitlines_x(s):
if not s:
return []
else:
return s.splitlines()
for field, items in [
("Methods", self.methods),
("Attributes", self.properties),
]:
if not self[field]:
doc_list = []
for name in sorted(items):
if name in _exclude or (_members and name not in _members):
continue
try:
doc_item = pydoc.getdoc(getattr(self._cls, name))
doc_list.append(Parameter(name, "", splitlines_x(doc_item)))
except AttributeError:
pass # method doesn't exist
self[field] = doc_list
@property
def methods(self):
if self._cls is None:
return []
return [
name
for name, func in inspect.getmembers(self._cls)
if (
(not name.startswith("_") or name in self.extra_public_methods)
and isinstance(func, Callable)
and self._is_show_member(name)
)
]
@property
def properties(self):
if self._cls is None:
return []
return [
name
for name, func in inspect.getmembers(self._cls)
if (
not name.startswith("_")
and not self._should_skip_member(name, self._cls)
and (
func is None
or isinstance(func, property | cached_property)
or inspect.isdatadescriptor(func)
)
and self._is_show_member(name)
)
]
@staticmethod
def _should_skip_member(name, klass):
return (
# Namedtuples should skip everything in their ._fields as the
# docstrings for each of the members is: "Alias for field number X"
issubclass(klass, tuple)
and hasattr(klass, "_asdict")
and hasattr(klass, "_fields")
and name in klass._fields
)
def _is_show_member(self, name):
return (
# show all class members
self.show_inherited_members
# or class member is not inherited
or name in self._cls.__dict__
)
def get_doc_object(
obj,
what=None,
doc=None,
config=None,
class_doc=ClassDoc,
func_doc=FunctionDoc,
obj_doc=ObjDoc,
):
if what is None:
if inspect.isclass(obj):
what = "class"
elif inspect.ismodule(obj):
what = "module"
elif isinstance(obj, Callable):
what = "function"
else:
what = "object"
if config is None:
config = {}
if what == "class":
return class_doc(obj, func_doc=func_doc, doc=doc, config=config)
elif what in ("function", "method"):
return func_doc(obj, doc=doc, config=config)
else:
if doc is None:
doc = pydoc.getdoc(obj)
return obj_doc(obj, doc, config=config)

View file

@ -0,0 +1,346 @@
# `_elementwise_iterative_method.py` includes tools for writing functions that
# - are vectorized to work elementwise on arrays,
# - implement non-trivial, iterative algorithms with a callback interface, and
# - return rich objects with iteration count, termination status, etc.
#
# Examples include:
# `scipy.optimize._chandrupatla._chandrupatla for scalar rootfinding,
# `scipy.optimize._chandrupatla._chandrupatla_minimize for scalar minimization,
# `scipy.optimize._differentiate._differentiate for numerical differentiation,
# `scipy.optimize._bracket._bracket_root for finding rootfinding brackets,
# `scipy.optimize._bracket._bracket_minimize for finding minimization brackets,
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature,
# `scipy.differentiate.derivative` for finite difference based differentiation.
import math
import numpy as np
from ._util import _RichResult, _call_callback_maybe_halt
from ._array_api import array_namespace, xp_size, xp_result_type
import scipy._lib.array_api_extra as xpx
_ESIGNERR = -1
_ECONVERR = -2
_EVALUEERR = -3
_ECALLBACK = -4
_EINPUTERR = -5
_ECONVERGED = 0
_EINPROGRESS = 1
def _initialize(func, xs, args, complex_ok=False, preserve_shape=None, xp=None):
"""Initialize abscissa, function, and args arrays for elementwise function
Parameters
----------
func : callable
An elementwise function with signature
func(x: ndarray, *args) -> ndarray
where each element of ``x`` is a finite real and ``args`` is a tuple,
which may contain an arbitrary number of arrays that are broadcastable
with ``x``.
xs : tuple of arrays
Finite real abscissa arrays. Must be broadcastable.
args : tuple, optional
Additional positional arguments to be passed to `func`.
preserve_shape : bool, default:False
When ``preserve_shape=False`` (default), `func` may be passed
arguments of any shape; `_scalar_optimization_loop` is permitted
to reshape and compress arguments at will. When
``preserve_shape=False``, arguments passed to `func` must have shape
`shape` or ``shape + (n,)``, where ``n`` is any integer.
xp : namespace
Namespace of array arguments in `xs`.
Returns
-------
xs, fs, args : tuple of arrays
Broadcasted, writeable, 1D abscissa and function value arrays (or
NumPy floats, if appropriate). The dtypes of the `xs` and `fs` are
`xfat`; the dtype of the `args` are unchanged.
shape : tuple of ints
Original shape of broadcasted arrays.
xfat : NumPy dtype
Result dtype of abscissae, function values, and args determined using
`np.result_type`, except integer types are promoted to `np.float64`.
Raises
------
ValueError
If the result dtype is not that of a real scalar
Notes
-----
Useful for initializing the input of SciPy functions that accept
an elementwise callable, abscissae, and arguments; e.g.
`scipy.optimize._chandrupatla`.
"""
nx = len(xs)
xp = array_namespace(*xs) if xp is None else xp
# Try to preserve `dtype`, but we need to ensure that the arguments are at
# least floats before passing them into the function; integers can overflow
# and cause failure.
# There might be benefit to combining the `xs` into a single array and
# calling `func` once on the combined array. For now, keep them separate.
xat = xp_result_type(*xs, force_floating=True, xp=xp)
xas = xp.broadcast_arrays(*xs, *args) # broadcast and rename
xs, args = xas[:nx], xas[nx:]
xs = [xp.asarray(x, dtype=xat) for x in xs] # use copy=False when implemented
fs = [xp.asarray(func(x, *args)) for x in xs]
shape = xs[0].shape
fshape = fs[0].shape
if preserve_shape:
# bind original shape/func now to avoid late-binding gotcha
def func(x, *args, shape=shape, func=func, **kwargs):
i = (0,)*(len(fshape) - len(shape))
return func(x[i], *args, **kwargs)
shape = np.broadcast_shapes(fshape, shape) # just shapes; use of NumPy OK
xs = [xp.broadcast_to(x, shape) for x in xs]
args = [xp.broadcast_to(arg, shape) for arg in args]
message = ("The shape of the array returned by `func` must be the same as "
"the broadcasted shape of `x` and all other `args`.")
if preserve_shape is not None: # only in tanhsinh for now
message = f"When `preserve_shape=False`, {message.lower()}"
shapes_equal = [f.shape == shape for f in fs]
if not all(shapes_equal): # use Python all to reduce overhead
raise ValueError(message)
# These algorithms tend to mix the dtypes of the abscissae and function
# values, so figure out what the result will be and convert them all to
# that type from the outset.
xfat = xp.result_type(*([f.dtype for f in fs] + [xat]))
if not complex_ok and not xp.isdtype(xfat, "real floating"):
raise ValueError("Abscissae and function output must be real numbers.")
xs = [xp.asarray(x, dtype=xfat, copy=True) for x in xs]
fs = [xp.asarray(f, dtype=xfat, copy=True) for f in fs]
# To ensure that we can do indexing, we'll work with at least 1d arrays,
# but remember the appropriate shape of the output.
xs = [xp.reshape(x, (-1,)) for x in xs]
fs = [xp.reshape(f, (-1,)) for f in fs]
args = [xp.reshape(xp.asarray(arg, copy=True), (-1,)) for arg in args]
return func, xs, fs, args, shape, xfat, xp
def _loop(work, callback, shape, maxiter, func, args, dtype, pre_func_eval,
post_func_eval, check_termination, post_termination_check,
customize_result, res_work_pairs, xp, preserve_shape=False):
"""Main loop of a vectorized scalar optimization algorithm
Parameters
----------
work : _RichResult
All variables that need to be retained between iterations. Must
contain attributes `nit`, `nfev`, and `success`. All arrays are
subject to being "compressed" if `preserve_shape is False`; nest
arrays that should not be compressed inside another object (e.g.
`dict` or `_RichResult`).
callback : callable
User-specified callback function
shape : tuple of ints
The shape of all output arrays
maxiter :
Maximum number of iterations of the algorithm
func : callable
The user-specified callable that is being optimized or solved
args : tuple
Additional positional arguments to be passed to `func`.
dtype : NumPy dtype
The common dtype of all abscissae and function values
pre_func_eval : callable
A function that accepts `work` and returns `x`, the active elements
of `x` at which `func` will be evaluated. May modify attributes
of `work` with any algorithmic steps that need to happen
at the beginning of an iteration, before `func` is evaluated,
post_func_eval : callable
A function that accepts `x`, `func(x)`, and `work`. May modify
attributes of `work` with any algorithmic steps that need to happen
in the middle of an iteration, after `func` is evaluated but before
the termination check.
check_termination : callable
A function that accepts `work` and returns `stop`, a boolean array
indicating which of the active elements have met a termination
condition.
post_termination_check : callable
A function that accepts `work`. May modify `work` with any algorithmic
steps that need to happen after the termination check and before the
end of the iteration.
customize_result : callable
A function that accepts `res` and `shape` and returns `shape`. May
modify `res` (in-place) according to preferences (e.g. rearrange
elements between attributes) and modify `shape` if needed.
res_work_pairs : list of (str, str)
Identifies correspondence between attributes of `res` and attributes
of `work`; i.e., attributes of active elements of `work` will be
copied to the appropriate indices of `res` when appropriate. The order
determines the order in which _RichResult attributes will be
pretty-printed.
preserve_shape : bool, default: False
Whether to compress the attributes of `work` (to avoid unnecessary
computation on elements that have already converged).
Returns
-------
res : _RichResult
The final result object
Notes
-----
Besides providing structure, this framework provides several important
services for a vectorized optimization algorithm.
- It handles common tasks involving iteration count, function evaluation
count, a user-specified callback, and associated termination conditions.
- It compresses the attributes of `work` to eliminate unnecessary
computation on elements that have already converged.
"""
if xp is None:
raise NotImplementedError("Must provide xp.")
cb_terminate = False
# Initialize the result object and active element index array
n_elements = math.prod(shape)
active = xp.arange(n_elements) # in-progress element indices
res_dict = {i: xp.zeros(n_elements, dtype=dtype) for i, j in res_work_pairs}
res_dict['success'] = xp.zeros(n_elements, dtype=xp.bool)
res_dict['status'] = xp.full(n_elements, xp.asarray(_EINPROGRESS), dtype=xp.int32)
res_dict['nit'] = xp.zeros(n_elements, dtype=xp.int32)
res_dict['nfev'] = xp.zeros(n_elements, dtype=xp.int32)
res = _RichResult(res_dict)
work.args = args
active = _check_termination(work, res, res_work_pairs, active,
check_termination, preserve_shape, xp)
if callback is not None:
temp = _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape, xp)
if _call_callback_maybe_halt(callback, temp):
cb_terminate = True
while work.nit < maxiter and xp_size(active) and not cb_terminate and n_elements:
x = pre_func_eval(work)
if work.args and work.args[0].ndim != x.ndim:
# `x` always starts as 1D. If the SciPy function that uses
# _loop added dimensions to `x`, we need to
# add them to the elements of `args`.
args = []
for arg in work.args:
n_new_dims = x.ndim - arg.ndim
new_shape = arg.shape + (1,)*n_new_dims
args.append(xp.reshape(arg, new_shape))
work.args = args
x_shape = x.shape
if preserve_shape:
x = xp.reshape(x, (shape + (-1,)))
f = func(x, *work.args)
f = xp.asarray(f, dtype=dtype)
if preserve_shape:
x = xp.reshape(x, x_shape)
f = xp.reshape(f, x_shape)
work.nfev += 1 if x.ndim == 1 else x.shape[-1]
post_func_eval(x, f, work)
work.nit += 1
active = _check_termination(work, res, res_work_pairs, active,
check_termination, preserve_shape, xp)
if callback is not None:
temp = _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape, xp)
if _call_callback_maybe_halt(callback, temp):
cb_terminate = True
break
if xp_size(active) == 0:
break
post_termination_check(work)
work.status = xpx.at(work.status)[:].set(_ECALLBACK if cb_terminate else _ECONVERR)
return _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape, xp)
def _check_termination(work, res, res_work_pairs, active, check_termination,
preserve_shape, xp):
# Checks termination conditions, updates elements of `res` with
# corresponding elements of `work`, and compresses `work`.
stop = check_termination(work)
if xp.any(stop):
# update the active elements of the result object with the active
# elements for which a termination condition has been met
_update_active(work, res, res_work_pairs, active, stop, preserve_shape, xp)
if preserve_shape:
stop = stop[active]
proceed = ~stop
active = active[proceed]
if not preserve_shape:
# compress the arrays to avoid unnecessary computation
for key, val in work.items():
# `continued_fraction` hacks `n`; improve if this becomes a problem
if key in {'args', 'n'}:
continue
work[key] = val[proceed] if getattr(val, 'ndim', 0) > 0 else val
work.args = [arg[proceed] for arg in work.args]
return active
def _update_active(work, res, res_work_pairs, active, mask, preserve_shape, xp):
# Update `active` indices of the arrays in result object `res` with the
# contents of the scalars and arrays in `update_dict`. When provided,
# `mask` is a boolean array applied both to the arrays in `update_dict`
# that are to be used and to the arrays in `res` that are to be updated.
update_dict = {key1: work[key2] for key1, key2 in res_work_pairs}
update_dict['success'] = work.status == 0
if mask is not None:
if preserve_shape:
active_mask = xp.zeros_like(mask)
active_mask = xpx.at(active_mask)[active].set(True)
active_mask = active_mask & mask
for key, val in update_dict.items():
val = val[active_mask] if getattr(val, 'ndim', 0) > 0 else val
res[key] = xpx.at(res[key])[active_mask].set(val)
else:
active_mask = active[mask]
for key, val in update_dict.items():
val = val[mask] if getattr(val, 'ndim', 0) > 0 else val
res[key] = xpx.at(res[key])[active_mask].set(val)
else:
for key, val in update_dict.items():
if preserve_shape and getattr(val, 'ndim', 0) > 0:
val = val[active]
res[key] = xpx.at(res[key])[active].set(val)
def _prepare_result(work, res, res_work_pairs, active, shape, customize_result,
preserve_shape, xp):
# Prepare the result object `res` by creating a copy, copying the latest
# data from work, running the provided result customization function,
# and reshaping the data to the original shapes.
res = res.copy()
_update_active(work, res, res_work_pairs, active, None, preserve_shape, xp)
shape = customize_result(res, shape)
for key, val in res.items():
# this looks like it won't work for xp != np if val is not numeric
temp = xp.reshape(val, shape)
res[key] = temp[()] if temp.ndim == 0 else temp
res['_order_keys'] = ['success'] + [i for i, j in res_work_pairs]
return _RichResult(**res)

View file

@ -0,0 +1,105 @@
"""
Module for testing automatic garbage collection of objects
.. autosummary::
:toctree: generated/
set_gc_state - enable or disable garbage collection
gc_state - context manager for given state of garbage collector
assert_deallocated - context manager to check for circular references on object
"""
import weakref
import gc
from contextlib import contextmanager
from platform import python_implementation
__all__ = ['set_gc_state', 'gc_state', 'assert_deallocated']
IS_PYPY = python_implementation() == 'PyPy'
class ReferenceError(AssertionError):
pass
def set_gc_state(state):
""" Set status of garbage collector """
if gc.isenabled() == state:
return
if state:
gc.enable()
else:
gc.disable()
@contextmanager
def gc_state(state):
""" Context manager to set state of garbage collector to `state`
Parameters
----------
state : bool
True for gc enabled, False for disabled
Examples
--------
>>> with gc_state(False):
... assert not gc.isenabled()
>>> with gc_state(True):
... assert gc.isenabled()
"""
orig_state = gc.isenabled()
set_gc_state(state)
yield
set_gc_state(orig_state)
@contextmanager
def assert_deallocated(func, *args, **kwargs):
"""Context manager to check that object is deallocated
This is useful for checking that an object can be freed directly by
reference counting, without requiring gc to break reference cycles.
GC is disabled inside the context manager.
This check is not available on PyPy.
Parameters
----------
func : callable
Callable to create object to check
\\*args : sequence
positional arguments to `func` in order to create object to check
\\*\\*kwargs : dict
keyword arguments to `func` in order to create object to check
Examples
--------
>>> class C: pass
>>> with assert_deallocated(C) as c:
... # do something
... del c
>>> class C:
... def __init__(self):
... self._circular = self # Make circular reference
>>> with assert_deallocated(C) as c: #doctest: +IGNORE_EXCEPTION_DETAIL
... # do something
... del c
Traceback (most recent call last):
...
ReferenceError: Remaining reference(s) to object
"""
if IS_PYPY:
raise RuntimeError("assert_deallocated is unavailable on PyPy")
with gc_state(False):
obj = func(*args, **kwargs)
ref = weakref.ref(obj)
yield obj
del obj
if ref() is not None:
raise ReferenceError("Remaining reference(s) to object")

View file

@ -0,0 +1,487 @@
"""Utility to compare pep440 compatible version strings.
The LooseVersion and StrictVersion classes that distutils provides don't
work; they don't recognize anything like alpha/beta/rc/dev versions.
"""
# Copyright (c) Donald Stufft and individual contributors.
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import collections
import itertools
import re
__all__ = [
"parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN",
]
# BEGIN packaging/_structures.py
class Infinity:
def __repr__(self):
return "Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return False
def __le__(self, other):
return False
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return True
def __ge__(self, other):
return True
def __neg__(self):
return NegativeInfinity
Infinity = Infinity()
class NegativeInfinity:
def __repr__(self):
return "-Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return True
def __le__(self, other):
return True
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return False
def __ge__(self, other):
return False
def __neg__(self):
return Infinity
# BEGIN packaging/version.py
NegativeInfinity = NegativeInfinity()
_Version = collections.namedtuple(
"_Version",
["epoch", "release", "dev", "pre", "post", "local"],
)
def parse(version):
"""
Parse the given version string and return either a :class:`Version` object
or a :class:`LegacyVersion` object depending on if the given version is
a valid PEP 440 version or a legacy version.
"""
try:
return Version(version)
except InvalidVersion:
return LegacyVersion(version)
class InvalidVersion(ValueError):
"""
An invalid version was found, users should refer to PEP 440.
"""
class _BaseVersion:
def __hash__(self):
return hash(self._key)
def __lt__(self, other):
return self._compare(other, lambda s, o: s < o)
def __le__(self, other):
return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other):
return self._compare(other, lambda s, o: s == o)
def __ge__(self, other):
return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other):
return self._compare(other, lambda s, o: s > o)
def __ne__(self, other):
return self._compare(other, lambda s, o: s != o)
def _compare(self, other, method):
if not isinstance(other, _BaseVersion):
return NotImplemented
return method(self._key, other._key)
class LegacyVersion(_BaseVersion):
def __init__(self, version):
self._version = str(version)
self._key = _legacy_cmpkey(self._version)
def __str__(self):
return self._version
def __repr__(self):
return f"<LegacyVersion({repr(str(self))})>"
@property
def public(self):
return self._version
@property
def base_version(self):
return self._version
@property
def local(self):
return None
@property
def is_prerelease(self):
return False
@property
def is_postrelease(self):
return False
_legacy_version_component_re = re.compile(
r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE,
)
_legacy_version_replacement_map = {
"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@",
}
def _parse_version_parts(s):
for part in _legacy_version_component_re.split(s):
part = _legacy_version_replacement_map.get(part, part)
if not part or part == ".":
continue
if part[:1] in "0123456789":
# pad for numeric comparison
yield part.zfill(8)
else:
yield "*" + part
# ensure that alpha/beta/candidate are before final
yield "*final"
def _legacy_cmpkey(version):
# We hardcode an epoch of -1 here. A PEP 440 version can only have an epoch
# greater than or equal to 0. This will effectively put the LegacyVersion,
# which uses the defacto standard originally implemented by setuptools,
# as before all PEP 440 versions.
epoch = -1
# This scheme is taken from pkg_resources.parse_version setuptools prior to
# its adoption of the packaging library.
parts = []
for part in _parse_version_parts(version.lower()):
if part.startswith("*"):
# remove "-" before a prerelease tag
if part < "*final":
while parts and parts[-1] == "*final-":
parts.pop()
# remove trailing zeros from each series of numeric parts
while parts and parts[-1] == "00000000":
parts.pop()
parts.append(part)
parts = tuple(parts)
return epoch, parts
# Deliberately not anchored to the start and end of the string, to make it
# easier for 3rd party code to reuse
VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
class Version(_BaseVersion):
_regex = re.compile(
r"^\s*" + VERSION_PATTERN + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
def __init__(self, version):
# Validate the version and parse it into pieces
match = self._regex.search(version)
if not match:
raise InvalidVersion(f"Invalid version: '{version}'")
# Store the parsed out pieces of the version
self._version = _Version(
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
release=tuple(int(i) for i in match.group("release").split(".")),
pre=_parse_letter_version(
match.group("pre_l"),
match.group("pre_n"),
),
post=_parse_letter_version(
match.group("post_l"),
match.group("post_n1") or match.group("post_n2"),
),
dev=_parse_letter_version(
match.group("dev_l"),
match.group("dev_n"),
),
local=_parse_local_version(match.group("local")),
)
# Generate a key which will be used for sorting
self._key = _cmpkey(
self._version.epoch,
self._version.release,
self._version.pre,
self._version.post,
self._version.dev,
self._version.local,
)
def __repr__(self):
return f"<Version({repr(str(self))})>"
def __str__(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append(f"{self._version.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
# Pre-release
if self._version.pre is not None:
parts.append("".join(str(x) for x in self._version.pre))
# Post-release
if self._version.post is not None:
parts.append(f".post{self._version.post[1]}")
# Development release
if self._version.dev is not None:
parts.append(f".dev{self._version.dev[1]}")
# Local version segment
if self._version.local is not None:
parts.append(
"+{}".format(".".join(str(x) for x in self._version.local))
)
return "".join(parts)
@property
def public(self):
return str(self).split("+", 1)[0]
@property
def base_version(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append(f"{self._version.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
return "".join(parts)
@property
def local(self):
version_string = str(self)
if "+" in version_string:
return version_string.split("+", 1)[1]
@property
def is_prerelease(self):
return bool(self._version.dev or self._version.pre)
@property
def is_postrelease(self):
return bool(self._version.post)
def _parse_letter_version(letter, number):
if letter:
# We assume there is an implicit 0 in a pre-release if there is
# no numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower-case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
elif letter in ["rev", "r"]:
letter = "post"
return letter, int(number)
if not letter and number:
# We assume that if we are given a number but not given a letter,
# then this is using the implicit post release syntax (e.g., 1.0-1)
letter = "post"
return letter, int(number)
_local_version_seperators = re.compile(r"[\._-]")
def _parse_local_version(local):
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in _local_version_seperators.split(local)
)
def _cmpkey(epoch, release, pre, post, dev, local):
# When we compare a release version, we want to compare it with all of the
# trailing zeros removed. So we'll use a reverse the list, drop all the now
# leading zeros until we come to something non-zero, then take the rest,
# re-reverse it back into the correct order, and make it a tuple and use
# that for our sorting key.
release = tuple(
reversed(list(
itertools.dropwhile(
lambda x: x == 0,
reversed(release),
)
))
)
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
# We'll do this by abusing the pre-segment, but we _only_ want to do this
# if there is no pre- or a post-segment. If we have one of those, then
# the normal sorting rules will handle this case correctly.
if pre is None and post is None and dev is not None:
pre = -Infinity
# Versions without a pre-release (except as noted above) should sort after
# those with one.
elif pre is None:
pre = Infinity
# Versions without a post-segment should sort before those with one.
if post is None:
post = -Infinity
# Versions without a development segment should sort after those with one.
if dev is None:
dev = Infinity
if local is None:
# Versions without a local segment should sort before those with one.
local = -Infinity
else:
# Versions with a local segment need that segment parsed to implement
# the sorting rules in PEP440.
# - Alphanumeric segments sort before numeric segments
# - Alphanumeric segments sort lexicographically
# - Numeric segments sort numerically
# - Shorter versions sort before longer versions when the prefixes
# match exactly
local = tuple(
(i, "") if isinstance(i, int) else (-Infinity, i)
for i in local
)
return epoch, release, pre, post, dev, local

View file

@ -0,0 +1,41 @@
from abc import ABC
__all__ = ["SparseABC", "issparse"]
class SparseABC(ABC):
pass
def issparse(x):
"""Is `x` of a sparse array or sparse matrix type?
Parameters
----------
x
object to check for being a sparse array or sparse matrix
Returns
-------
bool
True if `x` is a sparse array or a sparse matrix, False otherwise
Notes
-----
Use `isinstance(x, sp.sparse.sparray)` to check between an array or matrix.
Use `a.format` to check the sparse format, e.g. `a.format == 'csr'`.
Examples
--------
>>> import numpy as np
>>> from scipy.sparse import csr_array, csr_matrix, issparse
>>> issparse(csr_matrix([[5]]))
True
>>> issparse(csr_array([[5]]))
True
>>> issparse(np.array([[5]]))
False
>>> issparse(5)
False
"""
return isinstance(x, SparseABC)

View file

@ -0,0 +1,373 @@
"""
Generic test utilities.
"""
import inspect
import os
import re
import shutil
import subprocess
import sys
import sysconfig
import threading
from importlib.util import module_from_spec, spec_from_file_location
import numpy as np
import scipy
try:
# Need type: ignore[import-untyped] for mypy >= 1.6
import cython # type: ignore[import-untyped]
from Cython.Compiler.Version import ( # type: ignore[import-untyped]
version as cython_version,
)
except ImportError:
cython = None
else:
from scipy._lib import _pep440
required_version = '3.0.8'
if _pep440.parse(cython_version) < _pep440.Version(required_version):
# too old or wrong cython, skip Cython API tests
cython = None
__all__ = ['PytestTester', 'check_free_memory', '_TestPythranFunc', 'IS_MUSL']
IS_MUSL = False
# alternate way is
# from packaging.tags import sys_tags
# _tags = list(sys_tags())
# if 'musllinux' in _tags[0].platform:
_v = sysconfig.get_config_var('HOST_GNU_TYPE') or ''
if 'musl' in _v:
IS_MUSL = True
IS_EDITABLE = 'editable' in scipy.__path__[0]
class FPUModeChangeWarning(RuntimeWarning):
"""Warning about FPU mode change"""
pass
class PytestTester:
"""
Run tests for this namespace
``scipy.test()`` runs tests for all of SciPy, with the default settings.
When used from a submodule (e.g., ``scipy.cluster.test()``, only the tests
for that namespace are run.
Parameters
----------
label : {'fast', 'full'}, optional
Whether to run only the fast tests, or also those marked as slow.
Default is 'fast'.
verbose : int, optional
Test output verbosity. Default is 1.
extra_argv : list, optional
Arguments to pass through to Pytest.
doctests : bool, optional
Whether to run doctests or not. Default is False.
coverage : bool, optional
Whether to run tests with code coverage measurements enabled.
Default is False.
tests : list of str, optional
List of module names to run tests for. By default, uses the module
from which the ``test`` function is called.
parallel : int, optional
Run tests in parallel with pytest-xdist, if number given is larger than
1. Default is 1.
"""
def __init__(self, module_name):
self.module_name = module_name
def __call__(self, label="fast", verbose=1, extra_argv=None, doctests=False,
coverage=False, tests=None, parallel=None):
import pytest
module = sys.modules[self.module_name]
module_path = os.path.abspath(module.__path__[0])
pytest_args = ['--showlocals', '--tb=short']
if extra_argv is None:
extra_argv = []
pytest_args += extra_argv
if any(arg == "-m" or arg == "--markers" for arg in extra_argv):
# Likely conflict with default --mode=fast
raise ValueError("Must specify -m before --")
if verbose and int(verbose) > 1:
pytest_args += ["-" + "v"*(int(verbose)-1)]
if coverage:
pytest_args += ["--cov=" + module_path]
if label == "fast":
pytest_args += ["-m", "not slow"]
elif label != "full":
pytest_args += ["-m", label]
if tests is None:
tests = [self.module_name]
if parallel is not None and parallel > 1:
if _pytest_has_xdist():
pytest_args += ['-n', str(parallel)]
else:
import warnings
warnings.warn('Could not run tests in parallel because '
'pytest-xdist plugin is not available.',
stacklevel=2)
pytest_args += ['--pyargs'] + list(tests)
try:
code = pytest.main(pytest_args)
except SystemExit as exc:
code = exc.code
return (code == 0)
class _TestPythranFunc:
'''
These are situations that can be tested in our pythran tests:
- A function with multiple array arguments and then
other positional and keyword arguments.
- A function with array-like keywords (e.g. `def somefunc(x0, x1=None)`.
Note: list/tuple input is not yet tested!
`self.arguments`: A dictionary which key is the index of the argument,
value is tuple(array value, all supported dtypes)
`self.partialfunc`: A function used to freeze some non-array argument
that of no interests in the original function
'''
ALL_INTEGER = [np.int8, np.int16, np.int32, np.int64, np.intc, np.intp]
ALL_FLOAT = [np.float32, np.float64]
ALL_COMPLEX = [np.complex64, np.complex128]
def setup_method(self):
self.arguments = {}
self.partialfunc = None
self.expected = None
def get_optional_args(self, func):
# get optional arguments with its default value,
# used for testing keywords
signature = inspect.signature(func)
optional_args = {}
for k, v in signature.parameters.items():
if v.default is not inspect.Parameter.empty:
optional_args[k] = v.default
return optional_args
def get_max_dtype_list_length(self):
# get the max supported dtypes list length in all arguments
max_len = 0
for arg_idx in self.arguments:
cur_len = len(self.arguments[arg_idx][1])
if cur_len > max_len:
max_len = cur_len
return max_len
def get_dtype(self, dtype_list, dtype_idx):
# get the dtype from dtype_list via index
# if the index is out of range, then return the last dtype
if dtype_idx > len(dtype_list)-1:
return dtype_list[-1]
else:
return dtype_list[dtype_idx]
def test_all_dtypes(self):
for type_idx in range(self.get_max_dtype_list_length()):
args_array = []
for arg_idx in self.arguments:
new_dtype = self.get_dtype(self.arguments[arg_idx][1],
type_idx)
args_array.append(self.arguments[arg_idx][0].astype(new_dtype))
self.pythranfunc(*args_array)
def test_views(self):
args_array = []
for arg_idx in self.arguments:
args_array.append(self.arguments[arg_idx][0][::-1][::-1])
self.pythranfunc(*args_array)
def test_strided(self):
args_array = []
for arg_idx in self.arguments:
args_array.append(np.repeat(self.arguments[arg_idx][0],
2, axis=0)[::2])
self.pythranfunc(*args_array)
def _pytest_has_xdist():
"""
Check if the pytest-xdist plugin is installed, providing parallel tests
"""
# Check xdist exists without importing, otherwise pytests emits warnings
from importlib.util import find_spec
return find_spec('xdist') is not None
def check_free_memory(free_mb):
"""
Check *free_mb* of memory is available, otherwise do pytest.skip
"""
import pytest
try:
mem_free = _parse_size(os.environ['SCIPY_AVAILABLE_MEM'])
msg = '{} MB memory required, but environment SCIPY_AVAILABLE_MEM={}'.format(
free_mb, os.environ['SCIPY_AVAILABLE_MEM'])
except KeyError:
mem_free = _get_mem_available()
if mem_free is None:
pytest.skip("Could not determine available memory; set SCIPY_AVAILABLE_MEM "
"variable to free memory in MB to run the test.")
msg = f'{free_mb} MB memory required, but {mem_free/1e6} MB available'
if mem_free < free_mb * 1e6:
pytest.skip(msg)
def _parse_size(size_str):
suffixes = {'': 1e6,
'b': 1.0,
'k': 1e3, 'M': 1e6, 'G': 1e9, 'T': 1e12,
'kb': 1e3, 'Mb': 1e6, 'Gb': 1e9, 'Tb': 1e12,
'kib': 1024.0, 'Mib': 1024.0**2, 'Gib': 1024.0**3, 'Tib': 1024.0**4}
m = re.match(r'^\s*(\d+)\s*({})\s*$'.format('|'.join(suffixes.keys())),
size_str,
re.I)
if not m or m.group(2) not in suffixes:
raise ValueError("Invalid size string")
return float(m.group(1)) * suffixes[m.group(2)]
def _get_mem_available():
"""
Get information about memory available, not counting swap.
"""
try:
import psutil
return psutil.virtual_memory().available
except (ImportError, AttributeError):
pass
if sys.platform.startswith('linux'):
info = {}
with open('/proc/meminfo') as f:
for line in f:
p = line.split()
info[p[0].strip(':').lower()] = float(p[1]) * 1e3
if 'memavailable' in info:
# Linux >= 3.14
return info['memavailable']
else:
return info['memfree'] + info['cached']
return None
def _test_cython_extension(tmp_path, srcdir):
"""
Helper function to test building and importing Cython modules that
make use of the Cython APIs for BLAS, LAPACK, optimize, and special.
"""
import pytest
try:
subprocess.check_call(["meson", "--version"])
except FileNotFoundError:
pytest.skip("No usable 'meson' found")
# Make safe for being called by multiple threads within one test
tmp_path = tmp_path / str(threading.get_ident())
# build the examples in a temporary directory
mod_name = os.path.split(srcdir)[1]
shutil.copytree(srcdir, tmp_path / mod_name)
build_dir = tmp_path / mod_name / 'tests' / '_cython_examples'
target_dir = build_dir / 'build'
os.makedirs(target_dir, exist_ok=True)
# Ensure we use the correct Python interpreter even when `meson` is
# installed in a different Python environment (see numpy#24956)
native_file = str(build_dir / 'interpreter-native-file.ini')
with open(native_file, 'w') as f:
f.write("[binaries]\n")
f.write(f"python = '{sys.executable}'")
if sys.platform == "win32":
subprocess.check_call(["meson", "setup",
"--buildtype=release",
"--native-file", native_file,
"--vsenv", str(build_dir)],
cwd=target_dir,
)
else:
subprocess.check_call(["meson", "setup",
"--native-file", native_file, str(build_dir)],
cwd=target_dir
)
subprocess.check_call(["meson", "compile", "-vv"], cwd=target_dir)
# import without adding the directory to sys.path
suffix = sysconfig.get_config_var('EXT_SUFFIX')
def load(modname):
so = (target_dir / modname).with_suffix(suffix)
spec = spec_from_file_location(modname, so)
mod = module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# test that the module can be imported
return load("extending"), load("extending_cpp")
def _run_concurrent_barrier(n_workers, fn, *args, **kwargs):
"""
Run a given function concurrently across a given number of threads.
This is equivalent to using a ThreadPoolExecutor, but using the threading
primitives instead. This function ensures that the closure passed by
parameter gets called concurrently by setting up a barrier before it gets
called before any of the threads.
Arguments
---------
n_workers: int
Number of concurrent threads to spawn.
fn: callable
Function closure to execute concurrently. Its first argument will
be the thread id.
*args: tuple
Variable number of positional arguments to pass to the function.
**kwargs: dict
Keyword arguments to pass to the function.
"""
barrier = threading.Barrier(n_workers)
def closure(i, *args, **kwargs):
barrier.wait()
fn(i, *args, **kwargs)
workers = []
for i in range(0, n_workers):
workers.append(threading.Thread(
target=closure,
args=(i,) + args, kwargs=kwargs))
for worker in workers:
worker.start()
for worker in workers:
worker.join()

View file

@ -0,0 +1,58 @@
import threading
import scipy._lib.decorator
__all__ = ['ReentrancyError', 'ReentrancyLock', 'non_reentrant']
class ReentrancyError(RuntimeError):
pass
class ReentrancyLock:
"""
Threading lock that raises an exception for reentrant calls.
Calls from different threads are serialized, and nested calls from the
same thread result to an error.
The object can be used as a context manager or to decorate functions
via the decorate() method.
"""
def __init__(self, err_msg):
self._rlock = threading.RLock()
self._entered = False
self._err_msg = err_msg
def __enter__(self):
self._rlock.acquire()
if self._entered:
self._rlock.release()
raise ReentrancyError(self._err_msg)
self._entered = True
def __exit__(self, type, value, traceback):
self._entered = False
self._rlock.release()
def decorate(self, func):
def caller(func, *a, **kw):
with self:
return func(*a, **kw)
return scipy._lib.decorator.decorate(func, caller)
def non_reentrant(err_msg=None):
"""
Decorate a function with a threading lock and prevent reentrant calls.
"""
def decorator(func):
msg = err_msg
if msg is None:
msg = f"{func.__name__} is not re-entrant"
lock = ReentrancyLock(msg)
return lock.decorate(func)
return decorator

View file

@ -0,0 +1,86 @@
''' Contexts for *with* statement providing temporary directories
'''
import os
from contextlib import contextmanager
from shutil import rmtree
from tempfile import mkdtemp
@contextmanager
def tempdir():
"""Create and return a temporary directory. This has the same
behavior as mkdtemp but can be used as a context manager.
Upon exiting the context, the directory and everything contained
in it are removed.
Examples
--------
>>> import os
>>> with tempdir() as tmpdir:
... fname = os.path.join(tmpdir, 'example_file.txt')
... with open(fname, 'wt') as fobj:
... _ = fobj.write('a string\\n')
>>> os.path.exists(tmpdir)
False
"""
d = mkdtemp()
yield d
rmtree(d)
@contextmanager
def in_tempdir():
''' Create, return, and change directory to a temporary directory
Examples
--------
>>> import os
>>> my_cwd = os.getcwd()
>>> with in_tempdir() as tmpdir:
... _ = open('test.txt', 'wt').write('some text')
... assert os.path.isfile('test.txt')
... assert os.path.isfile(os.path.join(tmpdir, 'test.txt'))
>>> os.path.exists(tmpdir)
False
>>> os.getcwd() == my_cwd
True
'''
pwd = os.getcwd()
d = mkdtemp()
os.chdir(d)
yield d
os.chdir(pwd)
rmtree(d)
@contextmanager
def in_dir(dir=None):
""" Change directory to given directory for duration of ``with`` block
Useful when you want to use `in_tempdir` for the final test, but
you are still debugging. For example, you may want to do this in the end:
>>> with in_tempdir() as tmpdir:
... # do something complicated which might break
... pass
But, indeed, the complicated thing does break, and meanwhile, the
``in_tempdir`` context manager wiped out the directory with the
temporary files that you wanted for debugging. So, while debugging, you
replace with something like:
>>> with in_dir() as tmpdir: # Use working directory by default
... # do something complicated which might break
... pass
You can then look at the temporary file outputs to debug what is happening,
fix, and finally replace ``in_dir`` with ``in_tempdir`` again.
"""
cwd = os.getcwd()
if dir is None:
yield cwd
return
os.chdir(dir)
yield dir
os.chdir(cwd)

View file

@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2018, Quansight-Labs
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,116 @@
"""
.. note:
If you are looking for overrides for NumPy-specific methods, see the
documentation for :obj:`unumpy`. This page explains how to write
back-ends and multimethods.
``uarray`` is built around a back-end protocol, and overridable multimethods.
It is necessary to define multimethods for back-ends to be able to override them.
See the documentation of :obj:`generate_multimethod` on how to write multimethods.
Let's start with the simplest:
``__ua_domain__`` defines the back-end *domain*. The domain consists of period-
separated string consisting of the modules you extend plus the submodule. For
example, if a submodule ``module2.submodule`` extends ``module1``
(i.e., it exposes dispatchables marked as types available in ``module1``),
then the domain string should be ``"module1.module2.submodule"``.
For the purpose of this demonstration, we'll be creating an object and setting
its attributes directly. However, note that you can use a module or your own type
as a backend as well.
>>> class Backend: pass
>>> be = Backend()
>>> be.__ua_domain__ = "ua_examples"
It might be useful at this point to sidetrack to the documentation of
:obj:`generate_multimethod` to find out how to generate a multimethod
overridable by :obj:`uarray`. Needless to say, writing a backend and
creating multimethods are mostly orthogonal activities, and knowing
one doesn't necessarily require knowledge of the other, although it
is certainly helpful. We expect core API designers/specifiers to write the
multimethods, and implementors to override them. But, as is often the case,
similar people write both.
Without further ado, here's an example multimethod:
>>> import uarray as ua
>>> from uarray import Dispatchable
>>> def override_me(a, b):
... return Dispatchable(a, int),
>>> def override_replacer(args, kwargs, dispatchables):
... return (dispatchables[0], args[1]), {}
>>> overridden_me = ua.generate_multimethod(
... override_me, override_replacer, "ua_examples"
... )
Next comes the part about overriding the multimethod. This requires
the ``__ua_function__`` protocol, and the ``__ua_convert__``
protocol. The ``__ua_function__`` protocol has the signature
``(method, args, kwargs)`` where ``method`` is the passed
multimethod, ``args``/``kwargs`` specify the arguments and ``dispatchables``
is the list of converted dispatchables passed in.
>>> def __ua_function__(method, args, kwargs):
... return method.__name__, args, kwargs
>>> be.__ua_function__ = __ua_function__
The other protocol of interest is the ``__ua_convert__`` protocol. It has the
signature ``(dispatchables, coerce)``. When ``coerce`` is ``False``, conversion
between the formats should ideally be an ``O(1)`` operation, but it means that
no memory copying should be involved, only views of the existing data.
>>> def __ua_convert__(dispatchables, coerce):
... for d in dispatchables:
... if d.type is int:
... if coerce and d.coercible:
... yield str(d.value)
... else:
... yield d.value
>>> be.__ua_convert__ = __ua_convert__
Now that we have defined the backend, the next thing to do is to call the multimethod.
>>> with ua.set_backend(be):
... overridden_me(1, "2")
('override_me', (1, '2'), {})
Note that the marked type has no effect on the actual type of the passed object.
We can also coerce the type of the input.
>>> with ua.set_backend(be, coerce=True):
... overridden_me(1, "2")
... overridden_me(1.0, "2")
('override_me', ('1', '2'), {})
('override_me', ('1.0', '2'), {})
Another feature is that if you remove ``__ua_convert__``, the arguments are not
converted at all and it's up to the backend to handle that.
>>> del be.__ua_convert__
>>> with ua.set_backend(be):
... overridden_me(1, "2")
('override_me', (1, '2'), {})
You also have the option to return ``NotImplemented``, in which case processing moves on
to the next back-end, which in this case, doesn't exist. The same applies to
``__ua_convert__``.
>>> be.__ua_function__ = lambda *a, **kw: NotImplemented
>>> with ua.set_backend(be):
... overridden_me(1, "2")
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
The last possibility is if we don't have ``__ua_convert__``, in which case the job is
left up to ``__ua_function__``, but putting things back into arrays after conversion
will not be possible.
"""
from ._backend import *
__version__ = '0.8.8.dev0+aa94c5a4.scipy'

View file

@ -0,0 +1,707 @@
import typing
import types
import inspect
import functools
from . import _uarray
import copyreg
import pickle
import contextlib
import threading
from ._uarray import ( # type: ignore
BackendNotImplementedError,
_Function,
_SkipBackendContext,
_SetBackendContext,
_BackendState,
)
__all__ = [
"set_backend",
"set_global_backend",
"skip_backend",
"register_backend",
"determine_backend",
"determine_backend_multi",
"clear_backends",
"create_multimethod",
"generate_multimethod",
"_Function",
"BackendNotImplementedError",
"Dispatchable",
"wrap_single_convertor",
"wrap_single_convertor_instance",
"all_of_type",
"mark_as",
"set_state",
"get_state",
"reset_state",
"_BackendState",
"_SkipBackendContext",
"_SetBackendContext",
]
ArgumentExtractorType = typing.Callable[..., tuple["Dispatchable", ...]]
ArgumentReplacerType = typing.Callable[
[tuple, dict, tuple], tuple[tuple, dict]
]
def unpickle_function(mod_name, qname, self_):
import importlib
try:
module = importlib.import_module(mod_name)
qname = qname.split(".")
func = module
for q in qname:
func = getattr(func, q)
if self_ is not None:
func = types.MethodType(func, self_)
return func
except (ImportError, AttributeError) as e:
from pickle import UnpicklingError
raise UnpicklingError from e
def pickle_function(func):
mod_name = getattr(func, "__module__", None)
qname = getattr(func, "__qualname__", None)
self_ = getattr(func, "__self__", None)
try:
test = unpickle_function(mod_name, qname, self_)
except pickle.UnpicklingError:
test = None
if test is not func:
raise pickle.PicklingError(
f"Can't pickle {func}: it's not the same object as {test}"
)
return unpickle_function, (mod_name, qname, self_)
def pickle_state(state):
return _uarray._BackendState._unpickle, state._pickle()
def pickle_set_backend_context(ctx):
return _SetBackendContext, ctx._pickle()
def pickle_skip_backend_context(ctx):
return _SkipBackendContext, ctx._pickle()
copyreg.pickle(_Function, pickle_function)
copyreg.pickle(_uarray._BackendState, pickle_state)
copyreg.pickle(_SetBackendContext, pickle_set_backend_context)
copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context)
def get_state():
"""
Returns an opaque object containing the current state of all the backends.
Can be used for synchronization between threads/processes.
See Also
--------
set_state
Sets the state returned by this function.
"""
return _uarray.get_state()
@contextlib.contextmanager
def reset_state():
"""
Returns a context manager that resets all state once exited.
See Also
--------
set_state
Context manager that sets the backend state.
get_state
Gets a state to be set by this context manager.
"""
with set_state(get_state()):
yield
@contextlib.contextmanager
def set_state(state):
"""
A context manager that sets the state of the backends to one returned by :obj:`get_state`.
See Also
--------
get_state
Gets a state to be set by this context manager.
""" # noqa: E501
old_state = get_state()
_uarray.set_state(state)
try:
yield
finally:
_uarray.set_state(old_state, True)
def create_multimethod(*args, **kwargs):
"""
Creates a decorator for generating multimethods.
This function creates a decorator that can be used with an argument
extractor in order to generate a multimethod. Other than for the
argument extractor, all arguments are passed on to
:obj:`generate_multimethod`.
See Also
--------
generate_multimethod
Generates a multimethod.
"""
def wrapper(a):
return generate_multimethod(a, *args, **kwargs)
return wrapper
def generate_multimethod(
argument_extractor: ArgumentExtractorType,
argument_replacer: ArgumentReplacerType,
domain: str,
default: typing.Callable | None = None,
):
"""
Generates a multimethod.
Parameters
----------
argument_extractor : ArgumentExtractorType
A callable which extracts the dispatchable arguments. Extracted arguments
should be marked by the :obj:`Dispatchable` class. It has the same signature
as the desired multimethod.
argument_replacer : ArgumentReplacerType
A callable with the signature (args, kwargs, dispatchables), which should also
return an (args, kwargs) pair with the dispatchables replaced inside the
args/kwargs.
domain : str
A string value indicating the domain of this multimethod.
default: Optional[Callable], optional
The default implementation of this multimethod, where ``None`` (the default)
specifies there is no default implementation.
Examples
--------
In this example, ``a`` is to be dispatched over, so we return it, while marking it
as an ``int``.
The trailing comma is needed because the args have to be returned as an iterable.
>>> def override_me(a, b):
... return Dispatchable(a, int),
Next, we define the argument replacer that replaces the dispatchables inside
args/kwargs with the supplied ones.
>>> def override_replacer(args, kwargs, dispatchables):
... return (dispatchables[0], args[1]), {}
Next, we define the multimethod.
>>> overridden_me = generate_multimethod(
... override_me, override_replacer, "ua_examples"
... )
Notice that there's no default implementation, unless you supply one.
>>> overridden_me(1, "a")
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
>>> overridden_me2 = generate_multimethod(
... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
... )
>>> overridden_me2(1, "a")
(1, 'a')
See Also
--------
uarray
See the module documentation for how to override the method by creating
backends.
"""
kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
ua_func = _Function(
argument_extractor,
argument_replacer,
domain,
arg_defaults,
kw_defaults,
default,
)
return functools.update_wrapper(ua_func, argument_extractor)
def set_backend(backend, coerce=False, only=False):
"""
A context manager that sets the preferred backend.
Parameters
----------
backend
The backend to set.
coerce
Whether or not to coerce to a specific backend's types. Implies ``only``.
only
Whether or not this should be the last backend to try.
See Also
--------
skip_backend: A context manager that allows skipping of backends.
set_global_backend: Set a single, global backend for a domain.
"""
tid = threading.get_native_id()
try:
return backend.__ua_cache__[tid, "set", coerce, only]
except AttributeError:
backend.__ua_cache__ = {}
except KeyError:
pass
ctx = _SetBackendContext(backend, coerce, only)
backend.__ua_cache__[tid, "set", coerce, only] = ctx
return ctx
def skip_backend(backend):
"""
A context manager that allows one to skip a given backend from processing
entirely. This allows one to use another backend's code in a library that
is also a consumer of the same backend.
Parameters
----------
backend
The backend to skip.
See Also
--------
set_backend: A context manager that allows setting of backends.
set_global_backend: Set a single, global backend for a domain.
"""
tid = threading.get_native_id()
try:
return backend.__ua_cache__[tid, "skip"]
except AttributeError:
backend.__ua_cache__ = {}
except KeyError:
pass
ctx = _SkipBackendContext(backend)
backend.__ua_cache__[tid, "skip"] = ctx
return ctx
def get_defaults(f):
sig = inspect.signature(f)
kw_defaults = {}
arg_defaults = []
opts = set()
for k, v in sig.parameters.items():
if v.default is not inspect.Parameter.empty:
kw_defaults[k] = v.default
if v.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
arg_defaults.append(v.default)
opts.add(k)
return kw_defaults, tuple(arg_defaults), opts
def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
"""
This utility method replaces the default backend for permanent use. It
will be tried in the list of backends automatically, unless the
``only`` flag is set on a backend. This will be the first tried
backend outside the :obj:`set_backend` context manager.
Note that this method is not thread-safe.
.. warning::
We caution library authors against using this function in
their code. We do *not* support this use-case. This function
is meant to be used only by users themselves, or by a reference
implementation, if one exists.
Parameters
----------
backend
The backend to register.
coerce : bool
Whether to coerce input types when trying this backend.
only : bool
If ``True``, no more backends will be tried if this fails.
Implied by ``coerce=True``.
try_last : bool
If ``True``, the global backend is tried after registered backends.
See Also
--------
set_backend: A context manager that allows setting of backends.
skip_backend: A context manager that allows skipping of backends.
"""
_uarray.set_global_backend(backend, coerce, only, try_last)
def register_backend(backend):
"""
This utility method sets registers backend for permanent use. It
will be tried in the list of backends automatically, unless the
``only`` flag is set on a backend.
Note that this method is not thread-safe.
Parameters
----------
backend
The backend to register.
"""
_uarray.register_backend(backend)
def clear_backends(domain, registered=True, globals=False):
"""
This utility method clears registered backends.
.. warning::
We caution library authors against using this function in
their code. We do *not* support this use-case. This function
is meant to be used only by users themselves.
.. warning::
Do NOT use this method inside a multimethod call, or the
program is likely to crash.
Parameters
----------
domain : Optional[str]
The domain for which to de-register backends. ``None`` means
de-register for all domains.
registered : bool
Whether or not to clear registered backends. See :obj:`register_backend`.
globals : bool
Whether or not to clear global backends. See :obj:`set_global_backend`.
See Also
--------
register_backend : Register a backend globally.
set_global_backend : Set a global backend.
"""
_uarray.clear_backends(domain, registered, globals)
class Dispatchable:
"""
A utility class which marks an argument with a specific dispatch type.
Attributes
----------
value
The value of the Dispatchable.
type
The type of the Dispatchable.
Examples
--------
>>> x = Dispatchable(1, str)
>>> x
<Dispatchable: type=<class 'str'>, value=1>
See Also
--------
all_of_type
Marks all unmarked parameters of a function.
mark_as
Allows one to create a utility function to mark as a given type.
"""
def __init__(self, value, dispatch_type, coercible=True):
self.value = value
self.type = dispatch_type
self.coercible = coercible
def __getitem__(self, index):
return (self.type, self.value)[index]
def __str__(self):
return f"<{type(self).__name__}: type={self.type!r}, value={self.value!r}>"
__repr__ = __str__
def mark_as(dispatch_type):
"""
Creates a utility function to mark something as a specific type.
Examples
--------
>>> mark_int = mark_as(int)
>>> mark_int(1)
<Dispatchable: type=<class 'int'>, value=1>
"""
return functools.partial(Dispatchable, dispatch_type=dispatch_type)
def all_of_type(arg_type):
"""
Marks all unmarked arguments as a given type.
Examples
--------
>>> @all_of_type(str)
... def f(a, b):
... return a, Dispatchable(b, int)
>>> f('a', 1)
(<Dispatchable: type=<class 'str'>, value='a'>,
<Dispatchable: type=<class 'int'>, value=1>)
"""
def outer(func):
@functools.wraps(func)
def inner(*args, **kwargs):
extracted_args = func(*args, **kwargs)
return tuple(
Dispatchable(arg, arg_type)
if not isinstance(arg, Dispatchable)
else arg
for arg in extracted_args
)
return inner
return outer
def wrap_single_convertor(convert_single):
"""
Wraps a ``__ua_convert__`` defined for a single element to all elements.
If any of them return ``NotImplemented``, the operation is assumed to be
undefined.
Accepts a signature of (value, type, coerce).
"""
@functools.wraps(convert_single)
def __ua_convert__(dispatchables, coerce):
converted = []
for d in dispatchables:
c = convert_single(d.value, d.type, coerce and d.coercible)
if c is NotImplemented:
return NotImplemented
converted.append(c)
return converted
return __ua_convert__
def wrap_single_convertor_instance(convert_single):
"""
Wraps a ``__ua_convert__`` defined for a single element to all elements.
If any of them return ``NotImplemented``, the operation is assumed to be
undefined.
Accepts a signature of (value, type, coerce).
"""
@functools.wraps(convert_single)
def __ua_convert__(self, dispatchables, coerce):
converted = []
for d in dispatchables:
c = convert_single(self, d.value, d.type, coerce and d.coercible)
if c is NotImplemented:
return NotImplemented
converted.append(c)
return converted
return __ua_convert__
def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
"""Set the backend to the first active backend that supports ``value``
This is useful for functions that call multimethods without any dispatchable
arguments. You can use :func:`determine_backend` to ensure the same backend
is used everywhere in a block of multimethod calls.
Parameters
----------
value
The value being tested
dispatch_type
The dispatch type associated with ``value``, aka
":ref:`marking <MarkingGlossary>`".
domain: string
The domain to query for backends and set.
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only: bool
Whether or not this should be the last backend to try.
See Also
--------
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
>>> with ua.set_backend(ex.BackendA):
... ex.call_multimethod(ex.TypeB(), ex.TypeB())
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
Now consider a multimethod that creates a new object of ``TypeA``, or
``TypeB`` depending on the active backend.
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, ex.TypeA())
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
innermost with statement. So, ``call_multimethod`` fails since the types
don't match.
Instead, we need to first find a backend suitable for all of our objects.
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
... x = ex.TypeA()
... with ua.determine_backend(x, "mark", domain="ua_examples"):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, x)
TypeA
"""
dispatchables = (Dispatchable(value, dispatch_type, coerce),)
backend = _uarray.determine_backend(domain, dispatchables, coerce)
return set_backend(backend, coerce=coerce, only=only)
def determine_backend_multi(
dispatchables, *, domain, only=True, coerce=False, **kwargs
):
"""Set a backend supporting all ``dispatchables``
This is useful for functions that call multimethods without any dispatchable
arguments. You can use :func:`determine_backend_multi` to ensure the same
backend is used everywhere in a block of multimethod calls involving
multiple arrays.
Parameters
----------
dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
The dispatchables that must be supported
domain: string
The domain to query for backends and set.
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only: bool
Whether or not this should be the last backend to try.
dispatch_type: Optional[Any]
The default dispatch type associated with ``dispatchables``, aka
":ref:`marking <MarkingGlossary>`".
See Also
--------
determine_backend: For a single dispatch value
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
:func:`determine_backend` allows the backend to be set from a single
object. :func:`determine_backend_multi` allows multiple objects to be
checked simultaneously for support in the backend. Suppose we have a
``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
and a ``BackendBC`` that doesn't support ``TypeA``.
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
... a, b = ex.TypeA(), ex.TypeB()
... with ua.determine_backend_multi(
... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
... domain="ua_examples"
... ):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, a, b)
TypeA
This won't call ``BackendBC`` because it doesn't support ``TypeA``.
We can also use leave out the ``ua.Dispatchable`` if we specify the
default ``dispatch_type`` for the ``dispatchables`` argument.
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
... a, b = ex.TypeA(), ex.TypeB()
... with ua.determine_backend_multi(
... [a, b], dispatch_type="mark", domain="ua_examples"
... ):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, a, b)
TypeA
"""
if "dispatch_type" in kwargs:
disp_type = kwargs.pop("dispatch_type")
dispatchables = tuple(
d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
for d in dispatchables
)
else:
dispatchables = tuple(dispatchables)
if not all(isinstance(d, Dispatchable) for d in dispatchables):
raise TypeError("dispatchables must be instances of uarray.Dispatchable")
if len(kwargs) != 0:
raise TypeError(f"Received unexpected keyword arguments: {kwargs}")
backend = _uarray.determine_backend(domain, dispatchables, coerce)
return set_backend(backend, coerce=coerce, only=only)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,22 @@
"""
NumPy Array API compatibility library
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are
compatible with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
changes needed to be compliant with the Array API. See
https://numpy.org/doc/stable/reference/array_api.html for a full list of
changes. In particular, unlike array_api_strict, this package does not use a
separate Array object, but rather just uses numpy.ndarray directly.
Library authors using the Array API may wish to test against array_api_strict
to ensure they are not using functionality outside of the standard, but prefer
this implementation for the default when working with NumPy arrays.
"""
__version__ = '1.12.0'
from .common import * # noqa: F401, F403

View file

@ -0,0 +1,59 @@
"""
Internal helpers
"""
from collections.abc import Callable
from functools import wraps
from inspect import signature
from types import ModuleType
from typing import TypeVar
_T = TypeVar("_T")
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""
Decorator to automatically replace xp with the corresponding array module.
Use like
import numpy as np
@get_xp(np)
def func(x, /, xp, kwarg=None):
return xp.func(x, kwarg=kwarg)
Note that xp must be a keyword argument and come after all non-keyword
arguments.
"""
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
@wraps(f)
def wrapped_f(*args: object, **kwargs: object) -> object:
return f(*args, xp=xp, **kwargs)
sig = signature(f)
new_sig = sig.replace(
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
)
if wrapped_f.__doc__ is None:
wrapped_f.__doc__ = f"""\
Array API compatibility wrapper for {f.__name__}.
See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.
"""
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]
return inner
__all__ = ["get_xp"]
def __dir__() -> list[str]:
return __all__

View file

@ -0,0 +1 @@
from ._helpers import * # noqa: F403

View file

@ -0,0 +1,727 @@
"""
These are functions that are just aliases of existing functions in NumPy.
"""
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
from ._helpers import _check_device, array_namespace
from ._helpers import device as _get_device
from ._helpers import is_cupy_namespace as _is_cupy_namespace
from ._typing import Array, Device, DType, Namespace
if TYPE_CHECKING:
# TODO: import from typing (requires Python >=3.13)
from typing_extensions import TypeIs
# These functions are modified from the NumPy versions.
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
def arange(
start: float,
/,
stop: float | None = None,
step: float = 1,
*,
xp: Namespace,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
def empty(
shape: int | tuple[int, ...],
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.empty(shape, dtype=dtype, **kwargs)
def empty_like(
x: Array,
/,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.empty_like(x, dtype=dtype, **kwargs)
def eye(
n_rows: int,
n_cols: int | None = None,
/,
*,
xp: Namespace,
k: int = 0,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
def full(
shape: int | tuple[int, ...],
fill_value: complex,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
def full_like(
x: Array,
/,
fill_value: complex,
*,
xp: Namespace,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
def linspace(
start: float,
stop: float,
/,
num: int,
*,
xp: Namespace,
dtype: DType | None = None,
device: Device | None = None,
endpoint: bool = True,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
def ones(
shape: int | tuple[int, ...],
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.ones(shape, dtype=dtype, **kwargs)
def ones_like(
x: Array,
/,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.ones_like(x, dtype=dtype, **kwargs)
def zeros(
shape: int | tuple[int, ...],
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.zeros(shape, dtype=dtype, **kwargs)
def zeros_like(
x: Array,
/,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.zeros_like(x, dtype=dtype, **kwargs)
# np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
# to remove polymorphic return types).
# The functions here return namedtuples (np.unique() returns a normal
# tuple).
# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
class UniqueAllResult(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array
class UniqueCountsResult(NamedTuple):
values: Array
counts: Array
class UniqueInverseResult(NamedTuple):
values: Array
inverse_indices: Array
def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
# trying to parse version numbers, just check if equal_nan is in the
# signature.
s = inspect.signature(xp.unique)
if "equal_nan" in s.parameters:
return {"equal_nan": False}
return {}
def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
kwargs = _unique_kwargs(xp)
values, indices, inverse_indices, counts = xp.unique(
x,
return_counts=True,
return_index=True,
return_inverse=True,
**kwargs,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueAllResult(
values,
indices,
inverse_indices,
counts,
)
def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
kwargs = _unique_kwargs(xp)
res = xp.unique(
x, return_counts=True, return_index=False, return_inverse=False, **kwargs
)
return UniqueCountsResult(*res)
def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult:
kwargs = _unique_kwargs(xp)
values, inverse_indices = xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=True,
**kwargs,
)
# xp.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueInverseResult(values, inverse_indices)
def unique_values(x: Array, /, xp: Namespace) -> Array:
kwargs = _unique_kwargs(xp)
return xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=False,
**kwargs,
)
# These functions have different keyword argument names
def std(
x: Array,
/,
xp: Namespace,
*,
axis: int | tuple[int, ...] | None = None,
correction: float = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs: object,
) -> Array:
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
def var(
x: Array,
/,
xp: Namespace,
*,
axis: int | tuple[int, ...] | None = None,
correction: float = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs: object,
) -> Array:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
# argument
def cumulative_sum(
x: Array,
/,
xp: Namespace,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
**kwargs: object,
) -> Array:
wrapped_xp = array_namespace(x)
# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
raise ValueError(
"axis must be specified in cumulative_sum for more than one dimension"
)
axis = 0
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
# np.cumsum does not support include_initial
if include_initial:
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[
wrapped_xp.zeros(
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
),
res,
],
axis=axis,
)
return res
def cumulative_prod(
x: Array,
/,
xp: Namespace,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
**kwargs: object,
) -> Array:
wrapped_xp = array_namespace(x)
if axis is None:
if x.ndim > 1:
raise ValueError(
"axis must be specified in cumulative_prod for more than one dimension"
)
axis = 0
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
# np.cumprod does not support include_initial
if include_initial:
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[
wrapped_xp.ones(
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
),
res,
],
axis=axis,
)
return res
# The min and max argument names in clip are different and not optional in numpy, and type
# promotion behavior is different.
def clip(
x: Array,
/,
min: float | Array | None = None,
max: float | Array | None = None,
*,
xp: Namespace,
# TODO: np.clip has other ufunc kwargs
out: Array | None = None,
) -> Array:
def _isscalar(a: object) -> TypeIs[int | float | None]:
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
wrapped_xp = array_namespace(x)
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
# avoiding uint64 -> float64 promotion).
# Note: cases where min or max overflow (integer) or round (float) in the
# wrong direction when downcasting to x.dtype are unspecified. This code
# just does whatever NumPy does when it downcasts in the assignment, but
# other behavior could be preferred, especially for integers. For example,
# this code produces:
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
# -128
# but an answer of 0 might be preferred. See
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
if wrapped_xp.isdtype(x.dtype, "integral"):
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
min = None
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
max = None
dev = _get_device(x)
if out is None:
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
assert out is not None # workaround for a type-narrowing issue in pyright
out[()] = x
if min is not None:
a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev)
a = xp.broadcast_to(a, result_shape)
ia = (out < a) | xp.isnan(a)
out[ia] = a[ia]
if max is not None:
b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev)
b = xp.broadcast_to(b, result_shape)
ib = (out > b) | xp.isnan(b)
out[ib] = b[ib]
# Return a scalar for 0-D
return out[()]
# Unlike transpose(), the axes argument to permute_dims() is required.
def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array:
return xp.transpose(x, axes)
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
def reshape(
x: Array,
/,
shape: tuple[int, ...],
xp: Namespace,
*,
copy: Optional[bool] = None,
**kwargs: object,
) -> Array:
if copy is True:
x = x.copy()
elif copy is False:
y = x.view()
y.shape = shape
return y
return xp.reshape(x, shape, **kwargs)
# The descending keyword is new in sort and argsort, and 'kind' replaced with
# 'stable'
def argsort(
x: Array,
/,
xp: Namespace,
*,
axis: int = -1,
descending: bool = False,
stable: bool = True,
**kwargs: object,
) -> Array:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
kwargs["kind"] = "stable"
if not descending:
res = xp.argsort(x, axis=axis, **kwargs)
else:
# As NumPy has no native descending sort, we imitate it here. Note that
# simply flipping the results of xp.argsort(x, ...) would not
# respect the relative order like it would in native descending sorts.
res = xp.flip(
xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
axis=axis,
)
# Rely on flip()/argsort() to validate axis
normalised_axis = axis if axis >= 0 else x.ndim + axis
max_i = x.shape[normalised_axis] - 1
res = max_i - res
return res
def sort(
x: Array,
/,
xp: Namespace,
*,
axis: int = -1,
descending: bool = False,
stable: bool = True,
**kwargs: object,
) -> Array:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
kwargs["kind"] = "stable"
res = xp.sort(x, axis=axis, **kwargs)
if descending:
res = xp.flip(res, axis=axis)
return res
# nonzero should error for zero-dimensional arrays
def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)
# ceil, floor, and trunc return integers for integer inputs
def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.ceil(x, **kwargs)
def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.floor(x, **kwargs)
def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.trunc(x, **kwargs)
# linear algebra functions
def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
return xp.matmul(x1, x2, **kwargs)
# Unlike transpose, matrix_transpose only transposes the last two axes.
def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
if x.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return xp.swapaxes(x, -1, -2)
def tensordot(
x1: Array,
x2: Array,
/,
xp: Namespace,
*,
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
**kwargs: object,
) -> Array:
return xp.tensordot(x1, x2, axes=axes, **kwargs)
def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
if hasattr(xp, "broadcast_tensors"):
_broadcast = xp.broadcast_tensors
else:
_broadcast = xp.broadcast_arrays
x1_ = xp.moveaxis(x1, axis, -1)
x2_ = xp.moveaxis(x2, axis, -1)
x1_, x2_ = _broadcast(x1_, x2_)
res = xp.conj(x1_[..., None, :]) @ x2_[..., None]
return res[..., 0, 0]
# isdtype is a new function in the 2022.12 array API specification.
def isdtype(
dtype: DType,
kind: DType | str | tuple[DType | str, ...],
xp: Namespace,
*,
_tuple: bool = True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
Note that outside of this function, this compat library does not yet fully
support complex numbers.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple) and _tuple:
return any(
isdtype(dtype, k, xp, _tuple=False)
for k in cast("tuple[DType | str, ...]", kind)
)
elif isinstance(kind, str):
if kind == "bool":
return dtype == xp.bool_
elif kind == "signed integer":
return xp.issubdtype(dtype, xp.signedinteger)
elif kind == "unsigned integer":
return xp.issubdtype(dtype, xp.unsignedinteger)
elif kind == "integral":
return xp.issubdtype(dtype, xp.integer)
elif kind == "real floating":
return xp.issubdtype(dtype, xp.floating)
elif kind == "complex floating":
return xp.issubdtype(dtype, xp.complexfloating)
elif kind == "numeric":
return xp.issubdtype(dtype, xp.number)
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
# This will allow things that aren't required by the spec, like
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
# more strict here to match the type annotation? Note that the
# array_api_strict implementation will be very strict.
return dtype == kind
# unstack is a new function in the 2023.12 array API standard
def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("Input array must be at least 1-d.")
return tuple(xp.moveaxis(x, axis, 0))
# numpy 1.26 does not use the standard definition for sign on complex numbers
def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if isdtype(x.dtype, "complex floating", xp=xp):
out = (x / xp.abs(x, **kwargs))[...]
# sign(0) = 0 but the above formula would give nan
out[x == 0j] = 0j
else:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]
def finfo(type_: DType | Array, /, xp: Namespace) -> Any:
# It is surprisingly difficult to recognize a dtype apart from an array.
# np.int64 is not the same as np.asarray(1).dtype!
try:
return xp.finfo(type_)
except (ValueError, TypeError):
return xp.finfo(type_.dtype)
def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
try:
return xp.iinfo(type_)
except (ValueError, TypeError):
return xp.iinfo(type_.dtype)
__all__ = [
"arange",
"empty",
"empty_like",
"eye",
"full",
"full_like",
"linspace",
"ones",
"ones_like",
"zeros",
"zeros_like",
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"std",
"var",
"cumulative_sum",
"cumulative_prod",
"clip",
"permute_dims",
"reshape",
"argsort",
"sort",
"nonzero",
"ceil",
"floor",
"trunc",
"matmul",
"matrix_transpose",
"tensordot",
"vecdot",
"isdtype",
"unstack",
"sign",
"finfo",
"iinfo",
]
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
def __dir__() -> list[str]:
return __all__

View file

@ -0,0 +1,213 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Literal, TypeAlias
from ._typing import Array, Device, DType, Namespace
_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.
def fft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def fftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def rfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def rfftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def hfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.float32)
return res
def ihfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def fftfreq(
n: int,
/,
xp: Namespace,
*,
d: float = 1.0,
dtype: DType | None = None,
device: Device | None = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.fftfreq(n, d=d)
if dtype is not None:
return res.astype(dtype)
return res
def rfftfreq(
n: int,
/,
xp: Namespace,
*,
d: float = 1.0,
dtype: DType | None = None,
device: Device | None = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.rfftfreq(n, d=d)
if dtype is not None:
return res.astype(dtype)
return res
def fftshift(
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
) -> Array:
return xp.fft.fftshift(x, axes=axes)
def ifftshift(
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
) -> Array:
return xp.fft.ifftshift(x, axes=axes)
__all__ = [
"fft",
"ifft",
"fftn",
"ifftn",
"rfft",
"irfft",
"rfftn",
"irfftn",
"hfft",
"ihfft",
"fftfreq",
"rfftfreq",
"fftshift",
"ifftshift",
]
def __dir__() -> list[str]:
return __all__

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,232 @@
from __future__ import annotations
import math
from typing import Literal, NamedTuple, cast
import numpy as np
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
from numpy.core.numeric import normalize_axis_tuple
from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
from ._typing import Array, DType, JustFloat, JustInt, Namespace
# These are in the main NumPy namespace but not in numpy.linalg
def cross(
x1: Array,
x2: Array,
/,
xp: Namespace,
*,
axis: int = -1,
**kwargs: object,
) -> Array:
return xp.cross(x1, x2, axis=axis, **kwargs)
def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
return xp.outer(x1, x2, **kwargs)
class EighResult(NamedTuple):
eigenvalues: Array
eigenvectors: Array
class QRResult(NamedTuple):
Q: Array
R: Array
class SlogdetResult(NamedTuple):
sign: Array
logabsdet: Array
class SVDResult(NamedTuple):
U: Array
S: Array
Vh: Array
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
return EighResult(*xp.linalg.eigh(x, **kwargs))
def qr(
x: Array,
/,
xp: Namespace,
*,
mode: Literal["reduced", "complete"] = "reduced",
**kwargs: object,
) -> QRResult:
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
def svd(
x: Array,
/,
xp: Namespace,
*,
full_matrices: bool = True,
**kwargs: object,
) -> SVDResult:
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
# These functions have additional keyword arguments
# The upper keyword argument is new from NumPy
def cholesky(
x: Array,
/,
xp: Namespace,
*,
upper: bool = False,
**kwargs: object,
) -> Array:
L = xp.linalg.cholesky(x, **kwargs)
if upper:
U = get_xp(xp)(matrix_transpose)(L)
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
U = xp.conj(U) # pyright: ignore[reportConstantRedefinition]
return U
return L
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
# Note that it has a different semantic meaning from tol and rcond.
def matrix_rank(
x: Array,
/,
xp: Namespace,
*,
rtol: float | Array | None = None,
**kwargs: object,
) -> Array:
# this is different from xp.linalg.matrix_rank, which supports 1
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S: Array = get_xp(xp)(svdvals)(x, **kwargs)
if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
else:
# this is different from xp.linalg.matrix_rank, which does not
# multiply the tolerance by the largest singular value.
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
return xp.count_nonzero(S > tol, axis=-1)
def pinv(
x: Array,
/,
xp: Namespace,
*,
rtol: float | Array | None = None,
**kwargs: object,
) -> Array:
# this is different from xp.linalg.pinv, which does not multiply the
# default tolerance by max(M, N).
if rtol is None:
rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
return xp.linalg.pinv(x, rcond=rtol, **kwargs)
# These functions are new in the array API spec
def matrix_norm(
x: Array,
/,
xp: Namespace,
*,
keepdims: bool = False,
ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro",
) -> Array:
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# xp.linalg.svd(compute_uv=False).
def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
return xp.linalg.svd(x, compute_uv=False)
def vector_norm(
x: Array,
/,
xp: Namespace,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
ord: JustInt | JustFloat = 2,
) -> Array:
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
# it so the input is 1-D (for axis=None), or reshape so that norm is done
# on a single dimension.
if axis is None:
# Note: xp.linalg.norm() doesn't handle 0-D arrays
_x = x.ravel()
_axis = 0
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas
# xp.linalg.norm() only supports a single axis for vector norm.
normalized_axis = cast(
"tuple[int, ...]",
normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue]
)
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
newshape = axis + rest
_x = xp.transpose(x, newshape).reshape(
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
_axis = 0
else:
_x = x
_axis = axis
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
if keepdims:
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
_axis = cast(
"tuple[int, ...]",
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
range(x.ndim) if axis is None else axis,
x.ndim,
),
)
for i in _axis:
shape[i] = 1
res = xp.reshape(res, tuple(shape))
return res
# xp.diagonal and xp.trace operate on the first two axes whereas these
# operates on the last two
def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
def trace(
x: Array,
/,
xp: Namespace,
*,
offset: int = 0,
dtype: DType | None = None,
**kwargs: object,
) -> Array:
return xp.asarray(
xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)
)
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
def __dir__() -> list[str]:
return __all__

View file

@ -0,0 +1,192 @@
from __future__ import annotations
from collections.abc import Mapping
from types import ModuleType as Namespace
from typing import (
TYPE_CHECKING,
Literal,
Protocol,
TypeAlias,
TypedDict,
TypeVar,
final,
)
if TYPE_CHECKING:
from _typeshed import Incomplete
SupportsBufferProtocol: TypeAlias = Incomplete
Array: TypeAlias = Incomplete
Device: TypeAlias = Incomplete
DType: TypeAlias = Incomplete
else:
SupportsBufferProtocol = object
Array = object
Device = object
DType = object
_T_co = TypeVar("_T_co", covariant=True)
# These "Just" types are equivalent to the `Just` type from the `optype` library,
# apart from them not being `@runtime_checkable`.
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
@final
class JustInt(Protocol):
@property
def __class__(self, /) -> type[int]: ...
@__class__.setter
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
@final
class JustFloat(Protocol):
@property
def __class__(self, /) -> type[float]: ...
@__class__.setter
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
@final
class JustComplex(Protocol):
@property
def __class__(self, /) -> type[complex]: ...
@__class__.setter
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
#
class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
class SupportsArrayNamespace(Protocol[_T_co]):
def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
class HasShape(Protocol[_T_co]):
@property
def shape(self, /) -> _T_co: ...
# Return type of `__array_namespace_info__.default_dtypes`
Capabilities = TypedDict(
"Capabilities",
{
"boolean indexing": bool,
"data-dependent shapes": bool,
"max dimensions": int,
},
)
# Return type of `__array_namespace_info__.default_dtypes`
DefaultDTypes = TypedDict(
"DefaultDTypes",
{
"real floating": DType,
"complex floating": DType,
"integral": DType,
"indexing": DType,
},
)
_DTypeKind: TypeAlias = Literal[
"bool",
"signed integer",
"unsigned integer",
"integral",
"real floating",
"complex floating",
"numeric",
]
# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
# `__array_namespace_info__.dtypes(kind="bool")`
class DTypesBool(TypedDict):
bool: DType
# `__array_namespace_info__.dtypes(kind="signed integer")`
class DTypesSigned(TypedDict):
int8: DType
int16: DType
int32: DType
int64: DType
# `__array_namespace_info__.dtypes(kind="unsigned integer")`
class DTypesUnsigned(TypedDict):
uint8: DType
uint16: DType
uint32: DType
uint64: DType
# `__array_namespace_info__.dtypes(kind="integral")`
class DTypesIntegral(DTypesSigned, DTypesUnsigned):
pass
# `__array_namespace_info__.dtypes(kind="real floating")`
class DTypesReal(TypedDict):
float32: DType
float64: DType
# `__array_namespace_info__.dtypes(kind="complex floating")`
class DTypesComplex(TypedDict):
complex64: DType
complex128: DType
# `__array_namespace_info__.dtypes(kind="numeric")`
class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
pass
# `__array_namespace_info__.dtypes(kind=None)` (default)
class DTypesAll(DTypesBool, DTypesNumeric):
pass
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
DTypesAny: TypeAlias = Mapping[str, DType]
__all__ = [
"Array",
"Capabilities",
"DType",
"DTypeKind",
"DTypesAny",
"DTypesAll",
"DTypesBool",
"DTypesNumeric",
"DTypesIntegral",
"DTypesSigned",
"DTypesUnsigned",
"DTypesReal",
"DTypesComplex",
"DefaultDTypes",
"Device",
"HasShape",
"Namespace",
"JustInt",
"JustFloat",
"JustComplex",
"NestedSequence",
"SupportsArrayNamespace",
"SupportsBufferProtocol",
]
def __dir__() -> list[str]:
return __all__

View file

@ -0,0 +1,13 @@
from cupy import * # noqa: F403
# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round # noqa: F401
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
__array_api_version__ = '2024.12'

View file

@ -0,0 +1,156 @@
from __future__ import annotations
from typing import Optional
import cupy as cp
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from .._internal import get_xp
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType
bool = cp.bool_
# Basic renames
acos = cp.arccos
acosh = cp.arccosh
asin = cp.arcsin
asinh = cp.arcsinh
atan = cp.arctan
atan2 = cp.arctan2
atanh = cp.arctanh
bitwise_left_shift = cp.left_shift
bitwise_invert = cp.invert
bitwise_right_shift = cp.right_shift
concat = cp.concatenate
pow = cp.power
arange = get_xp(cp)(_aliases.arange)
empty = get_xp(cp)(_aliases.empty)
empty_like = get_xp(cp)(_aliases.empty_like)
eye = get_xp(cp)(_aliases.eye)
full = get_xp(cp)(_aliases.full)
full_like = get_xp(cp)(_aliases.full_like)
linspace = get_xp(cp)(_aliases.linspace)
ones = get_xp(cp)(_aliases.ones)
ones_like = get_xp(cp)(_aliases.ones_like)
zeros = get_xp(cp)(_aliases.zeros)
zeros_like = get_xp(cp)(_aliases.zeros_like)
UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult)
unique_all = get_xp(cp)(_aliases.unique_all)
unique_counts = get_xp(cp)(_aliases.unique_counts)
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
unique_values = get_xp(cp)(_aliases.unique_values)
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
clip = get_xp(cp)(_aliases.clip)
permute_dims = get_xp(cp)(_aliases.permute_dims)
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
ceil = get_xp(cp)(_aliases.ceil)
floor = get_xp(cp)(_aliases.floor)
trunc = get_xp(cp)(_aliases.trunc)
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
sign = get_xp(cp)(_aliases.sign)
finfo = get_xp(cp)(_aliases.finfo)
iinfo = get_xp(cp)(_aliases.iinfo)
# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: (
Array
| bool | int | float | complex
| NestedSequence[bool | int | float | complex]
| SupportsBufferProtocol
),
/,
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
copy: Optional[bool] = None,
**kwargs,
) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
with cp.cuda.Device(device):
if copy is None:
return cp.asarray(obj, dtype=dtype, **kwargs)
else:
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
if not copy and res is not obj:
raise ValueError("Unable to avoid copy while creating an array as requested")
return res
def astype(
x: Array,
dtype: DType,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> Array:
if device is None:
return x.astype(dtype=dtype, copy=copy)
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
return out.copy() if copy and out is x else out
# cupy.count_nonzero does not have keepdims
def count_nonzero(
x: Array,
axis=None,
keepdims=False
) -> Array:
result = cp.count_nonzero(x, axis)
if keepdims:
if axis is None:
return cp.reshape(result, [1]*x.ndim)
return cp.expand_dims(result, axis)
return result
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
return cp.take_along_axis(x, indices, axis=axis)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
vecdot = cp.vecdot
else:
vecdot = get_xp(cp)(_aliases.vecdot)
if hasattr(cp, 'isdtype'):
isdtype = cp.isdtype
else:
isdtype = get_xp(cp)(_aliases.isdtype)
if hasattr(cp, 'unstack'):
unstack = cp.unstack
else:
unstack = get_xp(cp)(_aliases.unstack)
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
'take_along_axis']
_all_ignore = ['cp', 'get_xp']

View file

@ -0,0 +1,336 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""
from cupy import (
dtype,
cuda,
bool_ as bool,
intp,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
complex64,
complex128,
)
class __array_namespace_info__:
"""
Get the array API inspection namespace for CuPy.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for CuPy.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': cupy.float64,
'complex floating': cupy.complex128,
'integral': cupy.int64,
'indexing': cupy.int64}
"""
__module__ = 'cupy'
def capabilities(self):
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing. Always ``True`` for CuPy.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
CuPy.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True,
'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
"max dimensions": 64,
}
def default_device(self):
"""
The default device used for new CuPy arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : Device
The default device used for new CuPy arrays.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_device()
Device(0)
Notes
-----
This method returns the static default device when CuPy is initialized.
However, the *current* device used by creation functions (``empty`` etc.)
can be changed globally or with a context manager.
See Also
--------
https://github.com/data-apis/array-api/issues/835
"""
return cuda.Device(0)
def default_dtypes(self, *, device=None):
"""
The default data types used for new CuPy arrays.
For CuPy, this always returns the following dictionary:
- **"real floating"**: ``cupy.float64``
- **"complex floating"**: ``cupy.complex128``
- **"integral"**: ``cupy.intp``
- **"indexing"**: ``cupy.intp``
Parameters
----------
device : str, optional
The device to get the default data types for.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new CuPy
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': cupy.float64,
'complex floating': cupy.complex128,
'integral': cupy.int64,
'indexing': cupy.int64}
"""
# TODO: Does this depend on device?
return {
"real floating": dtype(float64),
"complex floating": dtype(complex128),
"integral": dtype(intp),
"indexing": dtype(intp),
}
def dtypes(self, *, device=None, kind=None):
"""
The array API data types supported by CuPy.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : str, optional
The device to get the data types for.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
CuPy data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': cupy.int8,
'int16': cupy.int16,
'int32': cupy.int32,
'int64': cupy.int64}
"""
# TODO: Does this depend on device?
if kind is None:
return {
"bool": dtype(bool),
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
}
if kind == "unsigned integer":
return {
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "integral":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "real floating":
return {
"float32": dtype(float32),
"float64": dtype(float64),
}
if kind == "complex floating":
return {
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "numeric":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if isinstance(kind, tuple):
res = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
def devices(self):
"""
The devices supported by CuPy.
Returns
-------
devices : list[Device]
The devices supported by CuPy.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
"""
return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]

View file

@ -0,0 +1,31 @@
from __future__ import annotations
__all__ = ["Array", "DType", "Device"]
_all_ignore = ["cp"]
from typing import TYPE_CHECKING
import cupy as cp
from cupy import ndarray as Array
from cupy.cuda.device import Device
if TYPE_CHECKING:
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
DType = cp.dtype[
cp.intp
| cp.int8
| cp.int16
| cp.int32
| cp.int64
| cp.uint8
| cp.uint16
| cp.uint32
| cp.uint64
| cp.float32
| cp.float64
| cp.complex64
| cp.complex128
| cp.bool_
]
else:
DType = cp.dtype

View file

@ -0,0 +1,36 @@
from cupy.fft import * # noqa: F403
# cupy.fft doesn't have __all__. If it is added, replace this with
#
# from cupy.fft import __all__ as linalg_all
_n = {}
exec('from cupy.fft import *', _n)
del _n['__builtins__']
fft_all = list(_n)
del _n
from ..common import _fft
from .._internal import get_xp
import cupy as cp
fft = get_xp(cp)(_fft.fft)
ifft = get_xp(cp)(_fft.ifft)
fftn = get_xp(cp)(_fft.fftn)
ifftn = get_xp(cp)(_fft.ifftn)
rfft = get_xp(cp)(_fft.rfft)
irfft = get_xp(cp)(_fft.irfft)
rfftn = get_xp(cp)(_fft.rfftn)
irfftn = get_xp(cp)(_fft.irfftn)
hfft = get_xp(cp)(_fft.hfft)
ihfft = get_xp(cp)(_fft.ihfft)
fftfreq = get_xp(cp)(_fft.fftfreq)
rfftfreq = get_xp(cp)(_fft.rfftfreq)
fftshift = get_xp(cp)(_fft.fftshift)
ifftshift = get_xp(cp)(_fft.ifftshift)
__all__ = fft_all + _fft.__all__
del get_xp
del cp
del fft_all
del _fft

View file

@ -0,0 +1,49 @@
from cupy.linalg import * # noqa: F403
# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n
from ..common import _linalg
from .._internal import get_xp
import cupy as cp
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
__all__ = linalg_all + _linalg.__all__
del get_xp
del cp
del linalg_all
del _linalg

View file

@ -0,0 +1,12 @@
from typing import Final
from dask.array import * # noqa: F403
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
__array_api_version__: Final = "2024.12"
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')

View file

@ -0,0 +1,376 @@
# pyright: reportPrivateUsage=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
from __future__ import annotations
from builtins import bool as py_bool
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from typing_extensions import TypeIs
import dask.array as da
import numpy as np
from numpy import bool_ as bool
from numpy import (
can_cast,
complex64,
complex128,
float32,
float64,
int8,
int16,
int32,
int64,
result_type,
uint8,
uint16,
uint32,
uint64,
)
from ..._internal import get_xp
from ...common import _aliases, _helpers, array_namespace
from ...common._typing import (
Array,
Device,
DType,
NestedSequence,
SupportsBufferProtocol,
)
from ._info import __array_namespace_info__
isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)
# da.astype doesn't respect copy=True
def astype(
x: Array,
dtype: DType,
/,
*,
copy: py_bool = True,
device: Device | None = None,
) -> Array:
"""
Array API compatibility wrapper for astype().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)
if not copy and dtype == x.dtype:
return x
x = x.astype(dtype)
return x.copy() if copy else x
# Common aliases
# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask
def arange(
start: float,
/,
stop: float | None = None,
step: float = 1,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
"""
Array API compatibility wrapper for arange().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)
args: list[Any] = [start]
if stop is not None:
args.append(stop)
else:
# stop is None, so start is actually stop
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return da.arange(*args, dtype=dtype, **kwargs)
eye = get_xp(da)(_aliases.eye)
linspace = get_xp(da)(_aliases.linspace)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)
nonzero = get_xp(da)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)
finfo = get_xp(np)(_aliases.finfo)
iinfo = get_xp(np)(_aliases.iinfo)
# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
/,
*,
dtype: DType | None = None,
device: Device | None = None,
copy: py_bool | None = None,
**kwargs: object,
) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)
if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
if copy is False:
raise ValueError("Unable to avoid copy when changing dtype")
obj = obj.astype(dtype)
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
if copy is False:
raise ValueError(
"Unable to avoid copy when converting a non-dask object to dask"
)
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)
# Element wise aliases
from dask.array import arccos as acos
from dask.array import arccosh as acosh
from dask.array import arcsin as asin
from dask.array import arcsinh as asinh
from dask.array import arctan as atan
from dask.array import arctan2 as atan2
from dask.array import arctanh as atanh
# Other
from dask.array import concatenate as concat
from dask.array import invert as bitwise_invert
from dask.array import left_shift as bitwise_left_shift
from dask.array import power as pow
from dask.array import right_shift as bitwise_right_shift
# dask.array.clip does not work unless all three arguments are provided.
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
# now).
def clip(
x: Array,
/,
min: float | Array | None = None,
max: float | Array | None = None,
) -> Array:
"""
Array API compatibility wrapper for clip().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
return a is None or isinstance(a, (int, float))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
# TODO: This won't handle dask unknown shapes
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
if min is not None:
min = da.broadcast_to(da.asarray(min), result_shape)
if max is not None:
max = da.broadcast_to(da.asarray(max), result_shape)
if min is None and max is None:
return da.positive(x)
if min is None:
return astype(da.minimum(x, max), x.dtype)
if max is None:
return astype(da.maximum(x, min), x.dtype)
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
"""
Make sure that Array is not broken into multiple chunks along axis.
Returns
-------
x : Array
The input Array with a single chunk along axis.
restore : Callable[Array, Array]
function to apply to the output to rechunk it back into reasonable chunks
"""
if axis < 0:
axis += x.ndim
if x.numblocks[axis] < 2:
return x, lambda x: x
# Break chunks on other axes in an attempt to keep chunk size low
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
# Rather than reconstructing the original chunks, which can be a
# very expensive affair, just break down oversized chunks without
# incurring in any transfers over the network.
# This has the downside of a risk of overchunking if the array is
# then used in operations against other arrays that match the
# original chunking pattern.
return x, lambda x: x.rechunk()
def sort(
x: Array,
/,
*,
axis: int = -1,
descending: py_bool = False,
stable: py_bool = True,
) -> Array:
"""
Array API compatibility layer around the lack of sort() in Dask.
Warnings
--------
This function temporarily rechunks the array along `axis` to a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
x, restore = _ensure_single_chunk(x, axis)
meta_xp = array_namespace(x._meta)
x = da.map_blocks(
meta_xp.sort,
x,
axis=axis,
meta=x._meta,
dtype=x.dtype,
descending=descending,
stable=stable,
)
return restore(x)
def argsort(
x: Array,
/,
*,
axis: int = -1,
descending: py_bool = False,
stable: py_bool = True,
) -> Array:
"""
Array API compatibility layer around the lack of argsort() in Dask.
See the corresponding documentation in the array library and/or the array API
specification for more details.
Warnings
--------
This function temporarily rechunks the array along `axis` into a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
"""
x, restore = _ensure_single_chunk(x, axis)
meta_xp = array_namespace(x._meta)
dtype = meta_xp.argsort(x._meta).dtype
meta = meta_xp.astype(x._meta, dtype)
x = da.map_blocks(
meta_xp.argsort,
x,
axis=axis,
meta=meta,
dtype=dtype,
descending=descending,
stable=stable,
)
return restore(x)
# dask.array.count_nonzero does not have keepdims
def count_nonzero(
x: Array,
axis: int | None = None,
keepdims: py_bool = False,
) -> Array:
result = da.count_nonzero(x, axis)
if keepdims:
if axis is None:
return da.reshape(result, [1] * x.ndim)
return da.expand_dims(result, axis)
return result
__all__ = [
"__array_namespace_info__",
"count_nonzero",
"bool",
"int8", "int16", "int32", "int64",
"uint8", "uint16", "uint32", "uint64",
"float32", "float64",
"complex64", "complex128",
"asarray", "astype", "can_cast", "result_type",
"pow",
"concat",
"acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
] # fmt: skip
__all__ += _aliases.__all__
_all_ignore = ["array_namespace", "get_xp", "da", "np"]
def __dir__() -> list[str]:
return __all__

View file

@ -0,0 +1,416 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""
# pyright: reportPrivateUsage=false
from __future__ import annotations
from typing import Literal as L
from typing import TypeAlias, overload
from numpy import bool_ as bool
from numpy import (
complex64,
complex128,
dtype,
float32,
float64,
int8,
int16,
int32,
int64,
intp,
uint8,
uint16,
uint32,
uint64,
)
from ...common._helpers import _DASK_DEVICE, _dask_device
from ...common._typing import (
Capabilities,
DefaultDTypes,
DType,
DTypeKind,
DTypesAll,
DTypesAny,
DTypesBool,
DTypesComplex,
DTypesIntegral,
DTypesNumeric,
DTypesReal,
DTypesSigned,
DTypesUnsigned,
)
_Device: TypeAlias = L["cpu"] | _dask_device
class __array_namespace_info__:
"""
Get the array API inspection namespace for Dask.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for Dask.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': dask.float64,
'complex floating': dask.complex128,
'integral': dask.int64,
'indexing': dask.int64}
"""
__module__ = "dask.array"
def capabilities(self) -> Capabilities:
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing.
Dask support boolean indexing as long as both the index
and the indexed arrays have known shapes.
Note however that the output .shape and .size properties
will contain a non-compliant math.nan instead of None.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes.
Dask implements unique_values et.al.
Note however that the output .shape and .size properties
will contain a non-compliant math.nan instead of None.
- **"max dimensions"**: integer indicating the maximum number of
dimensions supported by the array library.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True,
'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
"max dimensions": 64,
}
def default_device(self) -> L["cpu"]:
"""
The default device used for new Dask arrays.
For Dask, this always returns ``'cpu'``.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : Device
The default device used for new Dask arrays.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_device()
'cpu'
"""
return "cpu"
def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
"""
The default data types used for new Dask arrays.
For Dask, this always returns the following dictionary:
- **"real floating"**: ``numpy.float64``
- **"complex floating"**: ``numpy.complex128``
- **"integral"**: ``numpy.intp``
- **"indexing"**: ``numpy.intp``
Parameters
----------
device : str, optional
The device to get the default data types for.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new Dask
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': dask.float64,
'complex floating': dask.complex128,
'integral': dask.int64,
'indexing': dask.int64}
"""
if device not in ["cpu", _DASK_DEVICE, None]:
raise ValueError(
f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, '
f"but received: {device!r}"
)
return {
"real floating": dtype(float64),
"complex floating": dtype(complex128),
"integral": dtype(intp),
"indexing": dtype(intp),
}
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: None = None
) -> DTypesAll: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["bool"]
) -> DTypesBool: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["signed integer"]
) -> DTypesSigned: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["unsigned integer"]
) -> DTypesUnsigned: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["integral"]
) -> DTypesIntegral: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["real floating"]
) -> DTypesReal: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["complex floating"]
) -> DTypesComplex: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["numeric"]
) -> DTypesNumeric: ...
def dtypes(
self, /, *, device: _Device | None = None, kind: DTypeKind | None = None
) -> DTypesAny:
"""
The array API data types supported by Dask.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : str, optional
The device to get the data types for.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
Dask data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': dask.int8,
'int16': dask.int16,
'int32': dask.int32,
'int64': dask.int64}
"""
if device not in ["cpu", _DASK_DEVICE, None]:
raise ValueError(
'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
f" {device}"
)
if kind is None:
return {
"bool": dtype(bool),
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
}
if kind == "unsigned integer":
return {
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "integral":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "real floating":
return {
"float32": dtype(float32),
"float64": dtype(float64),
}
if kind == "complex floating":
return {
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "numeric":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall]
res: dict[str, DType] = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
def devices(self) -> list[_Device]:
"""
The devices supported by Dask.
For Dask, this always returns ``['cpu', DASK_DEVICE]``.
Returns
-------
devices : list[Device]
The devices supported by Dask.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.devices()
['cpu', DASK_DEVICE]
"""
return ["cpu", _DASK_DEVICE]

View file

@ -0,0 +1,21 @@
from dask.array.fft import * # noqa: F403
# dask.array.fft doesn't have __all__. If it is added, replace this with
#
# from dask.array.fft import __all__ as linalg_all
_n = {}
exec('from dask.array.fft import *', _n)
for k in ("__builtins__", "Sequence", "annotations", "warnings"):
_n.pop(k, None)
fft_all = list(_n)
del _n, k
from ...common import _fft
from ..._internal import get_xp
import dask.array as da
fftfreq = get_xp(da)(_fft.fftfreq)
rfftfreq = get_xp(da)(_fft.rfftfreq)
__all__ = fft_all + ["fftfreq", "rfftfreq"]
_all_ignore = ["da", "fft_all", "get_xp", "warnings"]

View file

@ -0,0 +1,72 @@
from __future__ import annotations
from typing import Literal
import dask.array as da
# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
from dask.array import matmul, outer, tensordot
# Exports
from dask.array.linalg import * # noqa: F403
from ..._internal import get_xp
from ...common import _linalg
from ...common._typing import Array as _Array
from ._aliases import matrix_transpose, vecdot
# dask.array.linalg doesn't have __all__. If it is added, replace this with
#
# from dask.array.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
_n.pop(k, None)
linalg_all = list(_n)
del _n, k
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
# TODO: use the QR wrapper once dask
# supports the mode keyword on QR
# https://github.com/dask/dask/issues/10388
#qr = get_xp(da)(_linalg.qr)
def qr(
x: _Array,
mode: Literal["reduced", "complete"] = "reduced",
**kwargs: object,
) -> QRResult:
if mode != "reduced":
raise ValueError("dask arrays only support using mode='reduced'")
return QRResult(*da.linalg.qr(x, **kwargs))
trace = get_xp(da)(_linalg.trace)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)
# Wrap the svd functions to not pass full_matrices to dask
# when full_matrices=False (as that is the default behavior for dask),
# and dask doesn't have the full_matrices keyword
def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult:
if full_matrices:
raise ValueError("full_matrics=True is not supported by dask.")
return da.linalg.svd(x, coerce_signs=False, **kwargs)
def svdvals(x: _Array) -> _Array:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s
vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
"matrix_transpose", "vecdot", "EighResult",
"QRResult", "SlogdetResult", "SVDResult", "qr",
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
"vector_norm", "diagonal"]
_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings']

View file

@ -0,0 +1,28 @@
# ruff: noqa: PLC0414
from typing import Final
from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary]
# from numpy import * doesn't overwrite these builtin names
from numpy import abs as abs
from numpy import max as max
from numpy import min as min
from numpy import round as round
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
#
# from . import linalg
#
# It doesn't overwrite np.linalg from above. The import is generated
# dynamically so that the library can be vendored.
__import__(__package__ + ".linalg")
__import__(__package__ + ".fft")
from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401
__array_api_version__: Final = "2024.12"

Some files were not shown because too many files have changed in this diff Show more