fix batchnorm issue in pynative auto mix precision

This commit is contained in:
kingfo 2020-07-29 12:03:07 +08:00
parent e3fe1d76ca
commit fc92598881
8 changed files with 59 additions and 20 deletions

View File

@ -15,6 +15,7 @@
"""builtin_operations"""
import numpy as np
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
@ -174,11 +175,10 @@ def stop_gradient(x):
return x
hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
if isinstance(x, tuple):
res = list()
for item in x:
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)
def cast_inner(data):
return F.cast(data, dst_type)
return hyper_map(cast_inner, x)

View File

@ -61,6 +61,7 @@ class Parameter:
self._is_init = False
self._sliced = False
self.is_param_ps = False
self._cast_type = None
self.init_in_server = False
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
@ -103,6 +104,16 @@ class Parameter:
raise ValueError("The type of the name should be `str` or `None`.")
self._value.name = name_
@property
def cast_type(self):
return self._cast_type
@cast_type.setter
def cast_type(self, dst_type):
if dst_type not in (mstype.float16, mstype.float32, None):
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
self._cast_type = dst_type
@property
def sliced(self):
"""Get slice status of the parameter."""

View File

@ -274,7 +274,7 @@ class SparseTensor:
Outputs:
SparseTensor, composed of `indices`, `values`, `dense_shape`.
Examples:
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self, dense_shape):
>>> super(Net, self).__init__()

View File

@ -240,12 +240,13 @@ class Cell:
else:
_pynative_exec.set_grad_flag(False)
cast_inputs = list()
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
elif self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if cast_inputs:
cast_inputs = tuple(cast_inputs)
else:
@ -285,6 +286,8 @@ class Cell:
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
if name in params:
del params[name]
params_list[name] = value
else:
object.__setattr__(self, name, value)
@ -496,10 +499,13 @@ class Cell:
Args:
param (Parameter): The parameter to cast.
"""
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
param.cast_type = mstype.float16
elif self._mindspore_flags.get('fp32'):
param.cast_type = mstype.float32
else:
param.cast_type = None
return param
def insert_child_to_cell(self, child_name, child):

View File

@ -182,3 +182,4 @@ tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)
tensor_operator_registry.register('cast', cast)

View File

@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive):
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self._update_parameter = True
class BNTrainingReduce(PrimitiveWithInfer):

View File

@ -18,6 +18,8 @@
import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore.common import Parameter
from mindspore.common._register_for_tensor import tensor_operator_registry
from .._c_expression import Primitive_, real_run_op, prim_type
from .._c_expression import signature_rw as sig_rw
from .._c_expression import signature_kind as sig_kind
@ -49,6 +51,7 @@ class Primitive(Primitive_):
self.name = name
self.attrs = {}
self.init_attrs = {"name": name}
self._update_parameter = False
Primitive_.__init__(self, name, self)
if hasattr(self.__class__, '__mindspore_signature__'):
sig = self._fill_signature(self.__class__.__mindspore_signature__)
@ -189,6 +192,11 @@ class Primitive(Primitive_):
# for checking output number with kernel implementation
self.add_prim_attr("output_names", outputs)
@property
def update_parameter(self):
""" Whether the primitive will update the value of parameter."""
return self._update_parameter
class PrimitiveWithInfer(Primitive):
"""
@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func
def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode."""
output = real_run_op(obj, op_name, args)
cast = tensor_operator_registry.get("cast")
if op_name == "Cast" or obj.update_parameter:
cast_args = args
else:
cast_args = list()
for arg in args:
if isinstance(arg, Parameter):
if arg.cast_type:
cast_args.append(cast(arg, arg.cast_type))
else:
cast_args.append(arg)
else:
cast_args.append(arg)
output = real_run_op(obj, op_name, tuple(cast_args))
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:

View File

@ -428,5 +428,4 @@ def test_pynative_resnet50():
cost_time = end_time - start_time
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
if step > 1:
assert cost_time < 0.31
assert cost_time < 0.32