Skip to content

Commit c2f7bf1

Browse files
authored
Merge pull request #3 from aleCombi/Greek_Gradients
Greek gradients
2 parents ab0a04a + 17a0e35 commit c2f7bf1

File tree

13 files changed

+450
-86
lines changed

13 files changed

+450
-86
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2626
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2727
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2828
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
29+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2930
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3031

3132
[compat]
@@ -51,6 +52,7 @@ SciMLBase = "2.77.2"
5152
SpecialFunctions = "2.5.0"
5253
StaticArrays = "1.9.13"
5354
Statistics = "1.11.1"
55+
Test = "1.11.0"
5456
Zygote = "0.6.76"
5557
julia = "1.11"
5658

examples/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
5+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
56
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
67
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
78
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
89
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
911
Hedgehog2 = "7f16798b-0e18-40de-98af-932948254698"
1012
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
1113
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
@@ -17,3 +19,4 @@ Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
1719
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
1820
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1921
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
22+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

examples/comparisons/euro.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using Revise, Hedgehog2, BenchmarkTools, Dates
2+
using Accessors
3+
import Accessors: @optic
4+
using Test
5+
using DataFrames
6+
7+
include("run_model_comparison.jl")
8+
9+
# ------------------------------
10+
# Define payoff and pricing problem
11+
# ------------------------------
12+
strike = 1.0
13+
expiry = Date(2020, 1, 2)
14+
underlying = Hedgehog2.Forward()
15+
euro_payoff = VanillaOption(strike, expiry, European(), Put(), underlying)
16+
17+
reference_date = Date(2020, 1, 1)
18+
rate = 0.03
19+
spot = 1.0
20+
sigma = 1.0
21+
market_inputs = BlackScholesInputs(reference_date, rate, spot, sigma)
22+
euro_pricing_prob = PricingProblem(euro_payoff, market_inputs)
23+
24+
# ------------------------------
25+
# Define pricing methods
26+
# ------------------------------
27+
bs_method = BlackScholesAnalytic()
28+
crr_method = CoxRossRubinsteinMethod(800)
29+
30+
# ------------------------------
31+
# Define lenses for Greeks
32+
# ------------------------------
33+
vol_lens = VolLens(1, 1)
34+
spot_lens = @optic _.market_inputs.spot
35+
rate_lens = ZeroRateSpineLens(1)
36+
lenses = (spot_lens, vol_lens, rate_lens)
37+
38+
# ------------------------------
39+
# Run comparison table
40+
# ------------------------------
41+
df = run_model_comparison_table(
42+
euro_pricing_prob,
43+
[bs_method, crr_method],
44+
lenses;
45+
ad_method = ForwardAD(),
46+
fd_method = FiniteDifference(1e-4),
47+
analytic_method = AnalyticGreek(),
48+
)
49+
println(df)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
using DelimitedFiles
2+
using DataFrames
3+
using Hedgehog2
4+
5+
function run_model_comparison_table(
6+
prob::PricingProblem,
7+
models::Vector{Hedgehog2.AbstractPricingMethod},
8+
lenses::Tuple;
9+
ad_method::Hedgehog2.GreekMethod = ForwardAD(),
10+
fd_method::Hedgehog2.GreekMethod = FiniteDifference(1e-3),
11+
analytic_method::Union{Nothing, Hedgehog2.GreekMethod} = nothing,
12+
)
13+
14+
results = Dict{String, Any}()
15+
rows = Vector{NamedTuple}()
16+
17+
for model in models
18+
model_name = string(typeof(model).name.name)
19+
20+
# Price
21+
price_time = @belapsed Hedgehog2.solve($prob, $model)
22+
sol = Hedgehog2.solve(prob, model)
23+
price = sol.price
24+
25+
# Batch Greeks (AD & FD)
26+
batch_prob = BatchGreekProblem(prob, lenses)
27+
ad_time = @belapsed solve($batch_prob, $ad_method, $model)
28+
fd_time = @belapsed solve($batch_prob, $fd_method, $model)
29+
greeks_ad = solve(batch_prob, ad_method, model)
30+
greeks_fd = solve(batch_prob, fd_method, model)
31+
32+
# Attempt full AnalyticGreek batch solve
33+
greeks_analytic = Dict{Any, Union{Float64, Missing}}()
34+
if analytic_method !== nothing
35+
try
36+
greeks_full = solve(batch_prob, analytic_method, model)
37+
for lens in lenses
38+
greeks_analytic[lens] = greeks_full[lens]
39+
end
40+
catch
41+
# Fallback to individual lens-based attempts
42+
for lens in lenses
43+
try
44+
single_prob = BatchGreekProblem(prob, (lens,))
45+
val = solve(single_prob, analytic_method, model)[lens]
46+
greeks_analytic[lens] = val
47+
catch
48+
greeks_analytic[lens] = missing
49+
end
50+
end
51+
end
52+
else
53+
for lens in lenses
54+
greeks_analytic[lens] = missing
55+
end
56+
end
57+
58+
results[model_name] = (
59+
price=price,
60+
price_time=price_time,
61+
greeks_ad=greeks_ad,
62+
ad_time=ad_time,
63+
greeks_fd=greeks_fd,
64+
fd_time=fd_time,
65+
greeks_analytic=greeks_analytic,
66+
)
67+
end
68+
69+
baseline = first(models)
70+
baseline_name = string(typeof(baseline).name.name)
71+
72+
for lens in lenses
73+
for (name, data) in results
74+
ad_val = data.greeks_ad[lens]
75+
fd_val = data.greeks_fd[lens]
76+
analytic_val = data.greeks_analytic[lens]
77+
price = data.price
78+
ad_time = data.ad_time / length(lenses)
79+
fd_time = data.fd_time / length(lenses)
80+
price_time = data.price_time
81+
82+
push!(rows, (
83+
greek = string(lens),
84+
model = name,
85+
metric = "value",
86+
ad_value = ad_val,
87+
fd_value = fd_val,
88+
analytic_value = analytic_val,
89+
price = price,
90+
ad_us = ad_time * 1e6,
91+
fd_us = fd_time * 1e6,
92+
price_us = price_time * 1e6,
93+
))
94+
end
95+
96+
baseline_data = results[baseline_name]
97+
98+
for (name, data) in results
99+
if name == baseline_name
100+
continue
101+
end
102+
103+
ad_diff = data.greeks_ad[lens] - baseline_data.greeks_ad[lens]
104+
fd_diff = data.greeks_fd[lens] - baseline_data.greeks_fd[lens]
105+
price_diff = data.price - baseline_data.price
106+
ad_time_diff = (data.ad_time / length(lenses) - baseline_data.ad_time / length(lenses)) * 1e6
107+
fd_time_diff = (data.fd_time / length(lenses) - baseline_data.fd_time / length(lenses)) * 1e6
108+
price_time_diff = (data.price_time - baseline_data.price_time) * 1e6
109+
110+
a1 = data.greeks_analytic[lens]
111+
a0 = baseline_data.greeks_analytic[lens]
112+
analytic_diff = (!ismissing(a1) && !ismissing(a0)) ? a1 - a0 : missing
113+
114+
push!(rows, (
115+
greek = string(lens),
116+
model = "Δ " * name,
117+
metric = "diff",
118+
ad_value = ad_diff,
119+
fd_value = fd_diff,
120+
analytic_value = analytic_diff,
121+
price = price_diff,
122+
ad_us = ad_time_diff,
123+
fd_us = fd_time_diff,
124+
price_us = price_time_diff,
125+
))
126+
127+
add_separator!(rows)
128+
end
129+
end
130+
131+
return DataFrame(rows)
132+
end
133+
134+
function add_separator!(rows)
135+
# Add a separator row to visually split groups when displayed
136+
push!(rows, (
137+
greek = "────────────────────────",
138+
model = "",
139+
metric = "",
140+
ad_value = missing,
141+
fd_value = missing,
142+
analytic_value = missing,
143+
price = missing,
144+
ad_us = missing,
145+
fd_us = missing,
146+
price_us = missing,
147+
))
148+
end

examples/polished_examples/american_options.jl

Whitespace-only changes.

examples/polished_examples/black_scholes.jl

Lines changed: 0 additions & 58 deletions
This file was deleted.

examples/polished_examples/carr_madan.jl

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,49 @@ method_analytic = BlackScholesAnalytic()
2020
method_carr_madan = CarrMadan(1.0, 16.0, LognormalDynamics())
2121

2222
# -- Solve for prices
23-
sol_analytic = Hedgehog2.solve(prob, method_analytic)
24-
sol_carr = Hedgehog2.solve(prob, method_carr_madan)
25-
2623
@btime Hedgehog2.solve($prob, $method_analytic)
2724
@btime Hedgehog2.solve($prob, $method_carr_madan)
2825

26+
sol_analytic = Hedgehog2.solve(prob, method_analytic)
27+
sol_carr = Hedgehog2.solve(prob, method_carr_madan)
28+
2929
println("Analytic price: ", sol_analytic.price)
3030
println("Carr-Madan price: ", sol_carr.price)
3131

32-
# --- Greeks via GreekProblem
33-
34-
# Accessors
32+
# --- Greeks via Accessors
3533
spot_lens = @optic _.market_inputs.spot
36-
sigma_lens = Hedgehog2.VolLens(1,1)
34+
sigma_lens = Hedgehog2.VolLens(1, 1)
35+
lenses = (sigma_lens, spot_lens)
3736

