Skip to content

Commit 3aec5ed

Browse files
authored
feat: add trig_to_exp(f::BasicSymbolic) (#34)
1 parent a655cf8 commit 3aec5ed

File tree

4 files changed

+77
-25
lines changed

4 files changed

+77
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "QuestBase"
22
uuid = "7e80f742-43d6-403d-a9ea-981410111d43"
33
authors = ["Orjan Ameye <orjan.ameye@hotmail.com>", "Jan Kosata <kosataj@phys.ethz.ch>", "Javier del Pino <jdelpino@phys.ethz.ch>"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/Symbolics/Symbolics_utils.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,20 @@ function get_independent(x::BasicSymbolic, t::Num)
8888
end
8989

9090
"Return all the terms contained in `x`"
91-
get_all_terms(x::Num) = unique(_get_all_terms(Symbolics.expand(x).val))
91+
get_all_terms(x::Num) = Num.(unique(_get_all_terms(Symbolics.expand(x).val)))
92+
get_all_terms(x::BasicSymbolic) = unique(_get_all_terms(Symbolics.expand(x)))
9293
function get_all_terms(x::Equation)
9394
return unique(cat(get_all_terms(Num(x.lhs)), get_all_terms(Num(x.rhs)); dims=1))
9495
end
9596
function _get_all_terms(x::BasicSymbolic)
9697
@compactified x::BasicSymbolic begin
9798
Add => vcat([_get_all_terms(term) for term in SymbolicUtils.arguments(x)]...)
98-
Mul => Num.(SymbolicUtils.arguments(x))
99-
Div => Num.([_get_all_terms(x.num)..., _get_all_terms(x.den)...])
100-
_ => Num(x)
99+
Mul => SymbolicUtils.arguments(x)
100+
Div => [_get_all_terms(x.num)..., _get_all_terms(x.den)...]
101+
_ => [x]
101102
end
102103
end
103-
_get_all_terms(x) = Num(x)
104+
_get_all_terms(x) = x
104105

105106
function is_harmonic(x::Num, t::Num)::Bool
106107
all_terms = get_all_terms(x)

src/Symbolics/fourier.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ function trig_reduce(x)
2929
x = simplify_exp_products(x) # simplify products of exps
3030
x = exp_to_trig(x)
3131
x = Num(simplify_complex(expand(x)))
32-
return x# simplify_fractions(x)# (a*c^2 + b*c)/c^2 = (a*c + b)/c
32+
return x # simplify_fractions(x)# (a*c^2 + b*c)/c^2 = (a*c + b)/c
3333
end
3434

3535
"Return true if `f` is a sin or cos."
36-
function is_trig(f::Num)
37-
f = ispow(f.val) ? f.val.base : f.val
36+
is_trig(f::Num) = is_trig(f.val)
37+
is_trig(f) = false
38+
function is_trig(f::BasicSymbolic)
39+
f = ispow(f) ? f.base : f
3840
isterm(f) && SymbolicUtils.operation(f) [cos, sin] && return true
3941
return false
4042
end
@@ -148,6 +150,35 @@ trig_to_exp(x::Complex{Num}) = trig_to_exp(x.re) + im * trig_to_exp(x.im)
148150
convert_to_Num(x::Complex{Num})::Num = Num(first(x.re.val.arguments))
149151
convert_to_Num(x::Num)::Num = x
150152

153+
"""
154+
trig_to_exp(x::BasicSymbolic)
155+
156+
Convert all trigonometric terms (sin, cos) in expression `x` to their exponential form
157+
using Euler's formula: ``\\exp(ix) = \\cos(x) + i*\\sin(x)``.
158+
"""
159+
function trig_to_exp(x::BasicSymbolic)
160+
all_terms = get_all_terms(x)
161+
trigs = filter(z -> is_trig(z), all_terms)
162+
163+
rules = []
164+
for trig in trigs
165+
is_pow = ispow(trig) # trig is either a trig or a power of a trig
166+
power = is_pow ? trig.exp : 1
167+
arg = is_pow ? arguments(trig.base)[1] : arguments(trig)[1]
168+
type = is_pow ? operation(trig.base) : operation(trig)
169+
170+
if type == cos
171+
term = (exp(im * arg) + exp(-im * arg))^power * (1 // 2)^power
172+
elseif type == sin
173+
term =
174+
(1 * im^power) * ((exp(-im * arg) - exp(im * arg)))^power * (1 // 2)^power
175+
end
176+
177+
append!(rules, [trig => term])
178+
end
179+
return Symbolics.substitute(x, Dict(rules))
180+
end
181+
151182
"""
152183
exp_to_trig(x::BasicSymbolic)
153184
exp_to_trig(x)

test/symbolics.jl

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ end
4040
@eqtest max_power(a^2 + b, a) == 2
4141
@eqtest max_power(a * ((a + b)^4)^2 + a, a) == 9
4242
@eqtest max_power([a * ((a + b)^4)^2 + a, a^2], a) == 9
43-
@eqtest max_power(a + im*a^2, a) == 2
43+
@eqtest max_power(a + im * a^2, a) == 2
4444

4545
@eqtest drop_powers(a^2 + b, a, 1) == b
4646
@eqtest drop_powers((a + b)^2, a, 1) == b^2
@@ -52,29 +52,49 @@ end
5252
# eq = drop_powers(a^2 + a ~ b, [a, b], 2) # broken
5353
@eqtest [eq.lhs, eq.rhs] == [a, a]
5454
eq = drop_powers(a^2 + a + b ~ a, a, 2)
55-
@test string(eq.rhs) == "a" broken=true
55+
@test string(eq.rhs) == "a" broken = true
5656

5757
@eqtest drop_powers([a^2 + a + b, b], a, 2) == [a + b, b]
5858
@eqtest drop_powers([a^2 + a + b, b], [a, b], 2) == [a + b, b]
5959
end
6060

6161
@testset "trig_to_exp and trig_to_exp" begin
6262
using QuestBase: expand_all, trig_to_exp, exp_to_trig
63-
@variables f t
64-
cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2
65-
sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im)
66-
67-
# automatic conversion between trig and exp form
68-
trigs = [cos(f * t), sin(f * t)]
69-
for (i, trig) in pairs(trigs)
70-
z = trig_to_exp(trig)
71-
@eqtest expand(exp_to_trig(z)) == trig
72-
end
73-
trigs′ = [cos_euler(f * t), sin_euler(f * t)]
74-
for (i, trig) in pairs(trigs′)
75-
z = trig_to_exp(trig)
76-
@eqtest expand(exp_to_trig(z)) == trigs[i]
63+
@testset "Num" begin
64+
@variables f t
65+
cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2
66+
sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im)
67+
68+
# automatic conversion between trig and exp form
69+
trigs = [cos(f * t), sin(f * t)]
70+
for (i, trig) in pairs(trigs)
71+
z = trig_to_exp(trig)
72+
@eqtest expand(exp_to_trig(z)) == trig
73+
end
74+
trigs′ = [cos_euler(f * t), sin_euler(f * t)]
75+
for (i, trig) in pairs(trigs′)
76+
z = trig_to_exp(trig)
77+
@eqtest expand(exp_to_trig(z)) == trigs[i]
78+
end
7779
end
80+
81+
# @testset "BasicSymbolic" begin
82+
# @syms f t
83+
# cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2
84+
# sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im)
85+
86+
# # automatic conversion between trig and exp form
87+
# trigs = [cos(f * t), sin(f * t)]
88+
# for (i, trig) in pairs(trigs)
89+
# z = trig_to_exp(trig)
90+
# @eqtest expand(exp_to_trig(z)) == trig
91+
# end
92+
# trigs′ = [cos_euler(f * t), sin_euler(f * t)]
93+
# for (i, trig) in pairs(trigs′)
94+
# z = trig_to_exp(trig)
95+
# @eqtest expand(exp_to_trig(z)) == trigs[i]
96+
# end
97+
# end
7898
end
7999

80100
@testset "harmonic" begin

0 commit comments

Comments
 (0)