Skip to content

Commit b10998a

Browse files
authored
Merge pull request #172 from JuliaSymbolics/ale/localbuffers
Allow local buffers
2 parents 7828344 + ce871dd commit b10998a

File tree

6 files changed

+62
-63
lines changed

6 files changed

+62
-63
lines changed

src/EGraphs/EGraphs.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ include("../docstrings.jl")
55
using DataStructures
66
using TermInterface
77
using TimerOutputs
8-
using Metatheory:
9-
alwaystrue, cleanast, binarize, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
8+
using Metatheory: alwaystrue, cleanast, binarize
109
using Metatheory.Patterns
1110
using Metatheory.Rules
1211
using Metatheory.EMatchCompiler

src/EGraphs/egraph.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
abstract type AbstractENode end
66

7+
import Metatheory: maybelock!
8+
79
const AnalysisData = NamedTuple{N,T} where {N,T<:Tuple}
810
const EClassId = Int64
911
const TermTypes = Dict{Tuple{Any,Int},Type}
12+
# TODO document bindings
13+
const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}}
14+
const DEFAULT_BUFFER_SIZE = 1048576
1015

1116
struct ENodeLiteral <: AbstractENode
1217
value
@@ -190,14 +195,21 @@ mutable struct EGraph
190195
termtypes::TermTypes
191196
numclasses::Int
192197
numnodes::Int
198+
"If we use global buffers we may need to lock. Defaults to true."
199+
needslock::Bool
200+
"Buffer for e-matching which defaults to a global. Use a local buffer for generated functions."
201+
buffer::Vector{Bindings}
202+
"Buffer for rule application which defaults to a global. Use a local buffer for generated functions."
203+
merges_buffer::Vector{Tuple{Int,Int}}
204+
lock::ReentrantLock
193205
end
194206

195207

196208
"""
197209
EGraph(expr)
198210
Construct an EGraph from a starting symbolic expression `expr`.
199211
"""
200-
function EGraph()
212+
function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE)
201213
EGraph(
202214
IntDisjointSet(),
203215
Dict{EClassId,EClass}(),
@@ -210,12 +222,19 @@ function EGraph()
210222
TermTypes(),
211223
0,
212224
0,
213-
# 0
225+
needslock,
226+
Bindings[],
227+
Tuple{Int,Int}[],
228+
ReentrantLock(),
214229
)
215230
end
216231

217-
function EGraph(e; keepmeta = false)
218-
g = EGraph()
232+
function maybelock!(f::Function, g::EGraph)
233+
g.needslock ? lock(f, g.buffer_lock) : f()
234+
end
235+
236+
function EGraph(e; keepmeta = false, kwargs...)
237+
g = EGraph(kwargs...)
219238
keepmeta && addanalysis!(g, :metadata_analysis)
220239
g.root = addexpr!(g, e; keepmeta = keepmeta)
221240
g

src/EGraphs/saturation.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ function eqsat_search!(
119119
)::Int
120120
n_matches = 0
121121

122-
lock(BUFFER_LOCK) do
123-
empty!(BUFFER[])
122+
maybelock!(g) do
123+
empty!(g.buffer)
124124
end
125125

126126
@debug "SEARCHING"
@@ -166,13 +166,13 @@ function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId
166166
end
167167

168168
function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction)
169-
push!(MERGES_BUF[], (id, instantiate_enode!(buf, g, rule.right)))
169+
push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right)))
170170
nothing
171171
end
172172

173173
function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int)
174174
pat_to_inst = direction == 1 ? rule.right : rule.left
175-
push!(MERGES_BUF[], (id, instantiate_enode!(bindings, g, pat_to_inst)))
175+
push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst)))
176176
nothing
177177
end
178178

@@ -207,46 +207,44 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas
207207
r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
208208
isnothing(r) && return nothing
209209
rcid = addexpr!(g, r)
210-
push!(MERGES_BUF[], (id, rcid))
210+
push!(g.merges_buffer, (id, rcid))
211211
return nothing
212212
end
213213

214214

215215

216216
function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::SaturationReport, params::SaturationParams)
217217
i = 0
218-
@assert isempty(MERGES_BUF[])
218+
@assert isempty(g.merges_buffer)
219219

220-
@debug "APPLYING $(length(BUFFER[])) matches"
220+
@debug "APPLYING $(length(g.buffer)) matches"
221+
maybelock!(g) do
222+
while !isempty(g.buffer)
221223

222-
lock(BUFFER_LOCK) do
223-
while !isempty(BUFFER[])
224224
if reached(g, params.goal)
225225
@debug "Goal reached"
226226
rep.reason = :goalreached
227227
return
228228
end
229229

230-
bindings = popfirst!(BUFFER[])
230+
bindings = pop!(g.buffer)
231231
rule_idx, id = bindings[0]
232232
direction = sign(rule_idx)
233233
rule_idx = abs(rule_idx)
234234
rule = theory[rule_idx]
235235

236236

237-
halt_reason = lock(MERGES_BUF_LOCK) do
238-
apply_rule!(bindings, g, rule, id, direction)
239-
end
237+
halt_reason = apply_rule!(bindings, g, rule, id, direction)
240238

