Develop op MaxPoolWithArgMax

This commit is contained in:
buxue 2020-04-02 11:58:45 +08:00
parent 22cc03a54a
commit 1d3bb0b731
39 changed files with 544 additions and 518 deletions

View File

@ -148,8 +148,6 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec
} }
std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = { std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = {
{"MaxPoolWithArgmax", TbeAdapter::MaxPoolWithArgmaxAttrJsonPass},
{"MaxPoolGradWithArgmax", TbeAdapter::MaxPoolGradWithArgmaxAttrJsonPass},
{"Conv2D", TbeAdapter::Conv2DAttrJsonPass}, {"Conv2D", TbeAdapter::Conv2DAttrJsonPass},
{"Conv2DBackpropFilter", TbeAdapter::Conv2DBackpropFilterAttrJsonPass}, {"Conv2DBackpropFilter", TbeAdapter::Conv2DBackpropFilterAttrJsonPass},
{"Conv2DBackpropInput", TbeAdapter::Conv2DBackpropInputAttrJsonPass}, {"Conv2DBackpropInput", TbeAdapter::Conv2DBackpropInputAttrJsonPass},
@ -170,48 +168,6 @@ bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node,
return false; return false;
} }
void TbeAdapter::MaxPoolWithArgmaxAttrJsonPass(
const mindspore::AnfNodePtr &anf_node, const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(attrs_json);
auto attr_num = op_info_attrs.size();
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
for (size_t i = 0; i < attr_num; i++) {
nlohmann::json attr_obj;
MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
std::string attr_name = op_info_attrs[i]->name();
if (primitive->GetAttr(attr_name) != nullptr) {
auto value = primitive->GetAttr(attr_name);
if (attr_name == "pad_mode") {
std::string attr_value = GetValue<std::string>(value);
(void)transform(attr_value.begin(), attr_value.end(), attr_value.begin(), ::toupper);
attr_obj["value"] = attr_value;
} else {
std::vector<int> attr_value;
int data = GetValue<int>(value);
attr_value.push_back(1);
attr_value.push_back(data);
attr_value.push_back(data);
attr_value.push_back(1);
attr_obj["value"] = attr_value;
}
attr_obj["valid"] = true;
} else {
attr_obj["valid"] = false;
}
attr_obj["name"] = attr_name;
attrs_json->push_back(attr_obj);
}
}
void TbeAdapter::MaxPoolGradWithArgmaxAttrJsonPass(
const mindspore::AnfNodePtr &anf_node, const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json) {
MaxPoolWithArgmaxAttrJsonPass(anf_node, op_info_attrs, attrs_json);
}
void TbeAdapter::Conv2DAttrJsonPass(const mindspore::AnfNodePtr &anf_node, void TbeAdapter::Conv2DAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs, const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json) { nlohmann::json *attrs_json) {

View File

@ -161,6 +161,7 @@ const char kNameTopK[] = "TopK";
const char kNameSoftmaxGrad[] = "SoftmaxGrad"; const char kNameSoftmaxGrad[] = "SoftmaxGrad";
const char kNameMaxPool[] = "MaxPool"; const char kNameMaxPool[] = "MaxPool";
const char kNameAvgPool[] = "AvgPool"; const char kNameAvgPool[] = "AvgPool";
const char kNameMaxPoolWithArgmax[] = "MaxPoolWithArgmax";
const char kNameBatchNorm[] = "BatchNorm"; const char kNameBatchNorm[] = "BatchNorm";
const char kNameBatchNormGrad[] = "BatchNormGrad"; const char kNameBatchNormGrad[] = "BatchNormGrad";
const char kNameROIAlign[] = "ROIAlign"; const char kNameROIAlign[] = "ROIAlign";
@ -199,6 +200,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameApplyMomentum), ADPT_DESC(ApplyMomentum)}, {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentum)},
{string(kNameMaxPool), ADPT_DESC(MaxPool)}, {string(kNameMaxPool), ADPT_DESC(MaxPool)},
{string(kNameAvgPool), ADPT_DESC(AvgPool)}, {string(kNameAvgPool), ADPT_DESC(AvgPool)},
{string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)},
{string(kNameTopK), ADPT_DESC(TopK)}, {string(kNameTopK), ADPT_DESC(TopK)},
{string(kNamePack), ADPT_DESC(Pack)}, {string(kNamePack), ADPT_DESC(Pack)},
{string(kNameSplitD), ADPT_DESC(SplitD)}, {string(kNameSplitD), ADPT_DESC(SplitD)},

39
mindspore/ccsrc/transform/op_declare.cc Executable file → Normal file
View File

@ -192,8 +192,7 @@ ATTR_MAP(PRelu) = EMPTY_ATTR_MAP;
OUTPUT_MAP(PRelu) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(PRelu) = {{0, OUTPUT_DESC(y)}};
// PReluGrad // PReluGrad
INPUT_MAP(PReluGrad) = { INPUT_MAP(PReluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(features)}, {3, INPUT_DESC(weights)}};
{1, INPUT_DESC(grads)}, {2, INPUT_DESC(features)}, {3, INPUT_DESC(weights)}};
ATTR_MAP(PReluGrad) = EMPTY_ATTR_MAP; ATTR_MAP(PReluGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(PReluGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(da)}}; OUTPUT_MAP(PReluGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(da)}};
@ -702,24 +701,30 @@ ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<
OUTPUT_MAP(AvgPoolGrad) = {{0, OUTPUT_DESC(out_grad)}}; OUTPUT_MAP(AvgPoolGrad) = {{0, OUTPUT_DESC(out_grad)}};
// MaxPoolWithArgmax // MaxPoolWithArgmax
INPUT_MAP(MaxPoolWithArgmax) = {{1, INPUT_DESC(x)}};
ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"padding", ATTR_DESC(padding, AnyTraits<std::string>())}};
OUTPUT_MAP(MaxPoolWithArgmax) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(argmax)}};
// MaxPoolGradWithArgmax
INPUT_MAP(MaxPoolGradWithArgmax) = { INPUT_MAP(MaxPoolGradWithArgmax) = {
{1, INPUT_DESC(x)}, {1, INPUT_DESC(x)},
{2, INPUT_DESC(argmax)}, {2, INPUT_DESC(grad)},
{3, INPUT_DESC(grad)}, {3, INPUT_DESC(argmax)},
}; };
ATTR_MAP(MaxPoolGradWithArgmax) = {{"pad_mode", ATTR_DESC(padding, AnyTraits<std::string>())}, ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"window", ATTR_DESC(ksize, "window", AnyTraits<std::vector<int64_t>>())}, {"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, "stride", AnyTraits<std::vector<int64_t>>())}}; {"padding", ATTR_DESC(padding, AnyTraits<std::string>())}};
OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}};
// Conv2D // Conv2D
INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
ATTR_MAP(Conv2D) = { ATTR_MAP(Conv2D) = {{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, {"group", ATTR_DESC(groups, AnyTraits<int>())}};
{"group", ATTR_DESC(groups, AnyTraits<int>())}
};
OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}};
// Conv2DBackpropInputD // Conv2DBackpropInputD
@ -731,8 +736,7 @@ ATTR_MAP(Conv2DBackpropInputD) = {
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())},
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"group", ATTR_DESC(groups, AnyTraits<int>())} {"group", ATTR_DESC(groups, AnyTraits<int>())}};
};
OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
// Conv2DBackpropFilterD // Conv2DBackpropFilterD
@ -744,8 +748,7 @@ ATTR_MAP(Conv2DBackpropFilterD) = {
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, {"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())}, {"dilation", ATTR_DESC(dilations, "pad", AnyTraits<std::vector<int64_t>>())},
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"group", ATTR_DESC(groups, AnyTraits<int>())} {"group", ATTR_DESC(groups, AnyTraits<int>())}};
};
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
// DepthwiseConv2D // DepthwiseConv2D

View File

@ -88,8 +88,10 @@ DECLARE_OP_ADAPTER(FusedBatchNormGrad)
DECLARE_OP_USE_OUTPUT(FusedBatchNormGrad) DECLARE_OP_USE_OUTPUT(FusedBatchNormGrad)
DECLARE_OP_ADAPTER(BiasAddGrad) DECLARE_OP_ADAPTER(BiasAddGrad)
DECLARE_OP_USE_OUTPUT(BiasAddGrad) DECLARE_OP_USE_OUTPUT(BiasAddGrad)
DECLARE_OP_ADAPTER(MaxPoolWithArgmax)
DECLARE_OP_USE_OUTPUT(MaxPoolWithArgmax)
DECLARE_OP_ADAPTER(MaxPoolGradWithArgmax) DECLARE_OP_ADAPTER(MaxPoolGradWithArgmax)
DECLARE_OP_USE_ENUM(MaxPoolGradWithArgmax) DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax)
DECLARE_OP_ADAPTER(Conv2D) DECLARE_OP_ADAPTER(Conv2D)
DECLARE_OP_USE_ENUM(Conv2D) DECLARE_OP_USE_ENUM(Conv2D)
DECLARE_OP_USE_OUTPUT(Conv2D) DECLARE_OP_USE_OUTPUT(Conv2D)

View File

@ -168,7 +168,7 @@ class ResNet(nn.Cell):
self.conv1 = _conv7x7(3, 64, stride=2) self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64) self.bn1 = _bn(64)
self.relu = P.ReLU() self.relu = P.ReLU()
self.maxpool = P.MaxPoolWithArgmax(pad_mode='same', window=3, stride=2) self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
self.layer1 = self._make_layer(block, self.layer1 = self._make_layer(block,
layer_nums[0], layer_nums[0],

View File

@ -13,33 +13,49 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""pooling""" """pooling"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ... import context
from ..cell import Cell from ..cell import Cell
class _PoolNd(Cell): class _PoolNd(Cell):
"""N-D AvgPool""" """N-D AvgPool"""
def __init__(self, def __init__(self, kernel_size, stride, pad_mode):
kernel_size, name = self.__class__.__name__
stride,
pad_mode,
padding=0,
pool=None):
super(_PoolNd, self).__init__() super(_PoolNd, self).__init__()
self.kernel_size = kernel_size validator.check_type('kernel_size', kernel_size, [int, tuple])
self.stride = stride validator.check_type('stride', stride, [int, tuple])
self.pad_mode = pad_mode self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
self.padding = validator.check_integer('padding', padding, 0, Rel.GE)
self.pool = pool
if self.pool is None:
raise NotImplementedError
def construct(self, x): if isinstance(kernel_size, int):
return self.pool(x) validator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
else:
if (len(kernel_size) != 2 or
(not isinstance(kernel_size[0], int)) or
(not isinstance(kernel_size[1], int)) or
kernel_size[0] <= 0 or
kernel_size[1] <= 0):
raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or'
f'a tuple of two positive int numbers, but got {kernel_size}')
self.kernel_size = kernel_size
if isinstance(stride, int):
validator.check_integer("stride", stride, 1, Rel.GE)
else:
if (len(stride) != 2 or
(not isinstance(stride[0], int)) or
(not isinstance(stride[1], int)) or
stride[0] <= 0 or
stride[1] <= 0):
raise ValueError(f'The stride passed to cell {name} should be an positive int number or'
f'a tuple of two positive int numbers, but got {stride}')
self.stride = stride
def construct(self, *inputs):
pass
def extend_repr(self): def extend_repr(self):
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__) return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
@ -63,19 +79,23 @@ class MaxPool2d(_PoolNd):
pad_mode for training only supports "same" and "valid". pad_mode for training only supports "same" and "valid".
Args: Args:
kernel_size (int): Size of the window to take a max over. Default 1. kernel_size (Union[int, tuple[int]]): The size of kernel used to take the max value,
stride (int): Stride size of the window. Default: 1. is an int number that represents height and width are both kernel_size,
pad_mode (str): Select the mode of the pad. The optional values are or a tuple of two int numbers that represent height and width respectively.
"same" and "valid". Default: "valid". Default: 1.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
- same: Adopts the way of completion. Output height and width will be the same as - same: Adopts the way of completion. Output height and width will be the same as
the input. Total number of padding will be calculated for horizontal and vertical the input. Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the direction and evenly distributed to top and bottom, left and right if possible.
last extra padding will be done from the bottom and the right side. Otherwise, the last extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return - valid: Adopts the way of discarding. The possibly largest height and width of output
without padding. Extra pixels will be discarded. will be return without padding. Extra pixels will be discarded.
padding (int): Implicit zero padding to be added on both sides. Default: 0.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -103,31 +123,22 @@ class MaxPool2d(_PoolNd):
[[7. 8.] [[7. 8.]
[8. 8.]]]] [8. 8.]]]]
""" """
def __init__(self,
kernel_size=1,
stride=1,
pad_mode="VALID",
padding=0):
max_pool = P.MaxPool(ksize=kernel_size,
strides=stride,
padding=pad_mode)
self.is_autodiff_backend = False
if self.is_autodiff_backend:
# At present, pad mode of max pool is not unified, so it is a temporarily avoided def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
pad_mode = validator.check_string('pad_mode', pad_mode.lower(), ['valid', 'same']) super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode)
self.max_pool = P.MaxPool(ksize=self.kernel_size,
max_pool = P.MaxPoolWithArgmax(window=kernel_size, strides=self.stride,
stride=stride, padding=self.pad_mode)
pad_mode=pad_mode, self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
pad=padding) strides=self.stride,
super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode, padding, max_pool) padding=self.pad_mode)
self.is_tbe = context.get_context("device_target") == "Ascend"
def construct(self, x): def construct(self, x):
if self.is_autodiff_backend: if self.is_tbe and self.training:
out = self.pool(x)[0] out = self.max_pool_with_arg_max(x)[0]
else: else:
out = self.pool(x) out = self.max_pool(x)
return out return out
@ -149,19 +160,24 @@ class AvgPool2d(_PoolNd):
pad_mode for training only supports "same" and "valid". pad_mode for training only supports "same" and "valid".
Args: Args:
kernel_size (int): Size of the window to take a max over. Default: 1. kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value,
stride (int): Stride size of the window. Default: 1. is an int number that represents height and width are both kernel_size,
pad_mode (str): Select the mode of the pad. The optional values are or a tuple of two int numbers that represent height and width respectively.
"same", "valid". Default: "valid". Default: 1.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
- same: Adopts the way of completion. Output height and width will be the same as - same: Adopts the way of completion. Output height and width will be the same as
the input. Total number of padding will be calculated for horizontal and vertical the input. Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the direction and evenly distributed to top and bottom, left and right if possible.
last extra padding will be done from the bottom and the right side. Otherwise, the last extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possibly largest height and width of output
will be return without padding. Extra pixels will be discarded.
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
without padding. Extra pixels will be discarded.
padding (int): Implicit zero padding to be added on both sides. Default: 0.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -170,7 +186,7 @@ class AvgPool2d(_PoolNd):
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples: Examples:
>>> pool = AvgPool2d(kernel_size=3, stride=1) >>> pool = AvgPool2d(kernel_size=3, strides=1)
>>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32)
[[[[5. 5. 9. 9.] [[[[5. 5. 9. 9.]
[8. 4. 3. 0.] [8. 4. 3. 0.]
@ -189,12 +205,15 @@ class AvgPool2d(_PoolNd):
[[4.2222223 4.5555553] [[4.2222223 4.5555553]
[3.2222223 4.5555553]]]] [3.2222223 4.5555553]]]]
""" """
def __init__(self, def __init__(self,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
pad_mode="VALID", pad_mode="valid"):
padding=0): super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode)
avg_pool = P.AvgPool(ksize=kernel_size, self.avg_pool = P.AvgPool(ksize=self.kernel_size,
strides=stride, strides=self.stride,
padding=pad_mode) padding=self.pad_mode)
super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode, padding, avg_pool)
def construct(self, x):
return self.avg_pool(x)

View File

@ -76,14 +76,9 @@ def get_bprop_depthwise_conv2d_native(self):
def get_bprop_max_pool_with_argmax(self): def get_bprop_max_pool_with_argmax(self):
"""Grad definition for `MaxPoolWithArgmax` operation.""" """Grad definition for `MaxPoolWithArgmax` operation."""
maxpool_grad = G.MaxPoolGradWithArgmax( maxpool_grad = G.MaxPoolGradWithArgmax(
pad_mode=self.pad_mode, ksize=self.ksize,
window=self.window, strides=self.strides,
pad=self.pad, padding=self.padding,)
stride=self.stride,
data_mode=self.data_mode,
ceil_mode=self.ceil_mode,
alpha=self.alpha,
beta=self.beta)
def bprop(x, out, dout): def bprop(x, out, dout):
dx = maxpool_grad(x, dout[0], out[1]) dx = maxpool_grad(x, dout[0], out[1])

View File

