forked from mindspore-Ecosystem/mindspore
fix csr mul broadcast error
This commit is contained in:
parent
cf01c631b5
commit
3aecb22983
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit e3f2411858e34499fce13ec00ea35e1292d441b1
|
||||
Subproject commit 50d3082fdb2d084fff8509b6fbbdab5bc1e75e5c
|
|
@ -46,10 +46,6 @@ inline void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
|
|||
if (sparse_shp.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
|
||||
}
|
||||
if (dense_shp[0] != sparse_shp[0]) {
|
||||
MS_EXCEPTION(mindspore::ValueError)
|
||||
<< "Currently, dense tensor and sparse tensor shapes must equal in first dimension.";
|
||||
}
|
||||
for (size_t i = 0; i < sparse_shp.size(); i++) {
|
||||
auto s = sparse_shp[i];
|
||||
auto d = dense_shp[i];
|
||||
|
|
|
@ -2806,8 +2806,8 @@ class CSRTensor(CSRTensor_):
|
|||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> indptr = Tensor([0, 1, 2], dtype=ms.int32)
|
||||
>>> indices = Tensor([0, 1], dtype=ms.int32)
|
||||
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> values = Tensor([2, 1], dtype=mstype.float32)
|
||||
>>> dense_shape = (2, 4)
|
||||
>>> csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
|
|
|
@ -171,8 +171,6 @@ def csr_mul(x, y):
|
|||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
if x.shape[0] != 1 and y.shape[0] == 1:
|
||||
y = y.expand_as(x)
|
||||
return _csr_ops.CSRMul()(x, y)
|
||||
|
||||
def csr_div(x, y):
|
||||
|
@ -195,8 +193,6 @@ def csr_div(x, y):
|
|||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
if x.shape[0] != 1 and y.shape[0] == 1:
|
||||
y = y.expand_as(x)
|
||||
return _csr_ops.CSRDiv()(x, y)
|
||||
|
||||
csr_mv = _csr_ops.CSRMV()
|
||||
|
@ -974,6 +970,7 @@ coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
|
|||
def print_info(info):
|
||||
print(info)
|
||||
|
||||
|
||||
def make_sparse_tensor(indices, values, dense_shape):
|
||||
"""Call make_coo_tensor in this function."""
|
||||
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +
|
||||
|
|
Loading…
Reference in New Issue