from rtd.planner.trajectory import Trajectory
from rtd.entity.states import EntityState
from rtd.functional.sequences import toSequence
import numpy as np
[docs]class BadTrajectoryException(Exception):
def __init__(self, message: str = ""):
super().__init__(message)
[docs]class TrajectoryContainer:
[docs] def __init__(self):
# vector of start times for each trajectory
# last element is always inf to ensure the last trajectory is always used
self._startTimes: list[float] = np.array([np.inf], np.double)
# list of trajectories corresponding to the start times
self._trajectories: list[Trajectory] = np.array([], dtype=Trajectory)
[docs] def setInitialTrajectory(self, initialTrajectory: Trajectory):
'''
Sets the initial trajectory for the container.
This method must be called before any other method is called.
Parameters
----------
initialTrajectory : Trajectory
the initial trajectory to add
'''
if initialTrajectory.startTime != 0:
raise BadTrajectoryException("Provided initial trajectory does not start at 0!")
if not initialTrajectory.validate():
raise BadTrajectoryException("Provided initial trajectory is invalid!")
if not self.isValid(False):
self._startTimes = np.array([initialTrajectory.startTime, np.inf], dtype=np.double)
self._trajectories = np.array([initialTrajectory], dtype=Trajectory)
else:
self._startTimes[0] = initialTrajectory.startTime
self._trajectories[0] = initialTrajectory
[docs] def clear(self):
'''
Clears the container and resets it to the initial state
It's expected that there's already some initial trajectory set.
If not, a warning is thrown.
'''
if self.isValid():
self._startTimes = np.array([0, np.inf], np.double)
self._trajectories = np.array([self._trajectories[0]], dtype=Trajectory)
else:
print("Warning: clear() for TrajectoryContainer was called before valid initial trajectory was set!")
[docs] def isValid(self, errorIfInvalid: bool = False) -> bool:
'''
Checks if the container is valid.
Parameters
----------
errorIfInvalid : bool
whether to raise an error when the container is invalid
Returns
-------
valid : bool
whether the container is valid or not
'''
valid = len(self._trajectories)>=1 and len(self._startTimes)==1+len(self._trajectories)
if not valid and errorIfInvalid:
raise BadTrajectoryException("Initial trajectory for the container has not been set!")
return valid
[docs] def setTrajectory(self, trajectory: Trajectory, errorIfInvalid: bool = False):
'''
Sets a new trajectory for the container to the end.
The new trajectory must start at a time greater than equal to
the end of the last trajectory.
Parameters
----------
trajectory : Trajectory
the trajectory to add
errorIfInvalid : bool
whether to raise an error when the trajectory is invalid
'''
self.isValid(True)
# add the trajectory if it is valid
if trajectory.validate() and trajectory.startTime>=self._startTimes[-2]:
self._startTimes[-1] = trajectory.startTime
np.append(self._startTimes, np.inf)
np.append(self._trajectories, trajectory)
elif errorIfInvalid:
raise BadTrajectoryException("Provided trajectory starts before the end of the last trajectory!")
else:
print("Warning: Invalid trajectory provided to TrajectoryContainer")
[docs] def getCommand(self, time: float | list[float]) -> list[EntityState]:
'''
Generates a command based on the time.
The command is generated based on the trajectory that is active
at the time. If the time is before the start of the first trajectory,
then the command is generated based on the initial trajectory. If the
time is after the last trajectory, then the command is generated based
on the last trajectory.
Parameters
----------
time : float | list[float]
the time(s) to generate the command for
Returns
-------
commands : NDArray[EntityState]
the generated commands
'''
time = np.array(toSequence(time))
self.isValid(True)
# generate an output trajectory based on the provided time
ncommands = len(time)
commands = np.empty(ncommands, dtype=EntityState)
commands[-1] = self._trajectories[0].getCommand(np.array([0]))
for i in range(self._startTimes.size-1):
mask = np.logical_and(time>=self._startTimes[i], time<=self._startTimes[i+1])
if np.sum(mask) == 0:
continue
elif self._trajectories[i].vectorized:
commands[mask] = self._trajectories[i].getCommand(time[mask])
else:
commands[mask] = (self._trajectories[j].getCommand(time[j]) for j in np.argwhere(mask).T[0])
return commands