forked from mindspore-Ecosystem/mindspore
!32211 fix sparse validator master
Merge pull request !32211 from 杨林枫/sparse_fix
This commit is contained in:
commit
e9b0e5268d
|
@ -399,7 +399,7 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
constexpr int64_t kDimTwo = 2;
|
||||
if (indices_shp[kIndexOne] != kDimTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << "but got "
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got "
|
||||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
|
||||
|
@ -718,11 +718,13 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
return elem;
|
||||
});
|
||||
if (shape_vec.size() != kSizeTwo) {
|
||||
MS_EXCEPTION(ValueError) << "Currently only supports 2-dimensional csr tensor, got shape length = "
|
||||
MS_EXCEPTION(ValueError) << "Currently only supports 2-dimensional csr tensor, but got shape length = "
|
||||
<< shape_vec.size() << ".";
|
||||
}
|
||||
if (values_shp.size() + 1 != shape_vec.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1.";
|
||||
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got"
|
||||
<< "Values' dimension: " << values_shp.size()
|
||||
<< ", csr_tensor's dimension: " << shape_vec.size() << ".";
|
||||
}
|
||||
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
|
||||
|
|
|
@ -884,7 +884,7 @@ class Validator:
|
|||
def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
|
||||
"""Checks input tensors' shapes for CSRTensor."""
|
||||
if len(csr_shp) != 2:
|
||||
raise ValueError(f"Currently only supports 2-dimensional csr tensor, got `shape length`={len(csr_shp)}.")
|
||||
raise ValueError(f"Currently only supports 2-dimensional csr tensor, but got shape length={len(csr_shp)}.")
|
||||
shape_size = 1
|
||||
for item in csr_shp:
|
||||
if item <= 0:
|
||||
|
@ -938,7 +938,7 @@ class Validator:
|
|||
raise ValueError(f"For COOTensor, `values` must be a 1-dimensional tensor, but got a {len(values_shp)}" \
|
||||
f"dimension tensor.")
|
||||
if indices_shp[0] != values_shp[0]:
|
||||
raise ValueError(f"For COOTensor, `indices.shape[0]` must be euqla to `values.shape[0]`, but got " \
|
||||
raise ValueError(f"For COOTensor, `indices.shape[0]` must be euqal to `values.shape[0]`, but got " \
|
||||
f"`indices.shape[0]` = {indices_shp[0]} and `values.shape[0]` = {values_shp[0]}.")
|
||||
if indices_shp[1] != 2:
|
||||
raise ValueError(f"For COOTensor, `indices.shape[1]` must be 2, but got {indices_shp[1]}.")
|
||||
|
|
|
@ -1931,7 +1931,7 @@ def csr_abs(x):
|
|||
|
||||
def csr_mv(x, dense_vector):
|
||||
"""Implementation of `abs` for CSRTensor."""
|
||||
check_value_type('dense_vector', dense_vector, (Tensor_,), 'Tensor')
|
||||
check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
||||
return F.csr_mv(x, dense_vector)
|
||||
|
||||
|
||||
|
|
|
@ -2846,7 +2846,7 @@ class CSRTensor(CSRTensor_):
|
|||
[[2.]
|
||||
[1.]]
|
||||
"""
|
||||
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'Tensor')
|
||||
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
||||
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
|
||||
|
||||
def sum(self, axis):
|
||||
|
|
Loading…
Reference in New Issue