fix adaptive_max_pool2d_grad

This commit is contained in:
fanjibin 2022-11-25 14:24:13 +08:00
parent 6625b98253
commit 4017e6f137
13 changed files with 38 additions and 151 deletions

View File

@ -1,7 +1,7 @@
mindspore.ops.AdaptiveMaxPool2D
===============================
.. py:class:: mindspore.ops.AdaptiveMaxPool2D(output_size, return_indices=False)
.. py:class:: mindspore.ops.AdaptiveMaxPool2D(output_size)
二维自适应最大值池化。

View File

@ -39,7 +39,6 @@ bool AdaptiveMaxPool2dCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
// (H_out, W_out)
auto output_size = kernel_ptr->output_size();
(void)std::copy(output_size.begin(), output_size.end(), std::back_inserter(attr_output_size_));
attr_return_indices_ = kernel_ptr->return_indices();
return MatchKernelFunc(base_operator, inputs, outputs);
}
@ -80,33 +79,6 @@ bool AdaptiveMaxPool2dCpuKernelMod::ResizedOutputSize() {
return true;
}
bool AdaptiveMaxPool2dCpuKernelMod::UpdateOutputSizeList(const std::vector<KernelTensorPtr> &outputs,
size_t input_type_size) {
output_size_list_.clear();
// If return_indices is true, the outputs num should be 2, otherwise should be 1.
if ((outputs.size() == ops::kOutputSizeAttrSize - 1 && (!attr_return_indices_)) ||
(outputs.size() == ops::kOutputSizeAttrSize && attr_return_indices_)) {
MS_EXCEPTION_IF_NULL(outputs[0]);
auto output_shape = outputs[0]->GetShapeVector();
size_t output_number = 1;
for (size_t i = 0; i < output_shape.size(); i++) {
output_number *= static_cast<size_t>(output_shape[i]);
}
// N * C * H * W * type_size
auto output_mem_size = output_number * input_type_size;
output_size_list_.push_back(output_mem_size);
if (outputs.size() == ops::kOutputSizeAttrSize) {
output_size_list_.push_back(output_number * sizeof(int64_t));
}
return true;
}
MS_LOG(ERROR) << "For primitive[AdaptiveMaxPool2D], the number of outputs should be 2 when return_indices is True,"
" or that should be 1 when return_indices is False, but got the number of outputs : "
<< outputs.size() << ", and return_indices: " << attr_return_indices_;
return false;
}
int AdaptiveMaxPool2dCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
@ -135,11 +107,6 @@ int AdaptiveMaxPool2dCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
return KRET_RESIZE_FAILED;
}
size_t input_type_size = abstract::TypeIdSize(inputs.at(0)->GetDtype());
if (!UpdateOutputSizeList(outputs, input_type_size)) {
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
@ -171,10 +138,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &
const std::vector<AddressPtr> &outputs) {
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
int64_t *indices_addr = nullptr;
if (outputs.size() > 1) {
indices_addr = GetDeviceAddress<int64_t>(outputs, kIndex1);
}
int64_t *indices_addr = GetDeviceAddress<int64_t>(outputs, kIndex1);
auto task = [this, &input_addr, &output_addr, &indices_addr](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
@ -182,7 +146,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &
size_t output_offset = i * output_hw_;
T *input_ptr = input_addr + input_offset;
T *output_ptr = output_addr + output_offset;
int64_t *indices_ptr = (indices_addr == nullptr) ? indices_addr : indices_addr + output_offset;
int64_t *indices_ptr = indices_addr + output_offset;
for (size_t oh_index = 0; oh_index < output_height_; ++oh_index) {
size_t h_begin = start_index(oh_index, output_height_, input_height_);
@ -205,9 +169,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &
ComputeLocalMax(&max_indice, &max_val, lw, input_width_, input_ptr);
size_t output_index = oh_index * output_width_ + ow_index;
output_ptr[output_index] = max_val;
if (indices_addr != nullptr) {
indices_ptr[output_index] = SizeToLong(max_indice);
}
indices_ptr[output_index] = SizeToLong(max_indice);
}
}
}

View File

@ -69,9 +69,7 @@ __global__ void AdaptiveMaxPool2DKernel(const size_t size, const size_t input_he
sub_input_ptr += input_width;
}
output_ptr[oh * output_width + ow] = max;
if (indices_data != nullptr) {
indices_ptr[oh * output_width + ow] = indice;
}
indices_ptr[oh * output_width + ow] = indice;
}
}
}

View File

@ -67,9 +67,7 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
int64_t *indices_addr = nullptr;
if (outputs.size() > 1) {
indices_addr = GetDeviceAddress<int64_t>(outputs, 1);
}
indices_addr = GetDeviceAddress<int64_t>(outputs, 1);
ApplyAdaptiveMaxPool2D(size_, input_height_, input_width_, output_height_, output_width_, input_addr, output_addr,
indices_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
@ -90,6 +88,10 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
bool InitSize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::AdaptiveMaxPool2D>(base_operator);
if (kernel_ptr == nullptr) {
MS_EXCEPTION(ValueError)
@ -116,9 +118,6 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
return false;
}
input_size_list_.clear();
output_size_list_.clear();
input_height_ = static_cast<size_t>(input_shape[len_ - ops::kOutputSizeAttrSize]);
input_width_ = static_cast<size_t>(input_shape[len_ - ops::kOutputSizeAttrSize + 1]);
size_ = static_cast<size_t>(len_ == ops::kFormatCHWShapeSize ? input_shape[0] : input_shape[0] * input_shape[1]);
@ -126,7 +125,6 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
for (size_t i = 0; i < len_; i++) {
input_size_ *= input_shape[i];
}
input_size_list_.push_back(input_size_);
auto output_size = kernel_ptr->output_size();
if (output_size.size() == ops::kOutputSizeAttrSize) {
@ -139,27 +137,7 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
<< output_size.size();
return false;
}
size_t output_num = 1;
// If return indices is true, the outputs num should be 2, otherwise should be 1.
if ((outputs.size() == ops::kOutputSizeAttrSize - 1 && (!kernel_ptr->return_indices())) ||
(outputs.size() == ops::kOutputSizeAttrSize && kernel_ptr->return_indices())) {
MS_EXCEPTION_IF_NULL(outputs[0]);
auto output_shape = outputs[0]->GetShapeVector();
for (size_t i = 0; i < output_shape.size(); i++) {
output_num *= output_shape[i];
}
output_size_ = output_num * sizeof(T);
output_size_list_.push_back(output_size_);
if (outputs.size() == ops::kOutputSizeAttrSize) {
output_size_list_.push_back(output_num * sizeof(int64_t));
}
return true;
}
MS_EXCEPTION(ValueError) << "For primitive[AdaptiveMaxPool2D], the size of attr[output_size] should be 2, but got:"
<< output_size.size();
return false;
return true;
}
private:

View File

@ -33,12 +33,6 @@ std::vector<int64_t> AdaptiveMaxPool2D::output_size() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
bool AdaptiveMaxPool2D::return_indices() const {
auto value_ptr = GetAttr("return_indices");
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
namespace {
abstract::BaseShapePtr AdaptiveMaxPool2DInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -81,18 +75,9 @@ abstract::BaseShapePtr AdaptiveMaxPool2DInferShape(const PrimitivePtr &primitive
}
}
}
const auto &return_indices_ptr = primitive->GetAttr("return_indices");
MS_EXCEPTION_IF_NULL(return_indices_ptr);
const auto &return_indices = GetValue<bool>(return_indices_ptr);
auto in_shape = std::make_shared<abstract::Shape>(in_shape_vector);
// If return indices is true, need to output the indices corresponding to the max value, whose shape is the same
// as the max value.
if (return_indices) {
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{in_shape, in_shape});
}
return in_shape;
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{in_shape, in_shape});
}
TypePtr AdaptiveMaxPool2DInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
@ -107,17 +92,8 @@ TypePtr AdaptiveMaxPool2DInferType(const PrimitivePtr &prim, const std::vector<A
auto input_type =
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim->name());
const auto &return_indices_ptr = prim->GetAttr("return_indices");
MS_EXCEPTION_IF_NULL(return_indices_ptr);
const auto &return_indices = GetValue<bool>(return_indices_ptr);
// If return indices is true, need to output the indices corresponding to the max value, whose shape is the same
// as the max value.
if (return_indices) {
auto indices_type = kInt64;
return std::make_shared<Tuple>(std::vector<TypePtr>{input_type, indices_type});
}
return input_type;
auto indices_type = kInt64;
return std::make_shared<Tuple>(std::vector<TypePtr>{input_type, indices_type});
}
} // namespace

View File

@ -38,7 +38,6 @@ class MIND_API AdaptiveMaxPool2D : public BaseOperator {
MIND_API_BASE_MEMBER(AdaptiveMaxPool2D);
AdaptiveMaxPool2D() : BaseOperator(kAdaptiveMaxPool2D) { InitIOName({"input_x"}, {"output"}); }
std::vector<int64_t> output_size() const;
bool return_indices() const;
};
abstract::AbstractBasePtr AdaptiveMaxPool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -1188,10 +1188,15 @@ class AdaptiveMaxPool2d(Cell):
def __init__(self, output_size, return_indices=False):
"""Initialize AdaptiveMaxPool2d."""
super(AdaptiveMaxPool2d, self).__init__()
self.adaptive_max_pool2d = AdaptiveMaxPool2D(output_size, return_indices)
validator.check_value_type('return_indices', return_indices, [bool], self.cls_name)
self.adaptive_max_pool2d = AdaptiveMaxPool2D(output_size)
self.return_indices = return_indices
def construct(self, input_x):
return self.adaptive_max_pool2d(input_x)
output = self.adaptive_max_pool2d(input_x)
if self.return_indices:
return output
return output[0]
class AdaptiveMaxPool3d(Cell):

View File

