Source code for zonopyrobots.joint_reachable_set.offline_jrs

from __future__ import annotations
import torch
from .jrs_trig.process_jrs_trig import process_batch_JRS_trig as _process_batch_JRS_trig
from .jrs_trig.load_jrs_trig import preload_batch_JRS_trig as _preload_batch_JRS_trig
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import Tuple, Union
    from numpy import ndarray
    from torch import Tensor
    Array = Union[Tensor, ndarray]

[docs]class OfflineJRS: """ Wrapper for preloading and processing ARMTD style JRS tensors generated offline These tensors are generated offline using the MATLAB scripts in the jrs_trig/gen_jrs_trig folder. This provides a wrapper for some of the jrs_trig.load_jrs_trig and jrs_trig.process_jrs_trig functions to make it easier to use the JRS tensors. The JRS tensors are preloaded and processed in the __init__ function and then the __call__ function can be used to get the JRS and the corresponding rotatotopes for a given configuration and velocity. This specifically loads the tensors from the jrs_trig/jrs_trig_tensor_saved folder """
[docs] def __init__( self, device: torch.device = 'cpu', dtype: torch.dtype = torch.float, ): """ Wrapper for preloading and processing JRS tensors Args: device (torch.device, optional): The device to use for the JRS tensors. Defaults to 'cpu'. dtype (torch.dtype, optional): The dtype to use for the JRS tensors. Defaults to torch.float. """ from .jrs_trig.load_jrs_trig import g_ka self.jrs_tensor = _preload_batch_JRS_trig(device=device, dtype=dtype) self.g_ka = g_ka self.device = device self.dtype = dtype
def __call__( self, qpos: Array, qvel: Array, joint_axes: Array, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns the JRS and the corresponding rotatotopes for a given configuration and velocity Args: qpos (torch.Tensor): The configuration of the robot qvel (torch.Tensor): The velocity of the robot joint_axes (torch.Tensor): The joint axes of the robot Returns: Tuple[torch.Tensor, torch.Tensor]: The JRS and the corresponding rotatotopes """ qpos = torch.as_tensor(qpos, dtype=self.dtype, device=self.device) qvel = torch.as_tensor(qvel, dtype=self.dtype, device=self.device) joint_axes = torch.as_tensor(joint_axes, dtype=self.dtype, device=self.device) return _process_batch_JRS_trig(self.jrs_tensor, qpos, qvel, joint_axes)