!63201 optimize trace in Cell
Merge pull request !63201 from NaCN/optimize_trace
This commit is contained in:
commit
370abc06b2
|
@ -50,6 +50,8 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/composite/math_ops.py" "unused-import"
|
||||
"mindspore/mindspore/python/mindspore/ops/primitive.py" "assignment-from-none"
|
||||
"mindspore/mindspore/python/mindspore/ops/primitive.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/ops/_tracefunc.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/ops/_tracefunc.py" "function-redefined"
|
||||
"mindspore/mindspore/python/mindspore/nn/cell.py" "assignment-from-none"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/resources.py" "bad-whitespace"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "broad-except"
|
||||
|
|
|
@ -43,8 +43,6 @@ from mindspore.ops.operations import _inner_ops as inner
|
|||
from mindspore.parallel.shard import Shard
|
||||
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
||||
from mindspore.common._decorator import deprecated
|
||||
from mindspore._c_expression import PackExpander
|
||||
from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
|
||||
|
||||
|
||||
class Cell(Cell_):
|
||||
|
@ -667,9 +665,6 @@ class Cell(Cell_):
|
|||
args = bound_arguments.args
|
||||
kwargs = bound_arguments.kwargs
|
||||
|
||||
if PackFunc.is_tracing():
|
||||
return self._run_tracefunc(*args, **kwargs)
|
||||
|
||||
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
|
||||
self.check_names_and_refresh_name()
|
||||
self._is_check_and_refresh = True
|
||||
|
@ -2519,22 +2514,6 @@ class Cell(Cell_):
|
|||
f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
|
||||
f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
|
||||
|
||||
def _run_tracefunc(self, *args, **kwargs):
|
||||
""" Run Packed Cell in Pack."""
|
||||
args = self._mixed_precision_cast(args)
|
||||
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
||||
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
||||
expander = PackExpander.get_instance()
|
||||
args = expander.begin_subgraph(self, *args)
|
||||
args = [_convert_tensor(a) for a in args]
|
||||
output = self._run_construct(args, kwargs)
|
||||
ret = expander.end_subgraph(self, output)
|
||||
output = _convert_tensor(ret)
|
||||
else:
|
||||
with _SetMixedPrecision(self):
|
||||
output = self._run_construct(args, kwargs)
|
||||
return output
|
||||
|
||||
def _mixed_precision_cast(self, inputs):
|
||||
mixed_type = self.get_mixed_precision_type()
|
||||
if mixed_type == MixedPrecisionType.NOTSET:
|
||||
|
|
|
@ -18,6 +18,7 @@ import types
|
|||
import textwrap
|
||||
import inspect
|
||||
import os
|
||||
from mindspore import nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.primitive import _RunOpHook, Primitive
|
||||
from mindspore._c_expression import PackExpander, PackNode
|
||||
|
@ -60,6 +61,45 @@ def _convert_tensor(node):
|
|||
return tuple(_convert_tensor(e) for e in node)
|
||||
return node
|
||||
|
||||
class PackFunc:
|
||||
pass
|
||||
|
||||
def _trace_cell_call(self, *args, **kwargs):
|
||||
""" Run Packed Cell in Pack."""
|
||||
if self.__class__.construct is nn.Cell.construct:
|
||||
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
|
||||
|
||||
if kwargs:
|
||||
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
|
||||
bound_arguments.apply_defaults()
|
||||
args = bound_arguments.args
|
||||
kwargs = bound_arguments.kwargs
|
||||
args = self._mixed_precision_cast(args)
|
||||
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
||||
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
||||
expander = PackExpander.get_instance()
|
||||
args = expander.begin_subgraph(self, *args)
|
||||
args = [_convert_tensor(a) for a in args]
|
||||
output = self._run_construct(args, kwargs)
|
||||
ret = expander.end_subgraph(self, output)
|
||||
output = _convert_tensor(ret)
|
||||
else:
|
||||
with _SetMixedPrecision(self):
|
||||
output = self._run_construct(args, kwargs)
|
||||
return output
|
||||
|
||||
class _PackHook:
|
||||
"""Hook for trace run"""
|
||||
|
||||
def __init__(self):
|
||||
self.origin_call = nn.Cell.__call__
|
||||
|
||||
def __enter__(self):
|
||||
nn.Cell.__call__ = _trace_cell_call
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
nn.Cell.__call__ = self.origin_call
|
||||
|
||||
class PackFunc(Primitive):
|
||||
"""pack function with lazy expander"""
|
||||
|
@ -114,7 +154,7 @@ class PackFunc(Primitive):
|
|||
return _convert_tensor(ret)
|
||||
|
||||
def _run_op(self, args):
|
||||
with _RunOpHook(PackFunc._trace_run_op):
|
||||
with _RunOpHook(PackFunc._trace_run_op), _PackHook():
|
||||
fun_args = [_convert_tensor(a) for a in args]
|
||||
ret = self.func(*fun_args, **self.kwargs)
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue