forked from mindspore-Ecosystem/mindspore
ImageGradients check 4d
fix DiagPart constant folding issue fix argmin output type check fix atan2 doc error fix remove FusedBatchNorm and its grad
This commit is contained in:
parent
f748d02c05
commit
66e7a36846
|
@ -174,6 +174,8 @@ const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad")
|
|||
const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
|
||||
const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
|
||||
const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
|
||||
const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
|
||||
const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
|
||||
|
|
|
@ -175,6 +175,8 @@ extern const PrimitivePtr kPrimTanhGrad;
|
|||
extern const PrimitivePtr kPrimPooling;
|
||||
extern const PrimitivePtr kPrimPoolingGrad;
|
||||
extern const PrimitivePtr kPrimFusedBatchNorm;
|
||||
extern const PrimitivePtr kPrimBatchNorm;
|
||||
extern const PrimitivePtr kPrimBatchNormGrad;
|
||||
extern const PrimitivePtr kPrimConv2D;
|
||||
extern const PrimitivePtr kPrimMaxPool;
|
||||
extern const PrimitivePtr kPrimMaxPoolGrad;
|
||||
|
|
|
@ -221,7 +221,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{prim::kPrimAssign->name(), ADPT_DESC(Assign)},
|
||||
{prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)},
|
||||
{prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)},
|
||||
{prim::kPrimFusedBatchNormGrad->name(), ADPT_DESC(FusedBatchNormGrad)},
|
||||
{prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)},
|
||||
{prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)},
|
||||
{prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)},
|
||||
|
@ -229,7 +228,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)},
|
||||
{prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)},
|
||||
{prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)},
|
||||
{prim::kPrimFusedBatchNorm->name(), ADPT_DESC(FusedBatchNorm, BatchNorm)},
|
||||
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
|
||||
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
|
||||
{string(kNameReshape), ADPT_DESC(Reshape)},
|
||||
|
|
|
@ -703,28 +703,6 @@ INPUT_MAP(ReluGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}};
|
|||
ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}};
|
||||
|
||||
// FusedBatchNorm
|
||||
INPUT_MAP(FusedBatchNorm) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(scale)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(mean)}, {5, INPUT_DESC(variance)}};
|
||||
ATTR_MAP(FusedBatchNorm) = {{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())},
|
||||
{"momentum", ATTR_DESC(moving_average_fraction, AnyTraits<float>())},
|
||||
{"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(FusedBatchNorm) = {{0, OUTPUT_DESC(y)},
|
||||
{1, OUTPUT_DESC(running_mean)},
|
||||
{2, OUTPUT_DESC(running_variance)},
|
||||
{3, OUTPUT_DESC(save_mean)},
|
||||
{4, OUTPUT_DESC(save_inv_variance)}};
|
||||
|
||||
// FusedBatchNromGrad
|
||||
INPUT_MAP(FusedBatchNormGrad) = {{1, INPUT_DESC(dy)},
|
||||
{2, INPUT_DESC(x)},
|
||||
{3, INPUT_DESC(scale)},
|
||||
{4, INPUT_DESC(save_mean)},
|
||||
{5, INPUT_DESC(save_inv_variance)}};
|
||||
ATTR_MAP(FusedBatchNormGrad) = {{"momentum", ATTR_DESC(momentum, AnyTraits<float>())},
|
||||
{"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(FusedBatchNormGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(bn_scale)}, {2, OUTPUT_DESC(bn_bias)}};
|
||||
|
||||
// BiasAddGrad
|
||||
INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||
|
|
|
@ -82,10 +82,6 @@ DECLARE_OP_USE_OUTPUT(HcomAllGather)
|
|||
DECLARE_OP_ADAPTER(Variable)
|
||||
DECLARE_OP_ADAPTER(ReluGrad)
|
||||
DECLARE_OP_USE_OUTPUT(ReluGrad)
|
||||
DECLARE_OP_ADAPTER(FusedBatchNorm)
|
||||
DECLARE_OP_USE_OUTPUT(FusedBatchNorm)
|
||||
DECLARE_OP_ADAPTER(FusedBatchNormGrad)
|
||||
DECLARE_OP_USE_OUTPUT(FusedBatchNormGrad)
|
||||
DECLARE_OP_ADAPTER(BiasAddGrad)
|
||||
DECLARE_OP_USE_OUTPUT(BiasAddGrad)
|
||||
DECLARE_OP_ADAPTER(MaxPoolWithArgmax)
|
||||
|
|
|
@ -58,6 +58,7 @@ class ImageGradients(Cell):
|
|||
super(ImageGradients, self).__init__()
|
||||
|
||||
def construct(self, images):
|
||||
_check_input_4d(F.shape(images), "images", self.cls_name)
|
||||
batch_size, depth, height, width = P.Shape()(images)
|
||||
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
|
||||
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
|
||||
|
@ -151,8 +152,8 @@ class SSIM(Cell):
|
|||
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
|
||||
|
||||
def construct(self, img1, img2):
|
||||
_check_input_4d(F.shape(img1), "img1", "SSIM")
|
||||
_check_input_4d(F.shape(img2), "img2", "SSIM")
|
||||
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
||||
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
||||
P.SameTypeShape()(img1, img2)
|
||||
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
||||
|
@ -244,8 +245,8 @@ class PSNR(Cell):
|
|||
self.max_val = max_val
|
||||
|
||||
def construct(self, img1, img2):
|
||||
_check_input_4d(F.shape(img1), "img1", "PSNR")
|
||||
_check_input_4d(F.shape(img2), "img2", "PSNR")
|
||||
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
||||
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
||||
P.SameTypeShape()(img1, img2)
|
||||
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
|
||||
img1 = _convert_img_dtype_to_float32(img1, self.max_val)
|
||||
|
|
|
@ -1016,6 +1016,7 @@ class Argmin(PrimitiveWithInfer):
|
|||
"""init Argmin"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
validator.check_type_name("output_type", output_type, [mstype.int32, mstype.int64], self.name)
|
||||
self.axis = axis
|
||||
self.add_prim_attr('output_type', output_type)
|
||||
|
||||
|
@ -1726,7 +1727,9 @@ class Diag(PrimitiveWithInfer):
|
|||
def infer_value(self, x):
|
||||
if x is None:
|
||||
return None
|
||||
validator.check_integer("input x rank", len(x.shape()), 1, Rel.EQ, self.name)
|
||||
# do constant-folding only when x rank is 1
|
||||
if len(x.shape()) != 1:
|
||||
return None
|
||||
ret = np.diag(x.asnumpy())
|
||||
return Tensor(ret)
|
||||
|
||||
|
@ -1752,7 +1755,7 @@ class DiagPart(PrimitiveWithInfer):
|
|||
>>> [0, 0, 3, 0],
|
||||
>>> [0, 0, 0, 4]])
|
||||
>>> diag_part = P.DiagPart()
|
||||
>>> diag_part(x)
|
||||
>>> diag_part(input_x)
|
||||
[1, 2, 3, 4]
|
||||
"""
|
||||
|
||||
|
@ -1776,7 +1779,9 @@ class DiagPart(PrimitiveWithInfer):
|
|||
def infer_value(self, x):
|
||||
if x is None:
|
||||
return None
|
||||
validator.check("x rank", len(x.shape()), "", 2, Rel.EQ, self.name)
|
||||
# do constant-folding only when x rank is 2
|
||||
if len(x.shape()) != 2:
|
||||
return None
|
||||
ret = np.diag(x.asnumpy())
|
||||
return Tensor(ret)
|
||||
|
||||
|
|
|
@ -2037,7 +2037,7 @@ class Atan2(_MathBinaryOp):
|
|||
r"""
|
||||
Returns arctangent of input_x/input_y element-wise.
|
||||
|
||||
It returns :math:`\theta\ \in\ (-\frac{\pi}{2}, \frac{\pi}{2})`
|
||||
It returns :math:`\theta\ \in\ [-\pi, \pi]`
|
||||
such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`.
|
||||
|
||||
Inputs:
|
||||
|
|
|
@ -147,13 +147,13 @@ TEST_F(TestConvert, TestReluOps) {
|
|||
}
|
||||
|
||||
TEST_F(TestConvert, TestConvertBatchNorm) {
|
||||
PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm;
|
||||
fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f));
|
||||
fused_batch_norm->AddAttr("momentum", MakeValue(0.1f));
|
||||
PrimitivePtr batch_norm = prim::kPrimBatchNorm;
|
||||
batch_norm->AddAttr("epsilon", MakeValue(0.001f));
|
||||
batch_norm->AddAttr("momentum", MakeValue(0.1f));
|
||||
|
||||
FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(fused_batch_norm));
|
||||
inputs.push_back(NewValueNode(batch_norm));
|
||||
for (unsigned int i = 0; i < 5; i++) {
|
||||
inputs.push_back(anf_graph->add_parameter());
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test image gradients """
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -47,3 +48,10 @@ def test_compile_multi_channel():
|
|||
[[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype)
|
||||
net = Net()
|
||||
_executor.compile(net, image)
|
||||
|
||||
def test_invalid_5d_input():
|
||||
dtype = mstype.float32
|
||||
image = Tensor(np.random.random([4, 1, 16, 16, 1]), dtype=dtype)
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_executor.compile(net, image)
|
|
@ -14,16 +14,15 @@
|
|||
# ============================================================================
|
||||
""" test array ops """
|
||||
import functools
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import prim_attr_register
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.primitive import Primitive, PrimitiveWithInfer
|
||||
from mindspore.common.dtype import get_py_obj_dtype
|
||||
from mindspore._c_expression import signature_dtype as sig_dtype
|
||||
from mindspore._c_expression import signature_rw as sig_rw
|
||||
from mindspore._c_expression import signature_kind as sig_kind
|
||||
|
@ -96,6 +95,17 @@ def test_select():
|
|||
expect = np.array([[1, 8, 9], [10, 5, 6]])
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
def test_argmin_invalid_output_type():
|
||||
P.Argmin(-1, mstype.int64)
|
||||
P.Argmin(-1, mstype.int32)
|
||||
with pytest.raises(TypeError):
|
||||
P.Argmin(-1, mstype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
P.Argmin(-1, mstype.float64)
|
||||
with pytest.raises(TypeError):
|
||||
P.Argmin(-1, mstype.uint8)
|
||||
with pytest.raises(TypeError):
|
||||
P.Argmin(-1, mstype.bool_)
|
||||
|
||||
class CustomOP(PrimitiveWithInfer):
|
||||
__mindspore_signature__ = (sig_dtype.T, sig_dtype.T, sig_dtype.T1,
|
||||
|
|
|
@ -17,6 +17,7 @@ import functools
|
|||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore import Tensor
|
||||
|
|
Loading…
Reference in New Issue