code check

This commit is contained in:
hwjiaorui 2022-01-19 15:35:35 +08:00
parent 496cb8fadf
commit 3387b44317
2 changed files with 9 additions and 8 deletions

View File

@ -50,7 +50,7 @@ void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *inp
}
std::string dtype = input_json[kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype);
(*size_i) *= nbyte;
(*size_i) = SizetMulWithOverflowCheck(*size_i, nbyte);
input_size_list->push_back((*size_i));
}
@ -78,7 +78,7 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) {
auto output_max_shape = output_json[kJRange];
for (auto &max_shape : output_max_shape) {
(*size_i) *= LongToSize(max_shape[1]);
(*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(max_shape[1]));
}
MS_LOG(INFO) << "Dims is dynamic, change -2 Shape to Max Shape.";
} else {
@ -89,15 +89,15 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
}
MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << output_max_shape[j][1];
(*size_i) *= LongToSize(output_max_shape[j][1]);
(*size_i) = SizetMulWithOverflowCheck(*size_i, LongToSize(output_max_shape[j][1]));
continue;
}
(*size_i) *= static_cast<size_t>(output_json[kJShape][j]);
(*size_i) = SizetMulWithOverflowCheck(*size_i, static_cast<size_t>(output_json[kJShape][j]));
}
}
std::string dtype = output_json[kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype);
(*size_i) *= nbyte;
(*size_i) = SizetMulWithOverflowCheck(*size_i, nbyte);
output_size_list->push_back((*size_i));
}
@ -147,11 +147,11 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<si
size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) {
size_t ret = 1;
for (const auto &shape_item : desc[kJShape]) {
ret *= static_cast<size_t>(shape_item);
ret = SizetMulWithOverflowCheck(ret, static_cast<size_t>(shape_item));
}
std::string data_type = desc[kJDataType];
size_t nbyte = tbe::GetDtypeNbyte(data_type);
ret *= nbyte;
ret = SizetMulWithOverflowCheck(ret, nbyte);
return ret;
}

View File

@ -56,12 +56,13 @@ class Var : public Base {
explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
EnsureTag();
}
Var(const Var &other) : Base(other), tag_(other.tag_) {}
Var(const Var &other) : Base(other), tag_(other.tag_), primitive_(other.primitive_) {}
virtual Var &operator=(const Var &other) {
if (&other == this) {
return *this;
}
this->tag_ = other.tag_;
this->primitive_ = other.primitive_;
return *this;
}
~Var() override = default;