Skip to content

Commit 227db69

Browse files
committed
replace macros with multiple-dispatch
1 parent 7479676 commit 227db69

File tree

2 files changed

+148
-131
lines changed

2 files changed

+148
-131
lines changed

src/KLU.jl

Lines changed: 95 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export klu, klu!
88

99
const libklu = :libklu
1010
include("wrappers.jl")
11+
include("type_resolution.jl")
1112

1213
import Base: (\), size, getproperty, setproperty!, propertynames, show
1314

@@ -87,20 +88,25 @@ macro isok(A)
8788
:(kluerror($(esc(A))))
8889
end
8990

91+
# this ought to be handled via multiple dispatch, so it can be done at compile time.
9092
function _klu_name(name, Tv, Ti)
9193
outname = "klu_" * (Tv === :Float64 ? "" : "z") * (Ti === :Int64 ? "l_" : "_") * name
9294
return Symbol(replace(outname, "__"=>"_"))
9395
end
94-
function _common(T)
95-
if T == Int64
96-
common = klu_l_common()
97-
ok = klu_l_defaults(Ref(common))
98-
elseif T == Int32
99-
common = klu_common()
100-
ok = klu_defaults(Ref(common))
96+
97+
function _common(::Type{Int64})
98+
common = klu_l_common()
99+
ok = klu_l_defaults(Ref(common))
100+
if ok == 1
101+
return common
101102
else
102-
throw(ArgumentError("Index type must be Int64 or Int32"))
103+
throw(ErrorException("Could not initialize common struct."))
103104
end
105+
end
106+
107+
function _common(::Type{Int32})
108+
common = klu_common()
109+
ok = klu_defaults(Ref(common))
104110
if ok == 1
105111
return common
106112
else
@@ -144,24 +150,14 @@ end
144150

145151
function _free_symbolic(K::AbstractKLUFactorization{Tv, Ti}) where {Ti<:KLUITypes, Tv}
146152
K._symbolic == C_NULL && return C_NULL
147-
if Ti == Int64
148-
klu_l_free_symbolic(Ref(Ptr{klu_l_symbolic}(K._symbolic)), Ref(K.common))
149-
elseif Ti == Int32
150-
klu_free_symbolic(Ref(Ptr{klu_symbolic}(K._symbolic)), Ref(K.common))
151-
end
153+
__free_sym(Ti, Ref(Ptr{__symType(Ti)}(K._symbolic)), Ref(K.common))
152154
K._symbolic = C_NULL
153155
end
154156

155-
for Ti KLUIndexTypes, Tv KLUValueTypes
156-
klufree = _klu_name("free_numeric", Tv, Ti)
157-
ptr = _klu_name("numeric", :Float64, Ti)
158-
@eval begin
159-
function _free_numeric(K::AbstractKLUFactorization{$Tv, $Ti})
160-
K._numeric == C_NULL && return C_NULL
161-
$klufree(Ref(Ptr{$ptr}(K._numeric)), Ref(K.common))
162-
K._numeric = C_NULL
163-
end
164-
end
157+
function _free_numeric(K::AbstractKLUFactorization{Tv, Ti}) where {Tv<:KLUTypes, Ti<:KLUITypes}
158+
K._numeric == C_NULL && return C_NULL
159+
__free_num(Tv, Ti, Ref(Ptr{__numType(Ti)}(K._numeric)), Ref(K.common))
160+
K._numeric = C_NULL
165161
end
166162

167163
function KLUFactorization(A::SparseMatrixCSC{Tv, Ti}) where {Tv<:KLUTypes, Ti<:KLUITypes}
@@ -377,11 +373,7 @@ end
377373

