from rtd.planner.trajectory import Trajectory, InvalidTrajectory
from rtd.entity.states import ArmRobotState, EntityState
from rtd.planner.trajopt import TrajOptProps
from rtd.functional.vectools import rescale
from armour.reachsets import JRSInstance
from armour.legacy import bernstein_to_poly, match_deg5_bernstein_coefficients
import numpy as np
from rtd.util.mixins.Typings import Vecnp
[docs]class BernsteinArmTrajectory(Trajectory):
[docs] def __init__(self, trajOptProps: TrajOptProps, startState: ArmRobotState, jrsInstance: JRSInstance):
# initialize base classes
Trajectory.__init__(self)
# set properties
self.vectorized = True
# Initial parameters from the robot used to calculate the desired
# trajectory
self.alpha = None
self.q_end = None
# The JRS which contains the center and range to scale the
# parameters
self.jrsInstance = None
# other properties
self.trajOptProps = trajOptProps
self.startState = startState
self.jrsInstance = jrsInstance
[docs] def setParameters(self, trajectoryParams: Vecnp, startState: ArmRobotState = None,
jrsInstance: JRSInstance = None):
'''
A validated method to set the parameters for the trajectory
'''
self.trajectoryParams = trajectoryParams
if self.trajectoryParams.size > self.jrsInstance.n_q:
self.trajectoryParams = self.trajectoryParams[:self.jrsInstance.n_q]
if startState is not None:
self.startState = startState
if jrsInstance is not None:
self.jrsInstance = jrsInstance
# perform internal update
self.internalUpdate()
[docs] def validate(self, throwOnError: bool = False) -> bool:
'''
Validate that the trajectory is fully characterized
'''
# non-empty
valid = (self.trajectoryParams is not None)
valid &= (self.jrsInstance is not None)
valid &= (self.startState is not None)
# trajectory params makes sense
valid &= (self.trajectoryParams.size == self.jrsInstance.n_q)
# throw error if wanted
if not valid and throwOnError:
raise InvalidTrajectory("Called trajectory object does not have complete parameterization!")
return valid
[docs] def internalUpdate(self):
'''
Update internal parameters to reduce long term calculations
'''
# internal update if valid
if not self.validate():
return
# get the desired final position
jout = self.jrsInstance.output_range
jin = self.jrsInstance.input_range
q_goal = rescale(self.trajectoryParams, jout[0], jout[1], jin[0], jin[1])
q_goal = self.startState.q + q_goal
n_q = self.jrsInstance.n_q
self.alpha = np.zeros((n_q, 6))
for j in range(n_q):
beta = match_deg5_bernstein_coefficients([
self.startState.position[j],
self.startState.velocity[j],
self.startState.acceleration[j],
q_goal[j],
0, 0],
self.trajOptProps.horizonTime
)
self.alpha[j,:] = bernstein_to_poly(beta, 6)
# precompute end position
self.q_end = q_goal
[docs] def getCommand(self, time: Vecnp) -> EntityState:
# Do a parameter check and time check, and throw if anything is
# invalid.
self.validate(throwOnError=True)
t_shifted = np.atleast_1d(np.asarray(time - self.startState.time))
if np.any(t_shifted < 0):
raise InvalidTrajectory("Invalid time provided to PiecewiseArmTrajectory")
t_size = t_shifted.size
horizon_mask = t_shifted < self.trajOptProps.horizonTime
t_masked_scaled = t_shifted[horizon_mask] / self.trajOptProps.horizonTime
t_masked_size = t_masked_scaled.size
n_q = self.jrsInstance.n_q
# original implementation adapted
q_des = np.zeros((n_q, t_masked_size))
q_dot_des = np.zeros((n_q, t_masked_size))
q_ddot_des = np.zeros((n_q, t_masked_size))
for j in range(n_q):
for coef_idx in range(6):
q_des[j,:] += self.alpha[j,coef_idx]*np.power(t_masked_scaled, coef_idx)
if coef_idx > 0:
q_dot_des[j,:] += coef_idx*self.alpha[j,coef_idx]*np.power(t_masked_scaled, coef_idx-1)
if coef_idx > 1:
q_ddot_des[j,:] += coef_idx*(coef_idx-1)*self.alpha[j,coef_idx]*np.power(t_masked_scaled, coef_idx-2)
# move to a combined state variable
pos_idx = np.arange(n_q)
vel_idx = pos_idx + n_q
acc_idx = vel_idx + n_q
state = np.zeros((n_q*3, t_size))
state[np.ix_(pos_idx, horizon_mask)] = q_des
state[np.ix_(vel_idx, horizon_mask)] = q_dot_des / self.trajOptProps.horizonTime
state[np.ix_(acc_idx, horizon_mask)] = q_ddot_des / self.trajOptProps.horizonTime**2
# update all state times after the horizon time
state[np.ix_(pos_idx, np.logical_not(horizon_mask))] = np.reshape(self.q_end, (self.q_end.size,1))
# Generate the output.
command = ArmRobotState(pos_idx, vel_idx, acc_idx)
command.time = time
command.state = state
return command