modify 测试仓bug

This commit is contained in:
chenghaowang 2022-07-28 17:15:11 +08:00
parent 4a7c924c6e
commit 86cf41092e
1 changed files with 3 additions and 3 deletions

View File

@ -3421,7 +3421,7 @@ class Tensor(Tensor_):
tmp = []
for choice in choicelist:
tmp.append(tensor_operator_registry.get('broadcast_to')(shape_choice)(choice))
choices = tensor_operator_registry.get('stack')(0)(tmp)
choices = tensor_operator_registry.get('stack')(tmp, 0)
if self.ndim == 0 or choices.ndim == 0:
raise ValueError(f"For 'Tensor.choose', the original tensor and the argument 'choices' cannot be scalars."
@ -3441,7 +3441,7 @@ class Tensor(Tensor_):
dim_shape = validator.expanded_shape(ndim, a.shape[i], i)
dim_grid = tensor_operator_registry.get('broadcast_to')(a.shape)(dim_grid.reshape(dim_shape))
grids.append(dim_grid)
grid = tensor_operator_registry.get('stack')(-1)(grids)
grid = tensor_operator_registry.get('stack')(grids, -1)
indices = tensor_operator_registry.get('concatenate')(-1)((a.reshape(a.shape + (1,)), grid))
return tensor_operator_registry.get('gather_nd')(choices, indices).astype(dtype)
@ -5263,7 +5263,7 @@ class CSRTensor(CSRTensor_):
if self.ndim != 2:
raise ValueError("Currently only support 2-D CSRTensor when converting to COOTensor.")
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices), 1)
return COOTensor(coo_indices, self.values, self.shape)
def to_dense(self):