Skip to content

Commit 86c7484

Browse files
Follow matrix type branching from MKL Pardiso docs (#114)
Co-authored-by: Kristoffer <kcarlsson89@gmail.com>
1 parent e32862c commit 86c7484

File tree

4 files changed

+167
-48
lines changed

4 files changed

+167
-48
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ julia = "1.6"
1717
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1818
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1919
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
20+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2021
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2122

2223
[targets]
23-
test = ["Test", "Random", "Printf", "Pkg"]
24+
test = ["StableRNGs", "Test", "Random", "Printf", "Pkg"]

src/Pardiso.jl

Lines changed: 144 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
MKL_LOAD_FAILED = false
3030

31-
mkl_is_available() = (LOCAL_MKL_FOUND || MKL_jll.is_available()) && !MKL_LOAD_FAILED
31+
mkl_is_available() = (LOCAL_MKL_FOUND || MKL_jll.is_available()) && !MKL_LOAD_FAILED
3232

3333
if LinearAlgebra.BLAS.vendor() === :mkl && LinearAlgebra.BlasInt == Int64
3434
const MklInt = Int64
@@ -128,7 +128,7 @@ function __init__()
128128
elseif MKL_jll.is_available()
129129
libmkl_rt[] = MKL_jll.libmkl_rt_path
130130
end
131-
131+
132132
if !haskey(ENV, "PARDISOLICMESSAGE")
133133
ENV["PARDISOLICMESSAGE"] = 1
134134
end
@@ -137,7 +137,7 @@ function __init__()
137137
@warn "MKLROOT not set, MKL Pardiso solver will not be functional"
138138
end
139139

140-
if mkl_is_available()
140+
if mkl_is_available()
141141
try
142142
libmklpardiso = Libdl.dlopen(libmkl_rt[])
143143
mklpardiso_f = Libdl.dlsym(libmklpardiso, "pardiso")
@@ -146,7 +146,7 @@ function __init__()
146146
MKL_LOAD_FAILED = true
147147
end
148148
end
149-
149+
150150
# This is apparently needed for MKL to not get stuck on 1 thread when
151151
# libpardiso is loaded in the block below...
152152
if libmkl_rt[] !== ""
@@ -241,6 +241,85 @@ function fix_iparm!(ps::AbstractPardisoSolver, T::Symbol)
241241
end
242242
end
243243

244+
# Copied from SparseArrays.jl but with a tweak
245+
# to check symmetry of the sparsity pattern
246+
function _is_hermsym(A::SparseMatrixCSC, check::Function)
247+
m, n = size(A)
248+
if m != n; return false; end
249+
250+
colptr = SparseArrays.getcolptr(A)
251+
rowval = rowvals(A)
252+
nzval = nonzeros(A)
253+
tracker = copy(SparseArrays.getcolptr(A))
254+
for col in axes(A,2)
255+
# `tracker` is updated such that, for symmetric matrices,
256+
# the loop below starts from an element at or below the
257+
# diagonal element of column `col`"
258+
for p = tracker[col]:colptr[col+1]-1
259+
val = nzval[p]
260+
row = rowval[p]
261+
262+
# Ignore stored zeros
263+
if iszero(val)
264+
continue
265+
end
266+
267+
# If the matrix was symmetric we should have updated
268+
# the tracker to start at the diagonal or below. Here
269+
# we are above the diagonal so the matrix can't be symmetric.
270+
if row < col
271+
return false
272+
end
273+
274+
# Diagonal element
275+
if row == col
276+
if !check(val, val)
277+
return false
278+
end
279+
else
280+
offset = tracker[row]
281+
282+
# If the matrix is unsymmetric, there might not exist
283+
# a rowval[offset]
284+
if offset > length(rowval)
285+
return false
286+
end
287+
288+
row2 = rowval[offset]
289+
290+
# row2 can be less than col if the tracker didn't
291+
# get updated due to stored zeros in previous elements.
292+
# We therefore "catch up" here while making sure that
293+
# the elements are actually zero.
294+
while row2 < col
295+
if _isnotzero(nzval[offset])
296+
return false
297+
end
298+
offset += 1
299+
row2 = rowval[offset]
300+
tracker[row] += 1
301+
end
302+
303+
# Non zero A[i,j] exists but A[j,i] does not exist
304+
if row2 > col
305+
return false
306+
end
307+
308+
# A[i,j] and A[j,i] exists
309+
if row2 == col
310+
if !check(val, nzval[offset])
311+
return false
312+
end
313+
tracker[row] += 1
314+
end
315+
end
316+
end
317+
end
318+
return true
319+
end
320+
321+
isstructurallysymmetric(A::SparseMatrixCSC) = _is_hermsym(A, (x,y) -> true)
322+
244323
function solve!(ps::AbstractPardisoSolver, X::StridedVecOrMat{Tv},
245324
A::SparseMatrixCSC{Tv,Ti}, B::StridedVecOrMat{Tv},
246325
T::Symbol=:N) where {Ti, Tv <: PardisoNumTypes}
@@ -253,38 +332,72 @@ function solve!(ps::AbstractPardisoSolver, X::StridedVecOrMat{Tv},
253332
# - On pos def exception, solve instead with symmetric indefinite.
254333
# - If complex and symmetric, solve with symmetric complex solver
255334
# - Else solve as unsymmetric.
256-
if ishermitian(A)
257-
eltype(A) <: Union{Float32, Float64} ? set_matrixtype!(ps, REAL_SYM_POSDEF) : set_matrixtype!(ps, COMPLEX_HERM_POSDEF)
258-
pardisoinit(ps)
259-
fix_iparm!(ps, T)
260-
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
261-
try
335+
if Tv <: Union{Float32, Float64}
336+
if issymmetric(A)
337+
set_matrixtype!(ps, REAL_SYM_POSDEF)
338+
pardisoinit(ps)
339+
fix_iparm!(ps, T)
340+
Tv == Float32 && set_iparm!(ps, 28, 1)
341+
try
342+
pardiso(ps, X, get_matrix(ps, A, T), B)
343+
catch e
344+
set_phase!(ps, RELEASE_ALL)
345+
pardiso(ps, X, A, B)
346+
set_phase!(ps, ANALYSIS_NUM_FACT_SOLVE_REFINE)
347+
if !isa(e, PardisoPosDefException)
348+
rethrow()
349+
end
350+
set_matrixtype!(ps, REAL_SYM_INDEF)
351+
pardisoinit(ps)
352+
fix_iparm!(ps, T)
353+
Tv == Float32 && set_iparm!(ps, 28, 1)
354+
pardiso(ps, X, get_matrix(ps, A, T), B)
355+
end
356+
else
357+
if isstructurallysymmetric(A)
358+
set_matrixtype!(ps, REAL_SYM)
359+
else
360+
set_matrixtype!(ps, REAL_NONSYM)
361+
end
362+
pardisoinit(ps)
363+
fix_iparm!(ps, T)
364+
Tv == Float32 && set_iparm!(ps, 28, 1)
262365
pardiso(ps, X, get_matrix(ps, A, T), B)
263-
catch e
264-
set_phase!(ps, RELEASE_ALL)
265-
pardiso(ps, X, A, B)
266-
set_phase!(ps, ANALYSIS_NUM_FACT_SOLVE_REFINE)
267-
if !isa(e, PardisoPosDefException)
268-
rethrow()
366+
end
367+
else # Tv <: Union{ComplexF64, ComplexF32}
368+
if ishermitian(A)
369+
set_matrixtype!(ps, COMPLEX_HERM_POSDEF)
370+
pardisoinit(ps)
371+
fix_iparm!(ps, T)
372+
Tv == ComplexF32 && set_iparm!(ps, 28, 1)
373+
try
374+
pardiso(ps, X, get_matrix(ps, A, T), B)
375+
catch e
376+
set_phase!(ps, RELEASE_ALL)
377+
pardiso(ps, X, A, B)
378+
set_phase!(ps, ANALYSIS_NUM_FACT_SOLVE_REFINE)
379+
if !isa(e, PardisoPosDefException)
380+
rethrow()
381+
end
382+
set_matrixtype!(ps, COMPLEX_HERM_INDEF)
383+
pardisoinit(ps)
384+
fix_iparm!(ps, T)
385+
Tv == ComplexF32 && set_iparm!(ps, 28, 1)
386+
pardiso(ps, X, get_matrix(ps, A, T), B)
387+
end
388+
else
389+
if issymmetric(A)
390+
set_matrixtype!(ps, COMPLEX_SYM)
391+
elseif isstructurallysymmetric(A)
392+
set_matrixtype!(ps, COMPLEX_STRUCT_SYM)
393+
else
394+
set_matrixtype!(ps, COMPLEX_NONSYM)
269395
end
270-
eltype(A) <: Union{Float32, Float64} ? set_matrixtype!(ps, REAL_SYM_INDEF) : set_matrixtype!(ps, COMPLEX_HERM_INDEF)
271396
pardisoinit(ps)
272397
fix_iparm!(ps, T)
273-
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
398+
Tv == ComplexF32 && set_iparm!(ps, 28, 1)
274399
pardiso(ps, X, get_matrix(ps, A, T), B)
275400
end
276-
elseif issymmetric(A)
277-
set_matrixtype!(ps, COMPLEX_SYM)
278-
pardisoinit(ps)
279-
fix_iparm!(ps, T)
280-
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
281-
pardiso(ps, X, get_matrix(ps, A, T), B)
282-
else
283-
eltype(A) <: Union{Float32, Float64} ? set_matrixtype!(ps, REAL_NONSYM) : set_matrixtype!(ps, COMPLEX_NONSYM)
284-
pardisoinit(ps)
285-
fix_iparm!(ps, T)
286-
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
287-
pardiso(ps, X, get_matrix(ps, A, T), B)
288401
end
289402

290403
# Release memory, TODO: We are running the convert on IA and JA here
@@ -344,7 +457,7 @@ function pardiso(ps::AbstractPardisoSolver, X::StridedVecOrMat{Tv}, A::SparseMat
344457
end
345458

346459
N = size(A, 2)
347-
460+
348461
resize!(ps.perm, size(B, 1))
349462

350463
NRHS = size(B, 2)

src/enums.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414
Base.isreal(v::MatrixType) = v in (REAL_SYM, REAL_SYM_POSDEF, REAL_SYM_INDEF, REAL_NONSYM)
15-
LinearAlgebra.issymmetric(v::MatrixType) = v in (REAL_SYM, REAL_SYM_POSDEF, REAL_SYM_INDEF, COMPLEX_STRUCT_SYM,
15+
LinearAlgebra.issymmetric(v::MatrixType) = v in (REAL_SYM_POSDEF, REAL_SYM_INDEF,
1616
COMPLEX_HERM_POSDEF, COMPLEX_HERM_INDEF, COMPLEX_SYM)
1717
LinearAlgebra.ishermitian(v::MatrixType) = v in (REAL_SYM_POSDEF, COMPLEX_HERM_POSDEF, COMPLEX_HERM_INDEF)
1818
isposornegdef(v::MatrixType) = v in (REAL_SYM_POSDEF, REAL_SYM_INDEF, COMPLEX_HERM_POSDEF, COMPLEX_HERM_INDEF)
@@ -30,7 +30,7 @@ const MATRIX_STRING = Dict{MatrixType, String}(
3030
)
3131

3232
const REAL_MATRIX_TYPES = [REAL_SYM, REAL_SYM_POSDEF, REAL_SYM_INDEF, REAL_NONSYM]
33-
const COMPLEX_MATRIX_TYPES = [COMPLEX_STRUCT_SYM, COMPLEX_HERM_POSDEF, COMPLEX_HERM_INDEF, COMPLEX_NONSYM]
33+
const COMPLEX_MATRIX_TYPES = [COMPLEX_STRUCT_SYM, COMPLEX_HERM_POSDEF, COMPLEX_HERM_INDEF, COMPLEX_NONSYM, COMPLEX_SYM]
3434

3535
# Messages
3636
@enum(MessageLevel::Int32,

test/runtests.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@ end
99

1010
using Test
1111
using Pardiso
12+
using StableRNGs
1213
using Random
1314
using SparseArrays
1415
using LinearAlgebra
1516

16-
Random.seed!(1234)
17-
1817
available_solvers = empty([Pardiso.AbstractPardisoSolver])
1918
if Pardiso.mkl_is_available()
2019
push!(available_solvers, MKLPardisoSolver)
@@ -27,6 +26,8 @@ else
2726
@warn "Not testing panua Pardiso solver"
2827
end
2928

29+
const rng = StableRNG(1)
30+
3031
@show Pardiso.MklInt
3132

3233
println("Testing ", available_solvers)
@@ -38,9 +39,10 @@ supported_eltypes(ps::MKLPardisoSolver) = (Float32, ComplexF32, Float64, Complex
3839
@testset "solving" begin
3940
for pardiso_type in available_solvers
4041
ps = pardiso_type()
42+
Random.seed!(rng, 1234)
4143
for T in supported_eltypes(ps)
42-
A1 = sparse(rand(T, 10,10))
43-
for B in (rand(T, 10, 2), view(rand(T, 10, 4), 1:10, 2:3))
44+
A1 = sparse(rand(rng, T, 10,10))
45+
for B in (rand(rng, T, 10, 2), view(rand(rng, T, 10, 4), 1:10, 2:3))
4446
X = similar(B)
4547
# Test unsymmetric, herm indef, herm posdef and symmetric
4648
for A in SparseMatrixCSC[A1, A1 + A1', A1'A1 + I, transpose(A1) + A1]
@@ -91,6 +93,7 @@ end
9193

9294
if Pardiso.PARDISO_LOADED[]
9395
@testset "schur" begin
96+
Random.seed!(rng, 1234)
9497
# reproduce example from Pardiso website
9598
include("schur_matrix_def.jl")
9699
@test norm(real(D) - real(C)*rA⁻¹*real(B) - s) < 1e-10*(8)^2
@@ -108,11 +111,11 @@ if Pardiso.PARDISO_LOADED[]
108111
set_matrixtype!(ps, 13)
109112
end
110113
for j 1:100
111-
A = 5I + sprand(T,m,m,p)
114+
A = 5I + sprand(rng,T,m,m,p)
112115
A⁻¹ = inv(Matrix(A))
113-
B = sprand(T,m,n,p)
114-
C = sprand(T,n,m,p)
115-
D = 5I + sprand(T,n,n,p)
116+
B = sprand(rng,T,m,n,p)
117+
C = sprand(rng,T,n,m,p)
118+
D = 5I + sprand(rng,T,n,n,p)
116119
M = [A B; C D]
117120

118121
# test integer block specification
@@ -137,13 +140,14 @@ end # testset
137140
end
138141

139142
@testset "error checks" begin
143+
Random.seed!(rng, 1234)
140144
for pardiso_type in available_solvers
141145

142146
ps = pardiso_type()
143147

144-
A = sparse(rand(10,10))
145-
B = rand(10, 2)
146-
X = rand(10, 2)
148+
A = sparse(rand(rng,10,10))
149+
B = rand(rng, 10, 2)
150+
X = rand(rng, 10, 2)
147151

148152
if pardiso_type == PardisoSolver
149153
printstats(ps, A, B)
@@ -160,7 +164,7 @@ for pardiso_type in available_solvers
160164
X = zeros(12, 2)
161165
@test_throws DimensionMismatch solve!(ps,X, A, B)
162166

163-
B = rand(12, 2)
167+
B = rand(rng, 12, 2)
164168
@test_throws DimensionMismatch solve(ps, A, B)
165169
end
166170
end # testset
@@ -201,9 +205,10 @@ for pardiso_type in available_solvers
201205
end
202206

203207
@testset "pardiso" begin
208+
Random.seed!(rng, 1234)
204209
for pardiso_type in available_solvers
205-
A = sparse(rand(2,2) + im * rand(2,2))
206-
b = rand(2) + im * rand(2)
210+
A = sparse(rand(rng,2,2) + im * rand(rng,2,2))
211+
b = rand(rng,2) + im * rand(rng,2)
207212
ps = pardiso_type()
208213
set_matrixtype!(ps, Pardiso.COMPLEX_NONSYM)
209214
x = Pardiso.solve(ps, A, b);

0 commit comments

Comments
 (0)