code check

This commit is contained in:
hwjiaorui 2022-01-19 15:35:35 +08:00
parent 496cb8fadf
commit 3387b44317
2 changed files with 9 additions and 8 deletions

View File

@ -50,7 +50,7 @@ void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *inp
} }
std::string dtype = input_json[kJDtype]; std::string dtype = input_json[kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype); size_t nbyte = tbe::GetDtypeNbyte(dtype);
(*size_i) *= nbyte; (*size_i) = SizetMulWithOverflowCheck(*size_i, nbyte);
input_size_list->push_back((*size_i)); input_size_list->push_back((*size_i));
} }
@ -78,7 +78,7 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) { if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) {
auto output_max_shape = output_json[kJRange]; auto output_max_shape = output_json[kJRange];
for (auto &max_shape : output_max_shape) { for (auto &max_shape : output_max_shape) {
(*size_i) *= LongToSize(max_shape[1]); (*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(max_shape[1]));
} }
MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape."; MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape.";
} else { } else {
@ -89,15 +89,15 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape"; MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
} }
MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << output_max_shape[j][1]; MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << output_max_shape[j][1];
(*size_i) *= LongToSize(output_max_shape[j][1]); (*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(output_max_shape[j][1]));
continue; continue;
} }
(*size_i) *= static_cast<size_t>(output_json[kJShape][j]); (*size_i) = SizetMulWithOverflowCheck(*size_i, static_cast<size_t>(output_json[kJShape][j]));
} }
} }
std::string dtype = output_json[kJDtype]; std::string dtype = output_json[kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype); size_t nbyte = tbe::GetDtypeNbyte(dtype);
(*size_i) *= nbyte; (*size_i) = SizetMulWithOverflowCheck(*size_i, nbyte);
output_size_list->push_back((*size_i)); output_size_list->push_back((*size_i));
} }
@ -147,11 +147,11 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<si
size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) { size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) {
size_t ret = 1; size_t ret = 1;
for (const auto &shape_item : desc[kJShape]) { for (const auto &shape_item : desc[kJShape]) {
ret *= static_cast<size_t>(shape_item); ret = SizetMulWithOverflowCheck(ret, static_cast<size_t>(shape_item));
} }
std::string data_type = desc[kJDataType]; std::string data_type = desc[kJDataType];
size_t nbyte = tbe::GetDtypeNbyte(data_type); size_t nbyte = tbe::GetDtypeNbyte(data_type);
ret *= nbyte; ret = SizetMulWithOverflowCheck(ret, nbyte);
return ret; return ret;
} }

View File

@ -56,12 +56,13 @@ class Var : public Base {
explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
EnsureTag(); EnsureTag();
} }
Var(const Var &other) : Base(other), tag_(other.tag_) {} Var(const Var &other) : Base(other), tag_(other.tag_), primitive_(other.primitive_) {}
virtual Var &operator=(const Var &other) { virtual Var &operator=(const Var &other) {
if (&other == this) { if (&other == this) {
return *this; return *this;
} }
this->tag_ = other.tag_; this->tag_ = other.tag_;
this->primitive_ = other.primitive_;
return *this; return *this;
} }
~Var() override = default; ~Var() override = default;