@ -1516,7 +1516,6 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
nchw_index = 4
chw_reverse_index = -3
hw_size = 2
return_indices = prim.return_indices
output_size = prim.output_size
@constexpr
@ -1554,20 +1553,14 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
x_ori_shape = F.shape(x)
x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
output_shape = get_output_shape(x_ori_shape, output_size)
if return_indices:
out, indices = prim(x)
out = F.reshape(out, output_shape)
indices = F.reshape(indices, output_shape)
return (out, 0), (indices, 0)
out = prim(x)
out = F.reshape(out, output_shape)
return out, 0
# for the case of CHW
if return_indices:
out, indices = prim(x)
out = F.reshape(out, output_shape)
indices = F.reshape(indices, output_shape)
return (out, 0), (indices, 0)
out = prim(x)
return out, 0
# for the case of CHW
out, indices = prim(x)
return (out, 0), (indices, 0)
return vmap_rule

View File

@ -623,8 +623,11 @@ def adaptive_max_pool2d(input_x, output_size, return_indices=False):
[[8. 9.]]
[[8. 9.]]]]
"""
_adaptive_max_pool2d = _get_cache_prim(NN_OPS.AdaptiveMaxPool2D)(output_size, return_indices)
return _adaptive_max_pool2d(input_x)
validator.check_value_type("return_indices", return_indices, bool, "adaptive_max_pool2d")
_adaptive_max_pool2d = _get_cache_prim(NN_OPS.AdaptiveMaxPool2D)(output_size)
out = _adaptive_max_pool2d(input_x)
output = out if return_indices else out[0]
return output
def adaptive_max_pool3d(x, output_size, return_indices=False):

View File

@ -366,7 +366,7 @@ class AdaptiveMaxPool2D(Primitive):
... [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]), mindspore.float32)
>>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D((None, 2))
>>> output = adaptive_max_pool_2d(input_x)
>>> print(output)
>>> print(output[0])
[[[[2. 3.]
[5. 6.]
[8. 9.]]
@ -379,7 +379,7 @@ class AdaptiveMaxPool2D(Primitive):
>>> # case 2: output_size=2
>>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D(2)
>>> output = adaptive_max_pool_2d(input_x)
>>> print(output)
>>> print(output[0])
[[[[5. 6.]
[8. 9.]]
[[5. 6.]
@ -389,17 +389,16 @@ class AdaptiveMaxPool2D(Primitive):
>>> # case 3: output_size=(1, 2)
>>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D((1, 2))
>>> output = adaptive_max_pool_2d(input_x)
>>> print(output)
>>> print(output[0])
[[[[8. 9.]]
[[8. 9.]]
[[8. 9.]]]]
"""
@prim_attr_register
def __init__(self, output_size, return_indices=False):
def __init__(self, output_size):
"""Initialize AdaptiveMaxPool2D."""
validator.check_value_type("output_size", output_size, [int, tuple], self.name)
validator.check_value_type("return_indices", return_indices, [bool], self.name)
if isinstance(output_size, tuple):
validator.check_int(len(output_size), 2, Rel.EQ,
'length of output_size', self.name)
@ -409,7 +408,6 @@ class AdaptiveMaxPool2D(Primitive):
for size in self.output_size:
validator.check_number("output_size", size, -1, Rel.GE, None)
self.add_prim_attr('output_size', self.output_size)
self.add_prim_attr('return_indices', return_indices)
class AdaptiveMaxPool3D(Primitive):

View File

@ -23,7 +23,7 @@ class Net(nn.Cell):
def __init__(self, output_size):
super(Net, self).__init__()
self.op = P.AdaptiveMaxPool2D(output_size=output_size, return_indices=True)
self.op = P.AdaptiveMaxPool2D(output_size=output_size)
def construct(self, x):
return self.op(x)

View File

@ -151,31 +151,6 @@ def test_net_nn():
assert output.asnumpy().shape == expect_shape
def test_tensor_interface_pynative():
"""
Feature: test adaptivemaxpool2d op.
Description: test the ops in tensor interface in pynative mode.
Expectation: expect correct shape result.
"""
x = Tensor(np.random.randn(1, 32, 9, 9), mindspore.float32)
y = x.adaptive_max_pool2d((3, 5), True)
expect_shape = (1, 32, 3, 5)
assert y[1].asnumpy().shape == expect_shape
def test_tensor_interface_graph():
"""
Feature: test adaptivemaxpool2d op.
Description: test the ops in tensor interface in graph mode.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x = Tensor(np.random.randn(1, 32, 9, 9), mindspore.float32)
y = x.adaptive_max_pool2d((3, 5))
expect_shape = (1, 32, 3, 5)
assert y.asnumpy().shape == expect_shape
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard

View File

@ -1247,7 +1247,7 @@ def test_adaptive_max_pool2d():
input_x = Tensor(np.array([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]), mindspore.float32)
net = Net(ops.AdaptiveMaxPool2D((None, 2), True))
net = Net(ops.AdaptiveMaxPool2D((None, 2)))
grad = GradNet(net)
grad.compile(input_x)