diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 595edb7c43c..b38104ac969 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -574,7 +574,8 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_id; func_graph = value->cast(); if (!func_graph->dropped()) { - if (pipeline::GetJitLevel() == "o0") { + bool forbid_reuse = py::hasattr(obj, PYTHON_FUNCTION_FORBID_REUSE); + if (forbid_reuse || pipeline::GetJitLevel() == "o0") { return BasicClone(func_graph); } return func_graph; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 162a42ef150..977e841a574 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1031,17 +1031,20 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; return ret_abstract; } - // Try to get infer result from evaluator cache. - auto eval_result = evaluator_cache_mgr_->GetValue(args); - if (eval_result != nullptr) { - return std::make_shared(eval_result->abstract()->Clone(), eval_result->attribute()); + auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT); + if (!forbid_reuse) { + // Try to get infer result from evaluator cache. + EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args); + if (eval_result != nullptr) { + return std::make_shared(eval_result->abstract()->Clone(), eval_result->attribute()); + } } // In pynative mode (engine == nullptr), it is difficult to set added_attrs to // python object by C++ code, so we disable global eval cache in pynative mode. - const bool enable_global_cache = (engine != nullptr); + const bool enable_global_cache = (engine != nullptr && !forbid_reuse); if (enable_global_cache) { // Try to get infer result from global primitive eval cache. - eval_result = eval_cache_->Get(prim_py_, args); + EvalResultPtr eval_result = eval_cache_->Get(prim_py_, args); if (eval_result != nullptr) { // Global cache hit. evaluator_cache_mgr_->SetValue(args, eval_result); @@ -1059,7 +1062,7 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); auto res_abs = PyInferRes2Abstract(prim_py_, output); MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString(); - eval_result = std::make_shared(res_abs, std::make_shared(added_attrs)); + EvalResultPtr eval_result = std::make_shared(res_abs, std::make_shared(added_attrs)); // Save result to global primitive eval cache. if (enable_global_cache) { eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result); diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 350f9df29aa..04ce1482ef5 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -21,4 +21,5 @@ const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__"; const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__"; const char PYTHON_MS_CLASS[] = "__ms_class__"; const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__"; +const char PYTHON_FUNCTION_FORBID_REUSE[] = "__function_forbid_reuse__"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index 117e22df1a8..8a6e3c5bddc 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -24,6 +24,7 @@ extern const char PYTHON_CELL_AS_LIST[]; extern const char PYTHON_DATACLASS_FIELDS[]; extern const char PYTHON_MS_CLASS[]; extern const char PYTHON_CLASS_MEMBER_NAMESPACE[]; +extern const char PYTHON_FUNCTION_FORBID_REUSE[]; } // namespace mindspore #endif // PYBIND_API_EXPORT_FLAGS_H_ diff --git a/mindspore/core/utils/flags.h b/mindspore/core/utils/flags.h index 27d1437f745..ae536e9c6f0 100644 --- a/mindspore/core/utils/flags.h +++ b/mindspore/core/utils/flags.h @@ -28,6 +28,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_HIDDEN[] = "side_effect_hidden"; inline const char GRAPH_FLAG_SIDE_EFFECT_EXCEPTION[] = "side_effect_exception"; inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate"; inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop"; +inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result"; inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header"; inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip"; diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index 2771354e9c6..f8c2542e423 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -508,6 +508,13 @@ def ms_class(cls): return cls +def _function_forbid_reuse(func): + if not inspect.isfunction(func): + raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.') + setattr(func, '__function_forbid_reuse__', True) + return func + + def is_pynative_parallel(): run_mode = context.get_context('mode') parallel_mode = context.get_auto_parallel_context('parallel_mode') diff --git a/mindspore/python/mindspore/ops/composite/random_ops.py b/mindspore/python/mindspore/ops/composite/random_ops.py index 7331fe393e3..5adb995f2cd 100644 --- a/mindspore/python/mindspore/ops/composite/random_ops.py +++ b/mindspore/python/mindspore/ops/composite/random_ops.py @@ -19,14 +19,16 @@ from .. import functional as F from .multitype_ops import _constexpr_utils as const_utils from ...common import dtype as mstype from ...common.seed import _get_graph_seed +from ...common.api import _function_forbid_reuse -@constexpr +@constexpr(reuse_result=False) def _get_seed(op_seed, kernel_name): "Get the graph-level seed." return _get_graph_seed(op_seed, kernel_name) +@_function_forbid_reuse def normal(shape, mean, stddev, seed=None): """ Generates random numbers according to the Normal (or Gaussian) random number distribution. @@ -85,6 +87,7 @@ def normal(shape, mean, stddev, seed=None): return value +@_function_forbid_reuse def laplace(shape, mean, lambda_param, seed=None): r""" Generates random numbers according to the Laplace random number distribution. @@ -132,6 +135,7 @@ def laplace(shape, mean, lambda_param, seed=None): return value +@_function_forbid_reuse def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32): """ Generates random numbers according to the Uniform random number distribution. @@ -205,6 +209,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32): return value +@_function_forbid_reuse def gamma(shape, alpha, beta, seed=None): """ Generates random numbers according to the Gamma random number distribution. @@ -283,6 +288,7 @@ def gamma(shape, alpha, beta, seed=None): return value +@_function_forbid_reuse def poisson(shape, mean, seed=None): r""" Generates random numbers according to the Poisson random number distribution. @@ -334,6 +340,7 @@ def poisson(shape, mean, seed=None): return value +@_function_forbid_reuse def multinomial(inputs, num_sample, replacement=True, seed=None): r""" Returns a tensor sampled from the multinomial probability distribution located in the corresponding diff --git a/mindspore/python/mindspore/ops/primitive.py b/mindspore/python/mindspore/ops/primitive.py index b138b98bda5..ceb52ff0c9e 100644 --- a/mindspore/python/mindspore/ops/primitive.py +++ b/mindspore/python/mindspore/ops/primitive.py @@ -690,7 +690,7 @@ def prim_attr_register(fn): return deco -def constexpr(fn=None, get_instance=True, name=None): +def constexpr(fn=None, get_instance=True, name=None, reuse_result=True): """ Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function to compute constant value using the constants in the constructor. @@ -700,6 +700,8 @@ def constexpr(fn=None, get_instance=True, name=None): get_instance (bool): If true, return the instance of operator, otherwise return the operator class. Default: True. name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None. + reuse_result (bool): If true, the operator will be executed once and reuse the result next time, + otherwise the operator will always be executed. Default: True. Examples: >>> from mindspore.ops import constexpr @@ -732,6 +734,8 @@ def constexpr(fn=None, get_instance=True, name=None): op_name = name if name else fn.__name__ PrimitiveWithInfer.__init__(self, op_name) self.set_const_prim(True) + if not reuse_result: + self.add_prim_attr('forbid_reuse_result', True) def infer_value(self, *args): return fn(*args)