forked from mindspore-Ecosystem/mindspore
!23039 [Fallback] Fix previous block parameters problem and add test cases.
Merge pull request !23039 from 张清华/opt_fallback
This commit is contained in:
commit
0071667155
|
@ -76,6 +76,7 @@ parse_expr_statement_white_list = (
|
||||||
"append",
|
"append",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_builtin_function_or_method_type = type(abs)
|
||||||
|
|
||||||
def create_slice_obj(start, end, step):
|
def create_slice_obj(start, end, step):
|
||||||
"""Create slice object"""
|
"""Create slice object"""
|
||||||
|
@ -248,6 +249,7 @@ def get_obj_id(obj):
|
||||||
|
|
||||||
def get_obj_type(obj):
|
def get_obj_type(obj):
|
||||||
"""Get the obj type."""
|
"""Get the obj type."""
|
||||||
|
logger.debug("Get object type: %r", obj)
|
||||||
obj_type = RESOLVE_TYPE_INVALID
|
obj_type = RESOLVE_TYPE_INVALID
|
||||||
if obj is None:
|
if obj is None:
|
||||||
obj_type = RESOLVE_TYPE_NONE
|
obj_type = RESOLVE_TYPE_NONE
|
||||||
|
@ -529,9 +531,9 @@ class Parser:
|
||||||
# Used to resolve mindspore builtin ops namespace.
|
# Used to resolve mindspore builtin ops namespace.
|
||||||
self.ms_common_ns = CellNamespace('mindspore.common')
|
self.ms_common_ns = CellNamespace('mindspore.common')
|
||||||
self.ms_ops_ns = CellNamespace('mindspore.ops')
|
self.ms_ops_ns = CellNamespace('mindspore.ops')
|
||||||
self.ms_ops_c = CellNamespace('mindspore.ops.composite')
|
self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite')
|
||||||
self.ms_ops_c_multitype = CellNamespace('mindspore.ops.composite.multitype_ops')
|
self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops')
|
||||||
self.ms_ops_p = CellNamespace('mindspore.ops.operations')
|
self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations')
|
||||||
# Used to resolve the function's globals namespace.
|
# Used to resolve the function's globals namespace.
|
||||||
self.global_namespace = CellNamespace(fn.__module__)
|
self.global_namespace = CellNamespace(fn.__module__)
|
||||||
self.function_module = fn.__module__
|
self.function_module = fn.__module__
|
||||||
|
@ -567,6 +569,11 @@ class Parser:
|
||||||
logger.error("Fn type is invalid")
|
logger.error("Fn type is invalid")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
def is_unsupported_namespace(self, value):
|
||||||
|
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
|
||||||
|
logger.debug(f'`{value}` unsupported: {unsupported}.')
|
||||||
|
return unsupported
|
||||||
|
|
||||||
def get_namespace_symbol(self, var: str):
|
def get_namespace_symbol(self, var: str):
|
||||||
"""Get symbol type and namespace and symbol."""
|
"""Get symbol type and namespace and symbol."""
|
||||||
if var in self.closure_namespace:
|
if var in self.closure_namespace:
|
||||||
|
@ -575,7 +582,7 @@ class Parser:
|
||||||
if var in self.global_namespace:
|
if var in self.global_namespace:
|
||||||
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}")
|
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}")
|
||||||
value = self.global_namespace[var]
|
value = self.global_namespace[var]
|
||||||
if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map:
|
if self.is_unsupported_namespace(value):
|
||||||
error_info = f"The builtin function '{var}' is not supported in graph mode."
|
error_info = f"The builtin function '{var}' is not supported in graph mode."
|
||||||
return None, var, error_info
|
return None, var, error_info
|
||||||
return self.global_namespace, var
|
return self.global_namespace, var
|
||||||
|
@ -604,6 +611,11 @@ class Parser:
|
||||||
logger.debug(f'Found `{name}` in mindspore root namespace.')
|
logger.debug(f'Found `{name}` in mindspore root namespace.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Check `Tensor` namespace.
|
||||||
|
if value == Tensor:
|
||||||
|
logger.debug(f'Not support `{name}`.')
|
||||||
|
return False
|
||||||
|
|
||||||
# Check `builtins` namespace.
|
# Check `builtins` namespace.
|
||||||
if hasattr(value, '__module__'): # Not types.ModuleType
|
if hasattr(value, '__module__'): # Not types.ModuleType
|
||||||
mod = value.__module__
|
mod = value.__module__
|
||||||
|
@ -613,25 +625,29 @@ class Parser:
|
||||||
|
|
||||||
# We suppose it's supported if not a Module.
|
# We suppose it's supported if not a Module.
|
||||||
if not isinstance(value, types.ModuleType):
|
if not isinstance(value, types.ModuleType):
|
||||||
|
logger.debug(f'Found `{name}`, not a module.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check supported Module namespace.
|
# Check supported Module namespace.
|
||||||
rightmost_name = name.split('.')[-1]
|
rightmost_name = name.split('.')[-1]
|
||||||
# By now, we don't check `self.ms_common_ns`.
|
|
||||||
if rightmost_name in self.ms_ops_ns:
|
if rightmost_name in self.ms_ops_ns:
|
||||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}.')
|
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.')
|
||||||
return True
|
return True
|
||||||
if rightmost_name in self.ms_ops_c:
|
if rightmost_name in self.ms_ops_c_ns:
|
||||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c.__str__()}.')
|
logger.debug(f'Found `{name}`({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.')
|
||||||
return True
|
return True
|
||||||
if rightmost_name in self.ms_ops_c_multitype:
|
if rightmost_name in self.ms_ops_c_multitype_ns:
|
||||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c_multitype.__str__()}.')
|
logger.debug(
|
||||||
|
f'Found `{name}`({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.')
|
||||||
return True
|
return True
|
||||||
if rightmost_name in self.ms_ops_p:
|
if rightmost_name in self.ms_ops_p_ns:
|
||||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_p.__str__()}.')
|
logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.')
|
||||||
|
return True
|
||||||
|
if rightmost_name in self.ms_common_ns:
|
||||||
|
logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_common_ns)}.')
|
||||||
return True
|
return True
|
||||||
if rightmost_name in trope_ns:
|
if rightmost_name in trope_ns:
|
||||||
logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}.')
|
logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {str(trope_ns)}.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.error(f'Not found `{name}` in mindspore supported namespace.')
|
logger.error(f'Not found `{name}` in mindspore supported namespace.')
|
||||||
|
@ -648,14 +664,16 @@ class Parser:
|
||||||
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
|
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
|
||||||
logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.")
|
logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.")
|
||||||
# To check if allowed to support.
|
# To check if allowed to support.
|
||||||
|
if self.is_unsupported_namespace(value):
|
||||||
|
return self.global_namespace, var, value
|
||||||
if self.is_unsupported_builtin_type(value):
|
if self.is_unsupported_builtin_type(value):
|
||||||
return self.global_namespace, var, value
|
return self.global_namespace, var, value
|
||||||
if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
|
if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
|
||||||
return self.global_namespace, var, value
|
return self.global_namespace, var, value
|
||||||
return self.global_namespace, var
|
return self.global_namespace, var
|
||||||
|
|
||||||
error_info = f"The symbol '{var}' is not supported in graph mode."
|
error_info = f"The name '{var}' is not defined, or not supported in graph mode."
|
||||||
logger.debug(error_info)
|
logger.debug(f'error info: {error_info}')
|
||||||
return None, var, error_info
|
return None, var, error_info
|
||||||
|
|
||||||
def analyze_super(self, class_type_node, subclass_instance):
|
def analyze_super(self, class_type_node, subclass_instance):
|
||||||
|
|
|
@ -284,7 +284,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!fn || py::isinstance<py::none>(fn)) {
|
if (!fn || py::isinstance<py::none>(fn)) {
|
||||||
MS_LOG(ERROR) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
|
MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
func_graph = parse::ParsePythonCode(fn);
|
func_graph = parse::ParsePythonCode(fn);
|
||||||
|
|
|
@ -579,10 +579,18 @@ std::vector<std::string> GetObjKey(const py::object &obj) {
|
||||||
|
|
||||||
// Get obj detail type
|
// Get obj detail type
|
||||||
ResolveTypeDef GetObjType(const py::object &obj) {
|
ResolveTypeDef GetObjType(const py::object &obj) {
|
||||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
try {
|
||||||
auto obj_type =
|
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||||
ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
|
auto obj_type =
|
||||||
return obj_type;
|
ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
|
||||||
|
return obj_type;
|
||||||
|
} catch (const py::error_already_set &ex) {
|
||||||
|
MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what();
|
||||||
|
std::rethrow_exception(std::current_exception());
|
||||||
|
} catch (const py::type_error &ex) {
|
||||||
|
MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what();
|
||||||
|
std::rethrow_exception(std::current_exception());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get class instance detail type.
|
// Get class instance detail type.
|
||||||
|
|
|
@ -70,7 +70,7 @@ static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &nod
|
||||||
// Write variable records the variable name to corresponding node
|
// Write variable records the variable name to corresponding node
|
||||||
void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
|
void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var " << var_name << " with node "
|
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
|
||||||
<< node->DebugString();
|
<< node->DebugString();
|
||||||
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
|
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
|
||||||
if (!is_new_name) {
|
if (!is_new_name) {
|
||||||
|
@ -97,7 +97,7 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
|
||||||
|
|
||||||
// Read variable from predecessors
|
// Read variable from predecessors
|
||||||
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
||||||
MS_LOG(DEBUG) << "Read begin, var: " << var << ", block id: " << func_graph_->debug_info()->debug_id();
|
MS_LOG(DEBUG) << "Read begin, var: " << var << ", block: " << ToString();
|
||||||
// Get var node if it is found
|
// Get var node if it is found
|
||||||
auto found = assigned_vars_.find(var);
|
auto found = assigned_vars_.find(var);
|
||||||
if (found != assigned_vars_.end()) {
|
if (found != assigned_vars_.end()) {
|
||||||
|
@ -117,7 +117,12 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
||||||
if (prev_blocks_.size() == 1) {
|
if (prev_blocks_.size() == 1) {
|
||||||
auto block = prev_blocks_[0];
|
auto block = prev_blocks_[0];
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
return block->ReadVariable(var);
|
auto res = block->ReadVariable(var);
|
||||||
|
MS_LOG(INFO) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
|
||||||
|
<< ",\nCurrent: " << py::str(global_py_params())
|
||||||
|
<< "\nInsert: " << py::str(block->global_py_params());
|
||||||
|
CopyGlobalPyParam(block->global_py_params());
|
||||||
|
return res;
|
||||||
} else if (prev_blocks_.empty()) {
|
} else if (prev_blocks_.empty()) {
|
||||||
// Get namespace and make Resolve
|
// Get namespace and make Resolve
|
||||||
auto it = var_to_resolve_.find(var);
|
auto it = var_to_resolve_.find(var);
|
||||||
|
@ -190,13 +195,14 @@ AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
|
||||||
}
|
}
|
||||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
|
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
|
||||||
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
|
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
|
||||||
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString()
|
MS_LOG(DEBUG) << "[" << func_graph()->ToString() << "] name_space: " << name_space->ToString()
|
||||||
<< ", unsupported: " << unsupported;
|
<< ", symbol: " << symbol->ToString() << ", unsupported: " << unsupported;
|
||||||
auto resolved_node = MakeResolve(name_space, symbol);
|
auto resolved_node = MakeResolve(name_space, symbol);
|
||||||
if (unsupported) {
|
if (unsupported) {
|
||||||
resolved_node->set_interpret(true);
|
resolved_node->set_interpret(true);
|
||||||
AddGlobalPyParam(symbol->name(), py_obj);
|
AddGlobalPyParam(symbol->name(), py_obj);
|
||||||
MS_LOG(INFO) << "Added global python symblol: {" << symbol->name() << " : " << py::str(py_obj) << "}";
|
MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symblol: {" << symbol->name() << " : "
|
||||||
|
<< py::str(py_obj) << "}";
|
||||||
}
|
}
|
||||||
return resolved_node;
|
return resolved_node;
|
||||||
}
|
}
|
||||||
|
@ -218,7 +224,7 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
||||||
|
|
||||||
// The fallback feature is enabled in default.
|
// The fallback feature is enabled in default.
|
||||||
// Not support change the flag during the process is alive.
|
// Not support change the flag during the process is alive.
|
||||||
static const auto use_fallback = (parser_.support_fallback() != "1" ? false : true);
|
static const auto use_fallback = (parser_.support_fallback() == "1");
|
||||||
if (!use_fallback) {
|
if (!use_fallback) {
|
||||||
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
|
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
|
||||||
return HandleNamespaceInfo(namespace_info);
|
return HandleNamespaceInfo(namespace_info);
|
||||||
|
@ -268,12 +274,12 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
||||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
|
TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
|
||||||
std::string var = phi_nodes_[phi];
|
std::string var = phi_nodes_[phi];
|
||||||
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
|
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
|
||||||
<< " for var " << var;
|
<< " for var `" << var << "`";
|
||||||
auto removable = CollectRemovablePhi(phi);
|
auto removable = CollectRemovablePhi(phi);
|
||||||
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
|
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
|
||||||
if (removable) {
|
if (removable) {
|
||||||
MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
|
MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
|
||||||
<< " var " << var;
|
<< " var `" << var << "`";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto &pred : prev_blocks_) {
|
for (auto &pred : prev_blocks_) {
|
||||||
|
@ -402,12 +408,10 @@ CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
|
||||||
|
|
||||||
// Perform a jump from this block to target block
|
// Perform a jump from this block to target block
|
||||||
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
|
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
|
||||||
MS_LOG(DEBUG) << "Jump from " << func_graph_->debug_info()->debug_id() << " to "
|
MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString();
|
||||||
<< target_block->func_graph()->debug_info()->debug_id();
|
|
||||||
MS_EXCEPTION_IF_NULL(target_block);
|
MS_EXCEPTION_IF_NULL(target_block);
|
||||||
if (is_dead_block_) {
|
if (is_dead_block_) {
|
||||||
MS_LOG(DEBUG) << "Dead code block should not jump to other block! Block id:"
|
MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString();
|
||||||
<< func_graph_->debug_info()->debug_id();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (func_graph_->get_return() != nullptr) {
|
if (func_graph_->get_return() != nullptr) {
|
||||||
|
|
|
@ -49,6 +49,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
||||||
virtual ~FunctionBlock() = default;
|
virtual ~FunctionBlock() = default;
|
||||||
|
|
||||||
FuncGraphPtr func_graph() { return func_graph_; }
|
FuncGraphPtr func_graph() { return func_graph_; }
|
||||||
|
std::string ToString() const { return func_graph_->ToString(); }
|
||||||
void WriteVariable(const std::string &var_name, const AnfNodePtr &node);
|
void WriteVariable(const std::string &var_name, const AnfNodePtr &node);
|
||||||
AnfNodePtr ReadVariable(const std::string &var_name);
|
AnfNodePtr ReadVariable(const std::string &var_name);
|
||||||
void AddPrevBlock(const FunctionBlockPtr &block);
|
void AddPrevBlock(const FunctionBlockPtr &block);
|
||||||
|
@ -85,6 +86,13 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
||||||
py::dict &global_py_params() { return global_py_params_; }
|
py::dict &global_py_params() { return global_py_params_; }
|
||||||
void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
|
void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
|
||||||
void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
|
void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
|
||||||
|
void CopyGlobalPyParam(const py::dict &symbols) {
|
||||||
|
for (auto ¶m : symbols) {
|
||||||
|
if (!global_py_params_.contains(param.first)) {
|
||||||
|
global_py_params_[param.first] = param.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
|
std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
|
||||||
return {local_py_params_keys_, local_py_params_values_};
|
return {local_py_params_keys_, local_py_params_values_};
|
||||||
|
|
|
@ -167,7 +167,7 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
|
||||||
py::str desc =
|
py::str desc =
|
||||||
python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
|
python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
|
||||||
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ". "
|
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ". "
|
||||||
<< "Func graph id: " << func_graph->debug_info()->debug_id();
|
<< "FuncGraph: " << func_graph->ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -733,7 +733,9 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
|
||||||
} else {
|
} else {
|
||||||
auto kw_key_c = kw_key.cast<std::string>();
|
auto kw_key_c = kw_key.cast<std::string>();
|
||||||
keys.push_back(NewValueNode(kw_key_c));
|
keys.push_back(NewValueNode(kw_key_c));
|
||||||
values.push_back(ParseExprNode(block, kw_value));
|
auto node = ParseExprNode(block, kw_value);
|
||||||
|
node = HandleInterpret(block, node, kw_value);
|
||||||
|
values.push_back(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto keys_tuple = GenerateMakeTuple(block, keys);
|
auto keys_tuple = GenerateMakeTuple(block, keys);
|
||||||
|
@ -1070,20 +1072,21 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
|
||||||
MS_LOG(DEBUG) << "Process ast AugAssign";
|
MS_LOG(DEBUG) << "Process ast AugAssign";
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
MS_EXCEPTION_IF_NULL(ast_);
|
MS_EXCEPTION_IF_NULL(ast_);
|
||||||
py::object target_obj = python_adapter::GetPyObjAttr(node, "target");
|
|
||||||
py::object op_obj = python_adapter::GetPyObjAttr(node, "op");
|
py::object target_object = python_adapter::GetPyObjAttr(node, "target");
|
||||||
py::object value_obj = python_adapter::GetPyObjAttr(node, "value");
|
py::object op_object = python_adapter::GetPyObjAttr(node, "op");
|
||||||
|
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
|
||||||
AnfNodePtr target_node = nullptr;
|
AnfNodePtr target_node = nullptr;
|
||||||
AnfNodePtr op_node = block->MakeResolveAstOp(op_obj);
|
AnfNodePtr op_node = block->MakeResolveAstOp(op_object);
|
||||||
AnfNodePtr value_node = ParseExprNode(block, value_obj);
|
AnfNodePtr value_node = ParseExprNode(block, value_object);
|
||||||
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj)));
|
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
|
||||||
|
|
||||||
if (ast_type == AST_SUB_TYPE_NAME) {
|
if (ast_type == AST_SUB_TYPE_NAME) {
|
||||||
target_node = ParseName(block, target_obj);
|
target_node = ParseName(block, target_object);
|
||||||
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
|
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
|
||||||
target_node = ParseSubscript(block, target_obj);
|
target_node = ParseSubscript(block, target_object);
|
||||||
} else if (ast_->IsClassMember(target_obj)) {
|
} else if (ast_->IsClassMember(target_object)) {
|
||||||
target_node = ParseAttribute(block, target_obj);
|
target_node = ParseAttribute(block, target_object);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Not supported augassign";
|
MS_LOG(EXCEPTION) << "Not supported augassign";
|
||||||
}
|
}
|
||||||
|
@ -1091,8 +1094,7 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
|
||||||
MS_LOG(EXCEPTION) << "Can not get target node ";
|
MS_LOG(EXCEPTION) << "Can not get target node ";
|
||||||
}
|
}
|
||||||
CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
|
CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
|
||||||
WriteAssignVars(block, target_obj, augassign_app);
|
WriteAssignVars(block, target_object, augassign_app);
|
||||||
|
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
// Process global declaration such as 'global x';
|
// Process global declaration such as 'global x';
|
||||||
|
@ -1668,11 +1670,11 @@ AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_obj,
|
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object,
|
||||||
const AnfNodePtr &assigned_node) {
|
const AnfNodePtr &assigned_node) {
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
MS_EXCEPTION_IF_NULL(assigned_node);
|
MS_EXCEPTION_IF_NULL(assigned_node);
|
||||||
py::str name = python_adapter::GetPyObjAttr(target_obj, "id");
|
py::str name = python_adapter::GetPyObjAttr(target_object, "id");
|
||||||
std::string name_id = name;
|
std::string name_id = name;
|
||||||
assigned_node->debug_info()->set_name(name_id);
|
assigned_node->debug_info()->set_name(name_id);
|
||||||
// Set the debug name of the constant graph
|
// Set the debug name of the constant graph
|
||||||
|
@ -1683,16 +1685,16 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
|
||||||
fg->debug_info()->set_name(name_id);
|
fg->debug_info()->set_name(name_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Assign name: " << name_id << " to node: " << assigned_node;
|
MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
|
||||||
block->AddLocalPyParam(name_id, assigned_node);
|
block->AddLocalPyParam(name_id, assigned_node);
|
||||||
block->WriteVariable(name_id, assigned_node);
|
block->WriteVariable(name_id, assigned_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_obj,
|
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object,
|
||||||
const AnfNodePtr &assigned_node) {
|
const AnfNodePtr &assigned_node) {
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||||
py::list items = python_adapter::GetPyObjAttr(target_obj, "elts");
|
py::list items = python_adapter::GetPyObjAttr(target_object, "elts");
|
||||||
for (size_t i = 0; i < items.size(); i++) {
|
for (size_t i = 0; i < items.size(); i++) {
|
||||||
// Use the Primitive replace the operation resolve node (getitem),
|
// Use the Primitive replace the operation resolve node (getitem),
|
||||||
// because the getitem will eventually be converted to Primitive node
|
// because the getitem will eventually be converted to Primitive node
|
||||||
|
@ -1704,13 +1706,13 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_obj,
|
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object,
|
||||||
const AnfNodePtr &assigned_node) {
|
const AnfNodePtr &assigned_node) {
|
||||||
// Now only support the self.xx = xxxxx, can't support x.y = xxxx
|
// Now only support the self.xx = xxxxx, can't support x.y = xxxx
|
||||||
AnfNodePtr target_node = ParseExprNode(block, target_obj);
|
AnfNodePtr target_node = ParseExprNode(block, target_object);
|
||||||
MS_EXCEPTION_IF_NULL(target_node);
|
MS_EXCEPTION_IF_NULL(target_node);
|
||||||
|
|
||||||
auto attr_name = target_obj.attr("attr").cast<std::string>();
|
auto attr_name = target_object.attr("attr").cast<std::string>();
|
||||||
std::string var_name = "self." + attr_name;
|
std::string var_name = "self." + attr_name;
|
||||||
|
|
||||||
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
|
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
|
||||||
|
@ -1733,12 +1735,12 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
|
||||||
block->SetStateAssign(target_node, assigned_node);
|
block->SetStateAssign(target_node, assigned_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_obj,
|
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object,
|
||||||
const AnfNodePtr &assigned_node) {
|
const AnfNodePtr &assigned_node) {
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
|
AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
|
||||||
py::object value_obj = python_adapter::GetPyObjAttr(target_obj, "value");
|
py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value");
|
||||||
py::object slice_obj = python_adapter::GetPyObjAttr(target_obj, "slice");
|
py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice");
|
||||||
AnfNodePtr value_node = ParseExprNode(block, value_obj);
|
AnfNodePtr value_node = ParseExprNode(block, value_obj);
|
||||||
AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
|
AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
|
||||||
CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
|
CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
|
||||||
|
@ -1776,19 +1778,19 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
|
||||||
block->WriteVariable(var_name, setitem_app);
|
block->WriteVariable(var_name, setitem_app);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_obj,
|
void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object,
|
||||||
const AnfNodePtr &value_node) {
|
const AnfNodePtr &value_node) {
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
MS_LOG(DEBUG) << "Process WriteAssignVars";
|
MS_LOG(DEBUG) << "Process WriteAssignVars";
|
||||||
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj)));
|
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
|
||||||
if (ast_type == AST_SUB_TYPE_NAME) {
|
if (ast_type == AST_SUB_TYPE_NAME) {
|
||||||
HandleAssignName(block, target_obj, value_node);
|
HandleAssignName(block, target_object, value_node);
|
||||||
} else if (ast_type == AST_SUB_TYPE_TUPLE) {
|
} else if (ast_type == AST_SUB_TYPE_TUPLE) {
|
||||||
HandleAssignTuple(block, target_obj, value_node);
|
HandleAssignTuple(block, target_object, value_node);
|
||||||
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
|
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
|
||||||
HandleAssignSubscript(block, target_obj, value_node);
|
HandleAssignSubscript(block, target_object, value_node);
|
||||||
} else if (ast_->IsClassMember(target_obj)) {
|
} else if (ast_->IsClassMember(target_object)) {
|
||||||
HandleAssignClassMember(block, target_obj, value_node);
|
HandleAssignClassMember(block, target_object, value_node);
|
||||||
} else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
|
} else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
|
||||||
MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network. \n\n"
|
MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network. \n\n"
|
||||||
<< trace::GetDebugInfo(value_node->debug_info());
|
<< trace::GetDebugInfo(value_node->debug_info());
|
||||||
|
@ -1802,7 +1804,7 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
||||||
const py::object &value_object) {
|
const py::object &value_object) {
|
||||||
// The fallback feature is enabled in default.
|
// The fallback feature is enabled in default.
|
||||||
// Not support change the flag during the process is alive.
|
// Not support change the flag during the process is alive.
|
||||||
static const auto use_fallback = (support_fallback() != "1" ? false : true);
|
static const auto use_fallback = (support_fallback() == "1");
|
||||||
if (!use_fallback) {
|
if (!use_fallback) {
|
||||||
return value_node;
|
return value_node;
|
||||||
}
|
}
|
||||||
|
@ -1810,9 +1812,10 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
||||||
AnfNodePtr interpreted_node = value_node;
|
AnfNodePtr interpreted_node = value_node;
|
||||||
if (value_node->interpret()) {
|
if (value_node->interpret()) {
|
||||||
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
|
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
|
||||||
MS_LOG(INFO) << "script_text: " << script_text << ", value_node: " << value_node->DebugString(2);
|
|
||||||
// Prepare global parameters.
|
|
||||||
py::dict global_dict = block->global_py_params();
|
py::dict global_dict = block->global_py_params();
|
||||||
|
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: " << script_text
|
||||||
|
<< ", value_node: " << value_node->DebugString(3) << ", global_dict: " << py::str(global_dict);
|
||||||
|
// Prepare global parameters.
|
||||||
ValuePtr globals_converted_value = nullptr;
|
ValuePtr globals_converted_value = nullptr;
|
||||||
if (!ConvertData(global_dict, &globals_converted_value)) {
|
if (!ConvertData(global_dict, &globals_converted_value)) {
|
||||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||||
|
|
|
@ -60,7 +60,7 @@ abstract::AbstractBasePtr ClassType::ToAbstract() {
|
||||||
// The fallback feature is enabled in default.
|
// The fallback feature is enabled in default.
|
||||||
// Not support change the flag during the process is alive.
|
// Not support change the flag during the process is alive.
|
||||||
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
|
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
|
||||||
static const auto use_fallback = (support_fallback != "1" ? false : true);
|
static const auto use_fallback = (support_fallback == "1");
|
||||||
if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
|
if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
|
||||||
return abs_scalar;
|
return abs_scalar;
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,6 +112,16 @@ class PyObjectWrapper : public Named {
|
||||||
py::object obj_;
|
py::object obj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// InterpretedObject class wrappers interpreted python object.
|
||||||
|
class InterpretedObject : public PyObjectWrapper {
|
||||||
|
public:
|
||||||
|
explicit InterpretedObject(const py::object &obj, const std::string &name = "Interpreted object")
|
||||||
|
: PyObjectWrapper(obj, name) {}
|
||||||
|
~InterpretedObject() override = default;
|
||||||
|
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
};
|
||||||
|
|
||||||
// ClassObject class wrappers dataclass
|
// ClassObject class wrappers dataclass
|
||||||
class ClassObject : public PyObjectWrapper {
|
class ClassObject : public PyObjectWrapper {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -207,7 +207,7 @@ class MS_CORE_API AnfNode : public Base {
|
||||||
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
|
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
|
||||||
|
|
||||||
bool interpret() { return interpret_; }
|
bool interpret() { return interpret_; }
|
||||||
void set_interpret(const bool interpret) { interpret_ = interpret; }
|
void set_interpret(const bool &interpret) { interpret_ = interpret; }
|
||||||
|
|
||||||
AnfNodePtr interpreted_node() { return interpreted_node_; }
|
AnfNodePtr interpreted_node() { return interpreted_node_; }
|
||||||
void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; }
|
void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; }
|
||||||
|
|
|
@ -176,17 +176,23 @@ class MS_CORE_API ValueDictionary : public Value {
|
||||||
keys.push_back(kv.first);
|
keys.push_back(kv.first);
|
||||||
values.push_back(kv.second);
|
values.push_back(kv.second);
|
||||||
}
|
}
|
||||||
buffer << "(Dict: "
|
buffer << "dict: {keys: (";
|
||||||
<< " keys:(";
|
for (size_t i = 0; i < keys.size(); i++) {
|
||||||
for (const auto &key : keys) {
|
buffer << keys[i];
|
||||||
buffer << key << ", ";
|
if (i != keys.size() - 1) {
|
||||||
|
buffer << ", ";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
buffer << ") values:(";
|
buffer << "), values: (";
|
||||||
for (const auto &value : values) {
|
for (size_t i = 0; i < values.size(); i++) {
|
||||||
|
const auto &value = values[i];
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
buffer << value->ToString() << ", ";
|
buffer << value->ToString();
|
||||||
|
if (i != values.size() - 1) {
|
||||||
|
buffer << ", ";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
buffer << ")";
|
buffer << ")}";
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
|
|
@ -18,8 +18,10 @@ import numpy as np
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor, ms_function, context
|
from mindspore import Tensor, ms_function, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.common._monad as monad
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
@ -38,16 +40,15 @@ def test_increment():
|
||||||
|
|
||||||
|
|
||||||
@ms_function
|
@ms_function
|
||||||
def np_fallback_func():
|
def use_monad(x, y):
|
||||||
array_x = [2, 3, 4, 5]
|
res = P.Mul()(x, y)
|
||||||
np_x = np.array(array_x).astype(np.float32)
|
res = F.depend(res, monad.U)
|
||||||
me_x = Tensor(np_x)
|
return res
|
||||||
me_x = me_x + me_x
|
|
||||||
return me_x
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason='Graph fallback feature is not supported yet')
|
def test_use_monad():
|
||||||
def test_np_fallback_func():
|
x = Tensor(1.0, mstype.float32)
|
||||||
print(np_fallback_func())
|
y = Tensor(1.0, mstype.float32)
|
||||||
|
print(use_monad(x, y))
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
|
@ -64,3 +65,74 @@ class Net(nn.Cell):
|
||||||
def test_builtins_len():
|
def test_builtins_len():
|
||||||
net = Net()
|
net = Net()
|
||||||
net()
|
net()
|
||||||
|
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def np_fallback_func():
|
||||||
|
array_x = tuple([2, 3, 4, 5])
|
||||||
|
np_x = np.array(array_x).astype(np.float32)
|
||||||
|
me_x = Tensor(np_x)
|
||||||
|
me_x = me_x + me_x
|
||||||
|
return me_x
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||||
|
def test_np_fallback_func():
|
||||||
|
print(np_fallback_func())
|
||||||
|
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def div_mod_func(x, y):
|
||||||
|
a = divmod(x, y)
|
||||||
|
return Tensor(a)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||||
|
def test_div_mod_func():
|
||||||
|
print(div_mod_func(8, 3)) # (2, 2)
|
||||||
|
|
||||||
|
|
||||||
|
# NameError: name 'Tensor' is not defined.
|
||||||
|
@ms_function
|
||||||
|
def select_func(cond, x, y):
|
||||||
|
if isinstance(cond, (tuple, list)):
|
||||||
|
output = y
|
||||||
|
elif isinstance(cond, Tensor):
|
||||||
|
output = F.select(cond, x, y)
|
||||||
|
else:
|
||||||
|
output = x
|
||||||
|
return output
|
||||||
|
|
||||||
|
def test_select_func():
|
||||||
|
cond = Tensor([True, False])
|
||||||
|
x = Tensor([2, 3], mstype.float32)
|
||||||
|
y = Tensor([1, 2], mstype.float32)
|
||||||
|
print(select_func(cond, x, y))
|
||||||
|
|
||||||
|
|
||||||
|
# Not interpret 'Tensor'.
|
||||||
|
@ms_function
|
||||||
|
def select_func2(cond, x, y):
|
||||||
|
if isinstance(cond, (tuple, list)):
|
||||||
|
output = y
|
||||||
|
if isinstance(cond, Tensor):
|
||||||
|
output = F.select(cond, x, y)
|
||||||
|
else:
|
||||||
|
output = x
|
||||||
|
return output
|
||||||
|
|
||||||
|
def test_select_func2():
|
||||||
|
cond = Tensor([True, False])
|
||||||
|
x = Tensor([2, 3], mstype.float32)
|
||||||
|
y = Tensor([1, 2], mstype.float32)
|
||||||
|
print(select_func2(cond, x, y))
|
||||||
|
|
||||||
|
|
||||||
|
# NameError: name 'Tensor' is not defined.
|
||||||
|
@ms_function
|
||||||
|
def slice_func(a, b):
|
||||||
|
a[1:3, ::] = b
|
||||||
|
return a
|
||||||
|
|
||||||
|
def test_slice_func():
|
||||||
|
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
|
||||||
|
b = Tensor([1], dtype=mstype.float32)
|
||||||
|
print(slice_func(a, b))
|
||||||
|
|
Loading…
Reference in New Issue