forked from mindspore-Ecosystem/mindspore
[feat][assistant][#I40FFZ] add new Ascend operator ExtractVolumePatches
This commit is contained in:
parent
ee38ffbd3d
commit
2f5758a4a8
|
@ -261,6 +261,7 @@ inline const PrimitivePtr kPrimDiag = std::make_shared<Primitive>(kDiag);
|
|||
inline const PrimitivePtr kPrimDiagPart = std::make_shared<Primitive>(kDiagPart);
|
||||
inline const PrimitivePtr kPrimNonZero = std::make_shared<Primitive>("NonZero");
|
||||
inline const PrimitivePtr kPrimReal = std::make_shared<Primitive>(kReal);
|
||||
inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shared<Primitive>("ExtractVolumePatches");
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/extract_volume_patches.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ExtractVolumePatchesInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int MAX_SHAPE = 2048;
|
||||
const int d = 2;
|
||||
const int w = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input shape", x_shape.size(), kEqual, 5, primitive->name());
|
||||
auto x_v = x_shape[2] * x_shape[3] * x_shape[4];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_d * x_h * x_w", x_v, kLessEqual, MAX_SHAPE, primitive->name());
|
||||
std::vector<int64_t> kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
std::vector<int64_t> strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
(void)CheckAndConvertUtils::CheckInteger("kernel_size_length", kernel_size.size(), kEqual, 5, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides_length", strides.size(), kEqual, 5, primitive->name());
|
||||
auto padding = GetValue<std::string>(primitive->GetAttr(kPadding));
|
||||
for (auto &item : strides) {
|
||||
(void)CheckAndConvertUtils::Check("strides", item, kGreaterThan, "zero", 0, primitive->name());
|
||||
}
|
||||
for (auto &item : kernel_size) {
|
||||
(void)CheckAndConvertUtils::Check("kernel_size", item, kGreaterThan, "zero", 0, primitive->name());
|
||||
}
|
||||
std::vector<int64_t> y_shape(5);
|
||||
int64_t padding_needed = 0;
|
||||
y_shape[0] = x_shape[0];
|
||||
y_shape[1] = x_shape[1] * kernel_size[2] * kernel_size[3] * kernel_size[4];
|
||||
if (padding == "VALID") {
|
||||
for (int i = d; i <= w; ++i) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(
|
||||
"padding = VALID, input[" + std::to_string(i) + "] - kernel_size[" + std::to_string(i) + "]",
|
||||
x_shape[i] - kernel_size[i], kGreaterEqual, 0, primitive->name());
|
||||
y_shape[i] = 1 + (x_shape[i] - kernel_size[i]) / strides[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = d; i <= w; ++i) {
|
||||
y_shape[i] = (x_shape[i] + strides[i] - 1) / strides[i];
|
||||
int64_t output_size = y_shape[i];
|
||||
padding_needed = (output_size - 1) * strides[i] + kernel_size[i] - x_shape[i];
|
||||
(void)CheckAndConvertUtils::CheckInteger("padding = ((input[" + std::to_string(i) + "] + strides[" +
|
||||
std::to_string(i) + "] - 1) / strides[" + std::to_string(i) +
|
||||
"]) - 1) * strides[" + std::to_string(i) + "] + kernel_size[" +
|
||||
std::to_string(i) + "] - input[" + std::to_string(i) + "]",
|
||||
padding_needed, kGreaterEqual, 0, primitive->name());
|
||||
}
|
||||
}
|
||||
if (y_shape[3] != 1 || y_shape[4] != 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_w + pad_l + pad_r - kernel_w - stride_w",
|
||||
x_shape[4] + padding_needed - kernel_size[4] - strides[4], kGreaterEqual,
|
||||
0, primitive->name());
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
|
||||
TypePtr ExtractVolumePatchesInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr ExtractVolumePatchesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto type = ExtractVolumePatchesInferType(primitive, input_args);
|
||||
auto shape = ExtractVolumePatchesInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ExtractVolumePatches, prim::kPrimExtractVolumePatches, ExtractVolumePatchesInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_EXTRACT_VOLUME_PATCHES_H_
|
||||
#define MINDSPORE_CORE_OPS_EXTRACT_VOLUME_PATCHES_H_
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameExtractVolumePatches = "ExtractVolumePatches";
|
||||
class MS_CORE_API ExtractVolumePatches : public PrimitiveC {
|
||||
public:
|
||||
ExtractVolumePatches() : PrimitiveC(kNameExtractVolumePatches) { InitIOName({"x"}, {"y"}); }
|
||||
~ExtractVolumePatches() = default;
|
||||
MS_DECLARE_PARENT(ExtractVolumePatches, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ExtractVolumePatchesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimExtractVolumePatchesPtr = std::shared_ptr<ExtractVolumePatches>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_EXTRACT_VOLUME_PATCHES_H_
|
|
@ -105,3 +105,54 @@ def get_bprop_split_v(self):
|
|||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ExtractVolumePatches)
|
||||
def get_bprop_extract_volume_patches(self):
|
||||
"""Generate bprop for ExtractVolumePatches"""
|
||||
extract_volume_patches = P.ExtractVolumePatches(kernel_size=self.kernel_size,
|
||||
strides=self.strides, padding=self.padding)
|
||||
concat = P.Concat(axis=-1)
|
||||
expend_dims = P.ExpandDims()
|
||||
scatter_nd = P.ScatterNd()
|
||||
slice_op = P.Slice()
|
||||
fill = P.Fill()
|
||||
dtype = P.DType()
|
||||
cast = P.Cast()
|
||||
matmul = P.MatMul()
|
||||
_, _, ksize_d, ksize_h, ksize_w = self.kernel_size
|
||||
|
||||
def bprop(x, out, dout):
|
||||
x_shape = P.Shape()(x)
|
||||
x_n, x_c, x_d, x_h, x_w = x_shape
|
||||
x_indices_num = 1 + x_d * x_h * x_w
|
||||
x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float16)
|
||||
x_idx = P.Reshape()(x_idx, (1, 1, x_d, x_h, x_w))
|
||||
x_idx_patched = extract_volume_patches(x_idx)
|
||||
x_idx_patched = P.Transpose()(x_idx_patched, (0, 2, 3, 4, 1))
|
||||
x_idx_patched = cast(x_idx_patched, mstype.int32)
|
||||
|
||||
out_shape = P.Shape()(out)
|
||||
_, _, out_d, out_h, out_w = out_shape
|
||||
out_indices_num = out_d * out_h * out_w * ksize_d * ksize_h * ksize_w
|
||||
out_idx = F.tuple_to_array(range(0, out_indices_num))
|
||||
out_idx = P.Reshape()(out_idx, (1, out_d, out_h, out_w, ksize_d * ksize_h * ksize_w))
|
||||
|
||||
idx_tensor = concat((expend_dims(x_idx_patched, -1), expend_dims(out_idx, -1)))
|
||||
idx_map = P.Reshape()(idx_tensor, (-1, 2))
|
||||
sp_shape = (x_indices_num, out_indices_num)
|
||||
sp_mat_full = scatter_nd(idx_map, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
|
||||
sp_tensor = slice_op(sp_mat_full, (1, 0), (x_indices_num - 1, out_indices_num))
|
||||
|
||||
grad = P.Transpose()(dout, (0, 2, 3, 4, 1))
|
||||
grad = P.Reshape()(grad, (x_n, out_d, out_h, out_w, ksize_d,
|
||||
ksize_h, ksize_w, x_c))
|
||||
grad_expended = P.Transpose()(grad, (1, 2, 3, 4, 5, 6, 0, 7))
|
||||
grad_flat = P.Reshape()(grad_expended, (-1, x_n * x_c))
|
||||
|
||||
jac = matmul(sp_tensor, grad_flat)
|
||||
dx = P.Reshape()(jac, (x_d, x_h, x_w, x_n, x_c))
|
||||
dx = P.Transpose()(dx, (3, 4, 0, 1, 2))
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -419,3 +419,4 @@ from .hshrink_grad import _hshrink_grad_tbe
|
|||
from .new_im2col import _new_im2col_tbe
|
||||
from .non_zero_ds import _non_zero_ds_tbe
|
||||
from .trunc import _trunc_tbe
|
||||
from .extract_volume_patches import _extract_volume_patches_tbe
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# #http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ExtractVolumePatches op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
extract_volume_patches_op_info = TBERegOp("ExtractVolumePatches") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("extract_volume_patches.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("extract_volume_patches") \
|
||||
.partial_flag(True) \
|
||||
.attr("kernel_size", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("padding", "required", "str", "all") \
|
||||
.input(0, "input_x", False, "required", "all") \
|
||||
.output(0, "output_y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
|
||||
.dtype_format(DataType.I8_NDC1HWC0, DataType.I8_NDC1HWC0) \
|
||||
.dtype_format(DataType.U8_NDC1HWC0, DataType.U8_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(extract_volume_patches_op_info)
|
||||
def _extract_volume_patches_tbe():
|
||||
"""ExtractVolumePatches TBE register"""
|
||||
return
|
|
@ -638,6 +638,7 @@ class DataType:
|
|||
I8_HWCN = ("int8", "HWCN")
|
||||
I8_NDHWC = ("int8", "NDHWC")
|
||||
I8_ChannelLast = ("int8", "ChannelLast")
|
||||
I8_NDC1HWC0 = ("int8", "NDC1HWC0")
|
||||
|
||||
U8_None = ("uint8", "")
|
||||
U8_Default = ("uint8", "DefaultFormat")
|
||||
|
@ -650,6 +651,7 @@ class DataType:
|
|||
U8_HWCN = ("uint8", "HWCN")
|
||||
U8_NDHWC = ("uint8", "NDHWC")
|
||||
U8_ChannelLast = ("uint8", "ChannelLast")
|
||||
U8_NDC1HWC0 = ("uint8", "NDC1HWC0")
|
||||
|
||||
I16_None = ("int16", "")
|
||||
I16_Default = ("int16", "DefaultFormat")
|
||||
|
@ -799,6 +801,7 @@ class DataType:
|
|||
I8_HWCN = ("int8", "HWCN")
|
||||
I8_NDHWC = ("int8", "NDHWC")
|
||||
I8_ChannelLast = ("int8", "ChannelLast")
|
||||
I8_NDC1HWC0 = ("int8", "NDC1HWC0")
|
||||
|
||||
U8_None = ("uint8", "")
|
||||
U8_Default = ("uint8", "DefaultFormat")
|
||||
|
@ -811,6 +814,7 @@ class DataType:
|
|||
U8_HWCN = ("uint8", "HWCN")
|
||||
U8_NDHWC = ("uint8", "NDHWC")
|
||||
U8_ChannelLast = ("uint8", "ChannelLast")
|
||||
U8_NDC1HWC0 = ("uint8", "NDC1HWC0")
|
||||
|
||||
I16_None = ("int16", "")
|
||||
I16_Default = ("int16", "DefaultFormat")
|
||||
|
|
|
@ -34,7 +34,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
|
||||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements)
|
||||
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
||||
|
@ -484,7 +484,8 @@ __all__ = [
|
|||
"Real",
|
||||
"Imag",
|
||||
"Trunc",
|
||||
"Complex"
|
||||
"Complex",
|
||||
"ExtractVolumePatches",
|
||||
]
|
||||
|
||||
__sponge__ = [
|
||||
|
|
|
@ -33,6 +33,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_
|
|||
from .. import signature as sig
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import _check_3d_int_or_tuple
|
||||
from ...common import dtype as mstype
|
||||
from ...common._decorator import deprecated
|
||||
from ...common.parameter import Parameter
|
||||
|
@ -6569,3 +6570,82 @@ class ScatterElements(Primitive):
|
|||
"""Initialize ScatterElements"""
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
|
||||
class ExtractVolumePatches(Primitive):
|
||||
"""
|
||||
Extract patches from input and put them in the "depth" output dimension. 3D extension of extract_image_patches.
|
||||
|
||||
Args:
|
||||
kernel_size (Union[int, tuple[int], list[int]]): A list of ints which's length is 3 or 5.
|
||||
The size of the sliding window for each dimension of input. Must be: [1, 1, k_d, k_h, k_w] or
|
||||
[k_d, k_h, k_w]. If k_d = k_h = k_w, you can enter an integer.
|
||||
strides (Union[int, tuple[int], list[int]]): A list of ints which's length is 3 or 5.
|
||||
How far the centers of two consecutive patches are in input. Must be: [1, 1, s_d, s_h, s_w] or
|
||||
[s_d, s_h, s_w]. If s_d = s_h = s_w, you can enter an integer.
|
||||
padding (str): A string from: "SAME", "VALID". The type of padding algorithm to use.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - A Tensor. Must be one of the following types: float16, float32.
|
||||
5-D Tensor with shape :math:`(x_n, x_c, x_d, x_h, x_w)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same type as input.
|
||||
If padding is VALID, the shape is :math:`(x_n, k_d * k_h * k_w * x_c, 1 + (x_d - k_d) / s_d,
|
||||
1 + (x_h - k_h) / s_h, 1 + (x_w - k_w) / s_w)`; if padding is SAME, the shape is :math:`(
|
||||
x_n, k_d * k_h * k_w * x_c, (x_d + s_d - 1) / s_d, (x_h + s_h - 1) / s_h, (x_w + s_w - 1) / s_w)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of input_x is neither float16 nor float32.
|
||||
TypeError: If kernel_size or strides is not a list, a tuple or a int.
|
||||
TypeError: If input_x is not a tensor.
|
||||
TypeError: If padding is not str.
|
||||
ValueError: If the length of kernel_size is neither 3 nor 5 and kernel_size is not an integer.
|
||||
ValueError: If the length of strides is neither 3 nor 5 and strides is not an integer.
|
||||
ValueError: If padding is neither "VALID" nor "SAME".
|
||||
ValueError: If elements of kernel_size or strides are not positive integer.
|
||||
ValueError: If input_x is not a tensor in dimension 5.
|
||||
ValueError: If input_x's shape has zero.
|
||||
ValueError: If one of kernel_size or strides' first two numbers is not 1.
|
||||
ValueError: If padding = "VALID" and input - kernel_size is less than 0 in d, h or w dimension.
|
||||
ValueError: If padding = "SAME" and :math:`padding_needed = ((input_x + strides - 1) / strides - 1) *
|
||||
strides + kernelz_size - input` is less than 0 in d, h or w dimension.
|
||||
ValueError: If x_h is not 1 or x_w is not 1 and x_w + padding_needed - k_w - s_w is less than 0.
|
||||
ValueError: If x_d * x_h * x_w is greater than 2048.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Example:
|
||||
>>> kernel_size = (1, 1, 2, 2, 2)
|
||||
>>> strides = (1, 1, 1, 1, 1)
|
||||
>>> padding = "VALID"
|
||||
>>> input_x = P.Reshape()(Tensor(np.arange(1, 28), mstype.float16), (1, 1, 3, 3, 3))
|
||||
>>> output_y = P.ExtractVolumePatches(kernel_size, strides, padding)(input_x)
|
||||
>>> print(output_y.shape)
|
||||
(1, 8, 2, 2, 2)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size, strides, padding):
|
||||
validator.check_value_type("kernel_size", kernel_size, (int, list, tuple), self.name)
|
||||
validator.check_value_type("strides", strides, (int, list, tuple), self.name)
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
kernel_size = tuple(kernel_size)
|
||||
if len(kernel_size) == 5:
|
||||
validator.check_int(kernel_size[0], 1, Rel.EQ, "kernel_size[0]", self.name)
|
||||
validator.check_int(kernel_size[1], 1, Rel.EQ, "kernel_size[1]", self.name)
|
||||
if isinstance(strides, (list, tuple)):
|
||||
strides = tuple(strides)
|
||||
if len(strides) == 5:
|
||||
validator.check_int(strides[0], 1, Rel.EQ, "strides[0]", self.name)
|
||||
validator.check_int(strides[1], 1, Rel.EQ, "strides[1]", self.name)
|
||||
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
|
||||
allow_five=True, ret_five=True, greater_zero=True)
|
||||
self.strides = _check_3d_int_or_tuple("strides", strides, self.name,
|
||||
allow_five=True, ret_five=True, greater_zero=True)
|
||||
self.add_prim_attr("kernel_size", self.kernel_size)
|
||||
self.add_prim_attr("strides", self.strides)
|
||||
validator.check_value_type("padding_dtype", padding, (str), self.name)
|
||||
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
|
|
|
@ -2615,6 +2615,11 @@ test_case_array_ops = [
|
|||
'desc_bprop': [(Tensor([[1], [4], [7]]),
|
||||
Tensor([[2, 3], [5, 6], [8, 9]]))],
|
||||
}),
|
||||
('ExtractVolumePatches', {
|
||||
'block': P.ExtractVolumePatches(kernel_size=[1, 1, 2, 2, 2], strides=[1, 1, 1, 1, 1], padding="VALID"),
|
||||
'desc_inputs': [Tensor(np.random.rand(1, 1, 3, 3, 3), mstype.float16)],
|
||||
'desc_bprop': [Tensor(np.random.rand(1, 8, 2, 2, 2), mstype.float16)],
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
|
|
Loading…
Reference in New Issue