forked from mindspore-Ecosystem/mindspore
!38449 code warning clean.
Merge pull request !38449 from Margaret_wangrui/codecheck
This commit is contained in:
commit
5da23cf328
|
@ -18,13 +18,8 @@
|
|||
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/hash_map.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
@ -37,7 +32,9 @@ namespace mindspore {
|
|||
namespace parse {
|
||||
namespace {
|
||||
struct PyDataToValueRegister {
|
||||
PyDataToValueRegister() { python_adapter::PyAdapterCallback::SetPyDataToValueHandler(data_converter::PyDataToValue); }
|
||||
PyDataToValueRegister() noexcept {
|
||||
python_adapter::PyAdapterCallback::SetPyDataToValueHandler(data_converter::PyDataToValue);
|
||||
}
|
||||
} callback_register;
|
||||
} // namespace
|
||||
using Tensor = mindspore::tensor::Tensor;
|
||||
|
@ -628,8 +625,8 @@ const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs()
|
|||
|
||||
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
|
||||
|
||||
bool GetObjectValue(const std::string &obj_key, ValuePtr *data) {
|
||||
if (object_map_.count(obj_key)) {
|
||||
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
|
||||
if (object_map_.count(obj_key) != 0) {
|
||||
*data = object_map_[obj_key];
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -18,13 +18,10 @@
|
|||
|
||||
#include "pipeline/jit/parse/function_block.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include <queue>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
|
@ -37,10 +34,8 @@ namespace mindspore {
|
|||
namespace py = pybind11;
|
||||
|
||||
namespace parse {
|
||||
FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
|
||||
func_graph_ = std::make_shared<FuncGraph>();
|
||||
matured_ = false;
|
||||
}
|
||||
FunctionBlock::FunctionBlock(const Parser &parser)
|
||||
: func_graph_(std::make_shared<FuncGraph>()), parser_(parser), matured_(false) {}
|
||||
|
||||
void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
|
||||
|
||||
|
@ -137,7 +132,7 @@ std::pair<AnfNodePtr, bool> FunctionBlock::FindPredInterpretNode(const std::stri
|
|||
while (!block_queue.empty()) {
|
||||
const auto cur_block = block_queue.front();
|
||||
block_queue.pop();
|
||||
visited_block.insert(cur_block);
|
||||
(void)visited_block.insert(cur_block);
|
||||
auto pred_node = cur_block->ReadLocalVariable(var_name);
|
||||
if (pred_node != nullptr) {
|
||||
has_found = true;
|
||||
|
@ -650,7 +645,7 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
|
|||
std::vector<AnfNodePtr> states;
|
||||
states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
constexpr int recursive_level = 2;
|
||||
for (auto &node : isolated_nodes_) {
|
||||
for (const auto &node : isolated_nodes_) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(recursive_level) << " in "
|
||||
<< func_graph_->ToString();
|
||||
|
|
|
@ -61,7 +61,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
void Mature();
|
||||
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
|
||||
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
|
||||
void Jump(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &args);
|
||||
void Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args);
|
||||
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
|
||||
CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
|
||||
const AnfNodePtr &false_block_call);
|
||||
|
@ -76,9 +76,9 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
AnfNodePtr MakeResolveSymbol(const std::string &value);
|
||||
AnfNodePtr MakeResolveOperation(const std::string &value);
|
||||
AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
|
||||
AnfNodePtr GetResolveNode(const py::tuple &namespace_info);
|
||||
AnfNodePtr HandleNamespaceInfo(const py::tuple &namespace_info);
|
||||
AnfNodePtr HandleBuiltinNamespaceInfo(const py::tuple &namespace_info);
|
||||
AnfNodePtr GetResolveNode(const py::tuple &info);
|
||||
AnfNodePtr HandleNamespaceInfo(const py::tuple &info);
|
||||
AnfNodePtr HandleBuiltinNamespaceInfo(const py::tuple &info);
|
||||
AnfNodePtr MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
|
||||
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node);
|
||||
const mindspore::HashMap<ParameterPtr, AnfNodePtr> &removable_phis() const { return removable_phis_; }
|
||||
|
@ -111,11 +111,11 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
|
||||
// Call this method to update or add a variable.
|
||||
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
|
||||
auto key_iter = local_py_params_keys_.find(name);
|
||||
const auto key_iter = local_py_params_keys_.find(name);
|
||||
if (key_iter == local_py_params_keys_.end()) {
|
||||
MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString();
|
||||
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name)));
|
||||
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(name, node));
|
||||
(void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name)));
|
||||
(void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(name, node));
|
||||
} else {
|
||||
// Find the same position in 'values', and update the node.
|
||||
MS_LOG(DEBUG) << "Update '" << name << "', " << local_py_params_values_[name]->DebugString() << " -> "
|
||||
|
@ -131,10 +131,10 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
}
|
||||
for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
|
||||
const std::string &cur_key_name = iter->first;
|
||||
auto key_iter = local_py_params_keys_.find(cur_key_name);
|
||||
const auto key_iter = local_py_params_keys_.find(cur_key_name);
|
||||
if (key_iter == local_py_params_keys_.end()) {
|
||||
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
|
||||
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
|
||||
(void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
|
||||
(void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
|
||||
MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
|
||||
} else {
|
||||
// The local variable is already in the current block. This means the current block has multiples previous
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include "utils/hash_map.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
@ -68,9 +67,8 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
|
|||
|
||||
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
||||
|
||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
|
||||
support_fallback_ = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||
errcode_ = PARSE_SUCCESS;
|
||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast)
|
||||
: ast_(ast), errcode_(PARSE_SUCCESS), support_fallback_(common::GetEnv("MS_DEV_ENABLE_FALLBACK")) {
|
||||
BuildMethodMap();
|
||||
}
|
||||
|
||||
|
@ -386,10 +384,10 @@ bool CheckMiddleGraphOutputPyInterpret(
|
|||
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
|
||||
return false;
|
||||
}
|
||||
|
||||
contain_py_interpret |=
|
||||
bool exist_interpret =
|
||||
std::any_of(middle_graph_output_cnode->inputs().cbegin() + 1, middle_graph_output_cnode->inputs().cend(),
|
||||
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimPyInterpret); });
|
||||
contain_py_interpret |= exist_interpret;
|
||||
if (contain_py_interpret) {
|
||||
return true;
|
||||
}
|
||||
|
@ -718,7 +716,7 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
|
|||
// Call the process function
|
||||
std::string node_name = node_type->node_name();
|
||||
MS_LOG(DEBUG) << "Ast node is " << node_name;
|
||||
if (stmt_method_map_.count(node_name)) {
|
||||
if (stmt_method_map_.count(node_name) != 0) {
|
||||
auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return stmt_block;
|
||||
|
@ -742,7 +740,7 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
|
|||
// Call the process function
|
||||
std::string node_name = node_type->node_name();
|
||||
MS_LOG(DEBUG) << "Ast node is " << node_name;
|
||||
if (expr_method_map_.count(node_name)) {
|
||||
if (expr_method_map_.count(node_name) != 0) {
|
||||
auto expr_node = (this->*expr_method_map_[node_name])(block, node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return expr_node;
|
||||
|
@ -821,7 +819,7 @@ LocationPtr Parser::GetLocation(const py::object &node) const {
|
|||
}
|
||||
|
||||
void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) {
|
||||
const FunctionBlockPtr &false_block) const {
|
||||
MS_EXCEPTION_IF_NULL(true_block);
|
||||
MS_EXCEPTION_IF_NULL(false_block);
|
||||
true_block->AddPrevBlock(pre_block);
|
||||
|
@ -1289,7 +1287,7 @@ void Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) {
|
||||
AnfNodePtr Parser::ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const {
|
||||
std::string var_name = "self.";
|
||||
std::string attr_name = node.attr("attr").cast<std::string>();
|
||||
(void)var_name.append(attr_name);
|
||||
|
@ -1857,7 +1855,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
|
||||
if (transform_for_half_unroll_call) {
|
||||
// Lift the if branches in for statement.
|
||||
if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
|
||||
(void)if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
|
||||
}
|
||||
|
||||
if (after_block->prev_blocks().empty()) {
|
||||
|
@ -2056,7 +2054,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
after_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
(void)header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
LoopContext loop_context{&loops_, header_block, loop_var_inc};
|
||||
|
@ -2148,7 +2146,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
if (use_fallback) {
|
||||
after_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
(void)header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
|
||||
// Generate the body of the for statement
|
||||
FunctionBlockPtr rolled_body_block =
|
||||
|
@ -2333,7 +2331,7 @@ FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const
|
|||
list_after_block->func_graph()->set_output(list_param);
|
||||
|
||||
// Run the branches.
|
||||
list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
|
||||
(void)list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
|
||||
|
||||
top_block->Mature();
|
||||
list_header_block->Mature();
|
||||
|
@ -2377,7 +2375,7 @@ AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, con
|
|||
|
||||
// We don't want to create a header graph, where to get and wrap the result of Switch().
|
||||
// So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
|
||||
list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
|
||||
(void)list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
|
||||
// Output is Switch() result, i.e. updated list.
|
||||
auto switch_apply_node = list_body_block->func_graph()->output();
|
||||
auto ifs_new_list = switch_apply_node;
|
||||
|
@ -2452,11 +2450,11 @@ AnfNodePtr Parser::ParseFormattedValue(const FunctionBlockPtr &block, const py::
|
|||
return value_node;
|
||||
}
|
||||
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ,
|
||||
const AnfNodePtr &assigned_node) const {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
MS_EXCEPTION_IF_NULL(assigned_node);
|
||||
py::str name = python_adapter::GetPyObjAttr(target_object, "id");
|
||||
py::str name = python_adapter::GetPyObjAttr(targ, "id");
|
||||
std::string name_id = name;
|
||||
MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
|
||||
assigned_node->debug_info()->set_name(name_id);
|
||||
|
@ -2473,11 +2471,10 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
|
|||
block->WriteVariable(name_id, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
py::list items = python_adapter::GetPyObjAttr(target_object, "elts");
|
||||
py::list items = python_adapter::GetPyObjAttr(targ, "elts");
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
// Use the Primitive replace the operation resolve node (getitem),
|
||||
// because the getitem will eventually be converted to Primitive node
|
||||
|
@ -2490,14 +2487,14 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
// Now only support the self.xx = xxxxx, can't support x.y = xxxx
|
||||
AnfNodePtr target_node = ParseExprNode(block, target_object);
|
||||
target_node = HandleInterpret(block, target_node, target_object);
|
||||
AnfNodePtr target_node = ParseExprNode(block, targ);
|
||||
target_node = HandleInterpret(block, target_node, targ);
|
||||
MS_EXCEPTION_IF_NULL(target_node);
|
||||
|
||||
auto attr_name = target_object.attr("attr").cast<std::string>();
|
||||
auto attr_name = targ.attr("attr").cast<std::string>();
|
||||
std::string var_name = "self." + attr_name;
|
||||
|
||||
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
|
||||
|
@ -2521,12 +2518,12 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
|
|||
block->SetStateAssign(target_node, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object,
|
||||
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
|
||||
py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value");
|
||||
py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice");
|
||||
py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
|
||||
py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
|
||||
AnfNodePtr value_node = ParseExprNode(block, value_obj);
|
||||
value_node = HandleInterpret(block, value_node, value_obj);
|
||||
AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
|
||||
|
@ -2595,7 +2592,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) {
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) const {
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (!use_fallback) {
|
||||
|
@ -2621,7 +2618,8 @@ void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::
|
|||
}
|
||||
|
||||
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
|
||||
const std::map<std::string, AnfNodePtr> &local_keys,
|
||||
const FuncGraphPtr &func_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Check global parameters.
|
||||
if (global_dict.contains(script_text)) {
|
||||
|
@ -2841,7 +2839,7 @@ FunctionBlockPtr Parser::ParseAssert(const FunctionBlockPtr &block, const py::ob
|
|||
|
||||
true_block->Jump(after_block, {});
|
||||
false_block = MakeAssertErrorBlock(false_block, node);
|
||||
block->ConditionalJump(bool_node, true_block, false_block);
|
||||
(void)block->ConditionalJump(bool_node, true_block, false_block);
|
||||
|
||||
after_block->Mature();
|
||||
return after_block;
|
||||
|
@ -2876,7 +2874,7 @@ void Parser::RemoveUnnecessaryPhis() {
|
|||
for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
|
||||
auto phi = phis[LongToSize(idx)];
|
||||
auto new_node = FindPhis(removable_phis, phi);
|
||||
manager->Replace(phi, new_node);
|
||||
(void)manager->Replace(phi, new_node);
|
||||
}
|
||||
// Remove the parameter
|
||||
for (FunctionBlockPtr &block : func_block_list_) {
|
||||
|
@ -3124,7 +3122,7 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|||
func_graph->set_has_kwarg(current_graph->has_kwarg());
|
||||
func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
|
||||
// Copy all default values
|
||||
for (auto &d : current_graph->parameter_default_value()) {
|
||||
for (const auto &d : current_graph->parameter_default_value()) {
|
||||
func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
|
||||
}
|
||||
|
||||
|
|
|
@ -164,17 +164,17 @@ class Parser {
|
|||
// Process a variable name
|
||||
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process NoneType
|
||||
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseNone(const FunctionBlockPtr &, const py::object &);
|
||||
// Process Ellipsis
|
||||
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &, const py::object &);
|
||||
// Process an integer or float number
|
||||
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseNum(const FunctionBlockPtr &, const py::object &node);
|
||||
// Process a string variable
|
||||
AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseStr(const FunctionBlockPtr &, const py::object &node);
|
||||
// Process a Constant
|
||||
AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseConstant(const FunctionBlockPtr &, const py::object &node);
|
||||
// Process a name
|
||||
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &, const py::object &node);
|
||||
// Process a function call
|
||||
AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process function 'super'
|
||||
|
@ -219,7 +219,7 @@ class Parser {
|
|||
std::vector<AnfNodePtr> ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node);
|
||||
void ParseStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes);
|
||||
FunctionBlockPtr MakeAssertErrorBlock(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const;
|
||||
|
||||
// Transform tail call to parallel call.
|
||||
void TransformParallelCall();
|
||||
|
@ -231,9 +231,9 @@ class Parser {
|
|||
|
||||
// Check if script_text is in global/local params.
|
||||
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) const;
|
||||
// Set the interpret flag for the node calling the interpret node.
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node);
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) const;
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);
|
||||
// Make interpret node.
|
||||
AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text);
|
||||
|
@ -250,28 +250,28 @@ class Parser {
|
|||
const py::object &value_object);
|
||||
|
||||
// Generate argument nodes for ast function node
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node);
|
||||
// Generate argument default value for ast function node
|
||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node);
|
||||
// Parse ast function node
|
||||
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
FunctionBlockPtr ParseDefFunction(const py::object &node, const FunctionBlockPtr &block = nullptr);
|
||||
// Parse lambda function node
|
||||
FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
FunctionBlockPtr ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block = nullptr);
|
||||
// Parse ast statements
|
||||
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
|
||||
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &nodes);
|
||||
// Parse one ast statement node
|
||||
FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Parse an ast expression node
|
||||
AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
|
||||
const FunctionBlockPtr &falseBlock);
|
||||
void MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) const;
|
||||
void RemoveUnnecessaryPhis();
|
||||
// Write a new var
|
||||
void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node);
|
||||
void WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &value_node);
|
||||
|
||||
// Assign value to single variable name
|
||||
void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) const;
|
||||
|
||||
// Assign value to tuple
|
||||
void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
@ -289,18 +289,12 @@ class Parser {
|
|||
// Process a bool operation value list
|
||||
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
|
||||
|
||||
CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
|
||||
const AnfNodePtr &op_iter);
|
||||
|
||||
CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
|
||||
const AnfNodePtr &op_hasnext);
|
||||
|
||||
FunctionBlockPtr GenerateBlock(const TraceInfoPtr &trace_info);
|
||||
|
||||
void ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, ArgsContext *args_context);
|
||||
|
||||
void ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, ArgsContext *args_context);
|
||||
AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
|
||||
AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
|
||||
const ArgsContext &args_context) const;
|
||||
ScopePtr GetScopeForParseFunction();
|
||||
// Check the value is subscript is reference type
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -69,7 +69,7 @@ void DoExecNonInputGraph(const std::string &phase) {
|
|||
}
|
||||
}
|
||||
|
||||
Status CreateSessionAndGraphRunner(bool is_training = true) {
|
||||
void CreateSessionAndGraphRunner(bool is_training = true) {
|
||||
std::shared_ptr<ge::Session> sess = transform::GetGeSession();
|
||||
if (sess == nullptr) {
|
||||
transform::SessionOptions options;
|
||||
|
@ -98,7 +98,6 @@ Status CreateSessionAndGraphRunner(bool is_training = true) {
|
|||
options.sess_ptr = sess;
|
||||
auto graph_runner = transform::NewGraphRunner(options);
|
||||
transform::SetGraphRunner(graph_runner);
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
|
@ -140,10 +139,7 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
|
|||
return false;
|
||||
}
|
||||
|
||||
if (CreateSessionAndGraphRunner(training) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create GE Session or GraphRunner failed.";
|
||||
return false;
|
||||
}
|
||||
CreateSessionAndGraphRunner(training);
|
||||
|
||||
MS_LOG(INFO) << "DoExecNonInputGraph:" << phase;
|
||||
DoExecNonInputGraph(phase);
|
||||
|
@ -277,10 +273,7 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
|
|||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
}
|
||||
|
||||
if (CreateSessionAndGraphRunner(training) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create GE Session or GraphRunner failed.";
|
||||
return nullptr;
|
||||
}
|
||||
CreateSessionAndGraphRunner(training);
|
||||
|
||||
return anf_graph;
|
||||
}
|
||||
|
@ -325,7 +318,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t
|
|||
for (size_t i = 0; i < size; i++) {
|
||||
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
|
||||
}
|
||||
return std::move(tp);
|
||||
return tp;
|
||||
}
|
||||
|
||||
py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) {
|
||||
|
@ -356,7 +349,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
|
|||
for (size_t i = 1; i < size; i++) {
|
||||
tp[i - 1] = StructureOutput(input_list[i], data, count);
|
||||
}
|
||||
return std::move(tp);
|
||||
return tp;
|
||||
}
|
||||
if (output_c->IsApply(prim::kPrimDepend)) {
|
||||
return StructureOutput(output_c->input(1), data, count);
|
||||
|
@ -372,7 +365,7 @@ void GetMeRetDataType(const AbstractBasePtr &cnode_data, std::vector<TypeId> *me
|
|||
TypeId me_type = cnode_data->BuildType()->type_id();
|
||||
if (me_type == kObjectTypeTensorType) {
|
||||
me_type = dyn_cast<TensorType>(cnode_data->BuildType())->element()->type_id();
|
||||
me_types->emplace_back(me_type);
|
||||
(void)me_types->emplace_back(me_type);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -115,7 +115,7 @@ static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
|
|||
auto graph_input_index = func_graph->get_inputs()[index];
|
||||
auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
|
||||
abstract_list.emplace_back(virtual_dataset_abstract);
|
||||
(void)abstract_list.emplace_back(virtual_dataset_abstract);
|
||||
virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
|
||||
}
|
||||
}
|
||||
|
@ -135,17 +135,17 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
for (auto &anf_param : root->parameters()) {
|
||||
auto param = anf_param->cast<ParameterPtr>();
|
||||
if (!param->has_default()) {
|
||||
input_parameters.insert(anf_param);
|
||||
(void)input_parameters.insert(anf_param);
|
||||
}
|
||||
}
|
||||
for (auto input_parameter : input_parameters) {
|
||||
for (const auto &input_parameter : input_parameters) {
|
||||
auto node_users_map = root->manager()->node_users();
|
||||
auto node_users = node_users_map[input_parameter];
|
||||
for (auto node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
if (IsValueNode<Primitive>(cnode->inputs()[0]) ||
|
||||
(IsValueNode<FuncGraph>(cnode->inputs()[0]) && !root->has_flag(parallel::kTraining))) {
|
||||
graph_sets.insert(cnode->func_graph());
|
||||
(void)graph_sets.insert(cnode->func_graph());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -166,7 +166,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
} else {
|
||||
fun_graph = node->func_graph();
|
||||
}
|
||||
graph_sets.insert(fun_graph);
|
||||
(void)graph_sets.insert(fun_graph);
|
||||
}
|
||||
}
|
||||
return graph_sets;
|
||||
|
@ -175,7 +175,7 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(root, all_nodes);
|
||||
for (auto forward_graph : forward_graph_set) {
|
||||
for (const auto &forward_graph : forward_graph_set) {
|
||||
FuncGraphManagerPtr manager = forward_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
|
||||
|
@ -187,7 +187,7 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
|
|||
continue;
|
||||
}
|
||||
auto node_users = node_user_map[graph_inputs[index]];
|
||||
for (auto node_user : node_users) {
|
||||
for (const auto &node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
for (size_t input_index = 1; input_index < cnode->inputs().size(); input_index++) {
|
||||
if (!IsValueNode<Primitive>(cnode->inputs()[0]) && !IsValueNode<FuncGraph>(cnode->inputs()[0])) {
|
||||
|
@ -202,12 +202,12 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
|
|||
if (!is_match) {
|
||||
continue;
|
||||
}
|
||||
size_t node_input_index = node_input_iter - graph_inputs.begin();
|
||||
size_t node_input_index = LongToSize(node_input_iter - graph_inputs.begin());
|
||||
if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
|
||||
parameter_index_map[node_input_index] =
|
||||
CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
|
||||
}
|
||||
manager->SetEdge(cnode, input_index, parameter_index_map[node_input_index]);
|
||||
manager->SetEdge(cnode, SizeToInt(input_index), parameter_index_map[node_input_index]);
|
||||
manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,7 +27,8 @@ namespace pipeline {
|
|||
using HashCache = mindspore::HashMap<std::size_t, std::vector<AnfNodePtr>>;
|
||||
using HashValue = mindspore::HashMap<AnfNodePtr, std::size_t>;
|
||||
|
||||
void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value);
|
||||
void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache,
|
||||
HashValue *const hash_value);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -373,7 +373,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
// Gets branch graph from a switch cnode at given input index.
|
||||
FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) {
|
||||
FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return GetValueNode<FuncGraphPtr>(cnode->inputs().at(index));
|
||||
}
|
||||
|
@ -406,7 +406,7 @@ class SideEffectFinder {
|
|||
|
||||
// Add monad parameter to switch branch graphs.
|
||||
void AddMonadParameters(const std::vector<FuncGraphPtr> &branches, const std::string &name,
|
||||
const AbstractBasePtr &abs) {
|
||||
const AbstractBasePtr &abs) const {
|
||||
for (auto &branch : branches) {
|
||||
(void)AddMonadParameter(branch, name, abs);
|
||||
}
|
||||
|
@ -472,7 +472,7 @@ class SideEffectFinder {
|
|||
}
|
||||
}
|
||||
|
||||
void FixSwitchBranch(const CNodePtr &caller, const FuncGraphPtr &branch) {
|
||||
void FixSwitchBranch(const CNodePtr &caller, const FuncGraphPtr &branch) const {
|
||||
for (size_t i = caller->size() - 1; i > 0; --i) {
|
||||
auto &input = caller->input(i);
|
||||
if (HasAbstractUMonad(input)) {
|
||||
|
@ -531,7 +531,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
// Get graphs from a tuple of funcs make node for switch_layer.
|
||||
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) {
|
||||
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) const {
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto &inputs = make_tuple->inputs();
|
||||
if (inputs.size() <= 1) {
|
||||
|
@ -794,7 +794,7 @@ class SideEffectFinder {
|
|||
return {EffectInfo::kDetected, false, false, false};
|
||||
}
|
||||
|
||||
int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr ¶) {
|
||||
int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr ¶) const {
|
||||
int parameter_index = 0;
|
||||
for (auto ¶meter : func_graph->parameters()) {
|
||||
if (para == parameter) {
|
||||
|
@ -1075,7 +1075,7 @@ class SideEffectFinder {
|
|||
}
|
||||
}
|
||||
|
||||
void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) {
|
||||
void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(monad);
|
||||
auto monad_abs = monad->ToAbstract();
|
||||
|
@ -1211,7 +1211,7 @@ class AutoMonadConverter {
|
|||
}
|
||||
|
||||
// Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
|
||||
bool IsNoEliminateNode(const CNodePtr &cnode) {
|
||||
bool IsNoEliminateNode(const CNodePtr &cnode) const {
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,11 +18,9 @@
|
|||
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_AUTO_MONAD_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "base/effect_info.h"
|
||||
|
||||
namespace mindspore::pipeline {
|
||||
|
||||
|
|
|
@ -86,13 +86,13 @@ void BaseFuncGraphEvaluator::CollectSideEffectNodes(const AnfNodePtr &node,
|
|||
auto effect_info = GetPrimEffectInfo(primitive);
|
||||
if (effect_info.memory || effect_info.io) {
|
||||
MS_LOG(DEBUG) << "Side Effect Primitive CNode: " << node->DebugString();
|
||||
side_effect_nodes->emplace_back(node);
|
||||
(void)side_effect_nodes->emplace_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BaseFuncGraphEvaluator::CheckSideEffectNodes(const AbstractBasePtr &abstract,
|
||||
const std::vector<AnfNodePtr> &side_effect_nodes) {
|
||||
const std::vector<AnfNodePtr> &side_effect_nodes) const {
|
||||
if (!side_effect_nodes.empty()) {
|
||||
ValuePtr val = abstract->BuildValue();
|
||||
if (!val->isa<AnyValue>()) {
|
||||
|
|
|
@ -116,8 +116,7 @@ class TrivialPrimEvaluator : public PrimEvaluator {
|
|||
: PrimEvaluator(id), eval_cache_(AnalysisResultCacheMgr::GetInstance().prim_eval_cache()) {}
|
||||
~TrivialPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final;
|
||||
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list) = 0;
|
||||
|
||||
protected:
|
||||
|
@ -141,8 +140,7 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
|
|||
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~SymbolicPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final;
|
||||
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
||||
};
|
||||
|
||||
|
@ -224,7 +222,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
}
|
||||
|
||||
void CollectSideEffectNodes(const AnfNodePtr &node, std::vector<AnfNodePtr> *side_effect_nodes);
|
||||
void CheckSideEffectNodes(const AbstractBasePtr &abstract, const std::vector<AnfNodePtr> &side_effect_nodes);
|
||||
void CheckSideEffectNodes(const AbstractBasePtr &abstract, const std::vector<AnfNodePtr> &side_effect_nodes) const;
|
||||
|
||||
protected:
|
||||
AnalysisContextPtr parent_context_;
|
||||
|
@ -359,8 +357,7 @@ class JEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -391,8 +388,7 @@ class TaylorEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -425,8 +421,7 @@ class ShardEvaluator : public Evaluator {
|
|||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
|
@ -463,8 +458,7 @@ class VmapEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -102,7 +102,7 @@ class OrderEnforcer {
|
|||
}
|
||||
}
|
||||
|
||||
bool HasLoadInput(const CNodePtr &cnode) {
|
||||
bool HasLoadInput(const CNodePtr &cnode) const {
|
||||
auto &inputs = cnode->inputs();
|
||||
return std::any_of(inputs.begin() + 1, inputs.end(),
|
||||
[](const AnfNodePtr &input) { return IsPrimitiveCNode(input, prim::kPrimLoad); });
|
||||
|
@ -192,7 +192,7 @@ class OrderEnforcer {
|
|||
AddInputEdges(update_state_cnode, maketuple_users);
|
||||
}
|
||||
|
||||
bool IsRef(const AnfNodePtr &node) {
|
||||
bool IsRef(const AnfNodePtr &node) const {
|
||||
auto &abs = node->abstract();
|
||||
return abs != nullptr && abs->isa<abstract::AbstractRefTensor>();
|
||||
}
|
||||
|
@ -201,7 +201,7 @@ class OrderEnforcer {
|
|||
return IsPrimitiveCNode(node, prim::kPrimExpandDims) || IsPrimitiveCNode(node, prim::kPrimBatchNormGrad);
|
||||
}
|
||||
|
||||
bool IsSpecialParallelPrimitive(const AnfNodePtr &node) {
|
||||
bool IsSpecialParallelPrimitive(const AnfNodePtr &node) const {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
|
||||
|
@ -247,7 +247,7 @@ class OrderEnforcer {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) {
|
||||
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) const {
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
const size_t attach_index = 2;
|
||||
const size_t input_size = update_state->inputs().size();
|
||||
|
@ -410,7 +410,7 @@ class OrderEnforcer {
|
|||
return ref_key->value();
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> GetAllLoads(const AnfNodePtrList &check_nodes) {
|
||||
std::vector<CNodePtr> GetAllLoads(const AnfNodePtrList &check_nodes) const {
|
||||
std::vector<CNodePtr> need_insert_loads;
|
||||
for (auto &node : check_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
|
@ -424,7 +424,7 @@ class OrderEnforcer {
|
|||
std::vector<CNodePtr> GetSpecialLoads(const std::map<std::string, std::vector<CNodePtr>> &loads_map1,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map2,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map3,
|
||||
const std::set<CNodePtr> &call_lodes) {
|
||||
const std::set<CNodePtr> &call_lodes) const {
|
||||
std::vector<CNodePtr> need_insert_loads;
|
||||
for (auto &refkey_load : loads_map1) {
|
||||
auto &loads = refkey_load.second;
|
||||
|
@ -456,7 +456,7 @@ class OrderEnforcer {
|
|||
return need_insert_loads;
|
||||
}
|
||||
|
||||
bool CheckLoadInput(const AnfNodePtr &input) {
|
||||
bool CheckLoadInput(const AnfNodePtr &input) const {
|
||||
return IsPrimitiveCNode(input, prim::kPrimCall) || IsPrimitiveCNode(input, prim::kPrimPartial) ||
|
||||
(input->isa<CNode>() && (IsValueNode<FuncGraph>(input->cast<CNodePtr>()->input(0)) ||
|
||||
IsPrimitiveCNode(input->cast<CNodePtr>()->input(0), prim::kPrimSwitch) ||
|
||||
|
@ -549,7 +549,7 @@ void OrderEnforce(const FuncGraphPtr &func_graph) {
|
|||
OrderEnforcer enforcer(func_graph);
|
||||
enforcer.Run();
|
||||
auto fg_used_total = func_graph->func_graphs_used_total();
|
||||
for (auto &fg : fg_used_total) {
|
||||
for (const auto &fg : fg_used_total) {
|
||||
OrderEnforcer fg_enforcer(fg);
|
||||
fg_enforcer.Run();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue