!48563 bugfix for SparseMatrixMatMul and Conj

Merge pull request !48563 from 黄勇/bugfix_sparsematrixmatmul_and_conj
This commit is contained in:
i-robot 2023-02-09 06:57:51 +00:00 committed by Gitee
commit 7faee6de63
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 9 additions and 1 deletions

View File

@ -197,7 +197,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpCpuFuncCreator>>>
bool UnaryOpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
std::vector<int64_t> input_shape = inputs[kIndex0]->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
if (inputs.empty() || outputs.empty() || is_null_input_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}

View File

@ -49,6 +49,7 @@ class UnaryOpCpuKernelMod : public NativeCpuKernelMod {
private:
std::shared_ptr<CpuKernelFunc> func_obj_;
bool is_null_input_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -48,6 +48,11 @@ void SparseMatrixMatMulCheckShape(const std::vector<AbstractBasePtr> &input_args
if (!is_dynamic) {
const int kInputNoBatch = 2;
const int kInputWithBatch = 3;
auto x1_dense_shape_size = x1_dense_shape.size();
if (x1_dense_shape_size == 0) {
MS_EXCEPTION(ValueError) << "For SparseMatrixMatMul, x1_dense_shape.size() = " << x1_dense_shape_size
<< ", which is invalid";
}
const int64_t rank_x1 = x1_dense_shape[0];
const int64_t rank_x2 = (SizeToLong)(x2_dense_shape.size());
if (rank_x1 != rank_x2) {