@ -28,19 +28,19 @@ from mindspore.ops.op_info_register import op_info_register
"partial_flag": true, "partial_flag": true,
"attr": [ "attr": [
{ {
"name": "window", "name": "ksize",
"param_type": "required", "param_type": "required",
"type": "listInt", "type": "listInt",
"value": "all" "value": "all"
}, },
{ {
"name": "stride", "name": "strides",
"param_type": "required", "param_type": "required",
"type": "listInt", "type": "listInt",
"value": "all" "value": "all"
}, },
{ {
"name": "pad_mode", "name": "padding",
"param_type": "required", "param_type": "required",
"type": "str", "type": "str",
"value": "all" "value": "all"

View File

@ -28,19 +28,19 @@ from mindspore.ops.op_info_register import op_info_register
"partial_flag": true, "partial_flag": true,
"attr": [ "attr": [
{ {
"name": "window", "name": "ksize",
"param_type": "required", "param_type": "required",
"type": "listInt", "type": "listInt",
"value": "all" "value": "all"
}, },
{ {
"name": "stride", "name": "strides",
"param_type": "required", "param_type": "required",
"type": "listInt", "type": "listInt",
"value": "all" "value": "all"
}, },
{ {
"name": "pad_mode", "name": "padding",
"param_type": "required", "param_type": "required",
"type": "str", "type": "str",
"value": "all" "value": "all"

View File

@ -15,7 +15,6 @@
"""Operators for gradients.""" """Operators for gradients."""
import math
from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_kind as sig_kind
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@ -340,59 +339,60 @@ class _PoolGrad(PrimitiveWithInfer):
"""Gradients of the max/avg pool operation.""" """Gradients of the max/avg pool operation."""
@prim_attr_register @prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"): def __init__(self, ksize, strides, padding="VALID"):
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
self.ksize = ksize
self.strides = strides
self.padding = padding
self.ksize = validator.check_type('ksize', self.ksize, [int, tuple]) validator.check_type('ksize', ksize, [int, tuple])
self.strides = validator.check_type('strides', self.strides, [int, tuple]) validator.check_type('strides', strides, [int, tuple])
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'])
validator.check_type('padding', self.padding, [str])
self.padding = validator.check_string('padding', self.padding, ['VALID', 'SAME'])
self.add_prim_attr("padding", self.padding) self.add_prim_attr("padding", self.padding)
self.add_prim_attr('data_format', "NCHW") self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
if not self.is_maxpoolgradwithargmax:
self.add_prim_attr('data_format', "NCHW")
if isinstance(self.ksize, int): if isinstance(ksize, int):
self.pool_h = validator.check_integer("ksize", self.ksize, 1, Rel.GE) validator.check_integer("ksize", ksize, 1, Rel.GE)
self.pool_w = self.pool_h if self.is_maxpoolgradwithargmax:
self.add_prim_attr("ksize", (1, 1, self.ksize, self.ksize)) self.ksize = (1, ksize, ksize, 1)
elif isinstance(self.ksize, tuple): else:
if (len(self.ksize) != 2 and len(self.ksize) != 4): self.ksize = (1, 1, ksize, ksize)
raise ValueError('Attr \'ksize\' of \'Pool\' Op passed ' +
str(self.ksize)+', should be a int or a tuple of length 2 or 4.')
for ksize_val in self.ksize:
if (not isinstance(ksize_val, int)) or (ksize_val <= 0):
raise ValueError('Each value of attr \'ksize\' of \'MaxPool\' Op passed ' +
str(self.ksize)+', should be int and greater than 0.')
self.pool_h = self.ksize[-2]
self.pool_w = self.ksize[-1]
self.add_prim_attr("ksize", (1, 1, self.ksize[-2], self.ksize[-1]))
if isinstance(self.strides, int):
self.stride_h = validator.check_integer("strides", self.strides, 1, Rel.GE)
self.stride_w = self.stride_h
self.add_prim_attr("strides", (1, 1, self.strides, self.strides))
elif isinstance(self.strides, tuple):
if (len(self.strides) != 2 and len(self.strides) != 4):
raise ValueError('Attr \'strides\' of \'MaxPool\' Op passed ' +
str(self.strides)+', should be a int or a tuple of length 2 or 4.')
for stride_val in self.strides:
if (not isinstance(stride_val, int)) or (stride_val <= 0):
raise ValueError('Each value of attr \'strides\' of \'MaxPool\' Op passed ' +
str(self.strides)+', should be int and greater than 0.')
self.stride_h = self.strides[-2]
self.stride_w = self.strides[-1]
self.add_prim_attr("strides", (1, 1, self.strides[-2], self.strides[-1]))
if self.padding == "VALID":
self.pad = 0
elif self.padding == "SAME":
self.pad = math.floor((self.pool_h - 1) / 2)
else: else:
raise ValueError('The padding should be str and must be SAME or VALID,' ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number"
' but got {}.'.format(self.padding)) f"or a tuple of two or four positive int numbers, but got {ksize}")
if len(ksize) != 2 and len(ksize) != 4:
raise ksize_error
for ksize_val in ksize:
if not isinstance(ksize_val, int) or (ksize_val <= 0):
raise ksize_error
if len(ksize) == 2 and self.is_maxpoolgradwithargmax:
self.ksize = (1, ksize[0], ksize[1], 1)
elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax:
self.ksize = (1, 1, ksize[0], ksize[1])
else:
self.ksize = ksize
self.add_prim_attr("ksize", self.ksize)
if isinstance(strides, int):
validator.check_integer("strides", strides, 1, Rel.GE)
if self.is_maxpoolgradwithargmax:
self.strides = (1, strides, strides, 1)
else:
self.strides = (1, 1, strides, strides)
else:
strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number"
f"or a tuple of two or four positive int numbers, but got {strides}")
if len(strides) != 2 and len(strides) != 4:
raise strides_error
for strides_val in strides:
if not isinstance(strides_val, int) or (strides_val <= 0):
raise strides_error
if len(strides) == 2 and self.is_maxpoolgradwithargmax:
self.strides = (1, strides[0], strides[1], 1)
elif len(strides) == 2 and not self.is_maxpoolgradwithargmax:
self.strides = (1, 1, strides[0], strides[1])
else:
self.strides = strides
self.add_prim_attr("strides", self.strides)
class AvgPoolGrad(_PoolGrad): class AvgPoolGrad(_PoolGrad):
@ -451,28 +451,13 @@ class MaximumGrad(Primitive):
raise NotImplementedError raise NotImplementedError
class MaxPoolGradWithArgmax(PrimitiveWithInfer): class MaxPoolGradWithArgmax(_PoolGrad):
"""Computes the gradients of MaxPoolWithArgmax.""" """Computes the gradients of MaxPoolWithArgmax."""
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self, ksize=1, strides=1, padding="VALID",):
pad_mode="valid",
window=0,
pad=0,
stride=1,
data_mode=1,
ceil_mode=0,
alpha=1.0,
beta=0.0):
self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding)
self.window = window
self.pool_h = self.pool_w = window
self.pad = pad
self.pad_mode = pad_mode
self.stride = stride
self.data_mode = data_mode
self.ceil_mode = ceil_mode
def infer_shape(self, x_shape, grad_shape, argmax_shape): def infer_shape(self, x_shape, grad_shape, argmax_shape):
if not grad_shape: if not grad_shape:

View File

@ -682,186 +682,83 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
return x_dtype return x_dtype
class MaxPoolWithArgmax(PrimitiveWithInfer):
r"""
Performs max pooling on the input Tensor and return both max values and indices.
Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs
regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
:math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
.. math::
\text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
\text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
Args:
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
window (Union[int, tuple[int]]): The size of window, which is the kernel size, two `int` for width
and height. Default: 1.
pad (Union[int, tuple[int]]): If `pad_mode` is `pad`, the pad value to fill, two `int` for width
and height. Default: 0.
stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for
width and height. Default: 1.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tuple of 2 Tensor, the maxpool result and where max values from.
- **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
- **mask** (Tensor) - Max values' index represented by the mask.
"""
@prim_attr_register
def __init__(self,
pad_mode="valid",
window=1,
pad=0,
stride=1,
data_mode=1,
ceil_mode=0,
alpha=1.0,
beta=0.0):
self.init_prim_io_names(inputs=['x'], outputs=['output', 'argmax'])
self.window = validator.check_type('window', window, [int, tuple])
if isinstance(window, int) and window <= 0:
raise ValueError('Attr \'window\' of \'MaxPoolWithArgmax\' Op passed '
+ str(self.window)+', should be a int or tuple and greater than 0.')
if isinstance(window, tuple) and (len(window) != 2 or
(not isinstance(window[0], int)) or
(not isinstance(window[1], int)) or
window[0] <= 0 or window[1] <= 0):
raise ValueError('Attr \'window\' of \'MaxPoolWithArgmax\' Op passed '
+ str(self.window)+', should be a int or tuple and greater than 0.')
self.pool_h = self.pool_w = window
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'])
if self.pad_mode == "valid":
self.pad = 0
elif self.pad_mode == "same":
self.pad = math.floor((self.window - 1) / 2)
elif self.pad_mode == "pad":
self.pad = validator.check_integer('pad', pad, 0, Rel.GE)
self.data_mode = validator.check_integer('data_mode', data_mode, 1, Rel.EQ)
self.ceil_mode = validator.check_integer('ceil_mode', ceil_mode, 0, Rel.EQ)
self.stride = validator.check_integer('stride', stride, 1, Rel.GE)
self.alpha = validator.check_type('alpha', alpha, [int, float])
self.beta = validator.check_type('beta', beta, [int, float])
self.is_tbe = not context.get_context("enable_ge") and context.get_context("device_target") == "Ascend"
def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ)
pad = self.pad
h_input = x_shape[2]
w_input = x_shape[3]
h_out = (h_input + 2 * pad - (self.window - 1) - 1) / self.stride + 1
h_out = math.floor(h_out)
w_out = (w_input + 2 * pad - (self.window - 1) - 1) / self.stride + 1
w_out = math.floor(w_out)
out_shape = [x_shape[0], x_shape[1], h_out, w_out]
for shape_value in out_shape:
if shape_value <= 0:
raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.")
k_size_vec = [1, self.window, self.window, 1]
argmax_shape = []
if self.is_tbe:
for i in range(4):
if i == 2:
dim = k_size_vec[i - 1] * k_size_vec[i]
argmax_shape.append(dim)
elif i == 3:
dim = math.ceil(out_shape[i - 1] * out_shape[i] / 16) + 1
argmax_shape.append(dim)
else:
argmax_shape.append(x_shape[i])
else:
argmax_shape = out_shape
return out_shape, argmax_shape
def infer_dtype(self, x_dtype):
out_dtype = x_dtype
validator.check_typename("x_type", x_dtype, (mstype.float16, mstype.float32))
argmax_dtype = mstype.int32
return out_dtype, argmax_dtype
class _Pool(PrimitiveWithInfer): class _Pool(PrimitiveWithInfer):
r""" r"""
Performs max/avg pooling operation. Performs max/avg pooling operation.
Args: Args:
ksize (Union[int, tuple[int]]): The size of the window to take a max over, that should be a tuple ksize (Union[int, tuple[int]]): The size of the kernel, that should be a tuple
of two `int` for width and height. Default: 1. of two `int` for height and width. Default: 1.
stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for strides (Union[int, tuple[int]]): The stride of the window, that should be
width and height. Default: 1. a tuple of two `int` for height and width. Default: 1.
padding (str): The optional values for pad mode "SAME", "VALID". Default: "VALID". padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
""" """
@prim_attr_register @prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"): def __init__(self, ksize=1, strides=1, padding="valid"):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_type('padding', padding, [str])
self.ksize = ksize
self.strides = strides
self.padding = padding.upper()
self.ksize = validator.check_type('ksize', self.ksize, [int, tuple])
self.strides = validator.check_type('strides', self.strides, [int, tuple])
self.padding = validator.check_string('padding', self.padding, ['VALID', 'SAME'])
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_type('ksize', ksize, [int, tuple])
validator.check_type('strides', strides, [int, tuple])
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'])
self.add_prim_attr("padding", self.padding) self.add_prim_attr("padding", self.padding)
self.add_prim_attr('data_format', "NCHW") self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
if not self.is_maxpoolwithargmax:
self.add_prim_attr('data_format', "NCHW")
if isinstance(self.ksize, int): if isinstance(ksize, int):
self.pool_h = validator.check_integer("ksize", self.ksize, 1, Rel.GE) validator.check_integer("ksize", ksize, 1, Rel.GE)
self.pool_w = self.pool_h self.ksize = (1, 1, ksize, ksize)
self.add_prim_attr("ksize", (1, 1, self.ksize, self.ksize))
elif isinstance(self.ksize, tuple):
if (len(self.ksize) != 2 or (not isinstance(self.ksize[0], int)) or (not isinstance(self.ksize[1], int))
or self.ksize[0] <= 0 or self.ksize[1] <= 0):
raise ValueError('Each value of attr \'ksize\' of \'MaxPool\' Op passed ' +
str(self.ksize) + ', should be a int or a tuple of length 2 and greater than 0.')
self.pool_h = self.ksize[0]
self.pool_w = self.ksize[1]
self.add_prim_attr("ksize", (1, 1, self.ksize[0], self.ksize[1]))
if isinstance(self.strides, int):
self.stride_h = validator.check_integer("strides", self.strides, 1, Rel.GE)
self.stride_w = self.stride_h
self.add_prim_attr("strides", (1, 1, self.strides, self.strides))
elif isinstance(self.strides, tuple):
if (len(self.strides) != 2 or (not isinstance(self.strides[0], int)) or
(not isinstance(self.strides[1], int)) or self.strides[0] <= 0 or self.strides[1] <= 0):
raise ValueError('Each value of attr \'strides\' of \'MaxPool\' Op passed ' +
str(self.strides) + ', should be a int or a tuple of length 2 and greater than 0.')
self.stride_h = self.strides[0]
self.stride_w = self.strides[1]
self.add_prim_attr("strides", (1, 1, self.strides[0], self.strides[1]))
if self.padding == "VALID":
self.pad = 0
elif self.padding == "SAME":
self.pad = math.floor((self.pool_h - 1) / 2)
else: else:
raise ValueError('The padding should be str and must be SAME or VALID,' if (len(ksize) != 2 or
' but got {}.'.format(self.padding)) (not isinstance(ksize[0], int)) or
self.add_prim_attr('pad', self.pad) (not isinstance(ksize[1], int)) or
ksize[0] <= 0 or
ksize[1] <= 0):
raise ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number or"
f"a tuple of two positive int numbers, but got {ksize}")
self.ksize = (1, 1, ksize[0], ksize[1])
if self.is_maxpoolwithargmax:
self.ksize = (1, self.ksize[-2], self.ksize[-1], 1)
self.add_prim_attr("ksize", self.ksize)
if isinstance(strides, int):
validator.check_integer("strides", strides, 1, Rel.GE)
self.strides = (1, 1, strides, strides)
else:
if (len(strides) != 2 or
(not isinstance(strides[0], int)) or
(not isinstance(strides[1], int)) or
strides[0] <= 0 or
strides[1] <= 0):
raise ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number or"
f"a tuple of two positive int numbers, but got {strides}")
self.strides = (1, 1, strides[0], strides[1])
if self.is_maxpoolwithargmax:
self.strides = (1, self.strides[-2], self.strides[-1], 1)
self.add_prim_attr("strides", self.strides)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ)
h_input = x_shape[2] batch, channel, input_h, input_w = x_shape
w_input = x_shape[3] if self.is_maxpoolwithargmax:
if self.padding == "VALID": _, kernel_h, kernel_w, _ = self.ksize
h_out = math.ceil((h_input - (self.pool_h - 1)) / self.stride_h) _, stride_h, stride_w, _ = self.strides
w_out = math.ceil((w_input - (self.pool_w - 1)) / self.stride_w)
elif self.padding == "SAME":
h_out = math.ceil(h_input / self.stride_h)
w_out = math.ceil(w_input / self.stride_w)
else: else:
raise ValueError('The padding should be str and must be SAME or VALID,' _, _, kernel_h, kernel_w = self.ksize
' but got {}.'.format(self.padding)) _, _, stride_h, stride_w = self.strides
if self.padding == "VALID":
out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h)
out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w)
elif self.padding == "SAME":
out_h = math.ceil(input_h / stride_h)
out_w = math.ceil(input_w / stride_w)
else:
raise ValueError(f"The padding of operator {self.name} should be a str and must be 'SAME' or 'VALID', "
f"but got {self.padding}.")
out_shape = [batch, channel, out_h, out_w]
out_shape = [x_shape[0], x_shape[1], h_out, w_out]
for shape_value in out_shape: for shape_value in out_shape:
if shape_value <= 0: if shape_value <= 0:
raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.") raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.")
@ -887,11 +784,22 @@ class MaxPool(_Pool):
\text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
Args: Args:
ksize (Union[int, tuple[int]]): The size of the window to take a max over, that should be a tuple ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
of two `int` for width and height. Default: 1. is an int number that represents height and width are both ksize, or a tuple
stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for of two int numbers that represent height and width respectively. Default: 1.
width and height. Default: 1. strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
padding (str): The optional values for pad mode "SAME", "VALID". Default: "VALID". the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
- same: Adopts the way of completion. Output height and width will be the same as
the input. Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possibly largest height and width of output
will be return without padding. Extra pixels will be discarded.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -901,10 +809,83 @@ class MaxPool(_Pool):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"): def __init__(self, ksize=1, strides=1, padding="valid"):
super(MaxPool, self).__init__(ksize, strides, padding) super(MaxPool, self).__init__(ksize, strides, padding)
class MaxPoolWithArgmax(_Pool):
r"""
Performs max pooling on the input Tensor and return both max values and indices.
Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs
regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
:math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
.. math::
\text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
\text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
Args:
ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value and arg value,
is an int number that represents height and width are both ksize, or a tuple of
two int numbers that represent height and width respectively. Default: 1.
strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
- same: Adopts the way of completion. Output height and width will be the same as
the input. Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possibly largest height and width of output
will be return without padding. Extra pixels will be discarded.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tuple of 2 Tensor, the maxpool result and where max values from.
- **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
- **mask** (Tensor) - Max values' index represented by the mask.
"""
def __init__(self, ksize=1, strides=1, padding="valid"):
super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding)
self.is_tbe = context.get_context("device_target") == "Ascend"
def infer_shape(self, x_shape):
out_shape = _Pool.infer_shape(self, x_shape)
_, _, out_h, out_w = out_shape
_, kernel_h, kernel_w, _ = self.ksize
argmax_shape = []
if self.is_tbe:
for i in range(4):
if i == 2:
dim = kernel_h * kernel_w
argmax_shape.append(dim)
elif i == 3:
dim = math.ceil(out_h * out_w / 16) + 1
argmax_shape.append(dim)
else:
argmax_shape.append(x_shape[i])
else:
argmax_shape = out_shape
return out_shape, argmax_shape
def infer_dtype(self, x_dtype):
out_dtype = x_dtype
validator.check_typename("x_type", x_dtype, (mstype.float16, mstype.float32))
argmax_dtype = mstype.uint16
return out_dtype, argmax_dtype
class AvgPool(_Pool): class AvgPool(_Pool):
r""" r"""
Average pooling operation. Average pooling operation.
@ -919,11 +900,22 @@ class AvgPool(_Pool):
\text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
Args: Args:
ksize (Union[int, tuple[int]]): The size of the window to take a average over, that should be a tuple ksize (Union[int, tuple[int]]): The size of kernel used to take the average value,
of two `int` for width and height. Default: 1. is an int number that represents height and width are both ksize, or a tuple
stride (Union[int, tuple[int]]): The stride of the window, that should be a tuple of two `int` for of two int numbers that represent height and width respectively. Default: 1.
width and height. Default: 1. strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
padding (str): The optional values for pad mode "SAME", "VALID". Default: "VALID". the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
- same: Adopts the way of completion. Output height and width will be the same as
the input. Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possibly largest height and width of output
will be return without padding. Extra pixels will be discarded.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -933,7 +925,7 @@ class AvgPool(_Pool):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"): def __init__(self, ksize=1, strides=1, padding="valid"):
if context.get_context("device_target") == "GPU": if context.get_context("device_target") == "GPU":
self.target = "GPU" self.target = "GPU"
else: else:

