!16153 pynative refactoring to optimizing performance

From: @chujinjin
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-20 11:16:00 +08:00 committed by Gitee
commit 789b63b501
43 changed files with 3675 additions and 1844 deletions

View File

@ -1694,13 +1694,17 @@ class ExecuteOrderGenerator {
}
}
}
// Search all nodes for parameter write assigns.
std::set<AnfNodePtr> refed_parameters;
for (auto &item : search_list) {
auto &node = item.first;
std::set<AnfNodePtr> refed_parameters;
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
(void)refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
}
}
// Search all nodes for parameter write assigns.
for (auto &item : search_list) {
auto &node = item.first;
for (auto &in : node->inputs()) {
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
visit_node = validate_ref_parameter(visit_node);

View File

@ -290,10 +290,10 @@ void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_r
task_cond_var_.notify_all();
if (sync && !sync_run_task_finished_) {
std::unique_lock<std::mutex> lock(task_mutex_);
if (long_run) {
if (sync && long_run) {
mindspore::ScopedLongRunning long_running;
sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
} else {
} else if (sync) {
sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
}
}

View File

@ -108,11 +108,6 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(value, &tensors);
if (!tensors.empty()) {
if (tensors.size() != AnfAlgo::GetOutputTensorNum(value_node)) {
MS_LOG(EXCEPTION) << "The size of tensors converted from value [" << tensors.size()
<< "] is not equal to output size of value node [" << AnfAlgo::GetOutputTensorNum(value_node)
<< "]";
}
device_formats->clear();
device_types->clear();
for (const auto &tensor : tensors) {
@ -692,7 +687,7 @@ AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
if (RemoveValueNodeFromGraph(value_node)) {
if (!RemoveValueNodeFromGraph(value_node)) {
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
}
return make_tuple;

View File

@ -361,13 +361,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = Manage({fg, func_graph}, false);
auto need_replace_forward = pynative::PynativeExecutor::GetInstance()->need_replace_forward();
auto forward_value = GenNewTensor(manager, equivdout, forward, need_replace_forward);
if (!need_replace_forward) {
cnode_morph->clear_inputs_value();
MS_LOG(DEBUG) << "No need replace forward result";
return;
}
auto forward_value = GenNewTensor(manager, equivdout, forward, true);
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
auto value_node = NewValueNode(forward_value);
value_node->set_has_new_value(true);
@ -406,7 +400,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
}
auto out_node = c_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(out_node);
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward));
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), true));
}
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {

View File

@ -147,6 +147,7 @@ class KPrim {
bprop_registry_meta_.clear();
bprop_registry_.clear();
}
FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
private:
FuncGraphPtr GetBprop(const PrimitivePtr &prim);

View File

@ -38,7 +38,6 @@
namespace mindspore {
namespace ad {
using PatternListType = std::initializer_list<BaseRef>;
KPrim g_k_prims;
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
@ -76,6 +75,23 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
return func_graph;
}
FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) {
FuncGraphPtr bprop_fg = nullptr;
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
bprop_fg = iter->second;
}
if (bprop_fg == nullptr) {
bprop_fg = GetBprop(prim);
if (bprop_fg != nullptr) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
}
}
return bprop_fg;
}
FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
std::string func_name = "_fprop_" + prim->name();

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
#include <memory>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace ad {
class KPynativeCell {
public:
virtual ~KPynativeCell() = default;
virtual void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) = 0;
// Grad for cell which may have user passed front propagate FuncGraph.
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
// cell. op_args: the arguments list of each input parameters.
// out: the op result.
// fprop_fg: user defined back propagate cnode which output is the bprop_fg.
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
virtual bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &fprop_fg) = 0;
};
using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
// bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after
// just parsed. will have prototype:
// (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
// op_args: the arguments list of each input parameters.
// out: the op result.
// return: the returned funcgraph should have the same prototype.
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out);
// Start building back propagate funcgraph for this cell.
// cell_inputs: the input parameter list of this cell except the weights;
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
const std::vector<ValuePtr> &input_param_values);
// Return the back propagate funcgraph for this cell.
// weights: weights parameters used in this cell.
// grad_inputs: return sensitivity for input parameters;
// grad_weights: return sensitivity for weights;
// has_sens_arg: caller will pass sens args;
// return: the returned funcgraph will have prototype:
// if has_sens_arg is true
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...,
// sens_out)
// else:
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
// if build_formal_param is true
// each cnode in primal funcgraph is replaced by formal cnode
// else:
// each cnode in primal funcgraph is replaced by value node
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
bool grad_weights, bool has_sens_arg = false, bool build_formal_param = false);
// Grad for each operation.
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
// op_args: the arguments list of each input parameters.
// out: the op result.
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out);
// Grad for cell which may have user defined back propagate function.
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
// cell. op_args: the arguments list of each input parameters.
// out: the op result.
// bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed.
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out, const FuncGraphPtr &bprop_fg);
// Clear all static resources that used in grad process
void ClearKPynativeCellStaticRes();
} // namespace ad
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_

View File

