Skip to content

Commit c93edd4

Browse files
committed
Added what should be a direct copy of the scipy misc.derivative function that has been removed from scipy. For issue #1343 (and #1418)
1 parent d967c00 commit c93edd4

File tree

1 file changed

+89
-1
lines changed

1 file changed

+89
-1
lines changed

xga/models/__init__.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# This code is a part of X-ray: Generate and Analyse (XGA), a module designed for the XMM Cluster Survey (XCS).
2-
# Last modified by David J Turner (turne540@msu.edu) 29/07/2024, 21:58. Copyright (c) The Contributors
2+
# Last modified by David J Turner (turne540@msu.edu) 15/08/2025, 13:36. Copyright (c) The Contributors
33

44
import inspect
55
from types import FunctionType
@@ -89,6 +89,94 @@ def convert_to_odr_compatible(model_func: FunctionType, new_par_name: str = 'β'
8989
return new_model_func
9090

9191

92+
def derivative(func: FunctionType, x0: float, dx: float = 1.0, n: int = 1, args: tuple= (), order: int = 3):
93+
"""
94+
Find the nth derivative of a function at a point.
95+
96+
Given a function, use a central difference formula with spacing `dx` to
97+
compute the nth derivative at `x0`.
98+
99+
This is intended as a drop-in replacement for Scipy's misc.derivative function, which was deprecated in
100+
Scipy v1.10.0 and removed after Scipy v1.14.1. It has been directly copied/reconstructed from Scipy code.
101+
102+
:param FunctionType func: Input function
103+
:param x0: The point at which the nth derivative is found.
104+
:param dx: Spacing.
105+
:param n: Order of the derivative. Default is 1.
106+
:param args: Arguments
107+
:param order: Number of points to use, must be odd.
108+
"""
109+
110+
def _central_diff_weights(Np, ndiv=1):
111+
"""
112+
Return weights for an Np-point central derivative.
113+
114+
Assumes equally-spaced function points.
115+
116+
If weights are in the vector w, then
117+
derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)
118+
"""
119+
120+
if Np < ndiv + 1:
121+
raise ValueError(
122+
"Number of points must be at least the derivative order + 1."
123+
)
124+
if Np % 2 == 0:
125+
raise ValueError("The number of points must be odd.")
126+
127+
ho = Np >> 1
128+
x = np.arange(-ho, ho + 1.0)
129+
x = x[:, np.newaxis]
130+
X = x ** 0.0
131+
for k in range(1, Np):
132+
X = np.hstack([X, x ** k])
133+
w = np.prod(np.arange(1, ndiv + 1), axis=0) * np.linalg.inv(X)[ndiv]
134+
return w
135+
136+
if order < n + 1:
137+
raise ValueError(
138+
"'order' (the number of points used to compute the derivative), "
139+
"must be at least the derivative order 'n' + 1."
140+
)
141+
if order % 2 == 0:
142+
raise ValueError(
143+
"'order' (the number of points used to compute the derivative) "
144+
"must be odd."
145+
)
146+
# pre-computed for n=1 and 2 and low-order for speed.
147+
if n == 1:
148+
if order == 3:
149+
weights = np.array([-1, 0, 1]) / 2.0
150+
elif order == 5:
151+
weights = np.array([1, -8, 0, 8, -1]) / 12.0
152+
elif order == 7:
153+
weights = np.array([-1, 9, -45, 0, 45, -9, 1]) / 60.0
154+
elif order == 9:
155+
weights = np.array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0
156+
else:
157+
weights = _central_diff_weights(order, 1)
158+
elif n == 2:
159+
if order == 3:
160+
weights = np.array([1, -2.0, 1])
161+
elif order == 5:
162+
weights = np.array([-1, 16, -30, 16, -1]) / 12.0
163+
elif order == 7:
164+
weights = np.array([2, -27, 270, -490, 270, -27, 2]) / 180.0
165+
elif order == 9:
166+
weights = (
167+
np.array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9])
168+
/ 5040.0
169+
)
170+
else:
171+
weights = _central_diff_weights(order, 2)
172+
else:
173+
weights = _central_diff_weights(order, n)
174+
val = 0.0
175+
ho = order >> 1
176+
for k in range(order):
177+
val += weights[k] * func(x0 + (k - ho) * dx, *args)
178+
return val / np.prod((dx,) * n, axis=0)
179+
92180

93181

94182

0 commit comments

Comments
 (0)