From 2e1e46d284997e436d3d047b473f6b6c37e0df4b Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 14:28:15 +0200 Subject: [PATCH 01/22] Output parent lists and check parents after rebuilding --- src/EGraphs/egraph.jl | 45 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index a4f0e5f3..8c549ef6 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -7,7 +7,7 @@ The `modify!` function for EGraph Analysis can optionally modify the eclass `eclass` after it has been analyzed, typically by adding an e-node. -It should be **idempotent** if no other changes occur to the EClass. +It should be **idempotent** if no other changes occur to the EClass. (See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)). """ function modify! end @@ -25,7 +25,7 @@ function join end """ make(g::EGraph{ExpressionType, AnalysisType}, n::VecExpr)::AnalysisType where {ExpressionType} -Given an e-node `n`, `make` should return the corresponding analysis value. +Given an e-node `n`, `make` should return the corresponding analysis value. """ function make end @@ -198,9 +198,9 @@ function to_expr(g::EGraph, n::VecExpr) end function pretty_dict(g::EGraph) - d = Dict{Int,Vector{Any}}() + d = Dict{Int,Tuple{Vector{Any},Vector{Any}}}() for (class_id, eclass) in g.classes - d[class_id.val] = map(n -> to_expr(g, n), eclass.nodes) + d[class_id.val] = (map(n -> to_expr(g, n), eclass.nodes), map(pair -> (Int(pair[2]), Int(find(g, pair[2]))), eclass.parents)) end d end @@ -209,8 +209,8 @@ export pretty_dict function Base.show(io::IO, g::EGraph) d = pretty_dict(g) t = "$(typeof(g)) with $(length(d)) e-classes:" - cs = map(sort!(collect(d); by = first)) do (k, vect) - " $k => [$(Base.join(vect, ", "))]" + cs = map(sort!(collect(d); by = first)) do (k, (nodes, parents)) + " $k => [$(Base.join(nodes, ", "))], parents: [$(Base.join(parents, ", "))]" end print(io, Base.join([t; cs], "\n")) end @@ -284,7 +284,7 @@ end """ Extend this function on your types to do preliminary -preprocessing of a symbolic term before adding it to +preprocessing of a symbolic term before adding it to an EGraph. Most common preprocessing techniques are binarization of n-ary terms and metadata stripping. """ @@ -429,7 +429,7 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp if !isnothing(node_data) if !isnothing(eclass.data) joined_data = join(eclass.data, node_data) - + if joined_data != eclass.data g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data) # eclass.data = joined_data @@ -448,6 +448,32 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp n_unions end +function check_parents(g::EGraph)::Bool + for (id, class) in g.classes + # make sure that the parent node and parent eclass occurs in the parents vector for all children + for n in class.nodes + for chd_id in v_children(n) + chd_class = g[chd_id] + any(pair -> canonicalize!(g, copy(pair[1])) == n, chd_class.parents) || error("parent node is missing from child_class.parents") + any(pair -> find(g, pair[2]) == id.val, chd_class.parents) || error("missing parent reference from child") + end + end + + # make sure all nodes and parent ids occuring in the parent vector have this eclass as a child + for pair in class.parents + parent_id = pair[2] + parent_node = pair[1] + parent_class = g[parent_id] + any(n -> any(ch -> ch == id.val, v_children(n)), parent_class.nodes) || error("no node in the parent references the eclass") # nodes are canonicalized + parent_node_copy = copy(parent_node) + canonicalize!(g, parent_node_copy) + (parent_node_copy in parent_class.nodes) || error("the node from the parent list does not occur in the parent nodes") # might fail because parent_node is probably not canonical + end + end + + true +end + function check_memo(g::EGraph)::Bool test_memo = Dict{VecExpr,Id}() for (id, class) in g.classes @@ -483,9 +509,10 @@ upwards merging in an [`EGraph`](@ref). See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) for more details. """ -function rebuild!(g::EGraph; should_check_memo=false, should_check_analysis=false) +function rebuild!(g::EGraph; should_check_parents=false, should_check_memo=false, should_check_analysis=false) n_unions = process_unions!(g) trimmed_nodes = rebuild_classes!(g) + @assert !should_check_parents || check_parents(g) @assert !should_check_memo || check_memo(g) @assert !should_check_analysis || check_analysis(g) g.clean = true From 130142225e4b33c88a95a2a0fb02a39c380d6eff Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 14:29:23 +0200 Subject: [PATCH 02/22] Path splitting procedure to shorten path length with find call. --- src/EGraphs/unionfind.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/EGraphs/unionfind.jl b/src/EGraphs/unionfind.jl index f53c0f23..86ef2b20 100644 --- a/src/EGraphs/unionfind.jl +++ b/src/EGraphs/unionfind.jl @@ -18,8 +18,10 @@ function Base.union!(uf::UnionFind, i::Id, j::Id) end function find(uf::UnionFind, i::Id) + # path splitting while i != uf.parents[i] - i = uf.parents[i] + (i, uf.parents[i]) = (uf.parents[i], uf.parents[uf.parents[i]]) end + i end From 500e25e03ac1f5eccb6154d97815c503a12224e5 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 15:33:35 +0200 Subject: [PATCH 03/22] Store original e-nodes in egraph and keep only e-node ids in parent lists. --- examples/prove.jl | 2 +- src/EGraphs/egraph.jl | 50 +++++++++++++++++++++------------------ src/EGraphs/saturation.jl | 2 ++ 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/examples/prove.jl b/examples/prove.jl index dfce791f..4c2a34d0 100644 --- a/examples/prove.jl +++ b/examples/prove.jl @@ -1,4 +1,4 @@ -# Sketch function for basic iterative saturation and extraction +# Sketch function for basic iterative saturation and extraction function prove( t, ex, diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 8c549ef6..15c09d31 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -42,7 +42,7 @@ they represent. The [`EGraph`](@ref) itself comes with pretty printing for human struct EClass{D} id::Id nodes::Vector{VecExpr} - parents::Vector{Pair{VecExpr,Id}} + parents::Vector{Id} # The original Ids of parent enodes. data::Union{D,Nothing} end @@ -65,10 +65,6 @@ function Base.show(io::IO, a::EClass) end end -function addparent!(@nospecialize(a::EClass), n::VecExpr, id::Id) - push!(a.parents, (n => id)) -end - function merge_analysis_data!(a::EClass{D}, b::EClass{D})::Tuple{Bool,Bool,Union{D,Nothing}} where {D} if !isnothing(a.data) && !isnothing(b.data) @@ -119,13 +115,16 @@ mutable struct EGraph{ExpressionType,Analysis} uf::UnionFind "map from eclass id to eclasses" classes::Dict{IdKey,EClass{Analysis}} + "vector of the original e-nodes" + nodes::Vector{VecExpr} "hashcons mapping e-nodes to their e-class id" memo::Dict{VecExpr,Id} "Hashcons the constants in the e-graph" constants::Dict{UInt64,Any} - "Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass." - pending::Vector{Pair{VecExpr,Id}} - analysis_pending::UniqueQueue{Pair{VecExpr,Id}} + "Nodes which need to be processed for rebuilding. The id is the id of the e-node, not the canonical id of the e-class." + pending::Vector{Id} + "E-classes that have to be updated for semantic analysis. The id is the id of the e-class." + analysis_pending::UniqueQueue{Id} root::Id "a cache mapping signatures (function symbols and their arity) to e-classes that contain e-nodes with that function symbol." classes_by_op::Dict{IdKey,Vector{Id}} @@ -144,10 +143,11 @@ function EGraph{ExpressionType,Analysis}(; needslock::Bool = false) where {Expre EGraph{ExpressionType,Analysis}( UnionFind(), Dict{IdKey,EClass{Analysis}}(), + Vector{VecExpr}(), Dict{VecExpr,Id}(), Dict{UInt64,Any}(), - Pair{VecExpr,Id}[], - UniqueQueue{Pair{VecExpr,Id}}(), + Id[], + UniqueQueue{Id}(), 0, Dict{IdKey,Vector{Id}}(), false, @@ -263,20 +263,21 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool) end id = push!(g.uf) # create new singleton eclass + push!(g.nodes, n) if v_isexpr(n) for c_id in v_children(n) - addparent!(g.classes[IdKey(c_id)], n, id) + push!(g.classes[IdKey(c_id)].parents, id) end end g.memo[n] = id add_class_by_op(g, n, id) - eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n)) + eclass = EClass{Analysis}(id, VecExpr[n], Id[], make(g, n)) # TODO: check do we need to copy n for the nodes vector here? g.classes[IdKey(id)] = eclass modify!(g, eclass) - push!(g.pending, n => id) + push!(g.pending, id) return id end @@ -407,10 +408,12 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp while !isempty(g.pending) || !isempty(g.analysis_pending) while !isempty(g.pending) - (node::VecExpr, eclass_id::Id) = pop!(g.pending) + enode_id = pop!(g.pending) + eclass_id = find(g, enode_id) + node = g.nodes[enode_id] node = copy(node) canonicalize!(g, node) - old_class_id = get!(g.memo, node, eclass_id) + old_class_id = get!(g.memo, node, eclass_id) # TODO: check if we should pop the old node from memo if old_class_id != eclass_id did_something = union!(g, old_class_id, eclass_id) # TODO unique! can node dedup be moved here? compare performance @@ -420,8 +423,9 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp end while !isempty(g.analysis_pending) - (node::VecExpr, eclass_id::Id) = pop!(g.analysis_pending) - eclass_id = find(g, eclass_id) + enode_id = pop!(g.analysis_pending) + node = g.nodes[enode_id] + eclass_id = find(g, enode_id) eclass_id_key = IdKey(eclass_id) eclass = g.classes[eclass_id_key] @@ -454,17 +458,17 @@ function check_parents(g::EGraph)::Bool for n in class.nodes for chd_id in v_children(n) chd_class = g[chd_id] - any(pair -> canonicalize!(g, copy(pair[1])) == n, chd_class.parents) || error("parent node is missing from child_class.parents") - any(pair -> find(g, pair[2]) == id.val, chd_class.parents) || error("missing parent reference from child") + any(nid -> canonicalize!(g, copy(g.nodes[nid])) == n, chd_class.parents) || error("parent node is missing from child_class.parents") + any(nid -> find(g, nid) == id.val, chd_class.parents) || error("missing parent reference from child") end end # make sure all nodes and parent ids occuring in the parent vector have this eclass as a child - for pair in class.parents - parent_id = pair[2] - parent_node = pair[1] - parent_class = g[parent_id] + for nid in class.parents + parent_class = g[nid] any(n -> any(ch -> ch == id.val, v_children(n)), parent_class.nodes) || error("no node in the parent references the eclass") # nodes are canonicalized + + parent_node = g.nodes[nid] parent_node_copy = copy(parent_node) canonicalize!(g, parent_node_copy) (parent_node_copy in parent_class.nodes) || error("the node from the parent list does not occur in the parent nodes") # might fail because parent_node is probably not canonical diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index fcc3555f..0d0a70e4 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -40,6 +40,8 @@ Base.@kwdef mutable struct SaturationParams check_memo::Bool = false "Activate check for join-semilattice invariant for semantic analysis values after rebuilding" check_analysis::Bool = false + "Activate check for parent vectors" + check_parents::Bool = false end function cached_ids(g::EGraph, p::PatExpr)::Vector{Id} From ac4829c6625587c4adfe2771bb8fd6be9305a997 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 15:40:46 +0200 Subject: [PATCH 04/22] Revert changes to pretty_dict output. --- src/EGraphs/egraph.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 15c09d31..e73ec754 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -198,9 +198,9 @@ function to_expr(g::EGraph, n::VecExpr) end function pretty_dict(g::EGraph) - d = Dict{Int,Tuple{Vector{Any},Vector{Any}}}() + d = Dict{Int,Vector{Any}}() for (class_id, eclass) in g.classes - d[class_id.val] = (map(n -> to_expr(g, n), eclass.nodes), map(pair -> (Int(pair[2]), Int(find(g, pair[2]))), eclass.parents)) + d[class_id.val] = (map(n -> to_expr(g, n), eclass.nodes)) end d end @@ -210,7 +210,7 @@ function Base.show(io::IO, g::EGraph) d = pretty_dict(g) t = "$(typeof(g)) with $(length(d)) e-classes:" cs = map(sort!(collect(d); by = first)) do (k, (nodes, parents)) - " $k => [$(Base.join(nodes, ", "))], parents: [$(Base.join(parents, ", "))]" + " $k => [$(Base.join(nodes, ", "))]" end print(io, Base.join([t; cs], "\n")) end From 6341b3b6da893bafa4ecfe4f6eddae6b72329b4c Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 15:41:35 +0200 Subject: [PATCH 05/22] Change lookup in classes_by_op dictionary to prevent allocation of a new vector for each lookup --- src/EGraphs/egraph.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index e73ec754..1e2b4625 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -245,7 +245,11 @@ end function add_class_by_op(g::EGraph, n, eclass_id) key = IdKey(v_signature(n)) - vec = get!(g.classes_by_op, key, Vector{Id}()) + vec = get(g.classes_by_op, key, nothing) + if isnothing(vec) + vec = Id[eclass_id] + g.classes_by_op[key] = vec + end push!(vec, eclass_id) end From 9fb097c2968aa039d2bfadc839c41416d6f30277 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 16:45:59 +0200 Subject: [PATCH 06/22] Fixed implementation of iterate for optbuffer (currently only affected debug output). --- src/optbuffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optbuffer.jl b/src/optbuffer.jl index 94a1d3f6..82a183b3 100644 --- a/src/optbuffer.jl +++ b/src/optbuffer.jl @@ -34,4 +34,4 @@ end Base.isempty(b::OptBuffer{T}) where {T} = b.i === 0 Base.empty!(b::OptBuffer{T}) where {T} = (b.i = 0) @inline Base.length(b::OptBuffer{T}) where {T} = b.i -Base.iterate(b::OptBuffer{T}, i=1) where {T} = iterate(b.v[1:b.i], i) \ No newline at end of file +Base.iterate(b::OptBuffer{T}, i=1) where {T} = i <= b.i ? (b.v[i], i + 1) : nothing \ No newline at end of file From 80a46969b6a705b6ba2c638edb66a8cd047fafcd Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 16:46:51 +0200 Subject: [PATCH 07/22] Fix compile error. --- src/EGraphs/egraph.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 1e2b4625..f6a9501c 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -209,7 +209,7 @@ export pretty_dict function Base.show(io::IO, g::EGraph) d = pretty_dict(g) t = "$(typeof(g)) with $(length(d)) e-classes:" - cs = map(sort!(collect(d); by = first)) do (k, (nodes, parents)) + cs = map(sort!(collect(d); by = first)) do (k, nodes) " $k => [$(Base.join(nodes, ", "))]" end print(io, Base.join([t; cs], "\n")) From 68d40d1f7de8bc92642bd8d38e2529091dac6b37 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Thu, 10 Oct 2024 17:48:45 +0200 Subject: [PATCH 08/22] Find of eclass_id for enode_id is not necessary here --- src/EGraphs/egraph.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index f6a9501c..15df6dfc 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -123,7 +123,6 @@ mutable struct EGraph{ExpressionType,Analysis} constants::Dict{UInt64,Any} "Nodes which need to be processed for rebuilding. The id is the id of the e-node, not the canonical id of the e-class." pending::Vector{Id} - "E-classes that have to be updated for semantic analysis. The id is the id of the e-class." analysis_pending::UniqueQueue{Id} root::Id "a cache mapping signatures (function symbols and their arity) to e-classes that contain e-nodes with that function symbol." @@ -413,13 +412,11 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp while !isempty(g.pending) || !isempty(g.analysis_pending) while !isempty(g.pending) enode_id = pop!(g.pending) - eclass_id = find(g, enode_id) - node = g.nodes[enode_id] - node = copy(node) + node = copy(g.nodes[enode_id]) canonicalize!(g, node) - old_class_id = get!(g.memo, node, eclass_id) # TODO: check if we should pop the old node from memo - if old_class_id != eclass_id - did_something = union!(g, old_class_id, eclass_id) + memo_class = get!(g.memo, node, enode_id) + if memo_class != enode_id + did_something = union!(g, memo_class, enode_id) # TODO unique! can node dedup be moved here? compare performance # did_something && unique!(g[eclass_id].nodes) n_unions += did_something From 013253e61b804137c04c662ed89193e53cbb21a8 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sat, 12 Oct 2024 12:28:02 +0200 Subject: [PATCH 09/22] Add some test assertions for internal datastructures used for egraph rebuilding. --- test/egraphs/egraphs.jl | 70 ++++++++++++++++++++++++++++++----------- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/test/egraphs/egraphs.jl b/test/egraphs/egraphs.jl index 4bc4b51e..2399d35d 100644 --- a/test/egraphs/egraphs.jl +++ b/test/egraphs/egraphs.jl @@ -4,10 +4,25 @@ using Test, Metatheory @testset "Merging" begin testexpr = :((a * 2) / 2) testmatch = :(a << 1) - g = EGraph(testexpr) + g = EGraph() + testexpr_id = addexpr!(g, testexpr) + t1 = addexpr!(g, :(a * 2)) # get eclass id of a * 2 t2 = addexpr!(g, testmatch) - union!(g, t2, Id(3)) - @test find(g, t2) == find(g, Id(3)) + union!(g, t2, t1) + + @testset "Behaviour" begin + @test find(g, t2) == find(g, t1) + end + + @testset "Internals" begin + @test length(g[t1].nodes) == 2 # a << 1, a * 2 + @test g[t1].parents == [testexpr_id] + + id_1 = addexpr!(g, 1) # get id of constant 1 + @test g[id_1].parents == [find(g, t1)] # just eclass [a << 1, a * 2] + id_a = addexpr!(g, :a) + @test g[id_a].parents == [find(g, t1)] # just eclass [a << 1, a * 2] + end end # testexpr = :(42a + b * (foo($(Dict(:x => 2)), 42))) @@ -25,10 +40,20 @@ end t2 = addexpr!(g, :c) union!(g, t2, t1) - @test find(g, t2) == find(g, t1) - @test find(g, t2) == find(g, t1) - rebuild!(g) - @test find(g, ec1) == find(g, ec2) + @testset "Behaviour" begin + @test find(g, t2) == find(g, t1) + rebuild!(g) + @test find(g, ec1) == find(g, ec2) + @test length(g[ec2].nodes) == 1 + end + + @testset "Internals" begin + aid = addexpr!(g, :a) # get id of :a + @test g[aid].parents == [find(g, ec1)] + @test g[t1].parents == [find(g, ec1)] + @test g[ec1].parents == [find(g, testec)] + @test length(g[testec].nodes) == 1 + end end @@ -53,19 +78,26 @@ end t3 = addexpr!(g, apply(3, f, :a)) t4 = addexpr!(g, apply(7, f, :a)) - # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a - @test find(g, t1) == find(g, a) - @test find(g, t2) == find(g, a) - @test find(g, t3) == find(g, a) - @test find(g, t4) != find(g, a) + @testset "Behaviour" begin + # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a + @test find(g, t1) == find(g, a) + @test find(g, t2) == find(g, a) + @test find(g, t3) == find(g, a) + @test find(g, t4) != find(g, a) - # if m or n is prime, f(a) = a - t5 = addexpr!(g, apply(11, f, :a)) - t6 = addexpr!(g, apply(1, f, :a)) - c5_id = union!(g, t5, a) # a == apply(11,f,a) + # if m or n is prime, f(a) = a + t5 = addexpr!(g, apply(11, f, :a)) + t6 = addexpr!(g, apply(1, f, :a)) + c5_id = union!(g, t5, a) # a == apply(11,f,a) - rebuild!(g) + rebuild!(g) + + @test find(g, t5) == find(g, a) + @test find(g, t6) == find(g, a) + end - @test find(g, t5) == find(g, a) - @test find(g, t6) == find(g, a) + @testset "Internals" begin + @test length(g.classes) == 1 # only a single class %id [:a, f(%id)] remains + @test g[a].parents = [find(g, a)] # there can be only a single parent + end end From 9db374dad7d804805b99ce137801fe253b150675 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:18:25 +0200 Subject: [PATCH 10/22] Set root to allow debugging (requires extraction) --- examples/prove.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/prove.jl b/examples/prove.jl index 4c2a34d0..acac87d5 100644 --- a/examples/prove.jl +++ b/examples/prove.jl @@ -32,9 +32,9 @@ function test_equality(t, exprs...; params = SaturationParams(), g = EGraph()) params = deepcopy(params) params.goal = (g::EGraph) -> in_same_class(g, ids...) + g.root = first(ids) # to allow extraction (for debugging) report = saturate!(g, t, params) goal_reached = params.goal(g) - if !(report.reason === :saturated) && !goal_reached return false # failed to prove end From 3ac0718a813ff294702132e5c3f000a315bcd91e Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:20:25 +0200 Subject: [PATCH 11/22] isless for VecExpr to allow sorting. --- src/vecexpr.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/vecexpr.jl b/src/vecexpr.jl index c18059f9..eee3a7eb 100644 --- a/src/vecexpr.jl +++ b/src/vecexpr.jl @@ -121,4 +121,6 @@ v_pair_last(p::UInt128)::UInt64 = UInt64(p & 0xffffffffffffffff) @inline Base.lastindex(n::VecExpr) = lastindex(n.data) @inline Base.firstindex(n::VecExpr) = firstindex(n.data) +Base.isless(a::VecExpr,b::VecExpr) = isless(a.data,b.data) + end From 90719d0e9932fe4bb4fcfaa449005be54deb70a7 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:22:07 +0200 Subject: [PATCH 12/22] Check most specific constants first. --- src/ematch_compiler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index ecb9b5ae..61430a65 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -96,12 +96,12 @@ end check_constant_exprs!(buf, p::PatLiteral) = push!(buf, :(has_constant(g, $(last(p.n))) || return 0)) check_constant_exprs!(buf, ::AbstractPat) = buf function check_constant_exprs!(buf, p::PatExpr) - if !(p.head isa AbstractPat) - push!(buf, :(has_constant(g, $(p.head_hash)) || has_constant(g, $(p.quoted_head_hash)) || return 0)) - end for child in children(p) check_constant_exprs!(buf, child) end + if !(p.head isa AbstractPat) + push!(buf, :(has_constant(g, $(p.head_hash)) || has_constant(g, $(p.quoted_head_hash)) || return 0)) + end buf end From 6ddffa462c6c0043eeb5fd1e302cc61233cbcec5 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:22:50 +0200 Subject: [PATCH 13/22] Comment and removed unnecessary parentheses. --- src/ematch_compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index 61430a65..43bcdb48 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -209,8 +209,8 @@ function bind_expr(addr, p::PatExpr, memrange) n = eclass.nodes[$(Symbol(:enode_idx, addr))] v_flags(n) === $(v_flags(p.n)) || @goto $(Symbol(:skip_node, addr)) - v_signature(n) === $(v_signature(p.n)) || @goto $(Symbol(:skip_node, addr)) - v_head(n) === $(v_head(p.n)) || (v_head(n) === $(p.quoted_head_hash) || @goto $(Symbol(:skip_node, addr))) + v_signature(n) === $(v_signature(p.n)) || @goto $(Symbol(:skip_node, addr)) # TODO better to check signature before flags? check perf. + v_head(n) === $(v_head(p.n)) || v_head(n) === $(p.quoted_head_hash) || @goto $(Symbol(:skip_node, addr)) # Node has matched. $([:($(Symbol(:σ, j)) = n[$i + $VECEXPR_META_LENGTH]) for (i, j) in enumerate(memrange)]...) From 3c07d51bf28cbd8dc05ab6517a31f8d7cc9aaf6e Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:24:02 +0200 Subject: [PATCH 14/22] Fixes for constant matching from different PR --- src/EGraphs/saturation.jl | 8 +++----- src/ematch_compiler.jl | 24 +++++++++++++++++------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 0d0a70e4..f590e110 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -146,14 +146,12 @@ end Instantiate argument for dynamic rule application in e-graph """ function instantiate_actual_param!(bindings, g::EGraph, i) + const_hash = v_pair_last(bindings[i]) + const_hash == 0 || return get_constant(g, const_hash) + ecid = v_pair_first(bindings[i]) - literal_position = reinterpret(Int, v_pair_last(bindings[i])) ecid <= 0 && error("unbound pattern variable") eclass = g[ecid] - if literal_position > 0 - @assert !v_isexpr(eclass[literal_position]) - return get_constant(g, v_head(eclass[literal_position])) - end return eclass end diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index 43bcdb48..57bc6437 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -249,12 +249,12 @@ function check_var_expr(addr, predicate::Function) quote eclass = g[$(Symbol(:σ, addr))] if ($predicate)(g, eclass) - for (j, n) in enumerate(eclass.nodes) - if !v_isexpr(n) - $(Symbol(:enode_idx, addr)) = j + 1 - break - end - end + # for (j, n) in enumerate(eclass.nodes) + # if !v_isexpr(n) + # $(Symbol(:enode_idx, addr)) = j + 1 + # break + # end + # end pc += 0x0001 @goto compute end @@ -322,7 +322,17 @@ end function yield_expr(patvar_to_addr, direction) push_exprs = [ - :(push!(ematch_buffer, v_pair($(Symbol(:σ, addr)), reinterpret(UInt64, $(Symbol(:enode_idx, addr)) - 1)))) for + quote + id = $(Symbol(:σ, addr)) + eclass = g[id] + node_idx = $(Symbol(:enode_idx, addr)) - 1 + if node_idx <= 0 + push!(ematch_buffer, v_pair(id, reinterpret(UInt64, 0))) + else + n = eclass.nodes[node_idx] + push!(ematch_buffer, v_pair(id, v_head(n))) + end + end for addr in patvar_to_addr ] quote From 97c272b473d3ce1c7390d065b361d680a6c3551e Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:24:58 +0200 Subject: [PATCH 15/22] Allow to set SaturationParams for simplify for testing, and mark two broken tests. --- test/integration/cas.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/integration/cas.jl b/test/integration/cas.jl index 5fd1c068..b1be5bac 100644 --- a/test/integration/cas.jl +++ b/test/integration/cas.jl @@ -145,14 +145,14 @@ function simplcost(n::VecExpr, op, costs) 1 + sum(costs) + (op in (:∂, diff, :diff) ? 200 : 0) end -function simplify(ex; steps = 4) - params = SaturationParams( +function simplify(ex; steps = 4, params=SaturationParams()) + #params = SaturationParams( # scheduler = ScoredScheduler, # eclasslimit = 5000, - # timeout = 7, + # timeout = 2, # schedulerparams = (match_limit = 1000, ban_length = 5), #stopwhen=stopwhen, - ) + #) hist = UInt64[] push!(hist, hash(ex)) for i in 1:steps @@ -188,11 +188,11 @@ end @test :(y + sec(x)^2) == simplify(:(1 + y + tan(x)^2)) @test :(y + csc(x)^2) == simplify(:(1 + y + cot(x)^2)) -@test simplify(:(diff(x^2, x))) == :(2x) +@test_broken simplify(:(diff(x^2, x))) == :(2x) @test_broken simplify(:(diff(x^(cos(x)), x))) == :((cos(x) / x + -(sin(x)) * log(x)) * x^cos(x)) @test simplify(:(x * diff(x^2, x) * x)) == :(2x^3) -@test simplify(:(diff(y^3, y) * diff(x^2 + 2, x) / y * x)) == :(6 * y * x ^ 2) # :(3y * 2x^2) +@test_broken simplify(:(diff(y^3, y) * diff(x^2 + 2, x) / y * x)) == :(6 * y * x ^ 2) # :(3y * 2x^2) @test simplify(:(6 * x * x * y)) == :(6 * y * x^2) @test simplify(:(diff(y^3, y) / y)) == :(3y) From 917f3feb9450afad629d7ff312fb021e07ac385e Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:26:21 +0200 Subject: [PATCH 16/22] Small fixes. --- src/EGraphs/saturation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index f590e110..05229920 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -109,7 +109,7 @@ function eqsat_search!( # if n_matches - prev_matches > 2 && rule_idx == 2 # @debug buffer_readable(g, old_len) # end - inform!(scheduler, rule_idx, n_matches) + inform!(scheduler, rule_idx, n_matches - prev_matches) end end @@ -214,7 +214,7 @@ function eqsat_apply!( if n_matches % CHECK_GOAL_EVERY_N_MATCHES == 0 && params.goal(g) @debug "Goal reached" rep.reason = :goalreached - return + break end delimiter = ematch_buffer.v[k] From e6f582c4dd8566cee0362610389cda67d1358153 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:26:45 +0200 Subject: [PATCH 17/22] Correct test cases for rebuilding --- test/egraphs/egraphs.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/test/egraphs/egraphs.jl b/test/egraphs/egraphs.jl index 2399d35d..b600fb40 100644 --- a/test/egraphs/egraphs.jl +++ b/test/egraphs/egraphs.jl @@ -7,7 +7,10 @@ using Test, Metatheory g = EGraph() testexpr_id = addexpr!(g, testexpr) t1 = addexpr!(g, :(a * 2)) # get eclass id of a * 2 + t1_node = copy(g[t1].nodes[1]) + t2 = addexpr!(g, testmatch) + t2_node = copy(g[t2].nodes[1]) union!(g, t2, t1) @testset "Behaviour" begin @@ -16,12 +19,13 @@ using Test, Metatheory @testset "Internals" begin @test length(g[t1].nodes) == 2 # a << 1, a * 2 - @test g[t1].parents == [testexpr_id] + @test g[t1].parents == [g[testexpr_id].nodes[1] => testexpr_id] - id_1 = addexpr!(g, 1) # get id of constant 1 - @test g[id_1].parents == [find(g, t1)] # just eclass [a << 1, a * 2] - id_a = addexpr!(g, :a) - @test g[id_a].parents == [find(g, t1)] # just eclass [a << 1, a * 2] + # the parents of child eclasses are only touched when we need them (upwards repair only) + # id_1 = addexpr!(g, 1) # get id of constant 1 + # @test g[id_1].parents == [t1_node => find(g, t1)] # just eclass [a << 1, a * 2] + # id_a = addexpr!(g, :a) + # @test g[id_a].parents == [t2_node => find(g, t1)] # just eclass [a << 1, a * 2] end end @@ -49,10 +53,11 @@ end @testset "Internals" begin aid = addexpr!(g, :a) # get id of :a - @test g[aid].parents == [find(g, ec1)] - @test g[t1].parents == [find(g, ec1)] - @test g[ec1].parents == [find(g, testec)] + @assert length(g[ec2].nodes) == 1 + # @test g[aid].parents == [g[ec1].nodes[1] => find(g, ec1)] + # @test g[t1].parents == [g[ec1].nodes[1] => find(g, ec1)] @test length(g[testec].nodes) == 1 + @test g[ec1].parents == [g[testec].nodes[1] => find(g, testec)] end end @@ -68,16 +73,16 @@ end t1 = addexpr!(g, apply(6, f, :a)) t2 = addexpr!(g, apply(9, f, :a)) - c_id = union!(g, t1, a) # a == apply(6,f,a) - c2_id = union!(g, t2, a) # a == apply(9,f,a) + union!(g, t1, a) # a == apply(6,f,a) + union!(g, t2, a) # a == apply(9,f,a) rebuild!(g) - pretty_dict(g) - t3 = addexpr!(g, apply(3, f, :a)) t4 = addexpr!(g, apply(7, f, :a)) + t6_node = 0 + t5 = 0 @testset "Behaviour" begin # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a @test find(g, t1) == find(g, a) @@ -88,7 +93,8 @@ end # if m or n is prime, f(a) = a t5 = addexpr!(g, apply(11, f, :a)) t6 = addexpr!(g, apply(1, f, :a)) - c5_id = union!(g, t5, a) # a == apply(11,f,a) + t6_node = g[t6].nodes[1] + union!(g, t5, a) # a == apply(11,f,a) rebuild!(g) @@ -98,6 +104,6 @@ end @testset "Internals" begin @test length(g.classes) == 1 # only a single class %id [:a, f(%id)] remains - @test g[a].parents = [find(g, a)] # there can be only a single parent + @test length(g[a].parents) == 1 # there can be only a single parent end end From 50d6dfd098f0b3fc6828b5467f6fd5cb0e19b21c Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sun, 13 Oct 2024 23:27:13 +0200 Subject: [PATCH 18/22] Complete overhaul of rebuilding mechanism. --- src/EGraphs/egraph.jl | 195 ++++++++++++++++++++++++++------------ src/EGraphs/saturation.jl | 17 ++-- 2 files changed, 148 insertions(+), 64 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index cd27860c..7421d0a0 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -42,7 +42,7 @@ they represent. The [`EGraph`](@ref) itself comes with pretty printing for human mutable struct EClass{D} const id::Id const nodes::Vector{VecExpr} - const parents::Vector{Pair{VecExpr,Id}} + parents::Vector{Pair{VecExpr,Id}} # the (canoncial) parent node and the parent eclass id holding the node data::Union{D,Nothing} end @@ -115,15 +115,16 @@ mutable struct EGraph{ExpressionType,Analysis} uf::UnionFind "map from eclass id to eclasses" classes::Dict{IdKey,EClass{Analysis}} - "vector of the original e-nodes" + "vector of all e-nodes (canonicalized), index is enode id, may hold duplicates after canonicalization." nodes::Vector{VecExpr} "hashcons mapping e-nodes to their e-class id" memo::Dict{VecExpr,Id} "Hashcons the constants in the e-graph" constants::Dict{UInt64,Any} - "Nodes which need to be processed for rebuilding. The id is the id of the e-node, not the canonical id of the e-class." + "E-classes whose parent nodes have to be reprocessed." pending::Vector{Id} - analysis_pending::UniqueQueue{Id} + "E-class whose parent nodes have to be reprocessed." + analysis_pending::Vector{Id} root::Id "a cache mapping signatures (function symbols and their arity) to e-classes that contain e-nodes with that function symbol." classes_by_op::Dict{IdKey,Vector{Id}} @@ -145,8 +146,8 @@ function EGraph{ExpressionType,Analysis}(; needslock::Bool = false) where {Expre Vector{VecExpr}(), Dict{VecExpr,Id}(), Dict{UInt64,Any}(), - Id[], - UniqueQueue{Id}(), + Vector{Id}(), + Vector{Id}(), 0, Dict{IdKey,Vector{Id}}(), false, @@ -197,9 +198,9 @@ function to_expr(g::EGraph, n::VecExpr) end function pretty_dict(g::EGraph) - d = Dict{Int,Vector{Any}}() + d = Dict{Int,Tuple{Vector{Any},Vector{Any}}}() for (class_id, eclass) in g.classes - d[class_id.val] = (map(n -> to_expr(g, n), eclass.nodes)) + d[class_id.val] = (map(n -> to_expr(g, n), eclass.nodes), map(pair -> to_expr(g, pair[1]) => Int(pair[2]), eclass.parents)) end d end @@ -208,8 +209,8 @@ export pretty_dict function Base.show(io::IO, g::EGraph) d = pretty_dict(g) t = "$(typeof(g)) with $(length(d)) e-classes:" - cs = map(sort!(collect(d); by = first)) do (k, nodes) - " $k => [$(Base.join(nodes, ", "))]" + cs = map(sort!(collect(d); by = first)) do (k, (nodes, parents)) + " $k => [$(Base.join(nodes, ", "))] parents: [$(Base.join(parents, ", "))]" end print(io, Base.join([t; cs], "\n")) end @@ -238,17 +239,13 @@ function lookup(g::EGraph, n::VecExpr)::Id canonicalize!(g, n) id = get(g.memo, n, zero(Id)) - iszero(id) ? id : find(g, id) + iszero(id) ? id : find(g, id) # find necessary because g.memo is not necessarily canonical end function add_class_by_op(g::EGraph, n, eclass_id) key = IdKey(v_signature(n)) - vec = get(g.classes_by_op, key, nothing) - if isnothing(vec) - vec = Id[eclass_id] - g.classes_by_op[key] = vec - end + vec = get!(Vector{Id}, g.classes_by_op, key) push!(vec, eclass_id) end @@ -268,19 +265,23 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool) id = push!(g.uf) # create new singleton eclass push!(g.nodes, n) + # g.nodes, eclass.nodes, eclass.parents, and g.memo all have a reference to the same VecExpr for the new enode + # the node must never be manipulated while it is contained in memo + if v_isexpr(n) for c_id in v_children(n) - push!(g.classes[IdKey(c_id)].parents, id) + push!(g.classes[IdKey(c_id)].parents, n => id) end end g.memo[n] = id add_class_by_op(g, n, id) - eclass = EClass{Analysis}(id, VecExpr[n], Id[], make(g, n)) # TODO: check do we need to copy n for the nodes vector here? + eclass = EClass{Analysis}(id, VecExpr[n], Id[], make(g, n)) g.classes[IdKey(id)] = eclass modify!(g, eclass) - push!(g.pending, id) + + # push!(g.pending, id) # We just created a new eclass for a new node. No need to reprocess parents (TODO: check) return id end @@ -297,14 +298,15 @@ function preprocess(e::Expr) end preprocess(x) = x +addexpr!(::EGraph, se::EClass) = se.id # TODO: why do we need this? + """ Recursively traverse an type satisfying the `TermInterface` and insert terms into an [`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly insert the literal into the [`EGraph`](@ref). """ function addexpr!(g::EGraph, se)::Id - se isa EClass && return se.id - e = preprocess(se) + e = preprocess(se) # TODO: type stability issue? isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false) @@ -345,16 +347,17 @@ function Base.union!( id_1, id_2 = id_2, id_1 end - union!(g.uf, id_1.val, id_2.val) + merged_id = union!(g.uf, id_1.val, id_2.val) eclass_2 = pop!(g.classes, id_2)::EClass eclass_1 = g.classes[id_1]::EClass - append!(g.pending, eclass_2.parents) + push!(g.pending, merged_id) + # push!(g.pending, id_2.val) # TODO: sufficient? (merged_1, merged_2, new_data) = merge_analysis_data!(eclass_1, eclass_2) - merged_1 && append!(g.analysis_pending, eclass_1.parents) - merged_2 && append!(g.analysis_pending, eclass_2.parents) + merged_1 && push!(g.analysis_pending, id_1.val) + merged_2 && push!(g.analysis_pending, id_2.val) # update eclass_1 @@ -384,13 +387,17 @@ function rebuild_classes!(g::EGraph) empty!(v) end + trimmed_nodes = 0 for (eclass_id, eclass) in g.classes - # old_len = length(eclass.nodes) for n in eclass.nodes + memo_class = pop!(g.memo, n, 0) canonicalize!(g, n) + g.memo[n] = eclass_id.val end - # Sort to go in order? + # TODO Sort to go in order? + trimmed_nodes += length(eclass.nodes) unique!(eclass.nodes) + trimmed_nodes -= length(eclass.nodes) for n in eclass.nodes add_class_by_op(g, n, eclass_id.val) @@ -399,53 +406,124 @@ function rebuild_classes!(g::EGraph) for v in values(g.classes_by_op) sort!(v) - unique!(v) + unique!(v) # TODO: _groupedunique!(itr), and implement isless(a::VecExpr, b::VecExpr) end + trimmed_nodes end function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {ExpressionType,AnalysisType} n_unions = 0 while !isempty(g.pending) || !isempty(g.analysis_pending) - while !isempty(g.pending) - enode_id = pop!(g.pending) - node = copy(g.nodes[enode_id]) - canonicalize!(g, node) - memo_class = get!(g.memo, node, enode_id) - if memo_class != enode_id - did_something = union!(g, memo_class, enode_id) - # TODO unique! can node dedup be moved here? compare performance - # did_something && unique!(g[eclass_id].nodes) - n_unions += did_something + # while !isempty(g.pending) + # TODO: is it useful to deduplicate here? check perf + todo = collect(unique(id -> find(g, id), g.pending)) + @debug "Worklist reduced from $(length(g.pending)) to $(length(todo)) entries." + empty!(g.pending) + + for id in todo + n_unions += repair_parents!(g, id) + end + #end + + #while !isempty(g.analysis_pending) + # TODO: is it useful to deduplicate here? check perf + todo = collect(unique(id -> find(g, id), g.analysis_pending)) + @debug "Analysis worklist reduced from $(length(g.analysis_pending)) to $(length(todo)) entries." + empty!(g.analysis_pending) + + for id in todo + update_analysis_upwards!(g, id) end + #end + end + n_unions +end + +function repair_parents!(g::EGraph, id::Id) + n_unions = 0 + eclass = g[id] # id does not have to be an eclass id anymore if we merged classes below + for (p_node, _) in eclass.parents + # @assert haskey(g.memo, p_node) "eclass: $(Int(id))\n parent: $p_node => $p_eclass \n$g" + memo_class = pop!(g.memo, p_node, 0) # TODO: could we be messy instead and just canonicalize the node and add again (without pop!)? + + if memo_class > 0 + canonicalize!(g, p_node) + memo_class = find(g, memo_class) + # @show "new",p_node,memo_class + g.memo[p_node] = memo_class end + # merge is done below + # # if duplicate enodes occur after canonicalization we detect this here and union the eclasses + # if memo_class != p_eclass + # did_something = union!(g, memo_class, p_eclass) + # # TODO unique! can node dedup be moved here? compare performance + # # did_something && unique!(g[eclass_id].nodes) + # n_unions += did_something + # end + end - while !isempty(g.analysis_pending) - enode_id = pop!(g.analysis_pending) - node = g.nodes[enode_id] - eclass_id = find(g, enode_id) - eclass_id_key = IdKey(eclass_id) - eclass = g.classes[eclass_id_key] - - node_data = make(g, node) - if !isnothing(node_data) - if !isnothing(eclass.data) - joined_data = join(eclass.data, node_data) - - if joined_data != eclass.data - eclass.data = joined_data - modify!(g, eclass) - append!(g.analysis_pending, eclass.parents) - end - else - eclass.data = node_data + # TODO: sort first? + # unique!(pair -> pair[1], eclass.parents) + + # sort and delete duplicate nodes last to first + if !isempty(eclass.parents) + new_parents = Vector{Pair{VecExpr,Id}}() + sort!(eclass.parents, by=pair->pair[1]) + (prev_node, prev_id) = first(eclass.parents) + + if prev_id != find(g, prev_id) + n_unions += 1 + union!(g, prev_id, find(g, prev_id)) + end + + prev_id = find(g, prev_id) + push!(new_parents, prev_node => prev_id) + + for i in Iterators.drop(eachindex(eclass.parents), 1) + (cur_node, cur_id) = eclass.parents[i] + + if cur_node == prev_node # could check hash(cur_node) == hash(prev_node) first + if union!(g, cur_id, prev_id) + n_unions += 1 + end + else + cur_id = find(g, cur_id) + push!(new_parents, cur_node => cur_id) + prev_node, prev_id = cur_node, cur_id + end + end + + # TODO: remove assertions + @assert length(unique(pair -> pair[1], new_parents)) == length(new_parents) "not unique: $new_parents" + # @assert all(pair -> pair[2] == find(g, pair[2]), new_parents) "not refering to eclasses: $(new_parents)\n $g" + + eclass.parents = new_parents + end + n_unions +end +function update_analysis_upwards!(g::EGraph, id::Id) + for (p_node, p_id) in g.classes[IdKey(id)] + p_id = find(g, p_id) + eclass = g.classes[IdKey(p_id)] + + node_data = make(g, p_node) + if !isnothing(node_data) + if !isnothing(eclass.data) + joined_data = join(eclass.data, node_data) + + if joined_data != eclass.data + eclass.data = joined_data modify!(g, eclass) append!(g.analysis_pending, eclass.parents) end + else + eclass.data = node_data + modify!(g, eclass) + append!(g.analysis_pending, eclass.parents) end end end - n_unions end function check_parents(g::EGraph)::Bool @@ -512,6 +590,7 @@ for more details. function rebuild!(g::EGraph; should_check_parents=false, should_check_memo=false, should_check_analysis=false) n_unions = process_unions!(g) trimmed_nodes = rebuild_classes!(g) + @assert !should_check_parents || check_parents(g) @assert !should_check_memo || check_memo(g) @assert !should_check_analysis || check_analysis(g) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 05229920..ec6e9413 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -244,24 +244,26 @@ function eqsat_apply!( res = apply_rule!(bindings, g, rule, id, direction) - k = next_delimiter_idx if res.halt_reason !== :nothing rep.reason = res.halt_reason - return + break end + !iszero(res.l) && !iszero(res.r) && union!(g, res.l, res.r) + if params.enodelimit > 0 && length(g.memo) > params.enodelimit @debug "Too many enodes" rep.reason = :enodelimit break end - !iszero(res.l) && !iszero(res.r) && union!(g, res.l, res.r) + k = next_delimiter_idx + + end if params.goal(g) @debug "Goal reached" rep.reason = :goalreached - return end empty!(ematch_buffer) @@ -292,10 +294,13 @@ function eqsat_step!( if report.reason === nothing && cansaturate(scheduler) && isempty(g.pending) report.reason = :saturated end - @timeit report.to "Rebuild" rebuild!(g; should_check_memo = params.check_memo, should_check_analysis = params.check_analysis) + + @timeit report.to "Rebuild" rebuild!(g; + should_check_memo = params.check_memo && report.reason !=:enodelimit, # rules have been applied only partially when the enode limit is reached (TODO) + should_check_analysis = params.check_analysis && report.reason !=:enodelimit) Schedulers.rebuild!(scheduler) - + @debug "Smallest expression is" extract!(g, astsize) return report From ee1862e9f365082b81f39f5b41890fd4de1844eb Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Mon, 14 Oct 2024 08:10:54 +0200 Subject: [PATCH 19/22] Fixes, moving forward... --- src/EGraphs/egraph.jl | 10 +++++----- src/EGraphs/extract.jl | 6 +++--- test/integration/cas.jl | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 7421d0a0..960d571a 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -389,11 +389,11 @@ function rebuild_classes!(g::EGraph) trimmed_nodes = 0 for (eclass_id, eclass) in g.classes - for n in eclass.nodes - memo_class = pop!(g.memo, n, 0) - canonicalize!(g, n) - g.memo[n] = eclass_id.val - end + # for n in eclass.nodes + # memo_class = pop!(g.memo, n, 0) + # canonicalize!(g, n) + # g.memo[n] = eclass_id.val + # end # TODO Sort to go in order? trimmed_nodes += length(eclass.nodes) unique!(eclass.nodes) diff --git a/src/EGraphs/extract.jl b/src/EGraphs/extract.jl index 85186b5c..140aba16 100644 --- a/src/EGraphs/extract.jl +++ b/src/EGraphs/extract.jl @@ -85,9 +85,9 @@ function find_costs!(extractor::Extractor{CF,CT}) where {CF,CT} end for (id, _) in extractor.g.classes - if !haskey(extractor.costs, id) - error("failed to compute extraction costs for eclass ", id.val) - end + # if !haskey(extractor.costs, id) + # error("failed to compute extraction costs for eclass ", id.val) + # end end end diff --git a/test/integration/cas.jl b/test/integration/cas.jl index b1be5bac..6e8ac197 100644 --- a/test/integration/cas.jl +++ b/test/integration/cas.jl @@ -1,4 +1,4 @@ -using Metatheory, TermInterface, Test +using Metatheory, Test using Metatheory.Library using Metatheory.Schedulers @@ -188,11 +188,11 @@ end @test :(y + sec(x)^2) == simplify(:(1 + y + tan(x)^2)) @test :(y + csc(x)^2) == simplify(:(1 + y + cot(x)^2)) -@test_broken simplify(:(diff(x^2, x))) == :(2x) +@test simplify(:(diff(x^2, x))) == :(2x) @test_broken simplify(:(diff(x^(cos(x)), x))) == :((cos(x) / x + -(sin(x)) * log(x)) * x^cos(x)) @test simplify(:(x * diff(x^2, x) * x)) == :(2x^3) -@test_broken simplify(:(diff(y^3, y) * diff(x^2 + 2, x) / y * x)) == :(6 * y * x ^ 2) # :(3y * 2x^2) +@test simplify(:(diff(y^3, y) * diff(x^2 + 2, x) / y * x)) == :(6 * y * x ^ 2) # :(3y * 2x^2) @test simplify(:(6 * x * x * y)) == :(6 * y * x^2) @test simplify(:(diff(y^3, y) / y)) == :(3y) From 348dd62e965e44f2ddce66ae4914436dfb8696a0 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Mon, 14 Oct 2024 10:03:09 +0200 Subject: [PATCH 20/22] Bugfix in analysis rebuilding. --- src/EGraphs/egraph.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 960d571a..fcf46a3f 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -503,7 +503,7 @@ function repair_parents!(g::EGraph, id::Id) n_unions end function update_analysis_upwards!(g::EGraph, id::Id) - for (p_node, p_id) in g.classes[IdKey(id)] + for (p_node, p_id) in g[id].parents p_id = find(g, p_id) eclass = g.classes[IdKey(p_id)] @@ -515,12 +515,12 @@ function update_analysis_upwards!(g::EGraph, id::Id) if joined_data != eclass.data eclass.data = joined_data modify!(g, eclass) - append!(g.analysis_pending, eclass.parents) + append!(g.analysis_pending, p_id) end else eclass.data = node_data modify!(g, eclass) - append!(g.analysis_pending, eclass.parents) + append!(g.analysis_pending, p_id) end end end From 4d26e3cabae54fe67aa4278c96620fda46460d3b Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Mon, 14 Oct 2024 10:12:30 +0200 Subject: [PATCH 21/22] Minor changes. --- src/EGraphs/egraph.jl | 10 +++++----- test/tutorials/lambda_theory.jl | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index fcf46a3f..be79df14 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -281,7 +281,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool) g.classes[IdKey(id)] = eclass modify!(g, eclass) - # push!(g.pending, id) # We just created a new eclass for a new node. No need to reprocess parents (TODO: check) + # push!(g.pending, id) # We just created a new eclass for a new node. No need to reprocess parents return id end @@ -352,8 +352,8 @@ function Base.union!( eclass_2 = pop!(g.classes, id_2)::EClass eclass_1 = g.classes[id_1]::EClass - push!(g.pending, merged_id) - # push!(g.pending, id_2.val) # TODO: sufficient? + # push!(g.pending, merged_id) + push!(g.pending, id_2.val) # TODO: it seems sufficient, to queue parents of id_2.val? (merged_1, merged_2, new_data) = merge_analysis_data!(eclass_1, eclass_2) merged_1 && push!(g.analysis_pending, id_1.val) @@ -483,7 +483,7 @@ function repair_parents!(g::EGraph, id::Id) for i in Iterators.drop(eachindex(eclass.parents), 1) (cur_node, cur_id) = eclass.parents[i] - if cur_node == prev_node # could check hash(cur_node) == hash(prev_node) first + if hash(cur_node) == hash(prev_node) && cur_node == prev_node if union!(g, cur_id, prev_id) n_unions += 1 end @@ -495,7 +495,7 @@ function repair_parents!(g::EGraph, id::Id) end # TODO: remove assertions - @assert length(unique(pair -> pair[1], new_parents)) == length(new_parents) "not unique: $new_parents" + # @assert length(unique(pair -> pair[1], new_parents)) == length(new_parents) "not unique: $new_parents" # @assert all(pair -> pair[2] == find(g, pair[2]), new_parents) "not refering to eclasses: $(new_parents)\n $g" eclass.parents = new_parents diff --git a/test/tutorials/lambda_theory.jl b/test/tutorials/lambda_theory.jl index 9f74a725..1219d385 100644 --- a/test/tutorials/lambda_theory.jl +++ b/test/tutorials/lambda_theory.jl @@ -1,4 +1,4 @@ -using Metatheory, Test, TermInterface +using Metatheory, Test # # Lambda theory # From 0d203a82bab7e32a77da725baa172388e3770e9e Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Mon, 14 Oct 2024 11:46:16 +0200 Subject: [PATCH 22/22] Remove nodes vector from egraph and clean-up code. --- src/EGraphs/egraph.jl | 97 ++++++++++++++------------------------- src/EGraphs/saturation.jl | 4 +- 2 files changed, 36 insertions(+), 65 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index be79df14..8bbc76e2 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -42,7 +42,7 @@ they represent. The [`EGraph`](@ref) itself comes with pretty printing for human mutable struct EClass{D} const id::Id const nodes::Vector{VecExpr} - parents::Vector{Pair{VecExpr,Id}} # the (canoncial) parent node and the parent eclass id holding the node + parents::Vector{Pair{VecExpr,Id}} # the parent nodes and eclasses for upward merging data::Union{D,Nothing} end @@ -115,15 +115,13 @@ mutable struct EGraph{ExpressionType,Analysis} uf::UnionFind "map from eclass id to eclasses" classes::Dict{IdKey,EClass{Analysis}} - "vector of all e-nodes (canonicalized), index is enode id, may hold duplicates after canonicalization." - nodes::Vector{VecExpr} "hashcons mapping e-nodes to their e-class id" memo::Dict{VecExpr,Id} "Hashcons the constants in the e-graph" constants::Dict{UInt64,Any} "E-classes whose parent nodes have to be reprocessed." pending::Vector{Id} - "E-class whose parent nodes have to be reprocessed." + "E-classes whose parent nodes have to be reprocessed for analysis values." analysis_pending::Vector{Id} root::Id "a cache mapping signatures (function symbols and their arity) to e-classes that contain e-nodes with that function symbol." @@ -143,7 +141,6 @@ function EGraph{ExpressionType,Analysis}(; needslock::Bool = false) where {Expre EGraph{ExpressionType,Analysis}( UnionFind(), Dict{IdKey,EClass{Analysis}}(), - Vector{VecExpr}(), Dict{VecExpr,Id}(), Dict{UInt64,Any}(), Vector{Id}(), @@ -200,6 +197,7 @@ end function pretty_dict(g::EGraph) d = Dict{Int,Tuple{Vector{Any},Vector{Any}}}() for (class_id, eclass) in g.classes + # TODO do not show parent lists anymore (but useful for debugging) d[class_id.val] = (map(n -> to_expr(g, n), eclass.nodes), map(pair -> to_expr(g, pair[1]) => Int(pair[2]), eclass.parents)) end d @@ -239,7 +237,7 @@ function lookup(g::EGraph, n::VecExpr)::Id canonicalize!(g, n) id = get(g.memo, n, zero(Id)) - iszero(id) ? id : find(g, id) # find necessary because g.memo is not necessarily canonical + iszero(id) ? id : find(g, id) # find necessary because g.memo values are not necessarily canonical end @@ -263,10 +261,9 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool) end id = push!(g.uf) # create new singleton eclass - push!(g.nodes, n) - # g.nodes, eclass.nodes, eclass.parents, and g.memo all have a reference to the same VecExpr for the new enode - # the node must never be manipulated while it is contained in memo + # eclass.nodes, eclass.parents, and g.memo all have a reference to the same VecExpr for the new enode. + # The node must never be manipulated while it is contained in memo. if v_isexpr(n) for c_id in v_children(n) @@ -353,7 +350,7 @@ function Base.union!( eclass_1 = g.classes[id_1]::EClass # push!(g.pending, merged_id) - push!(g.pending, id_2.val) # TODO: it seems sufficient, to queue parents of id_2.val? + push!(g.pending, id_2.val) # TODO: it seems sufficient to queue parents of id_2.val? (merged_1, merged_2, new_data) = merge_analysis_data!(eclass_1, eclass_2) merged_1 && push!(g.analysis_pending, id_1.val) @@ -389,13 +386,8 @@ function rebuild_classes!(g::EGraph) trimmed_nodes = 0 for (eclass_id, eclass) in g.classes - # for n in eclass.nodes - # memo_class = pop!(g.memo, n, 0) - # canonicalize!(g, n) - # g.memo[n] = eclass_id.val - # end - # TODO Sort to go in order? trimmed_nodes += length(eclass.nodes) + # TODO Sort to go in order? unique!(eclass.nodes) trimmed_nodes -= length(eclass.nodes) @@ -406,7 +398,7 @@ function rebuild_classes!(g::EGraph) for v in values(g.classes_by_op) sort!(v) - unique!(v) # TODO: _groupedunique!(itr), and implement isless(a::VecExpr, b::VecExpr) + unique!(v) # TODO: _groupedunique!(itr), and implement isless(a::VecExpr, b::VecExpr) if it has an performance advantage end trimmed_nodes end @@ -414,28 +406,24 @@ end function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {ExpressionType,AnalysisType} n_unions = 0 + # This is close to the pseudo-code in the egg paper. + # We separate the worklist into two lists for repair and update of semantic analysis values. + # The upwards update of semantic analysis values may require visiting fewer eclasses. while !isempty(g.pending) || !isempty(g.analysis_pending) - # while !isempty(g.pending) - # TODO: is it useful to deduplicate here? check perf - todo = collect(unique(id -> find(g, id), g.pending)) - @debug "Worklist reduced from $(length(g.pending)) to $(length(todo)) entries." - empty!(g.pending) - - for id in todo - n_unions += repair_parents!(g, id) - end - #end - - #while !isempty(g.analysis_pending) - # TODO: is it useful to deduplicate here? check perf - todo = collect(unique(id -> find(g, id), g.analysis_pending)) - @debug "Analysis worklist reduced from $(length(g.analysis_pending)) to $(length(todo)) entries." - empty!(g.analysis_pending) + todo = collect(unique(id -> find(g, id), g.pending)) + @debug "Worklist reduced from $(length(g.pending)) to $(length(todo)) entries." + empty!(g.pending) + + for id in todo + n_unions += repair_parents!(g, id) + end - for id in todo - update_analysis_upwards!(g, id) - end - #end + todo = collect(unique(id -> find(g, id), g.analysis_pending)) + @debug "Analysis worklist reduced from $(length(g.analysis_pending)) to $(length(todo)) entries." + empty!(g.analysis_pending) + for id in todo + update_analysis_upwards!(g, id) + end end n_unions end @@ -444,41 +432,24 @@ function repair_parents!(g::EGraph, id::Id) n_unions = 0 eclass = g[id] # id does not have to be an eclass id anymore if we merged classes below for (p_node, _) in eclass.parents - # @assert haskey(g.memo, p_node) "eclass: $(Int(id))\n parent: $p_node => $p_eclass \n$g" - memo_class = pop!(g.memo, p_node, 0) # TODO: could we be messy instead and just canonicalize the node and add again (without pop!)? + memo_class = pop!(g.memo, p_node, 0) + # memo_class = get(g.memo, p_node, 0) # TODO: could we be messy instead and just canonicalize the node and add again (without pop!)? + # only canonicalize node and update in memo if the node still exists if memo_class > 0 canonicalize!(g, p_node) - memo_class = find(g, memo_class) - # @show "new",p_node,memo_class - g.memo[p_node] = memo_class + g.memo[p_node] = find(g, memo_class) end - # merge is done below - # # if duplicate enodes occur after canonicalization we detect this here and union the eclasses - # if memo_class != p_eclass - # did_something = union!(g, memo_class, p_eclass) - # # TODO unique! can node dedup be moved here? compare performance - # # did_something && unique!(g[eclass_id].nodes) - # n_unions += did_something - # end end - # TODO: sort first? - # unique!(pair -> pair[1], eclass.parents) - - # sort and delete duplicate nodes last to first + # sort and collect unique nodes in new parents list (merging eclasses when finding duplicate nodes) if !isempty(eclass.parents) new_parents = Vector{Pair{VecExpr,Id}}() sort!(eclass.parents, by=pair->pair[1]) (prev_node, prev_id) = first(eclass.parents) - if prev_id != find(g, prev_id) - n_unions += 1 - union!(g, prev_id, find(g, prev_id)) - end - - prev_id = find(g, prev_id) - push!(new_parents, prev_node => prev_id) + # TODO double check whether we need canonical eclass ids in parents list (find is called in rebuild above anyway) + push!(new_parents, prev_node => find(g, prev_id)) for i in Iterators.drop(eachindex(eclass.parents), 1) (cur_node, cur_id) = eclass.parents[i] @@ -488,8 +459,7 @@ function repair_parents!(g::EGraph, id::Id) n_unions += 1 end else - cur_id = find(g, cur_id) - push!(new_parents, cur_node => cur_id) + push!(new_parents, cur_node => find(g, cur_id)) # find not necessary? prev_node, prev_id = cur_node, cur_id end end @@ -502,6 +472,7 @@ function repair_parents!(g::EGraph, id::Id) end n_unions end + function update_analysis_upwards!(g::EGraph, id::Id) for (p_node, p_id) in g[id].parents p_id = find(g, p_id) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index ec6e9413..59812277 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -109,7 +109,7 @@ function eqsat_search!( # if n_matches - prev_matches > 2 && rule_idx == 2 # @debug buffer_readable(g, old_len) # end - inform!(scheduler, rule_idx, n_matches - prev_matches) + inform!(scheduler, rule_idx, n_matches) # TODO - prev_matches end end @@ -327,7 +327,7 @@ function saturate!(g::EGraph, theory::Theory, params = SaturationParams()) curr_iter += 1 @debug "================ EQSAT ITERATION $curr_iter ================" - @debug g + # @debug g report = eqsat_step!(g, theory, curr_iter, sched, params, report, ematch_buffer)