File unyt-pr512-np2.1.patch of Package python-unyt
From 5b419a27336796dd5f8e4e04e58c6b176550b89e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= <cr52@protonmail.com>
Date: Wed, 3 Jul 2024 11:19:49 +0200
Subject: [PATCH 1/3] BUG: test/implement array functions new in NumPy 2.1
(np.cumulative_sum and np.cumulative_prod)
---
unyt/_array_functions.py | 7 +++++++
unyt/tests/test_array_functions.py | 33 ++++++++++++++++++++++--------
2 files changed, 32 insertions(+), 8 deletions(-)
diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py
index bd752d7f..033fcb3a 100644
--- a/unyt/_array_functions.py
+++ b/unyt/_array_functions.py
@@ -696,6 +696,13 @@ def cumprod(a, *args, **kwargs):
)
+if NUMPY_VERSION >= Version("2.1.0.dev0"):
+
+ @implements(np.cumulative_prod)
+ def cumulative_prod(x, /, *args, **kwargs):
+ return cumprod(x, *args, **kwargs)
+
+
@implements(np.pad)
def pad(array, *args, **kwargs):
return np.pad._implementation(np.asarray(array), *args, **kwargs) * array.units
diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py
index a5960a94..50a59a7c 100644
--- a/unyt/tests/test_array_functions.py
+++ b/unyt/tests/test_array_functions.py
@@ -187,6 +187,7 @@
if NUMPY_VERSION >= Version("2.1.0dev0"):
NOOP_FUNCTIONS |= {
np.unstack,
+ np.cumulative_sum,
}
# Functions for which behaviour is intentionally left to default
@@ -1528,28 +1529,44 @@ def test_deltas(func, input_units, output_units):
@pytest.mark.parametrize(
- "func",
+ "func_name",
[
- np.cumsum,
- np.nancumsum,
+ "cumsum",
+ "nancumsum",
+ pytest.param(
+ "cumulative_sum",
+ marks=pytest.mark.skipif(
+ NUMPY_VERSION < Version("2.1.0dev0"),
+ reason="np.cumulative_sum is new in NumPy 2.1",
+ ),
+ ),
],
)
-def test_cumsum(func):
+def test_cumsum(func_name):
a = [1, 2, 3] * cm
+ func = getattr(np, func_name)
res = func(a)
assert type(res) is unyt_array
assert res.units == cm
@pytest.mark.parametrize(
- "func",
+ "func_name",
[
- np.cumprod,
- np.nancumprod,
+ "cumprod",
+ "nancumprod",
+ pytest.param(
+ "cumulative_prod",
+ marks=pytest.mark.skipif(
+ NUMPY_VERSION < Version("2.1.0dev0"),
+ reason="np.cumulative_prod is new in NumPy 2.1",
+ ),
+ ),
],
)
-def test_cumprod(func):
+def test_cumprod(func_name):
a = [1, 2, 3] * cm
+ func = getattr(np, func_name)
with pytest.raises(
UnytError,
match=re.escape(
From 0b039f7ff5dae29232ccc82289d16e2340b6c501 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= <cr52@protonmail.com>
Date: Wed, 3 Jul 2024 11:25:02 +0200
Subject: [PATCH 2/3] BUG: adapt np.clip's implementation signature at runtime
---
unyt/_array_functions.py | 17 +++++++++++++++--
1 file changed, 15 insertions(+), 2 deletions(-)
diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py
index 033fcb3a..7d6bcb3b 100644
--- a/unyt/_array_functions.py
+++ b/unyt/_array_functions.py
@@ -823,8 +823,7 @@ def sinc(x, *args, **kwargs):
return np.sinc._implementation(np.asarray(x), *args, **kwargs)
-@implements(np.clip)
-def clip(a, a_min, a_max, out=None, *args, **kwargs):
+def clip_impl(a, a_min, a_max, out=None, *args, **kwargs):
_validate_units_consistency_v2(a.units, a_min, a_max)
if out is None:
return (
@@ -850,6 +849,20 @@ def clip(a, a_min, a_max, out=None, *args, **kwargs):
return unyt_array(res, a.units, bypass_validation=True)
+if NUMPY_VERSION >= Version("2.1.0.dev0"):
+ from numpy._globals import _NoValue
+
+ @implements(np.clip)
+ def clip(a, a_min=_NoValue, a_max=_NoValue, out=None, *args, **kwargs):
+ return clip_impl(a, a_min, a_max, out, *args, **kwargs)
+
+else:
+
+ @implements(np.clip)
+ def clip(a, a_min, a_max, out=None, *args, **kwargs):
+ return clip_impl(a, a_min, a_max, out, *args, **kwargs)
+
+
@implements(np.where)
def where(condition, *args, **kwargs):
if len(args) == 0:
From 47a82ba5a142e267bac455a219e3fc79d39e2be3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= <cr52@protonmail.com>
Date: Mon, 19 Aug 2024 08:59:53 +0200
Subject: [PATCH 3/3] DOC: refactor a doctest for compatibility with numpy 2.1
---
docs/usage.rst | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/docs/usage.rst b/docs/usage.rst
index a1a7d378..d12be774 100644
--- a/docs/usage.rst
+++ b/docs/usage.rst
@@ -269,10 +269,9 @@ controlled identically to NumPy arrays, using ``numpy.setprintoptions``:
>>> import numpy as np
>>> import unyt as u
...
- >>> np.set_printoptions(precision=4)
- >>> print([1.123456789]*u.km)
+ >>> with np.printoptions(precision=4):
+ ... print([1.123456789]*u.km)
[1.1235] km
- >>> np.set_printoptions(precision=8)
Print a :math:`\rm{\LaTeX}` representation of a set of units using the
:meth:`unyt.unit_object.Unit.latex_representation` function or