Skip to content

Commit 1d6db7a

Browse files
committed
supporting forward and backward difference in FD
1 parent 7c94442 commit 1d6db7a

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

src/greeks/greeks_problem.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@ export ForwardAD, FiniteDifference, GreekProblem, SecondOrderGreekProblem
44
# Method types
55
abstract type GreekMethod end
66

7+
abstract type FDScheme end
8+
9+
struct FDForward <: FDScheme end
10+
struct FDBackward <: FDScheme end
11+
struct FDCentral <: FDScheme end
12+
713
struct ForwardAD <: GreekMethod end
8-
struct FiniteDifference <: GreekMethod
9-
bump::Float64
14+
struct FiniteDifference{S<:FDScheme} <: GreekMethod
15+
bump
16+
scheme::S
1017
end
1118

19+
FiniteDifference(bump) = FiniteDifference(bump, CentralFiniteDifference())
20+
1221
# First-order GreekProblem
1322
struct GreekProblem{P, L}
1423
pricing_problem::P
@@ -26,19 +35,38 @@ function solve(gprob::GreekProblem, ::ForwardAD, pricing_method::P) where P<:Abs
2635
return (greek = deriv,)
2736
end
2837

29-
function solve(gprob::GreekProblem, method::FiniteDifference, pricing_method::P) where P<:AbstractPricingMethod
30-
prob = gprob.pricing_problem
31-
lens = gprob.wrt
32-
ε = method.bump
33-
38+
function compute_fd_derivative(::ForwardFiniteDifference, prob, lens, ε, pricing_method)
3439
x₀ = lens(prob)
3540
prob_up = set(prob, lens, x₀ + ε)
41+
v_up = solve(prob_up, pricing_method).price
42+
v₀ = solve(prob, pricing_method).price
43+
return (v_up - v₀) / ε
44+
end
45+
46+
function compute_fd_derivative(::BackwardFiniteDifference, prob, lens, ε, pricing_method)
47+
x₀ = lens(prob)
3648
prob_down = set(prob, lens, x₀ - ε)
49+
v_down = solve(prob_down, pricing_method).price
50+
v₀ = solve(prob, pricing_method).price
51+
return (v₀ - v_down) / ε
52+
end
3753

54+
function compute_fd_derivative(::CentralFiniteDifference, prob, lens, ε, pricing_method)
55+
x₀ = lens(prob)
56+
prob_up = set(prob, lens, x₀ + ε)
57+
prob_down = set(prob, lens, x₀ - ε)
3858
v_up = solve(prob_up, pricing_method).price
3959
v_down = solve(prob_down, pricing_method).price
60+
return (v_up - v_down) / (2ε)
61+
end
62+
63+
function solve(gprob::GreekProblem, method::FiniteDifference{S}, pricing_method::P) where {S<:FiniteDifferenceScheme, P<:AbstractPricingMethod}
64+
prob = gprob.pricing_problem
65+
lens = gprob.wrt
66+
ε = method.bump
67+
scheme = method.scheme
4068

41-
deriv = (v_up - v_down) / (2ε)
69+
deriv = compute_fd_derivative(scheme, prob, lens, ε, pricing_method)
4270
return (greek = deriv,)
4371
end
4472

0 commit comments

Comments
 (0)