Skip to content

Interpolating spline #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion spatialmath/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from spatialmath.quaternion import Quaternion, UnitQuaternion
from spatialmath.DualQuaternion import DualQuaternion, UnitDualQuaternion
from spatialmath.spline import BSplineSE3
from spatialmath.spline import BSplineSE3, InterpSplineSE3, SplineFit

# from spatialmath.Plucker import *
# from spatialmath import base as smb
Expand Down Expand Up @@ -45,6 +45,8 @@
"Polygon2",
"Ellipse",
"BSplineSE3",
"InterpSplineSE3",
"SplineFit"
]

try:
Expand Down
10 changes: 5 additions & 5 deletions spatialmath/base/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(self, anim: Animate, h, xs, ys, zs):
self.anim = anim

def draw(self, T):
p = T @ self.p
p = T * self.p
self.h.set_data(p[0, :], p[1, :])
self.h.set_3d_properties(p[2, :])

Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(self, anim, h):

def draw(self, T):
# import ipdb; ipdb.set_trace()
p = T @ self.p
p = T * self.p

# reshape it
p = p[0:3, :].T.reshape(3, 2, 3)
Expand Down Expand Up @@ -432,7 +432,7 @@ def __init__(self, anim, h, x, y, z):
self.anim = anim

def draw(self, T):
p = T @ self.p
p = T * self.p
# x2, y2, _ = proj3d.proj_transform(
# p[0], p[1], p[2], self.anim.ax.get_proj())
# self.h.set_position((x2, y2))
Expand Down Expand Up @@ -759,7 +759,7 @@ def __init__(self, anim, h, xs, ys):
self.anim = anim

def draw(self, T):
p = T @ self.p
p = T * self.p
self.h.set_data(p[0, :], p[1, :])

def plot(self, x, y, *args, **kwargs):
Expand Down Expand Up @@ -837,7 +837,7 @@ def __init__(self, anim, h, x, y):
self.anim = anim

def draw(self, T):
p = T @ self.p
p = T * self.p
# x2, y2, _ = proj3d.proj_transform(
# p[0], p[1], p[2], self.anim.ax.get_proj())
# self.h.set_position((x2, y2))
Expand Down
223 changes: 202 additions & 21 deletions spatialmath/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# MIT Licence, see details in top-level file: LICENCE

"""
Classes for parameterizing a trajectory in SE3 with B-splines.

Copies parts of the API from scipy's B-spline class.
Classes for parameterizing a trajectory in SE3 with splines.
"""

from typing import Any, Dict, List, Optional
Expand All @@ -14,6 +12,182 @@
import matplotlib.pyplot as plt
from spatialmath.base.transforms3d import tranimate, trplot

from typing import Any, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from scipy.interpolate import CubicSpline
from scipy.spatial.transform import Rotation, RotationSpline
from spatialmath import SE3, SO3, Twist3
from spatialmath.base.transforms3d import tranimate


class InterpSplineSE3:
"""Class for an interpolated trajectory in SE3, as a function of time, through waypoints with a cubic spline.

A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic)
under the hood.
"""

def __init__(
self,
timepoints: list[float] | npt.NDArray,
waypoints: list[SE3],
*,
normalize_time: bool = True,
bc_type: str | tuple = "not-a-knot", # not-a-knot is scipy default; None is invalid
) -> None:
"""Construct a InterpSplineSE3 object

Extends the scipy CubicSpline object
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline

Args :
timepoints : list of times corresponding to provided poses
waypoints : list of SE3 objects that govern the shape of the spline.
normalize_time : flag to map times into the range [0, 1]
bc_type : boundary condition provided to scipy CubicSpline backend.
string options: ["not-a-knot" (default), "clamped", "natural", "periodic"].
For tuple options and details see the scipy docs link above.
"""

self.waypoints = waypoints
self.timepoints = np.array(timepoints)

if normalize_time:
self.timepoints = self.timepoints - self.timepoints[0]
self.timepoints = self.timepoints / self.timepoints[-1]

self.spline_xyz = CubicSpline(
self.timepoints,
np.array([pose.t for pose in self.waypoints]),
bc_type=bc_type
)
self.spline_so3 = RotationSpline(self.timepoints, Rotation.from_matrix(np.array([(pose.R) for pose in self.waypoints])))

