Skip to content

Create bpots_decoders.jl #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 304 additions & 0 deletions src/decoders/bpots_decoders.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
using SparseArrays

# State for storing BP-OTS computations
mutable struct BPOTSState
messages_vc::Dict{Tuple{Int,Int}, Float64}
messages_cv::Dict{Tuple{Int,Int}, Float64}
oscillations::Vector{Int}
biased_nodes::Set{Int}
prior_decisions::Vector{Int}
prior_llrs::Vector{Float64}
end

# Main decoder struct
struct BPOTSDecoder <: AbstractDecoder
per::Float64 # Physical error rate
max_iters::Int # Maximum iterations
s::Int # Number of stabilizers
n::Int # Number of qubits
T::Int # Biasing period
C::Float64 # Bias constant
sparse_H::SparseMatrixCSC{Bool,Int} # Sparse parity check matrix
sparse_HT::SparseMatrixCSC{Bool,Int} # Transposed matrix
scratch::BPOTSState # Working space for computations
end

# BP-OTS State initialization
# Initialize state
function initialize_bpots_state(H::SparseMatrixCSC, n::Int)
messages_vc = Dict{Tuple{Int,Int}, Float64}()
messages_cv = Dict{Tuple{Int,Int}, Float64}()

Check warning on line 30 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L28-L30

Added lines #L28 - L30 were not covered by tests

# Initialize messages with small random values
rows = rowvals(H)
vals = nonzeros(H)

Check warning on line 34 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L33-L34

Added lines #L33 - L34 were not covered by tests

for j in 1:size(H,2)
for idx in nzrange(H, j)
i = rows[idx]
if vals[idx]
messages_vc[(j,i)] = 0.0 # Start neutral
messages_cv[(i,j)] = 0.0

Check warning on line 41 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L36-L41

Added lines #L36 - L41 were not covered by tests
end
end
end

Check warning on line 44 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L43-L44

Added lines #L43 - L44 were not covered by tests

return BPOTSState(

Check warning on line 46 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L46

Added line #L46 was not covered by tests
messages_vc,
messages_cv,
zeros(Int, n), # oscillations
Set{Int}(), # biased_nodes
zeros(Int, n), # prior_decisions
zeros(Float64, n) # prior_llrs
)
end

# Helper to get check node neighbors
function get_check_neighbors(H::SparseMatrixCSC, check_idx::Int)
neighbors = Int[]

Check warning on line 58 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L57-L58

Added lines #L57 - L58 were not covered by tests
# Look through the check node's row
for col in 1:size(H,2)
if H[check_idx, col]
push!(neighbors, col)

Check warning on line 62 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L60-L62

Added lines #L60 - L62 were not covered by tests
end
end
return neighbors

Check warning on line 65 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L64-L65

Added lines #L64 - L65 were not covered by tests
end

# Helper to get variable node neighbors
function get_variable_neighbors(H::SparseMatrixCSC, var_idx::Int)
neighbors = Int[]

Check warning on line 70 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
# Look through column's non-zero entries
for idx in nzrange(H, var_idx)
row = rowvals(H)[idx]
if nonzeros(H)[idx]
push!(neighbors, row)

Check warning on line 75 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L72-L75

Added lines #L72 - L75 were not covered by tests
end
end
return neighbors

Check warning on line 78 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
end

