diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 43973fabc21..bfc076e42b3 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -123,15 +123,10 @@ class _MindSporeFunction: def __init__(self, fn, input_signature=None, obj=None): self.fn = fn - self.save_graphs = context.get_context("save_graphs") - self.save_graphs_path = context.get_context("save_graphs_path") self.input_signature = input_signature self.obj = None - self.identify_obj = None if hasattr(obj, fn.__name__): self.obj = obj - elif obj is not None: - self.identify_obj = obj self._executor = Executor_.get_instance() def build_data_init_graph(self, graph_name): @@ -145,16 +140,8 @@ class _MindSporeFunction: init_phase = "init_subgraph" + graph_name[graph_name.find("."):] _exec_init_graph(self.obj, init_phase) - def compile(self, arguments_dict, method_name): + def compile(self, args_list, arg_names, method_name): """Returns pipeline for the given args.""" - args_list = tuple(arguments_dict.values()) - arg_names = tuple(arguments_dict.keys()) - - # remove first self parameter when fn is a method - if self.obj is not None: - args_list = args_list[1:] - arg_names = arg_names[1:] - # verify the signature for both function and method if self.input_signature is not None: signatures = [] @@ -167,18 +154,20 @@ class _MindSporeFunction: raise ValueError("Inputs is incompatible with input signature!") dic = dict(zip(arg_names, args_list)) - generate_name = self.fn.__module__ + "." + self.fn.__name__ + generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \ + str(self.fn.__code__.co_firstlineno) self.fn.__parse_method__ = method_name - # replace key with obj info and object ext info when fn is a method - if self.obj is not None: + # add key with obj + identify = "" + if self.obj is None: + identify = str(id(self.fn)) + else: self.obj.__parse_method__ = method_name - generate_name = self.obj.__module__ + "." - if self.obj.__class__.__name__ != "ClipByNorm": - generate_name = generate_name + str(self.obj.create_time) + '.' + self.fn.__name__ - if self.identify_obj is not None: - generate_name = generate_name + str(id(self.identify_obj)) + generate_name = self.obj.__module__ + "." + generate_name + identify = str(self.obj.create_time) + "_" + str(id(self.obj)) + '_' + str(id(self.fn)) + generate_name = generate_name + "." + identify key = generate_key(generate_name, dic) phase = str(key[1]) + generate_name if key not in ms_compile_cache.keys(): @@ -191,9 +180,7 @@ class _MindSporeFunction: raise RuntimeError("Executor compile failed.") if context.get_context("enable_ge"): self.build_data_init_graph(phase) - # since function can be redefined, we only cache class method pipeline - if self.obj is not None or self.identify_obj is not None: - ms_compile_cache[key] = phase + ms_compile_cache[key] = phase return phase return ms_compile_cache[key] @@ -206,10 +193,12 @@ class _MindSporeFunction: raise RuntimeError('Process function parameter is failure') args_list = tuple(arguments_dict.values()) + arg_names = tuple(arguments_dict.keys()) if self.obj is not None: args_list = args_list[1:] + arg_names = arg_names[1:] - phase = self.compile(arguments_dict, parse_method) + phase = self.compile(args_list, arg_names, parse_method) if context.get_context("precompile_only"): return None @@ -233,7 +222,7 @@ def ms_function(fn=None, obj=None, input_signature=None): Args: fn (Function): The Python function that will be run as a graph. Default: None. - obj (Object): The Python Object that provides the information for identifying the compiled function.Default: + obj (Object): The python object that provides the information for identifying the compiled function. Default: None. input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`. @@ -286,7 +275,10 @@ def ms_function(fn=None, obj=None, input_signature=None): @wraps(func) def staging_specialize(*args): input_args = args - process_obj = obj + if obj is not None: + logger.warning("Obj is no longer in use, and the function's own object has been used to \ + distinguish whether it has been compiled.") + process_obj = None if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__): input_args = args[1:] process_obj = args[0] diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 45d28bd5e0d..9265058f8bd 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -358,11 +358,11 @@ class GradOperation(GradOperation_): grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param) if context.get_context("mode") == context.GRAPH_MODE: if self.get_by_list: - @ms_function(obj=fn) + @ms_function def after_grad(*args): return grad_(fn, weights)(*args) else: - @ms_function(obj=fn) + @ms_function def after_grad(*args): return grad_(fn)(*args) else: diff --git a/tests/mindspore_test_framework/utils/check_gradient.py b/tests/mindspore_test_framework/utils/check_gradient.py index 61ce116d031..8a5efbfe933 100644 --- a/tests/mindspore_test_framework/utils/check_gradient.py +++ b/tests/mindspore_test_framework/utils/check_gradient.py @@ -24,7 +24,6 @@ import mindspore._c_expression as _c_expression from mindspore import ParameterTuple from mindspore import Tensor from mindspore import context -from mindspore.common.api import ms_function from mindspore.ops.composite import GradOperation from .block_util import get_output_cell, gen_net, gen_grad_net, \ get_uniform_with_shape, set_block_phase, get_output_reduce_cell, set_block_param_with_rand @@ -36,7 +35,7 @@ class _GradChecker: Arguments: fn: The function under test. - gfn: The hight order function to compute the derivative function. + gfn: The high order function to compute the derivative function. args: The point in the function's domain where we want to estimate the gradient. @@ -119,7 +118,6 @@ class _GradChecker: def func_backward_pynative(*inputs): net = gen_grad_net(f, grad_wraper, len(inputs) - 1, inputs[-1]) - @ms_function def _func_pynative(*inputs): return net(*inputs) @@ -130,7 +128,6 @@ class _GradChecker: def func_forward_pynative(*inputs): net = gen_net(f, len(inputs)) - @ms_function def _func_pynative(*inputs): return net(*inputs) diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 97166546410..86d4d7ef12b 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -19,11 +19,12 @@ import mindspore.nn as nn from mindspore import context from mindspore.common import dtype as mstype from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common.api import ms_function from mindspore.ops import composite as C from mindspore.ops import operations as P from ..ut_filter import non_graph_engine from ....mindspore_test_framework.utils.check_gradient import ( - ms_function, check_jacobian, Tensor, NNGradChecker, + check_jacobian, Tensor, NNGradChecker, OperationGradChecker, check_gradient) context.set_context(mode=context.PYNATIVE_MODE)