!32292 Resolve the duplicate seed problem of random operator.

Merge pull request !32292 from 张清华/opt
This commit is contained in:
i-robot 2022-03-30 17:03:30 +00:00 committed by Gitee
commit 96681e20b7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 35 additions and 10 deletions

View File

@ -574,7 +574,8 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_id;
func_graph = value->cast<FuncGraphPtr>();
if (!func_graph->dropped()) {
if (pipeline::GetJitLevel() == "o0") {
bool forbid_reuse = py::hasattr(obj, PYTHON_FUNCTION_FORBID_REUSE);
if (forbid_reuse || pipeline::GetJitLevel() == "o0") {
return BasicClone(func_graph);
}
return func_graph;

View File

@ -1031,17 +1031,20 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
// Try to get infer result from evaluator cache.
auto eval_result = evaluator_cache_mgr_->GetValue(args);
if (eval_result != nullptr) {
return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
if (!forbid_reuse) {
// Try to get infer result from evaluator cache.
EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args);
if (eval_result != nullptr) {
return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
}
}
// In pynative mode (engine == nullptr), it is difficult to set added_attrs to
// python object by C++ code, so we disable global eval cache in pynative mode.
const bool enable_global_cache = (engine != nullptr);
const bool enable_global_cache = (engine != nullptr && !forbid_reuse);
if (enable_global_cache) {
// Try to get infer result from global primitive eval cache.
eval_result = eval_cache_->Get(prim_py_, args);
EvalResultPtr eval_result = eval_cache_->Get(prim_py_, args);
if (eval_result != nullptr) {
// Global cache hit.
evaluator_cache_mgr_->SetValue(args, eval_result);
@ -1059,7 +1062,7 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
auto res_abs = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
// Save result to global primitive eval cache.
if (enable_global_cache) {
eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result);

View File

@ -21,4 +21,5 @@ const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__";
const char PYTHON_MS_CLASS[] = "__ms_class__";
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
const char PYTHON_FUNCTION_FORBID_REUSE[] = "__function_forbid_reuse__";
} // namespace mindspore

View File

@ -24,6 +24,7 @@ extern const char PYTHON_CELL_AS_LIST[];
extern const char PYTHON_DATACLASS_FIELDS[];
extern const char PYTHON_MS_CLASS[];
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
extern const char PYTHON_FUNCTION_FORBID_REUSE[];
} // namespace mindspore
#endif // PYBIND_API_EXPORT_FLAGS_H_

View File

@ -28,6 +28,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_HIDDEN[] = "side_effect_hidden";
inline const char GRAPH_FLAG_SIDE_EFFECT_EXCEPTION[] = "side_effect_exception";
inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate";
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";
inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result";
inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header";
inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip";

View File

@ -508,6 +508,13 @@ def ms_class(cls):
return cls
def _function_forbid_reuse(func):
if not inspect.isfunction(func):
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
setattr(func, '__function_forbid_reuse__', True)
return func
def is_pynative_parallel():
run_mode = context.get_context('mode')
parallel_mode = context.get_auto_parallel_context('parallel_mode')

View File

@ -19,14 +19,16 @@ from .. import functional as F
from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common.seed import _get_graph_seed
from ...common.api import _function_forbid_reuse
@constexpr
@constexpr(reuse_result=False)
def _get_seed(op_seed, kernel_name):
"Get the graph-level seed."
return _get_graph_seed(op_seed, kernel_name)
@_function_forbid_reuse
def normal(shape, mean, stddev, seed=None):
"""
Generates random numbers according to the Normal (or Gaussian) random number distribution.
@ -85,6 +87,7 @@ def normal(shape, mean, stddev, seed=None):
return value
@_function_forbid_reuse
def laplace(shape, mean, lambda_param, seed=None):
r"""
Generates random numbers according to the Laplace random number distribution.
@ -132,6 +135,7 @@ def laplace(shape, mean, lambda_param, seed=None):
return value
@_function_forbid_reuse
def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
"""
Generates random numbers according to the Uniform random number distribution.
@ -205,6 +209,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
return value
@_function_forbid_reuse
def gamma(shape, alpha, beta, seed=None):
"""
Generates random numbers according to the Gamma random number distribution.
@ -283,6 +288,7 @@ def gamma(shape, alpha, beta, seed=None):
return value
@_function_forbid_reuse
def poisson(shape, mean, seed=None):
r"""
Generates random numbers according to the Poisson random number distribution.
@ -334,6 +340,7 @@ def poisson(shape, mean, seed=None):
return value
@_function_forbid_reuse
def multinomial(inputs, num_sample, replacement=True, seed=None):
r"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding

View File

@ -690,7 +690,7 @@ def prim_attr_register(fn):
return deco
def constexpr(fn=None, get_instance=True, name=None):
def constexpr(fn=None, get_instance=True, name=None, reuse_result=True):
"""
Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
to compute constant value using the constants in the constructor.
@ -700,6 +700,8 @@ def constexpr(fn=None, get_instance=True, name=None):
get_instance (bool): If true, return the instance of operator,
otherwise return the operator class. Default: True.
name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None.
reuse_result (bool): If true, the operator will be executed once and reuse the result next time,
otherwise the operator will always be executed. Default: True.
Examples:
>>> from mindspore.ops import constexpr
@ -732,6 +734,8 @@ def constexpr(fn=None, get_instance=True, name=None):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)
self.set_const_prim(True)
if not reuse_result:
self.add_prim_attr('forbid_reuse_result', True)
def infer_value(self, *args):
return fn(*args)