Skip to content

Commit e4b48fa

Browse files
Refactor project structure and add error handling utility
- Add Printf dependency to Project.toml and examples/Manifest.toml - Introduce `assert` utility function with custom error handling - Rename source and test files to use hyphen-separated naming - Remove individual substitution and LU factorization implementation files - Update module to import Printf macro and add LinearAlgebraError type - Modify test runner to reflect new file structure
1 parent d0df013 commit e4b48fa

15 files changed

+278
-137
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ version = "1.0.0-DEV"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
89

910
[compat]
1011
LinearAlgebra = "1.10.0"
12+
Printf = "1.11.0"
1113
julia = "1.10"
1214

1315
[extras]

examples/Manifest.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@ deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
2727
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
2828
version = "0.3.27+1"
2929

30+
[[deps.Printf]]
31+
deps = ["Unicode"]
32+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
33+
version = "1.11.0"
34+
3035
[[deps.SimpleLinearAlgebra]]
31-
deps = ["LinearAlgebra"]
32-
path = "/Users/yangjunjie/work/SimpleLinearAlgebra.jl/SimpleLinearAlgebra"
36+
deps = ["LinearAlgebra", "Printf"]
37+
path = "/Users/yangjunjie/.julia/dev/SimpleLinearAlgebra"
3338
uuid = "555e9691-8231-4cde-a0d7-6dc20204a59a"
3439
version = "1.0.0-DEV"
3540

41+
[[deps.Unicode]]
42+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
43+
version = "1.11.0"
44+
3645
[[deps.libblastrampoline_jll]]
3746
deps = ["Artifacts", "Libdl"]
3847
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"

src/SimpleLinearAlgebra.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SimpleLinearAlgebra
22

3-
using LinearAlgebra
3+
using LinearAlgebra, Printf
4+
import Printf.@sprintf # Fixed macro import
45

56
abstract type ProblemMixin end
67
abstract type SolutionMixin end
@@ -9,7 +10,20 @@ module SimpleLinearAlgebra
910
MethodError(kernel, (typeof(prob),)) |> throw
1011
end
1112

12-
include("forward_substitution.jl")
13-
include("back_substitution.jl")
14-
include("lufact.jl")
13+
struct LinearAlgebraError <: Exception
14+
message::String
15+
end
16+
17+
function assert(obj, condition::Bool, message::String)
18+
if !condition
19+
cls = split(string(typeof(obj)), ".")[2]
20+
message = @sprintf("%s failed: %s", cls, message)
21+
LinearAlgebraError(message) |> throw
22+
end
23+
end
24+
25+
include("forward-substitution.jl")
26+
include("back-substitution.jl")
27+
include("lu-factorization.jl")
28+
include("partial-pivoting-lu-factorization.jl")
1529
end

src/back_substitution.jl renamed to src/back-substitution.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ struct BackSubstitutionProblem <: ProblemMixin
44
tol::Real
55

66
function BackSubstitutionProblem(u, b, tol)
7-
if !istriu(u)
8-
ArgumentError("Matrix must be upper triangular") |> throw
9-
end
10-
return new(u, b, tol)
7+
prob = new(u, b, tol)
8+
assert(prob, istriu(u), "matrix must be upper triangular")
9+
return prob
1110
end
1211
end
1312

@@ -25,13 +24,16 @@ function kernel(prob::BackSubstitutionProblem)
2524

2625
# solve the system
2726
x = zeros(n)
28-
x[n] = b[n] / u[n, n]
2927

30-
for i in n-1:-1:1
31-
if abs(u[i, i]) < tol
32-
throw(ArgumentError("Matrix is singular"))
28+
for i in n:-1:1
29+
assert(prob, abs(u[i, i]) > tol, "matrix is singular")
30+
x[i] += b[i] / u[i, i]
31+
32+
if i == n
33+
continue
3334
end
34-
x[i] = (b[i] - sum(u[i, j] * x[j] for j in i+1:n)) / u[i, i]
35+
36+
x[i] -= u[i, i+1:n]' * x[i+1:n] / u[i, i]
3537
end
3638

3739
return BackSubstitutionSolution(x)

src/forward_substitution.jl renamed to src/forward-substitution.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ struct ForwardSubstitutionProblem <: ProblemMixin
44
tol::Real
55

66
function ForwardSubstitutionProblem(l, b, tol)
7-
if !istril(l)
8-
ArgumentError("Matrix must be lower triangular") |> throw
9-
end
10-
return new(l, b, tol)
7+
prob = new(l, b, tol)
8+
assert(prob, istril(l), "matrix must be lower triangular")
9+
return prob
1110
end
1211
end
1312

@@ -24,12 +23,16 @@ function kernel(prob::ForwardSubstitutionProblem)
2423
tol = prob.tol
2524

2625
x = zeros(n)
27-
x[1] = b[1] / l[1, 1]
28-
for i in 2:n
29-
if abs(l[i, i]) < tol
30-
ArgumentError("Matrix is singular") |> throw
26+
27+
for i in 1:n
28+
assert(prob, abs(l[i, i]) > tol, "matrix is singular")
29+
x[i] += b[i] / l[i, i]
30+
31+
if i == 1
32+
continue
3133
end
32-
x[i] = (b[i] - sum(l[i, j] * x[j] for j in 1:i-1)) / l[i, i]
34+
35+
x[i] -= l[i, 1:i-1]' * x[1:i-1] / l[i, i]
3336
end
3437
return ForwardSubstitutionSolution(x)
3538
end