@ -49,6 +49,7 @@
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
namespace mindspore {
namespace opt {
@ -241,6 +242,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// switch_layer defer inline
switch_layer_defer_inline_ =
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
bool_scalar_eliminate_ = MakeSubstitution(std::make_shared<BoolScalarEliminate>(), "bool_scalar_eliminate_", IsCNode);
}
ResolveIRPassLib::ResolveIRPassLib() {

View File

@ -149,6 +149,9 @@ class OptimizeIRPassLib {
// Pynative Eliminate
SubstitutionPtr pynative_eliminate_;
// Eliminate getattr bool scalar
SubstitutionPtr bool_scalar_eliminate_;
};
// the collection of irpass for resolve action

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
#include "frontend/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
AnfNodePtr BoolScalarEliminate::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
if (!cnode->IsApply(prim::kPrimGetAttr)) {
return nullptr;
}
auto vnode = cnode->input(1)->cast<ValueNodePtr>();
if (vnode == nullptr) {
return nullptr;
}
if (!vnode->value()->isa<BoolImm>()) {
return nullptr;
}
auto res = optimizer->resource();
auto manager = res->manager();
auto &node_users = manager->node_users();
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return nullptr;
}
AnfNodeIndexSet node_idx_set = iter->second;
for (auto &item : node_idx_set) {
manager->Replace(item.first, vnode);
}
return nullptr;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H
#include "ir/func_graph.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "ir/pattern_matcher.h"
namespace mindspore {
namespace opt {
namespace irpass {
class BoolScalarEliminate : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H

View File

@ -89,15 +89,12 @@ class InlinerBase : public AnfVisitor {
: use_move_(use_move), criterions_(criterions) {}
~InlinerBase() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>()) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 1 || !IsValueNode<FuncGraph>(inputs[0])) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->size() < 1) {
return nullptr;
}
auto &inputs = cnode->inputs();
// G
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
@ -109,19 +106,7 @@ class InlinerBase : public AnfVisitor {
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
// All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
// Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
bool is_match = false;
for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
is_match = true;
for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
if (!criterion(this, fg, node)) {
is_match = false;
break;
}
}
if (is_match) {
break;
}
}
bool is_match = ApplyCriterions(node, fg);
if (!is_match) {
return nullptr;
}
@ -137,22 +122,9 @@ class InlinerBase : public AnfVisitor {
if (IsUniqueUse(nullptr, fg, nullptr)) {
// For the single used fg, including non-after and after not matched above,
// we move the whole fg nodes.
if (use_move_) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
// The other branch calling the last after block.
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
// Check if parameters' changed.
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
if (param_simplified_caller != nullptr) {
return param_simplified_caller;
}
auto ret_node = InlineForUniqueUse(node, fg, args, inputs);
if (ret_node != nullptr) {
return ret_node;
}
} else {
// We don't expand the middle multiple used after block, except the last one.
@ -171,6 +143,45 @@ class InlinerBase : public AnfVisitor {
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
const std::vector<AnfNodePtr> &inputs) {
if (use_move_) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
// The other branch calling the last after block.
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
// Check if parameters' changed.
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
if (param_simplified_caller != nullptr) {
return param_simplified_caller;
}
}
return nullptr;
}
bool ApplyCriterions(const AnfNodePtr &node, const FuncGraphPtr &fg) {
bool is_match = false;
for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
is_match = true;
for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
if (!criterion(this, fg, node)) {
is_match = false;
break;
}
}
if (is_match) {
break;
}
}
return is_match;
}
void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
const FuncGraphPtr &fg) {
auto params = fg->parameters();
@ -289,9 +300,9 @@ class InlinerBase : public AnfVisitor {
};
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
auto &cnodes = fg->func_graph_cnodes_index();
const auto &users = fg->func_graph_cnodes_index();
int64_t n_use = std::accumulate(
cnodes.begin(), cnodes.end(), 0,
users.begin(), users.end(), 0,
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
return n_use == 1;
}

View File

@ -154,6 +154,7 @@ class TupleListGetitemConstEliminator : public AnfVisitor {
if (is_match_) {
auto out = NewValueNode((*tuple_)[id_]);
out->set_has_new_value(has_new_value_);
out->set_abstract((*tuple_)[id_]->ToAbstract());
return out;
}
return nullptr;

View File

@ -126,8 +126,8 @@ static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &n
return nullptr;
}
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo,
bool change, size_t seen) {
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node,
std::vector<AnfNodePtr> *todo, bool change, size_t seen) {
if (IsValueNode<FuncGraph>(node)) {
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
}
@ -164,16 +164,17 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
#endif
FuncGraphManagerPtr manager = optimizer->manager();
auto seen = NewSeenGeneration();
// 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.clear();
// 1024 is for the initial capacity of vector
std::vector<AnfNodePtr> todo;
todo.reserve(1024);
todo.emplace_back(func_graph->output());
bool changes = false;
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) {
AnfNodePtr node = todo.front();
todo.pop_front();
size_t node_idx = 0;
while (node_idx < todo.size()) {
AnfNodePtr node = todo[node_idx];
node_idx++;
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
continue;
@ -206,16 +207,17 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
#endif
FuncGraphManagerPtr manager = optimizer->manager();
auto seen = NewSeenGeneration();
// 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.clear();
// 1024 is for the initial capacity of vector
std::vector<AnfNodePtr> todo(0);
todo.reserve(1024);
todo.emplace_back(root_node);
bool changes = false;
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) {
AnfNodePtr node = todo.front();
todo.pop_front();
size_t node_idx = 0;
while (node_idx < todo.size()) {
AnfNodePtr node = todo[node_idx];
node_idx++;
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
continue;

View File

@ -8,6 +8,7 @@ file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"pipeline_split.cc"
"parse/*.cc"
"static_analysis/*.cc"
"prim_bprop_optimizer.cc"
)

View File

@ -522,6 +522,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
MS_LOG(EXCEPTION) << "TaskEmit args error";
}
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
auto context_ptr = MsContext::GetInstance();
std::string backend = MsContext::GetInstance()->backend_policy();

View File

@ -531,6 +531,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
data_converter::SetObjGraphValue(obj_key, func_graph);
}
return func_graph;
}
namespace data_converter {

View File

@ -0,0 +1,254 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unordered_set>
#include <set>
#include <vector>
#include <string>
#include <memory>
#include "pipeline/jit/parse/parse_dynamic.h"
#include "mindspore/core/ir/cell.h"
namespace mindspore::parse {
static std::unordered_set<std::string> cell_input_args_;
static const std::set<std::string> ignore_judge_dynamic_cell = {
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
static const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
parse::NAMED_PRIMITIVE_NAMECONSTANT,
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type) {
MS_EXCEPTION_IF_NULL(ast);
if (py::isinstance<py::none>(node)) {
MS_LOG(DEBUG) << "Get none type node!";
return "";
}
auto node_type = ast->GetNodeType(node);
MS_EXCEPTION_IF_NULL(node_type);
// Check node type
parse::AstMainType node_main_type = node_type->main_type();
if (node_main_type != type) {
MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
return "";
}
std::string node_name = node_type->node_name();
MS_LOG(DEBUG) << "Ast node is " << node_name;
return node_name;
}
void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
MS_EXCEPTION_IF_NULL(ast);
py::list args = ast->GetArgs(fn_node);
for (size_t i = 1; i < args.size(); i++) {
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
MS_LOG(DEBUG) << "Input arg name: " << arg_name;
cell_input_args_.emplace(arg_name);
}
}
bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse if/while expr";
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
if (comparators_node.empty()) {
MS_LOG(DEBUG) << "Get comparators node failed!";
return false;
}
auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
// while self.a > self.b and changed self.a or self.b
if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
std::string left_variable;
if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
}
auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
std::string right_variable;
if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
right_variable =
py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
}
return ParseBodyContext(ast, node, {left_variable, right_variable});
}
// if a[0]
if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
}
MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
return true;
}
}
// if flag:
if (node_name == parse::NAMED_PRIMITIVE_NAME) {
std::string id = py::cast<std::string>(test_node.attr("id"));
if (cell_input_args_.find(id) != cell_input_args_.end()) {
return true;
}
}
return false;
}
bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse assign expr";
py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
if (node_name == parse::NAMED_PRIMITIVE_CALL) {
py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
if (py::isinstance<py::none>(value_in_slice_node)) {
MS_LOG(DEBUG) << "Parse value node is none!";
return false;
}
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
std::string id;
if (py::hasattr(value_in_slice_node, "id")) {
id = py::cast<std::string>(value_in_slice_node.attr("id"));
}
if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() ||
(!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) {
return true;
}
}
}
return false;
}
bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
const std::vector<std::string> &compare_prim) {
MS_LOG(DEBUG) << "Parse augassign expr";
bool ret = false;
if (compare_prim.empty()) {
return ret;
}
py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
if (py::isinstance<py::none>(target_node)) {
MS_LOG(DEBUG) << "Parse target node is none!";
return ret;
}
py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
if (py::isinstance<py::none>(value_node)) {
MS_LOG(DEBUG) << "Parse value node is none!";
return ret;
}
std::string assign_prim;
if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
}
auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
if (iter != compare_prim.end()) {
ret = true;
}
return ret;
}
bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse for expr";
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
if (py::isinstance<py::none>(body_node)) {
MS_LOG(DEBUG) << "Parse body of for expression is none!";
return false;
}
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
size_t count = LongToSize(pcount);
MS_LOG(DEBUG) << "The for nodes count in body is " << count;
for (size_t i = 0; i < count; ++i) {
auto it = py::cast<py::list>(body_node)[i];
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
return true;
}
}
return false;
}
bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
const std::vector<std::string> &compare_prim) {
MS_EXCEPTION_IF_NULL(ast);
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
if (py::isinstance<py::none>(func_obj)) {
MS_LOG(DEBUG) << "Parse body of cell is none!";
return false;
}
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
size_t count = IntToSize(pcount);
MS_LOG(DEBUG) << "The nodes count in body is " << count;
bool ret = false;
for (size_t i = 0; i < count; ++i) {
auto node = py::cast<py::list>(func_obj)[i];
const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
ret = ParseAssignExprNode(ast, node);
} else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
ret = ParseAugAssignExprNode(ast, node, compare_prim);
} else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
ret = ParseForExprNode(ast, node);
} else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
ret = ParseIfWhileExprNode(ast, node);
}
if (ret) {
MS_LOG(INFO) << "Current cell is dynamic!";
break;
}
}
return ret;
}
std::string DynamicParser::GetCellInfo(const py::object &cell) {
if (py::isinstance<Cell>(cell)) {
auto c_cell = py::cast<CellPtr>(cell);
MS_EXCEPTION_IF_NULL(c_cell);
auto cell_info = c_cell->ToString();
return cell_info;
}
return "";
}
bool DynamicParser::IsDynamicCell(const py::object &cell) {
std::string cell_info = GetCellInfo(cell);
if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
return false;
}
// Using ast parse to check whether the construct of cell will be changed
auto ast = std::make_shared<parse::ParseAst>(cell);
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
if (!success) {
MS_LOG(ERROR) << "Parse code to ast tree failed";
return false;
}
py::object fn_node = ast->GetAstNode();
// get the name of input args as the initialize of dynamic_variables
ParseInputArgs(ast, fn_node);
// parse body context
bool ret = false;
ret = ParseBodyContext(ast, fn_node);
cell_input_args_.clear();
return ret;
}
} // namespace mindspore::parse

View File

