fix csr mul broadcast error

This commit is contained in:
yanglf1121 2022-03-16 21:52:29 +08:00
parent cf01c631b5
commit 3aecb22983
4 changed files with 4 additions and 11 deletions

2
akg

@ -1 +1 @@
Subproject commit e3f2411858e34499fce13ec00ea35e1292d441b1
Subproject commit 50d3082fdb2d084fff8509b6fbbdab5bc1e75e5c

View File

@ -46,10 +46,6 @@ inline void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
if (sparse_shp.size() < 1) {
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
}
if (dense_shp[0] != sparse_shp[0]) {
MS_EXCEPTION(mindspore::ValueError)
<< "Currently, dense tensor and sparse tensor shapes must equal in first dimension.";
}
for (size_t i = 0; i < sparse_shp.size(); i++) {
auto s = sparse_shp[i];
auto d = dense_shp[i];

View File

@ -2806,8 +2806,8 @@ class CSRTensor(CSRTensor_):
Examples:
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore import dtype as mstype
>>> indptr = Tensor([0, 1, 2], dtype=ms.int32)
>>> indices = Tensor([0, 1], dtype=ms.int32)
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> indices = Tensor([0, 1], dtype=mstype.int32)
>>> values = Tensor([2, 1], dtype=mstype.float32)
>>> dense_shape = (2, 4)
>>> csr_tensor = CSRTensor(indptr, indices, values, dense_shape)

View File

@ -171,8 +171,6 @@ def csr_mul(x, y):
Supported Platforms:
``GPU`` ``CPU``
"""
if x.shape[0] != 1 and y.shape[0] == 1:
y = y.expand_as(x)
return _csr_ops.CSRMul()(x, y)
def csr_div(x, y):
@ -195,8 +193,6 @@ def csr_div(x, y):
Supported Platforms:
``GPU`` ``CPU``
"""
if x.shape[0] != 1 and y.shape[0] == 1:
y = y.expand_as(x)
return _csr_ops.CSRDiv()(x, y)
csr_mv = _csr_ops.CSRMV()
@ -974,6 +970,7 @@ coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
def print_info(info):
print(info)
def make_sparse_tensor(indices, values, dense_shape):
"""Call make_coo_tensor in this function."""
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +