Skip to content

Commit e32862c

Browse files
authored
Exposing single precision support from MKL (#108)
* - Exposing single precision support from MKL - Adding check for square matrix - Changing Project.lic to Panua.lic and name of of file form "project_pardiso.jl" to "panua_pardiso.jl" * Update runtests.jl Keeping tests running for 1.6 by always computing the reference solution using Float64/CompelxF64. * Update runtests.jl Converting results back to Float32/ComplexF32. * Adding a diagonal to the `herm posdef` to make it well-condtioned. Testing that the solution actually works rather than comparing with SparseArrays * Adding exception to catch that iparm[28]=1 for Float32/ComplexF32 and MKLPardisoSolver.
1 parent a645c66 commit e32862c

File tree

3 files changed

+49
-24
lines changed

3 files changed

+49
-24
lines changed

src/Pardiso.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ end
6464
Base.showerror(io::IO, e::Union{PardisoException,PardisoPosDefException}) = print(io, e.info);
6565

6666

67-
const PardisoNumTypes = Union{Float64,ComplexF64}
67+
const PardisoNumTypes = Union{Float32, ComplexF32, Float64, ComplexF64}
6868

6969
abstract type AbstractPardisoSolver end
7070

@@ -173,7 +173,7 @@ function __init__()
173173
end
174174
end
175175
include("enums.jl")
176-
include("project_pardiso.jl")
176+
include("panua_pardiso.jl")
177177
include("mkl_pardiso.jl")
178178

179179
# Getters and setters
@@ -218,9 +218,9 @@ end
218218

219219
function solve(ps::AbstractPardisoSolver, A::SparseMatrixCSC{Tv,Ti},
220220
B::StridedVecOrMat{Tv}, T::Symbol=:N) where {Ti, Tv <: PardisoNumTypes}
221-
X = copy(B)
222-
solve!(ps, X, A, B, T)
223-
return X
221+
X = copy(B)
222+
solve!(ps, X, A, B, T)
223+
return X
224224
end
225225

226226
function fix_iparm!(ps::AbstractPardisoSolver, T::Symbol)
@@ -244,6 +244,7 @@ end
244244
function solve!(ps::AbstractPardisoSolver, X::StridedVecOrMat{Tv},
245245
A::SparseMatrixCSC{Tv,Ti}, B::StridedVecOrMat{Tv},
246246
T::Symbol=:N) where {Ti, Tv <: PardisoNumTypes}
247+
LinearAlgebra.checksquare(A)
247248
set_phase!(ps, ANALYSIS_NUM_FACT_SOLVE_REFINE)
248249

249250
# This is the heuristics for choosing what matrix type to use
@@ -253,9 +254,10 @@ function solve!(ps::AbstractPardisoSolver, X::StridedVecOrMat{Tv},
253254
# - If complex and symmetric, solve with symmetric complex solver
254255
# - Else solve as unsymmetric.
255256
if ishermitian(A)
256-
eltype(A) == Float64 ? set_matrixtype!(ps, REAL_SYM_POSDEF) : set_matrixtype!(ps, COMPLEX_HERM_POSDEF)
257+
eltype(A) <: Union{Float32, Float64} ? set_matrixtype!(ps, REAL_SYM_POSDEF) : set_matrixtype!(ps, COMPLEX_HERM_POSDEF)
257258
pardisoinit(ps)
258259
fix_iparm!(ps, T)
260+
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
259261
try
260262
pardiso(ps, X, get_matrix(ps, A, T), B)
261263
catch e
@@ -265,20 +267,23 @@ function solve!(ps::AbstractPardisoSolver, X::StridedVecOrMat{Tv},
265267
if !isa(e, PardisoPosDefException)
266268
rethrow()
267269
end
268-
eltype(A) == Float64 ? set_matrixtype!(ps, REAL_SYM_INDEF) : set_matrixtype!(ps, COMPLEX_HERM_INDEF)
270+
eltype(A) <: Union{Float32, Float64} ? set_matrixtype!(ps, REAL_SYM_INDEF) : set_matrixtype!(ps, COMPLEX_HERM_INDEF)
269271
pardisoinit(ps)
270272
fix_iparm!(ps, T)
273+
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
271274
pardiso(ps, X, get_matrix(ps, A, T), B)
272275
end
273276
elseif issymmetric(A)
274277
set_matrixtype!(ps, COMPLEX_SYM)
275278
pardisoinit(ps)
276279
fix_iparm!(ps, T)
280+
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
277281
pardiso(ps, X, get_matrix(ps, A, T), B)
278282
else
279-
eltype(A) == Float64 ? set_matrixtype!(ps, REAL_NONSYM) : set_matrixtype!(ps, COMPLEX_NONSYM)
283+
eltype(A) <: Union{Float32, Float64} ? set_matrixtype!(ps, REAL_NONSYM) : set_matrixtype!(ps, COMPLEX_NONSYM)
280284
pardisoinit(ps)
281285
fix_iparm!(ps, T)
286+
eltype(A) <: Union{Float32, ComplexF32} ? set_iparm!(ps, 28, 1) : nothing
282287
pardiso(ps, X, get_matrix(ps, A, T), B)
283288
end
284289

@@ -333,6 +338,11 @@ function pardiso(ps::AbstractPardisoSolver, X::StridedVecOrMat{Tv}, A::SparseMat
333338
"has a complex matrix type set: $(get_matrixtype(ps))")))
334339
end
335340

341+
if Tv <: Union{Float32, ComplexF32} && typeof(ps) <: MKLPardisoSolver && ps.iparm[28] != 1
342+
throw(ErrorException(string("input matrix is Float32/ComplexF32 while MKLPardisoSolver ",
343+
"have iparm[28]=$(ps.iparm[28]) rather than 1.")))
344+
end
345+
336346
N = size(A, 2)
337347

338348
resize!(ps.perm, size(B, 1))
@@ -435,6 +445,7 @@ function pardisogetschur(ps::AbstractPardisoSolver)
435445
end
436446

437447
function dim_check(X, A, B)
448+
LinearAlgebra.checksquare(A)
438449
size(X) == size(B) || throw(DimensionMismatch(string("solution has $(size(X)), ",
439450
"RHS has size as $(size(B)).")))
440451
size(A, 1) == size(B, 1) || throw(DimensionMismatch(string("matrix has $(size(A,1)) ",

src/project_pardiso.jl renamed to src/panua_pardiso.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ mutable struct PardisoSolver <: AbstractPardisoSolver
99
maxfct::Int32
1010
mnum::Int32
1111
perm::Vector{Int32}
12+
colptr::Vector{Int32}
13+
rowval::Vector{Int32}
1214
end
1315

1416
function PardisoSolver()
@@ -34,9 +36,11 @@ function PardisoSolver()
3436
mnum = 1
3537
maxfct = 1
3638
perm = Int32[]
39+
colptr = Int32[]
40+
rowval = Int32[]
3741

3842
ps = PardisoSolver(pt, iparm, dparm, mtype, solver,
39-
phase, msglvl, maxfct, mnum, perm)
43+
phase, msglvl, maxfct, mnum, perm, colptr, rowval)
4044

4145
return ps
4246
end
@@ -77,9 +81,14 @@ end
7781

7882
@inline function ccall_pardiso(ps::PardisoSolver, N::Integer, nzval::Vector{Tv},
7983
colptr, rowval, NRHS::Integer, B::StridedVecOrMat{Tv}, X::StridedVecOrMat{Tv}) where {Tv}
84+
(Tv == Float32 || Tv == ComplexF32) && throw(ArgumentError("Single precision input matrix only supported by MKL."))
85+
8086
N = Int32(N)
81-
colptr = convert(Vector{Int32}, colptr)
82-
rowval = convert(Vector{Int32}, rowval)
87+
# Save new colptr and rowvals if a new analysis phase is run
88+
if ps.phase in [ANALYSIS, ANALYSIS_NUM_FACT, ANALYSIS_NUM_FACT_SOLVE_REFINE]
89+
ps.colptr = convert(Vector{Int32}, colptr)
90+
ps.rowval = convert(Vector{Int32}, rowval)
91+
end
8392
resize!(ps.perm, size(B, 1))
8493
NRHS = Int32(NRHS)
8594

@@ -90,7 +99,7 @@ end
9099
Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Tv}, Ptr{Tv},
91100
Ptr{Int32}, Ptr{Float64}),
92101
ps.pt, Ref(ps.maxfct), Ref(Int32(ps.mnum)), Ref(Int32(ps.mtype)), Ref(Int32(ps.phase)),
93-
Ref(N), nzval, colptr, rowval, ps.perm,
102+
Ref(N), nzval, ps.colptr, ps.rowval, ps.perm,
94103
Ref(NRHS), ps.iparm, Ref(Int32(ps.msglvl)), B, X,
95104
ERR, ps.dparm)
96105
check_error(ps, ERR[])
@@ -99,6 +108,8 @@ end
99108