def __call__(self, t: float) -> Any:

return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix())

def derivative(self, t: float) -> Twist3:
linear_vel = self.spline_xyz.derivative()(t)
angular_vel = self.spline_so3(t, 1)
return Twist3(linear_vel, angular_vel)

def visualize(
self,
num_samples: int,
pose_marker_length: float = 0.2,
animate: bool = False,
ax: plt.Axes | None = None,
input_poses: List[SE3] | None = None
) -> None:
"""Displays an animation of the trajectory with the control poses."""
if ax is None:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")

samples = [self(t) for t in np.linspace(0, 1, num_samples)]
if not animate:
x = [pose.x for pose in samples]
y = [pose.y for pose in samples]
z = [pose.z for pose in samples]
ax.plot(x, y, z, "c", linewidth=1.0) # plot spline fit

x = [pose.x for pose in self.waypoints]
y = [pose.y for pose in self.waypoints]
z = [pose.z for pose in self.waypoints]
ax.plot(x, y, z, "r*") # plot waypoints

if input_poses is not None:
x = [pose.x for pose in input_poses]
y = [pose.y for pose in input_poses]
z = [pose.z for pose in input_poses]
ax.plot(x, y, z, "go", fillstyle="none") # plot compare to input poses

if animate:
tranimate(samples, repeat=True, length=pose_marker_length, wait=True) # animate pose along trajectory
else:
plt.show()


class SplineFit:

def __init__(
self,
time_data: npt.NDArray,
pose_data: npt.NDArray,
) -> None:
self.time_data = time_data
self.pose_data = pose_data

self.xyz_data = np.array([pose.t for pose in pose_data])
self.so3_data = Rotation.from_matrix(np.array([(pose.R) for pose in pose_data]))

self.spline: InterpSplineSE3 | BSplineSE3 | None = None

def downsampled_interpolation(
self,
epsilon_xyz: float = 1e-3,
epsilon_angle: float = 1e-1,
normalize_time: bool = True,
bc_type: str | tuple = "not-a-knot",
) -> Tuple[InterpSplineSE3, List[int]]:
"""
Return:
downsampled interpolating spline,
list of removed indices from input data
"""
spline = InterpSplineSE3(
self.time_data,
self.pose_data,
normalize_time = normalize_time,
bc_type=bc_type

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent formatting; do we not have a way of auto checking that on this repo?

)
chosen_indices: set[int] = set()
interpolation_indices = list(range(len(self.pose_data)))

for _ in range(len(self.time_data) - 2): # you must have at least 2 indices

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May as well use len(interpolation_indices) and drop the - 2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the length of interpolation indices is changing as the loop executes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your suggestion might work anyway, but that's why i didn't code it like that originally

choices = list(set(interpolation_indices).difference(chosen_indices))

index = np.random.choice(choices)

chosen_indices.add(index)
interpolation_indices.remove(index)

spline.spline_xyz = CubicSpline(self.time_data[interpolation_indices], self.xyz_data[interpolation_indices])
spline.spline_so3 = RotationSpline(
self.time_data[interpolation_indices], self.so3_data[interpolation_indices]
)

time = self.time_data[index]
angular_error = SO3(self.pose_data[index]).angdist(SO3(spline.spline_so3(time).as_matrix()))
euclidean_error = np.linalg.norm(self.pose_data[index].t - spline.spline_xyz(time))
if angular_error > epsilon_angle or euclidean_error > epsilon_xyz:
interpolation_indices.insert(int(np.searchsorted(interpolation_indices, index, side="right")), index)

self.spline = spline
return spline, interpolation_indices

def max_angular_error(self) -> float:
return np.max(self.angular_errors())

def angular_errors(self) -> list[float]:
return [
pose.angdist(self.spline(t))
for pose, t in zip(self.waypoints, self.timepoints, strict=True)
]

def max_euclidean_error(self) -> float:
return np.max(self.euclidean_errors())

def euclidean_errors(self) -> List[float]:
return [
np.linalg.norm(pose.t - self.spline(t).t)
for pose, t in zip(self.waypoints, self.timepoints, strict=True)
]


