forked from mindspore-Ecosystem/mindspore
fix adaptive_max_pool2d_grad
This commit is contained in:
parent
6625b98253
commit
4017e6f137
|
@ -1,7 +1,7 @@
|
|||
mindspore.ops.AdaptiveMaxPool2D
|
||||
===============================
|
||||
|
||||
.. py:class:: mindspore.ops.AdaptiveMaxPool2D(output_size, return_indices=False)
|
||||
.. py:class:: mindspore.ops.AdaptiveMaxPool2D(output_size)
|
||||
|
||||
二维自适应最大值池化。
|
||||
|
||||
|
|
|
@ -39,7 +39,6 @@ bool AdaptiveMaxPool2dCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
|||
// (H_out, W_out)
|
||||
auto output_size = kernel_ptr->output_size();
|
||||
(void)std::copy(output_size.begin(), output_size.end(), std::back_inserter(attr_output_size_));
|
||||
attr_return_indices_ = kernel_ptr->return_indices();
|
||||
return MatchKernelFunc(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
|
@ -80,33 +79,6 @@ bool AdaptiveMaxPool2dCpuKernelMod::ResizedOutputSize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AdaptiveMaxPool2dCpuKernelMod::UpdateOutputSizeList(const std::vector<KernelTensorPtr> &outputs,
|
||||
size_t input_type_size) {
|
||||
output_size_list_.clear();
|
||||
// If return_indices is true, the outputs num should be 2, otherwise should be 1.
|
||||
if ((outputs.size() == ops::kOutputSizeAttrSize - 1 && (!attr_return_indices_)) ||
|
||||
(outputs.size() == ops::kOutputSizeAttrSize && attr_return_indices_)) {
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
auto output_shape = outputs[0]->GetShapeVector();
|
||||
size_t output_number = 1;
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_number *= static_cast<size_t>(output_shape[i]);
|
||||
}
|
||||
// N * C * H * W * type_size
|
||||
auto output_mem_size = output_number * input_type_size;
|
||||
output_size_list_.push_back(output_mem_size);
|
||||
if (outputs.size() == ops::kOutputSizeAttrSize) {
|
||||
output_size_list_.push_back(output_number * sizeof(int64_t));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "For primitive[AdaptiveMaxPool2D], the number of outputs should be 2 when return_indices is True,"
|
||||
" or that should be 1 when return_indices is False, but got the number of outputs : "
|
||||
<< outputs.size() << ", and return_indices: " << attr_return_indices_;
|
||||
return false;
|
||||
}
|
||||
|
||||
int AdaptiveMaxPool2dCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
|
@ -135,11 +107,6 @@ int AdaptiveMaxPool2dCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
size_t input_type_size = abstract::TypeIdSize(inputs.at(0)->GetDtype());
|
||||
if (!UpdateOutputSizeList(outputs, input_type_size)) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
@ -171,10 +138,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &
|
|||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
int64_t *indices_addr = nullptr;
|
||||
if (outputs.size() > 1) {
|
||||
indices_addr = GetDeviceAddress<int64_t>(outputs, kIndex1);
|
||||
}
|
||||
int64_t *indices_addr = GetDeviceAddress<int64_t>(outputs, kIndex1);
|
||||
|
||||
auto task = [this, &input_addr, &output_addr, &indices_addr](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
|
@ -182,7 +146,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &
|
|||
size_t output_offset = i * output_hw_;
|
||||
T *input_ptr = input_addr + input_offset;
|
||||
T *output_ptr = output_addr + output_offset;
|
||||
int64_t *indices_ptr = (indices_addr == nullptr) ? indices_addr : indices_addr + output_offset;
|
||||
int64_t *indices_ptr = indices_addr + output_offset;
|
||||
|
||||
for (size_t oh_index = 0; oh_index < output_height_; ++oh_index) {
|
||||
size_t h_begin = start_index(oh_index, output_height_, input_height_);
|
||||
|
@ -205,9 +169,7 @@ bool AdaptiveMaxPool2dCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &
|
|||
ComputeLocalMax(&max_indice, &max_val, lw, input_width_, input_ptr);
|
||||
size_t output_index = oh_index * output_width_ + ow_index;
|
||||
output_ptr[output_index] = max_val;
|
||||
if (indices_addr != nullptr) {
|
||||
indices_ptr[output_index] = SizeToLong(max_indice);
|
||||
}
|
||||
indices_ptr[output_index] = SizeToLong(max_indice);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,9 +69,7 @@ __global__ void AdaptiveMaxPool2DKernel(const size_t size, const size_t input_he
|
|||
sub_input_ptr += input_width;
|
||||
}
|
||||
output_ptr[oh * output_width + ow] = max;
|
||||
if (indices_data != nullptr) {
|
||||
indices_ptr[oh * output_width + ow] = indice;
|
||||
}
|
||||
indices_ptr[oh * output_width + ow] = indice;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,9 +67,7 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
|
|||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
int64_t *indices_addr = nullptr;
|
||||
if (outputs.size() > 1) {
|
||||
indices_addr = GetDeviceAddress<int64_t>(outputs, 1);
|
||||
}
|
||||
indices_addr = GetDeviceAddress<int64_t>(outputs, 1);
|
||||
|
||||
ApplyAdaptiveMaxPool2D(size_, input_height_, input_width_, output_height_, output_width_, input_addr, output_addr,
|
||||
indices_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
@ -90,6 +88,10 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
bool InitSize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::AdaptiveMaxPool2D>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
|
@ -116,9 +118,6 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
|
|||
return false;
|
||||
}
|
||||
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
|
||||
input_height_ = static_cast<size_t>(input_shape[len_ - ops::kOutputSizeAttrSize]);
|
||||
input_width_ = static_cast<size_t>(input_shape[len_ - ops::kOutputSizeAttrSize + 1]);
|
||||
size_ = static_cast<size_t>(len_ == ops::kFormatCHWShapeSize ? input_shape[0] : input_shape[0] * input_shape[1]);
|
||||
|
@ -126,7 +125,6 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
|
|||
for (size_t i = 0; i < len_; i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
input_size_list_.push_back(input_size_);
|
||||
|
||||
auto output_size = kernel_ptr->output_size();
|
||||
if (output_size.size() == ops::kOutputSizeAttrSize) {
|
||||
|
@ -139,27 +137,7 @@ class AdaptiveMaxPool2DKernelMod : public NativeGpuKernelMod {
|
|||
<< output_size.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_num = 1;
|
||||
// If return indices is true, the outputs num should be 2, otherwise should be 1.
|
||||
if ((outputs.size() == ops::kOutputSizeAttrSize - 1 && (!kernel_ptr->return_indices())) ||
|
||||
(outputs.size() == ops::kOutputSizeAttrSize && kernel_ptr->return_indices())) {
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
auto output_shape = outputs[0]->GetShapeVector();
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_num *= output_shape[i];
|
||||
}
|
||||
output_size_ = output_num * sizeof(T);
|
||||
output_size_list_.push_back(output_size_);
|
||||
if (outputs.size() == ops::kOutputSizeAttrSize) {
|
||||
output_size_list_.push_back(output_num * sizeof(int64_t));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_EXCEPTION(ValueError) << "For primitive[AdaptiveMaxPool2D], the size of attr[output_size] should be 2, but got:"
|
||||
<< output_size.size();
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -33,12 +33,6 @@ std::vector<int64_t> AdaptiveMaxPool2D::output_size() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
bool AdaptiveMaxPool2D::return_indices() const {
|
||||
auto value_ptr = GetAttr("return_indices");
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::BaseShapePtr AdaptiveMaxPool2DInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -81,18 +75,9 @@ abstract::BaseShapePtr AdaptiveMaxPool2DInferShape(const PrimitivePtr &primitive
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto &return_indices_ptr = primitive->GetAttr("return_indices");
|
||||
MS_EXCEPTION_IF_NULL(return_indices_ptr);
|
||||
const auto &return_indices = GetValue<bool>(return_indices_ptr);
|
||||
auto in_shape = std::make_shared<abstract::Shape>(in_shape_vector);
|
||||
|
||||
// If return indices is true, need to output the indices corresponding to the max value, whose shape is the same
|
||||
// as the max value.
|
||||
if (return_indices) {
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{in_shape, in_shape});
|
||||
}
|
||||
return in_shape;
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{in_shape, in_shape});
|
||||
}
|
||||
|
||||
TypePtr AdaptiveMaxPool2DInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -107,17 +92,8 @@ TypePtr AdaptiveMaxPool2DInferType(const PrimitivePtr &prim, const std::vector<A
|
|||
auto input_type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
|
||||
const auto &return_indices_ptr = prim->GetAttr("return_indices");
|
||||
MS_EXCEPTION_IF_NULL(return_indices_ptr);
|
||||
const auto &return_indices = GetValue<bool>(return_indices_ptr);
|
||||
|
||||
// If return indices is true, need to output the indices corresponding to the max value, whose shape is the same
|
||||
// as the max value.
|
||||
if (return_indices) {
|
||||
auto indices_type = kInt64;
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{input_type, indices_type});
|
||||
}
|
||||
return input_type;
|
||||
auto indices_type = kInt64;
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{input_type, indices_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -38,7 +38,6 @@ class MIND_API AdaptiveMaxPool2D : public BaseOperator {
|
|||
MIND_API_BASE_MEMBER(AdaptiveMaxPool2D);
|
||||
AdaptiveMaxPool2D() : BaseOperator(kAdaptiveMaxPool2D) { InitIOName({"input_x"}, {"output"}); }
|
||||
std::vector<int64_t> output_size() const;
|
||||
bool return_indices() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr AdaptiveMaxPool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -1188,10 +1188,15 @@ class AdaptiveMaxPool2d(Cell):
|
|||
def __init__(self, output_size, return_indices=False):
|
||||
"""Initialize AdaptiveMaxPool2d."""
|
||||
super(AdaptiveMaxPool2d, self).__init__()
|
||||
self.adaptive_max_pool2d = AdaptiveMaxPool2D(output_size, return_indices)
|
||||
validator.check_value_type('return_indices', return_indices, [bool], self.cls_name)
|
||||
self.adaptive_max_pool2d = AdaptiveMaxPool2D(output_size)
|
||||
self.return_indices = return_indices
|
||||
|
||||
def construct(self, input_x):
|
||||
return self.adaptive_max_pool2d(input_x)
|
||||
output = self.adaptive_max_pool2d(input_x)
|
||||
if self.return_indices:
|
||||
return output
|
||||
return output[0]
|
||||
|
||||
|
||||
class AdaptiveMaxPool3d(Cell):
|
||||
|
|
|
@ -1516,7 +1516,6 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
|
|||
nchw_index = 4
|
||||
chw_reverse_index = -3
|
||||
hw_size = 2
|
||||
return_indices = prim.return_indices
|
||||
output_size = prim.output_size
|
||||
|
||||
@constexpr
|
||||
|
@ -1554,20 +1553,14 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
|
|||
x_ori_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
|
||||
output_shape = get_output_shape(x_ori_shape, output_size)
|
||||
if return_indices:
|
||||
out, indices = prim(x)
|
||||
out = F.reshape(out, output_shape)
|
||||
indices = F.reshape(indices, output_shape)
|
||||
return (out, 0), (indices, 0)
|
||||
out = prim(x)
|
||||
out = F.reshape(out, output_shape)
|
||||
return out, 0
|
||||
# for the case of CHW
|
||||
if return_indices:
|
||||
out, indices = prim(x)
|
||||
out = F.reshape(out, output_shape)
|
||||
indices = F.reshape(indices, output_shape)
|
||||
return (out, 0), (indices, 0)
|
||||
out = prim(x)
|
||||
return out, 0
|
||||
|
||||
# for the case of CHW
|
||||
out, indices = prim(x)
|
||||
return (out, 0), (indices, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
|
|
@ -623,8 +623,11 @@ def adaptive_max_pool2d(input_x, output_size, return_indices=False):
|
|||
[[8. 9.]]
|
||||
[[8. 9.]]]]
|
||||
"""
|
||||
_adaptive_max_pool2d = _get_cache_prim(NN_OPS.AdaptiveMaxPool2D)(output_size, return_indices)
|
||||
return _adaptive_max_pool2d(input_x)
|
||||
validator.check_value_type("return_indices", return_indices, bool, "adaptive_max_pool2d")
|
||||
_adaptive_max_pool2d = _get_cache_prim(NN_OPS.AdaptiveMaxPool2D)(output_size)
|
||||
out = _adaptive_max_pool2d(input_x)
|
||||
output = out if return_indices else out[0]
|
||||
return output
|
||||
|
||||
|
||||
def adaptive_max_pool3d(x, output_size, return_indices=False):
|
||||
|
|
|
@ -366,7 +366,7 @@ class AdaptiveMaxPool2D(Primitive):
|
|||
... [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]), mindspore.float32)
|
||||
>>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D((None, 2))
|
||||
>>> output = adaptive_max_pool_2d(input_x)
|
||||
>>> print(output)
|
||||
>>> print(output[0])
|
||||
[[[[2. 3.]
|
||||
[5. 6.]
|
||||
[8. 9.]]
|
||||
|
@ -379,7 +379,7 @@ class AdaptiveMaxPool2D(Primitive):
|
|||
>>> # case 2: output_size=2
|
||||
>>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D(2)
|
||||
>>> output = adaptive_max_pool_2d(input_x)
|
||||
>>> print(output)
|
||||
>>> print(output[0])
|
||||
[[[[5. 6.]
|
||||
[8. 9.]]
|
||||
[[5. 6.]
|
||||
|
@ -389,17 +389,16 @@ class AdaptiveMaxPool2D(Primitive):
|
|||
>>> # case 3: output_size=(1, 2)
|
||||
>>> adaptive_max_pool_2d = ops.AdaptiveMaxPool2D((1, 2))
|
||||
>>> output = adaptive_max_pool_2d(input_x)
|
||||
>>> print(output)
|
||||
>>> print(output[0])
|
||||
[[[[8. 9.]]
|
||||
[[8. 9.]]
|
||||
[[8. 9.]]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, output_size, return_indices=False):
|
||||
def __init__(self, output_size):
|
||||
"""Initialize AdaptiveMaxPool2D."""
|
||||
validator.check_value_type("output_size", output_size, [int, tuple], self.name)
|
||||
validator.check_value_type("return_indices", return_indices, [bool], self.name)
|
||||
if isinstance(output_size, tuple):
|
||||
validator.check_int(len(output_size), 2, Rel.EQ,
|
||||
'length of output_size', self.name)
|
||||
|
@ -409,7 +408,6 @@ class AdaptiveMaxPool2D(Primitive):
|
|||
for size in self.output_size:
|
||||
validator.check_number("output_size", size, -1, Rel.GE, None)
|
||||
self.add_prim_attr('output_size', self.output_size)
|
||||
self.add_prim_attr('return_indices', return_indices)
|
||||
|
||||
|
||||
class AdaptiveMaxPool3D(Primitive):
|
||||
|
|
|
@ -23,7 +23,7 @@ class Net(nn.Cell):
|
|||
|
||||
def __init__(self, output_size):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.AdaptiveMaxPool2D(output_size=output_size, return_indices=True)
|
||||
self.op = P.AdaptiveMaxPool2D(output_size=output_size)
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x)
|
||||
|
|
|
@ -151,31 +151,6 @@ def test_net_nn():
|
|||
assert output.asnumpy().shape == expect_shape
|
||||
|
||||
|
||||
def test_tensor_interface_pynative():
|
||||
"""
|
||||
Feature: test adaptivemaxpool2d op.
|
||||
Description: test the ops in tensor interface in pynative mode.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
x = Tensor(np.random.randn(1, 32, 9, 9), mindspore.float32)
|
||||
y = x.adaptive_max_pool2d((3, 5), True)
|
||||
expect_shape = (1, 32, 3, 5)
|
||||
assert y[1].asnumpy().shape == expect_shape
|
||||
|
||||
|
||||
def test_tensor_interface_graph():
|
||||
"""
|
||||
Feature: test adaptivemaxpool2d op.
|
||||
Description: test the ops in tensor interface in graph mode.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
x = Tensor(np.random.randn(1, 32, 9, 9), mindspore.float32)
|
||||
y = x.adaptive_max_pool2d((3, 5))
|
||||
expect_shape = (1, 32, 3, 5)
|
||||
assert y.asnumpy().shape == expect_shape
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -1247,7 +1247,7 @@ def test_adaptive_max_pool2d():
|
|||
input_x = Tensor(np.array([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]), mindspore.float32)
|
||||
net = Net(ops.AdaptiveMaxPool2D((None, 2), True))
|
||||
net = Net(ops.AdaptiveMaxPool2D((None, 2)))
|
||||
grad = GradNet(net)
|
||||
grad.compile(input_x)
|
||||
|
||||
|
|
Loading…
Reference in New Issue