File intc.patch of Package python-scipy
From 8501b7c2fb8a7121aeef94489ece988043c463d0 Mon Sep 17 00:00:00 2001
From: CJ Carey <perimosocordiae@gmail.com>
Date: Sun, 2 Jul 2023 16:29:22 -0400
Subject: [PATCH] BUG: sparse.linalg: Cast index arrays to intc before calling
SuperLU functions (#18644)
* BUG: sparse.linalg: Cast to intc before SuperLU
---------
Co-authored-by: Stefan van der Walt <stefanv@berkeley.edu>
Co-authored-by: Dan Schult <dschult@colgate.edu>
---
scipy/sparse/linalg/_dsolve/linsolve.py | 27 +++++++++++--
.../linalg/_dsolve/tests/test_linsolve.py | 38 ++++++++++---------
2 files changed, 44 insertions(+), 21 deletions(-)
diff --git a/scipy/sparse/linalg/_dsolve/linsolve.py b/scipy/sparse/linalg/_dsolve/linsolve.py
index 0c408f737266..78ba447bcd44 100644
--- a/scipy/sparse/linalg/_dsolve/linsolve.py
+++ b/scipy/sparse/linalg/_dsolve/linsolve.py
@@ -123,6 +123,21 @@ def _get_umf_family(A):
return family, A_new
+def _safe_downcast_indices(A):
+ # check for safe downcasting
+ max_value = np.iinfo(np.intc).max
+
+ if A.indptr[-1] > max_value: # indptr[-1] is max b/c indptr always sorted
+ raise ValueError("indptr values too large for SuperLU")
+
+ if max(*A.shape) > max_value: # only check large enough arrays
+ if np.any(A.indices > max_value):
+ raise ValueError("indices values too large for SuperLU")
+
+ indices = A.indices.astype(np.intc, copy=False)
+ indptr = A.indptr.astype(np.intc, copy=False)
+ return indices, indptr
+
def spsolve(A, b, permc_spec=None, use_umfpack=True):
"""Solve the sparse linear system Ax=b, where b may be a vector or a matrix.
@@ -269,8 +284,10 @@ def spsolve(A, b, permc_spec=None, use_umfpack=True):
else:
flag = 0 # CSR format
+ indices = A.indices.astype(np.intc, copy=False)
+ indptr = A.indptr.astype(np.intc, copy=False)
options = dict(ColPerm=permc_spec)
- x, info = _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr,
+ x, info = _superlu.gssv(N, A.nnz, A.data, indices, indptr,
b, flag, options=options)
if info != 0:
warn("Matrix is exactly singular", MatrixRankWarning)
@@ -402,6 +419,8 @@ def csc_construct_func(*a, cls=type(A)):
if (M != N):
raise ValueError("can only factor square matrices") # is this true?
+ indices, indptr = _safe_downcast_indices(A)
+
_options = dict(DiagPivotThresh=diag_pivot_thresh, ColPerm=permc_spec,
PanelSize=panel_size, Relax=relax)
if options is not None:
@@ -411,7 +430,7 @@ def csc_construct_func(*a, cls=type(A)):
if (_options["ColPerm"] == "NATURAL"):
_options["SymmetricMode"] = True
- return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
+ return _superlu.gstrf(N, A.nnz, A.data, indices, indptr,
csc_construct_func=csc_construct_func,
ilu=False, options=_options)
@@ -495,6 +514,8 @@ def csc_construct_func(*a, cls=type(A)):
if (M != N):
raise ValueError("can only factor square matrices") # is this true?
+ indices, indptr = _safe_downcast_indices(A)
+
_options = dict(ILU_DropRule=drop_rule, ILU_DropTol=drop_tol,
ILU_FillFactor=fill_factor,
DiagPivotThresh=diag_pivot_thresh, ColPerm=permc_spec,
@@ -506,7 +527,7 @@ def csc_construct_func(*a, cls=type(A)):
if (_options["ColPerm"] == "NATURAL"):
_options["SymmetricMode"] = True
- return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
+ return _superlu.gstrf(N, A.nnz, A.data, indices, indptr,
csc_construct_func=csc_construct_func,
ilu=True, options=_options)
diff --git a/scipy/sparse/linalg/_dsolve/tests/test_linsolve.py b/scipy/sparse/linalg/_dsolve/tests/test_linsolve.py
index 4740685a3dd7..00e0d1886dcc 100644
--- a/scipy/sparse/linalg/_dsolve/tests/test_linsolve.py
+++ b/scipy/sparse/linalg/_dsolve/tests/test_linsolve.py
@@ -220,8 +220,11 @@ def test_singular_gh_3312(self):
except RuntimeError:
pass
- def test_twodiags(self):
- A = spdiags([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]], [0, 1], 5, 5)
+ @pytest.mark.parametrize('format', ['csc', 'csr'])
+ @pytest.mark.parametrize('idx_dtype', [np.int32, np.int64])
+ def test_twodiags(self, format: str, idx_dtype: np.dtype):
+ A = spdiags([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]], [0, 1], 5, 5,
+ format=format)
b = array([1, 2, 3, 4, 5])
# condition number of A
@@ -230,13 +233,12 @@ def test_twodiags(self):
for t in ['f','d','F','D']:
eps = finfo(t).eps # floating point epsilon
b = b.astype(t)
+ Asp = A.astype(t)
+ Asp.indices = Asp.indices.astype(idx_dtype, copy=False)
+ Asp.indptr = Asp.indptr.astype(idx_dtype, copy=False)
- for format in ['csc','csr']:
- Asp = A.astype(t).asformat(format)
-
- x = spsolve(Asp,b)
-
- assert_(norm(b - Asp@x) < 10 * cond_A * eps)
+ x = spsolve(Asp, b)
+ assert_(norm(b - Asp@x) < 10 * cond_A * eps)
def test_bvector_smoketest(self):
Adense = array([[0., 1., 1.],
@@ -442,16 +444,18 @@ def setup_method(self):
n = 40
d = arange(n) + 1
self.n = n
- self.A = spdiags((d, 2*d, d[::-1]), (-3, 0, 5), n, n)
+ self.A = spdiags((d, 2*d, d[::-1]), (-3, 0, 5), n, n, format='csc')
random.seed(1234)
- def _smoketest(self, spxlu, check, dtype):
+ def _smoketest(self, spxlu, check, dtype, idx_dtype):
if np.issubdtype(dtype, np.complexfloating):
A = self.A + 1j*self.A.T
else:
A = self.A
A = A.astype(dtype)
+ A.indices = A.indices.astype(idx_dtype, copy=False)
+ A.indptr = A.indptr.astype(idx_dtype, copy=False)
lu = spxlu(A)
rng = random.RandomState(1234)
@@ -489,10 +493,9 @@ def check(A, b, x, msg=""):
r = A @ x
assert_(abs(r - b).max() < 1e3*eps, msg)
- self._smoketest(splu, check, np.float32)
- self._smoketest(splu, check, np.float64)
- self._smoketest(splu, check, np.complex64)
- self._smoketest(splu, check, np.complex128)
+ for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ for idx_dtype in [np.int32, np.int64]:
+ self._smoketest(splu, check, dtype, idx_dtype)
@sup_sparse_efficiency
def test_spilu_smoketest(self):
@@ -508,10 +511,9 @@ def check(A, b, x, msg=""):
if b.dtype in (np.float64, np.complex128):
errors.append(err)
- self._smoketest(spilu, check, np.float32)
- self._smoketest(spilu, check, np.float64)
- self._smoketest(spilu, check, np.complex64)
- self._smoketest(spilu, check, np.complex128)
+ for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ for idx_dtype in [np.int32, np.int64]:
+ self._smoketest(spilu, check, dtype, idx_dtype)
assert_(max(errors) > 1e-5)