add aicpu op MaxPoolV1 MaxPoolGradV1

This commit is contained in:
Seeker 2022-05-06 20:10:02 +08:00
parent 8b1af5b007
commit aa0dc155fb
15 changed files with 573 additions and 0 deletions

View File

@ -297,6 +297,10 @@ void SetNodedefProto(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDe
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
// when op_name is different in mindspore and aicpu
if (auto iter = kOpNameToAicpuOpNameMap.find(op_name); iter != kOpNameToAicpuOpNameMap.end()) {
op_name = iter->second;
}
// set op name
proto->set_op(op_name);
// set inputs tensor

View File

@ -81,6 +81,8 @@ constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
constexpr auto kPriorityReplayBufferUpdate = "PriorityReplayBufferUpdate";
constexpr auto kMaxPoolV1 = "MaxPoolV1";
constexpr auto kMaxPoolGradV1 = "MaxPoolGradV1";
const std::set<std::string> kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch,
kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D,
@ -97,6 +99,8 @@ const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask,
const std::set<std::string> kDynamicInputOps{
kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName,
kStackPushOpName, kStackPopOpName, kDynamicStitch, kPriorityReplayBufferPush, kPriorityReplayBufferSample};
const std::map<std::string, std::string> kOpNameToAicpuOpNameMap{{kMaxPoolV1, "MaxPool"},
{kMaxPoolGradV1, "MaxPoolGrad"}};
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message
uint32_t ioAddrNum; // Input and output address number

View File

@ -466,6 +466,8 @@ GVAR_DEF(PrimitivePtr, kPrimPSROIPoolingGrad, std::make_shared<Primitive>("PSROI
GVAR_DEF(PrimitivePtr, kPrimROIPooling, std::make_shared<Primitive>("ROIPooling"));
GVAR_DEF(PrimitivePtr, kPrimMaxPool, std::make_shared<Primitive>("MaxPool"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolGrad, std::make_shared<Primitive>("MaxPoolGrad"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolV1, std::make_shared<Primitive>("MaxPoolV1"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradV1, std::make_shared<Primitive>("MaxPoolGradV1"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolWithArgmax, std::make_shared<Primitive>("MaxPoolWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradWithArgmax, std::make_shared<Primitive>("MaxPoolGradWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimApplyCenteredRMSProp, std::make_shared<Primitive>("ApplyCenteredRMSProp"));

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/max_pool_grad_v1.h"
#include <set>
#include "ops/op_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
abstract::ShapePtr MaxPoolGradV1InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
int64_t format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format"));
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode_value = (primitive->GetAttr(kPadMode));
auto pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
if (format == NHWC) {
std::vector<int64_t> ksize_NHWC = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
(void)primitive->AddAttr("ksize", MakeValue(ksize_NHWC));
(void)primitive->DelAttr("data_format");
(void)primitive->AddAttr("data_format", MakeValue("NHWC"));
} else if (format == NCHW) {
std::vector<int64_t> ksize_NCHW = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
(void)primitive->AddAttr("ksize", MakeValue(ksize_NCHW));
(void)primitive->DelAttr("data_format");
(void)primitive->AddAttr("data_format", MakeValue("NCHW"));
}
if (pad_mode == VALID) {
(void)primitive->AddAttr("padding", MakeValue("VALID"));
} else if (pad_mode == SAME) {
(void)primitive->AddAttr("padding", MakeValue("SAME"));
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr MaxPoolGradV1InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto name = prim->name();
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32,
kFloat64, kUInt8, kUInt16, kUInt32, kUInt64};
auto orig_input_type = input_args[0]->BuildType();
auto orig_output_type = input_args[0]->BuildType();
auto grad_type = input_args[0]->BuildType();
auto inferred_type = CheckAndConvertUtils::CheckTensorTypeValid("orig_input", orig_input_type, valid_types, name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("orig_output", orig_output_type, valid_types, name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("grad", grad_type, valid_types, name);
return inferred_type;
}
AbstractBasePtr MaxPoolGradV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto maxpoolgradv1_infer_type = MaxPoolGradV1InferType(primitive, input_args);
auto maxpoolgradv1_infer_shape = MaxPoolGradV1InferShape(primitive, input_args)->shape();
return std::make_shared<abstract::AbstractTensor>(maxpoolgradv1_infer_type, maxpoolgradv1_infer_shape);
}
MIND_API_OPERATOR_IMPL(MaxPoolGradV1, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGradV1, prim::kPrimMaxPoolGradV1, MaxPoolGradV1Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAX_POOL_GRAD_V1_H_
#define MINDSPORE_CORE_OPS_MAX_POOL_GRAD_V1_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaxPoolGradV1 = "MaxPoolGradV1";
class MS_CORE_API MaxPoolGradV1 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MaxPoolGradV1);
MaxPoolGradV1() : BaseOperator(kNameMaxPoolGradV1) { InitIOName({"orig_input", "orig_output", "grad"}, {"output"}); }
};
AbstractBasePtr MaxPoolGradV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMaxPoolGradV1Ptr = std::shared_ptr<MaxPoolGradV1>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAX_POOL_GRAD_V1_H_

View File

@ -0,0 +1,134 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/max_pool_v1.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t kFormatNCHWIndexN = 0;
constexpr int64_t kFormatNCHWIndexC = 1;
constexpr int64_t kFormatNCHWIndexH = 2;
constexpr int64_t kFormatNCHWIndexW = 3;
constexpr int64_t kFormatNHWCIndexN = 0;
constexpr int64_t kFormatNHWCIndexH = 1;
constexpr int64_t kFormatNHWCIndexW = 2;
constexpr int64_t kFormatNHWCIndexC = 3;
abstract::ShapePtr MaxPoolV1InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format"));
const int64_t x_rank = 4;
const int64_t attr_size = 4;
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_rank, op_name);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name);
int64_t batch = 0, in_h = 0, in_w = 0, channel = 0;
int64_t kernel_h = 0, kernel_w = 0;
int64_t stride_h = 0, stride_w = 0;
if (format == NHWC) {
batch = in_shape[kFormatNHWCIndexN];
channel = in_shape[kFormatNHWCIndexC];
in_h = in_shape[kFormatNHWCIndexH];
in_w = in_shape[kFormatNHWCIndexW];
kernel_h = kernel_size[kFormatNHWCIndexH];
kernel_w = kernel_size[kFormatNHWCIndexW];
stride_h = strides[kFormatNHWCIndexH];
stride_w = strides[kFormatNHWCIndexW];
} else if (format == NCHW) {
batch = in_shape[kFormatNCHWIndexN];
channel = in_shape[kFormatNCHWIndexC];
in_h = in_shape[kFormatNCHWIndexH];
in_w = in_shape[kFormatNCHWIndexW];
kernel_h = kernel_size[kFormatNCHWIndexH];
kernel_w = kernel_size[kFormatNCHWIndexW];
stride_h = strides[kFormatNCHWIndexH];
stride_w = strides[kFormatNCHWIndexW];
}
int64_t out_h = abstract::Shape::SHP_ANY;
int64_t out_w = abstract::Shape::SHP_ANY;
if (pad_mode == VALID) {
out_h = static_cast<int64_t>(std::ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
out_w = static_cast<int64_t>(std::ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
} else if (pad_mode == SAME) {
out_h = static_cast<int64_t>(std::ceil(in_h / static_cast<float>(stride_h)));
out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
}
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
// Process attr mapping problems from mindspore to ai_cpu
// kernel_size -> ksize
// pad_mode -> padding
if (format == NHWC) {
std::vector<int64_t> ksize_NHWC = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
(void)primitive->AddAttr("ksize", MakeValue(ksize_NHWC));
(void)primitive->AddAttr("data_format", MakeValue("NHWC"));
} else if (format == NCHW) {
std::vector<int64_t> ksize_NCHW = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
(void)primitive->AddAttr("ksize", MakeValue(ksize_NCHW));
(void)primitive->AddAttr("data_format", MakeValue("NCHW"));
}
if (pad_mode == VALID) {
(void)primitive->AddAttr("padding", MakeValue("VALID"));
} else if (pad_mode == SAME) {
(void)primitive->AddAttr("padding", MakeValue("SAME"));
}
if (NHWC == format) {
out_shape = {batch, out_h, out_w, channel};
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr MaxPoolV1InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto name = prim->name();
const std::set<TypePtr> maxpoolv1_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8,
kUInt16, kFloat16, kFloat32, kFloat64};
auto input_type = input_args[0]->BuildType();
auto inferred_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, maxpoolv1_valid_types, name);
return inferred_type;
}
} // namespace
AbstractBasePtr MaxPoolV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto maxpoolv1_infer_type = MaxPoolV1InferType(primitive, input_args);
auto maxpoolv1_infer_shape = MaxPoolV1InferShape(primitive, input_args)->shape();
return std::make_shared<abstract::AbstractTensor>(maxpoolv1_infer_type, maxpoolv1_infer_shape);
}
MIND_API_OPERATOR_IMPL(MaxPoolV1, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolV1, prim::kPrimMaxPoolV1, MaxPoolV1Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAX_POOL_V1_H_
#define MINDSPORE_CORE_OPS_MAX_POOL_V1_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaxPoolV1 = "MaxPoolV1";
/// \brief Max pooling operation. Refer to Python API @ref mindspore.ops.MaxPoolV1 for more details.
class MIND_API MaxPoolV1 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MaxPoolV1);
/// \brief Constructor.
MaxPoolV1() : BaseOperator(kNameMaxPoolV1) { InitIOName({"x"}, {"output"}); }
};
AbstractBasePtr MaxPoolV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMaxPoolV1Ptr = std::shared_ptr<MaxPoolV1>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAX_POOL_V1_H_

View File

@ -120,6 +120,7 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
{"DepthwiseConv2dNativeBackpropInput", FormatAndPadAttrMap},
{"DepthwiseConv2dNativeBackpropFilter", FormatAndPadAttrMap},
{"AvgPool", FormatAndPadUpperAttrMap},
{"MaxPoolV1", FormatAndPadUpperAttrMap},
{"MaxPool", FormatAndPadUpperAttrMap},
{"MaxPoolWithArgmax", FormatAndPadUpperAttrMap},
{"AvgPoolGrad", FormatAndPadUpperAttrMap},
@ -127,6 +128,7 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
{"AvgPoolGradGpu", FormatAndPadUpperAttrMap},
{"AvgPoolGradCpu", FormatAndPadUpperAttrMap},
{"MaxPoolGrad", FormatAndPadUpperAttrMap},
{"MaxPoolGradV1", FormatAndPadUpperAttrMap},
{"MaxPoolGradGrad", FormatAndPadUpperAttrMap},
{"MaxPoolGradWithArgmax", FormatAndPadUpperAttrMap},
{"MaxPoolGradGradWithArgmax", FormatAndPadUpperAttrMap},

View File

@ -31,6 +31,8 @@ from ..operations._grad_ops import FractionalAvgPoolGrad
from ..operations.nn_ops import NthElement
from ..operations.nn_ops import PSROIPooling
from ..operations._grad_ops import PSROIPoolingGrad
from ..operations.nn_ops import MaxPoolV1
from ..operations._grad_ops import MaxPoolGradV1
@bprop_getters.register(P.CTCLossV2)
@ -181,3 +183,19 @@ def get_bprop_p_s_r_o_i_pooling(self):
return (dx, zeros_like(rois))
return bprop
@bprop_getters.register(MaxPoolV1)
def get_bprop_max_pool_v1_grad(self):
"""Grad definition for `MaxPoolV1` operation."""
maxpool_grad_v1 = MaxPoolGradV1(
kernel_size=self.kernel_size,
strides=self.strides,
pad_mode=self.pad_mode,
data_format=self.format)
def bprop(x, out, dout):
dx = maxpool_grad_v1(x, out, dout)
return (dx,)
return bprop

View File

@ -167,3 +167,5 @@ from .trace import _trace_aicpu
from .tracegrad import _tracegrad_aicpu
from .zeta import _zeta_aicpu
from .adjust_hue import _adjust_hue_aicpu
from .maxpool_v1 import _maxpool_v1_aicpu
from .maxpool_grad_v1 import _maxpool_grad_v1_aicpu

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaxPoolGradV1 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
maxpool_gradv1_op_info = AiCPURegOp("MaxPoolGradV1") \
.fusion_type("OPAQUE") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.input(2, "grad", "required") \
.output(0, "y", "required") \
.attr("ksize", "listInt") \
.attr("strides", "listInt") \
.attr("padding", "str") \
.attr("data_format", "str", "NHWC") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.get_op_info()
@op_info_register(maxpool_gradv1_op_info)
def _maxpool_grad_v1_aicpu():
"""MaxPoolGrad aicpu register"""
return

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaxPoolV1 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
maxpoolv1_op_info = AiCPURegOp("MaxPoolV1") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.attr("ksize", "listInt") \
.attr("strides", "listInt") \
.attr("padding", "str") \
.attr("data_format", "str", "NHWC") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.get_op_info()
@op_info_register(maxpoolv1_op_info)
def _maxpool_v1_aicpu():
"""MaxPool aicpu register"""
return

View File

@ -867,6 +867,59 @@ class MaxPoolGrad(_PoolGrad):
return x1_dtype
class MaxPoolGradV1(Primitive):
"""Performs gradients of the MaxPoolV1 operation."""
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
self.init_prim_io_names(
inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
self.pad_mode = validator.check_string(
pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
self.add_prim_attr("pad_mode", self.pad_mode)
self.format = validator.check_string(
data_format, ['NCHW', 'NHWC'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
def _grad_check_int_or_tuple(arg_name, arg_val):
validator.check_value_type(
arg_name, arg_val, (int, tuple), self.name)
error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
f"or a tuple of two or four positive int numbers, but got {arg_val}")
if isinstance(arg_val, int):
ret = (1, 1, arg_val, arg_val)
elif len(arg_val) == 2:
ret = (1, 1, arg_val[0], arg_val[1])
elif len(arg_val) == 4:
ret = arg_val
else:
raise error_msg
# whether all elements of tuple are positive integers
for item in ret:
if not isinstance(item, int) or item <= 0:
raise error_msg
return ret
self.kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
self.strides = _grad_check_int_or_tuple("strides", strides)
kernel_size_adapted = self.kernel_size if self.format == 'NCHW' else (
self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
strides_adapted = self.strides if self.format == 'NCHW' else (
self.strides[0], self.strides[2], self.strides[3], self.strides[1])
if len(kernel_size) == 4:
kernel_size_adapted = kernel_size
if len(strides) == 4:
strides_adapted = strides
self.add_prim_attr("kernel_size", kernel_size_adapted)
self.add_prim_attr("strides", strides_adapted)
class MaxPoolGradGrad(_PoolGrad):
r"""
Performs gradients of the MaxPoolGrad operation.

View File

@ -1641,6 +1641,97 @@ class MaxPool(_Pool):
super(MaxPool, self).__init__(kernel_size, strides, pad_mode, data_format)
class MaxPoolV1(Primitive):
r"""
Maxpooling operation.
Applies a 2D maxpooling over an input Tensor which can be regarded as a composition of 2D planes.
Typically, the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPoolV1
outputs regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
:math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_h, s_w)`, the operation is as follows.
.. math::
\text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
\text{input}(N_i, C_j, s_h \times h + m, s_w \times w + n)
Args:
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the max value,
is an integer that represents height and width of the kernel, or a tuple
of two integers that represent height and width respectively. Default: 1.
strides (Union[int, tuple[int]]): The distance of kernel moving, an integer that represents
the height and width of movement are both strides, or a tuple of two integers that
represent height and width of movement, respectively. Default: 1.
pad_mode (str): The optional value for pad mode, is "same" or "valid".
Default: "valid".
- same: Adopts the way of completion. The height and width of the output will be the same as
the input. The number of padding will be calculated in horizontal and vertical
directions, and evenly distributed to top and bottom, left and right if possible.
Otherwise, the extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possible largest height and width of the
output will be returned without padding. Extra pixels will be discarded.
data_format (str) : The optional value for data format, is 'NCHW' or 'NHWC'.
Default: 'NCHW'.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Raises:
TypeError: If `kernel_size` or `strides` is neither int nor tuple.
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
ValueError: If `data_format` is neither 'NHWC' nor 'NCHW'.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If the length of shape of `input` is not equal to 4.
Supported Platforms:
``Ascend``
Examples:
>>> x = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
>>> maxpoolv1_op = ops.MaxPoolV1(pad_mode="VALID", kernel_size=2, strides=1)
>>> output_ = maxpoolv1_op(x)
>>> print(output_)
[[[[ 5. 6. 7.]
[ 9. 10. 11.]]
[[17. 18. 19.]
[21. 22. 23.]]
[[29. 30. 31.]
[33. 34. 35.]]]]
"""
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
"""Initialize MaxPoolV1."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
self.pad_mode = validator.check_string(
pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
self.add_prim_attr("pad_mode", self.pad_mode)
self.format = validator.check_string(
data_format, ['NCHW', 'NHWC'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.kernel_size = _check_positive_int_or_tuple(
"kernel_size", kernel_size, self.name, allow_four=False, ret_four=True)
self.strides = _check_positive_int_or_tuple(
"strides", strides, self.name, allow_four=False, ret_four=True)
kernel_size_adapted = self.kernel_size if self.format == 'NCHW' else (
self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
strides_adapted = self.strides if self.format == 'NCHW' else (
self.strides[0], self.strides[2], self.strides[3], self.strides[1])
self.add_prim_attr("kernel_size", kernel_size_adapted)
self.add_prim_attr("strides", strides_adapted)
class MaxPoolWithArgmax(_Pool):
r"""
Performs max pooling on the input Tensor and returns both max values and indices.

View File

@ -51,6 +51,8 @@ from mindspore.ops.operations.nn_ops import FractionalAvgPool
from mindspore.ops.operations._grad_ops import FractionalAvgPoolGrad
from mindspore.ops.operations.nn_ops import NthElement
from mindspore.ops.operations.nn_ops import PSROIPooling
from mindspore.ops.operations.nn_ops import MaxPoolV1
from mindspore.ops.operations._grad_ops import MaxPoolGradV1
from mindspore.nn.layer import normalization
from mindspore.ops.operations.array_ops import RightShift
from mindspore._c_expression import security
@ -2076,6 +2078,15 @@ test_case_nn_ops = [
'desc_inputs': [[3, 4, 6, 6], [3, 4, 3, 3], [3, 4, 3, 3]],
'desc_bprop': [[3, 4, 6, 6]],
'skip': ['backward']}),
('MaxPoolV1', {
'block': MaxPoolV1(kernel_size=(2, 2), strides=(2, 2), pad_mode="VALID"),
'desc_inputs': [[100, 3, 28, 28]],
'desc_bprop': [[100, 3, 14, 14]]}),
('MaxPoolGradV1', {
'block': MaxPoolGradV1(kernel_size=(2, 2), strides=(2, 2), pad_mode="VALID"),
'desc_inputs': [[3, 4, 6, 6], [3, 4, 3, 3], [3, 4, 3, 3]],
'desc_bprop': [[3, 4, 6, 6]],
'skip': ['backward']}),
('MaxPool3D', {
'block': P.MaxPool3D(kernel_size=2, strides=2, pad_mode="VALID"),
'desc_inputs': [[100, 3, 28, 28, 28]],