stack functional interface bugfix
This commit is contained in:
parent
c332f84d10
commit
d779cc97e3
|
@ -3381,7 +3381,7 @@ class Tensor(Tensor_):
|
|||
tmp = []
|
||||
for choice in choicelist:
|
||||
tmp.append(tensor_operator_registry.get('broadcast_to')(shape_choice)(choice))
|
||||
choices = tensor_operator_registry.get('stack')(0)(tmp)
|
||||
choices = tensor_operator_registry.get('stack')(tmp, 0)
|
||||
|
||||
if self.ndim == 0 or choices.ndim == 0:
|
||||
raise ValueError(f"For 'Tensor.choose', the original tensor and the argument 'choices' cannot be scalars."
|
||||
|
@ -3401,7 +3401,7 @@ class Tensor(Tensor_):
|
|||
dim_shape = validator.expanded_shape(ndim, a.shape[i], i)
|
||||
dim_grid = tensor_operator_registry.get('broadcast_to')(a.shape)(dim_grid.reshape(dim_shape))
|
||||
grids.append(dim_grid)
|
||||
grid = tensor_operator_registry.get('stack')(-1)(grids)
|
||||
grid = tensor_operator_registry.get('stack')(grids, -1)
|
||||
indices = tensor_operator_registry.get('concatenate')(-1)((a.reshape(a.shape + (1,)), grid))
|
||||
return tensor_operator_registry.get('gather_nd')(choices, indices).astype(dtype)
|
||||
|
||||
|
@ -5114,7 +5114,7 @@ class CSRTensor(CSRTensor_):
|
|||
if self.ndim != 2:
|
||||
raise ValueError("Currently only support 2-D CSRTensor when converting to COOTensor.")
|
||||
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
|
||||
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
|
||||
coo_indices = tensor_operator_registry.get("stack")((row_indices, self.indices), 1)
|
||||
return COOTensor(coo_indices, self.values, self.shape)
|
||||
|
||||
def to_dense(self):
|
||||
|
|
Loading…
Reference in New Issue