forked from mindspore-Ecosystem/mindspore
回退 'Pull Request !38232 : code warning clean.'
This commit is contained in:
parent
27566d9a76
commit
b6102af258
|
@ -18,8 +18,13 @@
|
|||
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/hash_map.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
@ -32,9 +37,7 @@ namespace mindspore {
|
|||
namespace parse {
|
||||
namespace {
|
||||
struct PyDataToValueRegister {
|
||||
PyDataToValueRegister() noexcept {
|
||||
python_adapter::PyAdapterCallback::SetPyDataToValueHandler(data_converter::PyDataToValue);
|
||||
}
|
||||
PyDataToValueRegister() { python_adapter::PyAdapterCallback::SetPyDataToValueHandler(data_converter::PyDataToValue); }
|
||||
} callback_register;
|
||||
} // namespace
|
||||
using Tensor = mindspore::tensor::Tensor;
|
||||
|
@ -369,10 +372,10 @@ ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
|
|||
|
||||
ValuePtr ConvertConstantNumpyNumber(const py::object &obj, ResolveTypeDef obj_type) {
|
||||
if (obj_type == RESOLVE_TYPE_NUMPY_INT_NUMBER) {
|
||||
MS_LOG(INFO) << "Convert constant numpy int64_t number:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "Convert constant numpy int64_t number:" << (std::string)py::str(obj);
|
||||
return MakeValue(py::cast<int64_t>(obj));
|
||||
} else if (obj_type == RESOLVE_TYPE_NUMPY_FLOAT_NUMBER) {
|
||||
MS_LOG(INFO) << "Convert constant numpy float number::" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "Convert constant numpy float number::" << (std::string)py::str(obj);
|
||||
return MakeValue(py::cast<float>(obj));
|
||||
}
|
||||
MS_LOG(ERROR) << "Convert numpy number type is invalid, obj: " << py::str(obj);
|
||||
|
@ -381,7 +384,7 @@ ValuePtr ConvertConstantNumpyNumber(const py::object &obj, ResolveTypeDef obj_ty
|
|||
|
||||
ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
|
||||
auto obj_type = data_converter::GetObjType(obj);
|
||||
MS_LOG(DEBUG) << "Converting the object(" << (py::cast<std::string>(obj)) << ") detail type: " << obj_type << " ";
|
||||
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
|
||||
if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
|
||||
MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
|
||||
std::string desc = py::str(obj);
|
||||
|
@ -625,8 +628,8 @@ const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs()
|
|||
|
||||
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
|
||||
|
||||
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
|
||||
if (object_map_.count(obj_key) != 0) {
|
||||
bool GetObjectValue(const std::string &obj_key, ValuePtr *data) {
|
||||
if (object_map_.count(obj_key)) {
|
||||
*data = object_map_[obj_key];
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -18,10 +18,13 @@
|
|||
|
||||
#include "pipeline/jit/parse/function_block.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include <queue>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
|
@ -34,8 +37,10 @@ namespace mindspore {
|
|||
namespace py = pybind11;
|
||||
|
||||
namespace parse {
|
||||
FunctionBlock::FunctionBlock(const Parser &parser)
|
||||
: func_graph_(std::make_shared<FuncGraph>()), parser_(parser), matured_(false) {}
|
||||
FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
|
||||
func_graph_ = std::make_shared<FuncGraph>();
|
||||
matured_ = false;
|
||||
}
|
||||
|
||||
void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
|
||||
|
||||
|
@ -132,7 +137,7 @@ std::pair<AnfNodePtr, bool> FunctionBlock::FindPredInterpretNode(const std::stri
|
|||
while (!block_queue.empty()) {
|
||||
const auto cur_block = block_queue.front();
|
||||
block_queue.pop();
|
||||
(void)visited_block.insert(cur_block);
|
||||
visited_block.insert(cur_block);
|
||||
auto pred_node = cur_block->ReadLocalVariable(var_name);
|
||||
if (pred_node != nullptr) {
|
||||
has_found = true;
|
||||
|
@ -370,7 +375,7 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
|
|||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
|
||||
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? py::cast<std::string>(name_space->obj()) : "null namespace")
|
||||
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
|
||||
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
|
||||
ValueNodePtr module_node = NewValueNode(name_space);
|
||||
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
|
||||
|
@ -645,7 +650,7 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
|
|||
std::vector<AnfNodePtr> states;
|
||||
states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
constexpr int recursive_level = 2;
|
||||
for (const auto &node : isolated_nodes_) {
|
||||
for (auto &node : isolated_nodes_) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(recursive_level) << " in "
|
||||
<< func_graph_->ToString();
|
||||
|
|
|
@ -61,7 +61,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
void Mature();
|
||||
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
|
||||
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
|
||||
void Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args);
|
||||
void Jump(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &args);
|
||||
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
|
||||
CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
|
||||
const AnfNodePtr &false_block_call);
|
||||
|
@ -76,9 +76,9 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
AnfNodePtr MakeResolveSymbol(const std::string &value);
|
||||
AnfNodePtr MakeResolveOperation(const std::string &value);
|
||||
AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
|
||||
AnfNodePtr GetResolveNode(const py::tuple &info);
|
||||
AnfNodePtr HandleNamespaceInfo(const py::tuple &info);
|
||||
AnfNodePtr HandleBuiltinNamespaceInfo(const py::tuple &info);
|
||||
AnfNodePtr GetResolveNode(const py::tuple &namespace_info);
|
||||
AnfNodePtr HandleNamespaceInfo(const py::tuple &namespace_info);
|
||||
AnfNodePtr HandleBuiltinNamespaceInfo(const py::tuple &namespace_info);
|
||||
AnfNodePtr MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
|
||||
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node);
|
||||
const mindspore::HashMap<ParameterPtr, AnfNodePtr> &removable_phis() const { return removable_phis_; }
|
||||
|
@ -111,11 +111,11 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
|
||||
// Call this method to update or add a variable.
|
||||
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
|
||||
const auto key_iter = local_py_params_keys_.find(name);
|
||||
auto key_iter = local_py_params_keys_.find(name);
|
||||
if (key_iter == local_py_params_keys_.end()) {
|
||||
MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString();
|
||||
(void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name)));
|
||||
(void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(name, node));
|
||||
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name)));
|
||||
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(name, node));
|
||||
} else {
|
||||
// Find the same position in 'values', and update the node.
|
||||
MS_LOG(DEBUG) << "Update '" << name << "', " << local_py_params_values_[name]->DebugString() << " -> "
|
||||
|
@ -131,10 +131,10 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
}
|
||||
for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
|
||||
const std::string &cur_key_name = iter->first;
|
||||
const auto key_iter = local_py_params_keys_.find(cur_key_name);
|
||||
auto key_iter = local_py_params_keys_.find(cur_key_name);
|
||||
if (key_iter == local_py_params_keys_.end()) {
|
||||
(void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
|
||||
(void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
|
||||
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
|
||||
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
|
||||
MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
|
||||
} else {
|
||||
// The local variable is already in the current block. This means the current block has multiples previous
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include "utils/hash_map.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
@ -67,8 +68,9 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
|
|||
|
||||
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
||||
|
||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast)
|
||||
: ast_(ast), errcode_(PARSE_SUCCESS), support_fallback_(common::GetEnv("MS_DEV_ENABLE_FALLBACK")) {
|
||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
|
||||
support_fallback_ = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||
errcode_ = PARSE_SUCCESS;
|
||||
BuildMethodMap();
|
||||
}
|
||||
|
||||
|
@ -384,10 +386,10 @@ bool CheckMiddleGraphOutputPyInterpret(
|
|||
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
|
||||
return false;
|
||||
}
|
||||
bool exist_interpret =
|
||||
|
||||
contain_py_interpret |=
|
||||
std::any_of(middle_graph_output_cnode->inputs().cbegin() + 1, middle_graph_output_cnode->inputs().cend(),
|
||||
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimPyInterpret); });
|
||||
contain_py_interpret |= exist_interpret;
|
||||
if (contain_py_interpret) {
|
||||
return true;
|
||||
}
|
||||
|
@ -716,7 +718,7 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
|
|||
// Call the process function
|
||||
std::string node_name = node_type->node_name();
|
||||
MS_LOG(DEBUG) << "Ast node is " << node_name;
|
||||
if (stmt_method_map_.count(node_name) != 0) {
|
||||
if (stmt_method_map_.count(node_name)) {
|
||||
auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return stmt_block;
|
||||
|
@ -740,7 +742,7 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
|
|||
// Call the process function
|
||||
std::string node_name = node_type->node_name();
|
||||
MS_LOG(DEBUG) << "Ast node is " << node_name;
|
||||
if (expr_method_map_.count(node_name) != 0) {
|
||||
if (expr_method_map_.count(node_name)) {
|
||||
auto expr_node = (this->*expr_method_map_[node_name])(block, node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return expr_node;
|
||||
|
@ -819,7 +821,7 @@ LocationPtr Parser::GetLocation(const py::object &node) const {
|
|||
}
|
||||
|
||||
void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) const {
|
||||
const FunctionBlockPtr &false_block) {
|
||||
MS_EXCEPTION_IF_NULL(true_block);
|
||||
MS_EXCEPTION_IF_NULL(false_block);
|
||||
true_block->AddPrevBlock(pre_block);
|
||||
|
@ -977,11 +979,11 @@ AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
|
|||
MS_LOG(DEBUG) << "Process ast Num";
|
||||
py::object obj = python_adapter::GetPyObjAttr(node, "n");
|
||||
if (py::isinstance<py::int_>(obj)) {
|
||||
MS_LOG(INFO) << "The Num is int64_t:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj);
|
||||
auto data = py::cast<int64_t>(obj);
|
||||
return NewValueNode(data);
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
MS_LOG(INFO) << "The Num is float:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
|
||||
auto data = py::cast<float>(obj);
|
||||
return NewValueNode(data);
|
||||
} else {
|
||||
|
@ -1002,26 +1004,26 @@ AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &nod
|
|||
MS_LOG(DEBUG) << "Process ast Constant";
|
||||
py::object obj = python_adapter::GetPyObjAttr(node, "value");
|
||||
if (py::isinstance<py::bool_>(obj)) {
|
||||
MS_LOG(INFO) << "The Constant is bool:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj);
|
||||
return NewValueNode(py::cast<bool>(obj));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
MS_LOG(INFO) << "The Constant is int64_t:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj);
|
||||
return NewValueNode(py::cast<int64_t>(obj));
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
MS_LOG(INFO) << "The Constant is float:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj);
|
||||
return NewValueNode(py::cast<float>(obj));
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
MS_LOG(INFO) << "The Constant is string:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj);
|
||||
return NewValueNode(py::cast<std::string>(obj));
|
||||
} else if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(INFO) << "The Constant is none:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj);
|
||||
return NewValueNode(kNone);
|
||||
} else if (py::isinstance<py::ellipsis>(obj)) {
|
||||
MS_LOG(INFO) << "The Constance is ellipsis:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The Constance is ellipsis:" << (std::string)py::str(obj);
|
||||
return NewValueNode(kEllipsis);
|
||||
} else {
|
||||
// no else actually
|
||||
MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << py::cast<std::string>(obj);
|
||||
MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1029,16 +1031,16 @@ AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object
|
|||
MS_LOG(DEBUG) << "Process ast NameConstant";
|
||||
py::object obj = python_adapter::GetPyObjAttr(node, "value");
|
||||
if (py::isinstance<py::bool_>(obj)) {
|
||||
MS_LOG(INFO) << "The NameConstant is bool:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
|
||||
auto data = py::cast<bool>(obj);
|
||||
return NewValueNode(data);
|
||||
} else if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(INFO) << "The NameConstant is none:" << py::cast<std::string>(obj);
|
||||
MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
|
||||
return NewValueNode(kNone);
|
||||
}
|
||||
// no else actually
|
||||
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
|
||||
MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << py::cast<std::string>(obj);
|
||||
MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
|
||||
|
@ -1287,7 +1289,7 @@ void Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const {
|
||||
AnfNodePtr Parser::ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) {
|
||||
std::string var_name = "self.";
|
||||
std::string attr_name = node.attr("attr").cast<std::string>();
|
||||
(void)var_name.append(attr_name);
|
||||
|
@ -1855,7 +1857,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
|
||||
if (transform_for_half_unroll_call) {
|
||||
// Lift the if branches in for statement.
|
||||
(void)if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
|
||||
if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
|
||||
}
|
||||
|
||||
if (after_block->prev_blocks().empty()) {
|
||||
|
@ -2054,7 +2056,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
after_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
|
||||
(void)header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
LoopContext loop_context{&loops_, header_block, loop_var_inc};
|
||||
|
@ -2146,7 +2148,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
if (use_fallback) {
|
||||
after_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
(void)header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
|
||||
// Generate the body of the for statement
|
||||
FunctionBlockPtr rolled_body_block =
|
||||
|
@ -2325,7 +2327,7 @@ FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const
|
|||
list_after_block->func_graph()->set_output(list_param);
|
||||
|
||||
// Run the branches.
|
||||
(void)list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
|
||||
list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
|
||||
|
||||
top_block->Mature();
|
||||
list_header_block->Mature();
|
||||
|
@ -2369,7 +2371,7 @@ AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, con
|
|||
|
||||
// We don't want to create a header graph, where to get and wrap the result of Switch().
|
||||
// So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
|
||||
(void)list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
|
||||
list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
|
||||
// Output is Switch() result, i.e. updated list.
|
||||
auto switch_apply_node = list_body_block->func_graph()->output();
|
||||
auto ifs_new_list = switch_apply_node;
|
||||
|
@ -2444,11 +2446,11 @@ AnfNodePtr Parser::ParseFormattedValue(const FunctionBlockPtr &block, const py::
|
|||
return value_node;
|
||||
}
|
||||
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ,
|
||||
const AnfNodePtr &assigned_node) const {
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
MS_EXCEPTION_IF_NULL(assigned_node);
|
||||
py::str name = python_adapter::GetPyObjAttr(targ, "id");
|
||||
py::str name = python_adapter::GetPyObjAttr(target_object, "id");
|
||||
std::string name_id = name;
|
||||
MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
|
||||
assigned_node->debug_info()->set_name(name_id);
|
||||
|
@ -2465,10 +2467,11 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
|
|||
block->WriteVariable(name_id, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
|
||||
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
py::list items = python_adapter::GetPyObjAttr(targ, "elts");
|
||||
py::list items = python_adapter::GetPyObjAttr(target_object, "elts");
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
// Use the Primitive replace the operation resolve node (getitem),
|
||||
// because the getitem will eventually be converted to Primitive node
|
||||
|
@ -2481,14 +2484,14 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
|
||||
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
// Now only support the self.xx = xxxxx, can't support x.y = xxxx
|
||||
AnfNodePtr target_node = ParseExprNode(block, targ);
|
||||
target_node = HandleInterpret(block, target_node, targ);
|
||||
AnfNodePtr target_node = ParseExprNode(block, target_object);
|
||||
target_node = HandleInterpret(block, target_node, target_object);
|
||||
MS_EXCEPTION_IF_NULL(target_node);
|
||||
|
||||
auto attr_name = targ.attr("attr").cast<std::string>();
|
||||
auto attr_name = target_object.attr("attr").cast<std::string>();
|
||||
std::string var_name = "self." + attr_name;
|
||||
|
||||
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
|
||||
|
@ -2512,12 +2515,12 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
|
|||
block->SetStateAssign(target_node, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
|
||||
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
|
||||
py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
|
||||
py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
|
||||
py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value");
|
||||
py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice");
|
||||
AnfNodePtr value_node = ParseExprNode(block, value_obj);
|
||||
value_node = HandleInterpret(block, value_node, value_obj);
|
||||
AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
|
||||
|
@ -2586,7 +2589,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) const {
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) {
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (!use_fallback) {
|
||||
|
@ -2612,8 +2615,7 @@ void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::
|
|||
}
|
||||
|
||||
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::map<std::string, AnfNodePtr> &local_keys,
|
||||
const FuncGraphPtr &func_graph) const {
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Check global parameters.
|
||||
if (global_dict.contains(script_text)) {
|
||||
|
@ -2833,7 +2835,7 @@ FunctionBlockPtr Parser::ParseAssert(const FunctionBlockPtr &block, const py::ob
|
|||
|
||||
true_block->Jump(after_block, {});
|
||||
false_block = MakeAssertErrorBlock(false_block, node);
|
||||
(void)block->ConditionalJump(bool_node, true_block, false_block);
|
||||
block->ConditionalJump(bool_node, true_block, false_block);
|
||||
|
||||
after_block->Mature();
|
||||
return after_block;
|
||||
|
@ -2868,7 +2870,7 @@ void Parser::RemoveUnnecessaryPhis() {
|
|||
for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
|
||||
auto phi = phis[LongToSize(idx)];
|
||||
auto new_node = FindPhis(removable_phis, phi);
|
||||
(void)manager->Replace(phi, new_node);
|
||||
manager->Replace(phi, new_node);
|
||||
}
|
||||
// Remove the parameter
|
||||
for (FunctionBlockPtr &block : func_block_list_) {
|
||||
|
@ -3116,7 +3118,7 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|||
func_graph->set_has_kwarg(current_graph->has_kwarg());
|
||||
func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
|
||||
// Copy all default values
|
||||
for (const auto &d : current_graph->parameter_default_value()) {
|
||||
for (auto &d : current_graph->parameter_default_value()) {
|
||||
func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
|
||||
}
|
||||
|
||||
|
|
|
@ -164,17 +164,17 @@ class Parser {
|
|||
// Process a variable name
|
||||
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process NoneType
|
||||
AnfNodePtr ParseNone(const FunctionBlockPtr &, const py::object &);
|
||||
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process Ellipsis
|
||||
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &, const py::object &);
|
||||
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process an integer or float number
|
||||
AnfNodePtr ParseNum(const FunctionBlockPtr &, const py::object &node);
|
||||
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a string variable
|
||||
AnfNodePtr ParseStr(const FunctionBlockPtr &, const py::object &node);
|
||||
AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a Constant
|
||||
AnfNodePtr ParseConstant(const FunctionBlockPtr &, const py::object &node);
|
||||
AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a name
|
||||
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &, const py::object &node);
|
||||
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a function call
|
||||
AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process function 'super'
|
||||
|
@ -219,7 +219,7 @@ class Parser {
|
|||
std::vector<AnfNodePtr> ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node);
|
||||
void ParseStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes);
|
||||
FunctionBlockPtr MakeAssertErrorBlock(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const;
|
||||
AnfNodePtr ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
// Transform tail call to parallel call.
|
||||
void TransformParallelCall();
|
||||
|
@ -231,9 +231,9 @@ class Parser {
|
|||
|
||||
// Check if script_text is in global/local params.
|
||||
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) const;
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
// Set the interpret flag for the node calling the interpret node.
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) const;
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node);
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);
|
||||
// Make interpret node.
|
||||
AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text);
|
||||
|
@ -250,28 +250,28 @@ class Parser {
|
|||
const py::object &value_object);
|
||||
|
||||
// Generate argument nodes for ast function node
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node);
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// Generate argument default value for ast function node
|
||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node);
|
||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// Parse ast function node
|
||||
FunctionBlockPtr ParseDefFunction(const py::object &node, const FunctionBlockPtr &block = nullptr);
|
||||
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
// Parse lambda function node
|
||||
FunctionBlockPtr ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block = nullptr);
|
||||
FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
// Parse ast statements
|
||||
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &nodes);
|
||||
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
|
||||
// Parse one ast statement node
|
||||
FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Parse an ast expression node
|
||||
AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
void MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) const;
|
||||
void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
|
||||
const FunctionBlockPtr &falseBlock);
|
||||
void RemoveUnnecessaryPhis();
|
||||
// Write a new var
|
||||
void WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &value_node);
|
||||
void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node);
|
||||
|
||||
// Assign value to single variable name
|
||||
void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) const;
|
||||
void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
||||
// Assign value to tuple
|
||||
void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
@ -289,12 +289,18 @@ class Parser {
|
|||
// Process a bool operation value list
|
||||
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
|
||||
|
||||
CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
|
||||
const AnfNodePtr &op_iter);
|
||||
|
||||
CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
|
||||
const AnfNodePtr &op_hasnext);
|
||||
|
||||
FunctionBlockPtr GenerateBlock(const TraceInfoPtr &trace_info);
|
||||
|
||||
void ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, ArgsContext *args_context);
|
||||
|
||||
void ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, ArgsContext *args_context);
|
||||
AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
|
||||
AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
|
||||
const ArgsContext &args_context) const;
|
||||
ScopePtr GetScopeForParseFunction();
|
||||
// Check the value is subscript is reference type
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -69,7 +69,7 @@ void DoExecNonInputGraph(const std::string &phase) {
|
|||
}
|
||||
}
|
||||
|
||||
void CreateSessionAndGraphRunner(bool is_training = true) {
|
||||
Status CreateSessionAndGraphRunner(bool is_training = true) {
|
||||
std::shared_ptr<ge::Session> sess = transform::GetGeSession();
|
||||
if (sess == nullptr) {
|
||||
transform::SessionOptions options;
|
||||
|
@ -98,6 +98,7 @@ void CreateSessionAndGraphRunner(bool is_training = true) {
|
|||
options.sess_ptr = sess;
|
||||
auto graph_runner = transform::NewGraphRunner(options);
|
||||
transform::SetGraphRunner(graph_runner);
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
|
@ -139,7 +140,10 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
|
|||
return false;
|
||||
}
|
||||
|
||||
CreateSessionAndGraphRunner(training);
|
||||
if (CreateSessionAndGraphRunner(training) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create GE Session or GraphRunner failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "DoExecNonInputGraph:" << phase;
|
||||
DoExecNonInputGraph(phase);
|
||||
|
@ -273,7 +277,10 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
|
|||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
}
|
||||
|
||||
CreateSessionAndGraphRunner(training);
|
||||
if (CreateSessionAndGraphRunner(training) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create GE Session or GraphRunner failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return anf_graph;
|
||||
}
|
||||
|
@ -318,7 +325,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t
|
|||
for (size_t i = 0; i < size; i++) {
|
||||
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
|
||||
}
|
||||
return tp;
|
||||
return std::move(tp);
|
||||
}
|
||||
|
||||
py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) {
|
||||
|
@ -349,7 +356,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
|
|||
for (size_t i = 1; i < size; i++) {
|
||||
tp[i - 1] = StructureOutput(input_list[i], data, count);
|
||||
}
|
||||
return tp;
|
||||
return std::move(tp);
|
||||
}
|
||||
if (output_c->IsApply(prim::kPrimDepend)) {
|
||||
return StructureOutput(output_c->input(1), data, count);
|
||||
|
@ -365,7 +372,7 @@ void GetMeRetDataType(const AbstractBasePtr &cnode_data, std::vector<TypeId> *me
|
|||
TypeId me_type = cnode_data->BuildType()->type_id();
|
||||
if (me_type == kObjectTypeTensorType) {
|
||||
me_type = dyn_cast<TensorType>(cnode_data->BuildType())->element()->type_id();
|
||||
(void)me_types->emplace_back(me_type);
|
||||
me_types->emplace_back(me_type);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -115,7 +115,7 @@ static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
|
|||
auto graph_input_index = func_graph->get_inputs()[index];
|
||||
auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
|
||||
(void)abstract_list.emplace_back(virtual_dataset_abstract);
|
||||
abstract_list.emplace_back(virtual_dataset_abstract);
|
||||
virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
|
||||
}
|
||||
}
|
||||
|
@ -135,17 +135,17 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
for (auto &anf_param : root->parameters()) {
|
||||
auto param = anf_param->cast<ParameterPtr>();
|
||||
if (!param->has_default()) {
|
||||
(void)input_parameters.insert(anf_param);
|
||||
input_parameters.insert(anf_param);
|
||||
}
|
||||
}
|
||||
for (const auto &input_parameter : input_parameters) {
|
||||
for (auto input_parameter : input_parameters) {
|
||||
auto node_users_map = root->manager()->node_users();
|
||||
auto node_users = node_users_map[input_parameter];
|
||||
for (auto node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
if (IsValueNode<Primitive>(cnode->inputs()[0]) ||
|
||||
(IsValueNode<FuncGraph>(cnode->inputs()[0]) && !root->has_flag(parallel::kTraining))) {
|
||||
(void)graph_sets.insert(cnode->func_graph());
|
||||
graph_sets.insert(cnode->func_graph());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -166,7 +166,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
} else {
|
||||
fun_graph = node->func_graph();
|
||||
}
|
||||
(void)graph_sets.insert(fun_graph);
|
||||
graph_sets.insert(fun_graph);
|
||||
}
|
||||
}
|
||||
return graph_sets;
|
||||
|
@ -175,7 +175,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(root, all_nodes);
|
||||
for (const auto &forward_graph : forward_graph_set) {
|
||||
for (auto forward_graph : forward_graph_set) {
|
||||
FuncGraphManagerPtr manager = forward_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
|
||||
|
@ -187,7 +187,7 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
|
|||
continue;
|
||||
}
|
||||
auto node_users = node_user_map[graph_inputs[index]];
|
||||
for (const auto &node_user : node_users) {
|
||||
for (auto node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
for (size_t input_index = 1; input_index < cnode->inputs().size(); input_index++) {
|
||||
if (!IsValueNode<Primitive>(cnode->inputs()[0]) && !IsValueNode<FuncGraph>(cnode->inputs()[0])) {
|
||||
|
@ -202,12 +202,12 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
|
|||
if (!is_match) {
|
||||
continue;
|
||||
}
|
||||
size_t node_input_index = LongToSize(node_input_iter - graph_inputs.begin());
|
||||
size_t node_input_index = node_input_iter - graph_inputs.begin();
|
||||
if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
|
||||
parameter_index_map[node_input_index] =
|
||||
CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
|
||||
}
|
||||
manager->SetEdge(cnode, SizeToInt(input_index), parameter_index_map[node_input_index]);
|
||||
manager->SetEdge(cnode, input_index, parameter_index_map[node_input_index]);
|
||||
manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,8 +27,7 @@ namespace pipeline {
|
|||
using HashCache = mindspore::HashMap<std::size_t, std::vector<AnfNodePtr>>;
|
||||
using HashValue = mindspore::HashMap<AnfNodePtr, std::size_t>;
|
||||
|
||||
void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache,
|
||||
HashValue *const hash_value);
|
||||
void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -373,7 +373,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
// Gets branch graph from a switch cnode at given input index.
|
||||
FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) const {
|
||||
FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return GetValueNode<FuncGraphPtr>(cnode->inputs().at(index));
|
||||
}
|
||||
|
@ -406,7 +406,7 @@ class SideEffectFinder {
|
|||
|
||||
// Add monad parameter to switch branch graphs.
|
||||
void AddMonadParameters(const std::vector<FuncGraphPtr> &branches, const std::string &name,
|
||||
const AbstractBasePtr &abs) const {
|
||||
const AbstractBasePtr &abs) {
|
||||
for (auto &branch : branches) {
|
||||
(void)AddMonadParameter(branch, name, abs);
|
||||
}
|
||||
|
@ -472,7 +472,7 @@ class SideEffectFinder {
|
|||
}
|
||||
}
|
||||
|
||||
void FixSwitchBranch(const CNodePtr &caller, const FuncGraphPtr &branch) const {
|
||||
void FixSwitchBranch(const CNodePtr &caller, const FuncGraphPtr &branch) {
|
||||
for (size_t i = caller->size() - 1; i > 0; --i) {
|
||||
auto &input = caller->input(i);
|
||||
if (HasAbstractUMonad(input)) {
|
||||
|
@ -531,7 +531,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
// Get graphs from a tuple of funcs make node for switch_layer.
|
||||
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) const {
|
||||
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) {
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto &inputs = make_tuple->inputs();
|
||||
if (inputs.size() <= 1) {
|
||||
|
@ -794,7 +794,7 @@ class SideEffectFinder {
|
|||
return {EffectInfo::kDetected, false, false, false};
|
||||
}
|
||||
|
||||
int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr ¶) const {
|
||||
int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr ¶) {
|
||||
int parameter_index = 0;
|
||||
for (auto ¶meter : func_graph->parameters()) {
|
||||
if (para == parameter) {
|
||||
|
@ -1075,7 +1075,7 @@ class SideEffectFinder {
|
|||
}
|
||||
}
|
||||
|
||||
void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) const {
|
||||
void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(monad);
|
||||
auto monad_abs = monad->ToAbstract();
|
||||
|
@ -1211,7 +1211,7 @@ class AutoMonadConverter {
|
|||
}
|
||||
|
||||
// Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
|
||||
bool IsNoEliminateNode(const CNodePtr &cnode) const {
|
||||
bool IsNoEliminateNode(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,9 +18,11 @@
|
|||
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_AUTO_MONAD_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "base/effect_info.h"
|
||||
|
||||
namespace mindspore::pipeline {
|
||||
|
||||
|
|
|
@ -86,13 +86,13 @@ void BaseFuncGraphEvaluator::CollectSideEffectNodes(const AnfNodePtr &node,
|
|||
auto effect_info = GetPrimEffectInfo(primitive);
|
||||
if (effect_info.memory || effect_info.io) {
|
||||
MS_LOG(DEBUG) << "Side Effect Primitive CNode: " << node->DebugString();
|
||||
(void)side_effect_nodes->emplace_back(node);
|
||||
side_effect_nodes->emplace_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BaseFuncGraphEvaluator::CheckSideEffectNodes(const AbstractBasePtr &abstract,
|
||||
const std::vector<AnfNodePtr> &side_effect_nodes) const {
|
||||
const std::vector<AnfNodePtr> &side_effect_nodes) {
|
||||
if (!side_effect_nodes.empty()) {
|
||||
ValuePtr val = abstract->BuildValue();
|
||||
if (!val->isa<AnyValue>()) {
|
||||
|
|
|
@ -116,7 +116,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
|
|||
: PrimEvaluator(id), eval_cache_(AnalysisResultCacheMgr::GetInstance().prim_eval_cache()) {}
|
||||
~TrivialPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) final;
|
||||
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list) = 0;
|
||||
|
||||
protected:
|
||||
|
@ -140,7 +141,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
|
|||
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~SymbolicPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) final;
|
||||
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
||||
};
|
||||
|
||||
|
@ -222,7 +224,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
}
|
||||
|
||||
void CollectSideEffectNodes(const AnfNodePtr &node, std::vector<AnfNodePtr> *side_effect_nodes);
|
||||
void CheckSideEffectNodes(const AbstractBasePtr &abstract, const std::vector<AnfNodePtr> &side_effect_nodes) const;
|
||||
void CheckSideEffectNodes(const AbstractBasePtr &abstract, const std::vector<AnfNodePtr> &side_effect_nodes);
|
||||
|
||||
protected:
|
||||
AnalysisContextPtr parent_context_;
|
||||
|
@ -357,7 +359,8 @@ class JEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -388,7 +391,8 @@ class TaylorEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -421,7 +425,8 @@ class ShardEvaluator : public Evaluator {
|
|||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
|
@ -458,7 +463,8 @@ class VmapEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -102,7 +102,7 @@ class OrderEnforcer {
|
|||
}
|
||||
}
|
||||
|
||||
bool HasLoadInput(const CNodePtr &cnode) const {
|
||||
bool HasLoadInput(const CNodePtr &cnode) {
|
||||
auto &inputs = cnode->inputs();
|
||||
return std::any_of(inputs.begin() + 1, inputs.end(),
|
||||
[](const AnfNodePtr &input) { return IsPrimitiveCNode(input, prim::kPrimLoad); });
|
||||
|
@ -192,7 +192,7 @@ class OrderEnforcer {
|
|||
AddInputEdges(update_state_cnode, maketuple_users);
|
||||
}
|
||||
|
||||
bool IsRef(const AnfNodePtr &node) const {
|
||||
bool IsRef(const AnfNodePtr &node) {
|
||||
auto &abs = node->abstract();
|
||||
return abs != nullptr && abs->isa<abstract::AbstractRefTensor>();
|
||||
}
|
||||
|
@ -201,7 +201,7 @@ class OrderEnforcer {
|
|||
return IsPrimitiveCNode(node, prim::kPrimExpandDims) || IsPrimitiveCNode(node, prim::kPrimBatchNormGrad);
|
||||
}
|
||||
|
||||
bool IsSpecialParallelPrimitive(const AnfNodePtr &node) const {
|
||||
bool IsSpecialParallelPrimitive(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
|
||||
|
@ -247,7 +247,7 @@ class OrderEnforcer {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) const {
|
||||
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) {
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
const size_t attach_index = 2;
|
||||
const size_t input_size = update_state->inputs().size();
|
||||
|
@ -410,7 +410,7 @@ class OrderEnforcer {
|
|||
return ref_key->value();
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> GetAllLoads(const AnfNodePtrList &check_nodes) const {
|
||||
std::vector<CNodePtr> GetAllLoads(const AnfNodePtrList &check_nodes) {
|
||||
std::vector<CNodePtr> need_insert_loads;
|
||||
for (auto &node : check_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
|
@ -424,7 +424,7 @@ class OrderEnforcer {
|
|||
std::vector<CNodePtr> GetSpecialLoads(const std::map<std::string, std::vector<CNodePtr>> &loads_map1,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map2,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map3,
|
||||
const std::set<CNodePtr> &call_lodes) const {
|
||||
const std::set<CNodePtr> &call_lodes) {
|
||||
std::vector<CNodePtr> need_insert_loads;
|
||||
for (auto &refkey_load : loads_map1) {
|
||||
auto &loads = refkey_load.second;
|
||||
|
@ -456,7 +456,7 @@ class OrderEnforcer {
|
|||
return need_insert_loads;
|
||||
}
|
||||
|
||||
bool CheckLoadInput(const AnfNodePtr &input) const {
|
||||
bool CheckLoadInput(const AnfNodePtr &input) {
|
||||
return IsPrimitiveCNode(input, prim::kPrimCall) || IsPrimitiveCNode(input, prim::kPrimPartial) ||
|
||||
(input->isa<CNode>() && (IsValueNode<FuncGraph>(input->cast<CNodePtr>()->input(0)) ||
|
||||
IsPrimitiveCNode(input->cast<CNodePtr>()->input(0), prim::kPrimSwitch) ||
|
||||
|
@ -549,7 +549,7 @@ void OrderEnforce(const FuncGraphPtr &func_graph) {
|
|||
OrderEnforcer enforcer(func_graph);
|
||||
enforcer.Run();
|
||||
auto fg_used_total = func_graph->func_graphs_used_total();
|
||||
for (const auto &fg : fg_used_total) {
|
||||
for (auto &fg : fg_used_total) {
|
||||
OrderEnforcer fg_enforcer(fg);
|
||||
fg_enforcer.Run();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue