Skip to content

Commit 7b74dde

Browse files
royessKrastanov
andauthored
add BP-OSD (#22)
Co-authored-by: Stefan Krastanov <github.acc@krastanov.org>
1 parent d58ed11 commit 7b74dde

File tree

7 files changed

+182
-5
lines changed

7 files changed

+182
-5
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# News
2+
3+
## v0.3.2 - 2024-11-15
4+
5+
- Add a (still unoptimized) implementation of a BP OSD decoder.
6+
7+
## Older - before 2021-10-28 unrecorded

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LDPCDecoders"
22
uuid = "3c486d74-64b9-4c60-8b1a-13a564e77efb"
33
authors = ["Krishna Praneet Gudipaty", "Stefan Krastanov", "QuantumSavory contributors"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"

src/LDPCDecoders.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using RowEchelon
1212
export
1313
decode!, batchdecode!,
1414
BeliefPropagationDecoder,
15+
BeliefPropagationOSDDecoder,
1516
BitFlipDecoder
1617

1718
include("generator.jl")
@@ -22,7 +23,9 @@ include("parity_generator.jl")
2223

2324
include("decoders/abstract_decoder.jl")
2425
include("decoders/belief_propagation.jl")
26+
include("decoders/belief_propagation_osd.jl")
2527
include("decoders/iterative_bitflip.jl")
28+
2629
include("syndrome_bp_decoder.jl")
2730
include("syndrome_simulator.jl")
2831
include("syndrome_it_decoder.jl")

src/decoders/belief_propagation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct BeliefPropagationDecoder <: AbstractDecoder
4444
scratch::BeliefPropagationScratchSpace
4545
end
4646

47-
function BeliefPropagationDecoder(H, per::Float64, max_iters::Int)
47+
function BeliefPropagationDecoder(H::Union{SparseArrays.SparseMatrixCSC{Bool,Int}, BitMatrix}, per::Float64, max_iters::Int)
4848
s, n = size(H)
4949
sparse_H = sparse(H)
5050
sparse_HT = sparse(H')
@@ -108,7 +108,7 @@ true
108108
function decode!(decoder::BeliefPropagationDecoder, syndrome::AbstractVector) # TODO check if casting to bitarrays helps with performance -- if it does, set up warnings to the user for cases where they have not done the casting
109109
reset!(decoder)
110110
rows::Vector{Int} = rowvals(decoder.sparse_H);
111-
rowsT::Vector{Int} = rowvals(decoder.sparse_HT);
111+
rowsT::Vector{Int} = rowvals(decoder.sparse_HT);
112112
setup = decoder.scratch
113113

114114
for j in 1:decoder.n
@@ -138,7 +138,7 @@ function decode!(decoder::BeliefPropagationDecoder, syndrome::AbstractVector) #
138138

139139
for j in 1:decoder.n
140140
temp::Float64 = setup.channel_probs[j] / (1 - setup.channel_probs[j])
141-
141+
142142
for k in nzrange(decoder.sparse_H, j)
143143
setup.bit_2_check[rows[k],j] = temp
144144
temp *= setup.check_2_bit[rows[k],j]
@@ -166,7 +166,7 @@ function decode!(decoder::BeliefPropagationDecoder, syndrome::AbstractVector) #
166166

167167
syndrome_decoded = (decoder.sparse_H * setup.err) .% 2
168168
if all(syndrome_decoded .== syndrome)
169-
converged = true
169+
converged = true
170170
break # Break if converged
171171
end
172172
end
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
struct BeliefPropagationOSDDecoder <: AbstractDecoder
2+
"""A belief propagation decoder as a subroutine"""
3+
bp_decoder::BeliefPropagationDecoder
4+
"""Dense form of the parity check matrix"""
5+
H::BitMatrix
6+
"""The order of OSD; defaulted to be 0 in the constructor"""
7+
osd_order::Int
8+
end
9+
10+
function BeliefPropagationOSDDecoder(H::BitMatrix, per::Float64, max_iters::Int; osd_order::Int=0)
11+
bp_decoder = BeliefPropagationDecoder(H, per, max_iters)
12+
return BeliefPropagationOSDDecoder(bp_decoder, H, osd_order)
13+
end
14+
15+
function rowswap!(H::BitMatrix, i, j)
16+
@inbounds H[i, :], H[j, :] = H[j, :], H[i, :] # TODO This could be further optimized?
17+
end
18+
19+
function decode!(decoder::BeliefPropagationOSDDecoder, syndrome::AbstractVector)
20+
# use BP to get hard and soft decisions
21+
bp_err, converged = decode!(decoder.bp_decoder, syndrome) # hard decisions
22+
bp_log_probabs = decoder.bp_decoder.scratch.log_probabs # soft decisions
23+
bp_probabs = exp.(bp_log_probabs)
24+
# sort columns by reliability, less reliable columns first
25+
sort_by_reliability = sortperm(max.(bp_probabs, 1 .- bp_probabs), rev=true)
26+
H_sorted = decoder.H[:, sort_by_reliability]
27+
bp_err_sorted = bp_err[sort_by_reliability]
28+
# TODO an optimized version of OSD can be implemented when osd_order = 0, see Algorithm 2 in https://doi.org/10.22331/q-2021-11-22-585
29+
err = osd(H_sorted, syndrome, bp_err_sorted, decoder.osd_order)
30+
return err[invperm(sort_by_reliability)], converged # also return whether BP is converged
31+
end
32+
33+
function osd(H, syndrome, bp_err, osd_order)
34+
m, n = size(H)
35+
# diagnolize the submatrix corresponding to independent columns via Gaussian elimination
36+
# first obtain the row canonical form
37+
# and find least reliable indices, i.e., the first r pivot columns (assume H is rearranged by reliability)
38+
least_reliable_rows = [] # row indices of pivot elements
39+
least_reliable_cols = [] # column indices of pivot elements
40+
r = 0 # compute rank of H
41+
i, j = 1, 1
42+
s = copy(syndrome) # transform syndrome along with H in Gaussian elimination
43+
44+
while i <= m && j <= n
45+
k = findfirst(H[i:end, j])
46+
if isnothing(k) # not an independent column
47+
j += 1
48+
else
49+
if k > 1
50+
ii = i + k - 1 # the first row after `i` with 1 in column `j`
51+
rowswap!(H, i, ii) # TODO For optimization: Is this swap necessary? We may just track the row index
52+
s[i], s[ii] = s[ii], s[i]
53+
end
54+
for ii in i+1:m
55+
if H[ii, j]
56+
H[ii, :] .⊻= H[i, :]
57+
s[ii] ⊻= s[i]
58+
end
59+
end
60+
push!(least_reliable_rows, i)
61+
push!(least_reliable_cols, j)
62+
i += 1
63+
j += 1
64+
r += 1
65+
end
66+
end
67+
68+
# then obtain a diagonal submatrix on the least reliable part
69+
for (i, j) in zip(reverse(least_reliable_rows), reverse(least_reliable_cols))
70+
for ii in 1:i-1
71+
if H[ii, j]
72+
H[ii, :] .⊻= H[i, :]
73+
s[ii] ⊻= s[i]
74+
end
75+
end
76+
end
77+
78+
if osd_order > n - r
79+
@warn "The order of OSD $osd_order is greater than the size of the information set $(n-r). We set osd_order = $(n-r)."
80+
osd_order = n - r
81+
end
82+
83+
best_err = copy(bp_err)
84+
err = Bool.(copy(bp_err)) # TODO why error is in Float in BP?
85+
most_reliable_cols = setdiff(1:n, least_reliable_cols)
86+
min_weight = n + 1
87+
88+
for x in 0:2^osd_order-1
89+
# first compute the `most_reliable_cols` part of errors
90+
# try all possible errors on the first `osd_order` bits within `most_reliable_cols`
91+
if x != 0
92+
trial_err = BitArray([x >> i & 1 for i in 0:osd_order-1])
93+
for j in 1:osd_order
94+
err[most_reliable_cols[j]] = trial_err[j]
95+
end
96+
end
97+
# then based on the `most_reliable_cols` part of errors, compute the `least_reliable_cols` part of errors
98+
for (i, j) in zip(least_reliable_rows, least_reliable_cols)
99+
err[j] = s[i]
100+
for k in most_reliable_cols
101+
err[j] ⊻= H[i, k] * err[k]
102+
end
103+
end
104+
weight = sum(err) # This weight is set for depolarizing noise
105+
# TODO More generally, it should be a function depending on the noise model
106+
if weight < min_weight
107+
min_weight = weight
108+
best_err = copy(err)
109+
end
110+
end
111+
112+
return best_err
113+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ println("ENV[\"PYTHON\"] = \"$(get(ENV,"PYTHON",nothing))\"")
3030

3131
@doset "oldtests"
3232
@doset "bp_decoder"
33+
@doset "bposd_decoder"
3334
@doset "bf_decoder"
3435

3536
VERSION >= v"1.10" && @doset "doctests"

test/test_bposd_decoder.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using Test
2+
using LDPCDecoders
3+
4+
@testset "test_bposd_decoder.jl" begin
5+
6+
"""Test for BP-OSD decoder"""
7+
function test_bposd_decoder()
8+
H = LDPCDecoders.parity_check_matrix(1000, 10, 9)
9+
per = 0.01
10+
err = rand(1000) .< per
11+
syn = (H * err) .% 2
12+
13+
bposd = BeliefPropagationOSDDecoder(H, per, 100)
14+
guess, success = decode!(bposd, syn)
15+
16+
return guess == err
17+
end
18+
19+
"""Test high order OSD"""
20+
function test_bposd_decoder_high_order()
21+
H = LDPCDecoders.parity_check_matrix(1000, 10, 9)
22+
per = 0.01
23+
err = rand(1000) .< per
24+
syn = (H * err) .% 2
25+
26+
orders = 2:5
27+
succ = true
28+
for osd_order in orders
29+
bposd = BeliefPropagationOSDDecoder(H, per, 100; osd_order=osd_order)
30+
guess, success = decode!(bposd, syn)
31+
succ = succ & (guess == err)
32+
end
33+
34+
return succ
35+
end
36+
37+
"""Test for BP-OSD decoder with large error rate. Even if the decoding is not accurate, OSD will still ensure consistency between guess and syndromes."""
38+
function test_bposd_decoder_large_error_rate()
39+
H = LDPCDecoders.parity_check_matrix(1000, 10, 9)
40+
per = 0.2
41+
err = rand(1000) .< per
42+
syn = (H * err) .% 2
43+
44+
bposd = BeliefPropagationOSDDecoder(H, per, 100)
45+
guess, success = decode!(bposd, syn)
46+
47+
return syn == (H * guess) .% 2
48+
end
49+
50+
@test test_bposd_decoder()
51+
@test test_bposd_decoder_high_order()
52+
@test test_bposd_decoder_large_error_rate()
53+
end

0 commit comments

Comments
 (0)