100109
@inline function ccall_pardiso_get_schur(ps::PardisoSolver, S::Vector{Tv},
101110
IS::Vector{Int32}, JS::Vector{Int32}) where Tv
111+
(Tv == Float32 || Tv == ComplexF32) && throw(ArgumentError("Single precision input matrix only supported by MKL."))
112+
102113
ccall(pardiso_get_schur_f[], Cvoid,
103114
(Ptr{Int}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{Tv},
104115
Ptr{Int32}, Ptr{Int32}),
@@ -112,8 +123,8 @@ function printstats(ps::PardisoSolver, A::SparseMatrixCSC{Tv, Ti},
112123
B::StridedVecOrMat{Tv}) where {Ti,Tv <: PardisoNumTypes}
113124
N = Int32(size(A, 2))
114125
AA = A.nzval
115-
IA = convert(Vector{Int32}, A.colptr)
116-
JA = convert(Vector{Int32}, A.rowval)
126+
IA = ps.colptr
127+
JA = ps.rowval
117128
NRHS = Int32(size(B, 2))
118129
ERR = Ref{Int32}(0)
119130
if Tv <: Complex
@@ -181,7 +192,7 @@ function check_error(ps::PardisoSolver, err::Integer)
181192
err != -6 || throw(PardisoException("Preordering failed (matrix types 11, 13 only)."))
182193
err != -7 || throw(PardisoException("Diagonal matrix problem."))
183194
err != -8 || throw(PardisoException("32-bit integer overflow problem."))
184-
err != -10 || throw(PardisoException("No license file pardiso.lic found."))
195+
err != -10 || throw(PardisoException("No license file panua.lic found."))
185196
err != -11 || throw(PardisoException("License is expired."))
186197
err != -12 || throw(PardisoException("Wrong username or hostname."))
187198
err != -100|| throw(PardisoException("Reached maximum number of Krylov-subspace iteration in iterative solver."))

test/runtests.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,40 +24,43 @@ end
2424
if Pardiso.PARDISO_LOADED[]
2525
push!(available_solvers, PardisoSolver)
2626
else
27-
@warn "Not testing project Pardiso solver"
27+
@warn "Not testing panua Pardiso solver"
2828
end
2929

3030
@show Pardiso.MklInt
3131

3232
println("Testing ", available_solvers)
3333

34+
supported_eltypes(ps::PardisoSolver) = (Float64, ComplexF64)
35+
supported_eltypes(ps::MKLPardisoSolver) = (Float32, ComplexF32, Float64, ComplexF64)
36+
3437
# Test solver + for real and complex data
3538
@testset "solving" begin
3639
for pardiso_type in available_solvers
3740
ps = pardiso_type()
38-
for T in (Float64, ComplexF64)
41+
for T in supported_eltypes(ps)
3942
A1 = sparse(rand(T, 10,10))
4043
for B in (rand(T, 10, 2), view(rand(T, 10, 4), 1:10, 2:3))
4144
X = similar(B)
4245
# Test unsymmetric, herm indef, herm posdef and symmetric
43-
for A in SparseMatrixCSC[A1, A1 + A1', A1'A1, transpose(A1) + A1]
46+
for A in SparseMatrixCSC[A1, A1 + A1', A1'A1 + I, transpose(A1) + A1]
4447
solve!(ps, X, A, B)
45-
@test X A\Matrix(B)
48+
@test A*X B
4649

4750
X = solve(ps, A, B)
48-
@test X A\Matrix(B)
51+
@test A*X B
4952

5053
solve!(ps, X, A, B, :C)
51-
@test X A'\Matrix(B)
54+
@test A'*X B
5255

5356
X = solve(ps, A, B, :C)
54-
@test X A'\Matrix(B)
57+
@test A'*X B
5558

5659
solve!(ps, X, A, B, :T)
57-
@test X copy(transpose(A))\Matrix(B)
60+
@test transpose(A)*X B
5861

5962
X = solve(ps, A, B, :T)
60-
@test X copy(transpose(A))\Matrix(B)
63+
@test transpose(A)*X B
6164
end
6265
end
6366
end

0 commit comments

Comments
 (0)