@ -0,0 +1,52 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_DYNAMIC_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_DYNAMIC_H_
#include <vector>
#include <memory>
#include <string>
#include "pipeline/jit/parse/parse.h"
namespace mindspore::parse {
class DynamicParser {
public:
DynamicParser() = default;
~DynamicParser() = default;
// Check cell struct
static bool IsDynamicCell(const py::object &cell);
private:
static std::string GetCellInfo(const py::object &cell);
static void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
static bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
const std::vector<std::string> &compare_prim = {});
static bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
static bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
static bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
const std::vector<std::string> &compare_prim = {});
static bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
static std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type);
};
} // namespace mindspore::parse
#endif

View File

@ -41,6 +41,7 @@
#include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
#include "pipeline/pynative/pynative_execute.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#if (ENABLE_CPU && !_WIN32)
@ -75,6 +76,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
return true;
}
bool TransformTopGraphPass(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Transform top graph error.";
}
FuncGraphPtr func_graph = res->func_graph();
if (opt::FuncGraphHasTupleInput(func_graph)) {
opt::GraphTupleParamTransform graph_trans;
func_graph = graph_trans(func_graph, res->manager());
res->set_func_graph(func_graph);
AbstractBasePtrList abs_spec_list;
auto &params = func_graph->parameters();
std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
[](AnfNodePtr node) { return node->abstract(); });
res->set_args_spec(abs_spec_list);
}
return true;
}
bool CleanAfterOptAPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
@ -93,6 +112,104 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
return true;
}
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
irpass.pynative_eliminate_,
});
opt::irpass::ResolveIRPassLib resolve_irpass;
opt::OptPassConfig resolver_prim = opt::OptPassConfig({
resolve_irpass.resolver_resolve_and_getattr_,
resolve_irpass.resolver_resolve_,
resolve_irpass.resolver_getattr_,
});
opt::OptPassConfig switch_simplify = opt::OptPassConfig({
irpass.switch_simplify_,
});
opt::OptPassConfig inline_opt = opt::OptPassConfig({
irpass.inline_,
});
opt::OptPassConfig bool_scalar_eliminate = opt::OptPassConfig({
irpass.bool_scalar_eliminate_,
});
OptPassGroupMap map({{"ad_eliminate_", pynative_eliminate},
{"ad_resolver_prim", resolver_prim},
{"ad_inline_", inline_opt},
{"bool_scalar_eliminate_", bool_scalar_eliminate},
{"ad_switch_simplify_", switch_simplify}});
auto prim_bprop_opt_step_1 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_1", res, map);
FuncGraphPtr func_graph = res->func_graph();
WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_1"))[&prim_bprop_opt_step_1, &func_graph]() {
func_graph = prim_bprop_opt_step_1->step(func_graph, true);
};
return func_graph;
}
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
opt::OptPassConfig switch_simplify_ = opt::OptPassConfig({
irpass.switch_simplify_,
irpass.reduce_eliminate_,
});
opt::OptPassConfig inline_ = opt::OptPassConfig({
irpass.inline_,
});
opt::OptPassConfig zero_like_fill_zero_ = opt::OptPassConfig({
irpass.zero_like_fill_zero_,
});
auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
return ReAutoMonad(root);
};
OptPassGroupMap map({{"ad_renormalize", opt::OptPassConfig::Renormalize()},
{"ad_inline_", inline_},
{"ad_switch_simplify_", switch_simplify_},
{"ad_zero_like_fill_zero_", zero_like_fill_zero_},
{"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)}});
auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", res, map);
FuncGraphPtr func_graph = res->func_graph();
WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_2"))[&prim_bprop_opt_step_2, &func_graph]() {
func_graph = prim_bprop_opt_step_2->step(func_graph, true);
};
return func_graph;
}
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (!TransformTopGraphPass(res)) {
MS_LOG(EXCEPTION) << "Run TransformTopGraphPass failed";
}
opt::irpass::OptimizeIRPassLib irpass;
opt::OptPassConfig bg_final_opt_ = opt::OptPassConfig({
irpass.inline_,
irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_set_item_eliminator_,
irpass.depend_value_elim_,
irpass.reshape_eliminate_,
irpass.switch_simplify_,
});
OptPassGroupMap map({{"ad_final_opt_", bg_final_opt_}});
if (pynative::PynativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
}
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map);
FuncGraphPtr func_graph = res->func_graph();
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
func_graph = bprop_graph_final_opt->step(func_graph, true);
};
return func_graph;
}
namespace {
bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
@ -472,24 +589,6 @@ bool CconvPass(const ResourcePtr &res) {
return true;
}
bool TransformTopGraphPass(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Transform top graph error.";
}
FuncGraphPtr func_graph = res->func_graph();
if (opt::FuncGraphHasTupleInput(func_graph)) {
opt::GraphTupleParamTransform graph_trans;
func_graph = graph_trans(func_graph, res->manager());
res->set_func_graph(func_graph);
AbstractBasePtrList abs_spec_list;
auto &params = func_graph->parameters();
std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
[](AnfNodePtr node) { return node->abstract(); });
res->set_args_spec(abs_spec_list);
}
return true;
}
bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {

View File

@ -24,6 +24,12 @@
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace opt {
namespace irpass {
class OptimizeIRPassLib;
} // namespace irpass
} // namespace opt
namespace pipeline {
using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
@ -40,6 +46,9 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res);
bool InferenceOptPreparePass(const ResourcePtr &res);
void ReclaimOptimizer();
bool PynativeOptPass(const ResourcePtr &res);
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res);
} // namespace pipeline
} // namespace mindspore

View File

@ -49,6 +49,7 @@
#include "utils/shape_utils.h"
#include "utils/info.h"
#include "load_mindir/load_model.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#if (ENABLE_CPU && !_WIN32)
#include "ps/constants.h"
#include "ps/util.h"
@ -1166,6 +1167,8 @@ void ClearResAtexit() {
}
#endif
ad::g_k_prims.clear();
ad::ClearKPynativeCellStaticRes();
PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
abstract::ClearPrimEvaluatorMap();
compile::ClearConvertCache();

View File