378374
function klu_analyze!(K::KLUFactorization{Tv, Ti}; check=true) where {Tv, Ti<:KLUITypes}
379375
if K._symbolic != C_NULL return K end
380-
if Ti == Int64
381-
sym = klu_l_analyze(K.n, K.colptr, K.rowval, Ref(K.common))
382-
else
383-
sym = klu_analyze(K.n, K.colptr, K.rowval, Ref(K.common))
384-
end
376+
sym = __analyze(Ti, K.n, K.colptr, K.rowval, Ref(K.common))
385377
if sym == C_NULL && check
386378
kluerror(K.common)
387379
else
@@ -393,11 +385,7 @@ end
393385
# User provided permutation vectors:
394386
function klu_analyze!(K::KLUFactorization{Tv, Ti}, P::Vector{Ti}, Q::Vector{Ti}; check=true) where {Tv, Ti<:KLUITypes}
395387
if K._symbolic != C_NULL return K end
396-
if Ti == Int64
397-
sym = klu_l_analyze_given(K.n, K.colptr, K.rowval, P, Q, Ref(K.common))
398-
else
399-
sym = klu_analyze_given(K.n, K.colptr, K.rowval, P, Q, Ref(K.common))
400-
end
388+
sym = __analyze!(K.n, K.colptr, K.rowval, P, Q, Ref(K.common))
401389
if sym == C_NULL && check
402390
kluerror(K.common)
403391
else
@@ -406,85 +394,74 @@ function klu_analyze!(K::KLUFactorization{Tv, Ti}, P::Vector{Ti}, Q::Vector{Ti};
406394
return K
407395
end
408396

409-
for Tv KLUValueTypes, Ti KLUIndexTypes
410-
factor = _klu_name("factor", Tv, Ti)
411-
@eval begin
412-
function klu_factor!(K::KLUFactorization{$Tv, $Ti}; check=true, allowsingular=false)
413-
K._symbolic == C_NULL && K.common.status >= KLU_OK && klu_analyze!(K)
414-
if K._symbolic != C_NULL && K.common.status >= KLU_OK
415-
K.common.halt_if_singular = !allowsingular && check
416-
num = $factor(K.colptr, K.rowval, K.nzval, K._symbolic, Ref(K.common))
417-
K.common.halt_if_singular = true
418-
else
419-
num = C_NULL
420-
end
421-
if num == C_NULL && check
422-
kluerror(K.common)
423-
else
424-
if allowsingular
425-
K.common.status < KLU_OK && check && kluerror(K.common)
426-
else
427-
(K.common.status == KLU_OK) || (check && kluerror(K.common))
428-
end
429-
end
430-
K._numeric = num
431-
return K
397+
398+
function klu_factor!(K::KLUFactorization{Tv, Ti}; check=true, allowsingular=false) where {Tv<:KLUTypes, Ti<:KLUITypes}
399+
K._symbolic == C_NULL && K.common.status >= KLU_OK && klu_analyze!(K)
400+
if K._symbolic != C_NULL && K.common.status >= KLU_OK
401+
K.common.halt_if_singular = !allowsingular && check
402+
num = __factor(Tv, Ti, K.colptr, K.rowval, K.nzval, K._symbolic, Ref(K.common))
403+
K.common.halt_if_singular = true
404+
else
405+
num = C_NULL
406+
end
407+
if num == C_NULL && check
408+
kluerror(K.common)
409+
else
410+
if allowsingular
411+
K.common.status < KLU_OK && check && kluerror(K.common)
412+
else
413+
(K.common.status == KLU_OK) || (check && kluerror(K.common))
432414
end
433415
end
416+
K._numeric = num
417+
return K
434418
end
435419

436-
for Tv KLUValueTypes, Ti KLUIndexTypes
437-
rgrowth = _klu_name("rgrowth", Tv, Ti)
438-
rcond = _klu_name("rcond", Tv, Ti)
439-
condest = _klu_name("condest", Tv, Ti)
440-
@eval begin
441-
"""
442-
rgrowth(K::KLUFactorization)
443-
444-
Calculate the reciprocal pivot growth.
445-
"""
446-
function rgrowth(K::KLUFactorization{$Tv, $Ti})
447-
K._numeric == C_NULL && klu_factor!(K)
448-
ok = $rgrowth(K.colptr, K.rowval, K.nzval, K._symbolic, K._numeric, Ref(K.common))
449-
if ok == 0
450-
kluerror(K.common)
451-
else
452-
return K.common.rgrowth
453-
end
454-
end
455420

456-
"""
457-
rcond(K::KLUFactorization)
421+
"""
422+
rgrowth(K::KLUFactorization)
458423
459-
Cheaply estimate the reciprocal condition number.
460-
"""
461-
function rcond(K::AbstractKLUFactorization{$Tv, $Ti})
462-
K._numeric == C_NULL && klu_factor!(K)
463-
ok = $rcond(K._symbolic, K._numeric, Ref(K.common))
464-
if ok == 0
465-
kluerror(K.common)
466-
else
467-
return K.common.rcond
468-
end
469-
end
424+
Calculate the reciprocal pivot growth.
425+
"""
426+
function rgrowth(K::KLUFactorization{Tv, Ti}) where {Tv<:KLUTypes, Ti<:KLUITypes}
427+
K._numeric == C_NULL && klu_factor!(K)
428+
ok = __rgrowth(Tv, Ti, K.colptr, K.rowval, K.nzval, K._symbolic, K._numeric, Ref(K.common))
429+
if ok == 0
430+
kluerror(K.common)
431+
else
432+
return K.common.rgrowth
433+
end
434+
end
470435

471-
"""
472-
condest(K::KLUFactorization)
436+
"""
437+
rcond(K::KLUFactorization)
473438
474-
Accurately estimate the 1-norm condition number of the factorization.
475-
"""
476-
function condest(K::KLUFactorization{$Tv, $Ti})
477-
K._numeric == C_NULL && klu_factor!(K)
478-
ok = $condest(K.colptr, K.nzval, K._symbolic, K._numeric, Ref(K.common))
479-
if ok == 0
480-
kluerror(K.common)
481-
else
482-
return K.common.condest
483-
end
484-
end
439+
Cheaply estimate the reciprocal condition number.
440+
"""
441+
function rcond(K::AbstractKLUFactorization{Tv, Ti}) where {Tv<:KLUTypes, Ti<:KLUITypes}
442+
K._numeric == C_NULL && klu_factor!(K)
443+
ok = __rcond(Tv, Ti, K._symbolic, K._numeric, Ref(K.common))
444+
if ok == 0
445+
kluerror(K.common)
446+
else
447+
return K.common.rcond
485448
end
486449
end
487450

451+
"""
452+
condest(K::KLUFactorization)
453+
454+
Accurately estimate the 1-norm condition number of the factorization.
455+
"""
456+
function condest(K::KLUFactorization{Tv, Ti}) where {Tv<:KLUTypes, Ti<:KLUITypes}
457+
K._numeric == C_NULL && klu_factor!(K)
458+
ok = __condest(Tv, Ti, K.colptr, K.nzval, K._symbolic, K._numeric, Ref(K.common))
459+
if ok == 0
460+
kluerror(K.common)
461+
else
462+
return K.common.condest
463+
end
464+
end
488465

489466
"""
490467
klu_factor!(K::KLUFactorization; check=true, allowsingular=false) -> K::KLUFactorization
@@ -599,23 +576,16 @@ See also: [`klu`](@ref)
599576
600577
[^ACM907]: Davis, Timothy A., & Palamadai Natarajan, E. (2010). Algorithm 907: KLU, A Direct Sparse Solver for Circuit Simulation Problems. ACM Trans. Math. Softw., 37(3). doi:10.1145/1824801.1824814
601578
"""
602-
klu!
603-
604-
for Tv KLUValueTypes, Ti KLUIndexTypes
605-
refactor = _klu_name("refactor", Tv, Ti)
606-
@eval begin
607-
function klu!(K::KLUFactorization{$Tv, $Ti}, nzval::Vector{$Tv}; check=true, allowsingular=false)
608-
length(nzval) != length(K.nzval) && throw(DimensionMismatch())
609-
K.nzval = nzval
610-
K.common.halt_if_singular = !allowsingular && check
611-
ok = $refactor(K.colptr, K.rowval, K.nzval, K._symbolic, K._numeric, Ref(K.common))
612-
K.common.halt_if_singular = true
613-
if (ok == 1 || !check || (allowsingular && K.common.status >= KLU_OK))
614-
return K
615-
else
616-
kluerror(K.common)
617-
end
618-
end
579+
function klu!(K::KLUFactorization{Tv, Ti}, nzval::Vector{Tv}; check=true, allowsingular=false) where {Tv<:KLUTypes, Ti<:KLUITypes}
580+
length(nzval) != length(K.nzval) && throw(DimensionMismatch())
581+
K.nzval = nzval
582+
K.common.halt_if_singular = !allowsingular && check
583+
ok = __refactor(Tv, Ti, K.colptr, K.rowval, K.nzval, K._symbolic, K._numeric, Ref(K.common))
584+
K.common.halt_if_singular = true
585+
if (ok == 1 || !check || (allowsingular && K.common.status >= KLU_OK))
586+
return K
587+
else
588+
kluerror(K.common)
619589
end
620590
end
621591

@@ -663,19 +633,13 @@ This function overwrites `B` with the solution `X`, for a new solution vector `X
663633
664634
This status should be checked by the user before solve calls if singularity checks were disabled on factorization using `check=false` or `allowsingular=true`.
665635
"""
666-
solve!
667-
for Tv KLUValueTypes, Ti KLUIndexTypes
668-
solve = _klu_name("solve", Tv, Ti)
669-
@eval begin
670-
function solve!(klu::AbstractKLUFactorization{$Tv, $Ti}, B::StridedVecOrMat{$Tv}; check=true)
671-
stride(B, 1) == 1 || throw(ArgumentError("B must have unit strides"))
672-
klu._numeric == C_NULL && klu_factor!(klu)
673-
size(B, 1) == size(klu, 1) || throw(DimensionMismatch())
674-
isok = $solve(klu._symbolic, klu._numeric, size(B, 1), size(B, 2), B, Ref(klu.common))
675-
isok == 0 && check && kluerror(klu.common)
676-
return B
677-
end
678-
end
636+
function solve!(klu::AbstractKLUFactorization{Tv, Ti}, B::StridedVecOrMat{Tv}; check=true) where {Tv<:KLUTypes, Ti<:KLUITypes}
637+
stride(B, 1) == 1 || throw(ArgumentError("B must have unit strides"))
638+
klu._numeric == C_NULL && klu_factor!(klu)
639+
size(B, 1) == size(klu, 1) || throw(DimensionMismatch())
640+
isok = __solve(Tv, Ti, klu._symbolic, klu._numeric, size(B, 1), size(B, 2), B, Ref(klu.common))
641+
isok == 0 && check && kluerror(klu.common)
642+
return B
679643
end
680644

681645
for Tv KLUValueTypes, Ti KLUIndexTypes

src/type_resolution.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
__free_sym(::Type{Int32}, args...) = KLU.klu_free_symbolic(args...)
2+
__free_sym(::Type{Int64}, args...) = KLU.klu_l_free_symbolic(args...)
3+
4+
__free_num(::Type{Float64}, ::Type{Int32}, args...) = KLU.klu_free_numeric(args...)
5+
__free_num(::Type{Float64}, ::Type{Int64}, args...) = KLU.klu_l_free_numeric(args...)
6+
__free_num(::Type{ComplexF64}, ::Type{Int32}, args...) = KLU.klu_z_free_numeric(args...)
7+
__free_num(::Type{ComplexF64}, ::Type{Int64}, args...) = KLU.klu_zl_free_numeric(args...)
8+
9+
__analyze(::Type{Int32}, args...) = KLU.klu_analyze(args...)
10+
__analyze(::Type{Int64}, args...) = KLU.klu_l_analyze(args...)
11+
12+
__analyze!(::Type{Int32}, args...) = KLU.klu_analyze_given(args...)
13+
__analyze!(::Type{Int64}, args...) = KLU.klu_l_analyze_given(args...)
14+
15+
__factor(::Type{Float64}, ::Type{Int32}, args...) = KLU.klu_factor(args...)
16+
__factor(::Type{Float64}, ::Type{Int64}, args...) = KLU.klu_l_factor(args...)
17+
__factor(::Type{ComplexF64}, ::Type{Int32}, args...) = KLU.klu_z_factor(args...)
18+
__factor(::Type{ComplexF64}, ::Type{Int64}, args...) = KLU.klu_zl_factor(args...)
19+
20+
__rcond(::Type{Float64}, ::Type{Int32}, args...) = KLU.klu_rcond(args...)
21+
__rcond(::Type{Float64}, ::Type{Int64}, args...) = KLU.klu_l_rcond(args...)
22+
__rcond(::Type{ComplexF64}, ::Type{Int32}, args...) = KLU.klu_z_rcond(args...)
23+
__rcond(::Type{ComplexF64}, ::Type{Int64}, args...) = KLU.klu_zl_rcond(args...)
24+
25+
__rgrowth(::Type{Float64}, ::Type{Int32}, args...) = KLU.klu_rgrowth(args...)
26+
__rgrowth(::Type{Float64}, ::Type{Int64}, args...) = KLU.klu_l_rgrowth(args...)
27+
__rgrowth(::Type{ComplexF64}, ::Type{Int32}, args...) = KLU.klu_z_rgrowth(args...)
28+
__rgrowth(::Type{ComplexF64}, ::Type{Int64}, args...) = KLU.klu_zl_rgrowth(args...)
29+
30+
__condest(::Type{Float64}, ::Type{Int32}, args...) = KLU.klu_condest(args...)
31+
__condest(::Type{Float64}, ::Type{Int64}, args...) = KLU.klu_l_condest(args...)
32+
__condest(::Type{ComplexF64}, ::Type{Int32}, args...) = KLU.klu_z_condest(args...)
33+
__condest(::Type{ComplexF64}, ::Type{Int64}, args...) = KLU.klu_zl_condest(args...)
34+
35+
__refactor(::Type{Float64}, ::Type{Int32}, args...) = KLU.klu_refactor(args...)
36+
__refactor(::Type{Float64}, ::Type{Int64}, args...) = KLU.klu_l_refactor(args...)
37+
__refactor(::Type{ComplexF64}, ::Type{Int32}, args...) = KLU.klu_z_refactor(args...)
38+
__refactor(::Type{ComplexF64}, ::Type{Int64}, args...) = KLU.klu_zl_refactor(args...)
39+
40+
__solve(::Type{Float64}, ::Type{Int32}, args...) = KLU.klu_solve(args...)
41+
__solve(::Type{Float64}, ::Type{Int64}, args...) = KLU.klu_l_solve(args...)
42+
__solve(::Type{ComplexF64}, ::Type{Int32}, args...) = KLU.klu_z_solve(args...)
43+
__solve(::Type{ComplexF64}, ::Type{Int64}, args...) = KLU.klu_zl_solve(args...)
44+
45+
__symType(::Type{Int32}) = KLU.klu_symbolic
46+
__symType(::Type{Int64}) = KLU.klu_l_symbolic
47+
48+
__numType(::Type{Int32}) = KLU.klu_numeric
49+
__numType(::Type{Int64}) = KLU.klu_l_numeric
50+
51+
# Could rewrite with multiple dispatch, but more awkward because of the
52+
# varying number of arguments between real vs complex cases: tsolve, extract.
53+
# sort is only used in extract, so also not implemented here.

0 commit comments

Comments
 (0)