src/lu-factorization.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
struct LUFactorizationSolution <: SolutionMixin
2+
l::AbstractMatrix
3+
u::AbstractMatrix
4+
5+
function LUFactorizationSolution(l, u)
6+
s = new(l, u)
7+
assert(s, istril(l), "matrix must be lower triangular")
8+
assert(s, istriu(u), "matrix must be upper triangular")
9+
return s
10+
end
11+
end
12+
13+
abstract type LUFactorizationProblemMixin <: ProblemMixin end
14+
15+
struct Version1 <: LUFactorizationProblemMixin
16+
a::AbstractMatrix
17+
tol::Real
18+
end
19+
20+
function kernel(p::Version1)
21+
tol = p.tol
22+
u = copy(p.a)
23+
n = size(u, 1)
24+
25+
l = Matrix{Float64}(I, n, n)
26+
27+
for k in 1:n-1
28+
assert(p, abs(u[k, k]) > tol, "Gaussian elimination failed")
29+
30+
m_k = Matrix{Float64}(I, n, n)
31+
for i in k+1:n
32+
m_k[i, k] = -u[i, k] / u[k, k]
33+
end
34+
35+
l = l * inv(m_k)
36+
u .= m_k * u
37+
end
38+
39+
return LUFactorizationSolution(tril(l), triu(u))
40+
end
41+
42+
struct Version2 <: LUFactorizationProblemMixin
43+
a::AbstractMatrix
44+
tol::Real
45+
end
46+
47+
function kernel(p::Version2)
48+
tol = p.tol
49+
u = copy(p.a)
50+
n = size(u, 1)
51+
52+
# initialize l to be the identity matrix
53+
l = Matrix{Float64}(I, n, n)
54+
55+
for k in 1:n-1
56+
assert(p, abs(u[k, k]) > tol, "Gaussian elimination failed")
57+
l[k+1:n, k] = u[k+1:n, k] / u[k, k]
58+
u[k+1:n, k+1:n] -= l[k+1:n, k] * u[k, k+1:n]'
59+
end
60+
61+
return LUFactorizationSolution(tril(l), triu(u))
62+
end
63+
64+
LUFactVersion1 = Version1
65+
LUFactVersion2 = Version2
66+
export LUFactVersion1, LUFactVersion2
67+
68+
LUFactorizationProblem = Version2
69+
export LUFactorizationProblem

src/lufact.jl

Lines changed: 0 additions & 86 deletions
This file was deleted.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
struct PartialPivotingLUFactorizationSolution <: SolutionMixin
2+
l::AbstractMatrix
3+
u::AbstractMatrix
4+
p::Vector{Int}
5+
function PartialPivotingLUFactorizationSolution(l, u, p)
6+
s = new(l, u, p)
7+
assert(s, istril(l), "matrix must be lower triangular")
8+
assert(s, istriu(u), "matrix must be upper triangular")
9+
assert(s, isperm(p), "matrix must be a permutation matrix")
10+
return s
11+
end
12+
end
13+
14+
struct PartialPivotingLUFactorizationProblem <: ProblemMixin
15+
a::AbstractMatrix
16+
tol::Real
17+
end
18+
19+
function kernel(p::PartialPivotingLUFactorizationProblem)
20+
tol = p.tol
21+
u = copy(p.a)
22+
n = size(u, 1)
23+
24+
l = zeros(n, n)
25+
p = collect(1:n)
26+
27+
for k in 1:n-1
28+
v, x = findmax(abs.(u[k:n, k]))
29+
x += k - 1
30+
31+
if x != k
32+
# swap row k and row x of matrix u
33+
for i in 1:n
34+
u[k, i], u[x, i] = u[x, i], u[k, i]
35+
end
36+
37+
# swap row k and row x of matrix l
38+
for i in 1:n
39+
l[k, i], l[x, i] = l[x, i], l[k, i]
40+
end
41+
42+
# swap row k and row x of matrix p
43+
p[k], p[x] = p[x], p[k]
44+
end
45+
46+
if abs(u[k, k]) < tol
47+
continue
48+
end
49+
50+
l[k, k] = 1.0
51+
l[k+1:n, k] = u[k+1:n, k] / u[k, k]
52+
u[k+1:n, k+1:n] -= l[k+1:n, k] * u[k, k+1:n]'
53+
end
54+
55+
l[n, n] = 1.0
56+
return PartialPivotingLUFactorizationSolution(tril(l), triu(u), p)
57+
end
58+
59+
export PartialPivotingLUFactorizationProblem

test/back-substitution.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
@testset "backward substitution" begin
2+
n = 10
3+
tol = 1e-10
4+
u = LinearAlgebra.triu(rand(n, n))
5+
b = rand(n)
6+
7+
p = BackSubstitution(u, b, tol)
8+
s = kernel(p)
9+
x = s.x
10+
@test maximum(abs, u * x - b) < tol
11+
end
12+
13+
# test for me error message
14+
@testset "not upper triangular" begin
15+
n = 10
16+
tol = 1e-10
17+
u = rand(n, n)
18+
b = rand(n)
19+
@test_throws LinearAlgebraError BackSubstitution(u, b, tol)
20+
end
21+
22+
@testset "singular matrix" begin
23+
n = 10
24+
tol = 1e-10
25+
u = zeros(n, n)
26+
b = rand(n)
27+
prob = BackSubstitution(u, b, tol)
28+
@test_throws LinearAlgebraError kernel(prob)
29+
end

test/backward_substitution.jl

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)