@ -0,0 +1,360 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "ir/func_graph_cloner.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "pipeline/jit/pass.h"
namespace mindspore {
namespace pipeline {
void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
// args_value_using_info_ contains out
if (args_value_using_info_.size() != op_args.size() + 1) {
MS_LOG(EXCEPTION) << "param size :" << args_value_using_info_.size()
<< " of bp_graph:" << opt_func_graph_->ToString()
<< " not match input arguments num:" << op_args.size();
}
ValuePtrList new_args(op_args);
new_args.emplace_back(out);
TryFreeOneValue(new_args, args_value_using_info_);
}
void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
const std::vector<ParamUsingInfo> &param_info_vec) {
if (param_info_vec.size() != op_args.size()) {
MS_LOG(EXCEPTION) << "param size :" << param_info_vec.size() << " of bp_graph:" << opt_func_graph_->ToString()
<< " not match input arguments num:" << op_args.size();
}
for (size_t i = 0; i < op_args.size(); ++i) {
if (!param_info_vec[i].using_flg_ && !param_info_vec[i].tuple_flg_ && op_args[i]->isa<tensor::Tensor>()) {
auto value = op_args[i]->cast<tensor::TensorPtr>();
value->set_device_address(nullptr);
} else if (param_info_vec[i].tuple_flg_ && op_args[i]->isa<ValueTuple>()) {
auto value = op_args[i]->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value);
TryFreeOneValue(value->value(), param_info_vec[i].sub_using_info_);
}
}
}
void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
if (analysis_finish_flg_) {
return;
}
MS_EXCEPTION_IF_NULL(opt_func_graph_);
auto &params = opt_func_graph_->parameters();
const auto &node_users = manager->node_users();
args_value_using_info_.resize(params.size() - 1);
// analysis value using flg except dout
for (size_t i = 0; i < params.size() - 1; ++i) {
auto &param = params[i];
auto &arg_info = args_value_using_info_[i];
ArgInfoRefresh(param, &arg_info);
AnalysisNodeUsingInfo(node_users, param, &arg_info);
}
analysis_finish_flg_ = true;
}
void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users,
const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info) const {
auto iter = node_users.find(param);
if (iter == node_users.end()) {
arg_info->using_flg_ = false;
return;
}
// tensor return directly
if (!arg_info->tuple_flg_) {
arg_info->using_flg_ = true;
return;
}
// specific process for tuple parameter, may only partial items used
const auto &users_info = iter->second;
for (auto &user_info : users_info) {
auto user_node = user_info.first;
arg_info->using_flg_ = true;
MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
for (auto &sub_info : arg_info->sub_using_info_) {
sub_info.using_flg_ = true;
}
} else {
AalysisForTupleGetItem(node_users, param, arg_info, user_node);
}
}
}
void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &node_users,
const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
MS_EXCEPTION_IF_NULL(arg_info);
auto cnode = user_node->cast<CNodePtr>();
if (cnode->size() != 3) {
MS_LOG(EXCEPTION) << "TupleGetItem Node:" << user_node->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
<< "input size is:" << cnode->size();
}
auto idx_node = cnode->input(2);
if (!idx_node->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
<< " unexpected used by node:" << user_node->ToString()
<< " TupleGetItem idx node:" << idx_node->ToString();
}
auto vnode = idx_node->cast<ValueNodePtr>();
auto value_ptr = vnode->value();
if (value_ptr == nullptr || !value_ptr->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
<< " unexpected used by node:" << user_node->ToString()
<< " TupleGetItem idx node:" << idx_node->ToString() << " idx Value :" << value_ptr;
}
auto idx = value_ptr->cast<Int64ImmPtr>()->value();
arg_info->sub_using_info_[idx].using_flg_ = true;
ArgInfoRefresh(cnode, &(arg_info->sub_using_info_[idx]));
if (arg_info->tuple_flg_) {
AnalysisNodeUsingInfo(node_users, cnode, &(arg_info->sub_using_info_[idx]));
}
}
void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info) const {
MS_EXCEPTION_IF_NULL(arg_info);
auto abs = param->abstract();
if (abs->isa<abstract::AbstractTensor>()) {
arg_info->tuple_flg_ = false;
MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor";
} else if (abs->isa<abstract::AbstractTuple>()) {
auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTuple";
arg_info->tuple_flg_ = true;
arg_info->tuple_size_ = abs_tuple->size();
arg_info->sub_using_info_.resize(abs_tuple->size());
} else {
arg_info->tuple_flg_ = false;
}
}
PrimBpropOptimizer &PrimBpropOptimizer::GetPrimBpropOptimizerInst() {
static PrimBpropOptimizer g_prim_bprop_opt;
return g_prim_bprop_opt;
}
void PrimBpropOptimizer::Clear() {
prim_bprop_cache_.clear();
tuple_list_bprop_cache_.clear();
}
// bprop_fg has the signature:
// (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
// c_node contains the prim(input 0) and the input parameters of that prim;
// op_args contains the arguments list of each input parameters, it maybe tensor or tuple
// out contains the out of c_node;
FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node,
const ValuePtrList &op_args, const ValuePtr &out) {
MS_EXCEPTION_IF_NULL(bprop_fg);
MS_EXCEPTION_IF_NULL(c_node);
MS_EXCEPTION_IF_NULL(out);
auto &inputs = c_node->inputs();
if (inputs.size() < 1 || inputs.size() - 1 != op_args.size()) {
MS_LOG(EXCEPTION) << "The parameters num " << inputs.size() - 1 << " not match arguments num " << op_args.size()
<< ", CNode:" << c_node->ToString() << " grap:" << bprop_fg->ToString();
}
if (!IsValueNode<Primitive>(inputs[0])) {
MS_LOG(EXCEPTION) << "CNode:" << c_node->ToString()
<< " not a primitive node, input_0 is:" << inputs[0]->ToString();
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash();
// kPrimHookBackward
bool hookback_flg = IsPrimitiveEquals(prim, prim::kPrimHookBackward);
if (hookback_flg || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
}
return GetOptBpropFromCache(bprop_fg, op_args, out, prim);
}
FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
const ValuePtr &out, const PrimitivePtr &prim) {
abstract::AbstractBasePtrList abs_list;
ArgsToAbs(prim, op_args, &abs_list);
PrimBpropOptGraphLevel2InfoPtr level_2_graph_info;
PrimBpropOptGraphInfoPtr level_1_graph_info;
ECacheQrtRes cache_res = GetOptBpfgFromCache(prim, abs_list, &level_2_graph_info, &level_1_graph_info);
MS_LOG(DEBUG) << "Cache match result " << cache_res << ", prim: " << prim->ToString();
if (cache_res == E_LEVEL_2) {
MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
level_2_graph_info->TryFreeArgsValue(op_args, out);
return BasicClone(level_2_graph_info->opt_func_graph());
}
// do step1 opt
if (cache_res == E_NOT_FOUND) {
bprop_fg->debug_info()->set_name(prim->ToString());
level_1_graph_info = PrimBpropOptStep1(bprop_fg);
prim_bprop_cache_[prim] = level_1_graph_info;
}
FuncGraphPtr level_1_graph = BasicClone(level_1_graph_info->opt_func_graph_);
// do step2 opt
auto new_abs_list = AddOutToAbsList(out, abs_list);
level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
level_2_graph_info->TryFreeArgsValue(op_args, out);
return BasicClone(level_2_graph_info->opt_func_graph());
}
FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
abstract::AbstractBasePtrList abs_list;
ArgsToAbs(prim, op_args, &abs_list);
if (!hook_flg) {
auto iter = tuple_list_bprop_cache_.find(std::pair(prim, abs_list));
if (iter != tuple_list_bprop_cache_.end()) {
return BasicClone(iter->second);
}
}
// do step1 opt
bprop_fg->debug_info()->set_name(prim->ToString());
auto level_1_graph_info = PrimBpropOptStep1(bprop_fg);
// do step2 opt
auto new_abs_list = AddOutToAbsList(out, abs_list);
auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
level_2_graph_info->TryFreeArgsValue(op_args, out);
if (!hook_flg) {
tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
}
return level_2_graph_info->opt_func_graph();
}
PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) {
opt::irpass::OptimizeIRPassLib irpass;
auto level_1_graph_info = std::make_shared<PrimBpropOptGraphInfo>();
auto prim_bprop_opt_res = std::make_shared<pipeline::Resource>();
auto prim_bprop_opt_manage = prim_bprop_opt_res->manager();
auto graph_for_cache = BasicClone(bprop_fg);
prim_bprop_opt_res->set_func_graph(graph_for_cache);
prim_bprop_opt_manage->AddFuncGraph(graph_for_cache);
auto opt_bprop_fg = PrimBpOptPassStep1(irpass, prim_bprop_opt_res);
level_1_graph_info->opt_func_graph_ = opt_bprop_fg;
return level_1_graph_info;
}
void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
const abstract::AbstractBasePtrList &abs_list_input) {
auto &params = bprop_fg->parameters();
if (abs_list_input.size() != params.size()) {
MS_LOG(EXCEPTION) << "Param num:" << params.size() << " not match inputs num " << abs_list_input.size();
}
for (size_t i = 0; i < abs_list_input.size(); i++) {
params[i]->set_abstract(abs_list_input[i]);
}
}
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) {
opt::irpass::OptimizeIRPassLib irpass;
BindAbsToParameters(bprop_fg, abs_list_input);
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
auto manager = resource->manager();
resource->set_func_graph(bprop_fg);
manager->AddFuncGraph(bprop_fg);
auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource);
auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
level_2_graph_info->AnalysisArgUsingInfo(manager);
return level_2_graph_info;
}
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto after_opt_bg = BpropGraphFinalOptPass(res);
return after_opt_bg;
}
ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
const abstract::AbstractBasePtrList &abs_list,
PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
PrimBpropOptGraphInfoPtr *level_1_graph_info) {
auto attrs_ = prim->attrs();
for (auto &item : attrs_) {
MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
}
auto iter = prim_bprop_cache_.find(prim);
if (iter == prim_bprop_cache_.end()) {
return E_NOT_FOUND;
}
*level_1_graph_info = iter->second;
auto second_iter = (*level_1_graph_info)->graph_level_2_cache_.find(abs_list);
if (second_iter == (*level_1_graph_info)->graph_level_2_cache_.end()) {
return E_LEVEL_1;
}
*level_2_graph_info = second_iter->second;
return E_LEVEL_2;
}
void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
abstract::AbstractBasePtrList *abs_list) {
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
for (size_t i = 0; i < op_args.size(); ++i) {
bool is_const_input =
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
auto &arg_value = op_args[i];
auto arg_abs = arg_value->ToAbstract();
if (!is_const_prim && !is_const_input) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
arg_abs = arg_abs->Broaden(config);
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
}
(*abs_list).emplace_back(arg_abs);
}
}
abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
const abstract::AbstractBasePtrList &abs_list) {
if (!out->isa<tensor::Tensor>() && !out->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Out value not Tensor or Tuple, please check the input arguments.";
}
abstract::AbstractBasePtrList new_abs_list(abs_list);
auto out_abs = out->ToAbstract();
auto config = abstract::AbstractBase::kBroadenTensorOnly;
out_abs = out_abs->Broaden(config);
new_abs_list.emplace_back(out_abs);
new_abs_list.emplace_back(out_abs);
return new_abs_list;
}
} // namespace pipeline
} // namespace mindspore

View File

@ -0,0 +1,185 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
#define MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
#include <vector>
#include <utility>
#include <unordered_map>
#include <memory>
#include "frontend/optimizer/irpass.h"
#include "ir/func_graph.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace pipeline {
struct PrimBpropOptGraphInfo;
class PrimBpropOptGraphLevel2Info;
struct PrimitiveTotalEqual;
struct PrimitiveTupleListHasher;
struct PrimitiveTupleListEqual;
using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>;
using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>;
using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>;
using PrimBpropLevel2Cache =
std::unordered_map<abstract::AbstractBasePtrList, PrimBpropOptGraphLevel2InfoPtr, abstract::AbstractBasePtrListHasher,
abstract::AbstractBasePtrListEqual>;
using PrimTupleListCache =
std::unordered_map<TupleListKey, FuncGraphPtr, PrimitiveTupleListHasher, PrimitiveTupleListEqual>;
struct PrimitiveTupleListHasher {
bool operator()(const TupleListKey &key) const {
abstract::AbstractBasePtrListHasher hasher;
return hasher(key.second);
}
};
struct PrimitiveTupleListEqual {
bool operator()(TupleListKey const &t1, TupleListKey const &t2) const {
MS_EXCEPTION_IF_NULL(t1.first);
MS_EXCEPTION_IF_NULL(t2.first);
if (!(*t1.first == *t2.first)) {
return false;
}
abstract::AbstractBasePtrListEqual cmp;
return cmp(t1.second, t2.second);
}
};
struct PrimitiveTotalEqual {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return *t1 == *t2;
}
};
enum ECacheQrtRes { E_NOT_FOUND, E_LEVEL_1, E_LEVEL_2 };
struct PrimBpropOptGraphInfo {
// the level1 opt func_graph without infer, no shape/type info provide
FuncGraphPtr opt_func_graph_;
// the opt func_graph after infer, func_graph level2 cache
PrimBpropLevel2Cache graph_level_2_cache_;
};
struct ParamUsingInfo {
bool using_flg_{false};
bool tuple_flg_{false};
size_t tuple_size_;
std::vector<ParamUsingInfo> sub_using_info_;
};
class PrimBpropOptGraphLevel2Info {
public:
explicit PrimBpropOptGraphLevel2Info(const FuncGraphPtr &func_graph) : opt_func_graph_(func_graph) {}
const FuncGraphPtr &opt_func_graph() const { return opt_func_graph_; }
void TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out);
void AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager);
private:
void ArgInfoRefresh(const std::shared_ptr<AnfNode> &param, ParamUsingInfo *arg_info) const;
void AnalysisNodeUsingInfo(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info) const;
void TryFreeOneValue(const ValuePtrList &op_args, const std::vector<ParamUsingInfo> &param_info_vec);
void AalysisForTupleGetItem(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const;
private:
// the level2 opt func_graph
FuncGraphPtr opt_func_graph_;
// to indicate arguments value using or not, if not using should free device memory
std::vector<ParamUsingInfo> args_value_using_info_;
bool analysis_finish_flg_{false};
};
class PrimBpropOptimizer {
public:
~PrimBpropOptimizer() = default;
void Clear();
static PrimBpropOptimizer &GetPrimBpropOptimizerInst();
// bprop_fg has the signature:
// (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
// c_node contains the prim(input 0) and the input parameters of that prim;
// op_args contains the arguments list of each input parameters, it maybe tensor or tuple
// out contains the out of c_node;
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out);
// do inline opt for final bprop graph
FuncGraphPtr BpropGraphFinalOpt(const ResourcePtr &res);
private:
PrimBpropOptimizer() = default;
ECacheQrtRes GetOptBpfgFromCache(const PrimitivePtr &prim, const abstract::AbstractBasePtrList &abs_list,
PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
PrimBpropOptGraphInfoPtr *level_1_graph_info);
// converter tensor args to abs value;
void ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args, abstract::AbstractBasePtrList *abs_list);
// add out && dout to abs list
abstract::AbstractBasePtrList AddOutToAbsList(const ValuePtr &out, const abstract::AbstractBasePtrList &abs_list);
// do opt without input info, no infer
PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg);
// do opt with input info
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(const FuncGraphPtr &bprop_fg,
const abstract::AbstractBasePtrList &abs_list_input);
void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input);
FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
const PrimitivePtr &prim);
FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
const PrimitivePtr &prim, bool hook_flg);
private:
// cache optimized bprop graph
PrimBpropCache prim_bprop_cache_;
PrimTupleListCache tuple_list_bprop_cache_;
};
} // namespace pipeline
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H

View File

@ -50,28 +50,22 @@ enum PynativeStatusCode {
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo {
bool is_dynamic_shape = false;
bool is_mixed_precision_cast = false;
size_t next_input_index = 0;
std::string op_name;
std::string op_index;
std::string op_info;
std::string next_op_name = "";
PrimitivePyPtr py_primitive;
AbstractBasePtr abstract;
py::list op_inputs;
std::vector<int64_t> inputs_mask;
bool is_dynamic_shape = false;
std::string next_op_name = "";
bool is_mixed_precision_cast = false;
size_t next_input_index = 0;
};
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
const std::set<std::string> ignore_infer_prim = {"mixed_precision_cast"};
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
const std::set<std::string> ignore_judge_dynamic_cell = {
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
parse::NAMED_PRIMITIVE_NAMECONSTANT,
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
const std::set<std::string> dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup",
"Transpose"};
} // namespace pynative

File diff suppressed because it is too large Load Diff

View File

@ -36,11 +36,12 @@
#include "utils/ms_context.h"
#include "ir/anf.h"
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/ad/kpynative.h"
#include "frontend/operator/composite/composite.h"
namespace mindspore {
namespace pynative {
namespace mindspore::pynative {
namespace py = pybind11;
using cell_id = std::string;
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
@ -52,8 +53,8 @@ struct PrimAbsInfo {
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
using OpIndexWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
using TensorIdWithTensor = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
py::object RunOp(const py::args &args);
@ -64,108 +65,80 @@ struct GraphInfo {
AnfNodePtr output;
OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weights
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
std::vector<std::string> objects;
GraphInfo() = default;
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
};
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
class CellInfo {
public:
CellInfo() = default;
~CellInfo() = default;
CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
: is_custom_bprop_(custom_bprop),
is_dynamic_(has_dynamic),
fg_(std::move(foward_graph)),
cell_id_(std::move(cellid)),
bprop_cell_id_(std::move(bprop_id)) {}
bool is_custom_bprop() const { return is_custom_bprop_; }
void set_is_custom_bprop(bool is_custom_bprop) { is_custom_bprop_ = is_custom_bprop; }
bool is_dynamic() const { return is_dynamic_; }
void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
size_t call_times() const { return call_times_; }
void set_call_times(size_t call_times) { call_times_ = call_times; }
FuncGraphPtr fg() const { return fg_; }
void set_fg(FuncGraphPtr fg) { fg_ = std::move(fg); }
std::string &cell_id() { return cell_id_; }
void set_cell_id(std::string cell_id) { cell_id_ = std::move(cell_id); }
std::string &bprop_cell_id() { return bprop_cell_id_; }
std::vector<std::string> &cell_ops_info() { return cell_ops_info_; }
private:
bool is_custom_bprop_{false}; // Custom bprop
bool is_dynamic_{false}; // Set by has_dynamic_cell
size_t call_times_{0};
FuncGraphPtr fg_{nullptr}; // Forward graph
std::string cell_id_;
std::string bprop_cell_id_;
std::vector<std::string> cell_ops_info_; // All ops info
};
using CellInfoPtr = std::shared_ptr<CellInfo>;
class TopCellInfo {
public:
TopCellInfo() = default;
~TopCellInfo() = default;
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid)
: is_topest_(topest), resource_(std::move(r)), df_builder_(std::move(df)), cell_id_(std::move(cellid)) {}
TopCellInfo(bool topest, size_t grad_order, ResourcePtr r, FuncGraphPtr df, std::string cellid)
: is_topest_(topest),
grad_order_(grad_order),
resource_(std::move(r)),
df_builder_(std::move(df)),
cell_id_(std::move(cellid)) {}
bool is_grad() const { return is_grad_; }
void set_is_grad(bool is_grad) { is_grad_ = is_grad; }
bool is_init_kpynative() const { return is_init_kpynative_; }
void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
bool is_topest() const { return is_topest_; }
size_t grad_order() const { return grad_order_; }
void set_grad_order(size_t grad_order) { grad_order_ = grad_order; }
bool is_dynamic() const { return is_dynamic_; }
void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
bool vm_compiled() const { return vm_compiled_; }
void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; }
bool need_grad() const { return need_grad_; }
void set_need_grad(bool need_grad) { need_grad_ = need_grad; }
bool has_dynamic_cell() const { return has_dynamic_cell_; }
bool is_real_dynamic() const { return is_real_dynamic_; }
void set_is_real_dynamic(bool is_real_dynamic) { is_real_dynamic_ = is_real_dynamic; }
bool ms_function_flag() const { return ms_function_flag_; }
void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
bool need_compile_graph() const { return need_compile_graph_; }
void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
bool forward_already_run() const { return forward_already_run_; }
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
ResourcePtr resource() { return resource_; }
FuncGraphPtr df_builder() { return df_builder_; }
size_t op_num() const { return op_num_; }
void set_op_num(size_t op_num) { op_num_ = op_num; }
std::string &cell_id() { return cell_id_; }
std::string &sens_id() { return sens_id_; }
void set_sens_id(std::string sens_id) { sens_id_ = std::move(sens_id); }
std::string &weights_id() { return weights_id_; }
void set_weights_id(std::string weights_id) { weights_id_ = std::move(weights_id); }
std::string &input_args_id() { return input_args_id_; }
std::string all_op_info() const { return all_op_info_; }
void set_all_op_info(std::string all_op_info) { all_op_info_ = std::move(all_op_info); }
void set_input_args_id(std::string input_args_id) { input_args_id_ = std::move(input_args_id); }
std::vector<CellInfoPtr> &cell_graph_list() { return cell_graph_list_; }
void set_cell_graph_list(const std::vector<CellInfoPtr> &cell_graph_list) { cell_graph_list_ = cell_graph_list; }
std::unordered_set<std::string> &sub_cell_list() { return sub_cell_list_; }
bool IsSubCell(const std::string &cell_id) const;
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
void set_graph_info_map(const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map) {
graph_info_map_ = graph_info_map;
}
void clear() {
cell_graph_list_.clear();
graph_info_map_.clear();
OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
TensorIdWithTensorObject &tensor_id_with_tensor_object() { return tensor_id_with_tensor_object_; }
ad::KPynativeCellPtr k_pynative_cell_ptr() const { return k_pynative_cell_ptr_; }
void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) {
k_pynative_cell_ptr_ = k_pynative_cell_ptr;
}
void clear();
private:
bool is_grad_{false}; // Derivative is calculated
bool is_topest_{false};
bool is_dynamic_{false};
bool vm_compiled_{false};
bool need_grad_{true};
bool has_dynamic_cell_{false};
bool is_real_dynamic_{false};
bool ms_function_flag_{false};
bool is_init_kpynative_{false};
bool forward_already_run_{false};
bool need_compile_graph_{false};
size_t op_num_{0};
size_t grad_order_{0};
ResourcePtr resource_{nullptr};
FuncGraphPtr df_builder_{nullptr};
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
std::string cell_id_;
std::string sens_id_;
std::string weights_id_;
std::string input_args_id_;
std::vector<CellInfoPtr> cell_graph_list_;
std::string all_op_info_;
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
std::unordered_set<std::string> sub_cell_list_;
OpInfoWithTensorId op_info_with_tensor_id_;
TensorIdWithTensorObject tensor_id_with_tensor_object_;
};
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
class DynamicAnalysis;
using DynamicAnalysisPtr = std::shared_ptr<DynamicAnalysis>;
class ForwardExecutor;
using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
@ -174,30 +147,6 @@ class GradExecutor;
using GradExecutorPtr = std::shared_ptr<GradExecutor>;
using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>;
class DynamicAnalysis {
public:
DynamicAnalysis() = default;
~DynamicAnalysis() = default;
// Check cell struct
bool IsDynamicCell(const py::object &cell);
private:
std::string GetCellInfo(const py::object &cell);
void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
const std::vector<std::string> &compare_prim = {});
bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
const std::vector<std::string> &compare_prim = {});
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type);
std::unordered_set<std::string> cell_input_args_;
};
class GradExecutor {
public:
GradExecutor() = default;
@ -227,112 +176,74 @@ class GradExecutor {
FuncGraphPtr curr_g() const;
TopCellInfoPtr top_cell() const;
bool TopCellIsDynamic();
void CheckNeedCompileGraph();
TopCellInfoPtr GetTopCell(const string &cell_id) const;
bool need_renormalize() const { return need_renormalize_; }
void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
bool grad_flag() const { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
bool in_grad_process() const { return in_grad_process_; }
std::string top_cell_id() { return top_cell()->cell_id(); }
bool grad_is_running() const { return grad_is_running_; }
bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
std::string GetCellId(const py::object &obj, const py::args &args);
TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false);
std::stack<std::string> &cell_stack() { return cell_stack_; }
std::vector<TopCellInfoPtr> &top_cell_list() { return top_cell_list_; }
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info);
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real);
void DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &op_out);
void MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &fprop_g, const py::object &out,
const py::args &args, const std::string &graph_phase);
void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
const OpExecInfoPtr &op_exec_info, ValuePtrList *input_values,
CNodePtr *ms_function_cnode);
void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource);
py::object CheckGraph(const py::object &cell, const py::args &args);
void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args, const py::object &phase);
bool need_construct_graph() const { return !graph_stack_.empty() && grad_flag_; }
void set_dynamic_analysis(DynamicAnalysisPtr dynamic_analysis) { dynamic_analysis_ = std::move(dynamic_analysis); }
std::stack<FuncGraphPtr> &graph_stack() { return graph_stack_; }
std::vector<TopCellInfoPtr> &top_cell_list() { return top_cell_list_; }
bool need_replace_forward() const { return need_replace_forward_; }
std::stack<std::string> &cell_op_info_stack() { return cell_op_info_stack_; }
std::unordered_map<std::string, size_t> &op_index_map() { return op_index_map_; }
std::unordered_map<std::string, std::string> &obj_to_forward_id() { return obj_to_forward_id_; }
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
void ClearGrad(const py::object &cell, const py::args &args);
void ClearRes();
void ClearCellRes(const std::string &cell_id = "");
private:
ForwardExecutorPtr forward() const;
DynamicAnalysisPtr dynamic_analysis() const;
bool grad_running() const { return grad_is_running_; }
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
// Higher derivative
bool IsNestedGrad() const;
void AddNestedGradOrder() { ++grad_order_; }
void SubNestedGradOrder();
void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
const std::string &cell_id);
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens);
void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
// Dynamic
bool CheckDynamicCell(const std::string &cell_id);
bool CheckRealDynamicCell(const std::string &cell_id);
void ClearDynamicTopRes(const std::string &cell_id);
void PushCurrentGraphToStack();
void PopGraphStack();
void PushCurrentCellOpInfoToStack();
void PopCurrentCellOpInfoFromStack();
std::string GetCellOpInfo();
void ReplaceCellOpInfoByCellId(const std::string &cell_id);
void SwitchTopcell();
size_t GetHighOrderStackSize() const { return high_order_stack_.size(); }
void MakeNestedCnode(const py::object &cell, const std::string &cell_id, const py::args &forward_args,
const ResourcePtr &resource, const py::object &out);
void PushCellStack(const std::string &cell_id);
void PopCellStack();
void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell);
TopCellInfoPtr PopHighOrderGraphStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = "");
bool IsFirstGradStep();
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
bool IsTopGraph(const std::string &cell_id);
bool IsTopestGraph(const std::string &cell_id);
bool IsBpropGraph(const std::string &cell_id);
bool IsGradBefore(const std::string &cell_id);
bool CheckCellGraph(const std::string &cell_id);
bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
bool is_grad);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
bool CheckCellChanged(const std::string &cell_id);
void UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled);
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
void InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args);
void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest);
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
const std::string &out_id, const py::args &args);
bool EndBpropGraph(const string &cell_id);
FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
const std::string &cell_id, const py::args &args);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
py::object *sens = nullptr);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
py::args *forward_args = nullptr);
void GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
void SetTopCellTensorId(const std::string &cell_id);
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
void SetGradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
size_t arg_size, const std::string &cell_id);
FuncGraphPtr GetBpropGraph(const GradOperationPtr &grad, const py::object &cell,
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
const std::vector<int64_t> &index_sequence, bool is_param = false);
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
// Memory clean between steps
void ClearResidualRes(const std::string &cell_id);
void ClearCnodeRes(const AnfNodePtr &node);
void CleanPreMemoryInValueNode();
void SaveTensorsInValueNode(const ResourcePtr &resource);
void SaveAllValueNodeTensors(const FuncGraphPtr &graph);
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
top_cell()->graph_info_map()[g]->objects.push_back(obj);
}
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
bool is_param = false);
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) {
@ -346,33 +257,33 @@ class GradExecutor {
const std::vector<int64_t> &index) {
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
}
void CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g, const py::object &out);
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
private:
size_t grad_order_{0};
bool grad_flag_{false};
bool in_bprop_process_{false};
bool in_grad_process_{false};
bool has_dynamic_cell_{false};
bool need_replace_forward_{true};
bool need_renormalize_{false};
bool grad_is_running_{false};
int custom_bprop_cell_count_{0};
size_t grad_order_{0};
// Only set in high grad
FuncGraphPtr curr_g_{nullptr};
// For clear pre top res
TopCellInfoPtr pre_top_cell_{nullptr};
TopCellInfoPtr top_cell_{nullptr};
std::unordered_map<std::string, size_t> op_index_map_;
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
std::unordered_set<tensor::TensorPtr> all_value_node_tensors_;
std::unordered_map<std::string, std::string> obj_to_forward_id_;
// Records forwrad graph, the bottom is top graph
std::stack<FuncGraphPtr> graph_stack_;
// Records op info of every cell, the bottom is op info of top cell
std::stack<std::string> cell_op_info_stack_;
// Records forwrad cell, the bottom is top cell
std::stack<std::string> cell_stack_;
// For high grad of bprop
std::stack<bool> bprop_grad_stack_;
std::vector<std::string> bprop_cell_list_;
// For high grad order
std::stack<std::pair<FuncGraphPtr, TopCellInfoPtr>> high_order_stack_;
// Use vector for keep order
std::vector<TopCellInfoPtr> top_cell_list_;
// Record all top cell which has been ran
std::map<cell_id, TopCellInfoPtr> already_run_top_cell_;
// Use vector for keep order
ForwardExecutorWeakPtr forward_executor_;
DynamicAnalysisPtr dynamic_analysis_;
};
class ForwardExecutor {
@ -388,13 +299,9 @@ class ForwardExecutor {
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
std::unordered_map<std::string, OpIndexWithTensorId> &cell_op_index_with_tensor_id() {
return cell_op_index_with_tensor_id_;
}
std::unordered_map<std::string, TensorIdWithTensor> &cell_tensor_id_with_tensor() {
return cell_tensor_id_with_tensor_;
}
void ClearRes();
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
private:
GradExecutorPtr grad() const;
@ -404,17 +311,14 @@ class ForwardExecutor {
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
PynativeStatusCode *status);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
bool *is_find);
// Update the abstract and device address info of value node and tensors in bprop graph
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
py::object GetOpOutputObject(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
const AnfNodePtr &CNode, bool out_abstract_existed);
// Mix precision
void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index);
@ -429,9 +333,6 @@ class ForwardExecutor {
GradExecutorWeakPtr grad_executor_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
// Used for runop and replace forward result of grad graph
std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_;
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
// Used to cache cast struct
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
};
@ -444,7 +345,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
forward_executor_ = std::make_shared<ForwardExecutor>();
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
grad_executor_->set_dynamic_analysis(std::make_shared<DynamicAnalysis>());
forward_executor_->set_grad_executor(grad_executor_);
}
return executor_;
@ -459,6 +359,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
ForwardExecutorPtr forward_executor();
void set_grad_flag(bool flag);
std::string graph_phase() const { return graph_phase_; }
void set_graph_phase(std::string graph_phase) { graph_phase_ = graph_phase; }
void GradMsFunction(const py::object &out, const py::args &args);
void NewGraph(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
@ -468,9 +371,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// Used by graph clean
bool GetIsDynamicCell();
bool need_replace_forward() { return grad_executor()->need_replace_forward(); }
// Cell destruct will call
void ClearCell(const std::string &flag = "");
void ClearCell(const std::string &cell_id);
void ClearGrad(const py::object &cell, const py::args &args);
// Abnormal existed
void ClearRes();
@ -480,6 +382,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
private:
PynativeExecutor() = default;
// The graph phase is used to obtain backend graph that is complied by ms_function
std::string graph_phase_;
static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_;
static ForwardExecutorPtr forward_executor_;
@ -493,7 +397,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
} // namespace pynative
} // namespace mindspore
} // namespace mindspore::pynative
#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_

View File

@ -288,6 +288,9 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
auto primitive_py = primitive->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(primitive_py);
this->set_hook(primitive_py->hook());
if (primitive_py->HasAttr(kBpropAttrName)) {
this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
}
}
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {

View File

@ -292,6 +292,8 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
}
// Clear the output address of graph.
ClearOutputAddress(inputs, value_nodes, execution_order);
graph_output_map_.erase(graph_id);
}
void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) {

View File

@ -286,13 +286,15 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *
if (element->isa<tensor::Tensor>()) {
auto tensor = element->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
tensors->push_back(tensor);
tensors->emplace_back(tensor);
} else if (element->isa<ValueTuple>()) {
TensorValueToTensor(element, tensors);
}
}
} else if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
tensors->push_back(tensor);
tensors->emplace_back(tensor);
}
}
} // namespace mindspore

View File

@ -107,7 +107,7 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std:
result.graph_id = graph_id;
graph_id_map_[graph_id] = result;
if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) {
if (!pynative_mode) {
(void)g_ConvertCache.emplace(segment, result);
}
return result;

View File

@ -214,7 +214,10 @@ class _MindSporeFunction:
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)
return self._executor(tuple(new_inputs), phase)
output = self._executor(tuple(new_inputs), phase)
if context.get_context("mode") == context.PYNATIVE_MODE:
_pynative_exec.set_graph_phase(phase)
return output
def ms_function(fn=None, obj=None, input_signature=None):
@ -272,10 +275,15 @@ def ms_function(fn=None, obj=None, input_signature=None):
def wrap_mindspore(func):
@wraps(func)
def staging_specialize(*args):
input_args = args
process_obj = obj
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
input_args = args[1:]
process_obj = args[0]
return _MindSporeFunction(func, input_signature, process_obj)(*args)
out = _MindSporeFunction(func, input_signature, process_obj)(*args)
if context.get_context("mode") == context.PYNATIVE_MODE:
_pynative_exec.grad_ms_function(out, *input_args)
return out
return staging_specialize
@ -376,6 +384,12 @@ class _PynativeExecutor:
def sync(self):
self._executor.sync()
def grad_ms_function(self, output, *args):
self._executor.grad_ms_function(output, *args)
def set_graph_phase(self, phase):
self._executor.set_graph_phase(phase)
def set_grad_flag(self, flag):
self._executor.set_grad_flag(flag)