View File

@ -103,7 +103,7 @@ class ResNet50(nn.Cell):
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='valid') self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
self.layer1 = self.MakeLayer( self.layer1 = self.MakeLayer(
block, 3, in_channels=64, out_channels=256, stride=1) block, 3, in_channels=64, out_channels=256, stride=1)

View File

@ -21,6 +21,7 @@ import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import Tensor from mindspore import Tensor
class LeNet(nn.Cell): class LeNet(nn.Cell):
def __init__(self): def __init__(self):
super(LeNet, self).__init__() super(LeNet, self).__init__()
@ -50,8 +51,10 @@ class LeNet(nn.Cell):
output = self.fc3(output) output = self.fc3(output)
return output return output
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def train(net, data, label): def train(net, data, label):
learning_rate = 0.01 learning_rate = 0.01
momentum = 0.9 momentum = 0.9
@ -67,11 +70,12 @@ def train(net, data, label):
print("+++++++++++++++++++++++++++") print("+++++++++++++++++++++++++++")
assert res assert res
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_lenet(): def test_lenet():
data = Tensor(np.ones([32, 1 ,32, 32]).astype(np.float32) * 0.01) data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32)) label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet() net = LeNet()
train(net, data, label) train(net, data, label)

View File

@ -38,7 +38,7 @@ class AlexNet(nn.Cell):
self.conv4 = nn.Conv2d(384, 384, 3, stride=1, pad_mode="same") self.conv4 = nn.Conv2d(384, 384, 3, stride=1, pad_mode="same")
self.conv5 = nn.Conv2d(384, 256, 3, stride=1, pad_mode="same") self.conv5 = nn.Conv2d(384, 256, 3, stride=1, pad_mode="same")
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2,pad_mode="valid",padding=0) self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.fc1 = nn.Dense(6*6*256, 4096) self.fc1 = nn.Dense(6*6*256, 4096)
self.fc2 = nn.Dense(4096, 4096) self.fc2 = nn.Dense(4096, 4096)

View File

@ -20,26 +20,29 @@ import numpy as np
import mindspore.context as context import mindspore.context as context
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
context.set_context(device_target="Ascend") context.set_context(device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.maxpool = P.MaxPoolWithArgmax(pad_mode="same", self.maxpool = P.MaxPoolWithArgmax(padding="same",
window=3, ksize=3,
stride=2) strides=2)
self.x = Parameter(initializer( self.x = Parameter(initializer(
'normal', [1, 64, 112, 112]), name='w') 'normal', [1, 64, 112, 112]), name='w')
self.add = P.TensorAdd() self.add = P.TensorAdd()
@ms_function @ms_function
def construct(self): def construct(self):
output = self.maxpool(self.x) output = self.maxpool(self.x)
return output[0] return output[0]
def test_net(): def test_net():
x = np.random.randn(1,64,112,112).astype(np.float32) x = np.random.randn(1, 64, 112, 112).astype(np.float32)
maxpool = Net() maxpool = Net()
output = maxpool() output = maxpool()
print("***********output output*********") print("***********output output*********")

View File

@ -37,9 +37,9 @@ class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.maxpool = P.MaxPoolWithArgmax(pad_mode="same", self.maxpool = P.MaxPoolWithArgmax(padding="same",
window=3, ksize=3,
stride=2) strides=2)
@ms_function @ms_function
def construct(self, x): def construct(self, x):

View File

@ -267,7 +267,7 @@ class ResNet(nn.Cell):
self.bn1 = bn_with_initialize(64) self.bn1 = bn_with_initialize(64)
self.relu = P.ReLU() self.relu = P.ReLU()
self.maxpool = P.MaxPoolWithArgmax(window=3, stride=2, pad_mode="same") self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME")
self.layer1 = MakeLayer0(block, layer_num[0], in_channels=64, out_channels=256, stride=1) self.layer1 = MakeLayer0(block, layer_num[0], in_channels=64, out_channels=256, stride=1)
self.layer2 = MakeLayer1(block, layer_num[1], in_channels=256, out_channels=512, stride=2) self.layer2 = MakeLayer1(block, layer_num[1], in_channels=256, out_channels=512, stride=2)

View File

@ -21,7 +21,7 @@ addn = P.AddN()
add = P.TensorAdd() add = P.TensorAdd()
sub = P.Sub() sub = P.Sub()
mul = P.Mul() mul = P.Mul()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
four2five = Primitive('Four2Five') four2five = Primitive('Four2Five')
five2four = Primitive('Five2Four') five2four = Primitive('Five2Four')

View File

@ -17,7 +17,7 @@ from mindspore.ops import Primitive
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
add = P.TensorAdd() add = P.TensorAdd()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
transdata = Primitive("TransData") transdata = Primitive("TransData")

View File

@ -21,7 +21,7 @@ addn = P.AddN()
add = P.TensorAdd() add = P.TensorAdd()
sub = P.Sub() sub = P.Sub()
mul = P.Mul() mul = P.Mul()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
cast = Primitive('Cast') cast = Primitive('Cast')

View File

@ -17,7 +17,7 @@ from mindspore.ops import Primitive
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
add = P.TensorAdd() add = P.TensorAdd()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
four2five = Primitive('Four2Five') four2five = Primitive('Four2Five')
five2four = Primitive('Five2Four') five2four = Primitive('Five2Four')

View File

@ -17,7 +17,7 @@ from mindspore.ops import Primitive
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
add = P.TensorAdd() add = P.TensorAdd()
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
transdata = Primitive("TransData") transdata = Primitive("TransData")
Transpose = P.Transpose() Transpose = P.Transpose()

View File

@ -22,7 +22,7 @@ add = P.TensorAdd()
reshape = P.Reshape() reshape = P.Reshape()
cast = P.Cast() cast = P.Cast()
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2) max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
def test_addn_cast(x, y, z): def test_addn_cast(x, y, z):
sum = addn((x, y)) sum = addn((x, y))

View File

@ -107,7 +107,7 @@ class ResNet18(nn.Cell):
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='pad') self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.MakeLayer( self.layer1 = self.MakeLayer(
block, 2, in_channels=64, out_channels=256, stride=1) block, 2, in_channels=64, out_channels=256, stride=1)
@ -176,7 +176,7 @@ class ResNet9(nn.Cell):
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='same') self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.MakeLayer( self.layer1 = self.MakeLayer(
block, 1, in_channels=64, out_channels=256, stride=1) block, 1, in_channels=64, out_channels=256, stride=1)

View File

@ -189,7 +189,7 @@ class ResNet50(nn.Cell):
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, weight_init=weight_conv) self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, weight_init=weight_conv)
self.bn1 = bn_with_initialize(64) self.bn1 = bn_with_initialize(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
self.layer1 = MakeLayer3( self.layer1 = MakeLayer3(
block, in_channels=64, out_channels=256, stride=1) block, in_channels=64, out_channels=256, stride=1)

View File

@ -23,12 +23,10 @@ class MaxNet(nn.Cell):
"""MaxNet definition""" """MaxNet definition"""
def __init__(self, def __init__(self,
kernel_size, kernel_size,
stride=None, stride=None):
padding=0):
super(MaxNet, self).__init__() super(MaxNet, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size, self.maxpool = nn.MaxPool2d(kernel_size,
stride, stride)
padding=padding)
def construct(self, input_x): def construct(self, input_x):
return self.maxpool(input_x) return self.maxpool(input_x)

View File

@ -106,7 +106,7 @@ class ResNet18(nn.Cell):
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, pad_mode='pad') self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.MakeLayer( self.layer1 = self.MakeLayer(
block, 2, in_channels=64, out_channels=256, stride=1) block, 2, in_channels=64, out_channels=256, stride=1)
@ -175,7 +175,7 @@ class ResNet9(nn.Cell):
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
self.layer1 = self.MakeLayer( self.layer1 = self.MakeLayer(
block, 1, in_channels=64, out_channels=256, stride=1) block, 1, in_channels=64, out_channels=256, stride=1)

View File

@ -87,7 +87,7 @@ class ConvNet(nn.Cell):
self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3) self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3)
self.bn1 = nn.BatchNorm2d(ConvNet.output_ch) self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="pad", padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.fc = nn.Dense( self.fc = nn.Dense(
int(ConvNet.image_h*ConvNet.image_w*ConvNet.output_ch/(4*4)), int(ConvNet.image_h*ConvNet.image_w*ConvNet.output_ch/(4*4)),

View File

@ -46,8 +46,7 @@ class MaxNet(nn.Cell):
padding=0): padding=0):
super(MaxNet, self).__init__() super(MaxNet, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size, self.maxpool = nn.MaxPool2d(kernel_size,
stride, stride)
padding=padding)
def construct(self, x): def construct(self, x):
return self.maxpool(x) return self.maxpool(x)

View File

