!21616 modify ms function cache key

Merge pull request !21616 from chujinjin/modify_ms_function_cache_key
This commit is contained in:
i-robot 2021-08-26 07:36:50 +00:00 committed by Gitee
commit a646eb27ab
4 changed files with 25 additions and 35 deletions

View File

@ -123,15 +123,10 @@ class _MindSporeFunction:
def __init__(self, fn, input_signature=None, obj=None):
self.fn = fn
self.save_graphs = context.get_context("save_graphs")
self.save_graphs_path = context.get_context("save_graphs_path")
self.input_signature = input_signature
self.obj = None
self.identify_obj = None
if hasattr(obj, fn.__name__):
self.obj = obj
elif obj is not None:
self.identify_obj = obj
self._executor = Executor_.get_instance()
def build_data_init_graph(self, graph_name):
@ -145,16 +140,8 @@ class _MindSporeFunction:
init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
_exec_init_graph(self.obj, init_phase)
def compile(self, arguments_dict, method_name):
def compile(self, args_list, arg_names, method_name):
"""Returns pipeline for the given args."""
args_list = tuple(arguments_dict.values())
arg_names = tuple(arguments_dict.keys())
# remove first self parameter when fn is a method
if self.obj is not None:
args_list = args_list[1:]
arg_names = arg_names[1:]
# verify the signature for both function and method
if self.input_signature is not None:
signatures = []
@ -167,18 +154,20 @@ class _MindSporeFunction:
raise ValueError("Inputs is incompatible with input signature!")
dic = dict(zip(arg_names, args_list))
generate_name = self.fn.__module__ + "." + self.fn.__name__
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
str(self.fn.__code__.co_firstlineno)
self.fn.__parse_method__ = method_name
# replace key with obj info and object ext info when fn is a method
if self.obj is not None:
# add key with obj
identify = ""
if self.obj is None:
identify = str(id(self.fn))
else:
self.obj.__parse_method__ = method_name
generate_name = self.obj.__module__ + "."
if self.obj.__class__.__name__ != "ClipByNorm":
generate_name = generate_name + str(self.obj.create_time) + '.' + self.fn.__name__
if self.identify_obj is not None:
generate_name = generate_name + str(id(self.identify_obj))
generate_name = self.obj.__module__ + "." + generate_name
identify = str(self.obj.create_time) + "_" + str(id(self.obj)) + '_' + str(id(self.fn))
generate_name = generate_name + "." + identify
key = generate_key(generate_name, dic)
phase = str(key[1]) + generate_name
if key not in ms_compile_cache.keys():
@ -191,9 +180,7 @@ class _MindSporeFunction:
raise RuntimeError("Executor compile failed.")
if context.get_context("enable_ge"):
self.build_data_init_graph(phase)
# since function can be redefined, we only cache class method pipeline
if self.obj is not None or self.identify_obj is not None:
ms_compile_cache[key] = phase
ms_compile_cache[key] = phase
return phase
return ms_compile_cache[key]
@ -206,10 +193,12 @@ class _MindSporeFunction:
raise RuntimeError('Process function parameter is failure')
args_list = tuple(arguments_dict.values())
arg_names = tuple(arguments_dict.keys())
if self.obj is not None:
args_list = args_list[1:]
arg_names = arg_names[1:]
phase = self.compile(arguments_dict, parse_method)
phase = self.compile(args_list, arg_names, parse_method)
if context.get_context("precompile_only"):
return None
@ -233,7 +222,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
Args:
fn (Function): The Python function that will be run as a graph. Default: None.
obj (Object): The Python Object that provides the information for identifying the compiled function.Default:
obj (Object): The python object that provides the information for identifying the compiled function. Default:
None.
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
@ -286,7 +275,10 @@ def ms_function(fn=None, obj=None, input_signature=None):
@wraps(func)
def staging_specialize(*args):
input_args = args
process_obj = obj
if obj is not None:
logger.warning("Obj is no longer in use, and the function's own object has been used to \
distinguish whether it has been compiled.")
process_obj = None
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
input_args = args[1:]
process_obj = args[0]

View File

@ -358,11 +358,11 @@ class GradOperation(GradOperation_):
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
if context.get_context("mode") == context.GRAPH_MODE:
if self.get_by_list:
@ms_function(obj=fn)
@ms_function
def after_grad(*args):
return grad_(fn, weights)(*args)
else:
@ms_function(obj=fn)
@ms_function
def after_grad(*args):
return grad_(fn)(*args)
else:

View File

@ -24,7 +24,6 @@ import mindspore._c_expression as _c_expression
from mindspore import ParameterTuple
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import ms_function
from mindspore.ops.composite import GradOperation
from .block_util import get_output_cell, gen_net, gen_grad_net, \
get_uniform_with_shape, set_block_phase, get_output_reduce_cell, set_block_param_with_rand
@ -36,7 +35,7 @@ class _GradChecker:
Arguments:
fn: The function under test.
gfn: The hight order function to compute the derivative function.
gfn: The high order function to compute the derivative function.
args: The point in the function's domain where we want
to estimate the gradient.
@ -119,7 +118,6 @@ class _GradChecker:
def func_backward_pynative(*inputs):
net = gen_grad_net(f, grad_wraper, len(inputs) - 1, inputs[-1])
@ms_function
def _func_pynative(*inputs):
return net(*inputs)
@ -130,7 +128,6 @@ class _GradChecker:
def func_forward_pynative(*inputs):
net = gen_net(f, len(inputs))
@ms_function
def _func_pynative(*inputs):
return net(*inputs)

View File

@ -19,11 +19,12 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.api import ms_function
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.utils.check_gradient import (
ms_function, check_jacobian, Tensor, NNGradChecker,
check_jacobian, Tensor, NNGradChecker,
OperationGradChecker, check_gradient)
context.set_context(mode=context.PYNATIVE_MODE)