!39177 Modify to compatibale with dynamic input

Merge pull request !39177 from YijieChen/sparse_add_dev
This commit is contained in:
i-robot 2022-08-01 06:56:31 +00:00 committed by Gitee
commit 67c6d9733e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 18 additions and 14 deletions

View File

@ -235,7 +235,7 @@ const std::vector<std::pair<KernelAttr, SparseAddCpuKernelMod::KernelRunFunc>> &
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex64, kNumberTypeFloat32, int64_t,
std::complex<float>, float),
// complex64 values
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, kNumberTypeFloat32, int64_t,
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, kNumberTypeFloat64, int64_t,
std::complex<double>, double),
};
return func_list;

View File

@ -49,9 +49,8 @@ inline void CheckSparseAddGradNNZ(const int64_t indices_nnz, const int64_t value
const std::string &value_name, const std::string &op_name) {
if (indices_nnz != value_nnz) {
MS_EXCEPTION(mindspore::ValueError) << "For " << op_name << ", the length of " << indices_name << " and "
<< value_name << " must be same, but got"
<< "length of " << indices_name << " is " << indices_nnz << ", and length of "
<< value_name << " is " << value_nnz;
<< value_name << " must be same, but got length of " << indices_name << " is "
<< indices_nnz << ", and length of " << value_name << " is " << value_nnz;
}
}
@ -124,8 +123,10 @@ AbstractBasePtr SparseAddGradInfer(const abstract::AnalysisEnginePtr &, const Pr
}
auto sum_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kSparseAddGradIndex3]->BuildShape());
(void)CheckSparseAddGradShape(sum_shape[kShape].size(), kIndicesShapeSize, "sum_indices", name);
(void)CheckSparseAddGradNNZ(sum_shape[kShape][0], val_grad_shape[kShape][0], "sum_indices", "backprop_val_grad",
name);
if (sum_shape[kShape][0] >= 0 && val_grad_shape[kShape][0] >= 0) {
(void)CheckSparseAddGradNNZ(sum_shape[kShape][0], val_grad_shape[kShape][0], "sum_indices", "backprop_val_grad",
name);
}
AbstractBasePtrList ret = {dx1, dx2};
return std::make_shared<AbstractTuple>(ret);
}

View File

@ -43,9 +43,8 @@ inline void CheckSparseAddNNZ(const int64_t indices_nnz, const int64_t value_nnz
const std::string &value_name, const std::string &op_name) {
if (indices_nnz != value_nnz) {
MS_EXCEPTION(mindspore::ValueError) << "For " << op_name << ", the length of " << indices_name << " and "
<< value_name << " must be same, but got"
<< "length of " << indices_name << " is " << indices_nnz << ", and length of "
<< value_name << " is " << value_nnz;
<< value_name << " must be same, but got length of " << indices_name << " is "
<< indices_nnz << ", and length of " << value_name << " is " << value_nnz;
}
}
@ -54,8 +53,9 @@ inline void CheckSparseAddSameDtype(const mindspore::TypePtr a_dtype, const mind
const std::string &op_name) {
if (a_dtype->type_id() != b_dtype->type_id()) {
MS_EXCEPTION(mindspore::TypeError) << "For " << op_name << ", the type of " << a_arg_name << ", and " << b_arg_name
<< " should be the same data type, but got type of x1_value is "
<< a_dtype->ToString() << ", and type of x2_value is " << b_dtype->ToString();
<< " should be the same data type, but got type of " << a_arg_name << " is "
<< a_dtype->ToString() << ", and type of " << b_arg_name << " is "
<< b_dtype->ToString();
}
}
} // namespace
@ -138,9 +138,12 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
CheckAndConvertUtils::CheckTensorTypeValid("thresh", thresh->BuildType(), thresh_valid_types, op_name);
// Check same nnz
CheckSparseAddNNZ(a_indices_shape[0], a_values_shape[0], "x1_indices", "x1_values", op_name);
CheckSparseAddNNZ(b_indices_shape[0], b_values_shape[0], "x2_indices", "x2_values", op_name);
if (a_indices_shape[0] >= 0 && a_values_shape[0] >= 0) {
CheckSparseAddNNZ(a_indices_shape[0], a_values_shape[0], "x1_indices", "x1_values", op_name);
}
if (b_indices_shape[0] >= 0 && b_values_shape[0] >= 0) {
CheckSparseAddNNZ(b_indices_shape[0], b_values_shape[0], "x2_indices", "x2_values", op_name);
}
// Check same type
// value
CheckSparseAddSameDtype(a_value_type, b_value_type, "x1_values", "x2_values", op_name);