forked from mindspore-Ecosystem/mindspore
!13727 remove parameters of fuction in CheckAndConvertUtils
From: @lianliguang Reviewed-by: @zh_qh,@ginfung Signed-off-by: @ginfung
This commit is contained in:
commit
b1043bcf55
|
@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -36,8 +36,8 @@ PadMode AvgPool::get_pad_mode() const {
|
|||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(),
|
||||
false, true)));
|
||||
this->AddAttr(kKernelSize,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> AvgPool::get_kernel_size() const {
|
||||
|
@ -45,8 +45,7 @@ std::vector<int64_t> AvgPool::get_kernel_size() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
|
||||
this->AddAttr(kStrides,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
|
||||
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> AvgPool::get_strides() const {
|
||||
|
|
|
@ -93,8 +93,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
w_out = floor(w_out);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name);
|
||||
primitive->AddAttr(kPadList,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true)));
|
||||
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name)));
|
||||
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
||||
if (format == NHWC) {
|
||||
out_shape = {x_shape[0], h_out, w_out, out_channel};
|
||||
|
@ -144,11 +143,11 @@ void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
|||
}
|
||||
|
||||
void Conv2D::set_stride(const std::vector<int64_t> &stride) {
|
||||
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true)));
|
||||
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true)));
|
||||
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_pad_mode(const PadMode &pad_mode) {
|
||||
|
@ -166,7 +165,7 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) {
|
|||
|
||||
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_mode(int64_t mode) {
|
||||
|
|
|
@ -111,7 +111,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
|
|||
|
||||
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_mode(int64_t mode) {
|
||||
|
|
|
@ -35,13 +35,13 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
|
|||
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
|
||||
|
||||
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
||||
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false);
|
||||
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
|
||||
if (strides[0] != strides[1]) {
|
||||
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
|
||||
<< ", width " << strides[1];
|
||||
}
|
||||
this->set_stride(strides);
|
||||
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false);
|
||||
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
|
||||
if (dilations[0] != dilations[1]) {
|
||||
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
|
||||
<< ", width " << dilations[1];
|
||||
|
@ -57,7 +57,7 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
|
|||
} else {
|
||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
||||
}
|
||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
|
||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
|
||||
|
||||
this->set_out_channel(
|
||||
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));
|
||||
|
|
|
@ -30,13 +30,13 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve
|
|||
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
|
||||
|
||||
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
||||
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false);
|
||||
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
|
||||
if (strides[0] != strides[1]) {
|
||||
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
|
||||
<< ", width " << strides[1];
|
||||
}
|
||||
this->set_stride(strides);
|
||||
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false);
|
||||
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
|
||||
if (dilations[0] != dilations[1]) {
|
||||
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
|
||||
<< ", width " << dilations[1];
|
||||
|
@ -52,7 +52,7 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve
|
|||
} else {
|
||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
||||
}
|
||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
|
||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
|
||||
|
||||
this->set_out_channel(
|
||||
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));
|
||||
|
|
|
@ -105,11 +105,11 @@ void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_siz
|
|||
}
|
||||
|
||||
void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) {
|
||||
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true)));
|
||||
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
|
||||
}
|
||||
|
||||
void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true)));
|
||||
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
|
||||
}
|
||||
|
||||
void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
|
||||
|
@ -127,7 +127,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
|
|||
|
||||
void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
void Conv2DBackpropInput::set_mode(int64_t mode) {
|
||||
|
|
|
@ -36,8 +36,8 @@ PadMode MaxPool::get_pad_mode() const {
|
|||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(),
|
||||
false, true)));
|
||||
this->AddAttr(kKernelSize,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> MaxPool::get_kernel_size() const {
|
||||
|
@ -45,8 +45,7 @@ std::vector<int64_t> MaxPool::get_kernel_size() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
void MaxPool::set_strides(const std::vector<int64_t> &strides) {
|
||||
this->AddAttr(kStrides,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
|
||||
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> MaxPool::get_strides() const {
|
||||
|
|
|
@ -330,24 +330,10 @@ bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, cons
|
|||
|
||||
std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name,
|
||||
const std::vector<int64_t> &arg_value,
|
||||
const std::string &prim_name, bool allow_four,
|
||||
bool ret_four) {
|
||||
auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void {
|
||||
std::ostringstream buffer;
|
||||
buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two ";
|
||||
if (allow_four) {
|
||||
buffer << "or four ";
|
||||
}
|
||||
buffer << " positive int64_t numbers , but got [";
|
||||
for (auto item : arg_value) {
|
||||
buffer << item << ",";
|
||||
}
|
||||
buffer << "]";
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
};
|
||||
const std::string &prim_name) {
|
||||
for (auto item : arg_value) {
|
||||
if (item < 0) {
|
||||
raise_message();
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " attr " << arg_name << " should be a positive vector";
|
||||
}
|
||||
}
|
||||
return arg_value;
|
||||
|
|
|
@ -162,8 +162,7 @@ const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeT
|
|||
class CheckAndConvertUtils {
|
||||
public:
|
||||
static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value,
|
||||
const std::string &prim_name, bool allow_four = false,
|
||||
bool ret_four = false);
|
||||
const std::string &prim_name);
|
||||
static std::string CheckString(const std::string &arg_name, const std::string &arg_value,
|
||||
const std::set<std::string> &check_list, const std::string &prim_name);
|
||||
|
||||
|
|
Loading…
Reference in New Issue