# Constructor for the decoder
function BPOTSDecoder(H::Union{SparseMatrixCSC{Bool,Int}, BitMatrix}, per::Float64, max_iters::Int; T::Int=9, C::Float64=2.0)
s, n = size(H)
sparse_H = sparse(H)
sparse_HT = sparse(H')
scratch = initialize_bpots_state(sparse_H, n)

Check warning on line 86 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L82-L86

Added lines #L82 - L86 were not covered by tests

return BPOTSDecoder(per, max_iters, s, n, T, C, sparse_H, sparse_HT, scratch)

Check warning on line 88 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L88

Added line #L88 was not covered by tests
end

# Initialize beliefs with weaker bias
function initialize_beliefs(n::Int, per::Float64)

Check warning on line 92 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L92

Added line #L92 was not covered by tests
# Convert error probability to LLR with weak bias toward no errors
Π = fill(0.5 * log((1-per)/per), n) # Weak initial bias
return Π

Check warning on line 95 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L94-L95

Added lines #L94 - L95 were not covered by tests
end

# Compute beliefs and make decisions
function compute_beliefs!(decoder::BPOTSDecoder, state::BPOTSState, Ω::Vector{Float64})
n = decoder.n
decisions = zeros(Int, n)
llrs = zeros(Float64, n)

Check warning on line 102 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L99-L102

Added lines #L99 - L102 were not covered by tests

for j in 1:n

Check warning on line 104 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L104

Added line #L104 was not covered by tests
# Sum all incoming messages for this variable node
llr = Ω[j] # Start with prior
for i in get_variable_neighbors(decoder.sparse_H, j)
if haskey(state.messages_cv, (i,j))
llr += state.messages_cv[(i,j)]

Check warning on line 109 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L106-L109

Added lines #L106 - L109 were not covered by tests
end
end
llrs[j] = llr

Check warning on line 112 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L111-L112

Added lines #L111 - L112 were not covered by tests

# Decision threshold at 0 - flip bit if LLR is negative
decisions[j] = llr < 0.0 ? 1 : 0
end

Check warning on line 116 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L115-L116

Added lines #L115 - L116 were not covered by tests

return decisions, llrs

Check warning on line 118 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L118

Added line #L118 was not covered by tests
end

# Reset state between decodings with proper initialization
function reset!(decoder::BPOTSDecoder)
state = decoder.scratch
empty!(state.biased_nodes)
fill!(state.oscillations, 0)
fill!(state.prior_decisions, 0)
fill!(state.prior_llrs, 0)

Check warning on line 127 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L122-L127

Added lines #L122 - L127 were not covered by tests

# Reset messages to neutral values (paper starts with 0)
for k in keys(state.messages_vc)
state.messages_vc[k] = 0.0
end
for k in keys(state.messages_cv)
state.messages_cv[k] = 0.0
end

Check warning on line 135 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L130-L135

Added lines #L130 - L135 were not covered by tests

return decoder

Check warning on line 137 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L137

Added line #L137 was not covered by tests
end

# Update variable-to-check message according to Equation (1) in the paper
function update_variable_to_check!(state::BPOTSState, j::Int, i::Int, H::SparseMatrixCSC, Ω::Vector{Float64})
connected_checks = get_variable_neighbors(H, j)

Check warning on line 142 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L141-L142

Added lines #L141 - L142 were not covered by tests

# Sum all incoming messages except from target check
msg_sum = 0.0
for check in connected_checks
if check != i
msg = get(state.messages_cv, (check,j), 0.0)
msg_sum += msg

Check warning on line 149 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L145-L149

Added lines #L145 - L149 were not covered by tests
end
end

Check warning on line 151 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L151

Added line #L151 was not covered by tests

# Add prior and store - Equation (1) from the paper
msg = Ω[j] + msg_sum
state.messages_vc[(j,i)] = msg

Check warning on line 155 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L154-L155

Added lines #L154 - L155 were not covered by tests
end

# Update check-to-variable message according to Equation (2) in the paper
function update_check_to_variable!(state::BPOTSState, i::Int, j::Int, H::SparseMatrixCSC, syndrome::Vector{Bool})
connected_vars = get_check_neighbors(H, i)

Check warning on line 160 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L159-L160

Added lines #L159 - L160 were not covered by tests

# Compute product of tanh values from other variables
prod_tanh = 1.0
MAX_TANH = 0.99999 # For numerical stability

Check warning on line 164 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L163-L164

Added lines #L163 - L164 were not covered by tests

for var in connected_vars
if var != j
msg = get(state.messages_vc, (var,i), 0.0)
t = tanh(0.5 * msg)
t = min(MAX_TANH, max(-MAX_TANH, t))
prod_tanh *= t

Check warning on line 171 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L166-L171

Added lines #L166 - L171 were not covered by tests
end
end

Check warning on line 173 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L173

Added line #L173 was not covered by tests

# Apply syndrome as in Equation (2)
if syndrome[i]
prod_tanh = -prod_tanh

Check warning on line 177 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L176-L177

Added lines #L176 - L177 were not covered by tests
end

# Compute message using atanh function
msg = 2.0 * atanh(prod_tanh)
state.messages_cv[(i,j)] = msg

Check warning on line 182 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
end

# decode! function matching the paper's algorithm
function decode!(decoder::BPOTSDecoder, syndrome::Vector{Bool})
state = decoder.scratch
reset!(decoder)

Check warning on line 188 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L186-L188

Added lines #L186 - L188 were not covered by tests

println("\nDEBUG: Starting decode")
println("DEBUG: Syndrome: ", syndrome)

Check warning on line 191 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L190-L191

Added lines #L190 - L191 were not covered by tests

# Initialize priors
# Πj = log((1-(2ϵ/3))/(2ϵ/3))- from the paper
Π = fill(log((1-(2*decoder.per/3))/(2*decoder.per/3)), decoder.n)
Ω = copy(Π)

Check warning on line 196 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L195-L196

Added lines #L195 - L196 were not covered by tests

# Track best solution
best_decisions = zeros(Int, decoder.n)
best_mismatch = length(syndrome)
best_weight = decoder.n

Check warning on line 201 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L199-L201

Added lines #L199 - L201 were not covered by tests

for iter in 1:decoder.max_iters
println("\nDEBUG: === Iteration $iter ===")

Check warning on line 204 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L203-L204

Added lines #L203 - L204 were not covered by tests

# Message passing
println("\nDEBUG: Variable updates")
for j in 1:decoder.n
for i in get_variable_neighbors(decoder.sparse_H, j)
update_variable_to_check!(state, j, i, decoder.sparse_H, Ω)
end
end

Check warning on line 212 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L207-L212

Added lines #L207 - L212 were not covered by tests

println("\nDEBUG: Check updates")
for i in 1:decoder.s
for j in 1:decoder.n
if decoder.sparse_H[i,j]
update_check_to_variable!(state, i, j, decoder.sparse_H, syndrome)

Check warning on line 218 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L214-L218

Added lines #L214 - L218 were not covered by tests
end
end
end

Check warning on line 221 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L220-L221

Added lines #L220 - L221 were not covered by tests

# Compute marginals (beliefs)
decisions = zeros(Int, decoder.n)
llrs = zeros(Float64, decoder.n)

Check warning on line 225 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L224-L225

Added lines #L224 - L225 were not covered by tests

println("\nDEBUG: Computing beliefs")
for j in 1:decoder.n
llr = Ω[j]
println("DEBUG: Node $j")
println("DEBUG: Prior: $llr")

Check warning on line 231 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L227-L231

Added lines #L227 - L231 were not covered by tests

for i in get_variable_neighbors(decoder.sparse_H, j)
msg = state.messages_cv[(i,j)]
llr += msg
println("DEBUG: Added check $i msg: $msg")
end

Check warning on line 237 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L233-L237

Added lines #L233 - L237 were not covered by tests

llrs[j] = llr
decisions[j] = llr < 0.0 ? 1 : 0
println("DEBUG: Final LLR: $llr -> Decision: $(decisions[j])")
end

Check warning on line 242 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L239-L242

Added lines #L239 - L242 were not covered by tests

# Update oscillation vector using XOR between consecutive decisions
# This tracks how often bits flip
if iter > 1
for j in 1:decoder.n
state.oscillations[j] += (decisions[j] ⊻ state.prior_decisions[j])
end

Check warning on line 249 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L246-L249

Added lines #L246 - L249 were not covered by tests
end
state.prior_decisions = copy(decisions)
state.prior_llrs = copy(llrs)

Check warning on line 252 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L251-L252

Added lines #L251 - L252 were not covered by tests

# Check solution
check_result = Bool.(mod.(decoder.sparse_H * decisions, 2))
mismatch = count(check_result .!= syndrome)
weight = sum(decisions)
println("\nDEBUG: Mismatch: $mismatch, Weight: $weight")

Check warning on line 258 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L255-L258

Added lines #L255 - L258 were not covered by tests

if mismatch < best_mismatch || (mismatch == best_mismatch && weight < best_weight)
println("DEBUG: New best solution!")
best_mismatch = mismatch
best_weight = weight
best_decisions = copy(decisions)

Check warning on line 264 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L260-L264

Added lines #L260 - L264 were not covered by tests

if mismatch == 0
println("DEBUG: Found valid solution!")
return best_decisions, true

Check warning on line 268 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L266-L268

Added lines #L266 - L268 were not covered by tests
end
end

# Apply biasing if needed - following Algorithm 1 in the paper
if mismatch > 0 && iter % decoder.T == 0

Check warning on line 273 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L273

Added line #L273 was not covered by tests
# Reset priors to original values
Ω = copy(Π)

Check warning on line 275 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L275

Added line #L275 was not covered by tests

# Check if there are any oscillating nodes
if maximum(state.oscillations) > 0

Check warning on line 278 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L278

Added line #L278 was not covered by tests
# Step 1: Find nodes with maximum oscillations (F set in the paper)
max_osc = maximum(state.oscillations)
max_osc_indices = findall(o -> o == max_osc, state.oscillations)

Check warning on line 281 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L280-L281

Added lines #L280 - L281 were not covered by tests

# Step 2: From F, select j1 with minimum |LLR|
j1 = max_osc_indices[argmin([abs(llrs[j]) for j in max_osc_indices])]

Check warning on line 284 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L284

Added line #L284 was not covered by tests

# Reset oscillation counter for j1
state.oscillations[j1] = 0

Check warning on line 287 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L287

Added line #L287 was not covered by tests

# Bias j1
Ω[j1] = -decoder.C
println("DEBUG: Biasing j1 (node $j1): $(Ω[j1])")

Check warning on line 291 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L290-L291

Added lines #L290 - L291 were not covered by tests

# Step 3: Select j2 with minimum |LLR| overall
j2 = argmin(abs.(llrs))

Check warning on line 294 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L294

Added line #L294 was not covered by tests

# Bias j2 (which may coincide with j1)
Ω[j2] = -decoder.C
println("DEBUG: Biasing j2 (node $j2): $(Ω[j2])")

Check warning on line 298 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L297-L298

Added lines #L297 - L298 were not covered by tests
end
end
end

Check warning on line 301 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L301

Added line #L301 was not covered by tests

return best_decisions, false

Check warning on line 303 in src/decoders/bpots_decoders.jl

View check run for this annotation

Codecov / codecov/patch

src/decoders/bpots_decoders.jl#L303

Added line #L303 was not covered by tests
end
Loading