[feat][assistant][#I40FFZ] add new Ascend operator ExtractVolumePatches

This commit is contained in:
hzw 2021-09-24 18:36:03 +08:00
parent ee38ffbd3d
commit 2f5758a4a8
10 changed files with 328 additions and 2 deletions

View File

@ -261,6 +261,7 @@ inline const PrimitivePtr kPrimDiag = std::make_shared<Primitive>(kDiag);
inline const PrimitivePtr kPrimDiagPart = std::make_shared<Primitive>(kDiagPart);
inline const PrimitivePtr kPrimNonZero = std::make_shared<Primitive>("NonZero");
inline const PrimitivePtr kPrimReal = std::make_shared<Primitive>(kReal);
inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shared<Primitive>("ExtractVolumePatches");
// NN
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");

View File

@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/extract_volume_patches.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ExtractVolumePatchesInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int MAX_SHAPE = 2048;
const int d = 2;
const int w = 4;
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto x_shape = x_shape_map[kShape];
(void)CheckAndConvertUtils::CheckInteger("input shape", x_shape.size(), kEqual, 5, primitive->name());
auto x_v = x_shape[2] * x_shape[3] * x_shape[4];
(void)CheckAndConvertUtils::CheckInteger("x_d * x_h * x_w", x_v, kLessEqual, MAX_SHAPE, primitive->name());
std::vector<int64_t> kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
std::vector<int64_t> strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
(void)CheckAndConvertUtils::CheckInteger("kernel_size_length", kernel_size.size(), kEqual, 5, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("strides_length", strides.size(), kEqual, 5, primitive->name());
auto padding = GetValue<std::string>(primitive->GetAttr(kPadding));
for (auto &item : strides) {
(void)CheckAndConvertUtils::Check("strides", item, kGreaterThan, "zero", 0, primitive->name());
}
for (auto &item : kernel_size) {
(void)CheckAndConvertUtils::Check("kernel_size", item, kGreaterThan, "zero", 0, primitive->name());
}
std::vector<int64_t> y_shape(5);
int64_t padding_needed = 0;
y_shape[0] = x_shape[0];
y_shape[1] = x_shape[1] * kernel_size[2] * kernel_size[3] * kernel_size[4];
if (padding == "VALID") {
for (int i = d; i <= w; ++i) {
(void)CheckAndConvertUtils::CheckInteger(
"padding = VALID, input[" + std::to_string(i) + "] - kernel_size[" + std::to_string(i) + "]",
x_shape[i] - kernel_size[i], kGreaterEqual, 0, primitive->name());
y_shape[i] = 1 + (x_shape[i] - kernel_size[i]) / strides[i];
}
} else {
for (int i = d; i <= w; ++i) {
y_shape[i] = (x_shape[i] + strides[i] - 1) / strides[i];
int64_t output_size = y_shape[i];
padding_needed = (output_size - 1) * strides[i] + kernel_size[i] - x_shape[i];
(void)CheckAndConvertUtils::CheckInteger("padding = ((input[" + std::to_string(i) + "] + strides[" +
std::to_string(i) + "] - 1) / strides[" + std::to_string(i) +
"]) - 1) * strides[" + std::to_string(i) + "] + kernel_size[" +
std::to_string(i) + "] - input[" + std::to_string(i) + "]",
padding_needed, kGreaterEqual, 0, primitive->name());
}
}
if (y_shape[3] != 1 || y_shape[4] != 1) {
(void)CheckAndConvertUtils::CheckInteger("input_w + pad_l + pad_r - kernel_w - stride_w",
x_shape[4] + padding_needed - kernel_size[4] - strides[4], kGreaterEqual,
0, primitive->name());
}
return std::make_shared<abstract::Shape>(y_shape);
}
TypePtr ExtractVolumePatchesInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
}
} // namespace
AbstractBasePtr ExtractVolumePatchesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto type = ExtractVolumePatchesInferType(primitive, input_args);
auto shape = ExtractVolumePatchesInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ExtractVolumePatches, prim::kPrimExtractVolumePatches, ExtractVolumePatchesInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_EXTRACT_VOLUME_PATCHES_H_
#define MINDSPORE_CORE_OPS_EXTRACT_VOLUME_PATCHES_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameExtractVolumePatches = "ExtractVolumePatches";
class MS_CORE_API ExtractVolumePatches : public PrimitiveC {
public:
ExtractVolumePatches() : PrimitiveC(kNameExtractVolumePatches) { InitIOName({"x"}, {"y"}); }
~ExtractVolumePatches() = default;
MS_DECLARE_PARENT(ExtractVolumePatches, PrimitiveC);
};
AbstractBasePtr ExtractVolumePatchesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimExtractVolumePatchesPtr = std::shared_ptr<ExtractVolumePatches>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_EXTRACT_VOLUME_PATCHES_H_

View File

