forked from mindspore-Ecosystem/mindspore
!4592 fix optimizer tuple inpus issue
Merge pull request !4592 from wangqiuliang/resolve-optimizer-tuple-inputs-issue
This commit is contained in:
commit
0f753ee29d
|
@ -223,6 +223,15 @@ class Cell:
|
|||
else:
|
||||
object.__delattr__(self, name)
|
||||
|
||||
def cast_inputs(self, inputs, dst_type):
|
||||
res = list()
|
||||
for item in inputs:
|
||||
if isinstance(item, tuple):
|
||||
res.append(self.cast_inputs(item, dst_type))
|
||||
else:
|
||||
res.append(cast(item, dst_type))
|
||||
return tuple(res)
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if kwargs:
|
||||
|
@ -250,14 +259,10 @@ class Cell:
|
|||
cast_inputs = list()
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float16))
|
||||
cast_inputs = self.cast_inputs(inputs, mstype.float16)
|
||||
if self._mindspore_flags.get('fp32'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float32))
|
||||
if cast_inputs:
|
||||
cast_inputs = tuple(cast_inputs)
|
||||
else:
|
||||
cast_inputs = self.cast_inputs(inputs, mstype.float32)
|
||||
if not cast_inputs:
|
||||
cast_inputs = inputs
|
||||
if self.enable_hook:
|
||||
output = self._hook_construct(*cast_inputs, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue