!420 Fix mixed precision usage when const value passed as argument

Merge pull request !420 from amongo/FixMixedPrecisionForConstParam
This commit is contained in:
mindspore-ci-bot 2020-04-18 10:31:33 +08:00 committed by Gitee
commit cc08fdfd9a
3 changed files with 51 additions and 4 deletions

View File

@ -68,9 +68,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
return param;
}
auto cast_helper = prim::GetPythonOps("_mp_cast_helper", "mindspore.ops.composite.base");
auto partial =
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(cast_helper), NewValueNode(dst_type)});
auto cast = func_graph->NewCNode({NewValueNode(prim::kCompositeHyperMap), partial, param});
auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param});
return cast;
}

View File

@ -307,3 +307,12 @@ def _mixed_precision_cast_helper_2(type_, x):
if F.issubclass_(F.dtype(x), mstype.float_):
return P.Cast()(x, type_)
return x
@_mp_cast_helper.register("TypeType", "Tuple")
@core
def _mixed_precision_cast_helper_3(type_, x):
"""if x is a tuple"""
t = ()
for item in x:
t = t + (_mp_cast_helper(type_, item),)
return t

View File

@ -19,7 +19,7 @@ from mindspore.nn import Cell
from mindspore.ops import operations as P
import mindspore.ops.composite as C
context.set_context(mode=context.GRAPH_MODE)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def test_parser_three_default_mixed_args_subnet():
@ -227,3 +227,43 @@ def test_net_vargs_expand():
net.set_train()
net(x, y, sens)
def test_mixed_precision_const_parameter():
class NetLoss(Cell):
def __init__(self):
super(NetLoss, self).__init__()
self.shape = P.Shape()
self.up_sample1 = P.ResizeBilinear((14, 14))
self.up_sample2 = P.ResizeBilinear((28, 28))
self.up_sample3 = P.ResizeBilinear((36, 36))
def construct(self, x, y, z, *args):
ret = 0
if args[0] == self.shape(z)[2]:
if args[0] == 14:
ret = self.up_sample1(y) + x
elif args[0] == 28:
ret = self.up_sample2(y) - x
else:
ret = x / y
else:
ret = x * y
ret = ret * z
return ret
class NetMain(Cell):
def __init__(self, loss_fn):
super(NetMain, self).__init__()
self.loss_fn = loss_fn
self.shape = P.Shape()
def construct(self, x, y, z):
size_x = self.shape(x)[2]
size_y = self.shape(y)[2]
ret = self.loss_fn(x, y, z, size_x, size_y)
return ret
loss_fn = NetLoss()
net = NetMain(loss_fn)
net.add_flags_recursive(fp32=True)
x = Tensor(np.ones((1, 3, 28, 28), np.float32))
y = Tensor(np.ones((1, 3, 14, 14), np.float32))
z = Tensor(np.ones((1, 3, 28, 28), np.float32))
out = net(x, y, z)