Skip to content

Commit d0df013

Browse files
Refactor LU factorization with multiple implementation versions
- Introduce two LU factorization implementations (LUFactVersion1 and LUFactVersion2) - Update abstract type names from LinearSystem* to Problem/SolutionMixin - Modify kernel methods for both LU factorization versions - Update test suite to cover both implementation versions - Set LUFactorizationProblem as an alias for LUFactVersion2
1 parent e3c2533 commit d0df013

File tree

5 files changed

+83
-40
lines changed

5 files changed

+83
-40
lines changed

src/SimpleLinearAlgebra.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ module SimpleLinearAlgebra
22

33
using LinearAlgebra
44

5-
abstract type LinearSystemProblemMixin end
6-
abstract type LinearSystemSolutionMixin end
5+
abstract type ProblemMixin end
6+
abstract type SolutionMixin end
77

8-
function kernel(prob::LinearSystemProblemMixin)
8+
function kernel(prob::ProblemMixin)
99
MethodError(kernel, (typeof(prob),)) |> throw
1010
end
1111

src/back_substitution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct BackSubstitutionProblem <: LinearSystemProblemMixin
1+
struct BackSubstitutionProblem <: ProblemMixin
22
u::AbstractMatrix
33
b::AbstractVector
44
tol::Real
@@ -11,7 +11,7 @@ struct BackSubstitutionProblem <: LinearSystemProblemMixin
1111
end
1212
end
1313

14-
struct BackSubstitutionSolution <: LinearSystemSolutionMixin
14+
struct BackSubstitutionSolution <: SolutionMixin
1515
x::AbstractVector
1616
end
1717

src/forward_substitution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct ForwardSubstitutionProblem <: LinearSystemProblemMixin
1+
struct ForwardSubstitutionProblem <: ProblemMixin
22
l::AbstractMatrix
33
b::AbstractVector
44
tol::Real
@@ -13,7 +13,7 @@ end
1313

1414
ForwardSubstitution = ForwardSubstitutionProblem
1515

16-
struct ForwardSubstitutionSolution <: LinearSystemSolutionMixin
16+
struct ForwardSubstitutionSolution <: SolutionMixin
1717
x::AbstractVector
1818
end
1919

src/lufact.jl

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
struct LUFactorizationProblem <: LinearSystemProblemMixin
2-
a::AbstractMatrix
3-
tol::Real
4-
end
5-
6-
struct LUFactorizationSolution <: LinearSystemSolutionMixin
1+
struct LUFactorizationSolution <: SolutionMixin
72
l::AbstractMatrix
83
u::AbstractMatrix
94

@@ -20,54 +15,72 @@ struct LUFactorizationSolution <: LinearSystemSolutionMixin
2015
end
2116
end
2217

23-
LUFactorization = LUFactorizationProblem
18+
struct LUFactVersion1 <: ProblemMixin
19+
a::AbstractMatrix
20+
tol::Real
21+
end
2422

25-
function elemination_step(a::AbstractMatrix{T}, k::Integer) where T
26-
# given a matrix A, perform the k-th elimination step.
27-
# M A = N, which M is a lower triangular matrix with ones on the diagonal,
28-
# and N is the matrix A with the k-th column eliminated.
29-
# The inverse of M is simply given by Minv = 2I - m
23+
function kernel(p::LUFactVersion1)
24+
tol = p.tol
25+
u = copy(p.a)
26+
n = size(u, 1)
3027

31-
n = size(a, 1)
32-
@assert size(a, 2) == n
33-
@assert 1 <= k <= n-1
28+
l = zero(u)
29+
l[1:n+1:end] .= 1
30+
31+
for k in 1:n-1
32+
if abs(u[k, k]) < tol
33+
ArgumentError("Gaussian elimination failed") |> throw
34+
end
3435

35-
m = identity_matrix(T, n)
36-
for i in k+1:n
37-
m[i, k] = -a[i, k] / a[k, k]
36+
m_k = zero(u)
37+
m_k[1:n+1:end] .= 1
38+
for i in k+1:n
39+
m_k[i, k] = -u[i, k] / u[k, k]
40+
end
41+
42+
l = l * inv(m_k)
43+
u .= m_k * u
3844
end
39-
return m
45+
46+
return LUFactorizationSolution(tril(l), triu(u))
4047
end
4148

42-
function kernel(prob::LUFactorizationProblem)
43-
u = copy(prob.a)
44-
n = size(u, 1)
45-
@assert size(u, 2) == n
4649

47-
tol = prob.tol
4850

49-
l = zero(u)
51+
struct LUFactVersion2 <: ProblemMixin
52+
a::AbstractMatrix
53+
tol::Real
54+
end
5055

51-
l[1:n+1:end] .= 1
56+
function kernel(p::LUFactVersion2)
57+
tol = p.tol
58+
u = copy(p.a)
59+
n = size(u, 1)
5260

53-
for k in 1:n-1
61+
l = zero(u)
62+
l[1:n+1:end] .+= 1
5463

64+
for k in 1:n-1
5565
if abs(u[k, k]) < tol
5666
ArgumentError("Gaussian elimination failed") |> throw
5767
end
5868

59-
for i=k+1:n
69+
for i in k+1:n
6070
l[i, k] = u[i, k] / u[k, k]
6171
end
6272

6373
for j in k+1:n
64-
for i in k+1:n
74+
for i=k+1:n
6575
u[i, j] -= l[i, k] * u[k, j]
6676
end
6777
end
6878
end
6979

70-
return LUFactorizationSolution(l, triu(u))
80+
return LUFactorizationSolution(tril(l), triu(u))
7181
end
7282

73-
export LUFactorization, kernel
83+
export LUFactVersion1, LUFactVersion2
84+
85+
LUFactorizationProblem = LUFactVersion2
86+
export LUFactorizationProblem

test/lufact.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,39 @@
1-
@testset "LU factorization" begin
1+
@testset "LU factorization version 1" begin
2+
n = 4
3+
tol = 1e-10
4+
a = rand(n, n)
5+
6+
p = LUFactVersion1(a, tol)
7+
s = kernel(p)
8+
l = s.l
9+
u = s.u
10+
11+
@test istril(l)
12+
@test istriu(u)
13+
@test maximum(abs, a - l * u) < tol
14+
end
15+
16+
@testset "LU factorization version 2" begin
217
n = 10
318
tol = 1e-10
419
a = rand(n, n)
520

6-
p = LUFactorization(a, tol)
21+
p = LUFactVersion2(a, tol)
22+
s = kernel(p)
23+
l = s.l
24+
u = s.u
25+
26+
@test istril(l)
27+
@test istriu(u)
28+
@test maximum(abs, a - l * u) < tol
29+
end
30+
31+
@testset "LU factorization" begin
32+
n = 20
33+
tol = 1e-10
34+
a = rand(n, n)
35+
36+
p = LUFactorizationProblem(a, tol)
737
s = kernel(p)
838
l = s.l
939
u = s.u

0 commit comments

Comments
 (0)