Skip to content

Commit 7828344

Browse files
authored
Merge pull request #169 from JuliaSymbolics/ale/debugging
Add debugging and GraphViz visualization utilities
2 parents 39349f9 + ef978aa commit 7828344

File tree

12 files changed

+413
-58
lines changed

12 files changed

+413
-58
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ makedocs(
2929
"index.md"
3030
"rewrite.md"
3131
"egraphs.md"
32+
"visualizing.md"
3233
"api.md"
3334
"Tutorials" => tutorials
3435
],

docs/src/assets/graphviz.svg

Lines changed: 240 additions & 0 deletions
Loading

docs/src/visualizing.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Visualizing E-Graphs
2+
3+
You can visualize e-graphs in VSCode by using [GraphViz.jl]()
4+
5+
All you need to do is to install GraphViz.jl and to evaluate an e-graph after including the extra script:
6+
7+
```julia
8+
using GraphViz
9+
10+
include(dirname(pathof(Metatheory)) * "/extras/graphviz.jl")
11+
12+
algebra_rules = @theory a b c begin
13+
a * (b * c) == (a * b) * c
14+
a + (b + c) == (a + b) + c
15+
16+
a + b == b + a
17+
a * (b + c) == (a * b) + (a * c)
18+
(a + b) * c == (a * c) + (b * c)
19+
20+
-a == -1 * a
21+
a - b == a + -b
22+
1 * a == a
23+
24+
0 * a --> 0
25+
a + 0 --> a
26+
27+
a::Number * b == b * a::Number
28+
a::Number * b::Number => a * b
29+
a::Number + b::Number => a + b
30+
end;
31+
32+
ex = :(a - a)
33+
g = EGraph(ex)
34+
params = SaturationParams(; timeout = 2)
35+
saturate!(g, algebra_rules, params)
36+
g
37+
```
38+
39+
And you will see a nice e-graph drawing in the Julia Plots VSCode panel:
40+
41+
![E-Graph Drawing](/assets/graphviz.svg)

src/EGraphs/EGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DataStructures
66
using TermInterface
77
using TimerOutputs
88
using Metatheory:
9-
alwaystrue, cleanast, binarize, @log, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
9+
alwaystrue, cleanast, binarize, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
1010
using Metatheory.Patterns
1111
using Metatheory.Rules
1212
using Metatheory.EMatchCompiler

src/EGraphs/egraph.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,6 @@ end
294294
Inserts an e-node in an [`EGraph`](@ref)
295295
"""
296296
function add!(g::EGraph, n::AbstractENode)::EClassId
297-
@debug("adding ", n)
298-
299297
n = canonicalize(g, n)
300298
haskey(g.memo, n) && return g.memo[n]
301299

@@ -378,9 +376,6 @@ function Base.merge!(g::EGraph, a::EClassId, b::EClassId)::EClassId
378376

379377
id_a == id_b && return id_a
380378
to = union!(g.uf, id_a, id_b)
381-
382-
@debug "merging" id_a id_b
383-
384379
from = (to == id_a) ? id_b : id_a
385380

386381
push!(g.dirty, to)
@@ -432,15 +427,13 @@ function repair!(g::EGraph, id::EClassId)
432427
id = find(g, id)
433428
ecdata = g[id]
434429
ecdata.id = id
435-
@debug "repairing " id
436430

437431
new_parents = (length(ecdata.parents) > 30 ? OrderedDict : LittleDict){AbstractENode,EClassId}()
438432

439433
for (p_enode, p_eclass) in ecdata.parents
440434
p_enode = canonicalize!(g, p_enode)
441435
# deduplicate parents
442436
if haskey(new_parents, p_enode)
443-
@debug "merging classes" p_eclass (new_parents[p_enode])
444437
merge!(g, p_eclass, new_parents[p_enode])
445438
end
446439
n_id = find(g, p_eclass)
@@ -449,7 +442,6 @@ function repair!(g::EGraph, id::EClassId)
449442
end
450443

451444
ecdata.parents = collect(new_parents)
452-
@debug "updated parents " id g.parents[id]
453445

454446
# ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes)
455447

src/EGraphs/saturation.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ Base.@kwdef mutable struct SaturationParams
7171
schedulerparams::Tuple = ()
7272
threaded::Bool = false
7373
timer::Bool = true
74-
printiter::Bool = false
7574
end
7675

7776
# function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64}
@@ -124,17 +123,20 @@ function eqsat_search!(
124123
empty!(BUFFER[])
125124
end
126125

126+
@debug "SEARCHING"
127127
for (rule_idx, rule) in enumerate(theory)
128128
@timeit report.to string(rule_idx) begin
129129
# don't apply banned rules
130130
if !cansearch(scheduler, rule)
131+
@debug "$rule is banned"
131132
continue
132133
end
133134
ids = cached_ids(g, rule.left)
134135
rule isa BidirRule && (ids = ids cached_ids(g, rule.right))
135136
for i in ids
136137
n_matches += rule.ematcher!(g, rule_idx, i)
137138
end
139+
n_matches > 0 && @debug "Rule $rule_idx: $rule produced $n_matches matches"
138140
inform!(scheduler, rule, n_matches)
139141
end
140142
end
@@ -180,7 +182,7 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::UnequalRule, id::EClas
180182
other_id = instantiate_enode!(bindings, g, pat_to_inst)
181183

182184
if find(g, id) == find(g, other_id)
183-
@log "Contradiction!" rule
185+
@debug "$rule produced a contradiction!"
184186
return :contradiction
185187
end
186188
nothing
@@ -191,7 +193,7 @@ Instantiate argument for dynamic rule application in e-graph
191193
"""
192194
function instantiate_actual_param!(bindings::Bindings, g::EGraph, i)
193195
ecid, literal_position = bindings[i]
194-
ecid <= 0 && error("unbound pattern variable $pat in rule $rule")
196+
ecid <= 0 && error("unbound pattern variable")
195197
eclass = g[ecid]
196198
if literal_position > 0
197199
@assert eclass[literal_position] isa ENodeLiteral
@@ -215,10 +217,12 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
215217
i = 0
216218
@assert isempty(MERGES_BUF[])
217219

