!32211 fix sparse validator master

Merge pull request !32211 from 杨林枫/sparse_fix
This commit is contained in:
i-robot 2022-04-01 06:37:41 +00:00 committed by Gitee
commit e9b0e5268d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 9 additions and 7 deletions

View File

@ -399,7 +399,7 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv
}
constexpr int64_t kDimTwo = 2;
if (indices_shp[kIndexOne] != kDimTwo) {
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << "but got "
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got "
<< indices_shp[kIndexOne];
}
@ -718,11 +718,13 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
return elem;
});
if (shape_vec.size() != kSizeTwo) {
MS_EXCEPTION(ValueError) << "Currently only supports 2-dimensional csr tensor, got shape length = "
MS_EXCEPTION(ValueError) << "Currently only supports 2-dimensional csr tensor, but got shape length = "
<< shape_vec.size() << ".";
}
if (values_shp.size() + 1 != shape_vec.size()) {
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1.";
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got"
<< "Values' dimension: " << values_shp.size()
<< ", csr_tensor's dimension: " << shape_vec.size() << ".";
}
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];

View File

@ -884,7 +884,7 @@ class Validator:
def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
"""Checks input tensors' shapes for CSRTensor."""
if len(csr_shp) != 2:
raise ValueError(f"Currently only supports 2-dimensional csr tensor, got `shape length`={len(csr_shp)}.")
raise ValueError(f"Currently only supports 2-dimensional csr tensor, but got shape length={len(csr_shp)}.")
shape_size = 1
for item in csr_shp:
if item <= 0:
@ -938,7 +938,7 @@ class Validator:
raise ValueError(f"For COOTensor, `values` must be a 1-dimensional tensor, but got a {len(values_shp)}" \
f"dimension tensor.")
if indices_shp[0] != values_shp[0]:
raise ValueError(f"For COOTensor, `indices.shape[0]` must be euqla to `values.shape[0]`, but got " \
raise ValueError(f"For COOTensor, `indices.shape[0]` must be euqal to `values.shape[0]`, but got " \
f"`indices.shape[0]` = {indices_shp[0]} and `values.shape[0]` = {values_shp[0]}.")
if indices_shp[1] != 2:
raise ValueError(f"For COOTensor, `indices.shape[1]` must be 2, but got {indices_shp[1]}.")

View File

@ -1931,7 +1931,7 @@ def csr_abs(x):
def csr_mv(x, dense_vector):
"""Implementation of `abs` for CSRTensor."""
check_value_type('dense_vector', dense_vector, (Tensor_,), 'Tensor')
check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
return F.csr_mv(x, dense_vector)

View File

@ -2846,7 +2846,7 @@ class CSRTensor(CSRTensor_):
[[2.]
[1.]]
"""
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'Tensor')
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
def sum(self, axis):