forked from mindspore-Ecosystem/mindspore
!21616 modify ms function cache key
Merge pull request !21616 from chujinjin/modify_ms_function_cache_key
This commit is contained in:
commit
a646eb27ab
|
@ -123,15 +123,10 @@ class _MindSporeFunction:
|
|||
|
||||
def __init__(self, fn, input_signature=None, obj=None):
|
||||
self.fn = fn
|
||||
self.save_graphs = context.get_context("save_graphs")
|
||||
self.save_graphs_path = context.get_context("save_graphs_path")
|
||||
self.input_signature = input_signature
|
||||
self.obj = None
|
||||
self.identify_obj = None
|
||||
if hasattr(obj, fn.__name__):
|
||||
self.obj = obj
|
||||
elif obj is not None:
|
||||
self.identify_obj = obj
|
||||
self._executor = Executor_.get_instance()
|
||||
|
||||
def build_data_init_graph(self, graph_name):
|
||||
|
@ -145,16 +140,8 @@ class _MindSporeFunction:
|
|||
init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
|
||||
_exec_init_graph(self.obj, init_phase)
|
||||
|
||||
def compile(self, arguments_dict, method_name):
|
||||
def compile(self, args_list, arg_names, method_name):
|
||||
"""Returns pipeline for the given args."""
|
||||
args_list = tuple(arguments_dict.values())
|
||||
arg_names = tuple(arguments_dict.keys())
|
||||
|
||||
# remove first self parameter when fn is a method
|
||||
if self.obj is not None:
|
||||
args_list = args_list[1:]
|
||||
arg_names = arg_names[1:]
|
||||
|
||||
# verify the signature for both function and method
|
||||
if self.input_signature is not None:
|
||||
signatures = []
|
||||
|
@ -167,18 +154,20 @@ class _MindSporeFunction:
|
|||
raise ValueError("Inputs is incompatible with input signature!")
|
||||
|
||||
dic = dict(zip(arg_names, args_list))
|
||||
generate_name = self.fn.__module__ + "." + self.fn.__name__
|
||||
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
|
||||
str(self.fn.__code__.co_firstlineno)
|
||||
self.fn.__parse_method__ = method_name
|
||||
|
||||
# replace key with obj info and object ext info when fn is a method
|
||||
if self.obj is not None:
|
||||
# add key with obj
|
||||
identify = ""
|
||||
if self.obj is None:
|
||||
identify = str(id(self.fn))
|
||||
else:
|
||||
self.obj.__parse_method__ = method_name
|
||||
generate_name = self.obj.__module__ + "."
|
||||
if self.obj.__class__.__name__ != "ClipByNorm":
|
||||
generate_name = generate_name + str(self.obj.create_time) + '.' + self.fn.__name__
|
||||
if self.identify_obj is not None:
|
||||
generate_name = generate_name + str(id(self.identify_obj))
|
||||
generate_name = self.obj.__module__ + "." + generate_name
|
||||
identify = str(self.obj.create_time) + "_" + str(id(self.obj)) + '_' + str(id(self.fn))
|
||||
|
||||
generate_name = generate_name + "." + identify
|
||||
key = generate_key(generate_name, dic)
|
||||
phase = str(key[1]) + generate_name
|
||||
if key not in ms_compile_cache.keys():
|
||||
|
@ -191,9 +180,7 @@ class _MindSporeFunction:
|
|||
raise RuntimeError("Executor compile failed.")
|
||||
if context.get_context("enable_ge"):
|
||||
self.build_data_init_graph(phase)
|
||||
# since function can be redefined, we only cache class method pipeline
|
||||
if self.obj is not None or self.identify_obj is not None:
|
||||
ms_compile_cache[key] = phase
|
||||
ms_compile_cache[key] = phase
|
||||
return phase
|
||||
|
||||
return ms_compile_cache[key]
|
||||
|
@ -206,10 +193,12 @@ class _MindSporeFunction:
|
|||
raise RuntimeError('Process function parameter is failure')
|
||||
|
||||
args_list = tuple(arguments_dict.values())
|
||||
arg_names = tuple(arguments_dict.keys())
|
||||
if self.obj is not None:
|
||||
args_list = args_list[1:]
|
||||
arg_names = arg_names[1:]
|
||||
|
||||
phase = self.compile(arguments_dict, parse_method)
|
||||
phase = self.compile(args_list, arg_names, parse_method)
|
||||
|
||||
if context.get_context("precompile_only"):
|
||||
return None
|
||||
|
@ -233,7 +222,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
|
||||
Args:
|
||||
fn (Function): The Python function that will be run as a graph. Default: None.
|
||||
obj (Object): The Python Object that provides the information for identifying the compiled function.Default:
|
||||
obj (Object): The python object that provides the information for identifying the compiled function. Default:
|
||||
None.
|
||||
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
||||
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
|
||||
|
@ -286,7 +275,10 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
@wraps(func)
|
||||
def staging_specialize(*args):
|
||||
input_args = args
|
||||
process_obj = obj
|
||||
if obj is not None:
|
||||
logger.warning("Obj is no longer in use, and the function's own object has been used to \
|
||||
distinguish whether it has been compiled.")
|
||||
process_obj = None
|
||||
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
|
||||
input_args = args[1:]
|
||||
process_obj = args[0]
|
||||
|
|
|
@ -358,11 +358,11 @@ class GradOperation(GradOperation_):
|
|||
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if self.get_by_list:
|
||||
@ms_function(obj=fn)
|
||||
@ms_function
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
else:
|
||||
@ms_function(obj=fn)
|
||||
@ms_function
|
||||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
else:
|
||||
|
|
|
@ -24,7 +24,6 @@ import mindspore._c_expression as _c_expression
|
|||
from mindspore import ParameterTuple
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from .block_util import get_output_cell, gen_net, gen_grad_net, \
|
||||
get_uniform_with_shape, set_block_phase, get_output_reduce_cell, set_block_param_with_rand
|
||||
|
@ -36,7 +35,7 @@ class _GradChecker:
|
|||
|
||||
Arguments:
|
||||
fn: The function under test.
|
||||
gfn: The hight order function to compute the derivative function.
|
||||
gfn: The high order function to compute the derivative function.
|
||||
args: The point in the function's domain where we want
|
||||
to estimate the gradient.
|
||||
|
||||
|
@ -119,7 +118,6 @@ class _GradChecker:
|
|||
def func_backward_pynative(*inputs):
|
||||
net = gen_grad_net(f, grad_wraper, len(inputs) - 1, inputs[-1])
|
||||
|
||||
@ms_function
|
||||
def _func_pynative(*inputs):
|
||||
return net(*inputs)
|
||||
|
||||
|
@ -130,7 +128,6 @@ class _GradChecker:
|
|||
def func_forward_pynative(*inputs):
|
||||
net = gen_net(f, len(inputs))
|
||||
|
||||
@ms_function
|
||||
def _func_pynative(*inputs):
|
||||
return net(*inputs)
|
||||
|
||||
|
|
|
@ -19,11 +19,12 @@ import mindspore.nn as nn
|
|||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.utils.check_gradient import (
|
||||
ms_function, check_jacobian, Tensor, NNGradChecker,
|
||||
check_jacobian, Tensor, NNGradChecker,
|
||||
OperationGradChecker, check_gradient)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
|
Loading…
Reference in New Issue