38-
# Methods
37+
# -- Greek methods
3938
fd = FiniteDifference(1e-3)
4039
ad = ForwardAD()
4140

4241
println("\n--- Greeks (Analytic Method) ---")
43-
delta_fd = Hedgehog2.solve(GreekProblem(prob, spot_lens), fd, method_analytic).greek
44-
vega_fd = Hedgehog2.solve(GreekProblem(prob, sigma_lens), fd, method_analytic).greek
45-
delta_ad = Hedgehog2.solve(GreekProblem(prob, spot_lens), ad, method_analytic).greek
46-
vega_ad = Hedgehog2.solve(GreekProblem(prob, sigma_lens), ad, method_analytic).greek
42+
batch_prob = BatchGreekProblem(prob, lenses)
43+
greeks_fd = solve(batch_prob, fd, method_analytic)
44+
greeks_ad = solve(batch_prob, ad, method_analytic)
4745

48-
println("FD Delta (analytic): ", delta_fd)
49-
println("AD Delta (analytic): ", delta_ad)
50-
println("FD Vega (analytic): ", vega_fd)
51-
println("AD Vega (analytic): ", vega_ad)
46+
println("FD Delta (analytic): ", greeks_fd[spot_lens])
47+
println("AD Delta (analytic): ", greeks_ad[spot_lens])
48+
println("FD Vega (analytic): ", greeks_fd[sigma_lens])
49+
println("AD Vega (analytic): ", greeks_ad[sigma_lens])
5250

5351
println("\n--- Greeks (Carr-Madan Method) ---")
54-
delta_fd_cm = Hedgehog2.solve(GreekProblem(prob, spot_lens), fd, method_carr_madan).greek
55-
vega_fd_cm = Hedgehog2.solve(GreekProblem(prob, sigma_lens), fd, method_carr_madan).greek
56-
delta_ad_cm = Hedgehog2.solve(GreekProblem(prob, spot_lens), ad, method_carr_madan).greek
57-
vega_ad_cm = Hedgehog2.solve(GreekProblem(prob, sigma_lens), ad, method_carr_madan).greek
58-
59-
println("FD Delta (Carr-Madan): ", delta_fd_cm)
60-
println("AD Delta (Carr-Madan): ", delta_ad_cm)
61-
println("FD Vega (Carr-Madan): ", vega_fd_cm)
62-
println("AD Vega (Carr-Madan): ", vega_ad_cm)
52+
greeks_fd_cm = solve(batch_prob, fd, method_carr_madan)
53+
greeks_ad_cm = solve(batch_prob, ad, method_carr_madan)
54+
55+
println("FD Delta (Carr-Madan): ", greeks_fd_cm[spot_lens])
56+
println("AD Delta (Carr-Madan): ", greeks_ad_cm[spot_lens])
57+
println("FD Vega (Carr-Madan): ", greeks_fd_cm[sigma_lens])
58+
println("AD Vega (Carr-Madan): ", greeks_ad_cm[sigma_lens])
59+
60+
# -- Benchmarks
61+
println("\n--- Benchmarking AD Greeks ---")
62+
@btime solve($batch_prob, $ad, $method_carr_madan)
63+
64+
delta_prob = GreekProblem(prob, spot_lens)
65+
vega_prob = GreekProblem(prob, sigma_lens)
66+
67+
@btime solve($delta_prob, $ad, $method_carr_madan)
68+
@btime solve($vega_prob, $ad, $method_carr_madan)

0 commit comments

Comments
 (0)