220+
@debug "APPLYING $(length(BUFFER[])) matches"
221+
218222
lock(BUFFER_LOCK) do
219223
while !isempty(BUFFER[])
220224
if reached(g, params.goal)
221-
@log "Goal reached"
225+
@debug "Goal reached"
222226
rep.reason = :goalreached
223227
return
224228
end
@@ -249,10 +253,6 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
249253
end
250254

251255

252-
253-
import ..@log
254-
255-
256256
"""
257257
Core algorithm of the library: the equality saturation step.
258258
"""
@@ -276,6 +276,8 @@ function eqsat_step!(
276276
end
277277
@timeit report.to "Rebuild" rebuild!(g)
278278

279+
@debug smallest_expr = extract!(g, astsize)
280+
279281
return report
280282
end
281283

@@ -297,7 +299,7 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
297299
while true
298300
curr_iter += 1
299301

300-
params.printiter && @info("iteration ", curr_iter)
302+
@debug "================ EQSAT ITERATION $curr_iter ================"
301303

302304
report = eqsat_step!(g, theory, curr_iter, sched, params, report)
303305

@@ -328,7 +330,6 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
328330
end
329331
end
330332
report.iterations = curr_iter
331-
@log report
332333

333334
return report
334335
end
@@ -339,13 +340,9 @@ function areequal(theory::Vector, exprs...; params = SaturationParams())
339340
end
340341

341342
function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params = SaturationParams())
342-
@log "Checking equality for " exprs
343343
if length(exprs) == 1
344344
return true
345345
end
346-
# rebuild!(G)
347-
348-
@log "starting saturation"
349346

350347
n = length(exprs)
351348
ids = map(Base.Fix1(addexpr!, g), collect(exprs))

src/Metatheory.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,6 @@ using Base.Meta
2424
using Reexport
2525
using TermInterface
2626

27-
macro log(args...)
28-
quote
29-
haskey(ENV, "MT_DEBUG") && @info($(args...))
30-
end |> esc
31-
end
32-
3327
@inline alwaystrue(x) = true
3428

3529
function lookup_pat end

src/Syntax.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ Remove LineNumberNode from quoted blocks of code
1818
rmlines(e::Expr) = Expr(e.head, map(rmlines, filter(x -> !(x isa LineNumberNode), e.args))...)
1919
rmlines(a) = a
2020

21+
function_object_or_quote(op::Symbol, mod)::Expr = :(isdefined($mod, $(QuoteNode(op))) ? $op : $(QuoteNode(op)))
22+
function_object_or_quote(op, mod) = op
2123

2224
function makesegment(s::Expr, pvars)
2325
if !(exprhead(s) == :(::))
@@ -84,38 +86,41 @@ function makepattern(ex::Expr, pvars, slots, mod = @__MODULE__, splat = false)
8486
head = exprhead(ex)
8587
op = operation(ex)
8688
# Retrieve the function object if available
89+
# Optionally quote function objects
8790
args = arguments(ex)
8891
istree(op) && (op = makepattern(op, pvars, slots, mod))
8992