241239
if !isnothing(halt_reason)
242240
rep.reason = halt_reason
243241
return
244242
end
245243
end
246244
end
247-
lock(MERGES_BUF_LOCK) do
248-
while !isempty(MERGES_BUF[])
249-
(l, r) = popfirst!(MERGES_BUF[])
245+
maybelock!(g) do
246+
while !isempty(g.merges_buffer)
247+
(l, r) = pop!(g.merges_buffer)
250248
merge!(g, l, r)
251249
end
252250
end

src/Metatheory.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,14 @@ module Metatheory
22

33
using DataStructures
44

5-
import Base.ImmutableDict
6-
7-
const Bindings = ImmutableDict{Int,Tuple{Int,Int}}
8-
const DEFAULT_BUFFER_SIZE = 1048576
9-
const BUFFER = Ref(CircularDeque{Bindings}(DEFAULT_BUFFER_SIZE))
10-
const BUFFER_LOCK = ReentrantLock()
11-
const MERGES_BUF = Ref(CircularDeque{Tuple{Int,Int}}(DEFAULT_BUFFER_SIZE))
12-
const MERGES_BUF_LOCK = ReentrantLock()
13-
14-
function resetbuffers!(bufsize)
15-
BUFFER[] = CircularDeque{Bindings}(bufsize)
16-
MERGES_BUF[] = CircularDeque{Tuple{Int,Int}}(bufsize)
17-
end
18-
19-
function __init__()
20-
resetbuffers!(DEFAULT_BUFFER_SIZE)
21-
end
22-
235
using Base.Meta
246
using Reexport
257
using TermInterface
268

279
@inline alwaystrue(x) = true
2810

2911
function lookup_pat end
12+
function maybelock! end
3013

3114
include("docstrings.jl")
3215
include("utils.jl")

src/ematch_compiler.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module EMatchCompiler
22

33
using TermInterface
44
using ..Patterns
5-
using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, LL
5+
using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, LL, maybelock!
66

77
function ematcher(p::Any)
88
function literal_ematcher(next, g, data, bindings)
@@ -48,7 +48,7 @@ function predicate_ematcher(p::PatVar, pred)
4848
end
4949
end
5050
end
51-
51+
5252
function ematcher(p::PatVar)
5353
pred_matcher = predicate_ematcher(p, p.predicate)
5454

@@ -115,14 +115,14 @@ function ematcher(p::PatTerm)
115115

116116
for n in g[car(data)]
117117
if canbindtop(n)
118-
loop(LL(arguments(n),1), bindings, ematchers)
118+
loop(LL(arguments(n), 1), bindings, ematchers)
119119
end
120120
end
121121
end
122-
end
122+
end
123123

124124

125-
const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int, Int}}()
125+
const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}()
126126

127127
"""
128128
Substitutions are efficiently represented in memory as vector of tuples of two integers.
@@ -137,30 +137,30 @@ The format is as follows
137137
* The end of a substitution is delimited by (0,0)
138138
"""
139139
function ematcher_yield(p, npvars::Int, direction::Int)
140-
em = ematcher(p)
141-
function ematcher_yield(g, rule_idx, id)::Int
142-
n_matches = 0
143-
em(g, (id,), EMPTY_ECLASS_DICT) do b,n
144-
lock(BUFFER_LOCK) do
145-
push!(BUFFER[], assoc(b, 0, (rule_idx * direction, id)))
146-
n_matches+=1
147-
end
148-
end
149-
n_matches
140+
em = ematcher(p)
141+
function ematcher_yield(g, rule_idx, id)::Int
142+
n_matches = 0
143+
em(g, (id,), EMPTY_ECLASS_DICT) do b, n
144+
maybelock!(g) do
145+
push!(g.buffer, assoc(b, 0, (rule_idx * direction, id)))
146+
n_matches += 1
147+
end
150148
end
149+
n_matches
150+
end
151151
end
152152

153-
ematcher_yield(p,npvars) = ematcher_yield(p,npvars,1)
153+
ematcher_yield(p, npvars) = ematcher_yield(p, npvars, 1)
154154

155155
function ematcher_yield_bidir(l, r, npvars::Int)
156-
eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1)
157-
function ematcher_yield_bidir(g, rule_idx, id)::Int
158-
eml(g,rule_idx,id) + emr(g,rule_idx,id)
159-
end
156+
eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1)
157+
function ematcher_yield_bidir(g, rule_idx, id)::Int
158+
eml(g, rule_idx, id) + emr(g, rule_idx, id)
159+
end
160160
end
161161

162162
ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p")
163163

164164
export ematcher_yield, ematcher_yield_bidir
165165

166-
end
166+
end

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ end
151151
macro matchable(expr)
152152
@assert expr.head == :struct
153153
name = expr.args[2]
154-
if name isa Expr
155-
name.head === :(<:) && (name = name.args[1])
154+
if name isa Expr
155+
name.head === :(<:) && (name = name.args[1])
156156
name isa Expr && name.head === :curly && (name = name.args[1])
157157
end
158158
fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args)

0 commit comments

Comments
 (0)