From 7541d3b067e67a79df3674638cbff9176853978b Mon Sep 17 00:00:00 2001 From: buxue Date: Thu, 2 Apr 2020 11:58:45 +0800 Subject: [PATCH] Develop op MaxPoolWithArgMax --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 44 --- mindspore/ccsrc/transform/convert.cc | 2 + mindspore/ccsrc/transform/op_declare.cc | 18 +- mindspore/ccsrc/transform/op_declare.h | 4 +- mindspore/model_zoo/resnet.py | 2 +- mindspore/nn/layer/pooling.py | 145 ++++---- mindspore/ops/_grad/grad_nn_ops.py | 11 +- .../_op_impl/tbe/max_pool_grad_with_argmax.py | 6 +- .../ops/_op_impl/tbe/max_pool_with_argmax.py | 6 +- mindspore/ops/operations/_grad_ops.py | 119 +++--- mindspore/ops/operations/nn_ops.py | 342 +++++++++--------- tests/perf_test/resnet_example.py | 2 +- tests/st/networks/test_cpu_lenet.py | 6 +- tests/st/networks/test_gpu_alexnet.py | 2 +- .../ops/davinci/test_maxpool_with_argmax.py | 15 +- .../davinci/test_maxpool_with_argmax_grad.py | 6 +- tests/st/tbe_networks/resnet.py | 2 +- .../gtest_input/pre_activate/hw_opt_test.py | 2 +- .../pre_activate/insert_trans_op_test.py | 2 +- .../pre_activate/mixed_precision_test.py | 2 +- .../pre_activate/transdata_split_test.py | 2 +- .../transpose_transdata_fusion_test.py | 2 +- .../gtest_input/session/session_test.py | 2 +- .../test_data_parallel_resnet.py | 4 +- tests/ut/python/exec/resnet_example.py | 2 +- tests/ut/python/exec/test_pooling.py | 6 +- tests/ut/python/model/res18_example.py | 4 +- tests/ut/python/nn/test_cell.py | 2 +- tests/ut/python/nn/test_pooling.py | 3 +- tests/ut/python/ops/test_nn_ops.py | 105 +++--- tests/ut/python/ops/test_ops.py | 2 +- tests/ut/python/ops/test_ops_check.py | 8 +- .../pynative_mode/ge/ops/test_pooling.py | 4 +- tests/ut/python/pynative_mode/nn/test_cell.py | 2 +- .../python/pynative_mode/nn/test_pooling.py | 11 +- tests/ut/python/pynative_mode/vm/test_vm.py | 8 +- tests/ut/python/utils/test_serialize.py | 2 +- tests/vm_impl/nn_ops_vm_impl.py | 90 ++++- tests/vm_impl/vm_me.py | 46 ++- 39 files changed, 537 insertions(+), 506 deletions(-) mode change 100755 => 100644 mindspore/ccsrc/transform/op_declare.cc diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index c0416f648b1..229a3eb34a7 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -148,8 +148,6 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector TbeAdapter::build_json_attr_pass_map_ = { - {"MaxPoolWithArgmax", TbeAdapter::MaxPoolWithArgmaxAttrJsonPass}, - {"MaxPoolGradWithArgmax", TbeAdapter::MaxPoolGradWithArgmaxAttrJsonPass}, {"Conv2D", TbeAdapter::Conv2DAttrJsonPass}, {"Conv2DBackpropFilter", TbeAdapter::Conv2DBackpropFilterAttrJsonPass}, {"Conv2DBackpropInput", TbeAdapter::Conv2DBackpropInputAttrJsonPass}, @@ -170,48 +168,6 @@ bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node, return false; } -void TbeAdapter::MaxPoolWithArgmaxAttrJsonPass( - const mindspore::AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - auto attr_num = op_info_attrs.size(); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - for (size_t i = 0; i < attr_num; i++) { - nlohmann::json attr_obj; - MS_EXCEPTION_IF_NULL(op_info_attrs[i]); - std::string attr_name = op_info_attrs[i]->name(); - if (primitive->GetAttr(attr_name) != nullptr) { - auto value = primitive->GetAttr(attr_name); - if (attr_name == "pad_mode") { - std::string attr_value = GetValue(value); - (void)transform(attr_value.begin(), attr_value.end(), attr_value.begin(), ::toupper); - attr_obj["value"] = attr_value; - } else { - std::vector attr_value; - int data = GetValue(value); - attr_value.push_back(1); - attr_value.push_back(data); - attr_value.push_back(data); - attr_value.push_back(1); - attr_obj["value"] = attr_value; - } - attr_obj["valid"] = true; - } else { - attr_obj["valid"] = false; - } - attr_obj["name"] = attr_name; - attrs_json->push_back(attr_obj); - } -} - -void TbeAdapter::MaxPoolGradWithArgmaxAttrJsonPass( - const mindspore::AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MaxPoolWithArgmaxAttrJsonPass(anf_node, op_info_attrs, attrs_json); -} - void TbeAdapter::Conv2DAttrJsonPass(const mindspore::AnfNodePtr &anf_node, const std::vector> &op_info_attrs, nlohmann::json *attrs_json) { diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index c400d1c5733..bee460d84ac 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -161,6 +161,7 @@ const char kNameTopK[] = "TopK"; const char kNameSoftmaxGrad[] = "SoftmaxGrad"; const char kNameMaxPool[] = "MaxPool"; const char kNameAvgPool[] = "AvgPool"; +const char kNameMaxPoolWithArgmax[] = "MaxPoolWithArgmax"; const char kNameBatchNorm[] = "BatchNorm"; const char kNameBatchNormGrad[] = "BatchNormGrad"; const char kNameROIAlign[] = "ROIAlign"; @@ -198,6 +199,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentum)}, {string(kNameMaxPool), ADPT_DESC(MaxPool)}, {string(kNameAvgPool), ADPT_DESC(AvgPool)}, + {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, {string(kNameTopK), ADPT_DESC(TopKV2)}, {string(kNamePack), ADPT_DESC(Pack)}, {string(kNameSplitD), ADPT_DESC(SplitD)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc old mode 100755 new mode 100644 index 0af2923cc4a..419805c37fd --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -734,14 +734,22 @@ ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits< OUTPUT_MAP(AvgPoolGrad) = {{0, OUTPUT_DESC(out_grad)}}; // MaxPoolWithArgmax +INPUT_MAP(MaxPoolWithArgmax) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(MaxPoolWithArgmax) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(argmax)}}; + +// MaxPoolGradWithArgmax INPUT_MAP(MaxPoolGradWithArgmax) = { {1, INPUT_DESC(x)}, - {2, INPUT_DESC(argmax)}, - {3, INPUT_DESC(grad)}, + {2, INPUT_DESC(grad)}, + {3, INPUT_DESC(argmax)}, }; -ATTR_MAP(MaxPoolGradWithArgmax) = {{"pad_mode", ATTR_DESC(padding, AnyTraits())}, - {"window", ATTR_DESC(ksize, "window", AnyTraits>())}, - {"stride", ATTR_DESC(strides, "stride", AnyTraits>())}}; +ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; // Conv2D INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index d120c949892..e4d41011271 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -88,8 +88,10 @@ DECLARE_OP_ADAPTER(FusedBatchNormGrad) DECLARE_OP_USE_OUTPUT(FusedBatchNormGrad) DECLARE_OP_ADAPTER(BiasAddGrad) DECLARE_OP_USE_OUTPUT(BiasAddGrad) +DECLARE_OP_ADAPTER(MaxPoolWithArgmax) +DECLARE_OP_USE_OUTPUT(MaxPoolWithArgmax) DECLARE_OP_ADAPTER(MaxPoolGradWithArgmax) -DECLARE_OP_USE_ENUM(MaxPoolGradWithArgmax) +DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax) DECLARE_OP_ADAPTER(Conv2D) DECLARE_OP_USE_ENUM(Conv2D) DECLARE_OP_USE_OUTPUT(Conv2D) diff --git a/mindspore/model_zoo/resnet.py b/mindspore/model_zoo/resnet.py index 403f66e4150..9d010eede1f 100755 --- a/mindspore/model_zoo/resnet.py +++ b/mindspore/model_zoo/resnet.py @@ -168,7 +168,7 @@ class ResNet(nn.Cell): self.conv1 = _conv7x7(3, 64, stride=2) self.bn1 = _bn(64) self.relu = P.ReLU() - self.maxpool = P.MaxPoolWithArgmax(pad_mode='same', window=3, stride=2) + self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) self.layer1 = self._make_layer(block, layer_nums[0], diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 6ff28dd3629..bf90fcc9de9 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -13,36 +13,52 @@ # limitations under the License. # ============================================================================ """pooling""" - from mindspore.ops import operations as P from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Rel +from ... import context from ..cell import Cell class _PoolNd(Cell): """N-D AvgPool""" - def __init__(self, - kernel_size, - stride, - pad_mode, - padding=0, - pool=None): + def __init__(self, kernel_size, stride, pad_mode): + name = self.__class__.__name__ super(_PoolNd, self).__init__() - self.kernel_size = kernel_size - self.stride = stride - self.pad_mode = pad_mode - self.padding = validator.check_integer('padding', padding, 0, Rel.GE) - self.pool = pool - if self.pool is None: - raise NotImplementedError + validator.check_type('kernel_size', kernel_size, [int, tuple]) + validator.check_type('stride', stride, [int, tuple]) + self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) - def construct(self, x): - return self.pool(x) + if isinstance(kernel_size, int): + validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) + else: + if (len(kernel_size) != 2 or + (not isinstance(kernel_size[0], int)) or + (not isinstance(kernel_size[1], int)) or + kernel_size[0] <= 0 or + kernel_size[1] <= 0): + raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or' + f'a tuple of two positive int numbers, but got {kernel_size}') + self.kernel_size = kernel_size + + if isinstance(stride, int): + validator.check_integer("stride", stride, 1, Rel.GE) + else: + if (len(stride) != 2 or + (not isinstance(stride[0], int)) or + (not isinstance(stride[1], int)) or + stride[0] <= 0 or + stride[1] <= 0): + raise ValueError(f'The stride passed to cell {name} should be an positive int number or' + f'a tuple of two positive int numbers, but got {stride}') + self.stride = stride + + def construct(self, *inputs): + pass def extend_repr(self): - return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__) + return 'kernel_size={kernel_size}, strides={strides}, pad_mode={pad_mode}'.format(**self.__dict__) class MaxPool2d(_PoolNd): @@ -63,19 +79,23 @@ class MaxPool2d(_PoolNd): pad_mode for training only supports "same" and "valid". Args: - kernel_size (int): Size of the window to take a max over. Default 1. - stride (int): Stride size of the window. Default: 1. - pad_mode (str): Select the mode of the pad. The optional values are - "same" and "valid". Default: "valid". + kernel_size (Union[int, tuple[int]]): The size of kernel used to take the max value, + is an int number that represents height and width are both kernel_size, + or a tuple of two int numbers that represent height and width respectively. + Default: 1. + stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". - same: Adopts the way of completion. Output height and width will be the same as the input. Total number of padding will be calculated for horizontal and vertical - direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the - last extra padding will be done from the bottom and the right side. + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. - - valid: Adopts the way of discarding. The possibly largest height and width of output will be return - without padding. Extra pixels will be discarded. - padding (int): Implicit zero padding to be added on both sides. Default: 0. + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -103,31 +123,22 @@ class MaxPool2d(_PoolNd): [[7. 8.] [8. 8.]]]] """ - def __init__(self, - kernel_size=1, - stride=1, - pad_mode="VALID", - padding=0): - max_pool = P.MaxPool(ksize=kernel_size, - strides=stride, - padding=pad_mode) - self.is_autodiff_backend = False - if self.is_autodiff_backend: - # At present, pad mode of max pool is not unified, so it is a temporarily avoided - pad_mode = validator.check_string('pad_mode', pad_mode.lower(), ['valid', 'same']) - - max_pool = P.MaxPoolWithArgmax(window=kernel_size, - stride=stride, - pad_mode=pad_mode, - pad=padding) - super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode, padding, max_pool) + def __init__(self, kernel_size=1, stride=1, pad_mode="valid"): + super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode) + self.max_pool = P.MaxPool(ksize=self.kernel_size, + strides=self.stride, + padding=self.pad_mode) + self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size, + strides=self.stride, + padding=self.pad_mode) + self.is_tbe = context.get_context("device_target") == "Ascend" def construct(self, x): - if self.is_autodiff_backend: - out = self.pool(x)[0] + if self.is_tbe and self.training: + out = self.max_pool_with_arg_max(x)[0] else: - out = self.pool(x) + out = self.max_pool(x) return out @@ -149,19 +160,24 @@ class AvgPool2d(_PoolNd): pad_mode for training only supports "same" and "valid". Args: - kernel_size (int): Size of the window to take a max over. Default: 1. - stride (int): Stride size of the window. Default: 1. - pad_mode (str): Select the mode of the pad. The optional values are - "same", "valid". Default: "valid". + kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value, + is an int number that represents height and width are both kernel_size, + or a tuple of two int numbers that represent height and width respectively. + Default: 1. + stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". - same: Adopts the way of completion. Output height and width will be the same as the input. Total number of padding will be calculated for horizontal and vertical - direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the - last extra padding will be done from the bottom and the right side. + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. + + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. - - valid: Adopts the way of discarding. The possibly largest height and width of output will be return - without padding. Extra pixels will be discarded. - padding (int): Implicit zero padding to be added on both sides. Default: 0. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -170,7 +186,7 @@ class AvgPool2d(_PoolNd): Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> pool = AvgPool2d(kernel_size=3, stride=1) + >>> pool = AvgPool2d(kernel_size=3, strides=1) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) [[[[5. 5. 9. 9.] [8. 4. 3. 0.] @@ -189,12 +205,15 @@ class AvgPool2d(_PoolNd): [[4.2222223 4.5555553] [3.2222223 4.5555553]]]] """ + def __init__(self, kernel_size=1, stride=1, - pad_mode="VALID", - padding=0): - avg_pool = P.AvgPool(ksize=kernel_size, - strides=stride, - padding=pad_mode) - super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode, padding, avg_pool) + pad_mode="valid"): + super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode) + self.avg_pool = P.AvgPool(ksize=self.kernel_size, + strides=self.stride, + padding=self.pad_mode) + + def construct(self, x): + return self.avg_pool(x) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index bad99351a50..fbe48aff973 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -76,14 +76,9 @@ def get_bprop_depthwise_conv2d_native(self): def get_bprop_max_pool_with_argmax(self): """Grad definition for `MaxPoolWithArgmax` operation.""" maxpool_grad = G.MaxPoolGradWithArgmax( - pad_mode=self.pad_mode, - window=self.window, - pad=self.pad, - stride=self.stride, - data_mode=self.data_mode, - ceil_mode=self.ceil_mode, - alpha=self.alpha, - beta=self.beta) + ksize=self.ksize, + strides=self.strides, + padding=self.padding,) def bprop(x, out, dout): dx = maxpool_grad(x, dout[0], out[1]) diff --git a/mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py b/mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py index a167ef85f8c..3730ee1b93e 100644 --- a/mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +++ b/mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py @@ -28,19 +28,19 @@ from mindspore.ops.op_info_register import op_info_register "partial_flag": true, "attr": [ { - "name": "window", + "name": "ksize", "param_type": "required", "type": "listInt", "value": "all" }, { - "name": "stride", + "name": "strides", "param_type": "required", "type": "listInt", "value": "all" }, { - "name": "pad_mode", + "name": "padding", "param_type": "required", "type": "str", "value": "all" diff --git a/mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py b/mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py index 04d0eeb92c1..2e081c10824 100644 --- a/mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +++ b/mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py @@ -28,19 +28,19 @@ from mindspore.ops.op_info_register import op_info_register "partial_flag": true, "attr": [ { - "name": "window", + "name": "ksize", "param_type": "required", "type": "listInt", "value": "all" }, { - "name": "stride", + "name": "strides", "param_type": "required", "type": "listInt", "value": "all" }, { - "name": "pad_mode", + "name": "padding", "param_type": "required", "type": "str", "value": "all" diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index a699c23adc0..f38044ab6ac 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -15,7 +15,6 @@ """Operators for gradients.""" -import math from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register @@ -340,59 +339,60 @@ class _PoolGrad(PrimitiveWithInfer): """Gradients of the max/avg pool operation.""" @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="VALID"): + def __init__(self, ksize, strides, padding="VALID"): self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) - self.ksize = ksize - self.strides = strides - self.padding = padding - self.ksize = validator.check_type('ksize', self.ksize, [int, tuple]) - self.strides = validator.check_type('strides', self.strides, [int, tuple]) - - validator.check_type('padding', self.padding, [str]) - self.padding = validator.check_string('padding', self.padding, ['VALID', 'SAME']) + validator.check_type('ksize', ksize, [int, tuple]) + validator.check_type('strides', strides, [int, tuple]) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) self.add_prim_attr("padding", self.padding) - self.add_prim_attr('data_format', "NCHW") + self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax") + if not self.is_maxpoolgradwithargmax: + self.add_prim_attr('data_format', "NCHW") - if isinstance(self.ksize, int): - self.pool_h = validator.check_integer("ksize", self.ksize, 1, Rel.GE) - self.pool_w = self.pool_h - self.add_prim_attr("ksize", (1, 1, self.ksize, self.ksize)) - elif isinstance(self.ksize, tuple): - if (len(self.ksize) != 2 and len(self.ksize) != 4): - raise ValueError('Attr \'ksize\' of \'Pool\' Op passed ' + - str(self.ksize)+', should be a int or a tuple of length 2 or 4.') - for ksize_val in self.ksize: - if (not isinstance(ksize_val, int)) or (ksize_val <= 0): - raise ValueError('Each value of attr \'ksize\' of \'MaxPool\' Op passed ' + - str(self.ksize)+', should be int and greater than 0.') - self.pool_h = self.ksize[-2] - self.pool_w = self.ksize[-1] - self.add_prim_attr("ksize", (1, 1, self.ksize[-2], self.ksize[-1])) - - if isinstance(self.strides, int): - self.stride_h = validator.check_integer("strides", self.strides, 1, Rel.GE) - self.stride_w = self.stride_h - self.add_prim_attr("strides", (1, 1, self.strides, self.strides)) - elif isinstance(self.strides, tuple): - if (len(self.strides) != 2 and len(self.strides) != 4): - raise ValueError('Attr \'strides\' of \'MaxPool\' Op passed ' + - str(self.strides)+', should be a int or a tuple of length 2 or 4.') - for stride_val in self.strides: - if (not isinstance(stride_val, int)) or (stride_val <= 0): - raise ValueError('Each value of attr \'strides\' of \'MaxPool\' Op passed ' + - str(self.strides)+', should be int and greater than 0.') - self.stride_h = self.strides[-2] - self.stride_w = self.strides[-1] - self.add_prim_attr("strides", (1, 1, self.strides[-2], self.strides[-1])) - - if self.padding == "VALID": - self.pad = 0 - elif self.padding == "SAME": - self.pad = math.floor((self.pool_h - 1) / 2) + if isinstance(ksize, int): + validator.check_integer("ksize", ksize, 1, Rel.GE) + if self.is_maxpoolgradwithargmax: + self.ksize = (1, ksize, ksize, 1) + else: + self.ksize = (1, 1, ksize, ksize) else: - raise ValueError('The padding should be str and must be SAME or VALID,' - ' but got {}.'.format(self.padding)) + ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number" + f"or a tuple of two or four positive int numbers, but got {ksize}") + if len(ksize) != 2 and len(ksize) != 4: + raise ksize_error + for ksize_val in ksize: + if not isinstance(ksize_val, int) or (ksize_val <= 0): + raise ksize_error + if len(ksize) == 2 and self.is_maxpoolgradwithargmax: + self.ksize = (1, ksize[0], ksize[1], 1) + elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax: + self.ksize = (1, 1, ksize[0], ksize[1]) + else: + self.ksize = ksize + self.add_prim_attr("ksize", self.ksize) + + if isinstance(strides, int): + validator.check_integer("strides", strides, 1, Rel.GE) + if self.is_maxpoolgradwithargmax: + self.strides = (1, strides, strides, 1) + else: + self.strides = (1, 1, strides, strides) + else: + strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number" + f"or a tuple of two or four positive int numbers, but got {strides}") + if len(strides) != 2 and len(strides) != 4: + raise strides_error + for strides_val in strides: + if not isinstance(strides_val, int) or (strides_val <= 0): + raise strides_error + if len(strides) == 2 and self.is_maxpoolgradwithargmax: + self.strides = (1, strides[0], strides[1], 1) + elif len(strides) == 2 and not self.is_maxpoolgradwithargmax: + self.strides = (1, 1, strides[0], strides[1]) + else: + self.strides = strides + self.add_prim_attr("strides", self.strides) class AvgPoolGrad(_PoolGrad): @@ -451,28 +451,13 @@ class MaximumGrad(Primitive): raise NotImplementedError -class MaxPoolGradWithArgmax(PrimitiveWithInfer): +class MaxPoolGradWithArgmax(_PoolGrad): """Computes the gradients of MaxPoolWithArgmax.""" @prim_attr_register - def __init__(self, - pad_mode="valid", - window=0, - pad=0, - stride=1, - data_mode=1, - ceil_mode=0, - alpha=1.0, - beta=0.0): + def __init__(self, ksize=1, strides=1, padding="VALID",): self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output']) - - self.window = window - self.pool_h = self.pool_w = window - self.pad = pad - self.pad_mode = pad_mode - self.stride = stride - self.data_mode = data_mode - self.ceil_mode = ceil_mode + super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding) def infer_shape(self, x_shape, grad_shape, argmax_shape): if not grad_shape: diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 21effd4bd32..9ee98d174e8 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -682,186 +682,83 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): return x_dtype -class MaxPoolWithArgmax(PrimitiveWithInfer): - r""" - Performs max pooling on the input Tensor and return both max values and indices. - - Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs - regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size - :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows. - - .. math:: - \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} - \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) - - Args: - pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". - window (Union[int, tuple[int]]): The size of window, which is the kernel size, two `int` for width - and height. Default: 1. - pad (Union[int, tuple[int]]): If `pad_mode` is `pad`, the pad value to fill, two `int` for width - and height. Default: 0. - stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for - width and height. Default: 1. - - Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tuple of 2 Tensor, the maxpool result and where max values from. - - - **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. - - **mask** (Tensor) - Max values' index represented by the mask. - """ - - @prim_attr_register - def __init__(self, - pad_mode="valid", - window=1, - pad=0, - stride=1, - data_mode=1, - ceil_mode=0, - alpha=1.0, - beta=0.0): - self.init_prim_io_names(inputs=['x'], outputs=['output', 'argmax']) - self.window = validator.check_type('window', window, [int, tuple]) - if isinstance(window, int) and window <= 0: - raise ValueError('Attr \'window\' of \'MaxPoolWithArgmax\' Op passed ' - + str(self.window)+', should be a int or tuple and greater than 0.') - if isinstance(window, tuple) and (len(window) != 2 or - (not isinstance(window[0], int)) or - (not isinstance(window[1], int)) or - window[0] <= 0 or window[1] <= 0): - raise ValueError('Attr \'window\' of \'MaxPoolWithArgmax\' Op passed ' - + str(self.window)+', should be a int or tuple and greater than 0.') - self.pool_h = self.pool_w = window - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - if self.pad_mode == "valid": - self.pad = 0 - elif self.pad_mode == "same": - self.pad = math.floor((self.window - 1) / 2) - elif self.pad_mode == "pad": - self.pad = validator.check_integer('pad', pad, 0, Rel.GE) - - self.data_mode = validator.check_integer('data_mode', data_mode, 1, Rel.EQ) - self.ceil_mode = validator.check_integer('ceil_mode', ceil_mode, 0, Rel.EQ) - self.stride = validator.check_integer('stride', stride, 1, Rel.GE) - self.alpha = validator.check_type('alpha', alpha, [int, float]) - self.beta = validator.check_type('beta', beta, [int, float]) - self.is_tbe = not context.get_context("enable_ge") and context.get_context("device_target") == "Ascend" - - def infer_shape(self, x_shape): - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - pad = self.pad - h_input = x_shape[2] - w_input = x_shape[3] - h_out = (h_input + 2 * pad - (self.window - 1) - 1) / self.stride + 1 - h_out = math.floor(h_out) - w_out = (w_input + 2 * pad - (self.window - 1) - 1) / self.stride + 1 - w_out = math.floor(w_out) - out_shape = [x_shape[0], x_shape[1], h_out, w_out] - for shape_value in out_shape: - if shape_value <= 0: - raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.") - k_size_vec = [1, self.window, self.window, 1] - argmax_shape = [] - if self.is_tbe: - for i in range(4): - if i == 2: - dim = k_size_vec[i - 1] * k_size_vec[i] - argmax_shape.append(dim) - elif i == 3: - dim = math.ceil(out_shape[i - 1] * out_shape[i] / 16) + 1 - argmax_shape.append(dim) - else: - argmax_shape.append(x_shape[i]) - else: - argmax_shape = out_shape - return out_shape, argmax_shape - - def infer_dtype(self, x_dtype): - out_dtype = x_dtype - validator.check_typename("x_type", x_dtype, (mstype.float16, mstype.float32)) - argmax_dtype = mstype.int32 - return out_dtype, argmax_dtype - - class _Pool(PrimitiveWithInfer): r""" Performs max/avg pooling operation. Args: - ksize (Union[int, tuple[int]]): The size of the window to take a max over, that should be a tuple - of two `int` for width and height. Default: 1. - stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for - width and height. Default: 1. - padding (str): The optional values for pad mode "SAME", "VALID". Default: "VALID". + ksize (Union[int, tuple[int]]): The size of the kernel, that should be a tuple + of two `int` for height and width. Default: 1. + strides (Union[int, tuple[int]]): The stride of the window, that should be + a tuple of two `int` for height and width. Default: 1. + padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". """ @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="VALID"): - self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_type('padding', padding, [str]) - self.ksize = ksize - self.strides = strides - self.padding = padding.upper() - self.ksize = validator.check_type('ksize', self.ksize, [int, tuple]) - self.strides = validator.check_type('strides', self.strides, [int, tuple]) - self.padding = validator.check_string('padding', self.padding, ['VALID', 'SAME']) + def __init__(self, ksize=1, strides=1, padding="valid"): self.init_prim_io_names(inputs=['x'], outputs=['output']) + validator.check_type('ksize', ksize, [int, tuple]) + validator.check_type('strides', strides, [int, tuple]) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) self.add_prim_attr("padding", self.padding) - self.add_prim_attr('data_format', "NCHW") + self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax") + if not self.is_maxpoolwithargmax: + self.add_prim_attr('data_format', "NCHW") - if isinstance(self.ksize, int): - self.pool_h = validator.check_integer("ksize", self.ksize, 1, Rel.GE) - self.pool_w = self.pool_h - self.add_prim_attr("ksize", (1, 1, self.ksize, self.ksize)) - elif isinstance(self.ksize, tuple): - if (len(self.ksize) != 2 or (not isinstance(self.ksize[0], int)) or (not isinstance(self.ksize[1], int)) - or self.ksize[0] <= 0 or self.ksize[1] <= 0): - raise ValueError('Each value of attr \'ksize\' of \'MaxPool\' Op passed ' + - str(self.ksize) + ', should be a int or a tuple of length 2 and greater than 0.') - self.pool_h = self.ksize[0] - self.pool_w = self.ksize[1] - self.add_prim_attr("ksize", (1, 1, self.ksize[0], self.ksize[1])) - - if isinstance(self.strides, int): - self.stride_h = validator.check_integer("strides", self.strides, 1, Rel.GE) - self.stride_w = self.stride_h - self.add_prim_attr("strides", (1, 1, self.strides, self.strides)) - elif isinstance(self.strides, tuple): - if (len(self.strides) != 2 or (not isinstance(self.strides[0], int)) or - (not isinstance(self.strides[1], int)) or self.strides[0] <= 0 or self.strides[1] <= 0): - raise ValueError('Each value of attr \'strides\' of \'MaxPool\' Op passed ' + - str(self.strides) + ', should be a int or a tuple of length 2 and greater than 0.') - self.stride_h = self.strides[0] - self.stride_w = self.strides[1] - self.add_prim_attr("strides", (1, 1, self.strides[0], self.strides[1])) - - if self.padding == "VALID": - self.pad = 0 - elif self.padding == "SAME": - self.pad = math.floor((self.pool_h - 1) / 2) + if isinstance(ksize, int): + validator.check_integer("ksize", ksize, 1, Rel.GE) + self.ksize = (1, 1, ksize, ksize) else: - raise ValueError('The padding should be str and must be SAME or VALID,' - ' but got {}.'.format(self.padding)) - self.add_prim_attr('pad', self.pad) + if (len(ksize) != 2 or + (not isinstance(ksize[0], int)) or + (not isinstance(ksize[1], int)) or + ksize[0] <= 0 or + ksize[1] <= 0): + raise ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number or" + f"a tuple of two positive int numbers, but got {ksize}") + self.ksize = (1, 1, ksize[0], ksize[1]) + if self.is_maxpoolwithargmax: + self.ksize = (1, self.ksize[-2], self.ksize[-1], 1) + self.add_prim_attr("ksize", self.ksize) + + if isinstance(strides, int): + validator.check_integer("strides", strides, 1, Rel.GE) + self.strides = (1, 1, strides, strides) + else: + if (len(strides) != 2 or + (not isinstance(strides[0], int)) or + (not isinstance(strides[1], int)) or + strides[0] <= 0 or + strides[1] <= 0): + raise ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number or" + f"a tuple of two positive int numbers, but got {strides}") + self.strides = (1, 1, strides[0], strides[1]) + if self.is_maxpoolwithargmax: + self.strides = (1, self.strides[-2], self.strides[-1], 1) + self.add_prim_attr("strides", self.strides) def infer_shape(self, x_shape): validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - h_input = x_shape[2] - w_input = x_shape[3] - if self.padding == "VALID": - h_out = math.ceil((h_input - (self.pool_h - 1)) / self.stride_h) - w_out = math.ceil((w_input - (self.pool_w - 1)) / self.stride_w) - elif self.padding == "SAME": - h_out = math.ceil(h_input / self.stride_h) - w_out = math.ceil(w_input / self.stride_w) + batch, channel, input_h, input_w = x_shape + if self.is_maxpoolwithargmax: + _, kernel_h, kernel_w, _ = self.ksize + _, stride_h, stride_w, _ = self.strides else: - raise ValueError('The padding should be str and must be SAME or VALID,' - ' but got {}.'.format(self.padding)) + _, _, kernel_h, kernel_w = self.ksize + _, _, stride_h, stride_w = self.strides + + if self.padding == "VALID": + out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h) + out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w) + elif self.padding == "SAME": + out_h = math.ceil(input_h / stride_h) + out_w = math.ceil(input_w / stride_w) + else: + raise ValueError(f"The padding of operator {self.name} should be a str and must be 'SAME' or 'VALID', " + f"but got {self.padding}.") + out_shape = [batch, channel, out_h, out_w] - out_shape = [x_shape[0], x_shape[1], h_out, w_out] for shape_value in out_shape: if shape_value <= 0: raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.") @@ -887,11 +784,22 @@ class MaxPool(_Pool): \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) Args: - ksize (Union[int, tuple[int]]): The size of the window to take a max over, that should be a tuple - of two `int` for width and height. Default: 1. - stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for - width and height. Default: 1. - padding (str): The optional values for pad mode "SAME", "VALID". Default: "VALID". + ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value, + is an int number that represents height and width are both ksize, or a tuple + of two int numbers that represent height and width respectively. Default: 1. + strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". + + - same: Adopts the way of completion. Output height and width will be the same as + the input. Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. + + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -901,10 +809,83 @@ class MaxPool(_Pool): """ @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="VALID"): + def __init__(self, ksize=1, strides=1, padding="valid"): super(MaxPool, self).__init__(ksize, strides, padding) +class MaxPoolWithArgmax(_Pool): + r""" + Performs max pooling on the input Tensor and return both max values and indices. + + Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs + regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size + :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows. + + .. math:: + \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} + \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) + + Args: + ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value and arg value, + is an int number that represents height and width are both ksize, or a tuple of + two int numbers that represent height and width respectively. Default: 1. + strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". + + - same: Adopts the way of completion. Output height and width will be the same as + the input. Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. + + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. + + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tuple of 2 Tensor, the maxpool result and where max values from. + + - **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. + - **mask** (Tensor) - Max values' index represented by the mask. + """ + + def __init__(self, ksize=1, strides=1, padding="valid"): + super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding) + self.is_tbe = context.get_context("device_target") == "Ascend" + + def infer_shape(self, x_shape): + out_shape = _Pool.infer_shape(self, x_shape) + _, _, out_h, out_w = out_shape + _, kernel_h, kernel_w, _ = self.ksize + + argmax_shape = [] + if self.is_tbe: + for i in range(4): + if i == 2: + dim = kernel_h * kernel_w + argmax_shape.append(dim) + elif i == 3: + dim = math.ceil(out_h * out_w / 16) + 1 + argmax_shape.append(dim) + else: + argmax_shape.append(x_shape[i]) + else: + argmax_shape = out_shape + + return out_shape, argmax_shape + + def infer_dtype(self, x_dtype): + out_dtype = x_dtype + validator.check_typename("x_type", x_dtype, (mstype.float16, mstype.float32)) + argmax_dtype = mstype.uint16 + return out_dtype, argmax_dtype + + class AvgPool(_Pool): r""" Average pooling operation. @@ -919,11 +900,22 @@ class AvgPool(_Pool): \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) Args: - ksize (Union[int, tuple[int]]): The size of the window to take a average over, that should be a tuple - of two `int` for width and height. Default: 1. - stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for - width and height. Default: 1. - padding (str): The optional values for pad mode "SAME", "VALID". Default: "VALID". + ksize (Union[int, tuple[int]]): The size of kernel used to take the average value, + is an int number that represents height and width are both ksize, or a tuple + of two int numbers that represent height and width respectively. Default: 1. + strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". + + - same: Adopts the way of completion. Output height and width will be the same as + the input. Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. + + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -933,7 +925,7 @@ class AvgPool(_Pool): """ @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="VALID"): + def __init__(self, ksize=1, strides=1, padding="valid"): if context.get_context("device_target") == "GPU": self.target = "GPU" else: diff --git a/tests/perf_test/resnet_example.py b/tests/perf_test/resnet_example.py index 19d235c2b11..34413109de3 100644 --- a/tests/perf_test/resnet_example.py +++ b/tests/perf_test/resnet_example.py @@ -103,7 +103,7 @@ class ResNet50(nn.Cell): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='valid') + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid') self.layer1 = self.MakeLayer( block, 3, in_channels=64, out_channels=256, stride=1) diff --git a/tests/st/networks/test_cpu_lenet.py b/tests/st/networks/test_cpu_lenet.py index a3105721d34..9fd50f5d9b3 100644 --- a/tests/st/networks/test_cpu_lenet.py +++ b/tests/st/networks/test_cpu_lenet.py @@ -21,6 +21,7 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore import Tensor + class LeNet(nn.Cell): def __init__(self): super(LeNet, self).__init__() @@ -50,8 +51,10 @@ class LeNet(nn.Cell): output = self.fc3(output) return output + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + def train(net, data, label): learning_rate = 0.01 momentum = 0.9 @@ -67,11 +70,12 @@ def train(net, data, label): print("+++++++++++++++++++++++++++") assert res + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_lenet(): - data = Tensor(np.ones([32, 1 ,32, 32]).astype(np.float32) * 0.01) + data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = LeNet() train(net, data, label) diff --git a/tests/st/networks/test_gpu_alexnet.py b/tests/st/networks/test_gpu_alexnet.py index 3b193e17d61..9f92fc630e6 100644 --- a/tests/st/networks/test_gpu_alexnet.py +++ b/tests/st/networks/test_gpu_alexnet.py @@ -38,7 +38,7 @@ class AlexNet(nn.Cell): self.conv4 = nn.Conv2d(384, 384, 3, stride=1, pad_mode="same") self.conv5 = nn.Conv2d(384, 256, 3, stride=1, pad_mode="same") self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2,pad_mode="valid",padding=0) + self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid") self.flatten = nn.Flatten() self.fc1 = nn.Dense(6*6*256, 4096) self.fc2 = nn.Dense(4096, 4096) diff --git a/tests/st/ops/davinci/test_maxpool_with_argmax.py b/tests/st/ops/davinci/test_maxpool_with_argmax.py index c9312d666ca..a6c875a9e8e 100644 --- a/tests/st/ops/davinci/test_maxpool_with_argmax.py +++ b/tests/st/ops/davinci/test_maxpool_with_argmax.py @@ -20,26 +20,29 @@ import numpy as np import mindspore.context as context from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter + context.set_context(device_target="Ascend") + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.maxpool = P.MaxPoolWithArgmax(pad_mode="same", - window=3, - stride=2) + self.maxpool = P.MaxPoolWithArgmax(padding="same", + ksize=3, + strides=2) self.x = Parameter(initializer( - 'normal', [1, 64, 112, 112]), name='w') + 'normal', [1, 64, 112, 112]), name='w') self.add = P.TensorAdd() - @ms_function def construct(self): output = self.maxpool(self.x) return output[0] + def test_net(): - x = np.random.randn(1,64,112,112).astype(np.float32) + x = np.random.randn(1, 64, 112, 112).astype(np.float32) maxpool = Net() output = maxpool() print("***********output output*********") diff --git a/tests/st/ops/davinci/test_maxpool_with_argmax_grad.py b/tests/st/ops/davinci/test_maxpool_with_argmax_grad.py index d97e2a06f84..3bbc835c1b3 100644 --- a/tests/st/ops/davinci/test_maxpool_with_argmax_grad.py +++ b/tests/st/ops/davinci/test_maxpool_with_argmax_grad.py @@ -37,9 +37,9 @@ class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.maxpool = P.MaxPoolWithArgmax(pad_mode="same", - window=3, - stride=2) + self.maxpool = P.MaxPoolWithArgmax(padding="same", + ksize=3, + strides=2) @ms_function def construct(self, x): diff --git a/tests/st/tbe_networks/resnet.py b/tests/st/tbe_networks/resnet.py index a1ece6556ed..2024286b8fb 100644 --- a/tests/st/tbe_networks/resnet.py +++ b/tests/st/tbe_networks/resnet.py @@ -267,7 +267,7 @@ class ResNet(nn.Cell): self.bn1 = bn_with_initialize(64) self.relu = P.ReLU() - self.maxpool = P.MaxPoolWithArgmax(window=3, stride=2, pad_mode="same") + self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME") self.layer1 = MakeLayer0(block, layer_num[0], in_channels=64, out_channels=256, stride=1) self.layer2 = MakeLayer1(block, layer_num[1], in_channels=256, out_channels=512, stride=2) diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/hw_opt_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/hw_opt_test.py index 0afffc99df6..2877bc7c7af 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/hw_opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/hw_opt_test.py @@ -21,7 +21,7 @@ addn = P.AddN() add = P.TensorAdd() sub = P.Sub() mul = P.Mul() -max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) make_tuple = Primitive('make_tuple') four2five = Primitive('Four2Five') five2four = Primitive('Five2Four') diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_trans_op_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_trans_op_test.py index 57bd2000c4e..a24501e8b14 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_trans_op_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_trans_op_test.py @@ -17,7 +17,7 @@ from mindspore.ops import Primitive tuple_getitem = Primitive('tuple_getitem') add = P.TensorAdd() -max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) make_tuple = Primitive('make_tuple') transdata = Primitive("TransData") diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/mixed_precision_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/mixed_precision_test.py index 8cbad52db1d..7d3985376b4 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/mixed_precision_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/mixed_precision_test.py @@ -21,7 +21,7 @@ addn = P.AddN() add = P.TensorAdd() sub = P.Sub() mul = P.Mul() -max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) make_tuple = Primitive('make_tuple') cast = Primitive('Cast') diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/transdata_split_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/transdata_split_test.py index 8cd18d1ac30..e353cf8fbe6 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/transdata_split_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/transdata_split_test.py @@ -17,7 +17,7 @@ from mindspore.ops import Primitive tuple_getitem = Primitive('tuple_getitem') add = P.TensorAdd() -max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) make_tuple = Primitive('make_tuple') four2five = Primitive('Four2Five') five2four = Primitive('Five2Four') diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_transdata_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_transdata_fusion_test.py index ea3def743db..c4fc50e0da6 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_transdata_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_transdata_fusion_test.py @@ -17,7 +17,7 @@ from mindspore.ops import Primitive tuple_getitem = Primitive('tuple_getitem') add = P.TensorAdd() -max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) make_tuple = Primitive('make_tuple') transdata = Primitive("TransData") Transpose = P.Transpose() diff --git a/tests/ut/cpp/python_input/gtest_input/session/session_test.py b/tests/ut/cpp/python_input/gtest_input/session/session_test.py index ed074fc8d61..ee034a1ae01 100644 --- a/tests/ut/cpp/python_input/gtest_input/session/session_test.py +++ b/tests/ut/cpp/python_input/gtest_input/session/session_test.py @@ -22,7 +22,7 @@ add = P.TensorAdd() reshape = P.Reshape() cast = P.Cast() tuple_getitem = Primitive('tuple_getitem') -max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) def test_addn_cast(x, y, z): sum = addn((x, y)) diff --git a/tests/ut/python/communication/test_data_parallel_resnet.py b/tests/ut/python/communication/test_data_parallel_resnet.py index 037152a0b7a..220e553b4f8 100644 --- a/tests/ut/python/communication/test_data_parallel_resnet.py +++ b/tests/ut/python/communication/test_data_parallel_resnet.py @@ -107,7 +107,7 @@ class ResNet18(nn.Cell): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='pad') + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') self.layer1 = self.MakeLayer( block, 2, in_channels=64, out_channels=256, stride=1) @@ -176,7 +176,7 @@ class ResNet9(nn.Cell): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='same') + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') self.layer1 = self.MakeLayer( block, 1, in_channels=64, out_channels=256, stride=1) diff --git a/tests/ut/python/exec/resnet_example.py b/tests/ut/python/exec/resnet_example.py index bfbb64f7321..913e90a0bb5 100644 --- a/tests/ut/python/exec/resnet_example.py +++ b/tests/ut/python/exec/resnet_example.py @@ -189,7 +189,7 @@ class ResNet50(nn.Cell): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, weight_init=weight_conv) self.bn1 = bn_with_initialize(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) self.layer1 = MakeLayer3( block, in_channels=64, out_channels=256, stride=1) diff --git a/tests/ut/python/exec/test_pooling.py b/tests/ut/python/exec/test_pooling.py index 9c378c15c2c..0e526ff8d68 100644 --- a/tests/ut/python/exec/test_pooling.py +++ b/tests/ut/python/exec/test_pooling.py @@ -23,12 +23,10 @@ class MaxNet(nn.Cell): """MaxNet definition""" def __init__(self, kernel_size, - stride=None, - padding=0): + stride=None): super(MaxNet, self).__init__() self.maxpool = nn.MaxPool2d(kernel_size, - stride, - padding=padding) + stride) def construct(self, input_x): return self.maxpool(input_x) diff --git a/tests/ut/python/model/res18_example.py b/tests/ut/python/model/res18_example.py index eaf8bbc387d..88753334655 100644 --- a/tests/ut/python/model/res18_example.py +++ b/tests/ut/python/model/res18_example.py @@ -106,7 +106,7 @@ class ResNet18(nn.Cell): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='pad') + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') self.layer1 = self.MakeLayer( block, 2, in_channels=64, out_channels=256, stride=1) @@ -175,7 +175,7 @@ class ResNet9(nn.Cell): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) self.layer1 = self.MakeLayer( block, 1, in_channels=64, out_channels=256, stride=1) diff --git a/tests/ut/python/nn/test_cell.py b/tests/ut/python/nn/test_cell.py index 882756f3d29..c583b27c1d9 100644 --- a/tests/ut/python/nn/test_cell.py +++ b/tests/ut/python/nn/test_cell.py @@ -87,7 +87,7 @@ class ConvNet(nn.Cell): self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3) self.bn1 = nn.BatchNorm2d(ConvNet.output_ch) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="pad", padding=1) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.flatten = nn.Flatten() self.fc = nn.Dense( int(ConvNet.image_h*ConvNet.image_w*ConvNet.output_ch/(4*4)), diff --git a/tests/ut/python/nn/test_pooling.py b/tests/ut/python/nn/test_pooling.py index 694d202d13a..10bb7632b27 100644 --- a/tests/ut/python/nn/test_pooling.py +++ b/tests/ut/python/nn/test_pooling.py @@ -46,8 +46,7 @@ class MaxNet(nn.Cell): padding=0): super(MaxNet, self).__init__() self.maxpool = nn.MaxPool2d(kernel_size, - stride, - padding=padding) + stride) def construct(self, x): return self.maxpool(x) diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 5b9f37864cb..cadac6dfb43 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -108,6 +108,7 @@ class ResidualBlock(nn.Cell): class VirtualLossGrad(PrimitiveWithInfer): """ VirtualLossGrad definition """ + @prim_attr_register def __init__(self): """init VirtualLossGrad""" @@ -124,6 +125,7 @@ class VirtualLossGrad(PrimitiveWithInfer): class VirtualLoss(PrimitiveWithInfer): """ VirtualLoss definition """ + @prim_attr_register def __init__(self): """init VirtualLoss""" @@ -138,6 +140,7 @@ class VirtualLoss(PrimitiveWithInfer): # pylint: disable=unused-argument dx = loss_grad(x, out, dout) return (dx,) + return bprop def infer_shape(self, x_shape): @@ -149,6 +152,7 @@ class VirtualLoss(PrimitiveWithInfer): class VirtualNetWithLoss(nn.Cell): """ VirtualNetWithLoss definition """ + def __init__(self, network): super(VirtualNetWithLoss, self).__init__() self.loss = VirtualLoss() @@ -161,6 +165,7 @@ class VirtualNetWithLoss(nn.Cell): class SoftMaxGrad(nn.Cell): """ SoftMaxGrad definition """ + def __init__(self, network): super(SoftMaxGrad, self).__init__() self.network = network @@ -171,6 +176,7 @@ class SoftMaxGrad(nn.Cell): class DropoutGrad(nn.Cell): """ DropoutGrad definition """ + def __init__(self, network): super(DropoutGrad, self).__init__() self.network = network @@ -181,6 +187,7 @@ class DropoutGrad(nn.Cell): class ScalarSummaryNet(nn.Cell): """ ScalarSummaryNet definition """ + def __init__(self): super(ScalarSummaryNet, self).__init__() self.summary = P.ScalarSummary() @@ -193,6 +200,7 @@ class ScalarSummaryNet(nn.Cell): class FusedBatchNormGrad(nn.Cell): """ FusedBatchNormGrad definition """ + def __init__(self, network): super(FusedBatchNormGrad, self).__init__() self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) @@ -204,6 +212,7 @@ class FusedBatchNormGrad(nn.Cell): class NetWithLoss(nn.Cell): """ NetWithLoss definition """ + def __init__(self, network): super(NetWithLoss, self).__init__() self.loss = P.SmoothL1Loss() @@ -216,6 +225,7 @@ class NetWithLoss(nn.Cell): class Grad(nn.Cell): """ GradWrap definition """ + def __init__(self, network): super(Grad, self).__init__() self.network = network @@ -227,6 +237,7 @@ class Grad(nn.Cell): class BatchnormNet(nn.Cell): """ BatchnormNet definition """ + def __init__(self): super(BatchnormNet, self).__init__() self.conv1 = nn.Conv2d(3, 4, kernel_size=8, stride=2, pad_mode="pad", padding=3) @@ -247,6 +258,7 @@ class BatchnormNet(nn.Cell): class NetWithLossClass(nn.Cell): """ NetWithLossClass definition """ + def __init__(self, network): super(NetWithLossClass, self).__init__(auto_prefix=False) self.loss = nn.SoftmaxCrossEntropyWithLogits() @@ -259,12 +271,13 @@ class NetWithLossClass(nn.Cell): class BlockNet(nn.Cell): """ BlockNet definition """ + def __init__(self): super(BlockNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, pad_mode="pad", padding=3) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) self.block_down_sample = ResidualBlock( 64, 256, stride=1, down_sample=True ) @@ -281,6 +294,7 @@ class BlockNet(nn.Cell): class Conv2dWithBiasNet(nn.Cell): """ Conv2dWithBiasNet definition """ + def __init__(self): super(Conv2dWithBiasNet, self).__init__() self.conv = nn.Conv2d(3, 10, 1, bias_init='zeros') @@ -292,6 +306,7 @@ class Conv2dWithBiasNet(nn.Cell): class Conv2dNativeNet(nn.Cell): """ Conv2dNativeNet definition """ + def __init__(self): super(Conv2dNativeNet, self).__init__() self.conv = P.DepthwiseConv2dNative(channel_multiplier=3, kernel_size=(3, 3)) @@ -309,9 +324,10 @@ class Conv2dNativeNet(nn.Cell): class MakeRefKeyNet(nn.Cell): """ MakeRefKeyNet definition """ + def __init__(self): super(MakeRefKeyNet, self).__init__() - self.y= Parameter(Tensor([1.0], mindspore.float32), name="y") + self.y = Parameter(Tensor([1.0], mindspore.float32), name="y") def construct(self, x): key = P.MakeRefKey("y")() @@ -321,6 +337,7 @@ class MakeRefKeyNet(nn.Cell): class StateNet(nn.Cell): """ StateTestTensor definition """ + def __init__(self): super(StateNet, self).__init__() weight = Tensor(np.ones([2, 1, 2, 2], np.float32)) @@ -347,6 +364,24 @@ class ComparisonNet(nn.Cell): return ret +def test_max_pool_with_arg_max(): + class NetMaxPoolWithArgMax(nn.Cell): + def __init__(self): + """ ComparisonNet definition """ + super(NetMaxPoolWithArgMax, self).__init__() + self.max_pool_with_arg_max = P.MaxPoolWithArgmax(padding="valid", ksize=2, strides=1) + + def construct(self, x): + ret = self.max_pool_with_arg_max(x) + return ret + + x = Tensor(np.ones([1, 1, 3, 3], np.float32)) + net = NetMaxPoolWithArgMax() + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + ret = net(x) + print(ret) + + test_cases = [ ('SoftMaxGrad', { 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), @@ -382,7 +417,7 @@ test_cases = [ 'desc_inputs': [Tensor(np.ones([1, 3, 8, 8], np.float32)), Tensor(np.zeros([1, 64, 4, 4], np.float32))], }), ('Conv2dWithBiasGrad', { - 'block': Grad(NetWithLossClass(Conv2dWithBiasNet())), + 'block': Grad(NetWithLossClass(Conv2dWithBiasNet())), 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 2560], np.float32))], }), ('Conv2dNativeGrad', { @@ -407,114 +442,93 @@ test_cases = [ }), ] - test_cases_for_verify_exception = [ ('Conv2d_ValueError_1', { - 'block': (lambda _ : P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}), + 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}), 'desc_inputs': [0], }), ('Conv2d_ValueError_2', { - 'block': (lambda _ : P.Conv2D(3, 4, mode=-2), {'exception': ValueError}), + 'block': (lambda _: P.Conv2D(3, 4, mode=-2), {'exception': ValueError}), 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_1', { - 'block': (lambda _ : P.MaxPoolWithArgmax(pad_mode='sane'), {'exception': ValueError}), + 'block': (lambda _: P.MaxPoolWithArgmax(padding='sane'), {'exception': ValueError}), 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_2', { - 'block': (lambda _ : P.MaxPoolWithArgmax(data_mode=2), {'exception': ValueError}), + 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': ValueError}), 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_3', { - 'block': (lambda _ : P.MaxPoolWithArgmax(ceil_mode=2), {'exception': ValueError}), + 'block': (lambda _: P.MaxPoolWithArgmax(ksize=-2), {'exception': ValueError}), 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_4', { - 'block': (lambda _ : P.MaxPoolWithArgmax(pad_mode="pad", pad=-1), {'exception': ValueError}), - 'desc_inputs': [0], - }), - ('MaxPoolWithArgmax_ValueError_5', { - 'block': (lambda _ : P.MaxPoolWithArgmax(pad_mode="pad", pad='1'), {'exception': ValueError}), - 'desc_inputs': [0], - }), - ('MaxPoolWithArgmax_ValueError_6', { - 'block': (lambda _ : P.MaxPoolWithArgmax(window='1'), {'exception': ValueError}), - 'desc_inputs': [0], - }), - ('MaxPoolWithArgmax_ValueError_7', { - 'block': (lambda _ : P.MaxPoolWithArgmax(window=-2), {'exception': ValueError}), - 'desc_inputs': [0], - }), - ('MaxPoolWithArgmax_ValueError_8', { - 'block': (lambda _ : P.MaxPoolWithArgmax(stride=-1), {'exception': ValueError}), - 'desc_inputs': [0], - }), - ('MaxPoolWithArgmax_ValueError_9', { - 'block': (lambda _ : P.MaxPoolWithArgmax(alpha='1'), {'exception': ValueError}), + 'block': (lambda _: P.MaxPoolWithArgmax(strides=-1), {'exception': ValueError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_1', { - 'block': (lambda _ : P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_2', { - 'block': (lambda _ : P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_3', { - 'block': (lambda _ : P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_4', { - 'block': (lambda _ : P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_5', { - 'block': (lambda _ : P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}), 'desc_inputs': [0], }), ('Softmax_ValueError_1', { - 'block': (lambda _ : P.Softmax("1"), {'exception': ValueError}), + 'block': (lambda _: P.Softmax("1"), {'exception': ValueError}), 'desc_inputs': [0], }), ('Softmax_ValueError_2', { - 'block': (lambda _ : P.Softmax(1.1), {'exception': ValueError}), + 'block': (lambda _: P.Softmax(1.1), {'exception': ValueError}), 'desc_inputs': [0], }), ('Softmax_ValueError_3', { - 'block': (lambda _ : P.Softmax(axis="1"), {'exception': ValueError}), + 'block': (lambda _: P.Softmax(axis="1"), {'exception': ValueError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_1', { - 'block': (lambda _ : P.DropoutGenMask(Seed0="seed0"), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': ValueError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_2', { - 'block': (lambda _ : P.DropoutGenMask(Seed0=1.0), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': ValueError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_3', { - 'block': (lambda _ : P.DropoutGenMask(Seed1="seed1"), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': ValueError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_4', { - 'block': (lambda _ : P.DropoutGenMask(Seed1=2.0), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': ValueError}), 'desc_inputs': [0], }), ('MaxPool2d_ValueError_1', { - 'block': (nn.MaxPool2d(kernel_size=120, stride=1, pad_mode="valid", padding=0), {'exception': ValueError}), + 'block': (nn.MaxPool2d(kernel_size=120, stride=1, pad_mode="valid"), {'exception': ValueError}), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], }), ('MaxPool2d_ValueError_2', { 'block': ( - lambda _ : nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid", padding=0), + lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"), {'exception': ValueError}, ), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], }), ('MaxPool2d_ValueError_3', { 'block': ( - lambda _ : nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid", padding=0), + lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"), {'exception': ValueError}, ), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], @@ -532,4 +546,3 @@ def test_compile(): @mindspore_test(pipeline_for_verify_exception_for_case_by_case_config) def test_check_exception(): return test_cases_for_verify_exception - diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 0f5b716e390..092d6e32be8 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -571,7 +571,7 @@ test_case_nn_ops = [ 'desc_bprop': [[3, 4, 6, 6]], 'skip': ['backward']}), ('MaxPoolWithArgmax', { - 'block': P.MaxPoolWithArgmax(window=2, stride=2), + 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), 'desc_inputs': [[128, 32, 32, 64]], 'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}), ('SoftmaxCrossEntropyWithLogits', { diff --git a/tests/ut/python/ops/test_ops_check.py b/tests/ut/python/ops/test_ops_check.py index a7e1b41c4a6..aa379cc64e5 100644 --- a/tests/ut/python/ops/test_ops_check.py +++ b/tests/ut/python/ops/test_ops_check.py @@ -160,16 +160,16 @@ test_case_check_ops = [ 'block': nn.Dense(1, 6, has_bias=False, bias_init=Tensor(np.ones([6]).astype(np.float32))), 'desc_inputs': [Tensor(np.ones(shape=[6, 1]).astype(np.float32))]}), ('MaxPool2d_1', { - 'block': nn.MaxPool2d(5, pad_mode='same', padding=0), + 'block': nn.MaxPool2d(5, pad_mode='same'), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), ('MaxPool2d_2', { - 'block': nn.MaxPool2d(5, pad_mode='valid', padding=0), + 'block': nn.MaxPool2d(5, pad_mode='valid'), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), ('AvgPool2d_1', { - 'block': nn.AvgPool2d(5, pad_mode='same', padding=0), + 'block': nn.AvgPool2d(5, pad_mode='same'), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), ('AvgPool2d_2', { - 'block': nn.AvgPool2d(5, pad_mode='valid', padding=0), + 'block': nn.AvgPool2d(5, pad_mode='valid'), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), ('Conv2D_1', { 'block': P.Conv2D(1, 6, pad_mode='same', pad=0), diff --git a/tests/ut/python/pynative_mode/ge/ops/test_pooling.py b/tests/ut/python/pynative_mode/ge/ops/test_pooling.py index e6cf88a9ca9..d5b90b6edd0 100644 --- a/tests/ut/python/pynative_mode/ge/ops/test_pooling.py +++ b/tests/ut/python/pynative_mode/ge/ops/test_pooling.py @@ -42,12 +42,10 @@ def test_maxpool2d(): """ test_maxpool2d """ kernel_size = 3 stride = 3 - padding = 0 - max_pool = nn.MaxPool2d(kernel_size, stride, padding=padding) + max_pool = nn.MaxPool2d(kernel_size, stride) assert max_pool.kernel_size == 3 assert max_pool.stride == 3 - assert max_pool.padding == 0 input_data = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) output = max_pool(input_data) output_np = output.asnumpy() diff --git a/tests/ut/python/pynative_mode/nn/test_cell.py b/tests/ut/python/pynative_mode/nn/test_cell.py index 16adcd61192..2d5196b80d8 100644 --- a/tests/ut/python/pynative_mode/nn/test_cell.py +++ b/tests/ut/python/pynative_mode/nn/test_cell.py @@ -89,7 +89,7 @@ class ConvNet(nn.Cell): self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode='pad', padding=3) self.bn1 = nn.BatchNorm2d(ConvNet.output_ch) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='pad', padding=1) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.flatten = nn.Flatten() self.fc = nn.Dense( int(ConvNet.image_h*ConvNet.image_w*ConvNet.output_ch/(4*4)), diff --git a/tests/ut/python/pynative_mode/nn/test_pooling.py b/tests/ut/python/pynative_mode/nn/test_pooling.py index ab95fec091d..bb1822f8a8a 100644 --- a/tests/ut/python/pynative_mode/nn/test_pooling.py +++ b/tests/ut/python/pynative_mode/nn/test_pooling.py @@ -49,23 +49,14 @@ def test_maxpool2d(): """ test_maxpool2d """ kernel_size = 3 stride = 3 - padding = 2 - max_pool = nn.MaxPool2d(kernel_size, stride, pad_mode='SAME', padding=padding) + max_pool = nn.MaxPool2d(kernel_size, stride, pad_mode='SAME') assert max_pool.kernel_size == 3 assert max_pool.stride == 3 - assert max_pool.padding == 2 input_data = Tensor(np.random.randint(0, 255, [1, 3, 6, 6])*0.1) output = max_pool(input_data) output_np = output.asnumpy() assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) -def test_maxpool2d_error_padding(): - """ test_maxpool2d_error_padding """ - kernel_size = 3.5 - stride = 3 - padding = 1 - with pytest.raises(ValueError): - nn.MaxPool2d(kernel_size, stride, padding=padding) diff --git a/tests/ut/python/pynative_mode/vm/test_vm.py b/tests/ut/python/pynative_mode/vm/test_vm.py index 4ea0abd7538..77510337b0e 100644 --- a/tests/ut/python/pynative_mode/vm/test_vm.py +++ b/tests/ut/python/pynative_mode/vm/test_vm.py @@ -23,7 +23,7 @@ def test_avg_pooling(): [-9., -1., 3., 4.], [1., -1., -3., -6.], [-2., -1., -2., -15.]]]]).astype(np.float32) - out = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1, pad=0) + out = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1) expect_out = [[[[-4.25, 0.0, 4.25], [-2.5, -0.5, -0.5], [-0.75, -1.75, -6.5]]]] @@ -37,9 +37,9 @@ def test_avg_pool_grad(): [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]]).astype(np.float32) - dout = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1, pad=0) + dout = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1) print("vm.avg_pooling dout: ", dout) - out = vm.avg_pool_grad(dout, input_data.shape, 2, 2, 1, 0) + out = vm.avg_pool_grad(dout, input_data.shape, 2, 2, 1) print("vm.avg_pool_grad: ", out) assert True @@ -202,7 +202,7 @@ def test_max_pooling(): [-9., -1., 3., 4.], [1., -1., -3., -6.], [-2., -1., -2., -15.]]]]).astype(np.float32) - out = vm.max_pooling(input_data, pool_h=2, pool_w=2, stride=1, pad=0) + out = vm.max_pooling(input_data, pool_h=2, pool_w=2, stride=1) expect_out = [[[[-1., 3., 9.], [1., 3., 4.], [1., -1., -2.]]]] diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index 41da45ab255..cc6f346b772 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -44,7 +44,7 @@ class Net(nn.Cell): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros") self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) self.flatten = nn.Flatten() self.fc = nn.Dense(int(224*224*64/16), num_classes) diff --git a/tests/vm_impl/nn_ops_vm_impl.py b/tests/vm_impl/nn_ops_vm_impl.py index f6bbdca55a2..fc1fa95024d 100644 --- a/tests/vm_impl/nn_ops_vm_impl.py +++ b/tests/vm_impl/nn_ops_vm_impl.py @@ -19,66 +19,82 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.common.tensor import Tensor from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from .vm_interface import vm + + # pylint: disable=unused-argument @vm_impl_getters.register(P.ScalarSummary) def vm_impl_scalar_summary(self): """Generate vm_impl function for ScalarSummary""" + def vm_impl(string_in, scalar): """Implement by vm mode.""" return scalar + return vm_impl @vm_impl_getters.register(P.ReLU) def vm_impl_relu(self): """Generate vm_impl function for ReLU""" + def vm_impl(x): x = x.asnumpy() output = Tensor(vm.relu(x)) return output + return vm_impl + @vm_impl_getters.register(P.Flatten) def vm_impl_flatten(self): """Generate vm_impl function for Flatten""" + def vm_impl(x): x = x.asnumpy() return Tensor(vm.flatten_batch(x)) + return vm_impl @vm_impl_getters.register(P.Softmax) def vm_impl_softmax(self): """Generate vm_impl function for Softmax""" + def vm_impl(x): x = x.asnumpy() return Tensor(vm.softmax(x)) + return vm_impl @vm_impl_getters.register(P.LogSoftmax) def vm_impl_log_softmax(self): """Generate vm_impl function for LogSoftmax""" + def vm_impl(x): x = x.asnumpy() return Tensor(vm.logsoftmax(x)) + return vm_impl @vm_impl_getters.register(P.Tanh) def vm_impl_tanh(self): """Generate vm_impl function for Tanh""" + def vm_impl(x): x = x.asnumpy() return Tensor(vm.tanh(x)) + return vm_impl @vm_impl_getters.register(P.FusedBatchNorm) def vm_impl_fused_batch_norm(self): """Generate vm_impl function for FusedBatchNorm""" + def vm_impl(x, scale, b, mean, variance): # pylint: disable=unused-argument x = x.asnumpy() @@ -92,12 +108,14 @@ def vm_impl_fused_batch_norm(self): momentum=self.momentum) return Tensor(out), Tensor(x_mean), Tensor(x_var), \ Tensor(running_mean), Tensor(running_var) + return vm_impl @vm_impl_getters.register(P.BatchNorm) def vm_impl_batch_norm(self): """Generate vm_impl function for BatchNorm""" + def vm_impl(x, scale, b, mean, variance): # pylint: disable=unused-argument x = x.asnumpy() @@ -110,83 +128,106 @@ def vm_impl_batch_norm(self): eps=self.epsilon) return Tensor(out), Tensor(x_mean), Tensor(x_var), \ Tensor(running_mean), Tensor(running_var) + return vm_impl @vm_impl_getters.register(P.Conv2D) def vm_impl_conv2d(self): """Generate vm_impl function for Conv2D""" + def vm_impl(x, w): x = x.asnumpy() weight = w.asnumpy() bias = None out = vm.conv2d(x, weight, bias, self.stride, self.pad, self.dilation) return Tensor(out) + return vm_impl @vm_impl_getters.register(G.MaxPoolGradWithArgmax) def vm_impl_max_pool_grad_with_argmax(self): """Generate vm_impl function for MaxPoolGradWithArgmax""" - def vm_impl(x, argmax, dout): + + def vm_impl(x, dout, argmax): + print("buxue") + print(argmax) x = x.asnumpy() dout = dout.asnumpy() arg_max = argmax.asnumpy() - dx = vm.max_pool_grad_with_argmax(x, arg_max, dout, self.pool_h, self.pool_w, self.stride, self.pad) + dx = vm.max_pool_grad_with_argmax(x, dout, arg_max, + self.ksize[1], self.ksize[2], self.strides[1]) return Tensor(dx) + return vm_impl @vm_impl_getters.register(P.MaxPoolWithArgmax) def vm_impl_max_pool_with_argmax(self): """Generate vm_impl function for MaxPoolWithArgmax""" + def vm_impl(x): x = x.asnumpy() - out, out_argmax = vm.max_pool_with_argmax(x, self.pool_h, self.pool_w, self.stride, self.pad) + out, out_argmax = vm.max_pool_with_argmax(x, self.ksize[1], self.ksize[2], self.strides[1]) return Tensor(out), Tensor(out_argmax) + return vm_impl + @vm_impl_getters.register(P.MaxPool) def vm_impl_max_pool(self): """Generate vm_impl function for MaxPool""" + def vm_impl(x): x = x.asnumpy() - out = vm.max_pooling(x, self.pool_h, self.pool_w, self.stride_h, self.pad) + out = vm.max_pooling(x, self.ksize[-2], self.ksize[-1], self.strides[-2]) return Tensor(out) + return vm_impl + @vm_impl_getters.register(G.MaxPoolGrad) def vm_impl_max_pool_grad(self): """Generate vm_impl function for MaxPoolGrad""" + def vm_impl(x, out, dout): x = x.asnumpy() dout = dout.asnumpy() - out = vm.max_pool_grad(x, dout, self.pool_h, self.pool_w, self.stride_h, self.pad) + out = vm.max_pool_grad(x, dout, self.ksize[-2], self.ksize[-1], self.strides[-2]) return Tensor(out) + return vm_impl + @vm_impl_getters.register(P.AvgPool) -def vm_impl_max_pool(self): +def vm_impl_avg_pool(self): """Generate vm_impl function for AvgPool""" + def vm_impl(x): x = x.asnumpy() - out = vm.avg_pooling(x, self.pool_h, self.pool_w, self.stride_h, self.pad) + out = vm.avg_pooling(x, self.ksize[-2], self.ksize[-1], self.strides[-2]) return Tensor(out) + return vm_impl + @vm_impl_getters.register(G.AvgPoolGrad) def vm_impl_avg_pool_grad(self): """Generate vm_impl function for AvgPoolGrad""" + def vm_impl(dout, origin_shape): dout = dout.asnumpy() - out = vm.avg_pool_grad(dout, origin_shape, self.pool_h, self.pool_w, self.stride_h, self.pad) + out = vm.avg_pool_grad(dout, origin_shape, self.ksize[-2], self.ksize[-1], self.strides[-2]) return Tensor(out) + return vm_impl @vm_impl_getters.register(G.FusedBatchNormGrad) def vm_impl_fused_batch_norm_grad(self): """Generate vm_impl function for FusedBatchNormGrad""" + def vm_impl(dy, x, scale, save_mean, save_inv_variance): dy = dy.asnumpy() x = x.asnumpy() @@ -195,11 +236,14 @@ def vm_impl_fused_batch_norm_grad(self): save_inv_variance = save_inv_variance.asnumpy() dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance) return (Tensor(dx), Tensor(dscale), Tensor(dshift)) + return vm_impl + @vm_impl_getters.register(G.BatchNormGrad) def vm_impl_fused_batch_norm_grad(self): """Generate vm_impl function for BatchNormGrad""" + def vm_impl(dy, x, scale, save_mean, save_inv_variance): dy = dy.asnumpy() x = x.asnumpy() @@ -208,104 +252,123 @@ def vm_impl_fused_batch_norm_grad(self): save_inv_variance = save_inv_variance.asnumpy() dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance) return (Tensor(dx), Tensor(dscale), Tensor(dshift)) + return vm_impl @vm_impl_getters.register(G.ReluGrad) def vm_impl_relu_grad(self): """Generate vm_impl function for ReluGrad""" + def vm_impl(y_backprop, x): x = x.asnumpy() y_backprop = y_backprop.asnumpy() - y_backprop = vm.relu_grad(x.copy())*y_backprop + y_backprop = vm.relu_grad(x.copy()) * y_backprop return Tensor(y_backprop) + return vm_impl @vm_impl_getters.register(P.Conv2DBackpropInput) def vm_impl_conv2d_backprop_input(self): """Generate vm_impl function for Conv2DBackpropInput""" + def vm_impl(dout, w, x_size): dout = dout.asnumpy() w = w.asnumpy() dx = vm.conv2d_backprop_input(dout, x_size, w, self.stride, self.pad) return Tensor(dx) + return vm_impl @vm_impl_getters.register(G.Conv2DBackpropFilter) def vm_impl_conv2d_backprop_filter(self): """Generate vm_impl function for Conv2DBackpropFilter""" + def vm_impl(dout, x, w_size): x = x.asnumpy() dout = dout.asnumpy() dw = vm.conv2d_backprop_filter(dout, x, w_size, self.stride, self.pad) return Tensor(dw) + return vm_impl @vm_impl_getters.register(G.FlattenGrad) def vm_impl_flatten_grad(self): """Generate vm_impl function for FlattenGrad""" + def vm_impl(dout, x): dout = dout.asnumpy() dout = vm.flatten_grad(dout, x) return Tensor(dout) + return vm_impl @vm_impl_getters.register(P.BiasAdd) def vm_impl_bias_add(self): """Generate vm_impl function for BiasAdd""" + def vm_impl(wx, bias): wx = wx.asnumpy() bias = bias.asnumpy() out = wx + bias return Tensor(out) + return vm_impl @vm_impl_getters.register(G.BiasAddGrad) def vm_impl_bias_add_grad(self): """Generate vm_impl function for BiasAddGrad""" + def vm_impl(dout): dout = dout.asnumpy() shape = np.shape(dout) return Tensor(np.add.reduce(dout, axis=tuple(range(len(shape) - 1)))) + return vm_impl @vm_impl_getters.register(P.SoftmaxCrossEntropyWithLogits) def vm_impl_softmax_cross_entropy_with_logits(self): """Generate vm_impl function for SoftmaxCrossEntropyWithLogits""" + def vm_impl(logits, labels): logits = logits.asnumpy() labels = labels.asnumpy() loss, dx = vm.softmax_cross_entropy_with_logits(logits, labels) return (Tensor(np.array(loss)), Tensor(dx)) + return vm_impl @vm_impl_getters.register(P.SparseSoftmaxCrossEntropyWithLogits) def vm_impl_sparse_softmax_cross_entropy_with_logits(self): """Generate vm_impl function for SparseSoftmaxCrossEntropyWithLogits""" + def vm_impl(logits, labels): logits = logits.asnumpy() labels = labels.asnumpy() n_class = labels.max() + 1 n_sample = labels.shape[0] - one_hot_label = np.zeros((n_sample, n_class))#3个样本,4个类别 - one_hot_label[:, labels] = 1#非零列赋值为1 + one_hot_label = np.zeros((n_sample, n_class)) # 3个样本,4个类别 + one_hot_label[:, labels] = 1 # 非零列赋值为1 loss, dx = vm.softmax_cross_entropy_with_logits(logits, one_hot_label) if self.is_grad: return (Tensor(dx),) return (Tensor(np.array(loss)),) + return vm_impl + @vm_impl_getters.register(P.ApplyMomentum) def vm_impl_momentum(self): """Generate vm_impl function for Momentum""" + def vm_impl(variable, accumulation, learning_rate, @@ -327,19 +390,24 @@ def vm_impl_momentum(self): return vm_impl + @vm_impl_getters.register(P.ResizeBilinear) def vm_impl_resize_bilinear(self): """Generate vm_impl function for ResizeBilinear""" + def vm_impl(x): out = vm.ResizeBilinear(x) return Tensor(out) + return vm_impl @vm_impl_getters.register(G.ResizeBilinearGrad) def vm_impl_resize_bilinear_grad(self): """Generate vm_impl function for ResizeBilinearGrad""" + def vm_impl(dout, original_image): out = vm.ResizeBilinearGrad(dout, original_image) return Tensor(out) + return vm_impl diff --git a/tests/vm_impl/vm_me.py b/tests/vm_impl/vm_me.py index 03a0e1a885c..ba51a3b13bc 100644 --- a/tests/vm_impl/vm_me.py +++ b/tests/vm_impl/vm_me.py @@ -19,7 +19,7 @@ from mindspore._checkparam import Rel from mindspore._checkparam import ParamValidator as validator -def avg_pooling(x, pool_h, pool_w, stride, pad): +def avg_pooling(x, pool_h, pool_w, stride): """ Applies average pooling over an input array. @@ -28,26 +28,25 @@ def avg_pooling(x, pool_h, pool_w, stride, pad): pool_h (int): Height of the pooling window. pool_w (int): Width of the pooling window. stride (int): The stride of the sliding window. - pad (int): Padding to be added on height and width. Returns: numpy.ndarray, an output array after applying average pooling on input array. """ validator.check_integer("stride", stride, 0, Rel.GT) num, channel, height, width = x.shape - out_h = (height + 2*pad - pool_h)//stride + 1 - out_w = (width + 2*pad - pool_w)//stride + 1 + out_h = (height - pool_h)//stride + 1 + out_w = (width - pool_w)//stride + 1 - col = im2col(x, pool_h, pool_w, stride, pad) + col = im2col(x, pool_h, pool_w, stride) col = col.reshape(-1, pool_h*pool_w) out = np.mean(col, axis=1) - out = out.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) + out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2) return out -def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride, pad): +def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride): """ Gets grad of average pooling. @@ -57,7 +56,6 @@ def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride, pad): pool_h (int): Height of the pooling window. pool_w (int): Width of the pooling window. stride (int): The stride of the sliding window. - pad (int): Padding to be added on height and width. Returns: numpy.ndarray, grad of avgerage pooling. @@ -324,38 +322,38 @@ def matmul(x, w, b=None): return y -def max_pooling(x, pool_h, pool_w, stride, pad): +def max_pooling(x, pool_h, pool_w, stride): """Max pooling.""" validator.check_integer("stride", stride, 0, Rel.GT) num, channel, height, width = x.shape - out_h = (height + 2*pad - pool_h)//stride + 1 - out_w = (width + 2*pad - pool_w)//stride + 1 + out_h = (height - pool_h)//stride + 1 + out_w = (width - pool_w)//stride + 1 - col = im2col(x, pool_h, pool_w, stride, pad) + col = im2col(x, pool_h, pool_w, stride) col = col.reshape(-1, pool_h*pool_w) out = np.max(col, axis=1) - out = out.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) + out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2) return out -def max_pool_grad(x, dout, pool_h, pool_w, stride, pad): +def max_pool_grad(x, dout, pool_h, pool_w, stride): """Grad of max pooling.""" dout = dout.transpose(0, 2, 3, 1) pool_size = pool_h * pool_w dmax = np.zeros((dout.size, pool_size)) - col = im2col(x, pool_h, pool_w, stride, pad) + col = im2col(x, pool_h, pool_w, stride) col = col.reshape(-1, pool_h*pool_w) arg_max = np.argmax(col, axis=1) dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten() dmax = dmax.reshape(dout.shape + (pool_size,)) dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1) - dx = col2im(dcol, x.shape, pool_h, pool_w, stride, pad) + dx = col2im(dcol, x.shape, pool_h, pool_w, stride) return dx -def max_pool_grad_with_argmax(x, arg_max, dout, pool_h, pool_w, stride, pad): +def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride): """Grad of max pooling with argmax.""" dout = dout.transpose(0, 2, 3, 1) pool_size = pool_h * pool_w @@ -363,22 +361,22 @@ def max_pool_grad_with_argmax(x, arg_max, dout, pool_h, pool_w, stride, pad): dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten() dmax = dmax.reshape(dout.shape + (pool_size,)) dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1) - dx = col2im(dcol, x.shape, pool_h, pool_w, stride, pad) + dx = col2im(dcol, x.shape, pool_h, pool_w, stride) return dx -def max_pool_with_argmax(x, pool_h, pool_w, stride, pad): +def max_pool_with_argmax(x, pool_h, pool_w, stride): """Max pooling with argmax.""" validator.check_integer("stride", stride, 0, Rel.GT) num, channel, height, width = x.shape - out_h = (height + 2*pad - pool_h)//stride + 1 - out_w = (width + 2*pad - pool_w)//stride + 1 - col = im2col(x, pool_h, pool_w, stride, pad) + out_h = (height - pool_h)//stride + 1 + out_w = (width - pool_w)//stride + 1 + col = im2col(x, pool_h, pool_w, stride) col = col.reshape(-1, pool_h*pool_w) out = np.max(col, axis=1) out_argmax = np.argmax(col, axis=1) - out = out.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) - out_argmax = out_argmax.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) + out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2) + out_argmax = out_argmax.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2) return out, out_argmax