9093
if head === :call
9194
if operation(ex) === :(~) # is a variable or segment
92-
if args[1] isa Expr && operation(args[1]) == :(~)
93-
# matches ~~x::predicate or ~~x::predicate...
94-
return makesegment(arguments(args[1])[1], pvars)
95-
elseif splat
96-
# matches ~x::predicate...
97-
return makesegment(args[1], pvars)
98-
else
99-
return makevar(args[1], pvars)
95+
let v = args[1]
96+
if v isa Expr && operation(v) == :(~)
97+
# matches ~~x::predicate or ~~x::predicate...
98+
makesegment(arguments(v)[1], pvars)
99+
elseif splat
100+
# matches ~x::predicate...
101+
makesegment(v, pvars)
102+
else
103+
makevar(v, pvars)
104+
end
100105
end
101-
else # is a term
106+
else # Matches a term
102107
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
103-
return :($PatTerm(:call, $op, [$(patargs...)]))
108+
:($PatTerm(:call, $(function_object_or_quote(op, mod)), [$(patargs...)]))
104109
end
110+
105111
elseif head === :...
106112
makepattern(args[1], pvars, slots, mod, true)
107113
elseif head == :(::) && args[1] in slots
108-
return splat ? makesegment(ex, pvars) : makevar(ex, pvars)
114+
splat ? makesegment(ex, pvars) : makevar(ex, pvars)
109115
elseif head === :ref
110116
# getindex
111117
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
112-
return :($PatTerm(:ref, getindex, [$(patargs...)]))
118+
:($PatTerm(:ref, getindex, [$(patargs...)]))
113119
elseif head === :$
114-
return args[1]
120+
args[1]
115121
else
116122
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
117-
return :($PatTerm($(QuoteNode(head)), $(op isa Symbol ? QuoteNode(op) : op), [$(patargs...)]))
118-
# throw(Meta.ParseError("Unsupported pattern syntax $ex"))
123+
:($PatTerm($(QuoteNode(head)), $(function_object_or_quote(op, mod)), [$(patargs...)]))
119124
end
120125
end
121126

@@ -328,7 +333,6 @@ macro rule(args...)
328333

329334
e = macroexpand(__module__, expr)
330335
e = rmlines(e)
331-
op = operation(e)
332336
RuleType = rule_sym_map(e)
333337

334338
l, r = arguments(e)

src/extras/graphviz.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
using GraphViz
2+
using Metatheory
3+
using TermInterface
4+
5+
function render_egraph!(io::IO, g::EGraph)
6+
print(
7+
io,
8+
"""digraph {
9+
compound=true
10+
clusterrank=local
11+
remincross=false
12+
ranksep=0.9
13+
""",
14+
)
15+
for (_, eclass) in g.classes
16+
render_eclass!(io, g, eclass)
17+
end
18+
println(io, "\n}\n")
19+
end
20+
21+
function render_eclass!(io::IO, g::EGraph, eclass::EClass)
22+
print(
23+
io,
24+
""" subgraph cluster_$(eclass.id) {
25+
style="dotted,rounded";
26+
rank=same;
27+
label="#$(eclass.id). Smallest: $(extract!(g, astsize; root=eclass.id))"
28+
fontcolor = gray
29+
fontsize = 8
30+
""",
31+
)
32+
33+
# if g.root == find(g, eclass.id)
34+
# println(io, " penwidth=2")
35+
# end
36+
37+
for (i, node) in enumerate(eclass.nodes)
38+
render_enode_node!(io, g, eclass.id, i, node)
39+
end
40+
print(io, "\n }\n")
41+
42+
for (i, node) in enumerate(eclass.nodes)
43+
render_enode_edges!(io, g, eclass.id, i, node)
44+
end
45+
println(io)
46+
end
47+
48+
49+
function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::AbstractENode)
50+
label = operation(node)
51+
# (mr, style) = if node in diff && get(report.cause, node, missing) !== missing
52+
# pair = get(report.cause, node, missing)
53+
# split(split("$(pair[1].rule) ", "=>")[1], "-->")[1], " color=\"red\""
54+
# else
55+
# " ", ""
56+
# end
57+
# sg *= " $id.$os [label=<$label<br /><font point-size=\"8\" color=\"gray\">$mr</font>> $style];"
58+
println(io, " $eclass_id.$i [label=<$label> shape=box style=rounded]")
59+
end
60+
61+
render_enode_edges!(::IO, ::EGraph, eclass_id, i, ::ENodeLiteral) = nothing
62+
63+
function render_enode_edges!(io::IO, g::EGraph, eclass_id, i, node::ENodeTerm)
64+
len = length(arguments(node))
65+
for (ite, child) in enumerate(arguments(node))
66+
cluster_id = find(g, child)
67+
# The limitation of graphviz is that it cannot point to the eclass outer frame,
68+
# so when pointing to the same e-class, the next best thing is to point to the same e-node.
69+
target_id = "$cluster_id" * (cluster_id == eclass_id ? ".$i" : ".1")
70+
71+
# In order from left to right, if there are more than 3 children, label the order.
72+
dir = if len == 2
73+
ite == 1 ? ":sw" : ":se"
74+
elseif len == 3
75+
ite == 1 ? ":sw" : (ite == 2 ? ":s" : ":se")
76+
else
77+
""
78+
end
79+
80+
linelabel = len > 3 ? " label=$ite" : " "
81+
println(io, " $eclass_id.$i$dir -> $target_id [arrowsize=0.5 lhead=cluster_$cluster_id $linelabel]")
82+
end
83+
end
84+
85+
function Base.convert(::Type{GraphViz.Graph}, g::EGraph)::GraphViz.Graph
86+
io = IOBuffer()
87+
render_egraph!(io, g)
88+
gs = String(take!(io))
89+
g = GraphViz.Graph(gs)
90+
GraphViz.layout!(g; engine = "dot")
91+
g
92+
end
93+
94+
function Base.show(io::IO, mime::MIME"image/svg+xml", g::EGraph)
95+
show(io, mime, convert(GraphViz.Graph, g))
96+
end

0 commit comments

Comments
 (0)