From 4017e6f1376b4b568d6a34748c97af2683b69989 Mon Sep 17 00:00:00 2001 From: fanjibin Date: Fri, 25 Nov 2022 14:24:13 +0800 Subject: [PATCH] fix adaptive_max_pool2d_grad --- .../ops/mindspore.ops.AdaptiveMaxPool2D.rst | 2 +- .../kernel/adaptive_max_pool2d_cpu_kernel.cc | 44 ++----------------- .../cuda_ops/adaptive_max_pool2d_impl.cu | 4 +- .../nn/adaptive_max_pool2d_gpu_kernel.h | 34 +++----------- mindspore/core/ops/adaptive_max_pool2d.cc | 30 ++----------- mindspore/core/ops/adaptive_max_pool2d.h | 1 - .../python/mindspore/nn/layer/pooling.py | 9 +++- .../python/mindspore/ops/_vmap/vmap_nn_ops.py | 19 +++----- .../python/mindspore/ops/function/nn_func.py | 7 ++- .../python/mindspore/ops/operations/nn_ops.py | 10 ++--- .../grad/test_adaptive_max_pool_2d.py | 2 +- .../st/ops/gpu/test_adaptive_max_pool2d_op.py | 25 ----------- .../ut/python/optimizer/test_bprop_mindir.py | 2 +- 13 files changed, 38 insertions(+), 151 deletions(-) diff --git a/docs/api/api_python/ops/mindspore.ops.AdaptiveMaxPool2D.rst b/docs/api/api_python/ops/mindspore.ops.AdaptiveMaxPool2D.rst index 6c5c67c47a2..676e32dc46d 100644 --- a/docs/api/api_python/ops/mindspore.ops.AdaptiveMaxPool2D.rst +++ b/docs/api/api_python/ops/mindspore.ops.AdaptiveMaxPool2D.rst @@ -1,7 +1,7 @@ mindspore.ops.AdaptiveMaxPool2D =============================== -.. py:class:: mindspore.ops.AdaptiveMaxPool2D(output_size, return_indices=False) +.. py:class:: mindspore.ops.AdaptiveMaxPool2D(output_size) 二维自适应最大值池化。 diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool2d_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool2d_cpu_kernel.cc index 20d7f371c43..6350bd3dba7 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool2d_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool2d_cpu_kernel.cc @@ -39,7 +39,6 @@ bool AdaptiveMaxPool2dCpuKernelMod::Init(const BaseOperatorPtr &base_operator, // (H_out, W_out) auto output_size = kernel_ptr->output_size(); (void)std::copy(output_size.begin(), output_size.end(), std::back_inserter(attr_output_size_)); - attr_return_indices_ = kernel_ptr->return_indices(); return MatchKernelFunc(base_operator, inputs, outputs); } @@ -80,33 +79,6 @@ bool AdaptiveMaxPool2dCpuKernelMod::ResizedOutputSize() { return true; } -bool AdaptiveMaxPool2dCpuKernelMod::UpdateOutputSizeList(const std::vector &outputs, - size_t input_type_size) { - output_size_list_.clear(); - // If return_indices is true, the outputs num should be 2, otherwise should be 1. - if ((outputs.size() == ops::kOutputSizeAttrSize - 1 && (!attr_return_indices_)) || - (outputs.size() == ops::kOutputSizeAttrSize && attr_return_indices_)) { - MS_EXCEPTION_IF_NULL(outputs[0]); - auto output_shape = outputs[0]->GetShapeVector(); - size_t output_number = 1; - for (size_t i = 0; i < output_shape.size(); i++) { - output_number *= static_cast(output_shape[i]); - } - // N * C * H * W * type_size - auto output_mem_size = output_number * input_type_size; - output_size_list_.push_back(output_mem_size); - if (outputs.size() == ops::kOutputSizeAttrSize) { - output_size_list_.push_back(output_number * sizeof(int64_t)); - } - return true; - } - - MS_LOG(ERROR) << "For primitive[AdaptiveMaxPool2D], the number of outputs should be 2 when return_indices is True," - " or that should be 1 when return_indices is False, but got the number of outputs : " - << outputs.size() << ", and return_indices: " << attr_return_indices_; - return false; -} - int AdaptiveMaxPool2dCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs, @@ -135,11 +107,6 @@ int AdaptiveMaxPool2dCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, return KRET_RESIZE_FAILED; } - size_t input_type_size = abstract::TypeIdSize(inputs.at(0)->GetDtype()); - if (!UpdateOutputSizeList(outputs, input_type_size)) { - return KRET_RESIZE_FAILED; - } - return KRET_OK; } @@ -171,10 +138,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector & const std::vector &outputs) { T *input_addr = GetDeviceAddress(inputs, kIndex0); T *output_addr = GetDeviceAddress(outputs, kIndex0); - int64_t *indices_addr = nullptr; - if (outputs.size() > 1) { - indices_addr = GetDeviceAddress(outputs, kIndex1); - } + int64_t *indices_addr = GetDeviceAddress(outputs, kIndex1); auto task = [this, &input_addr, &output_addr, &indices_addr](size_t start, size_t end) { for (size_t i = start; i < end; ++i) { @@ -182,7 +146,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector & size_t output_offset = i * output_hw_; T *input_ptr = input_addr + input_offset; T *output_ptr = output_addr + output_offset; - int64_t *indices_ptr = (indices_addr == nullptr) ? indices_addr : indices_addr + output_offset; + int64_t *indices_ptr = indices_addr + output_offset; for (size_t oh_index = 0; oh_index < output_height_; ++oh_index) { size_t h_begin = start_index(oh_index, output_height_, input_height_); @@ -205,9 +169,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector & ComputeLocalMax(&max_indice, &max_val, lw, input_width_, input_ptr); size_t output_index = oh_index * output_width_ + ow_index; output_ptr[output_index] = max_val; - if (indices_addr != nullptr) { - indices_ptr[output_index] = SizeToLong(max_indice); - } + indices_ptr[output_index] = SizeToLong(max_indice); } } } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_impl.cu index 4305a6fbaa7..643386e965d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_impl.cu @@ -69,9 +69,7 @@ __global__ void AdaptiveMaxPool2DKernel(const size_t size, const size_t input_he sub_input_ptr += input_width; } output_ptr[oh * output_width + ow] = max; - if (indices_data != nullptr) { - indices_ptr[oh * output_width + ow] = indice; - } + indices_ptr[oh * output_width + ow] = indice; } } } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool2d_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool2d_gpu_kernel.h index 0043ffee1a6..f0c51c2a9a8 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool2d_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool2d_gpu_kernel.h @@ -67,9 +67,7 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod { T *input_addr = GetDeviceAddress(inputs, 0); T *output_addr = GetDeviceAddress(outputs, 0); int64_t *indices_addr = nullptr; - if (outputs.size() > 1) { - indices_addr = GetDeviceAddress(outputs, 1); - } + indices_addr = GetDeviceAddress(outputs, 1); ApplyAdaptiveMaxPool2D(size_, input_height_, input_width_, output_height_, output_width_, input_addr, output_addr, indices_addr, reinterpret_cast(stream_ptr)); @@ -90,6 +88,10 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod { bool InitSize(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) { + int ret = KernelMod::Resize(base_operator, inputs, outputs); + if (ret != KRET_OK) { + return ret; + } auto kernel_ptr = std::dynamic_pointer_cast(base_operator); if (kernel_ptr == nullptr) { MS_EXCEPTION(ValueError) @@ -116,9 +118,6 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod { return false; } - input_size_list_.clear(); - output_size_list_.clear(); - input_height_ = static_cast(input_shape[len_ - ops::kOutputSizeAttrSize]); input_width_ = static_cast(input_shape[len_ - ops::kOutputSizeAttrSize + 1]); size_ = static_cast(len_ == ops::kFormatCHWShapeSize ? input_shape[0] : input_shape[0] * input_shape[1]); @@ -126,7 +125,6 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod { for (size_t i = 0; i < len_; i++) { input_size_ *= input_shape[i]; } - input_size_list_.push_back(input_size_); auto output_size = kernel_ptr->output_size(); if (output_size.size() == ops::kOutputSizeAttrSize) { @@ -139,27 +137,7 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod { << output_size.size(); return false; } - - size_t output_num = 1; - // If return indices is true, the outputs num should be 2, otherwise should be 1. - if ((outputs.size() == ops::kOutputSizeAttrSize - 1 && (!kernel_ptr->return_indices())) || - (outputs.size() == ops::kOutputSizeAttrSize && kernel_ptr->return_indices())) { - MS_EXCEPTION_IF_NULL(outputs[0]); - auto output_shape = outputs[0]->GetShapeVector(); - for (size_t i = 0; i < output_shape.size(); i++) { - output_num *= output_shape[i]; - } - output_size_ = output_num * sizeof(T); - output_size_list_.push_back(output_size_); - if (outputs.size() == ops::kOutputSizeAttrSize) { - output_size_list_.push_back(output_num * sizeof(int64_t)); - } - return true; - } - - MS_EXCEPTION(ValueError) << "For primitive[AdaptiveMaxPool2D], the size of attr[output_size] should be 2, but got:" - << output_size.size(); - return false; + return true; } private: diff --git a/mindspore/core/ops/adaptive_max_pool2d.cc b/mindspore/core/ops/adaptive_max_pool2d.cc index df695fdaf9e..0790c48b818 100644 --- a/mindspore/core/ops/adaptive_max_pool2d.cc +++ b/mindspore/core/ops/adaptive_max_pool2d.cc @@ -33,12 +33,6 @@ std::vector AdaptiveMaxPool2D::output_size() const { return GetValue>(value_ptr); } -bool AdaptiveMaxPool2D::return_indices() const { - auto value_ptr = GetAttr("return_indices"); - MS_EXCEPTION_IF_NULL(value_ptr); - return GetValue(value_ptr); -} - namespace { abstract::BaseShapePtr AdaptiveMaxPool2DInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { @@ -81,18 +75,9 @@ abstract::BaseShapePtr AdaptiveMaxPool2DInferShape(const PrimitivePtr &primitive } } } - - const auto &return_indices_ptr = primitive->GetAttr("return_indices"); - MS_EXCEPTION_IF_NULL(return_indices_ptr); - const auto &return_indices = GetValue(return_indices_ptr); auto in_shape = std::make_shared(in_shape_vector); - // If return indices is true, need to output the indices corresponding to the max value, whose shape is the same - // as the max value. - if (return_indices) { - return std::make_shared(std::vector{in_shape, in_shape}); - } - return in_shape; + return std::make_shared(std::vector{in_shape, in_shape}); } TypePtr AdaptiveMaxPool2DInferType(const PrimitivePtr &prim, const std::vector &input_args) { @@ -107,17 +92,8 @@ TypePtr AdaptiveMaxPool2DInferType(const PrimitivePtr &prim, const std::vectorBuildType(), valid_types, prim->name()); - const auto &return_indices_ptr = prim->GetAttr("return_indices"); - MS_EXCEPTION_IF_NULL(return_indices_ptr); - const auto &return_indices = GetValue(return_indices_ptr); - - // If return indices is true, need to output the indices corresponding to the max value, whose shape is the same - // as the max value. - if (return_indices) { - auto indices_type = kInt64; - return std::make_shared(std::vector{input_type, indices_type}); - } - return input_type; + auto indices_type = kInt64; + return std::make_shared(std::vector{input_type, indices_type}); } } // namespace diff --git a/mindspore/core/ops/adaptive_max_pool2d.h b/mindspore/core/ops/adaptive_max_pool2d.h index 420379387d7..02dd9647dca 100644 --- a/mindspore/core/ops/adaptive_max_pool2d.h +++ b/mindspore/core/ops/adaptive_max_pool2d.h @@ -38,7 +38,6 @@ class MIND_API AdaptiveMaxPool2D : public BaseOperator { MIND_API_BASE_MEMBER(AdaptiveMaxPool2D); AdaptiveMaxPool2D() : BaseOperator(kAdaptiveMaxPool2D) { InitIOName({"input_x"}, {"output"}); } std::vector output_size() const; - bool return_indices() const; }; abstract::AbstractBasePtr AdaptiveMaxPool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/python/mindspore/nn/layer/pooling.py b/mindspore/python/mindspore/nn/layer/pooling.py index 30e16f713c1..4407a4231ff 100644 --- a/mindspore/python/mindspore/nn/layer/pooling.py +++ b/mindspore/python/mindspore/nn/layer/pooling.py @@ -1188,10 +1188,15 @@ class AdaptiveMaxPool2d(Cell): def __init__(self, output_size, return_indices=False): """Initialize AdaptiveMaxPool2d.""" super(AdaptiveMaxPool2d, self).__init__() - self.adaptive_max_pool2d = AdaptiveMaxPool2D(output_size, return_indices) + validator.check_value_type('return_indices', return_indices, [bool], self.cls_name) + self.adaptive_max_pool2d = AdaptiveMaxPool2D(output_size) + self.return_indices = return_indices def construct(self, input_x): - return self.adaptive_max_pool2d(input_x) + output = self.adaptive_max_pool2d(input_x) + if self.return_indices: + return output + return output[0] class AdaptiveMaxPool3d(Cell): diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py index 97759e15fbf..3c378c4e7f4 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py @@ -1516,7 +1516,6 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size): nchw_index = 4 chw_reverse_index = -3 hw_size = 2 - return_indices = prim.return_indices output_size = prim.output_size @constexpr @@ -1554,20 +1553,14 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size): x_ori_shape = F.shape(x) x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:]) output_shape = get_output_shape(x_ori_shape, output_size) - if return_indices: - out, indices = prim(x) - out = F.reshape(out, output_shape) - indices = F.reshape(indices, output_shape) - return (out, 0), (indices, 0) - out = prim(x) - out = F.reshape(out, output_shape) - return out, 0 - # for the case of CHW - if return_indices: out, indices = prim(x) + out = F.reshape(out, output_shape) + indices = F.reshape(indices, output_shape) return (out, 0), (indices, 0) - out = prim(x) - return out, 0 + + # for the case of CHW + out, indices = prim(x) + return (out, 0), (indices, 0) return vmap_rule diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index 973ced58b43..3bee1fc2772 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -623,8 +623,11 @@ def adaptive_max_pool2d(input_x, output_size, return_indices=False): [[8. 9.]] [[8. 9.]]]] """ - _adaptive_max_pool2d = _get_cache_prim(NN_OPS.AdaptiveMaxPool2D)(output_size, return_indices) - return _adaptive_max_pool2d(input_x) + validator.check_value_type("return_indices", return_indices, bool, "adaptive_max_pool2d") + _adaptive_max_pool2d = _get_cache_prim(NN_OPS.AdaptiveMaxPool2D)(output_size) + out = _adaptive_max_pool2d(input_x) + output = out if return_indices else out[0] + return output def adaptive_max_pool3d(x, output_size, return_indices=False): diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 99c3e8a40ae..fe5510f380e 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -366,7 +366,7 @@ class AdaptiveMaxPool2D(Primitive): ... [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]), mindspore.float32) >>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D((None, 2)) >>> output = adaptive_max_pool_2d(input_x) - >>> print(output) + >>> print(output[0]) [[[[2. 3.] [5. 6.] [8. 9.]] @@ -379,7 +379,7 @@ class AdaptiveMaxPool2D(Primitive): >>> # case 2: output_size=2 >>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D(2) >>> output = adaptive_max_pool_2d(input_x) - >>> print(output) + >>> print(output[0]) [[[[5. 6.] [8. 9.]] [[5. 6.] @@ -389,17 +389,16 @@ class AdaptiveMaxPool2D(Primitive): >>> # case 3: output_size=(1, 2) >>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D((1, 2)) >>> output = adaptive_max_pool_2d(input_x) - >>> print(output) + >>> print(output[0]) [[[[8. 9.]] [[8. 9.]] [[8. 9.]]]] """ @prim_attr_register - def __init__(self, output_size, return_indices=False): + def __init__(self, output_size): """Initialize AdaptiveMaxPool2D.""" validator.check_value_type("output_size", output_size, [int, tuple], self.name) - validator.check_value_type("return_indices", return_indices, [bool], self.name) if isinstance(output_size, tuple): validator.check_int(len(output_size), 2, Rel.EQ, 'length of output_size', self.name) @@ -409,7 +408,6 @@ class AdaptiveMaxPool2D(Primitive): for size in self.output_size: validator.check_number("output_size", size, -1, Rel.GE, None) self.add_prim_attr('output_size', self.output_size) - self.add_prim_attr('return_indices', return_indices) class AdaptiveMaxPool3D(Primitive): diff --git a/tests/st/ops/dynamic_shape/grad/test_adaptive_max_pool_2d.py b/tests/st/ops/dynamic_shape/grad/test_adaptive_max_pool_2d.py index 3dc004b6878..ebeb42e80e5 100644 --- a/tests/st/ops/dynamic_shape/grad/test_adaptive_max_pool_2d.py +++ b/tests/st/ops/dynamic_shape/grad/test_adaptive_max_pool_2d.py @@ -23,7 +23,7 @@ class Net(nn.Cell): def __init__(self, output_size): super(Net, self).__init__() - self.op = P.AdaptiveMaxPool2D(output_size=output_size, return_indices=True) + self.op = P.AdaptiveMaxPool2D(output_size=output_size) def construct(self, x): return self.op(x) diff --git a/tests/st/ops/gpu/test_adaptive_max_pool2d_op.py b/tests/st/ops/gpu/test_adaptive_max_pool2d_op.py index 60f9821eb6c..8573fc4fb9f 100644 --- a/tests/st/ops/gpu/test_adaptive_max_pool2d_op.py +++ b/tests/st/ops/gpu/test_adaptive_max_pool2d_op.py @@ -151,31 +151,6 @@ def test_net_nn(): assert output.asnumpy().shape == expect_shape -def test_tensor_interface_pynative(): - """ - Feature: test adaptivemaxpool2d op. - Description: test the ops in tensor interface in pynative mode. - Expectation: expect correct shape result. - """ - x = Tensor(np.random.randn(1, 32, 9, 9), mindspore.float32) - y = x.adaptive_max_pool2d((3, 5), True) - expect_shape = (1, 32, 3, 5) - assert y[1].asnumpy().shape == expect_shape - - -def test_tensor_interface_graph(): - """ - Feature: test adaptivemaxpool2d op. - Description: test the ops in tensor interface in graph mode. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - x = Tensor(np.random.randn(1, 32, 9, 9), mindspore.float32) - y = x.adaptive_max_pool2d((3, 5)) - expect_shape = (1, 32, 3, 5) - assert y.asnumpy().shape == expect_shape - - @pytest.mark.level1 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard diff --git a/tests/ut/python/optimizer/test_bprop_mindir.py b/tests/ut/python/optimizer/test_bprop_mindir.py index 9d1db9c2472..187e553ee80 100644 --- a/tests/ut/python/optimizer/test_bprop_mindir.py +++ b/tests/ut/python/optimizer/test_bprop_mindir.py @@ -1247,7 +1247,7 @@ def test_adaptive_max_pool2d(): input_x = Tensor(np.array([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]), mindspore.float32) - net = Net(ops.AdaptiveMaxPool2D((None, 2), True)) + net = Net(ops.AdaptiveMaxPool2D((None, 2))) grad = GradNet(net) grad.compile(input_x)