!13727 remove parameters of fuction in CheckAndConvertUtils

From: @lianliguang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @ginfung
This commit is contained in:
mindspore-ci-bot 2021-03-22 19:45:52 +08:00 committed by Gitee
commit b1043bcf55
10 changed files with 23 additions and 43 deletions

View File

@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -36,8 +36,8 @@ PadMode AvgPool::get_pad_mode() const {
return PadMode(GetValue<int64_t>(value_ptr));
}
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(),
false, true)));
this->AddAttr(kKernelSize,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
}
std::vector<int64_t> AvgPool::get_kernel_size() const {
@ -45,8 +45,7 @@ std::vector<int64_t> AvgPool::get_kernel_size() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
this->AddAttr(kStrides,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
}
std::vector<int64_t> AvgPool::get_strides() const {

View File

@ -93,8 +93,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
w_out = floor(w_out);
}
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name);
primitive->AddAttr(kPadList,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true)));
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name)));
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
if (format == NHWC) {
out_shape = {x_shape[0], h_out, w_out, out_channel};
@ -144,11 +143,11 @@ void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
}
void Conv2D::set_stride(const std::vector<int64_t> &stride) {
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true)));
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
}
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true)));
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
}
void Conv2D::set_pad_mode(const PadMode &pad_mode) {
@ -166,7 +165,7 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) {
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}
void Conv2D::set_mode(int64_t mode) {

View File

@ -111,7 +111,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}
void Conv2dTranspose::set_mode(int64_t mode) {

View File

@ -35,13 +35,13 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false);
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
if (strides[0] != strides[1]) {
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
<< ", width " << strides[1];
}
this->set_stride(strides);
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false);
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
if (dilations[0] != dilations[1]) {
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
<< ", width " << dilations[1];
@ -57,7 +57,7 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
} else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
}
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
this->set_out_channel(
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));

View File

@ -30,13 +30,13 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false);
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
if (strides[0] != strides[1]) {
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
<< ", width " << strides[1];
}
this->set_stride(strides);
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false);
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
if (dilations[0] != dilations[1]) {
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
<< ", width " << dilations[1];
@ -52,7 +52,7 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve
} else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
}
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
this->set_out_channel(
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));

View File

@ -105,11 +105,11 @@ void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_siz
}
void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) {
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true)));
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
}
void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) {
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true)));
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
}
void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
@ -127,7 +127,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}
void Conv2DBackpropInput::set_mode(int64_t mode) {

View File

@ -36,8 +36,8 @@ PadMode MaxPool::get_pad_mode() const {
return PadMode(GetValue<int64_t>(value_ptr));
}
void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(),
false, true)));
this->AddAttr(kKernelSize,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
}
std::vector<int64_t> MaxPool::get_kernel_size() const {
@ -45,8 +45,7 @@ std::vector<int64_t> MaxPool::get_kernel_size() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
void MaxPool::set_strides(const std::vector<int64_t> &strides) {
this->AddAttr(kStrides,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
}
std::vector<int64_t> MaxPool::get_strides() const {

View File

@ -330,24 +330,10 @@ bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, cons
std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name,
const std::vector<int64_t> &arg_value,
const std::string &prim_name, bool allow_four,
bool ret_four) {
auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void {
std::ostringstream buffer;
buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two ";
if (allow_four) {
buffer << "or four ";
}
buffer << " positive int64_t numbers , but got [";
for (auto item : arg_value) {
buffer << item << ",";
}
buffer << "]";
MS_EXCEPTION(ValueError) << buffer.str();
};
const std::string &prim_name) {
for (auto item : arg_value) {
if (item < 0) {
raise_message();
MS_EXCEPTION(ValueError) << "For " << prim_name << " attr " << arg_name << " should be a positive vector";
}
}
return arg_value;

View File

@ -162,8 +162,7 @@ const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeT
class CheckAndConvertUtils {
public:
static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value,
const std::string &prim_name, bool allow_four = false,
bool ret_four = false);
const std::string &prim_name);
static std::string CheckString(const std::string &arg_name, const std::string &arg_value,
const std::set<std::string> &check_list, const std::string &prim_name);