ImageGradients check 4d

fix DiagPart constant folding issue

fix argmin output type check

fix atan2 doc error

fix remove FusedBatchNorm and its grad
This commit is contained in:
zhaozhenlong 2020-04-28 10:49:36 +08:00
parent f748d02c05
commit 66e7a36846
12 changed files with 44 additions and 43 deletions

View File

@ -174,6 +174,8 @@ const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad")
const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");

View File

@ -175,6 +175,8 @@ extern const PrimitivePtr kPrimTanhGrad;
extern const PrimitivePtr kPrimPooling;
extern const PrimitivePtr kPrimPoolingGrad;
extern const PrimitivePtr kPrimFusedBatchNorm;
extern const PrimitivePtr kPrimBatchNorm;
extern const PrimitivePtr kPrimBatchNormGrad;
extern const PrimitivePtr kPrimConv2D;
extern const PrimitivePtr kPrimMaxPool;
extern const PrimitivePtr kPrimMaxPoolGrad;

View File

@ -221,7 +221,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimAssign->name(), ADPT_DESC(Assign)},
{prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)},
{prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)},
{prim::kPrimFusedBatchNormGrad->name(), ADPT_DESC(FusedBatchNormGrad)},
{prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)},
{prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)},
{prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)},
@ -229,7 +228,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)},
{prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)},
{prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)},
{prim::kPrimFusedBatchNorm->name(), ADPT_DESC(FusedBatchNorm, BatchNorm)},
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
{string(kNameReshape), ADPT_DESC(Reshape)},

View File

@ -703,28 +703,6 @@ INPUT_MAP(ReluGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}};
ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}};
// FusedBatchNorm
INPUT_MAP(FusedBatchNorm) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(scale)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(mean)}, {5, INPUT_DESC(variance)}};
ATTR_MAP(FusedBatchNorm) = {{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())},
{"momentum", ATTR_DESC(moving_average_fraction, AnyTraits<float>())},
{"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}};
OUTPUT_MAP(FusedBatchNorm) = {{0, OUTPUT_DESC(y)},
{1, OUTPUT_DESC(running_mean)},
{2, OUTPUT_DESC(running_variance)},
{3, OUTPUT_DESC(save_mean)},
{4, OUTPUT_DESC(save_inv_variance)}};
// FusedBatchNromGrad
INPUT_MAP(FusedBatchNormGrad) = {{1, INPUT_DESC(dy)},
{2, INPUT_DESC(x)},
{3, INPUT_DESC(scale)},
{4, INPUT_DESC(save_mean)},
{5, INPUT_DESC(save_inv_variance)}};
ATTR_MAP(FusedBatchNormGrad) = {{"momentum", ATTR_DESC(momentum, AnyTraits<float>())},
{"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}};
OUTPUT_MAP(FusedBatchNormGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(bn_scale)}, {2, OUTPUT_DESC(bn_bias)}};
// BiasAddGrad
INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}};
ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};

View File

@ -82,10 +82,6 @@ DECLARE_OP_USE_OUTPUT(HcomAllGather)
DECLARE_OP_ADAPTER(Variable)
DECLARE_OP_ADAPTER(ReluGrad)
DECLARE_OP_USE_OUTPUT(ReluGrad)
DECLARE_OP_ADAPTER(FusedBatchNorm)
DECLARE_OP_USE_OUTPUT(FusedBatchNorm)
DECLARE_OP_ADAPTER(FusedBatchNormGrad)
DECLARE_OP_USE_OUTPUT(FusedBatchNormGrad)
DECLARE_OP_ADAPTER(BiasAddGrad)
DECLARE_OP_USE_OUTPUT(BiasAddGrad)
DECLARE_OP_ADAPTER(MaxPoolWithArgmax)

View File

@ -58,6 +58,7 @@ class ImageGradients(Cell):
super(ImageGradients, self).__init__()
def construct(self, images):
_check_input_4d(F.shape(images), "images", self.cls_name)
batch_size, depth, height, width = P.Shape()(images)
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
@ -151,8 +152,8 @@ class SSIM(Cell):
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", "SSIM")
_check_input_4d(F.shape(img2), "img2", "SSIM")
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
@ -244,8 +245,8 @@ class PSNR(Cell):
self.max_val = max_val
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", "PSNR")
_check_input_4d(F.shape(img2), "img2", "PSNR")
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
P.SameTypeShape()(img1, img2)
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
img1 = _convert_img_dtype_to_float32(img1, self.max_val)

