File fix-get_h-test.patch of Package python-hankel

From 3307d7527325dc80ac441d845982f93fd84d5311 Mon Sep 17 00:00:00 2001
From: Steven Murray <steven.g.murray@asu.edu>
Date: Mon, 23 Sep 2024 12:48:01 +0200
Subject: [PATCH] test: better test ids for get_h

---
 hankel/tools.py     | 11 ++++-------
 tests/test_get_h.py | 11 ++++++++++-
 2 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/hankel/tools.py b/hankel/tools.py
index 5e7b7de..5096509 100755
--- a/hankel/tools.py
+++ b/hankel/tools.py
@@ -234,7 +234,7 @@ def get_h(
         i += 1
 
     if i == maxiter:
-        raise Exception("Maxiter reached while checking for non-zero values")
+        raise RuntimeError("Maxiter reached while checking for non-zero values")
 
     if K is None:  # Do a normal integral of f(x)J_nu(x)
         K = 1
@@ -263,7 +263,7 @@ def getres(h):
         res = getres(hstart)
 
     if i == maxiter:
-        raise Exception("Maxiter reached while checking convergence")
+        raise RuntimeError("Maxiter reached while checking convergence")
 
     # Can do some more trimming of N potentially, by seeing where f(x)~0.
     def consecutive(data, stepsize=1):
@@ -272,17 +272,14 @@ def consecutive(data, stepsize=1):
 
     hstart *= hdecrement
 
-    x = cls(nu, h=hstart, N=int(np.pi / hstart)).x
+    x = np.atleast_1d(cls(nu, h=hstart, N=int(np.pi / hstart)).x)
     lastk = np.where(f(x / np.max(K)) == 0)[0]
     if len(lastk) > 1:
         # if there are any that are zero,
         # and if there are more than 1 in a row
         # (otherwise might just be oscillatory)
         lastk = consecutive(lastk)  # split into arrays of consecutive zeros
-        if len(lastk[-1]) == 1:
-            lastk = int(np.pi / hstart)
-        else:
-            lastk = lastk[-1][0]
+        lastk = int(np.pi / hstart) if len(lastk[-1]) == 1 else lastk[-1][0]
     else:  # otherwise set back to N
         lastk = int(np.pi / hstart)
 
diff --git a/tests/test_get_h.py b/tests/test_get_h.py
index bc32ee3..f856c4b 100644
--- a/tests/test_get_h.py
+++ b/tests/test_get_h.py
@@ -26,7 +26,7 @@ def gammaincc_(a, x):
 @pytest.mark.parametrize(
     "f, anl",
     [
-        (lambda x: 1, 1),  # Ogata 05
+        (lambda x: np.ones_like(x), 1),  # Ogata 05
         (lambda x: x / (x**2 + 1), k0(1)),  # Ogata 05
         (lambda x: x**2, -1),  # wikipedia
         (lambda x: x**4, 9),  # wikipedia
@@ -40,6 +40,15 @@ def gammaincc_(a, x):
             1.0 / 2**2 * np.exp(-0.5 / 2**2),
         ),
     ],
+    ids=[
+        "ones",
+        "x / (x^2 + 1)",
+        "x^2",
+        "x^4",
+        "1 / sqrt(x)",
+        "x / sqrt(x^2 + 1)",
+        "exp(1/2 * x^2)",
+    ]
 )
 def test_nu0(f, anl):
     ans = get_h(f=f, nu=0, hstart=0.5, atol=0, rtol=1e-3, maxiter=20)[1]
openSUSE Build Service is sponsored by