!420 Fix mixed precision usage when const value passed as argument
Merge pull request !420 from amongo/FixMixedPrecisionForConstParam
This commit is contained in:
commit
cc08fdfd9a
|
@ -68,9 +68,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
|
|||
return param;
|
||||
}
|
||||
auto cast_helper = prim::GetPythonOps("_mp_cast_helper", "mindspore.ops.composite.base");
|
||||
auto partial =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(cast_helper), NewValueNode(dst_type)});
|
||||
auto cast = func_graph->NewCNode({NewValueNode(prim::kCompositeHyperMap), partial, param});
|
||||
auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param});
|
||||
return cast;
|
||||
}
|
||||
|
||||
|
|
|
@ -307,3 +307,12 @@ def _mixed_precision_cast_helper_2(type_, x):
|
|||
if F.issubclass_(F.dtype(x), mstype.float_):
|
||||
return P.Cast()(x, type_)
|
||||
return x
|
||||
|
||||
@_mp_cast_helper.register("TypeType", "Tuple")
|
||||
@core
|
||||
def _mixed_precision_cast_helper_3(type_, x):
|
||||
"""if x is a tuple"""
|
||||
t = ()
|
||||
for item in x:
|
||||
t = t + (_mp_cast_helper(type_, item),)
|
||||
return t
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.nn import Cell
|
|||
from mindspore.ops import operations as P
|
||||
import mindspore.ops.composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
def test_parser_three_default_mixed_args_subnet():
|
||||
|
||||
|
@ -227,3 +227,43 @@ def test_net_vargs_expand():
|
|||
|
||||
net.set_train()
|
||||
net(x, y, sens)
|
||||
|
||||
|
||||
def test_mixed_precision_const_parameter():
|
||||
class NetLoss(Cell):
|
||||
def __init__(self):
|
||||
super(NetLoss, self).__init__()
|
||||
self.shape = P.Shape()
|
||||
self.up_sample1 = P.ResizeBilinear((14, 14))
|
||||
self.up_sample2 = P.ResizeBilinear((28, 28))
|
||||
self.up_sample3 = P.ResizeBilinear((36, 36))
|
||||
def construct(self, x, y, z, *args):
|
||||
ret = 0
|
||||
if args[0] == self.shape(z)[2]:
|
||||
if args[0] == 14:
|
||||
ret = self.up_sample1(y) + x
|
||||
elif args[0] == 28:
|
||||
ret = self.up_sample2(y) - x
|
||||
else:
|
||||
ret = x / y
|
||||
else:
|
||||
ret = x * y
|
||||
ret = ret * z
|
||||
return ret
|
||||
class NetMain(Cell):
|
||||
def __init__(self, loss_fn):
|
||||
super(NetMain, self).__init__()
|
||||
self.loss_fn = loss_fn
|
||||
self.shape = P.Shape()
|
||||
def construct(self, x, y, z):
|
||||
size_x = self.shape(x)[2]
|
||||
size_y = self.shape(y)[2]
|
||||
ret = self.loss_fn(x, y, z, size_x, size_y)
|
||||
return ret
|
||||
loss_fn = NetLoss()
|
||||
net = NetMain(loss_fn)
|
||||
net.add_flags_recursive(fp32=True)
|
||||
x = Tensor(np.ones((1, 3, 28, 28), np.float32))
|
||||
y = Tensor(np.ones((1, 3, 14, 14), np.float32))
|
||||
z = Tensor(np.ones((1, 3, 28, 28), np.float32))
|
||||
out = net(x, y, z)
|
Loading…
Reference in New Issue