View File

@ -498,7 +498,13 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());
auto input_type = primitive->GetAttr("dst_type")->cast<TypePtr>();
auto attr = primitive->GetAttr("dst_type");
if (attr == nullptr) {
attr = args_spec_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(attr);
primitive->set_attr("dst_type", attr);
}
auto input_type = attr->cast<TypePtr>();
auto ret = std::make_shared<AbstractTensor>(input_type, input_x->shape()->shape());
return ret;
}

View File

@ -165,8 +165,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
}
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
<< index_value->ToString();
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index->ToString();
}
auto idx_v = GetValue<int64_t>(index_value);
std::size_t nelems = queue->elements().size();

View File

@ -41,8 +41,8 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
return res;
}
size_t seen = NewSeenGeneration();
std::deque<AnfNodePtr> todo(1024);
todo.clear();
std::vector<AnfNodePtr> todo;
todo.reserve(1024);
todo.push_back(root);
while (!todo.empty()) {
@ -95,14 +95,15 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
// search the cnodes inside this graph only
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts) {
std::deque<CNodePtr> todo(1024);
todo.clear();
std::vector<CNodePtr> todo;
todo.reserve(1024);
todo.insert(todo.end(), starts.begin(), starts.end());
std::vector<CNodePtr> sorted_nodes;
auto seen = NewSeenGeneration();
while (!todo.empty()) {
CNodePtr top = todo.front();
todo.pop_front();
size_t top_idx = 0;
while (top_idx < todo.size()) {
CNodePtr top = todo[top_idx];
top_idx++;
sorted_nodes.push_back(top);
auto inputs = top->inputs();
for (auto &item : inputs) {
@ -121,13 +122,14 @@ std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &s
// search the cnode match the predicate inside this graph only
CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate) {
std::deque<CNodePtr> todo(1024);
todo.clear();
std::vector<CNodePtr> todo;
todo.reserve(1024);
todo.insert(todo.end(), starts.begin(), starts.end());
auto seen = NewSeenGeneration();
while (!todo.empty()) {
CNodePtr top = todo.front();
todo.pop_front();
size_t top_idx = 0;
while (top_idx < todo.size()) {
CNodePtr top = todo[top_idx];
top_idx++;
if (match_predicate(top)) {
return top;
}
@ -147,13 +149,16 @@ CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const Matc
}
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
std::deque<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo;
todo.reserve(128);
todo.push_back(root);
std::vector<FuncGraphPtr> sorted;
sorted.reserve(128);
auto seen = NewSeenGeneration();
while (!todo.empty()) {
FuncGraphPtr top = todo.front();
todo.pop_front();
size_t top_idx = 0;
while (top_idx < todo.size()) {
FuncGraphPtr top = todo[top_idx];
top_idx++;
sorted.push_back(top);
auto used = top->func_graphs_used();
for (auto &item : used) {

View File

@ -44,7 +44,7 @@ class Primitive : public Named {
Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs);
Primitive(const Primitive &prim);
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToAbstract();
abstract::AbstractBasePtr ToAbstract() override;
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
void BeginRecordAddAttr() {

View File

@ -25,7 +25,7 @@ void ScopeManager::EnterScope(const ScopePtr &scope) {
}
void ScopeManager::LeaveScope(const ScopePtr &scope) noexcept {
if (scope != kDefaultScope) {
if (scope != kDefaultScope && !scope_stack_.empty()) {
scope_stack_.pop();
}
}

View File

@ -165,7 +165,7 @@ std::string GraphDebugInfo::debug_name() {
LocationPtr GraphDebugInfo::location() {
// function may have decorator which is included in its location
if (deco_loc_ != nullptr) {
if (deco_loc_ != nullptr && DebugInfo::location() != nullptr) {
LocationPtr loc = std::make_shared<Location>(*DebugInfo::location());
loc->set_line(loc->line() + (deco_loc_->line_end() - deco_loc_->line() + 1));
return loc;

View File

@ -348,15 +348,9 @@ class Cell(Cell_):
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")
origin_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(self, *inputs, **kwargs)
for cell in self.cells():
origin_grad.append(cell.requires_grad)
cell.set_grad(True)
else:
_pynative_exec.set_grad_flag(False)
_pynative_exec.new_graph(self, *inputs, **kwargs)
cast_inputs = list()
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
@ -368,10 +362,7 @@ class Cell(Cell_):
output = self.run_construct(cast_inputs, kwargs)
if isinstance(output, Parameter):
output = output.data
if self.requires_grad is True:
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
return output
def _add_attr(self, name, value):

View File

@ -491,7 +491,7 @@ def get_bprop_log(self):
def bprop(x, out, dout):
g = reciprocal(x)
dx = g * dout
return dx, 0
return (dx,)
return bprop
@ -505,7 +505,7 @@ def get_bprop_log1p(self):
x_1p = x + 1
g = reciprocal(x_1p)
dx = g * dout
return dx, 0
return (dx,)
return bprop

View File

@ -296,5 +296,18 @@ def _add_rowtensor_tensor(x, y):
"""
return x + y
@_add_backward.register("None", "None")
def _add_nonetensor_tensor(x, y):
"""
Adds None and None.
Args:
x (None): x
y (None): y
Returns:
None.
"""
return x + y
hyper_add = base.HyperMap(_add_backward)

View File

@ -0,0 +1,128 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <unordered_map>
#include "frontend/optimizer/ad/kpynative.h"
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "ir/manager.h"
#include "ir/value.h"
#include "ir/func_graph_cloner.h"
#include "utils/log_adapter.h"
#include "ir/graph_utils.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/parse/parse.h"
#include "debug/anf_ir_utils.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace ad {
class TestKPynative : public UT::Common {
public:
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
protected:
AbstractBasePtr BuildArg() {
std::vector<int64_t> shp = {2, 2};
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
auto abstract = tensor->ToAbstract();
return abstract;
}
FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
auto g = std::make_shared<FuncGraph>();
auto x = g->add_parameter();
auto y = g->add_parameter();
x->set_abstract(BuildArg());
y->set_abstract(BuildArg());
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
c_node->set_abstract(BuildArg());
g->set_output(c_node);
return g;
}
// a = x * y
// b = stop_gradient(a)
// c = b * y
// return c
FuncGraphPtr BuildStopGradient(const std::string &testCase) {
auto g = std::make_shared<FuncGraph>();
auto x = g->add_parameter();
x->debug_info()->set_name("x");
auto y = g->add_parameter();
y->debug_info()->set_name("y");
x->set_abstract(BuildArg());
y->set_abstract(BuildArg());
auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
a_node->set_abstract(BuildArg());
auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
b_node->set_abstract(BuildArg());
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
c_node->set_abstract(BuildArg());
auto d_node =
g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
d_node->set_abstract(BuildArg());
g->set_output(d_node);
return g;
}
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
auto input_params = primal_fg->parameters();
std::vector<ValuePtr> input_param_values;
std::for_each(input_params.begin(), input_params.end(),
[&](const AnfNodePtr &param) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
auto node_list = TopoSort(primal_fg->output());
for (auto node : node_list) {
if (node->isa<CNode>()) {
auto c_node = node->cast<CNodePtr>();
auto out = c_node->abstract()->GetValueTrack();
ValuePtrList args;
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
args.push_back(c_node->input(i)->abstract()->GetValueTrack());
}
GradPynativeOp(k_pynative_cell, c_node, args, out);
}
}
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false, false, true);
return bprop_fg;
}
};
TEST_F(TestKPynative, test_simple_add) {
auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
resource->manager()->KeepRoots({primal_fg});
ExportIR(primal_fg->ToString() + ".dat", primal_fg);
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
resource->manager()->KeepRoots({bprop_fg});
ExportIR(bprop_fg->ToString() + ".dat", bprop_fg);
}
TEST_F(TestKPynative, test_stop_gradient) {
auto primal_fg = BuildStopGradient("test_stop_gradient");
resource->manager()->KeepRoots({primal_fg});
ExportIR(primal_fg->ToString() + ".dat", primal_fg);
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
resource->manager()->KeepRoots({bprop_fg});
ExportIR(bprop_fg->ToString() + ".dat", bprop_fg);
}
} // namespace ad
} // namespace mindspore