Source code for zonopy.contset.interval.interval

'''
Define interval
Author: Qingyi Chen, Yongseok Kwon
Reference: CORA
'''

import torch
import zonopy.internal as zpi

[docs]class interval: r""" N-rank tensor intervals An interval is a set of real numbers that includes all numbers between two given numbers. Here, we define an interval as a set of real numbers given infinum and supremum tensors :math:`\underbar{X}` and :math:`\overline{\text{X}}` such that :math:`\underbar{X} \leq X \leq \overline{\text{X}}`. .. math:: \mathcal{I} := \left\{ x \in \mathbb{R}^{n,m,...} \; \middle\vert \; \begin{array}{c} \underbar{x}_{i,j,\ldots} \leq x_{i,j,\ldots} \leq \overline{\text{x}}_{i,j,\ldots} \\ \forall{i=1,\ldots,n} \\ \forall{j=1,\ldots,m} \\ \vdots \end{array} \right\} """
[docs] def __init__(self, inf=None, sup=None, dtype=None, device=None): """ Create an interval If ``inf`` and ``sup`` are both ``None``, an empty interval is created. If only one of ``inf`` or ``sup`` is ``None``, the interval is created as a point interval where ``inf = sup``. Args: inf (torch.Tensor, optional): infimum of the interval. Defaults to None. sup (torch.Tensor, optional): supremum of the interval. Defaults to None. dtype (torch.dtype, optional): data type of the interval. If None, the data type is inferred from the input tensors. Defaults to None. device (torch.device, optional): device of the interval. If None, the device is inferred from the input tensors. Defaults to None. Raises: AssertionError: If the shapes of ``inf`` and ``sup`` do not match. AssertionError: If the devices of ``inf`` and ``sup`` do not match. AssertionError: If ``inf`` is not less than or equal to ``sup`` entry-wise and :const:`zonopy.internal.__debug_extra__` is True. """ if inf is None and sup is None: inf = torch.empty(0, dtype=dtype, device=device) sup = torch.empty(0, dtype=dtype, device=device) elif inf is None: inf = sup elif sup is None: sup = inf # Make sure that the input is a tensor inf = torch.as_tensor(inf) sup = torch.as_tensor(sup) # Promote the data type if necessary if dtype is None: dtype = torch.promote_types(inf.dtype, sup.dtype) inf = inf.to(dtype=dtype, device=device) sup = sup.to(dtype=dtype, device=device) assert inf.shape == sup.shape, "inf and sup are expected to be of the same shape" assert inf.device == sup.device, "inf and sup are expected to be on the same device" if zpi.__debug_extra__: assert torch.all(inf <= sup), "inf should be less than sup entry-wise" self.__inf = inf self.__sup = sup
@property def dtype(self): ''' The data type of an interval properties return torch.float or torch.double ''' return self.inf.dtype @property def device(self): ''' The device of an interval properties return 'cpu', 'cuda:0', or ... ''' return self.inf.device @property def inf(self): ''' The infimum of an interval return <torch.Tensor> ,shape [n,m] ''' return self.__inf @inf.setter def inf(self,value): ''' Set value of the infimum of an interval ''' assert self.__inf.shape == value.shape self.__inf = value @property def sup(self): ''' The supremum of an interval return <torch.Tensor> ,shape [n,m] ''' return self.__sup @sup.setter def sup(self,value): ''' Set value of the supremum of an interval ''' assert self.__inf.shape == value.shape self.__inf = value @property def shape(self): ''' The shape of elements (infimum or supremum) of an interval ''' return tuple(self.__inf.shape)
[docs] def to(self,dtype=None,device=None): ''' Change the device and data type of an interval dtype: torch.float or torch.double device: 'cpu', 'gpu', 'cuda:0', ... ''' inf = self.__inf.to(dtype=dtype, device=device, non_blocking=True) sup = self.__sup.to(dtype=dtype, device=device, non_blocking=True) return interval(inf,sup)
[docs] def cpu(self): ''' Change the device of an interval to CPU ''' inf = self.__inf.cpu() sup = self.__sup.cpu() return interval(inf,sup)
def __repr__(self): ''' Representation of an interval as a text return <str>, ex. interval( inf([0., 0.]), sup([1., 1.]) ) ''' intv_repr1 = f"interval(\n"+str(self.__inf)+"," intv_repr2 = "\n"+str(self.__sup) intv_repr = intv_repr1.replace('tensor(',' inf(') + intv_repr2.replace('tensor(',' sup(') intv_repr = intv_repr.replace(' ',' ') return intv_repr+"\n )" def __add__(self, other): ''' Overloaded '+' operator for addition or Minkowski sum self: <interval> other: <torch.Tensor> OR <interval> return <interval> ''' if isinstance(other, interval): inf, sup = self.__inf+other.__inf, self.__sup+other.__sup elif isinstance(other, torch.Tensor) or isinstance(other, (int,float)): inf, sup = self.__inf+other, self.__sup+other else: assert False, f'the other object should be interval or numberic, but {type(other)}.' return interval(inf,sup) __radd__ = __add__ # '+' operator is commutative. def __sub__(self,other): ''' Overloaded '-' operator for substraction or Minkowski difference self: <interval> other: <torch.Tensor> OR <interval> return <interval> ''' return self.__add__(-other) def __rsub__(self,other): ''' Overloaded reverted '-' operator for substraction or Minkowski difference self: <interval> other: <torch.Tensor> OR <interval> return <interval> ''' return -self.__sub__(other) def __iadd__(self,other): ''' Overloaded '+=' operator for addition or Minkowski sum self: <interval> other: <torch.Tensor> OR <interval> return <interval> ''' return self+other def __isub__(self,other): ''' Overloaded '-=' operator for substraction or Minkowski difference self: <interval> other: <torch.Tensor> OR <interval> return <interval> ''' return self-other def __pos__(self): ''' Overloaded unary '+' operator for an interval ifself self: <interval> return <interval> ''' return self def __neg__(self): ''' Overloaded unary '-' operator for negation of an interval self: <interval> return <interval> ''' return interval(-self.__sup,-self.__inf) def __mul__(self, other): if isinstance(other,(int,float)): if other >= 0: return interval(other * self.__inf, other * self.__sup) else: return interval(other * self.__sup, other * self.__inf) if self.numel() == 1 and isinstance(other, interval): # candidates = other.inf.repeat(4,1).reshape((4,) + other.shape) candidates = torch.empty((4,) + other.shape, dtype=other.inf.dtype, device=other.inf.device) candidates[0] = self.__inf * other.__inf candidates[1] = self.__inf * other.__sup candidates[2] = self.__sup * other.__inf candidates[3] = self.__sup * other.__sup new_inf = torch.min(candidates,dim=0).values new_sup = torch.max(candidates,dim=0).values return interval(new_inf, new_sup) elif isinstance(other, interval) and (other.numel() == 1 or self.numel() == other.numel()): # candidates = self.inf.repeat(4,1).reshape((4,) + self.shape) candidates = torch.empty((4,) + self.shape, dtype=self.inf.dtype, device=self.inf.device) candidates[0] = self.__inf * other.__inf candidates[1] = self.__inf * other.__sup candidates[2] = self.__sup * other.__inf candidates[3] = self.__sup * other.__sup new_inf = torch.min(candidates,dim=0).values new_sup = torch.max(candidates,dim=0).values return interval(new_inf, new_sup) elif isinstance(other, torch.Tensor): candidates = torch.empty((2,) + self.shape, dtype=self.inf.dtype, device=self.inf.device) candidates[0] = self.__inf * other candidates[1] = self.__sup * other new_inf = torch.min(candidates,dim=0).values new_sup = torch.max(candidates,dim=0).values return interval(new_inf, new_sup) else: assert False, "such multiplication is not implemented yet" __rmul__ = __mul__ def __getitem__(self, pos): inf = self.__inf[pos] sup = self.__sup[pos] return interval(inf, sup) def __setitem__(self, pos, value): # set one interval if isinstance(value, interval): self.__inf[pos] = value.__inf self.__sup[pos] = value.__sup else: self.__inf[pos] = value self.__sup[pos] = value def __len__(self) -> int: """ Returns the length of the interval Returns: int: length of the interval (same as the first tensor dimension) """ return len(self.__inf)
[docs] def dim(self) -> int: """ Returns the number of dimensions of the interval Returns: int: number of dimensions of the interval """ return self.__inf.dim()
[docs] def t(self): """ Transposes the interval Returns: interval: transposed interval """ return interval(self.__inf.t(), self.__sup.t())
[docs] def numel(self) -> int: """ Returns the total number of elements in the interval Returns: int: number of elements in the interval """ return self.__inf.numel()
[docs] def center(self) -> torch.Tensor: """ Compute the center of the interval The center of the interval is the midpoint of the infimum and supremum. Returns: torch.Tensor: center of the interval """ return (self.inf+self.sup)/2
[docs] def rad(self) -> torch.Tensor: """ Compute the radius of the interval The radius of the interval is half of the difference between the supremum and infimum. It can be viewed as the distance from the center to the infimum or supremum. Returns: torch.Tensor: radius of the interval """ return (self.sup-self.inf)/2
#if __name__ == '__main__':