!39085 修复 测试仓bug

Merge pull request !39085 from 王程浩/master
This commit is contained in:
i-robot 2022-07-29 07:28:27 +00:00 committed by Gitee
commit 11fddc374e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 3 deletions

View File

@ -3425,7 +3425,7 @@ class Tensor(Tensor_):
tmp = []
for choice in choicelist:
tmp.append(tensor_operator_registry.get('broadcast_to')(shape_choice)(choice))
choices = tensor_operator_registry.get('stack')(0)(tmp)
choices = tensor_operator_registry.get('stack')(tmp, 0)
if self.ndim == 0 or choices.ndim == 0:
raise ValueError(f"For 'Tensor.choose', the original tensor and the argument 'choices' cannot be scalars."
@ -3445,7 +3445,7 @@ class Tensor(Tensor_):
dim_shape = validator.expanded_shape(ndim, a.shape[i], i)
dim_grid = tensor_operator_registry.get('broadcast_to')(a.shape)(dim_grid.reshape(dim_shape))
grids.append(dim_grid)
grid = tensor_operator_registry.get('stack')(-1)(grids)
grid = tensor_operator_registry.get('stack')(grids, -1)
indices = tensor_operator_registry.get('concatenate')(-1)((a.reshape(a.shape + (1,)), grid))
return tensor_operator_registry.get('gather_nd')(choices, indices).astype(dtype)
@ -5267,7 +5267,7 @@ class CSRTensor(CSRTensor_):
if self.ndim != 2:
raise ValueError("Currently only support 2-D CSRTensor when converting to COOTensor.")
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices), 1)
return COOTensor(coo_indices, self.values, self.shape)
def to_dense(self):