"""
Define class for batch zonotope
Author: Yongseok Kwon
Reference: CORA
"""
import torch
from zonopy.contset.polynomial_zonotope.batch_poly_zono import batchPolyZonotope
from zonopy.contset.interval.interval import interval
from zonopy.contset.zonotope.utils import pickedBatchGenerators
from zonopy.contset.zonotope.zono import zonotope
from ..gen_ops import (
_add_genzono_impl,
_add_genzono_num_impl,
_mul_genzono_num_impl,
)
[docs]class batchZonotope:
r''' Batched 1D zonotope class
Batched form of the :class:`zonotope` class.
This class is used to represent a batch of zonotopes over arbitrary batch dimensions,
where each zonotope in the batch is expanded to have the same number of generators.
This results in a :math:`\mathbf{Z}` tensor of shape :math:`B_1 \times B_2 \times \ldots \times (N+1) \times d`.
Refer to the :class:`zonotope` class for more details on zonotopes.
'''
[docs] def __init__(self, Z, dtype=None, device=None):
r""" Initialize a batch zonotope
Args:
Z (torch.Tensor): The :math:`\mathbf{Z}` tensor of shape :math:`B_1 \times B_2 \times \ldots \times (N+1) \times d`.
dtype (torch.dtype, optional): The data type of the batch zonotope. If None, the data type is inferred. Defaults to None.
device (str, optional): The device of the batch zonotope. If None, the device is inferred. Defaults to None.
Raises:
AssertionError: If the rank of Z is less than 3.
"""
################ may not need these for speed ################
# Make sure Z is a tensor
if not isinstance(Z, torch.Tensor) and dtype is None:
dtype = torch.get_default_dtype()
Z = torch.as_tensor(Z, dtype=dtype, device=device)
assert len(Z.shape) > 2, f'The dimension of Z input should be either 1 or 2, not {len(Z.shape)}.'
##############################################################
self.Z = Z
self.batch_dim = len(Z.shape) - 2
self.batch_idx_all = tuple([slice(None) for _ in range(self.batch_dim)])
def __getitem__(self,idx):
Z = self.Z[idx]
if len(Z.shape) > 2:
return batchZonotope(Z)
else:
return zonotope(Z)
@property
def batch_shape(self):
return self.Z.shape[:self.batch_dim]
@property
def dtype(self):
'''
The data type of a batch zonotope properties
return torch.float or torch.double
'''
return self.Z.dtype
@property
def device(self):
'''
The device of a batch zonotope properties
return 'cpu', 'cuda:0', or ...
'''
return self.Z.device
@property
def center(self):
'''
The center of a batch zonotope
return <torch.Tensor>
, shape [B1, B2, .. , Bb, nx]
'''
return self.Z[self.batch_idx_all+(0,)]
@center.setter
def center(self,value):
'''
Set value of the center
'''
self.Z[self.batch_idx_all+(0,)] = value
@property
def generators(self):
'''
Generators of a batch zonotope
return <torch.Tensor>
, shape [B1, B2, .. , Bb, N, nx]
'''
return self.Z[self.batch_idx_all+(slice(1,None),)]
@generators.setter
def generators(self,value):
'''
Set value of generators
'''
self.Z[self.batch_idx_all+(slice(1,None),)] = value
@property
def shape(self):
'''
The shape of vector elements (ex. center) of a layer of a batch zonotope
return <tuple>, (nx,)
'''
return (self.Z.shape[-1],)
@property
def dimension(self):
'''
The dimension of a batch zonotope
return <int>, nx
'''
return self.Z.shape[-1]
@property
def n_generators(self):
'''
The number of generators of a batch zonotope
return <int>, N
'''
return self.Z.shape[-2]-1
[docs] def to(self,dtype=None,device=None):
'''
Change the device and data type of a batch zonotope
dtype: torch.float or torch.double
device: 'cpu', 'gpu', 'cuda:0', ...
'''
Z = self.Z.to(dtype=dtype, device=device, non_blocking=True)
return batchZonotope(Z)
[docs] def cpu(self):
'''
Change the device of a batch zonotope to CPU
'''
Z = self.Z.cpu()
return batchZonotope(Z)
def __repr__(self):
'''
Representation of a batch zonotope as a text
return <str>,
ex. batchZonotope([[[0., 0., 0.],[1., 0., 0.]],[[0., 0., 0.],[2., 0., 0.]]])
'''
return str(self.Z).replace('tensor','batchZonotope')
def __add__(self,other):
'''
Overloaded '+' operator for addition or Minkowski sum
self: <batchZonotope>
other: <torch.Tensor> OR <zonotope> or <batchZonotope>
return <batchZonotope>
'''
if isinstance(other, (torch.Tensor, float, int)):
Z = _add_genzono_num_impl(self, other)
return batchZonotope(Z)
elif isinstance(other, (zonotope, batchZonotope)):
Z = _add_genzono_impl(self, other, batch_shape=self.batch_shape)
return batchZonotope(Z)
else:
return NotImplemented
__radd__ = __add__ # '+' operator is commutative.
def __sub__(self,other):
'''
Overloaded '-' operator for substraction or Minkowski difference
self: <batchZonotope>
other: <torch.Tensor> OR <zonotope> or <batchZonotope>
return <batchZonotope>
'''
import warnings
warnings.warn(
"PZ subtraction as addition of negative is deprecated and will be removed to reduce confusion!",
DeprecationWarning)
return self.__add__(-other)
def __rsub__(self,other):
'''
Overloaded reverted '-' operator for substraction or Minkowski difference
self: <batchZonotope>
other: <torch.Tensor> OR <zonotope> or <batchZonotope>
return <batchZonotope>
'''
import warnings
warnings.warn(
"PZ subtraction as addition of negative is deprecated and will be removed to reduce confusion!",
DeprecationWarning)
return -self.__sub__(other)
def __pos__(self):
'''
Overloaded unary '+' operator for a batch zonotope ifself
self: <zonotope>
return <zonotope>
'''
return self
def __neg__(self):
'''
Overloaded unary '-' operator for negation of a batch zonotope
self: <batchZonotope>
return <batchZonotope>
'''
Z = torch.clone(self.Z)
Z[...,0,:] *= -1
return batchZonotope(Z)
def __rmatmul__(self,other):
'''
Overloaded reverted '@' operator for matrix multiplication on vector elements of a batch zonotope
self: <batchZonotope>
other: <torch.Tensor>
return <batchZonotope>
'''
assert isinstance(other, torch.Tensor), f'The other object should be torch tensor, but {type(other)}.'
Z = self.Z@other.transpose(-2,-1)
return batchZonotope(Z)
def __mul__(self,other):
'''
Overloaded reverted '*' operator for scaling a batch zonotope
self: <batchZonotope>
other: <int> or <float>
return <batchZonotope>
'''
if isinstance(other,(torch.Tensor,int,float)):
Z = _mul_genzono_num_impl(self, other)
return batchZonotope(Z)
else:
return NotImplemented
__rmul__ = __mul__ # '*' operator is commutative.
def __len__(self):
return self.Z.shape[0]
[docs] def slice(self,slice_dim,slice_pt):
'''
slice zonotope on specified point in a certain dimension
self: <zonotope>
slice_dim: <torch.Tensor> or <list> or <int>
, shape []
slice_pt: <torch.Tensor> or <list> or <float> or <int>
, shape []
return <zonotope>
'''
if isinstance(slice_dim, list):
slice_dim = torch.tensor(slice_dim,dtype=torch.long,device=self.device)
elif isinstance(slice_dim, int) or (isinstance(slice_dim, torch.Tensor) and len(slice_dim.shape)==0):
slice_dim = torch.tensor([slice_dim],dtype=torch.long,device=self.device)
if isinstance(slice_pt, list):
slice_pt = torch.tensor(slice_pt,dtype=self.dtype,device=self.device)
elif isinstance(slice_pt, int) or isinstance(slice_pt, float) or (isinstance(slice_pt, torch.Tensor) and len(slice_pt.shape)==0):
slice_pt = torch.tensor([slice_pt],dtype=self.dtype,device=self.device)
assert isinstance(slice_dim, torch.Tensor) and isinstance(slice_pt, torch.Tensor), 'Invalid type of input'
assert len(slice_dim.shape) ==1, 'slicing dimension should be 1-dim component.'
#assert slice_pt.shape[:-1] ==self.batch_shape, 'slicing point should be (batch+1)-dim component.'
assert len(slice_dim) == slice_pt.shape[-1], f'The number of slicing dimension ({len(slice_dim)}) and the number of slicing point ({slice_pt.shape[-1]}) should be the same.'
N = len(slice_dim)
slice_dim, ind = torch.sort(slice_dim)
slice_pt = slice_pt[(slice(None),)*(len(slice_pt.shape)-1)+(ind,)]
c = self.center
G = self.generators
G_dim = G[self.batch_idx_all+(slice(None),slice_dim)]
non_zero_idx = G_dim != 0
assert torch.all(torch.sum(non_zero_idx,-2)==1), 'There should be one generator for each slice index.'
slice_idx = non_zero_idx.transpose(-2,-1).nonzero()
#slice_idx = torch.any(non_zero_idx,-1)
slice_c = c[self.batch_idx_all+(slice_dim,)]
ind = tuple(slice_idx[:,:-2].T)
slice_g = G_dim[ind+(slice_idx[:,-1],slice_idx[:,-2])].reshape(self.batch_shape+(N,))
slice_lambda = (slice_pt-slice_c)/slice_g
assert not (abs(slice_lambda)>1).any(), 'slice point is ouside bounds of reach set, and therefore is not verified'
Z = torch.cat((c.unsqueeze(-2) + slice_lambda.unsqueeze(-2)@G[ind+(slice_idx[:,-1],)].reshape(self.batch_shape+(N,self.dimension)),G[~non_zero_idx.any(-1)].reshape(self.batch_shape+(-1,self.dimension))),-2)
return batchZonotope(Z)
[docs] def project(self,dim=[0,1]):
'''
The projection of a batch zonotope onto the specified dimensions
self: <batchZonotope>
dim: <int> or <list> or <torch.Tensor> dimensions for prjection
return <batchZonotope>
'''
Z = self.Z[self.batch_idx_all+(slice(None),dim)]
return batchZonotope(Z)
[docs] def polygon(self,nan=True):
'''
NOTE: this is unstable for zero generators
converts a 2-d zonotope into a polygon as vertices
self: <zonotope>
return <torch.Tensor>, <torch.float64>
'''
dim = 2
z = self.deleteZerosGenerators()
c = z.center[self.batch_idx_all+(slice(2),)].unsqueeze(-2)#.repeat((1,)*(self.batch_dim+2))
G = torch.clone(z.generators[self.batch_idx_all+(slice(None),slice(2))])
x_idx = self.batch_idx_all+(slice(None),0)
y_idx = self.batch_idx_all+(slice(None),1)
G_y = G[y_idx]
x_max = torch.sum(abs(G[x_idx]),-1)
y_max = torch.sum(abs(G_y),-1)
G[G_y<0] = - G[G_y<0]
if nan:
G[torch.linalg.norm(G,dim=-1)==0] = torch.nan
angles = torch.atan2(G[y_idx],G[x_idx])
ang_idx = torch.argsort(angles,dim=-1).unsqueeze(-1).repeat((1,)*(self.batch_dim+1)+(2,))
vertices_half = torch.cat((torch.zeros(self.batch_shape+(1,)+(2,),dtype=self.dtype,device=self.device),2*G.gather(-2,ang_idx).cumsum(axis=self.batch_dim)),-2)
vertices_half[x_idx] += (x_max - torch.max(vertices_half[x_idx].nan_to_num(-torch.inf),dim=-1)[0]).unsqueeze(-1)
vertices_half[y_idx] -= y_max.unsqueeze(-1)
if nan:
last_idx = (z.n_generators-angles.isnan().sum(-1)).reshape(self.batch_shape+(1,1)).repeat((1,)*self.batch_dim+(1,2))
temp = (vertices_half[self.batch_idx_all+(0,)].unsqueeze(-2)+ vertices_half.gather(-2,last_idx))
else:
temp = (vertices_half[self.batch_idx_all+(0,)]+ vertices_half[self.batch_idx_all+(-1,)]).unsqueeze(-2)
full_vertices = torch.cat((vertices_half,-vertices_half[self.batch_idx_all+(slice(1,None),)] + temp),dim=self.batch_dim) + c
return full_vertices
[docs] def polytope(self,combs=None):
'''
converts a zonotope from a G- to a H- representation
P
comb
isDeg
NOTE: there is a possibility with having nan value on the output, so you might wanna use nan_to_num()
OR, just use python built-in max function instead of torch.max or np.max.
'''
c = self.center
G = torch.clone(self.generators)
h = torch.linalg.vector_norm(G,dim=-1)
h_sort, indicies = torch.sort(h,dim=-1,descending=True)
h_nonzero = h_sort > 1e-6
h_zero_all = ((h_nonzero).sum(tuple(range(self.batch_dim))) ==0)
#G[~h_nonzero] = 0 # make sure everything less than 1e-6 to be actual zero, so that non-removable zero padding can be converged into nan value on the output value
# NOTE: for some reason the above one didnt work out
if torch.any(h_zero_all):
first_reduce_idx = torch.nonzero(h_zero_all).squeeze(-1)[0]
G=G.gather(self.batch_dim,indicies.unsqueeze(-1).repeat((1,)*(self.batch_dim+1)+self.shape))[self.batch_idx_all+(slice(None,first_reduce_idx),)]
n_gens, dim = G.shape[-2:]
if dim == 1:
C = G/torch.linalg.vector_norm(G,dim=-1).unsqueeze(-1)
elif dim == 2:
x_idx = self.batch_idx_all+(slice(None),slice(0,1))
y_idx = self.batch_idx_all+(slice(None),slice(1,2))
C = torch.cat((-G[y_idx],G[x_idx]),-1)
C = C/torch.linalg.vector_norm(C,dim=-1).unsqueeze(-1)
elif dim == 3:
# not complete for example when n_gens < dim-1; n_gens =0 or n_gens =1
if combs is None or n_gens >= len(combs):
comb = torch.combinations(torch.arange(n_gens),r=dim-1)
else:
comb = combs[n_gens]
Q = torch.cat((G[self.batch_idx_all+(comb[:,0],)],G[self.batch_idx_all+(comb[:,1],)]),dim=-1)
temp1 = (Q[self.batch_idx_all+(slice(None),1)]*Q[self.batch_idx_all+(slice(None),5)]-Q[self.batch_idx_all+(slice(None),2)]*Q[self.batch_idx_all+(slice(None),4)]).unsqueeze(-1)
temp2 = (-Q[self.batch_idx_all+(slice(None),0)]*Q[self.batch_idx_all+(slice(None),5)]+Q[self.batch_idx_all+(slice(None),2)]*Q[self.batch_idx_all+(slice(None),3)]).unsqueeze(-1)
temp3 = (Q[self.batch_idx_all+(slice(None),0)]*Q[self.batch_idx_all+(slice(None),4)]-Q[self.batch_idx_all+(slice(None),1)]*Q[self.batch_idx_all+(slice(None),3)]).unsqueeze(-1)
C = torch.cat((temp1,temp2,temp3),dim=-1)
C = C/torch.norm(C,dim=-1,keepdim=True)
elif dim >=4 and dim<=7:
assert False
else:
assert False
deltaD = torch.sum(abs(C@self.generators.transpose(-2,-1)),dim=-1)
d = (C@c.unsqueeze(-1)).squeeze(-1)
PA = torch.cat((C,-C),dim=-2)
Pb = torch.cat((d+deltaD,-d+deltaD),dim=-1)
# NOTE: torch.nan_to_num()
return PA, Pb
[docs] def deleteZerosGenerators(self,sorted=False,sort=False):
'''
delete zero vector generators
self: <zonotope>
return <zonotope>
'''
if sorted:
non_zero_idxs = torch.sum(torch.any(self.generators!=0,-1),tuple(range(self.batch_dim))) != 0
g_red = self.generators[self.batch_idx_all+(non_zero_idxs,)]
else:
zero_idxs = torch.all(self.generators==0,axis=-1).to(torch.uint8)
# ind = zero_idxs.to(dtype=torch.float).sort(-1)[1].unsqueeze(-1).repeat((1,)*(self.batch_dim+1)+self.shape)
ind = zero_idxs.sort(-1)[1].unsqueeze(-1).repeat((1,)*(self.batch_dim+1)+self.shape)
max_non_zero_len = (~zero_idxs).sum(-1).max()
g_red = self.generators.gather(-2,ind)[self.batch_idx_all+(slice(None,max_non_zero_len),)]
Z = torch.cat((self.center.unsqueeze(self.batch_dim),g_red),self.batch_dim)
return batchZonotope(Z)
[docs] def reduce(self,order,option='girard'):
if option == 'girard':
Z = self.deleteZerosGenerators()
if order == 1:
center, G = Z.center, Z.generators
d = torch.sum(abs(G),-2)
Gbox = torch.diag_embed(d)
ZRed= torch.cat((center.unsqueeze(self.batch_dim),Gbox),-2)
else:
center, Gunred, Gred = pickedBatchGenerators(Z,order)
d = torch.sum(abs(Gred),-2)
Gbox = torch.diag_embed(d)
ZRed= torch.cat((center.unsqueeze(self.batch_dim),Gunred,Gbox),-2)
return batchZonotope(ZRed)
else:
assert False, 'Invalid reduction option'
[docs] def to_polyZonotope(self,dim=None,id=None):
'''
convert zonotope to polynomial zonotope
self: <zonotope>
dim: <int>, dimension to take as sliceable
return <polyZonotope>
'''
if dim is None:
return batchPolyZonotope(self.Z,0)
assert isinstance(dim,int) and dim <= self.dimension
idx = self.generators[self.batch_idx_all+(slice(None),dim)] == 0
assert ((~idx).sum(-1)==1).all(), 'sliceable generator should be one for the dimension.'
Z = torch.cat((self.center.unsqueeze(-2),self.generators[~idx].reshape(self.batch_shape+(-1,self.dimension)),self.generators[idx].reshape(self.batch_shape+(-1,self.dimension))),-2)
return batchPolyZonotope(Z,1,id=id)
[docs] def to_interval(self):
c = self.center
delta = torch.sum(abs(self.Z),self.batch_dim) - abs(c)
leftLimit, rightLimit = c -delta, c + delta
return interval(leftLimit,rightLimit)