stack functional interface bugfix

This commit is contained in:
jianghui58 2022-07-19 11:11:35 +08:00
parent c332f84d10
commit d779cc97e3
1 changed files with 3 additions and 3 deletions

View File

@ -3381,7 +3381,7 @@ class Tensor(Tensor_):
tmp = []
for choice in choicelist:
tmp.append(tensor_operator_registry.get('broadcast_to')(shape_choice)(choice))
choices = tensor_operator_registry.get('stack')(0)(tmp)
choices = tensor_operator_registry.get('stack')(tmp, 0)
if self.ndim == 0 or choices.ndim == 0:
raise ValueError(f"For 'Tensor.choose', the original tensor and the argument 'choices' cannot be scalars."
@ -3401,7 +3401,7 @@ class Tensor(Tensor_):
dim_shape = validator.expanded_shape(ndim, a.shape[i], i)
dim_grid = tensor_operator_registry.get('broadcast_to')(a.shape)(dim_grid.reshape(dim_shape))
grids.append(dim_grid)
grid = tensor_operator_registry.get('stack')(-1)(grids)
grid = tensor_operator_registry.get('stack')(grids, -1)
indices = tensor_operator_registry.get('concatenate')(-1)((a.reshape(a.shape + (1,)), grid))
return tensor_operator_registry.get('gather_nd')(choices, indices).astype(dtype)
@ -5114,7 +5114,7 @@ class CSRTensor(CSRTensor_):
if self.ndim != 2:
raise ValueError("Currently only support 2-D CSRTensor when converting to COOTensor.")
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
coo_indices = tensor_operator_registry.get("stack")((row_indices, self.indices), 1)
return COOTensor(coo_indices, self.values, self.shape)
def to_dense(self):