!39177 Modify to compatibale with dynamic input
Merge pull request !39177 from YijieChen/sparse_add_dev
This commit is contained in:
commit
67c6d9733e
|
@ -235,7 +235,7 @@ const std::vector<std::pair<KernelAttr, SparseAddCpuKernelMod::KernelRunFunc>> &
|
|||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex64, kNumberTypeFloat32, int64_t,
|
||||
std::complex<float>, float),
|
||||
// complex64 values
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, kNumberTypeFloat32, int64_t,
|
||||
CPU_SPARSE_ADD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, kNumberTypeFloat64, int64_t,
|
||||
std::complex<double>, double),
|
||||
};
|
||||
return func_list;
|
||||
|
|
|
@ -49,9 +49,8 @@ inline void CheckSparseAddGradNNZ(const int64_t indices_nnz, const int64_t value
|
|||
const std::string &value_name, const std::string &op_name) {
|
||||
if (indices_nnz != value_nnz) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << "For " << op_name << ", the length of " << indices_name << " and "
|
||||
<< value_name << " must be same, but got"
|
||||
<< "length of " << indices_name << " is " << indices_nnz << ", and length of "
|
||||
<< value_name << " is " << value_nnz;
|
||||
<< value_name << " must be same, but got length of " << indices_name << " is "
|
||||
<< indices_nnz << ", and length of " << value_name << " is " << value_nnz;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,8 +123,10 @@ AbstractBasePtr SparseAddGradInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
}
|
||||
auto sum_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kSparseAddGradIndex3]->BuildShape());
|
||||
(void)CheckSparseAddGradShape(sum_shape[kShape].size(), kIndicesShapeSize, "sum_indices", name);
|
||||
(void)CheckSparseAddGradNNZ(sum_shape[kShape][0], val_grad_shape[kShape][0], "sum_indices", "backprop_val_grad",
|
||||
name);
|
||||
if (sum_shape[kShape][0] >= 0 && val_grad_shape[kShape][0] >= 0) {
|
||||
(void)CheckSparseAddGradNNZ(sum_shape[kShape][0], val_grad_shape[kShape][0], "sum_indices", "backprop_val_grad",
|
||||
name);
|
||||
}
|
||||
AbstractBasePtrList ret = {dx1, dx2};
|
||||
return std::make_shared<AbstractTuple>(ret);
|
||||
}
|
||||
|
|
|
@ -43,9 +43,8 @@ inline void CheckSparseAddNNZ(const int64_t indices_nnz, const int64_t value_nnz
|
|||
const std::string &value_name, const std::string &op_name) {
|
||||
if (indices_nnz != value_nnz) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << "For " << op_name << ", the length of " << indices_name << " and "
|
||||
<< value_name << " must be same, but got"
|
||||
<< "length of " << indices_name << " is " << indices_nnz << ", and length of "
|
||||
<< value_name << " is " << value_nnz;
|
||||
<< value_name << " must be same, but got length of " << indices_name << " is "
|
||||
<< indices_nnz << ", and length of " << value_name << " is " << value_nnz;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -54,8 +53,9 @@ inline void CheckSparseAddSameDtype(const mindspore::TypePtr a_dtype, const mind
|
|||
const std::string &op_name) {
|
||||
if (a_dtype->type_id() != b_dtype->type_id()) {
|
||||
MS_EXCEPTION(mindspore::TypeError) << "For " << op_name << ", the type of " << a_arg_name << ", and " << b_arg_name
|
||||
<< " should be the same data type, but got type of x1_value is "
|
||||
<< a_dtype->ToString() << ", and type of x2_value is " << b_dtype->ToString();
|
||||
<< " should be the same data type, but got type of " << a_arg_name << " is "
|
||||
<< a_dtype->ToString() << ", and type of " << b_arg_name << " is "
|
||||
<< b_dtype->ToString();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -138,9 +138,12 @@ AbstractBasePtr SparseAddInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
CheckAndConvertUtils::CheckTensorTypeValid("thresh", thresh->BuildType(), thresh_valid_types, op_name);
|
||||
|
||||
// Check same nnz
|
||||
CheckSparseAddNNZ(a_indices_shape[0], a_values_shape[0], "x1_indices", "x1_values", op_name);
|
||||
CheckSparseAddNNZ(b_indices_shape[0], b_values_shape[0], "x2_indices", "x2_values", op_name);
|
||||
|
||||
if (a_indices_shape[0] >= 0 && a_values_shape[0] >= 0) {
|
||||
CheckSparseAddNNZ(a_indices_shape[0], a_values_shape[0], "x1_indices", "x1_values", op_name);
|
||||
}
|
||||
if (b_indices_shape[0] >= 0 && b_values_shape[0] >= 0) {
|
||||
CheckSparseAddNNZ(b_indices_shape[0], b_values_shape[0], "x2_indices", "x2_values", op_name);
|
||||
}
|
||||
// Check same type
|
||||
// value
|
||||
CheckSparseAddSameDtype(a_value_type, b_value_type, "x1_values", "x2_values", op_name);
|
||||
|
|
Loading…
Reference in New Issue