Skip to content

Commit 0cac932

Browse files
authored
Merge pull request #633 from JuliaHomotopyContinuation/automatic-triangle-inequality
Automatic triangle inequality
2 parents b8588c9 + 9e1fb4c commit 0cac932

File tree

5 files changed

+82
-43
lines changed

5 files changed

+82
-43
lines changed

src/monodromy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Base.@kwdef struct MonodromyOptions{D,GA<:Union{Nothing,GroupActions}}
3737
permutations::Bool = false
3838
# unique points options
3939
distance::D = EuclideanNorm()
40-
triangle_inequality::Bool = true
40+
triangle_inequality::Union{Nothing,Bool} = nothing
4141
unique_points_atol::Union{Nothing,Float64} = nothing
4242
unique_points_rtol::Union{Nothing,Float64} = nothing
4343
#
@@ -458,7 +458,7 @@ function MonodromySolver(
458458
unique_points = UniquePoints(
459459
x₀,
460460
1;
461-
metric = options.distance,
461+
distance = options.distance,
462462
group_actions = group_actions,
463463
triangle_inequality = options.triangle_inequality,
464464
)

src/unique_points.jl

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,14 @@ Base.iterate(p::SymmetricGroup, s) = iterate(p.permutations, s)
127127
UniquePoints{T, Id, M}
128128
129129
A data structure for assessing quickly whether a point is close to an indexed point as
130-
determined by the given distances function `M`. The distance function has to be a *metric*.
131-
The indexed points are only stored by their identifiers `Id`. `triangle_inequality` should be set to `true`, if the metric satisfies the triangle inequality. Otherwise, it should be set to `false`.
130+
determined by the given distance function `M`.
131+
The indexed points are only stored by their identifiers `Id`.
132+
`triangle_inequality` should be set to `true`, if the distance function satisfies the triangle inequality.
133+
Otherwise, it should be set to `false`. If `triangle_inequality` is nothing the algorithm will try to detect whether the triangle is satisfied.
132134
133135
UniquePoints(v::AbstractVector{T}, id::Id;
134-
metric = EuclideanNorm(),
135-
triangle_inequality = true,
136+
distance = EuclideanNorm(),
137+
triangle_inequality = nothing,
136138
group_actions = nothing)
137139
138140
@@ -167,16 +169,35 @@ end
167169
function UniquePoints(
168170
v::AbstractVector,
169171
id;
170-
metric = EuclideanNorm(),
171-
triangle_inequality = true,
172+
distance = EuclideanNorm(),
173+
triangle_inequality = nothing,
172174
group_action = nothing,
173175
group_actions = isnothing(group_action) ? nothing : GroupActions(group_action),
174176
)
175177
if (group_actions isa Tuple) || (group_actions isa AbstractVector)
176178
group_actions = GroupActions(group_actions)
177179
end
178180

179-
tree = VoronoiTree(v, id; metric = metric, triangle_inequality = triangle_inequality)
181+
d = distance
182+
183+
if isnothing(triangle_inequality)
184+
if typeof(d) <: AbstractNorm
185+
triangle_inequality = true
186+
else
187+
n = length(v)
188+
v₁ = randn(ComplexF64, n)
189+
v₂ = randn(ComplexF64, n)
190+
v₃ = randn(ComplexF64, n)
191+
if d(v₁, v₂) d(v₁, v₃) + d(v₃, v₂) &&
192+
d(v₁, 4 .* v₁) d(v₁, 2 .* v₁) + d(2 .* v₁, 4 .* v₁)
193+
triangle_inequality = true
194+
else
195+
triangle_inequality = false
196+
end
197+
end
198+
end
199+
200+
tree = VoronoiTree(v, id; distance = d, triangle_inequality = triangle_inequality)
180201
UniquePoints(tree, group_actions, zeros(eltype(v), length(v)))
181202
end
182203

@@ -257,7 +278,7 @@ function add!(
257278
atol::Float64 = 1e-14,
258279
rtol::Float64 = sqrt(eps()),
259280
) where {T,Id,M,GA}
260-
n = UP.tree.metric(v, UP.zero_vec)
281+
n = UP.tree.distance(v, UP.zero_vec)
261282
rad = max(atol, rtol * n)
262283
add!(UP, v, id, rad)
263284
end
@@ -266,10 +287,10 @@ end
266287
## Multiplicities ##
267288
####################
268289
"""
269-
multiplicities(vectors; metric = EuclideanNorm(), atol = 1e-14, rtol = 1e-8, kwargs...)
290+
multiplicities(vectors; distance = EuclideanNorm(), atol = 1e-14, rtol = 1e-8, kwargs...)
270291
271292
Returns a `Vector{Vector{Int}}` `v`. Each vector `w` in 'v' contains all indices `i`,`j`
272-
such that `w[i]` and `w[j]` have `distance` at most `max(atol, rtol * metric(0,w[i]))`.
293+
such that `w[i]` and `w[j]` have `distance` at most `max(atol, rtol * distance(0,w[i]))`.
273294
The remaining `kwargs` are things that can be passed to [`UniquePoints`](@ref).
274295
275296
```julia-repl
@@ -289,19 +310,19 @@ julia> m = multiplicities(X, group_action = permutation)
289310
```
290311
"""
291312
multiplicities(v; kwargs...) = multiplicities(identity, v; kwargs...)
292-
function multiplicities(f::F, v; metric = EuclideanNorm(), kwargs...) where {F<:Function}
313+
function multiplicities(f::F, v; distance = EuclideanNorm(), kwargs...) where {F<:Function}
293314
isempty(v) && return Vector{Vector{Int}}()
294-
_multiplicities(f, v, metric; kwargs...)
315+
_multiplicities(f, v, distance; kwargs...)
295316
end
296317
function _multiplicities(
297318
f::F,
298319
V,
299-
metric;
320+
distance;
300321
atol::Float64 = 1e-14,
301322
rtol::Float64 = 1e-8,
302323
kwargs...,
303324
) where {F<:Function}
304-
unique_points = UniquePoints(f(first(V)), 1; metric = metric, kwargs...)
325+
unique_points = UniquePoints(f(first(V)), 1; distance = distance, kwargs...)
305326
mults = Dict{Int,Vector{Int}}()
306327
for (i, vᵢ) in enumerate(V)
307328
wᵢ = f(vᵢ)
@@ -317,14 +338,14 @@ function _multiplicities(
317338
collect(values(mults))
318339
end
319340
"""
320-
unique_points(vectors; metric = EuclideanNorm(), atol = 1e-14, rtol = 1e-8, kwargs...)
341+
unique_points(vectors; distance = EuclideanNorm(), atol = 1e-14, rtol = 1e-8, kwargs...)
321342
322-
Returns all elements in `vector` for which two elements have `distance` at most `max(atol, rtol * metric(0,w[i]))`.
343+
Returns all elements in `vector` for which two elements have `distance` at most `max(atol, rtol * distance(0,w[i]))`.
323344
Note that the output can depend on the order of elements in vectors.
324345
The remaining `kwargs` are things that can be passed to [`UniquePoints`](@ref).
325346
"""
326-
function unique_points(V; metric = EuclideanNorm(), atol = 1e-14, rtol = 1e-8, kwargs...)
327-
unique_points = UniquePoints(first(V), 1; metric = metric, kwargs...)
347+
function unique_points(V; distance = EuclideanNorm(), atol = 1e-14, rtol = 1e-8, kwargs...)
348+
unique_points = UniquePoints(first(V), 1; distance = distance, kwargs...)
328349
out = Vector{eltype(V)}()
329350
for (i, vᵢ) in enumerate(V)
330351
_, new_point = add!(unique_points, vᵢ, i; atol = atol, rtol = rtol)

src/voronoi_tree.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ Base.length(node::VTNode) = node.nentries
2929
Base.isempty(node::VTNode) = length(node) == 0
3030
capacity(node::VTNode) = length(node.children)
3131

32-
function compute_distances!(node, x, metric::M) where {M}
32+
function compute_distances!(node, x, distance::M) where {M}
3333
for j = 1:length(node)
34-
node.distances[j] = (metric(x, view(node.values, :, j)), j)
34+
node.distances[j] = (distance(x, view(node.values, :, j)), j)
3535
end
3636
node.distances
3737
end
@@ -40,7 +40,7 @@ function search_in_radius(
4040
node::VTNode{T,Id},
4141
x,
4242
tol::Real,
43-
metric::M,
43+
distance::M,
4444
triangle_inequality::Bool,
4545
) where {T,Id,M}
4646
!isempty(node) || return nothing
@@ -62,7 +62,7 @@ function search_in_radius(
6262
m₁ = m₂ = m₃ = (Inf, 1)
6363

6464
# we have a distances cache per thread
65-
distances = compute_distances!(node, x, metric)
65+
distances = compute_distances!(node, x, distance)
6666
for i = 1:n
6767
dᵢ = first(distances[i])
6868
# early exit
@@ -95,7 +95,7 @@ function search_in_radius(
9595
# Check smallest element first
9696
if isassigned(node.children, last(m₁))
9797
retid =
98-
search_in_radius(node.children[last(m₁)], x, tol, metric, triangle_inequality)
98+
search_in_radius(node.children[last(m₁)], x, tol, distance, triangle_inequality)
9999
if !isnothing(retid)
100100
# we rely on the distances for insertion, so place the smallest element first
101101
distances[1] = m₁
@@ -112,7 +112,7 @@ function search_in_radius(
112112
# Case 2) If we have a triangle in equality, we know m₂[1] - m₁[1] ≤ 2tol
113113
if isassigned(node.children, last(m₂))
114114
retid =
115-
search_in_radius(node.children[last(m₂)], x, tol, metric, triangle_inequality)
115+
search_in_radius(node.children[last(m₂)], x, tol, distance, triangle_inequality)
116116
if !isnothing(retid)
117117
# we rely on the distances for insertion, so place the smallest element first
118118
distances[1] = m₁
@@ -128,7 +128,7 @@ function search_in_radius(
128128
# Since we know also the third element, let's check it
129129
if isassigned(node.children, last(m₃))
130130
retid =
131-
search_in_radius(node.children[last(m₃)], x, tol, metric, triangle_inequality)
131+
search_in_radius(node.children[last(m₃)], x, tol, distance, triangle_inequality)
132132
if !isnothing(retid)
133133
# we rely on the distances for insertion, so place at the first place the smallest element
134134
distances[1] = m₁
@@ -146,8 +146,13 @@ function search_in_radius(
146146
dᵢ, i = distances[k]
147147
if dᵢ - m₁[1] < 2tol || !triangle_inequality
148148
if isassigned(node.children, i)
149-
retid =
150-
search_in_radius(node.children[i], x, tol, metric, triangle_inequality)
149+
retid = search_in_radius(
150+
node.children[i],
151+
x,
152+
tol,
153+
distance,
154+
triangle_inequality,
155+
)
151156
if !isnothing(retid)
152157
return retid::Id
153158
end
@@ -165,7 +170,7 @@ function _insert!(
165170
node::VTNode{T,Id},
166171
v,
167172
id::Id,
168-
metric::M;
173+
distance::M;
169174
use_distances::Bool = false,
170175
) where {T,Id,M}
171176
# if not filled so far, just add it to the current node
@@ -179,14 +184,14 @@ function _insert!(
179184
if use_distances
180185
dᵢ, minᵢ = first(node.distances)
181186
else
182-
compute_distances!(node, v, metric)
187+
compute_distances!(node, v, distance)
183188
dᵢ, minᵢ = findmin(node.distances)
184189
end
185190

186191
if !isassigned(node.children, minᵢ)
187192
node.children[minᵢ] = VTNode{T,Id}(v, id; capacity = capacity(node))
188193
else # a node already exists, so recurse
189-
_insert!(node.children[minᵢ], v, id, metric)
194+
_insert!(node.children[minᵢ], v, id, distance)
190195
end
191196

192197
nothing
@@ -206,29 +211,30 @@ end
206211
VoronoiTree(
207212
v::AbstractVector{T},
208213
id::Id;
209-
metric = EuclideanNorm(),
214+
distance = EuclideanNorm(),
210215
capacity = 8,
211216
triangle_inequality = true
212217
)
213218
214219
Construct a Voronoi tree data structure for vector `v` of element type `T` and with identifiers
215-
`Id`. Each node has the given `capacity` and distances are measured by the given `metric`. `triangle_inequality` should be set to `true`, if `metric` satisfies the triangle inequality. Otherwise, it should be set to `false`.
220+
`Id`. Each node has the given `capacity` and distances are measured by the given `distance`.
221+
`triangle_inequality` should be set to `true`, if `distance` satisfies the triangle inequality. Otherwise, it should be set to `false`.
216222
"""
217223
mutable struct VoronoiTree{T,Id,M}
218224
root::VTNode{T,Id}
219225
nentries::Int
220-
metric::M
226+
distance::M
221227
triangle_inequality::Bool
222228
end
223229

224230
function VoronoiTree{T,Id}(
225231
d::Int;
226-
metric = EuclideanNorm(),
232+
distance = EuclideanNorm(),
227233
capacity::Int = 8,
228234
triangle_inequality::Bool = true,
229235
) where {T,Id}
230236
root = VTNode(T, Id, d; capacity = capacity)
231-
VoronoiTree(root, 0, metric, triangle_inequality)
237+
VoronoiTree(root, 0, distance, triangle_inequality)
232238
end
233239

234240
function VoronoiTree(v::AbstractVector{T}, id::Id; kwargs...) where {T,Id}
@@ -254,7 +260,7 @@ function Base.insert!(
254260
id::Id;
255261
use_distances::Bool = false,
256262
) where {T,Id}
257-
_insert!(tree.root, v, id, tree.metric; use_distances = use_distances)
263+
_insert!(tree.root, v, id, tree.distance; use_distances = use_distances)
258264
tree.nentries += 1
259265
tree
260266
end
@@ -266,7 +272,7 @@ Search whether the given `tree` contains a point `p` with distances at most `tol
266272
Returns `nothing` if no point exists, otherwise the identifier of `p` is returned.
267273
"""
268274
function search_in_radius(tree::VoronoiTree, v, tol::Real)
269-
search_in_radius(tree.root, v, tol, tree.metric, tree.triangle_inequality)
275+
search_in_radius(tree.root, v, tol, tree.distance, tree.triangle_inequality)
270276
end
271277

272278

test/monodromy_test.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,22 @@
6969
monodromy_solve(F.expressions, [x₀, rand(6)], p₀, parameters = F.parameters)
7070
@test length(solutions(result)) == 21
7171

72-
# different distance function
72+
# distance function that satisfies triangle inequality
7373
result = monodromy_solve(F, x₀, p₀, distance = (x, y) -> 0.0)
7474
@test length(solutions(result)) == 1
7575

76+
# distance function that does not satisfy triangle inequality
77+
result = monodromy_solve(F, x₀, p₀, distance = (x, y) -> norm(x - y, 2)^2)
78+
@test length(solutions(result)) == 21
79+
7680
# don't use triangle inequality
7781
result = monodromy_solve(F, x₀, p₀, triangle_inequality = false)
7882
@test length(solutions(result)) == 21
7983

84+
# use triangle inequality
85+
result = monodromy_solve(F, x₀, p₀, triangle_inequality = true)
86+
@test length(solutions(result)) == 21
87+
8088
# Test stop heuristic with no target solutions count
8189
result = monodromy_solve(F, x₀, p₀)
8290
@test is_heuristic_stop(result)
@@ -440,7 +448,7 @@
440448
target_solutions_count = 305,
441449
)
442450

443-
UP = unique_points(solutions(points), metric = dist, rtol = 1e-8, atol = 1e-14)
451+
UP = unique_points(solutions(points), distance = dist, rtol = 1e-8, atol = 1e-14)
444452

445453
@test length(solutions(points)) == length(UP)
446454
end

test/unique_points_test.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ end
338338
M = multiplicities(V)
339339
@test length(M) == 0
340340

341-
N = multiplicities(W, metric = InfNorm(), atol = 1e-5)
341+
N = multiplicities(W, distance = InfNorm(), atol = 1e-5)
342342
sort!(N, by = first)
343343
@test length(N) == 10
344344
@test unique([length(m) for m in N]) == [2]
@@ -347,7 +347,11 @@ end
347347
O = multiplicities([U; U])
348348
@test length(O) == 20
349349

350-
P = multiplicities(X, metric = (x, y) -> 1 - abs(LinearAlgebra.dot(x, y)), atol = 1e-5)
350+
P = multiplicities(
351+
X,
352+
distance = (x, y) -> 1 - abs(LinearAlgebra.dot(x, y)),
353+
atol = 1e-5,
354+
)
351355
@test length(P) == 3
352356

353357
# Test with group action

0 commit comments

Comments
 (0)