From 26df6fc50bb4b65641ccb3f07cef3395a19b2da0 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Tue, 22 Dec 2020 17:37:36 +0800 Subject: [PATCH] add typecheck for in_channels&out_channels of Conv2dBnFoldQuant --- mindspore/nn/layer/quant.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 608d6a24e60..9fbbfa9080f 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -407,8 +407,8 @@ class Conv2dBnFoldQuantOneConv(Cell): quant_dtype=QuantDtype.INT8): """Initialize Conv2dBnFoldQuant layer""" super(Conv2dBnFoldQuantOneConv, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.kernel_size = twice(kernel_size) self.stride = twice(stride) self.pad_mode = pad_mode @@ -626,8 +626,8 @@ class Conv2dBnFoldQuant(Cell): freeze_bn=100000): """Initialize Conv2dBnFoldQuant layer""" super(Conv2dBnFoldQuant, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.kernel_size = twice(kernel_size) self.stride = twice(stride) self.pad_mode = pad_mode