forked from mindspore-Ecosystem/mindspore
!32546 [feat] [assistant] [I4CRKO] [I4CRKP] add nn operator AvgPoolV1 and AvgPoolGradV1
Merge pull request !32546 from Seeker98/0405_a
This commit is contained in:
commit
a3e00bd637
|
@ -87,6 +87,8 @@ constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
|
|||
constexpr auto kPriorityReplayBufferUpdate = "PriorityReplayBufferUpdate";
|
||||
constexpr auto kMaxPoolV1 = "MaxPoolV1";
|
||||
constexpr auto kMaxPoolGradV1 = "MaxPoolGradV1";
|
||||
constexpr auto kAvgPoolV1 = "AvgPoolV1";
|
||||
constexpr auto kAvgPoolGradV1 = "AvgPoolGradV1";
|
||||
const std::set<std::string> kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch,
|
||||
kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements};
|
||||
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D,
|
||||
|
@ -110,7 +112,9 @@ const std::map<std::string, std::string> kOpNameToAicpuOpNameMap{
|
|||
{kStack, "Pack"},
|
||||
{kUnstack, "Unpack"},
|
||||
{kGather, "GatherV2"},
|
||||
{kSampleDistortedBoundingBoxV2, "SampleDistortedBoundingBoxExt2"}};
|
||||
{kSampleDistortedBoundingBoxV2, "SampleDistortedBoundingBoxExt2"},
|
||||
{kAvgPoolV1, "AvgPool"},
|
||||
{kAvgPoolGradV1, "AvgPoolGrad"}};
|
||||
struct AicpuParamHead {
|
||||
uint32_t length; // Total length: include cunstom message
|
||||
uint32_t ioAddrNum; // Input and output address number
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/avg_pool_v1.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t kFormatNCHWIndN = 0;
|
||||
constexpr int64_t kFormatNCHWIndC = 1;
|
||||
constexpr int64_t kFormatNCHWIndH = 2;
|
||||
constexpr int64_t kFormatNCHWIndW = 3;
|
||||
constexpr int64_t kFormatNHWCIndN = 0;
|
||||
constexpr int64_t kFormatNHWCIndH = 1;
|
||||
constexpr int64_t kFormatNHWCIndW = 2;
|
||||
constexpr int64_t kFormatNHWCIndC = 3;
|
||||
abstract::ShapePtr AvgPoolV1InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
int64_t format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format"));
|
||||
const int64_t x_size = 4;
|
||||
const int64_t attr_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_size, op_name);
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
||||
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name);
|
||||
|
||||
int64_t batch = 0, in_h = 0, in_w = 0, channel = 0;
|
||||
int64_t kernel_h = 0, kernel_w = 0;
|
||||
int64_t stride_h = 0, stride_w = 0;
|
||||
|
||||
if (format == NHWC) {
|
||||
batch = in_shape[kFormatNHWCIndN];
|
||||
channel = in_shape[kFormatNHWCIndC];
|
||||
in_h = in_shape[kFormatNHWCIndH];
|
||||
in_w = in_shape[kFormatNHWCIndW];
|
||||
kernel_h = kernel_size[kFormatNHWCIndH];
|
||||
kernel_w = kernel_size[kFormatNHWCIndW];
|
||||
stride_h = strides[kFormatNHWCIndH];
|
||||
stride_w = strides[kFormatNHWCIndW];
|
||||
} else if (format == NCHW) {
|
||||
batch = in_shape[kFormatNCHWIndN];
|
||||
channel = in_shape[kFormatNCHWIndC];
|
||||
in_h = in_shape[kFormatNCHWIndH];
|
||||
in_w = in_shape[kFormatNCHWIndW];
|
||||
kernel_h = kernel_size[kFormatNCHWIndH];
|
||||
kernel_w = kernel_size[kFormatNCHWIndW];
|
||||
stride_h = strides[kFormatNCHWIndH];
|
||||
stride_w = strides[kFormatNCHWIndW];
|
||||
}
|
||||
int64_t out_h = abstract::Shape::SHP_ANY;
|
||||
int64_t out_w = abstract::Shape::SHP_ANY;
|
||||
if (pad_mode == VALID) {
|
||||
out_h = static_cast<int64_t>(std::ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
|
||||
out_w = static_cast<int64_t>(std::ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
|
||||
} else if (pad_mode == SAME) {
|
||||
out_h = static_cast<int64_t>(std::ceil(in_h / static_cast<float>(stride_h)));
|
||||
out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
|
||||
}
|
||||
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
|
||||
|
||||
// Process attr mapping problems from mindspore to ai_cpu
|
||||
// kernel_size -> ksize
|
||||
// pad_mode -> padding
|
||||
if (format == NHWC) {
|
||||
std::vector<int64_t> ksize_NHWC = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
|
||||
(void)primitive->AddAttr("ksize", MakeValue(ksize_NHWC));
|
||||
(void)primitive->AddAttr("data_format", MakeValue("NHWC"));
|
||||
} else if (format == NCHW) {
|
||||
std::vector<int64_t> ksize_NCHW = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
|
||||
(void)primitive->AddAttr("ksize", MakeValue(ksize_NCHW));
|
||||
(void)primitive->AddAttr("data_format", MakeValue("NCHW"));
|
||||
}
|
||||
if (pad_mode == VALID) {
|
||||
(void)primitive->AddAttr("padding", MakeValue("VALID"));
|
||||
} else if (pad_mode == SAME) {
|
||||
(void)primitive->AddAttr("padding", MakeValue("SAME"));
|
||||
}
|
||||
|
||||
if (NHWC == format) {
|
||||
out_shape = {batch, out_h, out_w, channel};
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr AvgPoolV1InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto name = prim->name();
|
||||
const std::set<TypePtr> avgpool_v1_valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
auto input_type = input_args[kInputIndex0]->BuildType();
|
||||
auto inferred_type = CheckAndConvertUtils::CheckTypeValid("value", input_type, avgpool_v1_valid_types, name);
|
||||
return inferred_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AvgPoolV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto avgpool_v1_infer_type = AvgPoolV1InferType(primitive, input_args);
|
||||
auto avgpool_v1_infer_shape = AvgPoolV1InferShape(primitive, input_args)->shape();
|
||||
return std::make_shared<abstract::AbstractTensor>(avgpool_v1_infer_type, avgpool_v1_infer_shape);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(AvgPoolV1, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolV1, prim::kPrimAvgPoolV1, AvgPoolV1Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_AVG_POOL_V1_H_
|
||||
#define MINDSPORE_CORE_OPS_AVG_POOL_V1_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAvgPoolV1 = "AvgPoolV1";
|
||||
/// \brief Average pooling operation. Refer to Python API @ref mindspore.ops.AvgPoolV1 for more details.
|
||||
class MIND_API AvgPoolV1 : public BaseOperator {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
AvgPoolV1() : BaseOperator(kNameAvgPoolV1) { InitIOName({"value"}, {"output"}); }
|
||||
MIND_API_BASE_MEMBER(AvgPoolV1);
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr AvgPoolV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAvgPoolV1Ptr = std::shared_ptr<AvgPoolV1>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_AVG_POOL_V1_H_
|
|
@ -496,6 +496,8 @@ GVAR_DEF(PrimitivePtr, kPrimAvgPoolGrad, std::make_shared<Primitive>("AvgPoolGra
|
|||
GVAR_DEF(PrimitivePtr, kPrimAvgPool3DGrad, std::make_shared<Primitive>("AvgPool3DGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAvgPoolGradVm, std::make_shared<Primitive>("AvgPoolGradVm"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAvgPoolGradGe, std::make_shared<Primitive>("AvgPoolGradGe"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAvgPoolV1, std::make_shared<Primitive>("AvgPoolV1"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAvgPoolGradV1, std::make_shared<Primitive>("AvgPoolGradV1"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFusedSparseAdam, std::make_shared<Primitive>("FusedSparseAdam"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFusedBatchNorm, std::make_shared<Primitive>("FusedBatchNorm"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConv2D, std::make_shared<Primitive>("Conv2D"));
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/avg_pool_grad_v1.h"
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::ShapePtr AvgPoolGradV1InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
int64_t format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format"));
|
||||
std::vector<int64_t> kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
|
||||
auto pad_mode_value = (primitive->GetAttr(kPadMode));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
|
||||
if (format == NHWC) {
|
||||
std::vector<int64_t> ksize_NHWC = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
|
||||
(void)primitive->AddAttr("ksize", MakeValue(ksize_NHWC));
|
||||
(void)primitive->DelAttr("data_format");
|
||||
(void)primitive->AddAttr("data_format", MakeValue("NHWC"));
|
||||
} else if (format == NCHW) {
|
||||
std::vector<int64_t> ksize_NCHW = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
|
||||
(void)primitive->AddAttr("ksize", MakeValue(ksize_NCHW));
|
||||
(void)primitive->DelAttr("data_format");
|
||||
(void)primitive->AddAttr("data_format", MakeValue("NCHW"));
|
||||
}
|
||||
if (pad_mode == VALID) {
|
||||
(void)primitive->AddAttr("padding", MakeValue("VALID"));
|
||||
} else if (pad_mode == SAME) {
|
||||
(void)primitive->AddAttr("padding", MakeValue("SAME"));
|
||||
}
|
||||
|
||||
auto orig_input_shape = input_args[0]->BuildValue();
|
||||
auto orig_input_shape_tensor = orig_input_shape->cast<tensor::TensorPtr>();
|
||||
auto orig_input_shape_tensor_data_ptr = orig_input_shape_tensor->data_c();
|
||||
int32_t *orig_input_shape_ptr = static_cast<int32_t *>(orig_input_shape_tensor_data_ptr);
|
||||
|
||||
std::vector<int64_t> orig_shape = {orig_input_shape_ptr[0], orig_input_shape_ptr[1], orig_input_shape_ptr[2],
|
||||
orig_input_shape_ptr[3]};
|
||||
|
||||
return std::make_shared<abstract::Shape>(orig_shape);
|
||||
}
|
||||
TypePtr AvgPoolGradV1InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto name = prim->name();
|
||||
auto orig_input_shape_type = input_args[0]->BuildType();
|
||||
auto input_grad_type = input_args[1]->BuildType();
|
||||
const std::set<TypePtr> orig_input_shape_valid_type = {kInt32};
|
||||
const std::set<TypePtr> input_grad_valid_type = {kFloat16, kFloat32, kFloat64};
|
||||
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("orig_input_shape", orig_input_shape_type, orig_input_shape_valid_type,
|
||||
name);
|
||||
auto inferred_type = CheckAndConvertUtils::CheckTensorTypeValid("grad", input_grad_type, input_grad_valid_type, name);
|
||||
return inferred_type;
|
||||
}
|
||||
|
||||
AbstractBasePtr AvgPoolGradV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t num_inputs = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, num_inputs, primitive->name());
|
||||
auto avgpoolgradv1_infer_type = AvgPoolGradV1InferType(primitive, input_args);
|
||||
auto avgpoolgradv1_infer_shape = AvgPoolGradV1InferShape(primitive, input_args)->shape();
|
||||
return std::make_shared<abstract::AbstractTensor>(avgpoolgradv1_infer_type, avgpoolgradv1_infer_shape);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(AvgPoolGradV1, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGradV1, prim::kPrimAvgPoolGradV1, AvgPoolGradV1Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_AVG_POOL_GRAD_V1_H_
|
||||
#define MINDSPORE_CORE_OPS_AVG_POOL_GRAD_V1_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAvgPoolGradV1 = "AvgPoolGradV1";
|
||||
class MIND_API AvgPoolGradV1 : public BaseOperator {
|
||||
public:
|
||||
AvgPoolGradV1() : BaseOperator(kNameAvgPoolGradV1) { InitIOName({"orig_input_shape", "grad"}, {"output"}); }
|
||||
MIND_API_BASE_MEMBER(AvgPoolGradV1);
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr AvgPoolGradV1Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAvgPoolGradV1Ptr = std::shared_ptr<AvgPoolGradV1>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_AVG_POOL_GRAD_V1_H_
|
|
@ -127,6 +127,8 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
|
|||
{"AvgPoolGradVm", FormatAndPadUpperAttrMap},
|
||||
{"AvgPoolGradGpu", FormatAndPadUpperAttrMap},
|
||||
{"AvgPoolGradCpu", FormatAndPadUpperAttrMap},
|
||||
{"AvgPoolV1", FormatAndPadUpperAttrMap},
|
||||
{"AvgPoolGradV1", FormatAndPadUpperAttrMap},
|
||||
{"MaxPoolGrad", FormatAndPadUpperAttrMap},
|
||||
{"MaxPoolGradV1", FormatAndPadUpperAttrMap},
|
||||
{"MaxPoolGradGrad", FormatAndPadUpperAttrMap},
|
||||
|
|
|
@ -32,6 +32,8 @@ from ..operations._grad_ops import FractionalAvgPoolGrad
|
|||
from ..operations.nn_ops import NthElement
|
||||
from ..operations.nn_ops import PSROIPooling
|
||||
from ..operations._grad_ops import PSROIPoolingGrad
|
||||
from ..operations.nn_ops import AvgPoolV1
|
||||
from ..operations._grad_ops import AvgPoolGradV1
|
||||
from ..operations.nn_ops import MaxPoolV1
|
||||
from ..operations._grad_ops import MaxPoolGradV1
|
||||
|
||||
|
@ -186,6 +188,25 @@ def get_bprop_p_s_r_o_i_pooling(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(AvgPoolV1)
|
||||
def get_bprop_avg_pool_v1_grad(self):
|
||||
"""Grad definition for `AvgPoolV1` operation."""
|
||||
avgpool_grad_v1 = AvgPoolGradV1(
|
||||
kernel_size=self.kernel_size,
|
||||
strides=self.strides,
|
||||
pad_mode=self.pad_mode,
|
||||
data_format=self.format)
|
||||
to_arr = P.TupleToArray()
|
||||
get_shape = P.Shape()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
orig_input_shape = to_arr(get_shape(x))
|
||||
dx = avgpool_grad_v1(orig_input_shape, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(MaxPoolV1)
|
||||
def get_bprop_max_pool_v1_grad(self):
|
||||
"""Grad definition for `MaxPoolV1` operation."""
|
||||
|
|
|
@ -178,6 +178,8 @@ from .trace import _trace_aicpu
|
|||
from .tracegrad import _tracegrad_aicpu
|
||||
from .zeta import _zeta_aicpu
|
||||
from .adjust_hue import _adjust_hue_aicpu
|
||||
from .avgpool_v1 import _avgpool_v1_aicpu
|
||||
from .avgpool_grad_v1 import _avgpool_grad_v1_aicpu
|
||||
from .maxpool_v1 import _maxpool_v1_aicpu
|
||||
from .maxpool_grad_v1 import _maxpool_grad_v1_aicpu
|
||||
from .dense_to_csr_sparse_matrix import _dense_to_csr_sparse_matrix_aicpu
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AvgPoolGradV1 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
avgpool_grad_v1_op_info = AiCPURegOp("AvgPoolGradV1") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "orig_input_shape", "required") \
|
||||
.input(1, "input_grad", "required") \
|
||||
.output(0, "out_grad", "required") \
|
||||
.attr("ksize", "listInt") \
|
||||
.attr("strides", "listInt") \
|
||||
.attr("padding", "str") \
|
||||
.attr("data_format", "str", "NHWC") \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(avgpool_grad_v1_op_info)
|
||||
def _avgpool_grad_v1_aicpu():
|
||||
"""AvgPoolGradV1 aicpu register"""
|
||||
return
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AvgPoolV1 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
avgpool_v1_op_info = AiCPURegOp("AvgPoolV1") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("ksize", "listInt") \
|
||||
.attr("strides", "listInt") \
|
||||
.attr("padding", "str") \
|
||||
.attr("data_format", "str", "NHWC") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(avgpool_v1_op_info)
|
||||
def _avgpool_v1_aicpu():
|
||||
"""AvgPoolV1 aicpu register"""
|
||||
return
|
|
@ -820,6 +820,58 @@ class AvgPoolGrad(_PoolGrad):
|
|||
return x1_dtype
|
||||
|
||||
|
||||
class AvgPoolGradV1(Primitive):
|
||||
"""Gradients of the AvgPoolV1 operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
|
||||
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
self.pad_mode = validator.check_string(
|
||||
pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
||||
self.add_prim_attr("pad_mode", self.pad_mode)
|
||||
self.format = validator.check_string(
|
||||
data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
def _avgpoolgrad_check_int_or_tuple(argname, argval):
|
||||
validator.check_value_type(argname, argval, (int, tuple), self.name)
|
||||
errormsg = ValueError(f"For '{self.name}' the '{argname}' should be an positive int number "
|
||||
f"or a tuple of two or four positive int numbers, but got {argval}")
|
||||
if isinstance(argval, int):
|
||||
ret = (1, 1, argval, argval)
|
||||
elif len(argval) == 2:
|
||||
ret = (1, 1, argval[0], argval[1])
|
||||
elif len(argval) == 4:
|
||||
ret = argval
|
||||
else:
|
||||
raise errormsg
|
||||
# whether all elements of tuple are positive integers?
|
||||
for it in ret:
|
||||
if not isinstance(it, int) or it <= 0:
|
||||
raise errormsg
|
||||
return ret
|
||||
|
||||
self.kernel_size = _avgpoolgrad_check_int_or_tuple(
|
||||
"kernel_size", kernel_size)
|
||||
self.strides = _avgpoolgrad_check_int_or_tuple("strides", strides)
|
||||
|
||||
self.kernel_size_adapt = self.kernel_size if self.format == "NCHW" else (
|
||||
self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
|
||||
self.strides_adapt = self.strides if self.format == "NCHW" else (
|
||||
self.strides[0], self.strides[2], self.strides[3], self.strides[1])
|
||||
|
||||
# If length of some attrs is 4 we regard it as legal, either by using the op directly,
|
||||
# or passed from an instance of forward op AvgPoolV1.
|
||||
if len(self.kernel_size) == 4:
|
||||
self.kernel_size_adapt = self.kernel_size
|
||||
if len(self.strides) == 4:
|
||||
self.strides_adapt = self.strides
|
||||
|
||||
self.add_prim_attr("kernel_size", self.kernel_size_adapt)
|
||||
self.add_prim_attr("strides", self.strides_adapt)
|
||||
|
||||
|
||||
class AdaptiveAvgPool2DGrad(PrimitiveWithInfer):
|
||||
"""Gradients of the adaptive avg pool 2D operation."""
|
||||
|
||||
|
|
|
@ -2049,6 +2049,101 @@ class AvgPool(_Pool):
|
|||
super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)
|
||||
|
||||
|
||||
class AvgPoolV1(Primitive):
|
||||
r"""
|
||||
Average-pooling operation.
|
||||
|
||||
Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D planes.
|
||||
Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPoolV1 outputs
|
||||
regional average in the :math:`(H_{in}, W_{in})`-dimension. Given window size
|
||||
:math:`ks = (h_{ker}, w_{ker})` and strides :math:`s = (s_0, s_1)`, the operation is as follows.
|
||||
|
||||
.. math::
|
||||
\text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
|
||||
\text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
|
||||
|
||||
.. warning::
|
||||
- Only single input and single output are supported.
|
||||
- Global average pooling is supported.
|
||||
- The height of "kernel_size" and the weight of "kernel_size" are positive integers within the range [1, 255].
|
||||
ksize_h * ksize_w < 256.
|
||||
- Due to instruction restrictions, the values of "strides_h" and "strides_w" are
|
||||
positive integers within the range [1, 64).
|
||||
|
||||
Args:
|
||||
kernel_size (Union[int, tuple[int]]): The size of the kernel used to take the average value,
|
||||
is an integer that represents height and width of the kernel, or a tuple
|
||||
of two integers that represent height and width respectively. Default: 1.
|
||||
strides (Union[int, tuple[int]]): The distance of kernel moving, an integer that represents
|
||||
the height and width of movement are both strides, or a tuple of two integers that
|
||||
represent height and width of movement, respectively. Default: 1.
|
||||
pad_mode (str): The optional value for pad mode, should be one of "same" or "valid".
|
||||
Default: "valid".
|
||||
|
||||
- same: Adopts the way of completion. The height and width of output will be the same as
|
||||
the input. The total number of padding will be calculated horizontally and vertically,
|
||||
and evenly distributed to top and bottom, left and right if possible.
|
||||
Otherwise, the last extra padding will be done from bottom and right.
|
||||
|
||||
- valid: Adopts the way of discarding. The largest possible height and width of output
|
||||
will be returned without padding. Extra pixels will be discarded.
|
||||
data_format (str): The format of input and output data. Should be 'NHWC' or 'NCHW'.
|
||||
Default: 'NCHW'.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `kernel_size` or `strides` is neither int nor tuple.
|
||||
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
|
||||
ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
|
||||
ValueError: If `kernel_size` or `strides` is less than 1.
|
||||
ValueError: If length of shape of `x` is not equal to 4.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.arange(1 * 2 * 4 * 4).reshape((1, 2, 4, 4)), mindspore.float64)
|
||||
>>> avgpoolv1_op = ops.AvgPoolV1(pad_mode="VALID", kernel_size=3, strides=1)
|
||||
>>> _output = avgpoolv1_op(x)
|
||||
>>> print(_output)
|
||||
[[[[ 5. 6.]
|
||||
[ 9. 10.]]
|
||||
[[21. 22.]
|
||||
[25. 26.]]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
|
||||
"""Initialize AvgPoolV1."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
||||
self.pad_mode = validator.check_string(
|
||||
pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
||||
self.add_prim_attr("pad_mode", self.pad_mode)
|
||||
self.format = validator.check_string(
|
||||
data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.kernel_size = _check_positive_int_or_tuple(
|
||||
"kernel_size", kernel_size, self.name, allow_four=False, ret_four=True)
|
||||
self.strides = _check_positive_int_or_tuple(
|
||||
"strides", strides, self.name, allow_four=False, ret_four=True)
|
||||
|
||||
# adapt data_format
|
||||
self.kernel_size_adapted = self.kernel_size if self.format == "NCHW" else (
|
||||
self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
|
||||
self.add_prim_attr("kernel_size", self.kernel_size_adapted)
|
||||
self.strides_adapted = self.strides if self.format == "NCHW" else (
|
||||
self.strides[0], self.strides[2], self.strides[3], self.strides[1])
|
||||
self.add_prim_attr("strides", self.strides_adapted)
|
||||
|
||||
|
||||
class Conv2DBackpropInput(Primitive):
|
||||
r"""
|
||||
The Conv2DBackpropInput interface is deprecated, please refer to :class:`mindspore.ops.Conv2DTranspose` if you
|
||||
|
|
|
@ -55,6 +55,8 @@ from mindspore.ops.operations._grad_ops import FractionalAvgPoolGrad
|
|||
from mindspore.ops.operations.nn_ops import GridSampler2D
|
||||
from mindspore.ops.operations.nn_ops import NthElement
|
||||
from mindspore.ops.operations.nn_ops import PSROIPooling
|
||||
from mindspore.ops.operations.nn_ops import AvgPoolV1
|
||||
from mindspore.ops.operations._grad_ops import AvgPoolGradV1
|
||||
from mindspore.ops.operations.nn_ops import MaxPoolV1
|
||||
from mindspore.ops.operations._grad_ops import MaxPoolGradV1
|
||||
from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix
|
||||
|
@ -1159,6 +1161,18 @@ class NthElementNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class AvgPoolGradV1Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AvgPoolGradV1Net, self).__init__()
|
||||
self.avgpool_grad_v1 = AvgPoolGradV1(kernel_size=2, strides=2, pad_mode="VALID")
|
||||
self.to_arr = P.TupleToArray()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, orig_input, grad):
|
||||
orig_input_shape = self.to_arr(self.shape(orig_input))
|
||||
return self.avgpool_grad_v1(orig_input_shape, grad)
|
||||
|
||||
|
||||
test_case_math_ops = [
|
||||
('Cross', {
|
||||
'block': P.Cross(dim=1),
|
||||
|
@ -2117,6 +2131,15 @@ test_case_nn_ops = [
|
|||
'block': P.AvgPool(kernel_size=(2, 2), strides=(2, 2), pad_mode="VALID"),
|
||||
'desc_inputs': [[100, 3, 28, 28]],
|
||||
'desc_bprop': [[100, 3, 14, 14]]}),
|
||||
('AvgPoolV1', {
|
||||
'block': AvgPoolV1(kernel_size=(2, 2), strides=(2, 2), pad_mode="VALID"),
|
||||
'desc_inputs': [[100, 3, 28, 28]],
|
||||
'desc_bprop': [[100, 3, 14, 14]]}),
|
||||
('AvgPoolGradV1', {
|
||||
'block': AvgPoolGradV1Net(),
|
||||
'desc_inputs': [[100, 3, 28, 28], [100, 3, 14, 14]],
|
||||
'desc_bprop': [[100, 3, 28, 28]],
|
||||
'skip': ['backward']}),
|
||||
('AvgPool3D_1', {
|
||||
'block': P.AvgPool3D(kernel_size=2, strides=2, pad_mode="VALID"),
|
||||
'desc_inputs': [[10, 3, 28, 28, 28]],
|
||||
|
|
Loading…
Reference in New Issue