class BSplineSE3:
"""A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.
Expand Down Expand Up @@ -78,28 +252,35 @@ def __call__(self, t: float) -> SE3:
def visualize(
self,
num_samples: int,
length: float = 1.0,
repeat: bool = False,
ax: Optional[plt.Axes] = None,
kwargs_trplot: Dict[str, Any] = {"color": "green"},
kwargs_tranimate: Dict[str, Any] = {"wait": True},
kwargs_plot: Dict[str, Any] = {},
pose_marker_length: float = 0.2,
animate: bool = False,
ax: plt.Axes | None = None,
input_poses: List[SE3] | None = None
) -> None:
"""Displays an animation of the trajectory with the control poses."""
out_poses = [self(t) for t in np.linspace(0, 1, num_samples)]
x = [pose.x for pose in out_poses]
y = [pose.y for pose in out_poses]
z = [pose.z for pose in out_poses]

if ax is None:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")

trplot(
[np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot
) # plot control points
ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory
samples = [self(t) for t in np.linspace(0, 1, num_samples)]
if not animate:
x = [pose.x for pose in samples]
y = [pose.y for pose in samples]
z = [pose.z for pose in samples]
ax.plot(x, y, z, "c", linewidth=1.0) # plot spline fit

x = [pose.x for pose in self.control_poses]
y = [pose.y for pose in self.control_poses]
z = [pose.z for pose in self.control_poses]
ax.plot(x, y, z, "r*") # plot waypoints

if input_poses is not None:
x = [pose.x for pose in input_poses]
y = [pose.y for pose in input_poses]
z = [pose.z for pose in input_poses]
ax.plot(x, y, z, "go", fillstyle="none") # plot compare to input poses

tranimate(
out_poses, repeat=repeat, length=length, **kwargs_tranimate
) # animate pose along trajectory
if animate:
tranimate(samples, repeat=True, length=pose_marker_length, wait=True) # animate pose along trajectory
else:
plt.show()
55 changes: 53 additions & 2 deletions tests/test_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import pytest

from spatialmath import BSplineSE3, SE3
from spatialmath import BSplineSE3, SE3, InterpSplineSE3, SplineFit, SO3


class TestBSplineSE3(unittest.TestCase):
Expand All @@ -29,4 +29,55 @@ def test_evaluation(self):

def test_visualize(self):
spline = BSplineSE3(self.control_poses)
spline.visualize(num_samples=100, repeat=False)
spline.visualize(num_samples=100, animate=True)

class TestInterpSplineSE3:
waypoints = [
SE3.Trans([e, 2 * np.cos(e / 2 * np.pi), 2 * np.sin(e / 2 * np.pi)])
* SE3.Ry(e / 8 * np.pi)
for e in range(0, 8)
]
times = np.linspace(0, 10, len(waypoints))

@classmethod
def tearDownClass(cls):
plt.close("all")

def test_constructor(self):
InterpSplineSE3(self.times, self.waypoints)

def test_evaluation(self):
spline = InterpSplineSE3(self.times, self.waypoints)
nt.assert_almost_equal(spline(0).A, self.waypoints[0].A)
nt.assert_almost_equal(spline(1).A, self.waypoints[-1].A)

def test_visualize(self):
spline = InterpSplineSE3(self.times, self.waypoints)
spline.visualize(num_samples=100, animate=True)


class TestSpineFit:
num_data_points = 300
time_horizon = 5
num_viz_points = 100

# make a helix
timestamps = np.linspace(0, 1, num_data_points)
trajectory = [
SE3.Rt(
t=[t * 0.4, 0.4 * np.sin(t * 2 * np.pi * 0.5), 0.4 * np.cos(t * 2 * np.pi * 0.5)],
R=SO3.Rx(t * 2 * np.pi * 0.5),
)
for t in timestamps * time_horizon
]

fit = SplineFit(timestamps, trajectory)
spline, kept_indices = fit.downsampled_interpolation()

fraction_points_removed = 1.0 - len(kept_indices) / num_data_points
assert(fraction_points_removed > 0.2)

sample = spline(timestamps[50])
true_pose = trajectory[50]
assert( sample.angdist(true_pose) <np.deg2rad(5.0) )
assert( np.linalg.norm(sample.t - true_pose.t) < 0.1 )
Loading