@ -105,3 +105,54 @@ def get_bprop_split_v(self):
return (dx,)
return bprop
@bprop_getters.register(P.ExtractVolumePatches)
def get_bprop_extract_volume_patches(self):
"""Generate bprop for ExtractVolumePatches"""
extract_volume_patches = P.ExtractVolumePatches(kernel_size=self.kernel_size,
strides=self.strides, padding=self.padding)
concat = P.Concat(axis=-1)
expend_dims = P.ExpandDims()
scatter_nd = P.ScatterNd()
slice_op = P.Slice()
fill = P.Fill()
dtype = P.DType()
cast = P.Cast()
matmul = P.MatMul()
_, _, ksize_d, ksize_h, ksize_w = self.kernel_size
def bprop(x, out, dout):
x_shape = P.Shape()(x)
x_n, x_c, x_d, x_h, x_w = x_shape
x_indices_num = 1 + x_d * x_h * x_w
x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float16)
x_idx = P.Reshape()(x_idx, (1, 1, x_d, x_h, x_w))
x_idx_patched = extract_volume_patches(x_idx)
x_idx_patched = P.Transpose()(x_idx_patched, (0, 2, 3, 4, 1))
x_idx_patched = cast(x_idx_patched, mstype.int32)
out_shape = P.Shape()(out)
_, _, out_d, out_h, out_w = out_shape
out_indices_num = out_d * out_h * out_w * ksize_d * ksize_h * ksize_w
out_idx = F.tuple_to_array(range(0, out_indices_num))
out_idx = P.Reshape()(out_idx, (1, out_d, out_h, out_w, ksize_d * ksize_h * ksize_w))
idx_tensor = concat((expend_dims(x_idx_patched, -1), expend_dims(out_idx, -1)))
idx_map = P.Reshape()(idx_tensor, (-1, 2))
sp_shape = (x_indices_num, out_indices_num)
sp_mat_full = scatter_nd(idx_map, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
sp_tensor = slice_op(sp_mat_full, (1, 0), (x_indices_num - 1, out_indices_num))
grad = P.Transpose()(dout, (0, 2, 3, 4, 1))
grad = P.Reshape()(grad, (x_n, out_d, out_h, out_w, ksize_d,
ksize_h, ksize_w, x_c))
grad_expended = P.Transpose()(grad, (1, 2, 3, 4, 5, 6, 0, 7))
grad_flat = P.Reshape()(grad_expended, (-1, x_n * x_c))
jac = matmul(sp_tensor, grad_flat)
dx = P.Reshape()(jac, (x_d, x_h, x_w, x_n, x_c))
dx = P.Transpose()(dx, (3, 4, 0, 1, 2))
return (dx,)
return bprop

View File

@ -419,3 +419,4 @@ from .hshrink_grad import _hshrink_grad_tbe
from .new_im2col import _new_im2col_tbe
from .non_zero_ds import _non_zero_ds_tbe
from .trunc import _trunc_tbe
from .extract_volume_patches import _extract_volume_patches_tbe

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# #http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExtractVolumePatches op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
extract_volume_patches_op_info = TBERegOp("ExtractVolumePatches") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("extract_volume_patches.so") \
.compute_cost(10) \
.kernel_name("extract_volume_patches") \
.partial_flag(True) \
.attr("kernel_size", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.input(0, "input_x", False, "required", "all") \
.output(0, "output_y", False, "required", "all") \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
.dtype_format(DataType.I8_NDC1HWC0, DataType.I8_NDC1HWC0) \
.dtype_format(DataType.U8_NDC1HWC0, DataType.U8_NDC1HWC0) \
.get_op_info()
@op_info_register(extract_volume_patches_op_info)
def _extract_volume_patches_tbe():
"""ExtractVolumePatches TBE register"""
return

View File

@ -638,6 +638,7 @@ class DataType:
I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
I8_ChannelLast = ("int8", "ChannelLast")
I8_NDC1HWC0 = ("int8", "NDC1HWC0")
U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
@ -650,6 +651,7 @@ class DataType:
U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
U8_ChannelLast = ("uint8", "ChannelLast")
U8_NDC1HWC0 = ("uint8", "NDC1HWC0")
I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")
@ -799,6 +801,7 @@ class DataType:
I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
I8_ChannelLast = ("int8", "ChannelLast")
I8_NDC1HWC0 = ("int8", "NDC1HWC0")
U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
@ -811,6 +814,7 @@ class DataType:
U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
U8_ChannelLast = ("uint8", "ChannelLast")
U8_NDC1HWC0 = ("uint8", "NDC1HWC0")
I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")

View File

@ -34,7 +34,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements)
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
@ -484,7 +484,8 @@ __all__ = [
"Real",
"Imag",
"Trunc",
"Complex"
"Complex",
"ExtractVolumePatches",
]
__sponge__ = [

View File

@ -33,6 +33,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_
from .. import signature as sig
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ..._checkparam import _check_3d_int_or_tuple
from ...common import dtype as mstype
from ...common._decorator import deprecated
from ...common.parameter import Parameter
@ -6569,3 +6570,82 @@ class ScatterElements(Primitive):
"""Initialize ScatterElements"""
validator.check_value_type("axis", axis, [int], self.name)
self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
class ExtractVolumePatches(Primitive):
"""
Extract patches from input and put them in the "depth" output dimension. 3D extension of extract_image_patches.
Args:
kernel_size (Union[int, tuple[int], list[int]]): A list of ints which's length is 3 or 5.
The size of the sliding window for each dimension of input. Must be: [1, 1, k_d, k_h, k_w] or
[k_d, k_h, k_w]. If k_d = k_h = k_w, you can enter an integer.
strides (Union[int, tuple[int], list[int]]): A list of ints which's length is 3 or 5.
How far the centers of two consecutive patches are in input. Must be: [1, 1, s_d, s_h, s_w] or
[s_d, s_h, s_w]. If s_d = s_h = s_w, you can enter an integer.
padding (str): A string from: "SAME", "VALID". The type of padding algorithm to use.
Inputs:
- **input_x** (Tensor) - A Tensor. Must be one of the following types: float16, float32.
5-D Tensor with shape :math:`(x_n, x_c, x_d, x_h, x_w)`.
Outputs:
Tensor, has the same type as input.
If padding is VALID, the shape is :math:`(x_n, k_d * k_h * k_w * x_c, 1 + (x_d - k_d) / s_d,
1 + (x_h - k_h) / s_h, 1 + (x_w - k_w) / s_w)`; if padding is SAME, the shape is :math:`(
x_n, k_d * k_h * k_w * x_c, (x_d + s_d - 1) / s_d, (x_h + s_h - 1) / s_h, (x_w + s_w - 1) / s_w)`.
Raises:
TypeError: If dtype of input_x is neither float16 nor float32.
TypeError: If kernel_size or strides is not a list, a tuple or a int.
TypeError: If input_x is not a tensor.
TypeError: If padding is not str.
ValueError: If the length of kernel_size is neither 3 nor 5 and kernel_size is not an integer.
ValueError: If the length of strides is neither 3 nor 5 and strides is not an integer.
ValueError: If padding is neither "VALID" nor "SAME".
ValueError: If elements of kernel_size or strides are not positive integer.
ValueError: If input_x is not a tensor in dimension 5.
ValueError: If input_x's shape has zero.
ValueError: If one of kernel_size or strides' first two numbers is not 1.
ValueError: If padding = "VALID" and input - kernel_size is less than 0 in d, h or w dimension.
ValueError: If padding = "SAME" and :math:`padding_needed = ((input_x + strides - 1) / strides - 1) *
strides + kernelz_size - input` is less than 0 in d, h or w dimension.
ValueError: If x_h is not 1 or x_w is not 1 and x_w + padding_needed - k_w - s_w is less than 0.
ValueError: If x_d * x_h * x_w is greater than 2048.
Supported Platforms:
``Ascend``
Example:
>>> kernel_size = (1, 1, 2, 2, 2)
>>> strides = (1, 1, 1, 1, 1)
>>> padding = "VALID"
>>> input_x = P.Reshape()(Tensor(np.arange(1, 28), mstype.float16), (1, 1, 3, 3, 3))
>>> output_y = P.ExtractVolumePatches(kernel_size, strides, padding)(input_x)
>>> print(output_y.shape)
(1, 8, 2, 2, 2)
"""
@prim_attr_register
def __init__(self, kernel_size, strides, padding):
validator.check_value_type("kernel_size", kernel_size, (int, list, tuple), self.name)
validator.check_value_type("strides", strides, (int, list, tuple), self.name)
if isinstance(kernel_size, (list, tuple)):
kernel_size = tuple(kernel_size)
if len(kernel_size) == 5:
validator.check_int(kernel_size[0], 1, Rel.EQ, "kernel_size[0]", self.name)
validator.check_int(kernel_size[1], 1, Rel.EQ, "kernel_size[1]", self.name)
if isinstance(strides, (list, tuple)):
strides = tuple(strides)
if len(strides) == 5:
validator.check_int(strides[0], 1, Rel.EQ, "strides[0]", self.name)
validator.check_int(strides[1], 1, Rel.EQ, "strides[1]", self.name)
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
allow_five=True, ret_five=True, greater_zero=True)
self.strides = _check_3d_int_or_tuple("strides", strides, self.name,
allow_five=True, ret_five=True, greater_zero=True)
self.add_prim_attr("kernel_size", self.kernel_size)
self.add_prim_attr("strides", self.strides)
validator.check_value_type("padding_dtype", padding, (str), self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)

View File

@ -2615,6 +2615,11 @@ test_case_array_ops = [
'desc_bprop': [(Tensor([[1], [4], [7]]),
Tensor([[2, 3], [5, 6], [8, 9]]))],
}),
('ExtractVolumePatches', {
'block': P.ExtractVolumePatches(kernel_size=[1, 1, 2, 2, 2], strides=[1, 1, 1, 1, 1], padding="VALID"),
'desc_inputs': [Tensor(np.random.rand(1, 1, 3, 3, 3), mstype.float16)],
'desc_bprop': [Tensor(np.random.rand(1, 8, 2, 2, 2), mstype.float16)],
}),
]
test_case_other_ops = [