@ -108,6 +108,7 @@ class ResidualBlock(nn.Cell):
class VirtualLossGrad(PrimitiveWithInfer): class VirtualLossGrad(PrimitiveWithInfer):
""" VirtualLossGrad definition """ """ VirtualLossGrad definition """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init VirtualLossGrad""" """init VirtualLossGrad"""
@ -124,6 +125,7 @@ class VirtualLossGrad(PrimitiveWithInfer):
class VirtualLoss(PrimitiveWithInfer): class VirtualLoss(PrimitiveWithInfer):
""" VirtualLoss definition """ """ VirtualLoss definition """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init VirtualLoss""" """init VirtualLoss"""
@ -138,6 +140,7 @@ class VirtualLoss(PrimitiveWithInfer):
# pylint: disable=unused-argument # pylint: disable=unused-argument
dx = loss_grad(x, out, dout) dx = loss_grad(x, out, dout)
return (dx,) return (dx,)
return bprop return bprop
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -149,6 +152,7 @@ class VirtualLoss(PrimitiveWithInfer):
class VirtualNetWithLoss(nn.Cell): class VirtualNetWithLoss(nn.Cell):
""" VirtualNetWithLoss definition """ """ VirtualNetWithLoss definition """
def __init__(self, network): def __init__(self, network):
super(VirtualNetWithLoss, self).__init__() super(VirtualNetWithLoss, self).__init__()
self.loss = VirtualLoss() self.loss = VirtualLoss()
@ -161,6 +165,7 @@ class VirtualNetWithLoss(nn.Cell):
class SoftMaxGrad(nn.Cell): class SoftMaxGrad(nn.Cell):
""" SoftMaxGrad definition """ """ SoftMaxGrad definition """
def __init__(self, network): def __init__(self, network):
super(SoftMaxGrad, self).__init__() super(SoftMaxGrad, self).__init__()
self.network = network self.network = network
@ -171,6 +176,7 @@ class SoftMaxGrad(nn.Cell):
class DropoutGrad(nn.Cell): class DropoutGrad(nn.Cell):
""" DropoutGrad definition """ """ DropoutGrad definition """
def __init__(self, network): def __init__(self, network):
super(DropoutGrad, self).__init__() super(DropoutGrad, self).__init__()
self.network = network self.network = network
@ -181,6 +187,7 @@ class DropoutGrad(nn.Cell):
class ScalarSummaryNet(nn.Cell): class ScalarSummaryNet(nn.Cell):
""" ScalarSummaryNet definition """ """ ScalarSummaryNet definition """
def __init__(self): def __init__(self):
super(ScalarSummaryNet, self).__init__() super(ScalarSummaryNet, self).__init__()
self.summary = P.ScalarSummary() self.summary = P.ScalarSummary()
@ -193,6 +200,7 @@ class ScalarSummaryNet(nn.Cell):
class FusedBatchNormGrad(nn.Cell): class FusedBatchNormGrad(nn.Cell):
""" FusedBatchNormGrad definition """ """ FusedBatchNormGrad definition """
def __init__(self, network): def __init__(self, network):
super(FusedBatchNormGrad, self).__init__() super(FusedBatchNormGrad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
@ -204,6 +212,7 @@ class FusedBatchNormGrad(nn.Cell):
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
""" NetWithLoss definition """ """ NetWithLoss definition """
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
self.loss = P.SmoothL1Loss() self.loss = P.SmoothL1Loss()
@ -216,6 +225,7 @@ class NetWithLoss(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
""" GradWrap definition """ """ GradWrap definition """
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.network = network self.network = network
@ -227,6 +237,7 @@ class Grad(nn.Cell):
class BatchnormNet(nn.Cell): class BatchnormNet(nn.Cell):
""" BatchnormNet definition """ """ BatchnormNet definition """
def __init__(self): def __init__(self):
super(BatchnormNet, self).__init__() super(BatchnormNet, self).__init__()
self.conv1 = nn.Conv2d(3, 4, kernel_size=8, stride=2, pad_mode="pad", padding=3) self.conv1 = nn.Conv2d(3, 4, kernel_size=8, stride=2, pad_mode="pad", padding=3)
@ -247,6 +258,7 @@ class BatchnormNet(nn.Cell):
class NetWithLossClass(nn.Cell): class NetWithLossClass(nn.Cell):
""" NetWithLossClass definition """ """ NetWithLossClass definition """
def __init__(self, network): def __init__(self, network):
super(NetWithLossClass, self).__init__(auto_prefix=False) super(NetWithLossClass, self).__init__(auto_prefix=False)
self.loss = nn.SoftmaxCrossEntropyWithLogits() self.loss = nn.SoftmaxCrossEntropyWithLogits()
@ -259,12 +271,13 @@ class NetWithLossClass(nn.Cell):
class BlockNet(nn.Cell): class BlockNet(nn.Cell):
""" BlockNet definition """ """ BlockNet definition """
def __init__(self): def __init__(self):
super(BlockNet, self).__init__() super(BlockNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, pad_mode="pad", padding=3) self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, pad_mode="pad", padding=3)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
self.block_down_sample = ResidualBlock( self.block_down_sample = ResidualBlock(
64, 256, stride=1, down_sample=True 64, 256, stride=1, down_sample=True
) )
@ -281,6 +294,7 @@ class BlockNet(nn.Cell):
class Conv2dWithBiasNet(nn.Cell): class Conv2dWithBiasNet(nn.Cell):
""" Conv2dWithBiasNet definition """ """ Conv2dWithBiasNet definition """
def __init__(self): def __init__(self):
super(Conv2dWithBiasNet, self).__init__() super(Conv2dWithBiasNet, self).__init__()
self.conv = nn.Conv2d(3, 10, 1, bias_init='zeros') self.conv = nn.Conv2d(3, 10, 1, bias_init='zeros')
@ -292,6 +306,7 @@ class Conv2dWithBiasNet(nn.Cell):
class Conv2dNativeNet(nn.Cell): class Conv2dNativeNet(nn.Cell):
""" Conv2dNativeNet definition """ """ Conv2dNativeNet definition """
def __init__(self): def __init__(self):
super(Conv2dNativeNet, self).__init__() super(Conv2dNativeNet, self).__init__()
self.conv = P.DepthwiseConv2dNative(channel_multiplier=3, kernel_size=(3, 3)) self.conv = P.DepthwiseConv2dNative(channel_multiplier=3, kernel_size=(3, 3))
@ -309,9 +324,10 @@ class Conv2dNativeNet(nn.Cell):
class MakeRefKeyNet(nn.Cell): class MakeRefKeyNet(nn.Cell):
""" MakeRefKeyNet definition """ """ MakeRefKeyNet definition """
def __init__(self): def __init__(self):
super(MakeRefKeyNet, self).__init__() super(MakeRefKeyNet, self).__init__()
self.y= Parameter(Tensor([1.0], mindspore.float32), name="y") self.y = Parameter(Tensor([1.0], mindspore.float32), name="y")
def construct(self, x): def construct(self, x):
key = P.MakeRefKey("y")() key = P.MakeRefKey("y")()
@ -321,6 +337,7 @@ class MakeRefKeyNet(nn.Cell):
class StateNet(nn.Cell): class StateNet(nn.Cell):
""" StateTestTensor definition """ """ StateTestTensor definition """
def __init__(self): def __init__(self):
super(StateNet, self).__init__() super(StateNet, self).__init__()
weight = Tensor(np.ones([2, 1, 2, 2], np.float32)) weight = Tensor(np.ones([2, 1, 2, 2], np.float32))
@ -347,6 +364,24 @@ class ComparisonNet(nn.Cell):
return ret return ret
def test_max_pool_with_arg_max():
class NetMaxPoolWithArgMax(nn.Cell):
def __init__(self):
""" ComparisonNet definition """
super(NetMaxPoolWithArgMax, self).__init__()
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(padding="valid", ksize=2, strides=1)
def construct(self, x):
ret = self.max_pool_with_arg_max(x)
return ret
x = Tensor(np.ones([1, 1, 3, 3], np.float32))
net = NetMaxPoolWithArgMax()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
ret = net(x)
print(ret)
test_cases = [ test_cases = [
('SoftMaxGrad', { ('SoftMaxGrad', {
'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())),
@ -382,7 +417,7 @@ test_cases = [
'desc_inputs': [Tensor(np.ones([1, 3, 8, 8], np.float32)), Tensor(np.zeros([1, 64, 4, 4], np.float32))], 'desc_inputs': [Tensor(np.ones([1, 3, 8, 8], np.float32)), Tensor(np.zeros([1, 64, 4, 4], np.float32))],
}), }),
('Conv2dWithBiasGrad', { ('Conv2dWithBiasGrad', {
'block': Grad(NetWithLossClass(Conv2dWithBiasNet())), 'block': Grad(NetWithLossClass(Conv2dWithBiasNet())),
'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 2560], np.float32))], 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 2560], np.float32))],
}), }),
('Conv2dNativeGrad', { ('Conv2dNativeGrad', {
@ -407,114 +442,93 @@ test_cases = [
}), }),
] ]
test_cases_for_verify_exception = [ test_cases_for_verify_exception = [
('Conv2d_ValueError_1', { ('Conv2d_ValueError_1', {
'block': (lambda _ : P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}), 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('Conv2d_ValueError_2', { ('Conv2d_ValueError_2', {
'block': (lambda _ : P.Conv2D(3, 4, mode=-2), {'exception': ValueError}), 'block': (lambda _: P.Conv2D(3, 4, mode=-2), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('MaxPoolWithArgmax_ValueError_1', { ('MaxPoolWithArgmax_ValueError_1', {
'block': (lambda _ : P.MaxPoolWithArgmax(pad_mode='sane'), {'exception': ValueError}), 'block': (lambda _: P.MaxPoolWithArgmax(padding='sane'), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('MaxPoolWithArgmax_ValueError_2', { ('MaxPoolWithArgmax_ValueError_2', {
'block': (lambda _ : P.MaxPoolWithArgmax(data_mode=2), {'exception': ValueError}), 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('MaxPoolWithArgmax_ValueError_3', { ('MaxPoolWithArgmax_ValueError_3', {
'block': (lambda _ : P.MaxPoolWithArgmax(ceil_mode=2), {'exception': ValueError}), 'block': (lambda _: P.MaxPoolWithArgmax(ksize=-2), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('MaxPoolWithArgmax_ValueError_4', { ('MaxPoolWithArgmax_ValueError_4', {
'block': (lambda _ : P.MaxPoolWithArgmax(pad_mode="pad", pad=-1), {'exception': ValueError}), 'block': (lambda _: P.MaxPoolWithArgmax(strides=-1), {'exception': ValueError}),
'desc_inputs': [0],
}),
('MaxPoolWithArgmax_ValueError_5', {
'block': (lambda _ : P.MaxPoolWithArgmax(pad_mode="pad", pad='1'), {'exception': ValueError}),
'desc_inputs': [0],
}),
('MaxPoolWithArgmax_ValueError_6', {
'block': (lambda _ : P.MaxPoolWithArgmax(window='1'), {'exception': ValueError}),
'desc_inputs': [0],
}),
('MaxPoolWithArgmax_ValueError_7', {
'block': (lambda _ : P.MaxPoolWithArgmax(window=-2), {'exception': ValueError}),
'desc_inputs': [0],
}),
('MaxPoolWithArgmax_ValueError_8', {
'block': (lambda _ : P.MaxPoolWithArgmax(stride=-1), {'exception': ValueError}),
'desc_inputs': [0],
}),
('MaxPoolWithArgmax_ValueError_9', {
'block': (lambda _ : P.MaxPoolWithArgmax(alpha='1'), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('FusedBatchNorm_ValueError_1', { ('FusedBatchNorm_ValueError_1', {
'block': (lambda _ : P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}), 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('FusedBatchNorm_ValueError_2', { ('FusedBatchNorm_ValueError_2', {
'block': (lambda _ : P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}), 'block': (lambda _: P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('FusedBatchNorm_ValueError_3', { ('FusedBatchNorm_ValueError_3', {
'block': (lambda _ : P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}), 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('FusedBatchNorm_ValueError_4', { ('FusedBatchNorm_ValueError_4', {
'block': (lambda _ : P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}), 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('FusedBatchNorm_ValueError_5', { ('FusedBatchNorm_ValueError_5', {
'block': (lambda _ : P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}), 'block': (lambda _: P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('Softmax_ValueError_1', { ('Softmax_ValueError_1', {
'block': (lambda _ : P.Softmax("1"), {'exception': ValueError}), 'block': (lambda _: P.Softmax("1"), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('Softmax_ValueError_2', { ('Softmax_ValueError_2', {
'block': (lambda _ : P.Softmax(1.1), {'exception': ValueError}), 'block': (lambda _: P.Softmax(1.1), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('Softmax_ValueError_3', { ('Softmax_ValueError_3', {
'block': (lambda _ : P.Softmax(axis="1"), {'exception': ValueError}), 'block': (lambda _: P.Softmax(axis="1"), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('DropoutGenMask_ValueError_1', { ('DropoutGenMask_ValueError_1', {
'block': (lambda _ : P.DropoutGenMask(Seed0="seed0"), {'exception': ValueError}), 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('DropoutGenMask_ValueError_2', { ('DropoutGenMask_ValueError_2', {
'block': (lambda _ : P.DropoutGenMask(Seed0=1.0), {'exception': ValueError}), 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('DropoutGenMask_ValueError_3', { ('DropoutGenMask_ValueError_3', {
'block': (lambda _ : P.DropoutGenMask(Seed1="seed1"), {'exception': ValueError}), 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('DropoutGenMask_ValueError_4', { ('DropoutGenMask_ValueError_4', {
'block': (lambda _ : P.DropoutGenMask(Seed1=2.0), {'exception': ValueError}), 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': ValueError}),
'desc_inputs': [0], 'desc_inputs': [0],
}), }),
('MaxPool2d_ValueError_1', { ('MaxPool2d_ValueError_1', {
'block': (nn.MaxPool2d(kernel_size=120, stride=1, pad_mode="valid", padding=0), {'exception': ValueError}), 'block': (nn.MaxPool2d(kernel_size=120, stride=1, pad_mode="valid"), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
}), }),
('MaxPool2d_ValueError_2', { ('MaxPool2d_ValueError_2', {
'block': ( 'block': (
lambda _ : nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid", padding=0), lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"),
{'exception': ValueError}, {'exception': ValueError},
), ),
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
}), }),
('MaxPool2d_ValueError_3', { ('MaxPool2d_ValueError_3', {
'block': ( 'block': (
lambda _ : nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid", padding=0), lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"),
{'exception': ValueError}, {'exception': ValueError},
), ),
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
@ -532,4 +546,3 @@ def test_compile():
@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config) @mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
def test_check_exception(): def test_check_exception():
return test_cases_for_verify_exception return test_cases_for_verify_exception

View File

@ -571,7 +571,7 @@ test_case_nn_ops = [
'desc_bprop': [[3, 4, 6, 6]], 'desc_bprop': [[3, 4, 6, 6]],
'skip': ['backward']}), 'skip': ['backward']}),
('MaxPoolWithArgmax', { ('MaxPoolWithArgmax', {
'block': P.MaxPoolWithArgmax(window=2, stride=2), 'block': P.MaxPoolWithArgmax(ksize=2, strides=2),
'desc_inputs': [[128, 32, 32, 64]], 'desc_inputs': [[128, 32, 32, 64]],
'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}), 'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}),
('SoftmaxCrossEntropyWithLogits', { ('SoftmaxCrossEntropyWithLogits', {

View File

@ -160,16 +160,16 @@ test_case_check_ops = [
'block': nn.Dense(1, 6, has_bias=False, bias_init=Tensor(np.ones([6]).astype(np.float32))), 'block': nn.Dense(1, 6, has_bias=False, bias_init=Tensor(np.ones([6]).astype(np.float32))),
'desc_inputs': [Tensor(np.ones(shape=[6, 1]).astype(np.float32))]}), 'desc_inputs': [Tensor(np.ones(shape=[6, 1]).astype(np.float32))]}),
('MaxPool2d_1', { ('MaxPool2d_1', {
'block': nn.MaxPool2d(5, pad_mode='same', padding=0), 'block': nn.MaxPool2d(5, pad_mode='same'),
'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
('MaxPool2d_2', { ('MaxPool2d_2', {
'block': nn.MaxPool2d(5, pad_mode='valid', padding=0), 'block': nn.MaxPool2d(5, pad_mode='valid'),
'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
('AvgPool2d_1', { ('AvgPool2d_1', {
'block': nn.AvgPool2d(5, pad_mode='same', padding=0), 'block': nn.AvgPool2d(5, pad_mode='same'),
'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
('AvgPool2d_2', { ('AvgPool2d_2', {
'block': nn.AvgPool2d(5, pad_mode='valid', padding=0), 'block': nn.AvgPool2d(5, pad_mode='valid'),
'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}), 'desc_inputs': [Tensor(np.ones(shape=[5, 5, 8, 8]).astype(np.float32))]}),
('Conv2D_1', { ('Conv2D_1', {
'block': P.Conv2D(1, 6, pad_mode='same', pad=0), 'block': P.Conv2D(1, 6, pad_mode='same', pad=0),

View File

@ -42,12 +42,10 @@ def test_maxpool2d():
""" test_maxpool2d """ """ test_maxpool2d """
kernel_size = 3 kernel_size = 3
stride = 3 stride = 3
padding = 0
max_pool = nn.MaxPool2d(kernel_size, stride, padding=padding) max_pool = nn.MaxPool2d(kernel_size, stride)
assert max_pool.kernel_size == 3 assert max_pool.kernel_size == 3
assert max_pool.stride == 3 assert max_pool.stride == 3
assert max_pool.padding == 0
input_data = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) input_data = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32))
output = max_pool(input_data) output = max_pool(input_data)
output_np = output.asnumpy() output_np = output.asnumpy()

View File

@ -89,7 +89,7 @@ class ConvNet(nn.Cell):
self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode='pad', padding=3) self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode='pad', padding=3)
self.bn1 = nn.BatchNorm2d(ConvNet.output_ch) self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='pad', padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.fc = nn.Dense( self.fc = nn.Dense(
int(ConvNet.image_h*ConvNet.image_w*ConvNet.output_ch/(4*4)), int(ConvNet.image_h*ConvNet.image_w*ConvNet.output_ch/(4*4)),

View File

@ -49,23 +49,14 @@ def test_maxpool2d():
""" test_maxpool2d """ """ test_maxpool2d """
kernel_size = 3 kernel_size = 3
stride = 3 stride = 3
padding = 2
max_pool = nn.MaxPool2d(kernel_size, stride, pad_mode='SAME', padding=padding) max_pool = nn.MaxPool2d(kernel_size, stride, pad_mode='SAME')
assert max_pool.kernel_size == 3 assert max_pool.kernel_size == 3
assert max_pool.stride == 3 assert max_pool.stride == 3
assert max_pool.padding == 2
input_data = Tensor(np.random.randint(0, 255, [1, 3, 6, 6])*0.1) input_data = Tensor(np.random.randint(0, 255, [1, 3, 6, 6])*0.1)
output = max_pool(input_data) output = max_pool(input_data)
output_np = output.asnumpy() output_np = output.asnumpy()
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
def test_maxpool2d_error_padding():
""" test_maxpool2d_error_padding """
kernel_size = 3.5
stride = 3
padding = 1
with pytest.raises(ValueError):
nn.MaxPool2d(kernel_size, stride, padding=padding)

View File

@ -23,7 +23,7 @@ def test_avg_pooling():
[-9., -1., 3., 4.], [-9., -1., 3., 4.],
[1., -1., -3., -6.], [1., -1., -3., -6.],
[-2., -1., -2., -15.]]]]).astype(np.float32) [-2., -1., -2., -15.]]]]).astype(np.float32)
out = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1, pad=0) out = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1)
expect_out = [[[[-4.25, 0.0, 4.25], expect_out = [[[[-4.25, 0.0, 4.25],
[-2.5, -0.5, -0.5], [-2.5, -0.5, -0.5],
[-0.75, -1.75, -6.5]]]] [-0.75, -1.75, -6.5]]]]
@ -37,9 +37,9 @@ def test_avg_pool_grad():
[5, 6, 7, 8], [5, 6, 7, 8],
[9, 10, 11, 12], [9, 10, 11, 12],
[13, 14, 15, 16]]]]).astype(np.float32) [13, 14, 15, 16]]]]).astype(np.float32)
dout = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1, pad=0) dout = vm.avg_pooling(input_data, pool_h=2, pool_w=2, stride=1)
print("vm.avg_pooling dout: ", dout) print("vm.avg_pooling dout: ", dout)
out = vm.avg_pool_grad(dout, input_data.shape, 2, 2, 1, 0) out = vm.avg_pool_grad(dout, input_data.shape, 2, 2, 1)
print("vm.avg_pool_grad: ", out) print("vm.avg_pool_grad: ", out)
assert True assert True
@ -202,7 +202,7 @@ def test_max_pooling():
[-9., -1., 3., 4.], [-9., -1., 3., 4.],
[1., -1., -3., -6.], [1., -1., -3., -6.],
[-2., -1., -2., -15.]]]]).astype(np.float32) [-2., -1., -2., -15.]]]]).astype(np.float32)
out = vm.max_pooling(input_data, pool_h=2, pool_w=2, stride=1, pad=0) out = vm.max_pooling(input_data, pool_h=2, pool_w=2, stride=1)
expect_out = [[[[-1., 3., 9.], expect_out = [[[[-1., 3., 9.],
[1., 3., 4.], [1., 3., 4.],
[1., -1., -2.]]]] [1., -1., -2.]]]]

View File

@ -44,7 +44,7 @@ class Net(nn.Cell):
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros") self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.fc = nn.Dense(int(224*224*64/16), num_classes) self.fc = nn.Dense(int(224*224*64/16), num_classes)

View File

@ -19,66 +19,82 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm from .vm_interface import vm
# pylint: disable=unused-argument # pylint: disable=unused-argument
@vm_impl_getters.register(P.ScalarSummary) @vm_impl_getters.register(P.ScalarSummary)
def vm_impl_scalar_summary(self): def vm_impl_scalar_summary(self):
"""Generate vm_impl function for ScalarSummary""" """Generate vm_impl function for ScalarSummary"""
def vm_impl(string_in, scalar): def vm_impl(string_in, scalar):
"""Implement by vm mode.""" """Implement by vm mode."""
return scalar return scalar
return vm_impl return vm_impl
@vm_impl_getters.register(P.ReLU) @vm_impl_getters.register(P.ReLU)
def vm_impl_relu(self): def vm_impl_relu(self):
"""Generate vm_impl function for ReLU""" """Generate vm_impl function for ReLU"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
output = Tensor(vm.relu(x)) output = Tensor(vm.relu(x))
return output return output
return vm_impl return vm_impl
@vm_impl_getters.register(P.Flatten) @vm_impl_getters.register(P.Flatten)
def vm_impl_flatten(self): def vm_impl_flatten(self):
"""Generate vm_impl function for Flatten""" """Generate vm_impl function for Flatten"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
return Tensor(vm.flatten_batch(x)) return Tensor(vm.flatten_batch(x))
return vm_impl return vm_impl
@vm_impl_getters.register(P.Softmax) @vm_impl_getters.register(P.Softmax)
def vm_impl_softmax(self): def vm_impl_softmax(self):
"""Generate vm_impl function for Softmax""" """Generate vm_impl function for Softmax"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
return Tensor(vm.softmax(x)) return Tensor(vm.softmax(x))
return vm_impl return vm_impl
@vm_impl_getters.register(P.LogSoftmax) @vm_impl_getters.register(P.LogSoftmax)
def vm_impl_log_softmax(self): def vm_impl_log_softmax(self):
"""Generate vm_impl function for LogSoftmax""" """Generate vm_impl function for LogSoftmax"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
return Tensor(vm.logsoftmax(x)) return Tensor(vm.logsoftmax(x))
return vm_impl return vm_impl
@vm_impl_getters.register(P.Tanh) @vm_impl_getters.register(P.Tanh)
def vm_impl_tanh(self): def vm_impl_tanh(self):
"""Generate vm_impl function for Tanh""" """Generate vm_impl function for Tanh"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
return Tensor(vm.tanh(x)) return Tensor(vm.tanh(x))
return vm_impl return vm_impl
@vm_impl_getters.register(P.FusedBatchNorm) @vm_impl_getters.register(P.FusedBatchNorm)
def vm_impl_fused_batch_norm(self): def vm_impl_fused_batch_norm(self):
"""Generate vm_impl function for FusedBatchNorm""" """Generate vm_impl function for FusedBatchNorm"""
def vm_impl(x, scale, b, mean, variance): def vm_impl(x, scale, b, mean, variance):
# pylint: disable=unused-argument # pylint: disable=unused-argument
x = x.asnumpy() x = x.asnumpy()
@ -92,12 +108,14 @@ def vm_impl_fused_batch_norm(self):
momentum=self.momentum) momentum=self.momentum)
return Tensor(out), Tensor(x_mean), Tensor(x_var), \ return Tensor(out), Tensor(x_mean), Tensor(x_var), \
Tensor(running_mean), Tensor(running_var) Tensor(running_mean), Tensor(running_var)
return vm_impl return vm_impl
@vm_impl_getters.register(P.BatchNorm) @vm_impl_getters.register(P.BatchNorm)
def vm_impl_batch_norm(self): def vm_impl_batch_norm(self):
"""Generate vm_impl function for BatchNorm""" """Generate vm_impl function for BatchNorm"""
def vm_impl(x, scale, b, mean, variance): def vm_impl(x, scale, b, mean, variance):
# pylint: disable=unused-argument # pylint: disable=unused-argument
x = x.asnumpy() x = x.asnumpy()
@ -110,83 +128,106 @@ def vm_impl_batch_norm(self):
eps=self.epsilon) eps=self.epsilon)
return Tensor(out), Tensor(x_mean), Tensor(x_var), \ return Tensor(out), Tensor(x_mean), Tensor(x_var), \
Tensor(running_mean), Tensor(running_var) Tensor(running_mean), Tensor(running_var)
return vm_impl return vm_impl
@vm_impl_getters.register(P.Conv2D) @vm_impl_getters.register(P.Conv2D)
def vm_impl_conv2d(self): def vm_impl_conv2d(self):
"""Generate vm_impl function for Conv2D""" """Generate vm_impl function for Conv2D"""
def vm_impl(x, w): def vm_impl(x, w):
x = x.asnumpy() x = x.asnumpy()
weight = w.asnumpy() weight = w.asnumpy()
bias = None bias = None
out = vm.conv2d(x, weight, bias, self.stride, self.pad, self.dilation) out = vm.conv2d(x, weight, bias, self.stride, self.pad, self.dilation)
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@vm_impl_getters.register(G.MaxPoolGradWithArgmax) @vm_impl_getters.register(G.MaxPoolGradWithArgmax)
def vm_impl_max_pool_grad_with_argmax(self): def vm_impl_max_pool_grad_with_argmax(self):
"""Generate vm_impl function for MaxPoolGradWithArgmax""" """Generate vm_impl function for MaxPoolGradWithArgmax"""
def vm_impl(x, argmax, dout):
def vm_impl(x, dout, argmax):
print("buxue")
print(argmax)
x = x.asnumpy() x = x.asnumpy()
dout = dout.asnumpy() dout = dout.asnumpy()
arg_max = argmax.asnumpy() arg_max = argmax.asnumpy()
dx = vm.max_pool_grad_with_argmax(x, arg_max, dout, self.pool_h, self.pool_w, self.stride, self.pad) dx = vm.max_pool_grad_with_argmax(x, dout, arg_max,
self.ksize[1], self.ksize[2], self.strides[1])
return Tensor(dx) return Tensor(dx)
return vm_impl return vm_impl
@vm_impl_getters.register(P.MaxPoolWithArgmax) @vm_impl_getters.register(P.MaxPoolWithArgmax)
def vm_impl_max_pool_with_argmax(self): def vm_impl_max_pool_with_argmax(self):
"""Generate vm_impl function for MaxPoolWithArgmax""" """Generate vm_impl function for MaxPoolWithArgmax"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
out, out_argmax = vm.max_pool_with_argmax(x, self.pool_h, self.pool_w, self.stride, self.pad) out, out_argmax = vm.max_pool_with_argmax(x, self.ksize[1], self.ksize[2], self.strides[1])
return Tensor(out), Tensor(out_argmax) return Tensor(out), Tensor(out_argmax)
return vm_impl return vm_impl
@vm_impl_getters.register(P.MaxPool) @vm_impl_getters.register(P.MaxPool)
def vm_impl_max_pool(self): def vm_impl_max_pool(self):
"""Generate vm_impl function for MaxPool""" """Generate vm_impl function for MaxPool"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
out = vm.max_pooling(x, self.pool_h, self.pool_w, self.stride_h, self.pad) out = vm.max_pooling(x, self.ksize[-2], self.ksize[-1], self.strides[-2])
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@vm_impl_getters.register(G.MaxPoolGrad) @vm_impl_getters.register(G.MaxPoolGrad)
def vm_impl_max_pool_grad(self): def vm_impl_max_pool_grad(self):
"""Generate vm_impl function for MaxPoolGrad""" """Generate vm_impl function for MaxPoolGrad"""
def vm_impl(x, out, dout): def vm_impl(x, out, dout):
x = x.asnumpy() x = x.asnumpy()
dout = dout.asnumpy() dout = dout.asnumpy()
out = vm.max_pool_grad(x, dout, self.pool_h, self.pool_w, self.stride_h, self.pad) out = vm.max_pool_grad(x, dout, self.ksize[-2], self.ksize[-1], self.strides[-2])
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@vm_impl_getters.register(P.AvgPool) @vm_impl_getters.register(P.AvgPool)
def vm_impl_max_pool(self): def vm_impl_avg_pool(self):
"""Generate vm_impl function for AvgPool""" """Generate vm_impl function for AvgPool"""
def vm_impl(x): def vm_impl(x):
x = x.asnumpy() x = x.asnumpy()
out = vm.avg_pooling(x, self.pool_h, self.pool_w, self.stride_h, self.pad) out = vm.avg_pooling(x, self.ksize[-2], self.ksize[-1], self.strides[-2])
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@vm_impl_getters.register(G.AvgPoolGrad) @vm_impl_getters.register(G.AvgPoolGrad)
def vm_impl_avg_pool_grad(self): def vm_impl_avg_pool_grad(self):
"""Generate vm_impl function for AvgPoolGrad""" """Generate vm_impl function for AvgPoolGrad"""
def vm_impl(dout, origin_shape): def vm_impl(dout, origin_shape):
dout = dout.asnumpy() dout = dout.asnumpy()
out = vm.avg_pool_grad(dout, origin_shape, self.pool_h, self.pool_w, self.stride_h, self.pad) out = vm.avg_pool_grad(dout, origin_shape, self.ksize[-2], self.ksize[-1], self.strides[-2])
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@vm_impl_getters.register(G.FusedBatchNormGrad) @vm_impl_getters.register(G.FusedBatchNormGrad)
def vm_impl_fused_batch_norm_grad(self): def vm_impl_fused_batch_norm_grad(self):
"""Generate vm_impl function for FusedBatchNormGrad""" """Generate vm_impl function for FusedBatchNormGrad"""
def vm_impl(dy, x, scale, save_mean, save_inv_variance): def vm_impl(dy, x, scale, save_mean, save_inv_variance):
dy = dy.asnumpy() dy = dy.asnumpy()
x = x.asnumpy() x = x.asnumpy()
@ -195,11 +236,14 @@ def vm_impl_fused_batch_norm_grad(self):
save_inv_variance = save_inv_variance.asnumpy() save_inv_variance = save_inv_variance.asnumpy()
dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance) dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
return (Tensor(dx), Tensor(dscale), Tensor(dshift)) return (Tensor(dx), Tensor(dscale), Tensor(dshift))
return vm_impl return vm_impl
@vm_impl_getters.register(G.BatchNormGrad) @vm_impl_getters.register(G.BatchNormGrad)
def vm_impl_fused_batch_norm_grad(self): def vm_impl_fused_batch_norm_grad(self):
"""Generate vm_impl function for BatchNormGrad""" """Generate vm_impl function for BatchNormGrad"""
def vm_impl(dy, x, scale, save_mean, save_inv_variance): def vm_impl(dy, x, scale, save_mean, save_inv_variance):
dy = dy.asnumpy() dy = dy.asnumpy()
x = x.asnumpy() x = x.asnumpy()
@ -208,104 +252,123 @@ def vm_impl_fused_batch_norm_grad(self):
save_inv_variance = save_inv_variance.asnumpy() save_inv_variance = save_inv_variance.asnumpy()
dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance) dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
return (Tensor(dx), Tensor(dscale), Tensor(dshift)) return (Tensor(dx), Tensor(dscale), Tensor(dshift))
return vm_impl return vm_impl
@vm_impl_getters.register(G.ReluGrad) @vm_impl_getters.register(G.ReluGrad)
def vm_impl_relu_grad(self): def vm_impl_relu_grad(self):
"""Generate vm_impl function for ReluGrad""" """Generate vm_impl function for ReluGrad"""
def vm_impl(y_backprop, x): def vm_impl(y_backprop, x):
x = x.asnumpy() x = x.asnumpy()
y_backprop = y_backprop.asnumpy() y_backprop = y_backprop.asnumpy()
y_backprop = vm.relu_grad(x.copy())*y_backprop y_backprop = vm.relu_grad(x.copy()) * y_backprop
return Tensor(y_backprop) return Tensor(y_backprop)
return vm_impl return vm_impl
@vm_impl_getters.register(P.Conv2DBackpropInput) @vm_impl_getters.register(P.Conv2DBackpropInput)
def vm_impl_conv2d_backprop_input(self): def vm_impl_conv2d_backprop_input(self):
"""Generate vm_impl function for Conv2DBackpropInput""" """Generate vm_impl function for Conv2DBackpropInput"""
def vm_impl(dout, w, x_size): def vm_impl(dout, w, x_size):
dout = dout.asnumpy() dout = dout.asnumpy()
w = w.asnumpy() w = w.asnumpy()
dx = vm.conv2d_backprop_input(dout, x_size, w, self.stride, self.pad) dx = vm.conv2d_backprop_input(dout, x_size, w, self.stride, self.pad)
return Tensor(dx) return Tensor(dx)
return vm_impl return vm_impl
@vm_impl_getters.register(G.Conv2DBackpropFilter) @vm_impl_getters.register(G.Conv2DBackpropFilter)
def vm_impl_conv2d_backprop_filter(self): def vm_impl_conv2d_backprop_filter(self):
"""Generate vm_impl function for Conv2DBackpropFilter""" """Generate vm_impl function for Conv2DBackpropFilter"""
def vm_impl(dout, x, w_size): def vm_impl(dout, x, w_size):
x = x.asnumpy() x = x.asnumpy()
dout = dout.asnumpy() dout = dout.asnumpy()
dw = vm.conv2d_backprop_filter(dout, x, w_size, self.stride, self.pad) dw = vm.conv2d_backprop_filter(dout, x, w_size, self.stride, self.pad)
return Tensor(dw) return Tensor(dw)
return vm_impl return vm_impl
@vm_impl_getters.register(G.FlattenGrad) @vm_impl_getters.register(G.FlattenGrad)
def vm_impl_flatten_grad(self): def vm_impl_flatten_grad(self):
"""Generate vm_impl function for FlattenGrad""" """Generate vm_impl function for FlattenGrad"""
def vm_impl(dout, x): def vm_impl(dout, x):
dout = dout.asnumpy() dout = dout.asnumpy()
dout = vm.flatten_grad(dout, x) dout = vm.flatten_grad(dout, x)
return Tensor(dout) return Tensor(dout)
return vm_impl return vm_impl
@vm_impl_getters.register(P.BiasAdd) @vm_impl_getters.register(P.BiasAdd)
def vm_impl_bias_add(self): def vm_impl_bias_add(self):
"""Generate vm_impl function for BiasAdd""" """Generate vm_impl function for BiasAdd"""
def vm_impl(wx, bias): def vm_impl(wx, bias):
wx = wx.asnumpy() wx = wx.asnumpy()
bias = bias.asnumpy() bias = bias.asnumpy()
out = wx + bias out = wx + bias
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@vm_impl_getters.register(G.BiasAddGrad) @vm_impl_getters.register(G.BiasAddGrad)
def vm_impl_bias_add_grad(self): def vm_impl_bias_add_grad(self):
"""Generate vm_impl function for BiasAddGrad""" """Generate vm_impl function for BiasAddGrad"""
def vm_impl(dout): def vm_impl(dout):
dout = dout.asnumpy() dout = dout.asnumpy()
shape = np.shape(dout) shape = np.shape(dout)
return Tensor(np.add.reduce(dout, axis=tuple(range(len(shape) - 1)))) return Tensor(np.add.reduce(dout, axis=tuple(range(len(shape) - 1))))
return vm_impl return vm_impl
@vm_impl_getters.register(P.SoftmaxCrossEntropyWithLogits) @vm_impl_getters.register(P.SoftmaxCrossEntropyWithLogits)
def vm_impl_softmax_cross_entropy_with_logits(self): def vm_impl_softmax_cross_entropy_with_logits(self):
"""Generate vm_impl function for SoftmaxCrossEntropyWithLogits""" """Generate vm_impl function for SoftmaxCrossEntropyWithLogits"""
def vm_impl(logits, labels): def vm_impl(logits, labels):
logits = logits.asnumpy() logits = logits.asnumpy()
labels = labels.asnumpy() labels = labels.asnumpy()
loss, dx = vm.softmax_cross_entropy_with_logits(logits, labels) loss, dx = vm.softmax_cross_entropy_with_logits(logits, labels)
return (Tensor(np.array(loss)), Tensor(dx)) return (Tensor(np.array(loss)), Tensor(dx))
return vm_impl return vm_impl
@vm_impl_getters.register(P.SparseSoftmaxCrossEntropyWithLogits) @vm_impl_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
def vm_impl_sparse_softmax_cross_entropy_with_logits(self): def vm_impl_sparse_softmax_cross_entropy_with_logits(self):
"""Generate vm_impl function for SparseSoftmaxCrossEntropyWithLogits""" """Generate vm_impl function for SparseSoftmaxCrossEntropyWithLogits"""
def vm_impl(logits, labels): def vm_impl(logits, labels):
logits = logits.asnumpy() logits = logits.asnumpy()
labels = labels.asnumpy() labels = labels.asnumpy()
n_class = labels.max() + 1 n_class = labels.max() + 1
n_sample = labels.shape[0] n_sample = labels.shape[0]
one_hot_label = np.zeros((n_sample, n_class))#3个样本4个类别 one_hot_label = np.zeros((n_sample, n_class)) # 3个样本4个类别
one_hot_label[:, labels] = 1#非零列赋值为1 one_hot_label[:, labels] = 1 # 非零列赋值为1
loss, dx = vm.softmax_cross_entropy_with_logits(logits, one_hot_label) loss, dx = vm.softmax_cross_entropy_with_logits(logits, one_hot_label)
if self.is_grad: if self.is_grad:
return (Tensor(dx),) return (Tensor(dx),)
return (Tensor(np.array(loss)),) return (Tensor(np.array(loss)),)
return vm_impl return vm_impl
@vm_impl_getters.register(P.ApplyMomentum) @vm_impl_getters.register(P.ApplyMomentum)
def vm_impl_momentum(self): def vm_impl_momentum(self):
"""Generate vm_impl function for Momentum""" """Generate vm_impl function for Momentum"""
def vm_impl(variable, def vm_impl(variable,
accumulation, accumulation,
learning_rate, learning_rate,
@ -327,19 +390,24 @@ def vm_impl_momentum(self):
return vm_impl return vm_impl
@vm_impl_getters.register(P.ResizeBilinear) @vm_impl_getters.register(P.ResizeBilinear)
def vm_impl_resize_bilinear(self): def vm_impl_resize_bilinear(self):
"""Generate vm_impl function for ResizeBilinear""" """Generate vm_impl function for ResizeBilinear"""
def vm_impl(x): def vm_impl(x):
out = vm.ResizeBilinear(x) out = vm.ResizeBilinear(x)
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@vm_impl_getters.register(G.ResizeBilinearGrad) @vm_impl_getters.register(G.ResizeBilinearGrad)
def vm_impl_resize_bilinear_grad(self): def vm_impl_resize_bilinear_grad(self):
"""Generate vm_impl function for ResizeBilinearGrad""" """Generate vm_impl function for ResizeBilinearGrad"""
def vm_impl(dout, original_image): def vm_impl(dout, original_image):
out = vm.ResizeBilinearGrad(dout, original_image) out = vm.ResizeBilinearGrad(dout, original_image)
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl

View File

@ -19,7 +19,7 @@ from mindspore._checkparam import Rel
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
def avg_pooling(x, pool_h, pool_w, stride, pad): def avg_pooling(x, pool_h, pool_w, stride):
""" """
Applies average pooling over an input array. Applies average pooling over an input array.
@ -28,26 +28,25 @@ def avg_pooling(x, pool_h, pool_w, stride, pad):
pool_h (int): Height of the pooling window. pool_h (int): Height of the pooling window.
pool_w (int): Width of the pooling window. pool_w (int): Width of the pooling window.
stride (int): The stride of the sliding window. stride (int): The stride of the sliding window.
pad (int): Padding to be added on height and width.
Returns: Returns:
numpy.ndarray, an output array after applying average pooling on input array. numpy.ndarray, an output array after applying average pooling on input array.
""" """
validator.check_integer("stride", stride, 0, Rel.GT) validator.check_integer("stride", stride, 0, Rel.GT)
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height + 2*pad - pool_h)//stride + 1 out_h = (height - pool_h)//stride + 1
out_w = (width + 2*pad - pool_w)//stride + 1 out_w = (width - pool_w)//stride + 1
col = im2col(x, pool_h, pool_w, stride, pad) col = im2col(x, pool_h, pool_w, stride)
col = col.reshape(-1, pool_h*pool_w) col = col.reshape(-1, pool_h*pool_w)
out = np.mean(col, axis=1) out = np.mean(col, axis=1)
out = out.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
return out return out
def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride, pad): def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride):
""" """
Gets grad of average pooling. Gets grad of average pooling.
@ -57,7 +56,6 @@ def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride, pad):
pool_h (int): Height of the pooling window. pool_h (int): Height of the pooling window.
pool_w (int): Width of the pooling window. pool_w (int): Width of the pooling window.
stride (int): The stride of the sliding window. stride (int): The stride of the sliding window.
pad (int): Padding to be added on height and width.
Returns: Returns:
numpy.ndarray, grad of avgerage pooling. numpy.ndarray, grad of avgerage pooling.
@ -324,38 +322,38 @@ def matmul(x, w, b=None):
return y return y
def max_pooling(x, pool_h, pool_w, stride, pad): def max_pooling(x, pool_h, pool_w, stride):
"""Max pooling.""" """Max pooling."""
validator.check_integer("stride", stride, 0, Rel.GT) validator.check_integer("stride", stride, 0, Rel.GT)
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height + 2*pad - pool_h)//stride + 1 out_h = (height - pool_h)//stride + 1
out_w = (width + 2*pad - pool_w)//stride + 1 out_w = (width - pool_w)//stride + 1
col = im2col(x, pool_h, pool_w, stride, pad) col = im2col(x, pool_h, pool_w, stride)
col = col.reshape(-1, pool_h*pool_w) col = col.reshape(-1, pool_h*pool_w)
out = np.max(col, axis=1) out = np.max(col, axis=1)
out = out.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
return out return out
def max_pool_grad(x, dout, pool_h, pool_w, stride, pad): def max_pool_grad(x, dout, pool_h, pool_w, stride):
"""Grad of max pooling.""" """Grad of max pooling."""
dout = dout.transpose(0, 2, 3, 1) dout = dout.transpose(0, 2, 3, 1)
pool_size = pool_h * pool_w pool_size = pool_h * pool_w
dmax = np.zeros((dout.size, pool_size)) dmax = np.zeros((dout.size, pool_size))
col = im2col(x, pool_h, pool_w, stride, pad) col = im2col(x, pool_h, pool_w, stride)
col = col.reshape(-1, pool_h*pool_w) col = col.reshape(-1, pool_h*pool_w)
arg_max = np.argmax(col, axis=1) arg_max = np.argmax(col, axis=1)
dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten() dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size,)) dmax = dmax.reshape(dout.shape + (pool_size,))
dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1) dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1)
dx = col2im(dcol, x.shape, pool_h, pool_w, stride, pad) dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
return dx return dx
def max_pool_grad_with_argmax(x, arg_max, dout, pool_h, pool_w, stride, pad): def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
"""Grad of max pooling with argmax.""" """Grad of max pooling with argmax."""
dout = dout.transpose(0, 2, 3, 1) dout = dout.transpose(0, 2, 3, 1)
pool_size = pool_h * pool_w pool_size = pool_h * pool_w
@ -363,22 +361,22 @@ def max_pool_grad_with_argmax(x, arg_max, dout, pool_h, pool_w, stride, pad):
dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten() dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size,)) dmax = dmax.reshape(dout.shape + (pool_size,))
dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1) dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1)
dx = col2im(dcol, x.shape, pool_h, pool_w, stride, pad) dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
return dx return dx
def max_pool_with_argmax(x, pool_h, pool_w, stride, pad): def max_pool_with_argmax(x, pool_h, pool_w, stride):
"""Max pooling with argmax.""" """Max pooling with argmax."""
validator.check_integer("stride", stride, 0, Rel.GT) validator.check_integer("stride", stride, 0, Rel.GT)
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height + 2*pad - pool_h)//stride + 1 out_h = (height - pool_h)//stride + 1
out_w = (width + 2*pad - pool_w)//stride + 1 out_w = (width - pool_w)//stride + 1
col = im2col(x, pool_h, pool_w, stride, pad) col = im2col(x, pool_h, pool_w, stride)
col = col.reshape(-1, pool_h*pool_w) col = col.reshape(-1, pool_h*pool_w)
out = np.max(col, axis=1) out = np.max(col, axis=1)
out_argmax = np.argmax(col, axis=1) out_argmax = np.argmax(col, axis=1)
out = out.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
out_argmax = out_argmax.reshape(num, out_h, out_w, channel).transpose(0, 3, 1, 2) out_argmax = out_argmax.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
return out, out_argmax return out, out_argmax