forked from mindspore-Ecosystem/mindspore
!32292 Resolve the duplicate seed problem of random operator.
Merge pull request !32292 from 张清华/opt
This commit is contained in:
commit
96681e20b7
|
@ -574,7 +574,8 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
|||
MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_id;
|
||||
func_graph = value->cast<FuncGraphPtr>();
|
||||
if (!func_graph->dropped()) {
|
||||
if (pipeline::GetJitLevel() == "o0") {
|
||||
bool forbid_reuse = py::hasattr(obj, PYTHON_FUNCTION_FORBID_REUSE);
|
||||
if (forbid_reuse || pipeline::GetJitLevel() == "o0") {
|
||||
return BasicClone(func_graph);
|
||||
}
|
||||
return func_graph;
|
||||
|
|
|
@ -1031,17 +1031,20 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
|
|||
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
// Try to get infer result from evaluator cache.
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args);
|
||||
if (eval_result != nullptr) {
|
||||
return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
|
||||
auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
|
||||
if (!forbid_reuse) {
|
||||
// Try to get infer result from evaluator cache.
|
||||
EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args);
|
||||
if (eval_result != nullptr) {
|
||||
return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
|
||||
}
|
||||
}
|
||||
// In pynative mode (engine == nullptr), it is difficult to set added_attrs to
|
||||
// python object by C++ code, so we disable global eval cache in pynative mode.
|
||||
const bool enable_global_cache = (engine != nullptr);
|
||||
const bool enable_global_cache = (engine != nullptr && !forbid_reuse);
|
||||
if (enable_global_cache) {
|
||||
// Try to get infer result from global primitive eval cache.
|
||||
eval_result = eval_cache_->Get(prim_py_, args);
|
||||
EvalResultPtr eval_result = eval_cache_->Get(prim_py_, args);
|
||||
if (eval_result != nullptr) {
|
||||
// Global cache hit.
|
||||
evaluator_cache_mgr_->SetValue(args, eval_result);
|
||||
|
@ -1059,7 +1062,7 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
|
|||
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
|
||||
auto res_abs = PyInferRes2Abstract(prim_py_, output);
|
||||
MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
|
||||
eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
|
||||
EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
|
||||
// Save result to global primitive eval cache.
|
||||
if (enable_global_cache) {
|
||||
eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result);
|
||||
|
|
|
@ -21,4 +21,5 @@ const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
|
|||
const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__";
|
||||
const char PYTHON_MS_CLASS[] = "__ms_class__";
|
||||
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
|
||||
const char PYTHON_FUNCTION_FORBID_REUSE[] = "__function_forbid_reuse__";
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@ extern const char PYTHON_CELL_AS_LIST[];
|
|||
extern const char PYTHON_DATACLASS_FIELDS[];
|
||||
extern const char PYTHON_MS_CLASS[];
|
||||
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
|
||||
extern const char PYTHON_FUNCTION_FORBID_REUSE[];
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PYBIND_API_EXPORT_FLAGS_H_
|
||||
|
|
|
@ -28,6 +28,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_HIDDEN[] = "side_effect_hidden";
|
|||
inline const char GRAPH_FLAG_SIDE_EFFECT_EXCEPTION[] = "side_effect_exception";
|
||||
inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate";
|
||||
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";
|
||||
inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result";
|
||||
inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header";
|
||||
inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip";
|
||||
|
||||
|
|
|
@ -508,6 +508,13 @@ def ms_class(cls):
|
|||
return cls
|
||||
|
||||
|
||||
def _function_forbid_reuse(func):
|
||||
if not inspect.isfunction(func):
|
||||
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
|
||||
setattr(func, '__function_forbid_reuse__', True)
|
||||
return func
|
||||
|
||||
|
||||
def is_pynative_parallel():
|
||||
run_mode = context.get_context('mode')
|
||||
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
||||
|
|
|
@ -19,14 +19,16 @@ from .. import functional as F
|
|||
from .multitype_ops import _constexpr_utils as const_utils
|
||||
from ...common import dtype as mstype
|
||||
from ...common.seed import _get_graph_seed
|
||||
from ...common.api import _function_forbid_reuse
|
||||
|
||||
|
||||
@constexpr
|
||||
@constexpr(reuse_result=False)
|
||||
def _get_seed(op_seed, kernel_name):
|
||||
"Get the graph-level seed."
|
||||
return _get_graph_seed(op_seed, kernel_name)
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def normal(shape, mean, stddev, seed=None):
|
||||
"""
|
||||
Generates random numbers according to the Normal (or Gaussian) random number distribution.
|
||||
|
@ -85,6 +87,7 @@ def normal(shape, mean, stddev, seed=None):
|
|||
return value
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def laplace(shape, mean, lambda_param, seed=None):
|
||||
r"""
|
||||
Generates random numbers according to the Laplace random number distribution.
|
||||
|
@ -132,6 +135,7 @@ def laplace(shape, mean, lambda_param, seed=None):
|
|||
return value
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
|
||||
"""
|
||||
Generates random numbers according to the Uniform random number distribution.
|
||||
|
@ -205,6 +209,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
|
|||
return value
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def gamma(shape, alpha, beta, seed=None):
|
||||
"""
|
||||
Generates random numbers according to the Gamma random number distribution.
|
||||
|
@ -283,6 +288,7 @@ def gamma(shape, alpha, beta, seed=None):
|
|||
return value
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def poisson(shape, mean, seed=None):
|
||||
r"""
|
||||
Generates random numbers according to the Poisson random number distribution.
|
||||
|
@ -334,6 +340,7 @@ def poisson(shape, mean, seed=None):
|
|||
return value
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def multinomial(inputs, num_sample, replacement=True, seed=None):
|
||||
r"""
|
||||
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
|
||||
|
|
|
@ -690,7 +690,7 @@ def prim_attr_register(fn):
|
|||
return deco
|
||||
|
||||
|
||||
def constexpr(fn=None, get_instance=True, name=None):
|
||||
def constexpr(fn=None, get_instance=True, name=None, reuse_result=True):
|
||||
"""
|
||||
Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
|
||||
to compute constant value using the constants in the constructor.
|
||||
|
@ -700,6 +700,8 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
get_instance (bool): If true, return the instance of operator,
|
||||
otherwise return the operator class. Default: True.
|
||||
name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None.
|
||||
reuse_result (bool): If true, the operator will be executed once and reuse the result next time,
|
||||
otherwise the operator will always be executed. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import constexpr
|
||||
|
@ -732,6 +734,8 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
op_name = name if name else fn.__name__
|
||||
PrimitiveWithInfer.__init__(self, op_name)
|
||||
self.set_const_prim(True)
|
||||
if not reuse_result:
|
||||
self.add_prim_attr('forbid_reuse_result', True)
|
||||
|
||||
def infer_value(self, *args):
|
||||
return fn(*args)
|
||||
|
|
Loading…
Reference in New Issue