forked from mindspore-Ecosystem/mindspore
support mix precision in pynative
This commit is contained in:
parent
8f4bab4e75
commit
09b437285b
|
@ -26,6 +26,7 @@ from ..common.parameter import Parameter, ParameterTuple
|
||||||
from .._c_expression import init_backend
|
from .._c_expression import init_backend
|
||||||
from ..ops.primitive import Primitive
|
from ..ops.primitive import Primitive
|
||||||
from ..ops.operations import HookBackward
|
from ..ops.operations import HookBackward
|
||||||
|
from ..ops.functional import cast
|
||||||
from ..parallel._tensor import _load_tensor_by_layout
|
from ..parallel._tensor import _load_tensor_by_layout
|
||||||
from ..common.tensor import Tensor
|
from ..common.tensor import Tensor
|
||||||
|
|
||||||
|
@ -60,6 +61,7 @@ class Cell:
|
||||||
def __init__(self, auto_prefix=True, flags=None):
|
def __init__(self, auto_prefix=True, flags=None):
|
||||||
self._params = OrderedDict()
|
self._params = OrderedDict()
|
||||||
self._cells = OrderedDict()
|
self._cells = OrderedDict()
|
||||||
|
self._params_list = OrderedDict()
|
||||||
self.training = False
|
self.training = False
|
||||||
self.requires_grad = False
|
self.requires_grad = False
|
||||||
self.pynative = False
|
self.pynative = False
|
||||||
|
@ -188,11 +190,22 @@ class Cell:
|
||||||
if '_params' in self.__dict__:
|
if '_params' in self.__dict__:
|
||||||
params = self.__dict__['_params']
|
params = self.__dict__['_params']
|
||||||
if name in params:
|
if name in params:
|
||||||
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
|
return self.cast_param(params[name])
|
||||||
return params[name]
|
return params[name]
|
||||||
if '_cells' in self.__dict__:
|
if '_cells' in self.__dict__:
|
||||||
cells = self.__dict__['_cells']
|
cells = self.__dict__['_cells']
|
||||||
if name in cells:
|
if name in cells:
|
||||||
return cells[name]
|
return cells[name]
|
||||||
|
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
|
||||||
|
params_list = self.__dict__['_params_list']
|
||||||
|
if name in params_list:
|
||||||
|
para_list = params_list[name]
|
||||||
|
cast_list = list()
|
||||||
|
for para in para_list:
|
||||||
|
cast_list.append(self.cast_param(para))
|
||||||
|
para_list = ParameterTuple(cast_list)
|
||||||
|
return para_list
|
||||||
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -225,10 +238,21 @@ class Cell:
|
||||||
cell.set_grad(True)
|
cell.set_grad(True)
|
||||||
else:
|
else:
|
||||||
_pynative_exec.set_grad_flag(False)
|
_pynative_exec.set_grad_flag(False)
|
||||||
if self.enable_hook:
|
cast_inputs = list()
|
||||||
output = self._hook_construct(*inputs)
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||||
|
for item in inputs:
|
||||||
|
cast_inputs.append(cast(item, mstype.float16))
|
||||||
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||||
|
for item in inputs:
|
||||||
|
cast_inputs.append(cast(item, mstype.float32))
|
||||||
|
if cast_inputs:
|
||||||
|
cast_inputs = tuple(cast_inputs)
|
||||||
else:
|
else:
|
||||||
output = self.construct(*inputs)
|
cast_inputs = inputs
|
||||||
|
if self.enable_hook:
|
||||||
|
output = self._hook_construct(*cast_inputs)
|
||||||
|
else:
|
||||||
|
output = self.construct(*cast_inputs)
|
||||||
if isinstance(output, Parameter):
|
if isinstance(output, Parameter):
|
||||||
output = output.data
|
output = output.data
|
||||||
if self.requires_grad is True:
|
if self.requires_grad is True:
|
||||||
|
@ -241,6 +265,7 @@ class Cell:
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
cells = self.__dict__.get('_cells')
|
cells = self.__dict__.get('_cells')
|
||||||
params = self.__dict__.get('_params')
|
params = self.__dict__.get('_params')
|
||||||
|
params_list = self.__dict__.get('_params_list')
|
||||||
if isinstance(value, Parameter):
|
if isinstance(value, Parameter):
|
||||||
if params is None:
|
if params is None:
|
||||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||||
|
@ -256,6 +281,11 @@ class Cell:
|
||||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||||
for item in value:
|
for item in value:
|
||||||
self.insert_param_to_cell(item.name, item, check_name=False)
|
self.insert_param_to_cell(item.name, item, check_name=False)
|
||||||
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
|
if name in self.__dict__:
|
||||||
|
del self.__dict__[name]
|
||||||
|
params_list[name] = value
|
||||||
|
else:
|
||||||
object.__setattr__(self, name, value)
|
object.__setattr__(self, name, value)
|
||||||
elif isinstance(value, Cell):
|
elif isinstance(value, Cell):
|
||||||
if cells is None:
|
if cells is None:
|
||||||
|
@ -458,6 +488,19 @@ class Cell:
|
||||||
raise TypeError("The type of parameter should be 'Parameter' if not None.")
|
raise TypeError("The type of parameter should be 'Parameter' if not None.")
|
||||||
self._params[param_name] = param
|
self._params[param_name] = param
|
||||||
|
|
||||||
|
def cast_param(self, param):
|
||||||
|
"""
|
||||||
|
Cast parameter according to auto mix precison level in pynative mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param (Parameter): The parameter to cast.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||||
|
return cast(param, mstype.float16)
|
||||||
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||||
|
return cast(param, mstype.float32)
|
||||||
|
return param
|
||||||
|
|
||||||
def insert_child_to_cell(self, child_name, child):
|
def insert_child_to_cell(self, child_name, child):
|
||||||
"""
|
"""
|
||||||
Adds a child cell to the current cell.
|
Adds a child cell to the current cell.
|
||||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import composite as C
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.parameter import ParameterTuple, Parameter
|
from mindspore.common.parameter import ParameterTuple, Parameter
|
||||||
|
|
||||||
context.set_context(device_target='CPU')
|
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||||
|
|
||||||
|
|
||||||
class LstmNet(nn.Cell):
|
class LstmNet(nn.Cell):
|
||||||
|
|
Loading…
Reference in New Issue