forked from mindspore-Ecosystem/mindspore
!33644 add PSROIPoolingV2 op
Merge pull request !33644 from guozhijian/add_PSROIPoolingV2
This commit is contained in:
commit
02f58183c5
|
@ -452,6 +452,7 @@ GVAR_DEF(PrimitivePtr, kPrimBesselJ1, std::make_shared<Primitive>("BesselJ1"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimTanhGrad, std::make_shared<Primitive>("TanhGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPooling, std::make_shared<Primitive>("Pooling"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPoolingGrad, std::make_shared<Primitive>("PoolingGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPSROIPooling, std::make_shared<Primitive>("PSROIPooling"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPSROIPoolingGrad, std::make_shared<Primitive>("PSROIPoolingGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimROIPooling, std::make_shared<Primitive>("ROIPooling"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxPool, std::make_shared<Primitive>("MaxPool"));
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/ps_roi_pooling.h"
|
||||
#include <set>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr PSROIPoolingInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", int64_t(input_args.size()), kGreaterEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (x_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name()
|
||||
<< "', input x shape must be 4d(NCHW), but got: " << x_shape.size();
|
||||
}
|
||||
|
||||
auto group_size_ptr = primitive->GetAttr("group_size");
|
||||
MS_EXCEPTION_IF_NULL(group_size_ptr);
|
||||
auto group_size = GetValue<int64_t>(group_size_ptr);
|
||||
|
||||
// The value of group_size must be less than 128
|
||||
if (group_size <= 0 || group_size >= 128) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name()
|
||||
<< "', 'group_size' should be in the range (0, 128), but got: " << group_size;
|
||||
}
|
||||
|
||||
auto output_dim_ptr = primitive->GetAttr("output_dim");
|
||||
MS_EXCEPTION_IF_NULL(output_dim_ptr);
|
||||
auto output_dim = GetValue<int64_t>(output_dim_ptr);
|
||||
|
||||
// the first dimension of the input data should be equal group_size * group_size * output_dim
|
||||
if (x_shape[1] / (group_size * group_size) != output_dim) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the second dimension(" << x_shape[1]
|
||||
<< ") of the input x is illegal, it is not equal to group_size(" << group_size
|
||||
<< ") * group_size(" << group_size << ") * output_dim(" << output_dim << ").";
|
||||
}
|
||||
|
||||
auto rois_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
if (rois_shape.size() < 3) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name()
|
||||
<< "', the dimension of 'rois' should be equal 3, but got: " << rois_shape.size();
|
||||
}
|
||||
|
||||
std::vector<int64_t> ret_shape({rois_shape[0] * rois_shape[2], output_dim, group_size, group_size});
|
||||
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
|
||||
TypePtr PSROIPoolingInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(PSROIPooling, BaseOperator);
|
||||
AbstractBasePtr PSROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto infertype = PSROIPoolingInferType(primitive, input_args);
|
||||
auto infershape = PSROIPoolingInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(PSROIPooling, prim::kPrimPSROIPooling, PSROIPoolingInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_PS_ROI_POOLING_H_
|
||||
#define MINDSPORE_CORE_OPS_PS_ROI_POOLING_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNamePSROIPooling = "PSROIPooling";
|
||||
|
||||
class PSROIPooling : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(PSROIPooling);
|
||||
PSROIPooling() : BaseOperator(kNamePSROIPooling) { InitIOName({"input", "rois"}, {"output"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr PSROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_PS_ROI_POOLING_H_
|
|
@ -569,4 +569,5 @@ from .expm1_ds import _expm1_ds_tbe
|
|||
from .deformable_offsets import _deformable_offsets_tbe
|
||||
from .parallel_resize_bilinear import _parallel_resize_bilinear_op_info_tbe
|
||||
from .parallel_resize_bilinear_grad import _parallel_resize_bilinear_grad_op_info_tbe
|
||||
from .p_s_r_o_i_pooling import _p_s_r_o_i_pooling_tbe
|
||||
from .p_s_r_o_i_pooling_grad import _p_s_r_o_i_pooling_grad_tbe
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""PSROIPooling op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
p_s_r_o_i_pooling_op_info = TBERegOp("PSROIPooling") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("p_s_r_o_i_pooling_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("p_s_r_o_i_pooling_v2") \
|
||||
.partial_flag(True) \
|
||||
.attr("output_dim", "required", "int", "all") \
|
||||
.attr("group_size", "required", "int", "all") \
|
||||
.attr("spatial_scale", "required", "float", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "rois", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(p_s_r_o_i_pooling_op_info)
|
||||
def _p_s_r_o_i_pooling_tbe():
|
||||
"""PSROIPooling TBE register"""
|
||||
return
|
|
@ -9328,3 +9328,84 @@ class NthElement(Primitive):
|
|||
self.add_prim_attr("reverse", self.reverse)
|
||||
self.init_prim_io_names(inputs=['input', 'n'],
|
||||
outputs=['output'])
|
||||
|
||||
|
||||
class PSROIPooling(Primitive):
|
||||
r"""
|
||||
Position Sensitive ROI-Pooling
|
||||
|
||||
Args:
|
||||
spatial_scale (float): a scaling factor that maps the box coordinates to the input coordinates.
|
||||
For example, if your boxes are defined on the scale of a 224x224 image and
|
||||
your input is a 112x112 feature map (resulting from a 0.5x scaling of the original
|
||||
image), you’ll want to set this to 0.5.
|
||||
group_size (int): the size of the output (in pixels) after the pooling is performed, as (height, width).
|
||||
output_dim (int): the dim of the output after the pooling is performed.
|
||||
|
||||
Inputs:
|
||||
- **features** (Tensor) - The input features, whose shape must be :math:`(N, C, H, W)`. With data type is
|
||||
float16 or float32. This formula should hold: :math:`(C == output_dim * group_size * group_size)`.
|
||||
- **rois** (Tensor) - The shape is `(batch, 5, rois_n)`. With data type of float16 or float32.
|
||||
The size of first dimension `batch` is batch_size. The size of the second dimension must be `5`.
|
||||
The size of third dimension `rois_n` is the number of rois. The value of `rois` like:
|
||||
(index, x1, y1, x2, y2). The first element of `rois_n` is the index of the `rois`. And the box coordinates
|
||||
in (x1, y1, x2, y2) format where the regions will be taken from. The coordinate must satisfy
|
||||
0 <= x1 < x2 and 0 <= y1 < y2.
|
||||
|
||||
Outputs:
|
||||
- out (rois.shape[0] * rois.shape[2], output_dim, group_size, group_size), the result after pooling.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.operations import nn_ops
|
||||
>>> features = np.random.randn(4, 3 * 7 * 7, 80, 48)
|
||||
>>> features = Tensor.from_numpy(features).astype(mindspore.float32)
|
||||
>>> rois = Tensor.from_numpy(
|
||||
>>> np.array([[[0.0000],
|
||||
>>> [150.3563],
|
||||
>>> [200.1320],
|
||||
>>> [579.3563],
|
||||
>>> [602.3452]],
|
||||
>>> [[1.0000],
|
||||
>>> [657.1263],
|
||||
>>> [302.8564],
|
||||
>>> [762.4214],
|
||||
>>> [567.9854]],
|
||||
>>> [[2.0000],
|
||||
>>> [321.3122],
|
||||
>>> [232.2410],
|
||||
>>> [679.0281],
|
||||
>>> [587.6346]],
|
||||
>>> [[3.0000],
|
||||
>>> [664.1630],
|
||||
>>> [387.4919],
|
||||
>>> [778.7322],
|
||||
>>> [562.7321]]])).astype(mindspore.float32)
|
||||
>>> psROIPooling = nn_ops.PSROIPooling(spatial_scale=1.0/16, output_dim=3,
|
||||
>>> group_size=7)
|
||||
>>> out = psROIPooling(features, rois)
|
||||
>>> print(out.shape)
|
||||
(4, 3, 7, 7)
|
||||
>>> print(out.dtype)
|
||||
Float32
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, spatial_scale, group_size, output_dim):
|
||||
|
||||
"""Initialize PSROIPooling"""
|
||||
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
||||
validator.check_value_type("group_size", group_size, [int], self.name)
|
||||
validator.check_value_type("output_dim", output_dim, [int], self.name)
|
||||
self.spatial_scale = spatial_scale
|
||||
self.group_size = group_size
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.add_prim_attr('spatial_scale', self.spatial_scale)
|
||||
self.add_prim_attr('group_size', self.group_size)
|
||||
self.add_prim_attr('output_dim', self.output_dim)
|
||||
|
|
|
@ -48,6 +48,7 @@ from mindspore.ops.operations._grad_ops import FractionalMaxPool3DGradWithFixedK
|
|||
from mindspore.ops.operations.nn_ops import FractionalAvgPool
|
||||
from mindspore.ops.operations._grad_ops import FractionalAvgPoolGrad
|
||||
from mindspore.ops.operations.nn_ops import NthElement
|
||||
from mindspore.ops.operations.nn_ops import PSROIPooling
|
||||
from mindspore.nn.layer import normalization
|
||||
from mindspore.ops.operations.array_ops import RightShift
|
||||
from mindspore._c_expression import security
|
||||
|
@ -3376,6 +3377,31 @@ test_case_other_ops = [
|
|||
Tensor(np.random.randint(0, 4, size=(4)).astype(np.int32))
|
||||
],
|
||||
'skip': ['backward']}),
|
||||
('PSROIPooling', {
|
||||
'block': PSROIPooling(1.0/16, 7, 3),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 255, size=(4, 3 * 7 * 7, 80, 48)).astype(np.float32)),
|
||||
Tensor(np.array([[[0.0000],
|
||||
[150.3563],
|
||||
[200.1320],
|
||||
[579.3563],
|
||||
[602.3452]],
|
||||
[[1.0000],
|
||||
[657.1263],
|
||||
[302.8564],
|
||||
[762.4214],
|
||||
[567.9854]],
|
||||
[[2.0000],
|
||||
[321.3122],
|
||||
[232.2410],
|
||||
[679.0281],
|
||||
[587.6346]],
|
||||
[[3.0000],
|
||||
[664.1630],
|
||||
[387.4919],
|
||||
[778.7322],
|
||||
[562.7321]]]).astype(np.float32))
|
||||
],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
test_case_quant_ops = [
|
||||
|
|
Loading…
Reference in New Issue