forked from mindspore-Ecosystem/mindspore
add conv3d object
This commit is contained in:
parent
07dbe082c5
commit
c45eee9b94
|
@ -37,9 +37,9 @@ class Rel(Enum):
|
|||
GE = 6 # >=
|
||||
# scalar range check
|
||||
INC_NEITHER = 7 # (), include neither
|
||||
INC_LEFT = 8 # [), include left
|
||||
INC_RIGHT = 9 # (], include right
|
||||
INC_BOTH = 10 # [], include both
|
||||
INC_LEFT = 8 # [), include left
|
||||
INC_RIGHT = 9 # (], include right
|
||||
INC_BOTH = 10 # [], include both
|
||||
# collection in, not in
|
||||
IN = 11
|
||||
NOT_IN = 12
|
||||
|
@ -92,6 +92,41 @@ rel_strs = {
|
|||
}
|
||||
|
||||
|
||||
def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False,
|
||||
ret_five=False, greater_zero=True):
|
||||
"""
|
||||
Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
|
||||
"""
|
||||
|
||||
def _raise_message():
|
||||
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three "
|
||||
f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}")
|
||||
|
||||
def _get_return_value():
|
||||
if isinstance(arg_value, int):
|
||||
ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
|
||||
elif len(arg_value) == 3:
|
||||
ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
|
||||
elif len(arg_value) == 5:
|
||||
if not allow_five:
|
||||
_raise_message()
|
||||
ret = arg_value if ret_five else (arg_value[1], arg_value[2], arg_value[3])
|
||||
else:
|
||||
_raise_message()
|
||||
return ret
|
||||
|
||||
Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
|
||||
ret_value = _get_return_value()
|
||||
for item in ret_value:
|
||||
if isinstance(item, int) and not isinstance(item, bool):
|
||||
if greater_zero and item > 0:
|
||||
continue
|
||||
if not greater_zero and item >= 0:
|
||||
continue
|
||||
_raise_message()
|
||||
return tuple(ret_value)
|
||||
|
||||
|
||||
def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Check argument integer.
|
||||
|
@ -428,6 +463,7 @@ class Validator:
|
|||
@staticmethod
|
||||
def check_types_same_and_valid(args, valid_values, prim_name):
|
||||
"""Checks whether the types of inputs are the same and valid."""
|
||||
|
||||
def _check_type_valid(arg):
|
||||
arg_key, arg_val = arg
|
||||
elem_type = arg_val
|
||||
|
@ -494,6 +530,7 @@ class Validator:
|
|||
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
||||
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
||||
return arg1
|
||||
|
||||
reduce(_check_types_same, map(_check_argument_type, args.items()))
|
||||
|
||||
@staticmethod
|
||||
|
@ -625,6 +662,7 @@ def args_type_check(*type_args, **type_kwargs):
|
|||
if value is not None and not isinstance(value, bound_types[name]):
|
||||
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return type_check
|
||||
|
|
|
@ -29,6 +29,8 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimAvgPoolGradVm->name(), {0});
|
||||
Register(prim::kPrimConv2DBackpropInput->name(), {2});
|
||||
Register(prim::kPrimConv2DBackpropFilter->name(), {2});
|
||||
Register(prim::kPrimConv3DBackpropInput->name(), {2});
|
||||
Register(prim::kPrimConv3DBackpropFilter->name(), {2});
|
||||
Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1});
|
||||
Register(prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), {0});
|
||||
Register(prim::kPrimReshape->name(), {1});
|
||||
|
|
|
@ -31,6 +31,7 @@ namespace session {
|
|||
namespace {
|
||||
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
|
||||
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||
constexpr size_t k5dDims = 5;
|
||||
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
|
||||
prim::kPrimAssignSub->name()};
|
||||
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||
|
@ -383,7 +384,8 @@ void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &forma
|
|||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) {
|
||||
auto in_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), i);
|
||||
MS_EXCEPTION_IF_NULL(in_node);
|
||||
if (in_node->isa<Parameter>() || in_node->isa<ValueNode>()) {
|
||||
if ((in_node->isa<Parameter>() || in_node->isa<ValueNode>()) &&
|
||||
AnfAlgo::GetOutputInferShape(in_node, 0).size() == k5dDims) {
|
||||
ReSetParameterValueNodeFormatAndType(in_node, format);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -291,10 +291,7 @@ std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
|||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[4]);
|
||||
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
||||
device_shape.push_back(N1);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
|
|
|
@ -148,6 +148,8 @@ inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"
|
|||
inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad");
|
||||
inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
|
||||
inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
|
||||
inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput");
|
||||
inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
from mindspore.ops import _selected_grad_ops as SG
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from .grad_base import bprop_getters
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
|
@ -60,6 +61,27 @@ def get_bprop_conv2d(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(nps.Conv3D)
|
||||
def get_bprop_conv3d(self):
|
||||
"""Grad definition for `Conv3D` operation."""
|
||||
input_grad = nps.Conv3DBackpropInput(
|
||||
self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode,
|
||||
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
||||
)
|
||||
filter_grad = G.Conv3DBackpropFilter(
|
||||
self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode,
|
||||
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
||||
)
|
||||
get_shape = P.Shape()
|
||||
|
||||
def bprop(x, w, out, dout):
|
||||
dx = input_grad(w, dout, get_shape(x))
|
||||
dw = filter_grad(x, dout, get_shape(w))
|
||||
return dx, dw
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.ExtractImagePatches)
|
||||
def get_bprop_extract_image_patches(self):
|
||||
"""Grad definition for `ExtractImagePatches` operation."""
|
||||
|
|
|
@ -347,3 +347,7 @@ from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe
|
|||
from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe
|
||||
from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe
|
||||
from .fake_quant_with_min_max_vars_per_channel_gradient import _fake_quant_with_min_max_vars_per_channel_gradient_tbe
|
||||
from .conv3d import _conv3d_tbe
|
||||
from .conv3d_backprop_input import _conv3d_backprop_input_tbe
|
||||
from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe
|
||||
from .conv3d_transpose import _conv3d_transpose_tbe
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Conv3D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv3d_op_info = TBERegOp("Conv3D") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv3d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv3d") \
|
||||
.partial_flag(True) \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("pads", "required", "listInt", "all") \
|
||||
.attr("dilations", "required", "listInt", "all") \
|
||||
.attr("groups", "optional", "int", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.attr("offset_x", "optional", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "filter", False, "required", "all") \
|
||||
.input(2, "bias", False, "optional", "all") \
|
||||
.input(3, "offset_w", False, "optional", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default,
|
||||
DataType.I8_Default, DataType.F16_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(conv3d_op_info)
|
||||
def _conv3d_tbe():
|
||||
"""Conv3D TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Conv3DBackpropFilter op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv3d_backprop_filter_op_info = TBERegOp("Conv3DBackpropFilter") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv3d_backprop_filter_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv3d_backprop_filter_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("filter_size", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("pads", "required", "listInt", "all") \
|
||||
.attr("dilations", "required", "listInt", "all") \
|
||||
.attr("groups", "optional", "int", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "out_backprop", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_FRACTAL_Z_3D) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(conv3d_backprop_filter_op_info)
|
||||
def _conv3d_backprop_filter_tbe():
|
||||
"""Conv3DBackpropFilter TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Conv3DBackpropInput op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv3d_backprop_input_op_info = TBERegOp("Conv3DBackpropInput") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv3d_backprop_input_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv3d_backprop_input_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("input_size", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("pads", "required", "listInt", "all") \
|
||||
.attr("dilations", "required", "listInt", "all") \
|
||||
.attr("groups", "optional", "int", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.input(0, "filter", False, "required", "all") \
|
||||
.input(1, "out_backprop", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(conv3d_backprop_input_op_info)
|
||||
def _conv3d_backprop_input_tbe():
|
||||
"""Conv3DBackpropInput TBE register"""
|
||||
return
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Conv3DTranspose op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv3d_transpose_op_info = TBERegOp("Conv3DTranspose") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv3d_transpose_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv3d_transpose_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("input_size", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("pads", "required", "listInt", "all") \
|
||||
.attr("dilations", "optional", "listInt", "all") \
|
||||
.attr("groups", "optional", "int", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.attr("output_padding", "optional", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(0, "filter", False, "required", "all") \
|
||||
.input(0, "bias", False, "optional", "all") \
|
||||
.input(1, "offset_w", False, "optional", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, DataType.I8_Default,
|
||||
DataType.F16_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(conv3d_transpose_op_info)
|
||||
def _conv3d_transpose_tbe():
|
||||
"""Conv3DTranspose TBE register"""
|
||||
return
|
|
@ -133,6 +133,24 @@ trans_data_op_info = TBERegOp("TransData") \
|
|||
.dtype_format(DataType.F32_HWCN, DataType.F32_FracZNLSTM) \
|
||||
.dtype_format(DataType.F16_FracZNLSTM, DataType.F16_HWCN) \
|
||||
.dtype_format(DataType.F32_FracZNLSTM, DataType.F32_HWCN) \
|
||||
.dtype_format(DataType.F16_NDHWC, DataType.F16_NDC1HWC0) \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDHWC) \
|
||||
.dtype_format(DataType.F16_DHWCN, DataType.F16_FRACTAL_Z_3D) \
|
||||
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_DHWCN) \
|
||||
.dtype_format(DataType.F16_NCDHW, DataType.F16_NDC1HWC0) \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NCDHW) \
|
||||
.dtype_format(DataType.F16_NCDHW, DataType.F16_FRACTAL_Z_3D) \
|
||||
.dtype_format(DataType.F32_NCDHW, DataType.F32_FRACTAL_Z_3D) \
|
||||
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NCDHW) \
|
||||
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NCDHW) \
|
||||
.dtype_format(DataType.F16_NDHWC, DataType.F16_FRACTAL_Z_3D) \
|
||||
.dtype_format(DataType.F32_NDHWC, DataType.F32_FRACTAL_Z_3D) \
|
||||
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NDHWC) \
|
||||
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NDHWC) \
|
||||
.dtype_format(DataType.F32_DHWCN, DataType.F32_FRACTAL_Z_3D) \
|
||||
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_DHWCN) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDHWC) \
|
||||
.dtype_format(DataType.F32_NDHWC, DataType.F32_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -110,7 +110,7 @@ def clip_by_global_norm(x, clip_norm=1.0, use_norm=None):
|
|||
Clips tensor values by the ratio of the sum of their norms.
|
||||
|
||||
Note:
|
||||
input 'x' should be a tuple or list of tensors. Otherwise, it will raise an error.
|
||||
Input 'x' should be a tuple or list of tensors. Otherwise, it will raise an error.
|
||||
|
||||
Args:
|
||||
x (Union(tuple[Tensor], list[Tensor])): Input data to clip.
|
||||
|
|
|
@ -631,6 +631,10 @@ class DataType:
|
|||
F16_NHWC = ("float16", "NHWC")
|
||||
F16_HWCN = ("float16", "HWCN")
|
||||
F16_NDHWC = ("float16", "NDHWC")
|
||||
F16_NCDHW = ("float16", "NCDHW")
|
||||
F16_DHWCN = ("float16", "DHWCN")
|
||||
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
||||
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
||||
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
||||
|
||||
F32_None = ("float32", "")
|
||||
|
@ -643,6 +647,10 @@ class DataType:
|
|||
F32_NHWC = ("float32", "NHWC")
|
||||
F32_HWCN = ("float32", "HWCN")
|
||||
F32_NDHWC = ("float32", "NDHWC")
|
||||
F32_NCDHW = ("float32", "NCDHW")
|
||||
F32_DHWCN = ("float32", "DHWCN")
|
||||
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
||||
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
||||
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
||||
|
||||
F64_None = ("float64", "")
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Operators for gradients."""
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from mindspore._checkparam import _check_3d_int_or_tuple
|
||||
from .. import signature as sig
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
|
@ -288,6 +289,140 @@ class ConcatOffset(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Conv3DBackpropFilter(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the gradients of convolution 3D with respect to the filter.
|
||||
|
||||
Args:
|
||||
out_channel (int): The dimension of the output.
|
||||
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
|
||||
mode (int): Modes for different convolutions. Not currently used.
|
||||
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
|
||||
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
||||
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
|
||||
integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
|
||||
pad[3], pad[4] and pad[5] correspondingly.
|
||||
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
|
||||
dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
|
||||
group (int): Splits input into groups. Default: 1.
|
||||
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
|
||||
- **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default
|
||||
data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
|
||||
- **w_size** (Tensor) - A tuple describes the shape of the weight which conforms to the format
|
||||
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16)
|
||||
>>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
|
||||
>>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
|
||||
>>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
|
||||
>>> output = conv3d_backprop_input(x, dout, F.shape(w))
|
||||
>>> print(output.shape)
|
||||
(32, 32, 4, 6, 2)
|
||||
"""
|
||||
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
mode=1,
|
||||
stride=(1, 1, 1, 1, 1),
|
||||
dilation=(1, 1, 1, 1, 1),
|
||||
group=1,
|
||||
data_format="NCDHW"):
|
||||
"""Initialize Convolution"""
|
||||
self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y'])
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
|
||||
self.add_prim_attr('strides', self.stride)
|
||||
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
|
||||
self.add_prim_attr('dilations', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||
if isinstance(pad, int):
|
||||
pad = (pad,) * 6
|
||||
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
|
||||
self.pad_list = pad
|
||||
self.add_prim_attr('pads', self.pad_list)
|
||||
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
||||
self.group = validator.check_positive_int(group, 'group', self.name)
|
||||
self.add_prim_attr('groups', self.group)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.add_prim_attr('io_format', "NCDHW")
|
||||
|
||||
def __infer__(self, x, doutput, w_size):
|
||||
w_size_v = w_size['value']
|
||||
validator.check_value_type('w_size', w_size_v, [tuple], self.name)
|
||||
for i, dim_len in enumerate(w_size_v):
|
||||
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
|
||||
args = {"x": x['dtype'], "doutput": doutput['dtype']}
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
|
||||
validator.check("filter's batch", w_size_v[0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name)
|
||||
validator.check("filter's channel", w_size_v[1], "input_size's channel", x['shape'][1], Rel.EQ, self.name)
|
||||
validator.check("input_size's batch", x['shape'][0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name)
|
||||
|
||||
# infer shape
|
||||
x_shape = x['shape']
|
||||
dout_shape = doutput['shape']
|
||||
kernel_d = self.kernel_size[0]
|
||||
kernel_h = self.kernel_size[1]
|
||||
kernel_w = self.kernel_size[2]
|
||||
stride_d = self.stride[2]
|
||||
stride_h = self.stride[3]
|
||||
stride_w = self.stride[4]
|
||||
dilation_d = self.dilation[2]
|
||||
dilation_h = self.dilation[3]
|
||||
dilation_w = self.dilation[4]
|
||||
# The pad_mode is valid by default. If pad_mode is not valid or same, then pad.
|
||||
if self.pad_mode == "valid":
|
||||
self.pad_list = (0, 0, 0, 0, 0, 0)
|
||||
if self.pad_mode == "same":
|
||||
pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_shape[2])
|
||||
pad_head = math.floor(pad_needed_d / 2)
|
||||
pad_tail = pad_needed_d - pad_head
|
||||
|
||||
pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_shape[3])
|
||||
pad_top = math.floor(pad_needed_h / 2)
|
||||
pad_bottom = pad_needed_h - pad_top
|
||||
|
||||
pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_shape[4])
|
||||
pad_left = math.floor(pad_needed_w / 2)
|
||||
pad_right = pad_needed_w - pad_left
|
||||
self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
|
||||
|
||||
self.add_prim_attr('pads', self.pad_list)
|
||||
out = {
|
||||
'value': None,
|
||||
'shape': w_size_v,
|
||||
'dtype': mstype.float32,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class Conv2DBackpropFilter(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the gradients of convolution with respect to the filter.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
import math
|
||||
import operator
|
||||
from functools import reduce, partial
|
||||
from mindspore._checkparam import _check_3d_int_or_tuple
|
||||
import numpy as np
|
||||
from ... import context
|
||||
from .. import signature as sig
|
||||
|
@ -3683,6 +3684,7 @@ class AdamNoUpdateParam(PrimitiveWithInfer):
|
|||
[-0.00013441 -0.00013441 -0.00013441]]
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
@ -6526,3 +6528,295 @@ class LRN(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name)
|
||||
return x_shape
|
||||
|
||||
|
||||
class Conv3D(PrimitiveWithInfer):
|
||||
r"""
|
||||
3D convolution layer.
|
||||
|
||||
Applies a 3D convolution over an input tensor which is typically of shape
|
||||
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
|
||||
For each batch of shape :math:`(C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
|
||||
If the 'pad_mode' is set to be "valid", the output height and width will be
|
||||
:math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} -
|
||||
(\text{ks_d} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
|
||||
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
|
||||
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
|
||||
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
|
||||
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
|
||||
|
||||
Args:
|
||||
out_channel (int): The dimension of the output.
|
||||
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
|
||||
mode (int): Modes for different convolutions. Not currently used.
|
||||
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
|
||||
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
||||
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
|
||||
integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
|
||||
pad[3], pad[4] and pad[5] correspondingly.
|
||||
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
|
||||
dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
|
||||
group (int): Splits input into groups. Default: 1.
|
||||
data_format (str): The optional value for data format. Currently only support "NCDHW".
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
- **weight** (Tensor) - Set size of kernel is :math:`(D_1, K_2, K_3)`, then the shape is
|
||||
:math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32)
|
||||
>>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float32)
|
||||
>>> conv3d = P.Conv3D(out_channel=32, kernel_size=(4, 3, 3))
|
||||
>>> output = conv3d(input, weight)
|
||||
>>> print(output.shape)
|
||||
(16, 32, 7, 30, 30)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
data_format="NCDHW"):
|
||||
"""Initialize Conv3D"""
|
||||
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
|
||||
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
|
||||
self.add_prim_attr('strides', self.stride)
|
||||
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
|
||||
self.add_prim_attr('dilations', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||
if isinstance(pad, int):
|
||||
pad = (pad,) * 6
|
||||
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
|
||||
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.add_prim_attr('io_format', "NCDHW")
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
self.group = validator.check_positive_int(group, 'group', self.name)
|
||||
self.add_prim_attr('groups', self.group)
|
||||
self.add_prim_attr('offset_x', 0)
|
||||
|
||||
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
||||
validator.check_equal_int(len(w_shape), 5, "weight rank", self.name)
|
||||
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
|
||||
if b_shape is not None:
|
||||
raise ValueError("Bias currently only support None.")
|
||||
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
|
||||
validator.check('kernel_size', self.kernel_size, 'w_shape[1:4]', tuple(w_shape[2:]), Rel.EQ, self.name)
|
||||
|
||||
kernel_size_d = w_shape[2]
|
||||
kernel_size_h = w_shape[3]
|
||||
kernel_size_w = w_shape[4]
|
||||
|
||||
stride_d = self.stride[2]
|
||||
stride_h = self.stride[3]
|
||||
stride_w = self.stride[4]
|
||||
|
||||
dilation_d = self.dilation[2]
|
||||
dilation_h = self.dilation[3]
|
||||
dilation_w = self.dilation[4]
|
||||
|
||||
if self.pad_mode == "valid":
|
||||
d_out = math.ceil((x_shape[2] - dilation_d * (kernel_size_d - 1)) / stride_d)
|
||||
h_out = math.ceil((x_shape[3] - dilation_h * (kernel_size_h - 1)) / stride_h)
|
||||
w_out = math.ceil((x_shape[4] - dilation_w * (kernel_size_w - 1)) / stride_w)
|
||||
pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0, 0, 0
|
||||
|
||||
elif self.pad_mode == "same":
|
||||
d_out = math.ceil(x_shape[2] / stride_d)
|
||||
h_out = math.ceil(x_shape[3] / stride_h)
|
||||
w_out = math.ceil(x_shape[4] / stride_w)
|
||||
|
||||
pad_needed_d = max(0, (d_out - 1) * stride_d + dilation_d * (kernel_size_d - 1) + 1 - x_shape[2])
|
||||
pad_head = math.floor(pad_needed_d / 2)
|
||||
pad_tail = pad_needed_d - pad_head
|
||||
|
||||
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[3])
|
||||
pad_top = math.floor(pad_needed_h / 2)
|
||||
pad_bottom = pad_needed_h - pad_top
|
||||
|
||||
pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[4])
|
||||
pad_left = math.floor(pad_needed_w / 2)
|
||||
pad_right = pad_needed_w - pad_left
|
||||
|
||||
elif self.pad_mode == 'pad':
|
||||
pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.padding
|
||||
d_out = 1 + (x_shape[2] + pad_head + pad_tail - kernel_size_d - (kernel_size_d - 1)
|
||||
* (dilation_d - 1)) / stride_d
|
||||
h_out = 1 + (x_shape[3] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1)
|
||||
* (dilation_h - 1)) / stride_h
|
||||
w_out = 1 + (x_shape[4] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1)
|
||||
* (dilation_w - 1)) / stride_w
|
||||
d_out = math.floor(d_out)
|
||||
h_out = math.floor(h_out)
|
||||
w_out = math.floor(w_out)
|
||||
|
||||
self.pad_list = [pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right]
|
||||
self.add_prim_attr('pads', (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right))
|
||||
out_channel = self.out_channel
|
||||
out_shape = [x_shape[0], out_channel, d_out, h_out, w_out]
|
||||
_check_shape('output', out_shape, self.name)
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
|
||||
args = {'x': x_dtype, 'w': w_dtype}
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Conv3DBackpropInput(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the gradients of convolution 3D with respect to the input.
|
||||
|
||||
Args:
|
||||
out_channel (int): The dimension of the output.
|
||||
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
|
||||
mode (int): Modes for different convolutions. Not currently used.
|
||||
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
|
||||
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
||||
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
|
||||
integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
|
||||
pad[3], pad[4] and pad[5] correspondingly.
|
||||
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
|
||||
dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
|
||||
group (int): Splits input into groups. Default: 1.
|
||||
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
|
||||
|
||||
Inputs:
|
||||
- **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2, K_3)`, then the shape is
|
||||
:math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
|
||||
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
|
||||
data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
|
||||
- **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format
|
||||
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the gradients w.r.t the input of convolution 3D. It has the same shape as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float32)
|
||||
>>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float32)
|
||||
>>> x = Tensor(np.ones([16, 32, 13, 37, 33]))
|
||||
>>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
|
||||
>>> output = conv3d_backprop_input(dout, weight, F.shape(x))
|
||||
>>> print(output.shape)
|
||||
(16, 32, 13, 37, 33)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
mode=1,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
data_format="NCDHW"):
|
||||
"""Initialize Conv3DBackpropInput"""
|
||||
self.init_prim_io_names(inputs=['filter', 'out_backprop', 'input_size'], outputs=['y'])
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
|
||||
self.add_prim_attr('strides', self.stride)
|
||||
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
|
||||
self.add_prim_attr('dilations', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||
if isinstance(pad, int):
|
||||
pad = (pad,) * 6
|
||||
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
|
||||
self.pad_list = pad
|
||||
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
||||
self.group = validator.check_positive_int(group, 'group', self.name)
|
||||
self.add_prim_attr('groups', self.group)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.add_prim_attr('io_format', "NCDHW")
|
||||
|
||||
def __infer__(self, w, doutput, x_size):
|
||||
x_size_v = x_size['value']
|
||||
validator.check_value_type('x_size', x_size_v, [tuple], self.name)
|
||||
for i, dim_len in enumerate(x_size_v):
|
||||
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
|
||||
args = {'doutput': doutput['dtype'], 'w': w['dtype']}
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
validator.check("filter's batch", w['shape'][0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name)
|
||||
validator.check("filter's channel", w['shape'][1], "input_size's channel", x_size_v[1], Rel.EQ, self.name)
|
||||
validator.check("input_size's batch", x_size_v[0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name)
|
||||
|
||||
# infer shape
|
||||
dout_shape = doutput['shape']
|
||||
kernel_d = self.kernel_size[0]
|
||||
kernel_h = self.kernel_size[1]
|
||||
kernel_w = self.kernel_size[2]
|
||||
stride_d = self.stride[2]
|
||||
stride_h = self.stride[3]
|
||||
stride_w = self.stride[4]
|
||||
dilation_d = self.dilation[2]
|
||||
dilation_h = self.dilation[3]
|
||||
dilation_w = self.dilation[4]
|
||||
# The pad_mode is valid by default. If pad_mode is not valid or same, then pad.
|
||||
if self.pad_mode == "valid":
|
||||
self.pad_list = (0, 0, 0, 0, 0, 0)
|
||||
if self.pad_mode == "same":
|
||||
pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_size_v[2])
|
||||
pad_head = math.floor(pad_needed_d / 2)
|
||||
pad_tail = pad_needed_d - pad_head
|
||||
|
||||
pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[3])
|
||||
pad_top = math.floor(pad_needed_h / 2)
|
||||
pad_bottom = pad_needed_h - pad_top
|
||||
|
||||
pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[4])
|
||||
pad_left = math.floor(pad_needed_w / 2)
|
||||
pad_right = pad_needed_w - pad_left
|
||||
self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
|
||||
|
||||
self.add_prim_attr('pads', self.pad_list)
|
||||
out = {
|
||||
'value': None,
|
||||
'shape': x_size_v,
|
||||
'dtype': doutput['dtype'],
|
||||
}
|
||||
return out
|
||||
|
|
|
@ -27,6 +27,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
|
@ -288,6 +289,7 @@ class CountNonZero(nn.Cell):
|
|||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
self.dtype = dtype
|
||||
|
||||
def construct(self, input_x):
|
||||
nonzero_num = C.count_nonzero(input_x, self.axis, self.keep_dims, self.dtype)
|
||||
return nonzero_num
|
||||
|
@ -423,6 +425,50 @@ class ScatterDiv(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class Conv3D(nn.Cell):
|
||||
"""Conv3D net definition"""
|
||||
|
||||
def __init__(self, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format):
|
||||
super(Conv3D, self).__init__()
|
||||
self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, pad_mode=pad_mode,
|
||||
pad=pad, stride=stride, dilation=dilation, group=group, data_format=data_format)
|
||||
|
||||
def construct(self, x, w):
|
||||
out = self.conv(x, w)
|
||||
return out
|
||||
|
||||
|
||||
class Conv3DBackpropInput(nn.Cell):
|
||||
"""Conv3DBackpropInput net definition"""
|
||||
|
||||
def __init__(self, input_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group,
|
||||
data_format):
|
||||
super(Conv3DBackpropInput, self).__init__()
|
||||
self.conv = nps.Conv3DBackpropInput(out_channel, kernel_size, pad_mode=pad_mode,
|
||||
pad=pad, mode=mode, stride=stride, dilation=dilation,
|
||||
group=group, data_format=data_format)
|
||||
self.x_size = input_shape
|
||||
|
||||
def construct(self, w, doutput):
|
||||
ms_out = self.conv(w, doutput, self.x_size)
|
||||
return ms_out
|
||||
|
||||
|
||||
class Conv3DBackpropFilter(nn.Cell):
|
||||
"""Conv3DBackpropFilter net definition"""
|
||||
|
||||
def __init__(self, w_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format):
|
||||
super(Conv3DBackpropFilter, self).__init__()
|
||||
self.conv = G.Conv3DBackpropFilter(out_channel, kernel_size, pad_mode=pad_mode,
|
||||
pad=pad, mode=mode, stride=stride, dilation=dilation,
|
||||
group=group, data_format=data_format)
|
||||
self.w_size = w_shape
|
||||
|
||||
def construct(self, x, doutput):
|
||||
ms_out = self.conv(x, doutput, self.w_size)
|
||||
return ms_out
|
||||
|
||||
|
||||
class ApplyFtrlNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyFtrlNet, self).__init__()
|
||||
|
@ -1180,6 +1226,24 @@ test_case_math_ops = [
|
|||
'block': Moments(axis=(), keep_dims=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('Conv3D', {
|
||||
'block': Conv3D(out_channel=32, kernel_size=(4, 3, 3), mode=1, pad_mode='valid', pad=0,
|
||||
stride=1, dilation=1, group=1, data_format="NCDHW"),
|
||||
'desc_inputs': [Tensor(np.random.random((16, 3, 10, 32, 32)).astype(np.float16)),
|
||||
Tensor(np.random.random((32, 3, 4, 3, 3)).astype(np.float16))],
|
||||
'skip': ['backward']}),
|
||||
('Conv3DBackpropInput', {
|
||||
'block': Conv3DBackpropInput(input_shape=(16, 32, 13, 37, 33), out_channel=32, kernel_size=(4, 6, 2), mode=1,
|
||||
pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"),
|
||||
'desc_inputs': [Tensor(np.random.random((32, 32, 4, 6, 2)).astype(np.float16)),
|
||||
Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))],
|
||||
'skip': ['backward']}),
|
||||
('Conv3DBackpropFilter', {
|
||||
'block': Conv3DBackpropFilter(w_shape=(32, 32, 4, 6, 2), out_channel=32, kernel_size=(4, 6, 2), mode=1,
|
||||
pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"),
|
||||
'desc_inputs': [Tensor(np.random.random((16, 32, 13, 37, 33)).astype(np.float16)),
|
||||
Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))],
|
||||
'skip': ['backward']}),
|
||||
('CountNonZero', {
|
||||
'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],
|
||||
|
|
Loading…
Reference in New Issue