Skip to content

Commit 7fc2d1c

Browse files
committed
Add equalization and extrapolation to interpolated curves
1 parent a508cda commit 7fc2d1c

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

src/classy_blocks/construct/curves/interpolated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class InterpolatedCurveBase(FunctionCurveBase, abc.ABC):
2525

2626
_interpolator: Type[InterpolatorBase]
2727

28-
def __init__(self, points: PointListType):
28+
def __init__(self, points: PointListType, extrapolate: bool = False, equalize: bool = True):
2929
self.array = Array(points)
30-
self.function = self._interpolator(self.array, False)
30+
self.function = self._interpolator(self.array, extrapolate, equalize)
3131
self.bounds = (0, 1)
3232

3333
@property

src/classy_blocks/construct/curves/interpolators.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import numpy as np
44
import scipy.interpolate
5-
from numpy.typing import NDArray
65

76
from classy_blocks.construct.array import Array
8-
from classy_blocks.types import NPPointType, ParamCurveFuncType
7+
from classy_blocks.types import FloatListType, NPPointType, ParamCurveFuncType
98

109

1110
class InterpolatorBase(abc.ABC):
@@ -19,9 +18,10 @@ class InterpolatorBase(abc.ABC):
1918
def _get_function(self) -> ParamCurveFuncType:
2019
"""Returns an interpolation function from stored points"""
2120

22-
def __init__(self, points: Array, extrapolate: bool):
21+
def __init__(self, points: Array, extrapolate: bool, equalize: bool = True):
2322
self.points = points
2423
self.extrapolate = extrapolate
24+
self.equalize = equalize
2525

2626
self.function = self._get_function()
2727
self._valid = True
@@ -37,7 +37,16 @@ def invalidate(self) -> None:
3737
self._valid = False
3838

3939
@property
40-
def params(self) -> NDArray:
40+
def params(self) -> FloatListType:
41+
"""A list of parameters for the interpolation curve.
42+
If not equalized, it's just linearly spaced floats;
43+
if equalized, scaled distances between provided points are taken so that
44+
evenly spaced parameters will produce evenly spaced points even if
45+
interpolation points are unequally spaced."""
46+
if self.equalize:
47+
lengths = np.cumsum(np.sqrt(np.sum((self.points[:-1] - self.points[1:]) ** 2, axis=1)))
48+
return np.concatenate(([0], lengths / lengths[-1]))
49+
4150
return np.linspace(0, 1, num=len(self.points))
4251

4352

tests/test_construct/test_curves/test_interpolated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_param_at_length(self):
6969
self.assertAlmostEqual(self.curve.get_param_at_length(1), 0.25)
7070

7171
def test_shear(self):
72-
curve = self.curve
72+
curve = LinearInterpolatedCurve(self.points, equalize=False)
7373

7474
curve.shear([0, 1, 0], [0, 0, 0], [1, 0, 0], np.pi / 4)
7575

0 commit comments

Comments
 (0)