View File

@ -1016,6 +1016,7 @@ class Argmin(PrimitiveWithInfer):
"""init Argmin"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type("axis", axis, [int], self.name)
validator.check_type_name("output_type", output_type, [mstype.int32, mstype.int64], self.name)
self.axis = axis
self.add_prim_attr('output_type', output_type)
@ -1726,7 +1727,9 @@ class Diag(PrimitiveWithInfer):
def infer_value(self, x):
if x is None:
return None
validator.check_integer("input x rank", len(x.shape()), 1, Rel.EQ, self.name)
# do constant-folding only when x rank is 1
if len(x.shape()) != 1:
return None
ret = np.diag(x.asnumpy())
return Tensor(ret)
@ -1752,7 +1755,7 @@ class DiagPart(PrimitiveWithInfer):
>>> [0, 0, 3, 0],
>>> [0, 0, 0, 4]])
>>> diag_part = P.DiagPart()
>>> diag_part(x)
>>> diag_part(input_x)
[1, 2, 3, 4]
"""
@ -1776,7 +1779,9 @@ class DiagPart(PrimitiveWithInfer):
def infer_value(self, x):
if x is None:
return None
validator.check("x rank", len(x.shape()), "", 2, Rel.EQ, self.name)
# do constant-folding only when x rank is 2
if len(x.shape()) != 2:
return None
ret = np.diag(x.asnumpy())
return Tensor(ret)

View File

@ -2037,7 +2037,7 @@ class Atan2(_MathBinaryOp):
r"""
Returns arctangent of input_x/input_y element-wise.
It returns :math:`\theta\ \in\ (-\frac{\pi}{2}, \frac{\pi}{2})`
It returns :math:`\theta\ \in\ [-\pi, \pi]`
such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`.
Inputs:

View File

@ -147,13 +147,13 @@ TEST_F(TestConvert, TestReluOps) {
}
TEST_F(TestConvert, TestConvertBatchNorm) {
PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm;
fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f));
fused_batch_norm->AddAttr("momentum", MakeValue(0.1f));
PrimitivePtr batch_norm = prim::kPrimBatchNorm;
batch_norm->AddAttr("epsilon", MakeValue(0.001f));
batch_norm->AddAttr("momentum", MakeValue(0.1f));
FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(fused_batch_norm));
inputs.push_back(NewValueNode(batch_norm));
for (unsigned int i = 0; i < 5; i++) {
inputs.push_back(anf_graph->add_parameter());
}

View File

@ -14,6 +14,7 @@
# ============================================================================
""" test image gradients """
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
@ -47,3 +48,10 @@ def test_compile_multi_channel():
[[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype)
net = Net()
_executor.compile(net, image)
def test_invalid_5d_input():
dtype = mstype.float32
image = Tensor(np.random.random([4, 1, 16, 16, 1]), dtype=dtype)
net = Net()
with pytest.raises(ValueError):
_executor.compile(net, image)

View File

@ -14,16 +14,15 @@
# ============================================================================
""" test array ops """
import functools
import pytest
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import prim_attr_register
from mindspore.common import dtype as mstype
from mindspore.ops.primitive import Primitive, PrimitiveWithInfer
from mindspore.common.dtype import get_py_obj_dtype
from mindspore._c_expression import signature_dtype as sig_dtype
from mindspore._c_expression import signature_rw as sig_rw
from mindspore._c_expression import signature_kind as sig_kind
@ -96,6 +95,17 @@ def test_select():
expect = np.array([[1, 8, 9], [10, 5, 6]])
assert np.all(output.asnumpy() == expect)
def test_argmin_invalid_output_type():
P.Argmin(-1, mstype.int64)
P.Argmin(-1, mstype.int32)
with pytest.raises(TypeError):
P.Argmin(-1, mstype.float32)
with pytest.raises(TypeError):
P.Argmin(-1, mstype.float64)
with pytest.raises(TypeError):
P.Argmin(-1, mstype.uint8)
with pytest.raises(TypeError):
P.Argmin(-1, mstype.bool_)
class CustomOP(PrimitiveWithInfer):
__mindspore_signature__ = (sig_dtype.T, sig_dtype.T, sig_dtype.T1,

View File

@ -17,6 +17,7 @@ import functools
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore import Tensor