forked from mindspore-Ecosystem/mindspore
!3295 support mix precision in pynative mode
Merge pull request !3295 from wangqiuliang/support-mix-precision
This commit is contained in:
commit
9351739b3d
|
@ -26,6 +26,7 @@ from ..common.parameter import Parameter, ParameterTuple
|
|||
from .._c_expression import init_backend
|
||||
from ..ops.primitive import Primitive
|
||||
from ..ops.operations import HookBackward
|
||||
from ..ops.functional import cast
|
||||
from ..parallel._tensor import _load_tensor_by_layout
|
||||
from ..common.tensor import Tensor
|
||||
|
||||
|
@ -60,6 +61,7 @@ class Cell:
|
|||
def __init__(self, auto_prefix=True, flags=None):
|
||||
self._params = OrderedDict()
|
||||
self._cells = OrderedDict()
|
||||
self._params_list = OrderedDict()
|
||||
self.training = False
|
||||
self.requires_grad = False
|
||||
self.pynative = False
|
||||
|
@ -188,11 +190,22 @@ class Cell:
|
|||
if '_params' in self.__dict__:
|
||||
params = self.__dict__['_params']
|
||||
if name in params:
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
return self.cast_param(params[name])
|
||||
return params[name]
|
||||
if '_cells' in self.__dict__:
|
||||
cells = self.__dict__['_cells']
|
||||
if name in cells:
|
||||
return cells[name]
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
|
||||
params_list = self.__dict__['_params_list']
|
||||
if name in params_list:
|
||||
para_list = params_list[name]
|
||||
cast_list = list()
|
||||
for para in para_list:
|
||||
cast_list.append(self.cast_param(para))
|
||||
para_list = ParameterTuple(cast_list)
|
||||
return para_list
|
||||
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
||||
|
||||
def __del__(self):
|
||||
|
@ -225,10 +238,21 @@ class Cell:
|
|||
cell.set_grad(True)
|
||||
else:
|
||||
_pynative_exec.set_grad_flag(False)
|
||||
if self.enable_hook:
|
||||
output = self._hook_construct(*inputs)
|
||||
cast_inputs = list()
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float16))
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float32))
|
||||
if cast_inputs:
|
||||
cast_inputs = tuple(cast_inputs)
|
||||
else:
|
||||
output = self.construct(*inputs)
|
||||
cast_inputs = inputs
|
||||
if self.enable_hook:
|
||||
output = self._hook_construct(*cast_inputs)
|
||||
else:
|
||||
output = self.construct(*cast_inputs)
|
||||
if isinstance(output, Parameter):
|
||||
output = output.data
|
||||
if self.requires_grad is True:
|
||||
|
@ -241,6 +265,7 @@ class Cell:
|
|||
def __setattr__(self, name, value):
|
||||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
params_list = self.__dict__.get('_params_list')
|
||||
if isinstance(value, Parameter):
|
||||
if params is None:
|
||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
|
@ -256,7 +281,12 @@ class Cell:
|
|||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
for item in value:
|
||||
self.insert_param_to_cell(item.name, item, check_name=False)
|
||||
object.__setattr__(self, name, value)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
params_list[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
elif isinstance(value, Cell):
|
||||
if cells is None:
|
||||
raise AttributeError("Can not assign cells before Cell.__init__() call.")
|
||||
|
@ -458,6 +488,19 @@ class Cell:
|
|||
raise TypeError("The type of parameter should be 'Parameter' if not None.")
|
||||
self._params[param_name] = param
|
||||
|
||||
def cast_param(self, param):
|
||||
"""
|
||||
Cast parameter according to auto mix precison level in pynative mode.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter to cast.
|
||||
"""
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||
return cast(param, mstype.float16)
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||
return cast(param, mstype.float32)
|
||||
return param
|
||||
|
||||
def insert_child_to_cell(self, child_name, child):
|
||||
"""
|
||||
Adds a child cell to the current cell.
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import ParameterTuple, Parameter
|
||||
|
||||
context.set_context(device_target='CPU')
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class LstmNet(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue