-
Notifications
You must be signed in to change notification settings - Fork 94
Interpolating spline #141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interpolating spline #141
Changes from 6 commits
183c854
af175d3
1e0f7ed
578f4d3
bce7f58
3d99327
806726e
aa8148a
d5f2c42
ecc9b59
2c5190f
1b8fc78
c7b7be6
51f9204
9634c24
a2294ed
4e7aa35
551e473
37181dd
8dbc059
02b1f52
bf28479
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,7 @@ | |
# MIT Licence, see details in top-level file: LICENCE | ||
|
||
""" | ||
Classes for parameterizing a trajectory in SE3 with B-splines. | ||
|
||
Copies parts of the API from scipy's B-spline class. | ||
Classes for parameterizing a trajectory in SE3 with splines. | ||
""" | ||
|
||
from typing import Any, Dict, List, Optional | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -14,6 +12,182 @@ | |
import matplotlib.pyplot as plt | ||
from spatialmath.base.transforms3d import tranimate, trplot | ||
|
||
from typing import Any, List, Tuple | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import numpy.typing as npt | ||
from scipy.interpolate import CubicSpline | ||
from scipy.spatial.transform import Rotation, RotationSpline | ||
from spatialmath import SE3, SO3, Twist3 | ||
from spatialmath.base.transforms3d import tranimate | ||
|
||
|
||
class InterpSplineSE3: | ||
"""Class for an interpolated trajectory in SE3, as a function of time, through waypoints with a cubic spline. | ||
|
||
A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic) | ||
under the hood. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
timepoints: list[float] | npt.NDArray, | ||
waypoints: list[SE3], | ||
*, | ||
normalize_time: bool = True, | ||
bc_type: str | tuple = "not-a-knot", # not-a-knot is scipy default; None is invalid | ||
) -> None: | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Construct a InterpSplineSE3 object | ||
|
||
Extends the scipy CubicSpline object | ||
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline | ||
|
||
Args : | ||
timepoints : list of times corresponding to provided poses | ||
waypoints : list of SE3 objects that govern the shape of the spline. | ||
normalize_time : flag to map times into the range [0, 1] | ||
bc_type : boundary condition provided to scipy CubicSpline backend. | ||
string options: ["not-a-knot" (default), "clamped", "natural", "periodic"]. | ||
For tuple options and details see the scipy docs link above. | ||
""" | ||
|
||
self.waypoints = waypoints | ||
self.timepoints = np.array(timepoints) | ||
|
||
if normalize_time: | ||
self.timepoints = self.timepoints - self.timepoints[0] | ||
self.timepoints = self.timepoints / self.timepoints[-1] | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.spline_xyz = CubicSpline( | ||
self.timepoints, | ||
np.array([pose.t for pose in self.waypoints]), | ||
bc_type=bc_type | ||
) | ||
self.spline_so3 = RotationSpline(self.timepoints, Rotation.from_matrix(np.array([(pose.R) for pose in self.waypoints]))) | ||
|
||
def __call__(self, t: float) -> Any: | ||
|
||
return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix()) | ||
|
||
def derivative(self, t: float) -> Twist3: | ||
linear_vel = self.spline_xyz.derivative()(t) | ||
angular_vel = self.spline_so3(t, 1) | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Twist3(linear_vel, angular_vel) | ||
|
||
def visualize( | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
num_samples: int, | ||
pose_marker_length: float = 0.2, | ||
animate: bool = False, | ||
ax: plt.Axes | None = None, | ||
input_poses: List[SE3] | None = None | ||
) -> None: | ||
"""Displays an animation of the trajectory with the control poses.""" | ||
if ax is None: | ||
fig = plt.figure(figsize=(10, 10)) | ||
ax = fig.add_subplot(projection="3d") | ||
|
||
samples = [self(t) for t in np.linspace(0, 1, num_samples)] | ||
if not animate: | ||
x = [pose.x for pose in samples] | ||
y = [pose.y for pose in samples] | ||
z = [pose.z for pose in samples] | ||
ax.plot(x, y, z, "c", linewidth=1.0) # plot spline fit | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
x = [pose.x for pose in self.waypoints] | ||
y = [pose.y for pose in self.waypoints] | ||
z = [pose.z for pose in self.waypoints] | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ax.plot(x, y, z, "r*") # plot waypoints | ||
|
||
if input_poses is not None: | ||
x = [pose.x for pose in input_poses] | ||
y = [pose.y for pose in input_poses] | ||
z = [pose.z for pose in input_poses] | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ax.plot(x, y, z, "go", fillstyle="none") # plot compare to input poses | ||
|
||
if animate: | ||
tranimate(samples, repeat=True, length=pose_marker_length, wait=True) # animate pose along trajectory | ||
else: | ||
plt.show() | ||
|
||
|
||
class SplineFit: | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
time_data: npt.NDArray, | ||
pose_data: npt.NDArray, | ||
) -> None: | ||
self.time_data = time_data | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.pose_data = pose_data | ||
|
||
self.xyz_data = np.array([pose.t for pose in pose_data]) | ||
self.so3_data = Rotation.from_matrix(np.array([(pose.R) for pose in pose_data])) | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.spline: InterpSplineSE3 | BSplineSE3 | None = None | ||
|
||
def downsampled_interpolation( | ||
self, | ||
epsilon_xyz: float = 1e-3, | ||
epsilon_angle: float = 1e-1, | ||
normalize_time: bool = True, | ||
bc_type: str | tuple = "not-a-knot", | ||
) -> Tuple[InterpSplineSE3, List[int]]: | ||
""" | ||
Return: | ||
downsampled interpolating spline, | ||
list of removed indices from input data | ||
""" | ||
spline = InterpSplineSE3( | ||
self.time_data, | ||
self.pose_data, | ||
normalize_time = normalize_time, | ||
bc_type=bc_type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent formatting; do we not have a way of auto checking that on this repo? |
||
) | ||
chosen_indices: set[int] = set() | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
interpolation_indices = list(range(len(self.pose_data))) | ||
|
||
for _ in range(len(self.time_data) - 2): # you must have at least 2 indices | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May as well use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the length of interpolation indices is changing as the loop executes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. your suggestion might work anyway, but that's why i didn't code it like that originally |
||
choices = list(set(interpolation_indices).difference(chosen_indices)) | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
index = np.random.choice(choices) | ||
|
||
chosen_indices.add(index) | ||
interpolation_indices.remove(index) | ||
|
||
spline.spline_xyz = CubicSpline(self.time_data[interpolation_indices], self.xyz_data[interpolation_indices]) | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
spline.spline_so3 = RotationSpline( | ||
self.time_data[interpolation_indices], self.so3_data[interpolation_indices] | ||
) | ||
|
||
time = self.time_data[index] | ||
angular_error = SO3(self.pose_data[index]).angdist(SO3(spline.spline_so3(time).as_matrix())) | ||
euclidean_error = np.linalg.norm(self.pose_data[index].t - spline.spline_xyz(time)) | ||
if angular_error > epsilon_angle or euclidean_error > epsilon_xyz: | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
interpolation_indices.insert(int(np.searchsorted(interpolation_indices, index, side="right")), index) | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.spline = spline | ||
return spline, interpolation_indices | ||
|
||
def max_angular_error(self) -> float: | ||
return np.max(self.angular_errors()) | ||
|
||
def angular_errors(self) -> list[float]: | ||
return [ | ||
pose.angdist(self.spline(t)) | ||
for pose, t in zip(self.waypoints, self.timepoints, strict=True) | ||
] | ||
|
||
def max_euclidean_error(self) -> float: | ||
return np.max(self.euclidean_errors()) | ||
|
||
def euclidean_errors(self) -> List[float]: | ||
return [ | ||
np.linalg.norm(pose.t - self.spline(t).t) | ||
for pose, t in zip(self.waypoints, self.timepoints, strict=True) | ||
] | ||
|
||
|
||
class BSplineSE3: | ||
"""A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline. | ||
|
@@ -78,28 +252,35 @@ def __call__(self, t: float) -> SE3: | |
def visualize( | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
num_samples: int, | ||
length: float = 1.0, | ||
repeat: bool = False, | ||
ax: Optional[plt.Axes] = None, | ||
kwargs_trplot: Dict[str, Any] = {"color": "green"}, | ||
kwargs_tranimate: Dict[str, Any] = {"wait": True}, | ||
kwargs_plot: Dict[str, Any] = {}, | ||
pose_marker_length: float = 0.2, | ||
animate: bool = False, | ||
ax: plt.Axes | None = None, | ||
input_poses: List[SE3] | None = None | ||
) -> None: | ||
"""Displays an animation of the trajectory with the control poses.""" | ||
out_poses = [self(t) for t in np.linspace(0, 1, num_samples)] | ||
x = [pose.x for pose in out_poses] | ||
y = [pose.y for pose in out_poses] | ||
z = [pose.z for pose in out_poses] | ||
|
||
if ax is None: | ||
fig = plt.figure(figsize=(10, 10)) | ||
ax = fig.add_subplot(projection="3d") | ||
|
||
trplot( | ||
[np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot | ||
) # plot control points | ||
ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory | ||
samples = [self(t) for t in np.linspace(0, 1, num_samples)] | ||
if not animate: | ||
x = [pose.x for pose in samples] | ||
y = [pose.y for pose in samples] | ||
z = [pose.z for pose in samples] | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ax.plot(x, y, z, "c", linewidth=1.0) # plot spline fit | ||
|
||
x = [pose.x for pose in self.control_poses] | ||
y = [pose.y for pose in self.control_poses] | ||
z = [pose.z for pose in self.control_poses] | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ax.plot(x, y, z, "r*") # plot waypoints | ||
|
||
if input_poses is not None: | ||
x = [pose.x for pose in input_poses] | ||
y = [pose.y for pose in input_poses] | ||
z = [pose.z for pose in input_poses] | ||
myeatman-bdai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ax.plot(x, y, z, "go", fillstyle="none") # plot compare to input poses | ||
|
||
tranimate( | ||
out_poses, repeat=repeat, length=length, **kwargs_tranimate | ||
) # animate pose along trajectory | ||
if animate: | ||
tranimate(samples, repeat=True, length=pose_marker_length, wait=True) # animate pose along trajectory | ||
else: | ||
plt.show() |
Uh oh!
There was an error while loading. Please reload this page.