diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index dfef34570d7..bb2e23858f7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -50,7 +50,7 @@ void GetRealInputSize(const nlohmann::json &input_json, std::vector *inp } std::string dtype = input_json[kJDtype]; size_t nbyte = tbe::GetDtypeNbyte(dtype); - (*size_i) *= nbyte; + (*size_i) = SizetMulWithOverflowCheck(*size_i, nbyte); input_size_list->push_back((*size_i)); } @@ -78,7 +78,7 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector *o if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) { auto output_max_shape = output_json[kJRange]; for (auto &max_shape : output_max_shape) { - (*size_i) *= LongToSize(max_shape[1]); + (*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(max_shape[1])); } MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape."; } else { @@ -89,15 +89,15 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector *o MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape"; } MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << output_max_shape[j][1]; - (*size_i) *= LongToSize(output_max_shape[j][1]); + (*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(output_max_shape[j][1])); continue; } - (*size_i) *= static_cast(output_json[kJShape][j]); + (*size_i) = SizetMulWithOverflowCheck(*size_i, static_cast(output_json[kJShape][j])); } } std::string dtype = output_json[kJDtype]; size_t nbyte = tbe::GetDtypeNbyte(dtype); - (*size_i) *= nbyte; + (*size_i) = SizetMulWithOverflowCheck(*size_i, nbyte); output_size_list->push_back((*size_i)); } @@ -147,11 +147,11 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector(shape_item); + ret = SizetMulWithOverflowCheck(ret, static_cast(shape_item)); } std::string data_type = desc[kJDataType]; size_t nbyte = tbe::GetDtypeNbyte(data_type); - ret *= nbyte; + ret = SizetMulWithOverflowCheck(ret, nbyte); return ret; } diff --git a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h index 3be6edbc351..2625a914b0f 100644 --- a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h +++ b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h @@ -56,12 +56,13 @@ class Var : public Base { explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { EnsureTag(); } - Var(const Var &other) : Base(other), tag_(other.tag_) {} + Var(const Var &other) : Base(other), tag_(other.tag_), primitive_(other.primitive_) {} virtual Var &operator=(const Var &other) { if (&other == this) { return *this; } this->tag_ = other.tag_; + this->primitive_ = other.primitive_; return *this; } ~Var() override = default;