!16153 pynative refactoring to optimizing performance
From: @chujinjin Reviewed-by: Signed-off-by:
This commit is contained in:
commit
789b63b501
|
@ -1694,13 +1694,17 @@ class ExecuteOrderGenerator {
|
|||
}
|
||||
}
|
||||
}
|
||||
// Search all nodes for parameter write assigns.
|
||||
|
||||
std::set<AnfNodePtr> refed_parameters;
|
||||
for (auto &item : search_list) {
|
||||
auto &node = item.first;
|
||||
std::set<AnfNodePtr> refed_parameters;
|
||||
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
|
||||
(void)refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
|
||||
}
|
||||
}
|
||||
// Search all nodes for parameter write assigns.
|
||||
for (auto &item : search_list) {
|
||||
auto &node = item.first;
|
||||
for (auto &in : node->inputs()) {
|
||||
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
|
||||
visit_node = validate_ref_parameter(visit_node);
|
||||
|
|
|
@ -290,10 +290,10 @@ void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_r
|
|||
task_cond_var_.notify_all();
|
||||
if (sync && !sync_run_task_finished_) {
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
if (long_run) {
|
||||
if (sync && long_run) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
|
||||
} else {
|
||||
} else if (sync) {
|
||||
sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
|
||||
}
|
||||
}
|
||||
|
|
|
@ -108,11 +108,6 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::
|
|||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(value, &tensors);
|
||||
if (!tensors.empty()) {
|
||||
if (tensors.size() != AnfAlgo::GetOutputTensorNum(value_node)) {
|
||||
MS_LOG(EXCEPTION) << "The size of tensors converted from value [" << tensors.size()
|
||||
<< "] is not equal to output size of value node [" << AnfAlgo::GetOutputTensorNum(value_node)
|
||||
<< "]";
|
||||
}
|
||||
device_formats->clear();
|
||||
device_types->clear();
|
||||
for (const auto &tensor : tensors) {
|
||||
|
@ -692,7 +687,7 @@ AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
|
|||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
|
||||
if (RemoveValueNodeFromGraph(value_node)) {
|
||||
if (!RemoveValueNodeFromGraph(value_node)) {
|
||||
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
|
||||
}
|
||||
return make_tuple;
|
||||
|
|
|
@ -361,13 +361,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|||
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = Manage({fg, func_graph}, false);
|
||||
auto need_replace_forward = pynative::PynativeExecutor::GetInstance()->need_replace_forward();
|
||||
auto forward_value = GenNewTensor(manager, equivdout, forward, need_replace_forward);
|
||||
if (!need_replace_forward) {
|
||||
cnode_morph->clear_inputs_value();
|
||||
MS_LOG(DEBUG) << "No need replace forward result";
|
||||
return;
|
||||
}
|
||||
auto forward_value = GenNewTensor(manager, equivdout, forward, true);
|
||||
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
|
||||
auto value_node = NewValueNode(forward_value);
|
||||
value_node->set_has_new_value(true);
|
||||
|
@ -406,7 +400,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|||
}
|
||||
auto out_node = c_input->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward));
|
||||
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), true));
|
||||
}
|
||||
|
||||
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
|
||||
|
|
|
@ -147,6 +147,7 @@ class KPrim {
|
|||
bprop_registry_meta_.clear();
|
||||
bprop_registry_.clear();
|
||||
}
|
||||
FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
|
||||
|
||||
private:
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
|
||||
|
|
|
@ -38,7 +38,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
using PatternListType = std::initializer_list<BaseRef>;
|
||||
KPrim g_k_prims;
|
||||
|
||||
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
||||
|
@ -76,6 +75,23 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) {
|
||||
FuncGraphPtr bprop_fg = nullptr;
|
||||
auto iter = bprop_registry_.find(prim);
|
||||
if (iter != bprop_registry_.end()) {
|
||||
bprop_fg = iter->second;
|
||||
}
|
||||
|
||||
if (bprop_fg == nullptr) {
|
||||
bprop_fg = GetBprop(prim);
|
||||
if (bprop_fg != nullptr) {
|
||||
// Set bprop_g graph cache
|
||||
bprop_registry_[prim] = bprop_fg;
|
||||
}
|
||||
}
|
||||
return bprop_fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
|
||||
static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
|
||||
std::string func_name = "_fprop_" + prim->name();
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
class KPynativeCell {
|
||||
public:
|
||||
virtual ~KPynativeCell() = default;
|
||||
virtual void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) = 0;
|
||||
// Grad for cell which may have user passed front propagate FuncGraph.
|
||||
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
|
||||
// cell. op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
// fprop_fg: user defined back propagate cnode which output is the bprop_fg.
|
||||
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
|
||||
virtual bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &fprop_fg) = 0;
|
||||
};
|
||||
|
||||
using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
|
||||
|
||||
// bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after
|
||||
// just parsed. will have prototype:
|
||||
// (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
// op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
// return: the returned funcgraph should have the same prototype.
|
||||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
|
||||
// Start building back propagate funcgraph for this cell.
|
||||
// cell_inputs: the input parameter list of this cell except the weights;
|
||||
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
||||
const std::vector<ValuePtr> &input_param_values);
|
||||
|
||||
// Return the back propagate funcgraph for this cell.
|
||||
// weights: weights parameters used in this cell.
|
||||
// grad_inputs: return sensitivity for input parameters;
|
||||
// grad_weights: return sensitivity for weights;
|
||||
// has_sens_arg: caller will pass sens args;
|
||||
// return: the returned funcgraph will have prototype:
|
||||
// if has_sens_arg is true
|
||||
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...,
|
||||
// sens_out)
|
||||
// else:
|
||||
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
|
||||
// if build_formal_param is true
|
||||
// each cnode in primal funcgraph is replaced by formal cnode
|
||||
// else:
|
||||
// each cnode in primal funcgraph is replaced by value node
|
||||
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
|
||||
bool grad_weights, bool has_sens_arg = false, bool build_formal_param = false);
|
||||
|
||||
// Grad for each operation.
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
// op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
|
||||
// Grad for cell which may have user defined back propagate function.
|
||||
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
|
||||
// cell. op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
// bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed.
|
||||
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
|
||||
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const FuncGraphPtr &bprop_fg);
|
||||
|
||||
// Clear all static resources that used in grad process
|
||||
void ClearKPynativeCellStaticRes();
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_
|
|
@ -49,6 +49,7 @@
|
|||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
|
||||
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
|
||||
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -241,6 +242,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// switch_layer defer inline
|
||||
switch_layer_defer_inline_ =
|
||||
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
|
||||
|
||||
bool_scalar_eliminate_ = MakeSubstitution(std::make_shared<BoolScalarEliminate>(), "bool_scalar_eliminate_", IsCNode);
|
||||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
|
|
|
@ -149,6 +149,9 @@ class OptimizeIRPassLib {
|
|||
|
||||
// Pynative Eliminate
|
||||
SubstitutionPtr pynative_eliminate_;
|
||||
|
||||
// Eliminate getattr bool scalar
|
||||
SubstitutionPtr bool_scalar_eliminate_;
|
||||
};
|
||||
|
||||
// the collection of irpass for resolve action
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
AnfNodePtr BoolScalarEliminate::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!cnode->IsApply(prim::kPrimGetAttr)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto vnode = cnode->input(1)->cast<ValueNodePtr>();
|
||||
if (vnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!vnode->value()->isa<BoolImm>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res = optimizer->resource();
|
||||
auto manager = res->manager();
|
||||
auto &node_users = manager->node_users();
|
||||
auto iter = node_users.find(node);
|
||||
if (iter == node_users.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodeIndexSet node_idx_set = iter->second;
|
||||
for (auto &item : node_idx_set) {
|
||||
manager->Replace(item.first, vnode);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "frontend/optimizer/optimizer_caller.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
class BoolScalarEliminate : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H
|
|
@ -89,15 +89,12 @@ class InlinerBase : public AnfVisitor {
|
|||
: use_move_(use_move), criterions_(criterions) {}
|
||||
~InlinerBase() override = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 1 || !IsValueNode<FuncGraph>(inputs[0])) {
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr || cnode->size() < 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = cnode->inputs();
|
||||
// G
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
|
||||
|
@ -109,19 +106,7 @@ class InlinerBase : public AnfVisitor {
|
|||
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
|
||||
// All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
|
||||
// Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
|
||||
bool is_match = false;
|
||||
for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
|
||||
is_match = true;
|
||||
for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
|
||||
if (!criterion(this, fg, node)) {
|
||||
is_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_match) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool is_match = ApplyCriterions(node, fg);
|
||||
if (!is_match) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -137,22 +122,9 @@ class InlinerBase : public AnfVisitor {
|
|||
if (IsUniqueUse(nullptr, fg, nullptr)) {
|
||||
// For the single used fg, including non-after and after not matched above,
|
||||
// we move the whole fg nodes.
|
||||
if (use_move_) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
ReplaceParams(mng, args, fg);
|
||||
auto out_node = fg->output();
|
||||
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
||||
return out_node;
|
||||
}
|
||||
|
||||
// The other branch calling the last after block.
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
|
||||
// Check if parameters' changed.
|
||||
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
|
||||
if (param_simplified_caller != nullptr) {
|
||||
return param_simplified_caller;
|
||||
}
|
||||
auto ret_node = InlineForUniqueUse(node, fg, args, inputs);
|
||||
if (ret_node != nullptr) {
|
||||
return ret_node;
|
||||
}
|
||||
} else {
|
||||
// We don't expand the middle multiple used after block, except the last one.
|
||||
|
@ -171,6 +143,45 @@ class InlinerBase : public AnfVisitor {
|
|||
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
|
||||
}
|
||||
|
||||
AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
if (use_move_) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
ReplaceParams(mng, args, fg);
|
||||
auto out_node = fg->output();
|
||||
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
||||
return out_node;
|
||||
}
|
||||
|
||||
// The other branch calling the last after block.
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
|
||||
// Check if parameters' changed.
|
||||
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
|
||||
if (param_simplified_caller != nullptr) {
|
||||
return param_simplified_caller;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool ApplyCriterions(const AnfNodePtr &node, const FuncGraphPtr &fg) {
|
||||
bool is_match = false;
|
||||
for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
|
||||
is_match = true;
|
||||
for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
|
||||
if (!criterion(this, fg, node)) {
|
||||
is_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_match) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return is_match;
|
||||
}
|
||||
|
||||
void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
|
||||
const FuncGraphPtr &fg) {
|
||||
auto params = fg->parameters();
|
||||
|
@ -289,9 +300,9 @@ class InlinerBase : public AnfVisitor {
|
|||
};
|
||||
|
||||
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
|
||||
auto &cnodes = fg->func_graph_cnodes_index();
|
||||
const auto &users = fg->func_graph_cnodes_index();
|
||||
int64_t n_use = std::accumulate(
|
||||
cnodes.begin(), cnodes.end(), 0,
|
||||
users.begin(), users.end(), 0,
|
||||
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
|
||||
return n_use == 1;
|
||||
}
|
||||
|
|
|
@ -154,6 +154,7 @@ class TupleListGetitemConstEliminator : public AnfVisitor {
|
|||
if (is_match_) {
|
||||
auto out = NewValueNode((*tuple_)[id_]);
|
||||
out->set_has_new_value(has_new_value_);
|
||||
out->set_abstract((*tuple_)[id_]->ToAbstract());
|
||||
return out;
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
@ -126,8 +126,8 @@ static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &n
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo,
|
||||
bool change, size_t seen) {
|
||||
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
std::vector<AnfNodePtr> *todo, bool change, size_t seen) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
|
@ -164,16 +164,17 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
|
|||
#endif
|
||||
FuncGraphManagerPtr manager = optimizer->manager();
|
||||
auto seen = NewSeenGeneration();
|
||||
// 1024 is for the initial capacity of deque
|
||||
std::deque<AnfNodePtr> todo(1024);
|
||||
todo.clear();
|
||||
// 1024 is for the initial capacity of vector
|
||||
std::vector<AnfNodePtr> todo;
|
||||
todo.reserve(1024);
|
||||
todo.emplace_back(func_graph->output());
|
||||
bool changes = false;
|
||||
|
||||
auto &all_nodes = manager->all_nodes();
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front();
|
||||
todo.pop_front();
|
||||
size_t node_idx = 0;
|
||||
while (node_idx < todo.size()) {
|
||||
AnfNodePtr node = todo[node_idx];
|
||||
node_idx++;
|
||||
|
||||
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
|
||||
continue;
|
||||
|
@ -206,16 +207,17 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
|
|||
#endif
|
||||
FuncGraphManagerPtr manager = optimizer->manager();
|
||||
auto seen = NewSeenGeneration();
|
||||
// 1024 is for the initial capacity of deque
|
||||
std::deque<AnfNodePtr> todo(1024);
|
||||
todo.clear();
|
||||
// 1024 is for the initial capacity of vector
|
||||
std::vector<AnfNodePtr> todo(0);
|
||||
todo.reserve(1024);
|
||||
todo.emplace_back(root_node);
|
||||
bool changes = false;
|
||||
|
||||
auto &all_nodes = manager->all_nodes();
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front();
|
||||
todo.pop_front();
|
||||
size_t node_idx = 0;
|
||||
while (node_idx < todo.size()) {
|
||||
AnfNodePtr node = todo[node_idx];
|
||||
node_idx++;
|
||||
|
||||
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
|
||||
continue;
|
||||
|
|
|
@ -8,6 +8,7 @@ file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"pipeline_split.cc"
|
||||
"parse/*.cc"
|
||||
"static_analysis/*.cc"
|
||||
"prim_bprop_optimizer.cc"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -522,6 +522,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
MS_LOG(EXCEPTION) << "TaskEmit args error";
|
||||
}
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
|
|
@ -531,6 +531,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
|||
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
|
||||
data_converter::SetObjGraphValue(obj_key, func_graph);
|
||||
}
|
||||
|
||||
return func_graph;
|
||||
}
|
||||
namespace data_converter {
|
||||
|
|
|
@ -0,0 +1,254 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <unordered_set>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "pipeline/jit/parse/parse_dynamic.h"
|
||||
#include "mindspore/core/ir/cell.h"
|
||||
|
||||
namespace mindspore::parse {
|
||||
static std::unordered_set<std::string> cell_input_args_;
|
||||
static const std::set<std::string> ignore_judge_dynamic_cell = {
|
||||
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
|
||||
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
|
||||
static const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
|
||||
parse::NAMED_PRIMITIVE_NAMECONSTANT,
|
||||
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
|
||||
|
||||
std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
parse::AstMainType type) {
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
if (py::isinstance<py::none>(node)) {
|
||||
MS_LOG(DEBUG) << "Get none type node!";
|
||||
return "";
|
||||
}
|
||||
auto node_type = ast->GetNodeType(node);
|
||||
MS_EXCEPTION_IF_NULL(node_type);
|
||||
// Check node type
|
||||
parse::AstMainType node_main_type = node_type->main_type();
|
||||
if (node_main_type != type) {
|
||||
MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
|
||||
return "";
|
||||
}
|
||||
std::string node_name = node_type->node_name();
|
||||
MS_LOG(DEBUG) << "Ast node is " << node_name;
|
||||
return node_name;
|
||||
}
|
||||
|
||||
void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::list args = ast->GetArgs(fn_node);
|
||||
for (size_t i = 1; i < args.size(); i++) {
|
||||
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
|
||||
MS_LOG(DEBUG) << "Input arg name: " << arg_name;
|
||||
cell_input_args_.emplace(arg_name);
|
||||
}
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Parse if/while expr";
|
||||
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
|
||||
const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
|
||||
py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
|
||||
py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
|
||||
if (comparators_node.empty()) {
|
||||
MS_LOG(DEBUG) << "Get comparators node failed!";
|
||||
return false;
|
||||
}
|
||||
auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
|
||||
// while self.a > self.b and changed self.a or self.b
|
||||
if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
|
||||
auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
std::string left_variable;
|
||||
if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
|
||||
left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
|
||||
}
|
||||
auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
|
||||
std::string right_variable;
|
||||
if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
|
||||
right_variable =
|
||||
py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
|
||||
}
|
||||
return ParseBodyContext(ast, node, {left_variable, right_variable});
|
||||
}
|
||||
// if a[0]
|
||||
if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
|
||||
py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
|
||||
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
|
||||
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// if flag:
|
||||
if (node_name == parse::NAMED_PRIMITIVE_NAME) {
|
||||
std::string id = py::cast<std::string>(test_node.attr("id"));
|
||||
if (cell_input_args_.find(id) != cell_input_args_.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Parse assign expr";
|
||||
py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
if (node_name == parse::NAMED_PRIMITIVE_CALL) {
|
||||
py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
|
||||
const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
|
||||
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
|
||||
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
if (py::isinstance<py::none>(value_in_slice_node)) {
|
||||
MS_LOG(DEBUG) << "Parse value node is none!";
|
||||
return false;
|
||||
}
|
||||
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
std::string id;
|
||||
if (py::hasattr(value_in_slice_node, "id")) {
|
||||
id = py::cast<std::string>(value_in_slice_node.attr("id"));
|
||||
}
|
||||
if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() ||
|
||||
(!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
const std::vector<std::string> &compare_prim) {
|
||||
MS_LOG(DEBUG) << "Parse augassign expr";
|
||||
bool ret = false;
|
||||
if (compare_prim.empty()) {
|
||||
return ret;
|
||||
}
|
||||
py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
|
||||
if (py::isinstance<py::none>(target_node)) {
|
||||
MS_LOG(DEBUG) << "Parse target node is none!";
|
||||
return ret;
|
||||
}
|
||||
py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
if (py::isinstance<py::none>(value_node)) {
|
||||
MS_LOG(DEBUG) << "Parse value node is none!";
|
||||
return ret;
|
||||
}
|
||||
std::string assign_prim;
|
||||
if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
|
||||
assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
|
||||
}
|
||||
auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
|
||||
if (iter != compare_prim.end()) {
|
||||
ret = true;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Parse for expr";
|
||||
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
|
||||
if (py::isinstance<py::none>(body_node)) {
|
||||
MS_LOG(DEBUG) << "Parse body of for expression is none!";
|
||||
return false;
|
||||
}
|
||||
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
|
||||
size_t count = LongToSize(pcount);
|
||||
MS_LOG(DEBUG) << "The for nodes count in body is " << count;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
auto it = py::cast<py::list>(body_node)[i];
|
||||
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
|
||||
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
const std::vector<std::string> &compare_prim) {
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
|
||||
if (py::isinstance<py::none>(func_obj)) {
|
||||
MS_LOG(DEBUG) << "Parse body of cell is none!";
|
||||
return false;
|
||||
}
|
||||
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
|
||||
size_t count = IntToSize(pcount);
|
||||
MS_LOG(DEBUG) << "The nodes count in body is " << count;
|
||||
bool ret = false;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
auto node = py::cast<py::list>(func_obj)[i];
|
||||
const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
|
||||
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
|
||||
ret = ParseAssignExprNode(ast, node);
|
||||
} else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
|
||||
ret = ParseAugAssignExprNode(ast, node, compare_prim);
|
||||
} else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
|
||||
ret = ParseForExprNode(ast, node);
|
||||
} else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
|
||||
ret = ParseIfWhileExprNode(ast, node);
|
||||
}
|
||||
if (ret) {
|
||||
MS_LOG(INFO) << "Current cell is dynamic!";
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string DynamicParser::GetCellInfo(const py::object &cell) {
|
||||
if (py::isinstance<Cell>(cell)) {
|
||||
auto c_cell = py::cast<CellPtr>(cell);
|
||||
MS_EXCEPTION_IF_NULL(c_cell);
|
||||
auto cell_info = c_cell->ToString();
|
||||
return cell_info;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool DynamicParser::IsDynamicCell(const py::object &cell) {
|
||||
std::string cell_info = GetCellInfo(cell);
|
||||
if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
|
||||
return false;
|
||||
}
|
||||
// Using ast parse to check whether the construct of cell will be changed
|
||||
auto ast = std::make_shared<parse::ParseAst>(cell);
|
||||
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
|
||||
if (!success) {
|
||||
MS_LOG(ERROR) << "Parse code to ast tree failed";
|
||||
return false;
|
||||
}
|
||||
py::object fn_node = ast->GetAstNode();
|
||||
// get the name of input args as the initialize of dynamic_variables
|
||||
ParseInputArgs(ast, fn_node);
|
||||
// parse body context
|
||||
bool ret = false;
|
||||
ret = ParseBodyContext(ast, fn_node);
|
||||
cell_input_args_.clear();
|
||||
return ret;
|
||||
}
|
||||
} // namespace mindspore::parse
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_DYNAMIC_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_DYNAMIC_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
|
||||
namespace mindspore::parse {
|
||||
|
||||
class DynamicParser {
|
||||
public:
|
||||
DynamicParser() = default;
|
||||
~DynamicParser() = default;
|
||||
|
||||
// Check cell struct
|
||||
static bool IsDynamicCell(const py::object &cell);
|
||||
|
||||
private:
|
||||
static std::string GetCellInfo(const py::object &cell);
|
||||
static void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
|
||||
static bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
static bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
static bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
static bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
static bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
static std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
parse::AstMainType type);
|
||||
};
|
||||
} // namespace mindspore::parse
|
||||
|
||||
#endif
|
|
@ -41,6 +41,7 @@
|
|||
#include "frontend/optimizer/recompute.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
|
@ -75,6 +76,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TransformTopGraphPass(const ResourcePtr &res) {
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Transform top graph error.";
|
||||
}
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
if (opt::FuncGraphHasTupleInput(func_graph)) {
|
||||
opt::GraphTupleParamTransform graph_trans;
|
||||
func_graph = graph_trans(func_graph, res->manager());
|
||||
res->set_func_graph(func_graph);
|
||||
AbstractBasePtrList abs_spec_list;
|
||||
auto ¶ms = func_graph->parameters();
|
||||
std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
|
||||
[](AnfNodePtr node) { return node->abstract(); });
|
||||
res->set_args_spec(abs_spec_list);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CleanAfterOptAPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
|
@ -93,6 +112,104 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
|
||||
opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
|
||||
irpass.pynative_eliminate_,
|
||||
});
|
||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||
opt::OptPassConfig resolver_prim = opt::OptPassConfig({
|
||||
resolve_irpass.resolver_resolve_and_getattr_,
|
||||
resolve_irpass.resolver_resolve_,
|
||||
resolve_irpass.resolver_getattr_,
|
||||
});
|
||||
|
||||
opt::OptPassConfig switch_simplify = opt::OptPassConfig({
|
||||
irpass.switch_simplify_,
|
||||
});
|
||||
|
||||
opt::OptPassConfig inline_opt = opt::OptPassConfig({
|
||||
irpass.inline_,
|
||||
});
|
||||
|
||||
opt::OptPassConfig bool_scalar_eliminate = opt::OptPassConfig({
|
||||
irpass.bool_scalar_eliminate_,
|
||||
});
|
||||
|
||||
OptPassGroupMap map({{"ad_eliminate_", pynative_eliminate},
|
||||
{"ad_resolver_prim", resolver_prim},
|
||||
{"ad_inline_", inline_opt},
|
||||
{"bool_scalar_eliminate_", bool_scalar_eliminate},
|
||||
{"ad_switch_simplify_", switch_simplify}});
|
||||
|
||||
auto prim_bprop_opt_step_1 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_1", res, map);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_1"))[&prim_bprop_opt_step_1, &func_graph]() {
|
||||
func_graph = prim_bprop_opt_step_1->step(func_graph, true);
|
||||
};
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
|
||||
opt::OptPassConfig switch_simplify_ = opt::OptPassConfig({
|
||||
irpass.switch_simplify_,
|
||||
irpass.reduce_eliminate_,
|
||||
});
|
||||
|
||||
opt::OptPassConfig inline_ = opt::OptPassConfig({
|
||||
irpass.inline_,
|
||||
});
|
||||
|
||||
opt::OptPassConfig zero_like_fill_zero_ = opt::OptPassConfig({
|
||||
irpass.zero_like_fill_zero_,
|
||||
});
|
||||
|
||||
auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
|
||||
return ReAutoMonad(root);
|
||||
};
|
||||
OptPassGroupMap map({{"ad_renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"ad_inline_", inline_},
|
||||
{"ad_switch_simplify_", switch_simplify_},
|
||||
{"ad_zero_like_fill_zero_", zero_like_fill_zero_},
|
||||
{"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)}});
|
||||
|
||||
auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", res, map);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_2"))[&prim_bprop_opt_step_2, &func_graph]() {
|
||||
func_graph = prim_bprop_opt_step_2->step(func_graph, true);
|
||||
};
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (!TransformTopGraphPass(res)) {
|
||||
MS_LOG(EXCEPTION) << "Run TransformTopGraphPass failed";
|
||||
}
|
||||
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
opt::OptPassConfig bg_final_opt_ = opt::OptPassConfig({
|
||||
irpass.inline_,
|
||||
irpass.tuple_list_get_set_item_eliminator_,
|
||||
irpass.tuple_list_get_item_eliminator_,
|
||||
irpass.tuple_list_set_item_eliminator_,
|
||||
irpass.depend_value_elim_,
|
||||
irpass.reshape_eliminate_,
|
||||
irpass.switch_simplify_,
|
||||
});
|
||||
OptPassGroupMap map({{"ad_final_opt_", bg_final_opt_}});
|
||||
if (pynative::PynativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
|
||||
map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
|
||||
}
|
||||
|
||||
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
|
||||
func_graph = bprop_graph_final_opt->step(func_graph, true);
|
||||
};
|
||||
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
|
||||
|
||||
|
@ -472,24 +589,6 @@ bool CconvPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TransformTopGraphPass(const ResourcePtr &res) {
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Transform top graph error.";
|
||||
}
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
if (opt::FuncGraphHasTupleInput(func_graph)) {
|
||||
opt::GraphTupleParamTransform graph_trans;
|
||||
func_graph = graph_trans(func_graph, res->manager());
|
||||
res->set_func_graph(func_graph);
|
||||
AbstractBasePtrList abs_spec_list;
|
||||
auto ¶ms = func_graph->parameters();
|
||||
std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
|
||||
[](AnfNodePtr node) { return node->abstract(); });
|
||||
res->set_args_spec(abs_spec_list);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }
|
||||
|
||||
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
||||
|
|
|
@ -24,6 +24,12 @@
|
|||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
class OptimizeIRPassLib;
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
||||
namespace pipeline {
|
||||
using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
|
||||
|
||||
|
@ -40,6 +46,9 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res);
|
|||
bool InferenceOptPreparePass(const ResourcePtr &res);
|
||||
void ReclaimOptimizer();
|
||||
bool PynativeOptPass(const ResourcePtr &res);
|
||||
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
|
||||
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@
|
|||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
|
@ -1166,6 +1167,8 @@ void ClearResAtexit() {
|
|||
}
|
||||
#endif
|
||||
ad::g_k_prims.clear();
|
||||
ad::ClearKPynativeCellStaticRes();
|
||||
PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||
|
||||
abstract::ClearPrimEvaluatorMap();
|
||||
compile::ClearConvertCache();
|
||||
|
|
|
@ -0,0 +1,360 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
|
||||
void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
|
||||
// args_value_using_info_ contains out
|
||||
if (args_value_using_info_.size() != op_args.size() + 1) {
|
||||
MS_LOG(EXCEPTION) << "param size :" << args_value_using_info_.size()
|
||||
<< " of bp_graph:" << opt_func_graph_->ToString()
|
||||
<< " not match input arguments num:" << op_args.size();
|
||||
}
|
||||
|
||||
ValuePtrList new_args(op_args);
|
||||
new_args.emplace_back(out);
|
||||
TryFreeOneValue(new_args, args_value_using_info_);
|
||||
}
|
||||
|
||||
void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
|
||||
const std::vector<ParamUsingInfo> ¶m_info_vec) {
|
||||
if (param_info_vec.size() != op_args.size()) {
|
||||
MS_LOG(EXCEPTION) << "param size :" << param_info_vec.size() << " of bp_graph:" << opt_func_graph_->ToString()
|
||||
<< " not match input arguments num:" << op_args.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < op_args.size(); ++i) {
|
||||
if (!param_info_vec[i].using_flg_ && !param_info_vec[i].tuple_flg_ && op_args[i]->isa<tensor::Tensor>()) {
|
||||
auto value = op_args[i]->cast<tensor::TensorPtr>();
|
||||
value->set_device_address(nullptr);
|
||||
} else if (param_info_vec[i].tuple_flg_ && op_args[i]->isa<ValueTuple>()) {
|
||||
auto value = op_args[i]->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
TryFreeOneValue(value->value(), param_info_vec[i].sub_using_info_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
|
||||
if (analysis_finish_flg_) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(opt_func_graph_);
|
||||
auto ¶ms = opt_func_graph_->parameters();
|
||||
const auto &node_users = manager->node_users();
|
||||
args_value_using_info_.resize(params.size() - 1);
|
||||
// analysis value using flg except dout
|
||||
for (size_t i = 0; i < params.size() - 1; ++i) {
|
||||
auto ¶m = params[i];
|
||||
auto &arg_info = args_value_using_info_[i];
|
||||
ArgInfoRefresh(param, &arg_info);
|
||||
AnalysisNodeUsingInfo(node_users, param, &arg_info);
|
||||
}
|
||||
analysis_finish_flg_ = true;
|
||||
}
|
||||
|
||||
void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users,
|
||||
const std::shared_ptr<AnfNode> ¶m,
|
||||
ParamUsingInfo *arg_info) const {
|
||||
auto iter = node_users.find(param);
|
||||
|
||||
if (iter == node_users.end()) {
|
||||
arg_info->using_flg_ = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// tensor return directly
|
||||
if (!arg_info->tuple_flg_) {
|
||||
arg_info->using_flg_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// specific process for tuple parameter, may only partial items used
|
||||
const auto &users_info = iter->second;
|
||||
for (auto &user_info : users_info) {
|
||||
auto user_node = user_info.first;
|
||||
arg_info->using_flg_ = true;
|
||||
MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
|
||||
if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
|
||||
for (auto &sub_info : arg_info->sub_using_info_) {
|
||||
sub_info.using_flg_ = true;
|
||||
}
|
||||
} else {
|
||||
AalysisForTupleGetItem(node_users, param, arg_info, user_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &node_users,
|
||||
const std::shared_ptr<AnfNode> ¶m,
|
||||
ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
|
||||
MS_EXCEPTION_IF_NULL(arg_info);
|
||||
auto cnode = user_node->cast<CNodePtr>();
|
||||
if (cnode->size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "TupleGetItem Node:" << user_node->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
|
||||
<< "input size is:" << cnode->size();
|
||||
}
|
||||
auto idx_node = cnode->input(2);
|
||||
if (!idx_node->isa<ValueNode>()) {
|
||||
MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
|
||||
<< " unexpected used by node:" << user_node->ToString()
|
||||
<< " TupleGetItem idx node:" << idx_node->ToString();
|
||||
}
|
||||
|
||||
auto vnode = idx_node->cast<ValueNodePtr>();
|
||||
auto value_ptr = vnode->value();
|
||||
if (value_ptr == nullptr || !value_ptr->isa<Int64Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
|
||||
<< " unexpected used by node:" << user_node->ToString()
|
||||
<< " TupleGetItem idx node:" << idx_node->ToString() << " idx Value :" << value_ptr;
|
||||
}
|
||||
|
||||
auto idx = value_ptr->cast<Int64ImmPtr>()->value();
|
||||
arg_info->sub_using_info_[idx].using_flg_ = true;
|
||||
ArgInfoRefresh(cnode, &(arg_info->sub_using_info_[idx]));
|
||||
|
||||
if (arg_info->tuple_flg_) {
|
||||
AnalysisNodeUsingInfo(node_users, cnode, &(arg_info->sub_using_info_[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> ¶m,
|
||||
ParamUsingInfo *arg_info) const {
|
||||
MS_EXCEPTION_IF_NULL(arg_info);
|
||||
auto abs = param->abstract();
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
arg_info->tuple_flg_ = false;
|
||||
MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor";
|
||||
} else if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTuple";
|
||||
arg_info->tuple_flg_ = true;
|
||||
arg_info->tuple_size_ = abs_tuple->size();
|
||||
arg_info->sub_using_info_.resize(abs_tuple->size());
|
||||
} else {
|
||||
arg_info->tuple_flg_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
PrimBpropOptimizer &PrimBpropOptimizer::GetPrimBpropOptimizerInst() {
|
||||
static PrimBpropOptimizer g_prim_bprop_opt;
|
||||
return g_prim_bprop_opt;
|
||||
}
|
||||
|
||||
void PrimBpropOptimizer::Clear() {
|
||||
prim_bprop_cache_.clear();
|
||||
tuple_list_bprop_cache_.clear();
|
||||
}
|
||||
|
||||
// bprop_fg has the signature:
|
||||
// (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
|
||||
// c_node contains the prim(input 0) and the input parameters of that prim;
|
||||
// op_args contains the arguments list of each input parameters, it maybe tensor or tuple
|
||||
// out contains the out of c_node;
|
||||
FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node,
|
||||
const ValuePtrList &op_args, const ValuePtr &out) {
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
MS_EXCEPTION_IF_NULL(c_node);
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
auto &inputs = c_node->inputs();
|
||||
if (inputs.size() < 1 || inputs.size() - 1 != op_args.size()) {
|
||||
MS_LOG(EXCEPTION) << "The parameters num " << inputs.size() - 1 << " not match arguments num " << op_args.size()
|
||||
<< ", CNode:" << c_node->ToString() << " grap:" << bprop_fg->ToString();
|
||||
}
|
||||
|
||||
if (!IsValueNode<Primitive>(inputs[0])) {
|
||||
MS_LOG(EXCEPTION) << "CNode:" << c_node->ToString()
|
||||
<< " not a primitive node, input_0 is:" << inputs[0]->ToString();
|
||||
}
|
||||
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||
MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash();
|
||||
|
||||
// kPrimHookBackward
|
||||
bool hookback_flg = IsPrimitiveEquals(prim, prim::kPrimHookBackward);
|
||||
if (hookback_flg || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
|
||||
return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
|
||||
}
|
||||
|
||||
return GetOptBpropFromCache(bprop_fg, op_args, out, prim);
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const PrimitivePtr &prim) {
|
||||
abstract::AbstractBasePtrList abs_list;
|
||||
ArgsToAbs(prim, op_args, &abs_list);
|
||||
|
||||
PrimBpropOptGraphLevel2InfoPtr level_2_graph_info;
|
||||
PrimBpropOptGraphInfoPtr level_1_graph_info;
|
||||
ECacheQrtRes cache_res = GetOptBpfgFromCache(prim, abs_list, &level_2_graph_info, &level_1_graph_info);
|
||||
|
||||
MS_LOG(DEBUG) << "Cache match result " << cache_res << ", prim: " << prim->ToString();
|
||||
if (cache_res == E_LEVEL_2) {
|
||||
MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
return BasicClone(level_2_graph_info->opt_func_graph());
|
||||
}
|
||||
|
||||
// do step1 opt
|
||||
if (cache_res == E_NOT_FOUND) {
|
||||
bprop_fg->debug_info()->set_name(prim->ToString());
|
||||
level_1_graph_info = PrimBpropOptStep1(bprop_fg);
|
||||
prim_bprop_cache_[prim] = level_1_graph_info;
|
||||
}
|
||||
FuncGraphPtr level_1_graph = BasicClone(level_1_graph_info->opt_func_graph_);
|
||||
|
||||
// do step2 opt
|
||||
auto new_abs_list = AddOutToAbsList(out, abs_list);
|
||||
level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
|
||||
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
return BasicClone(level_2_graph_info->opt_func_graph());
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
|
||||
abstract::AbstractBasePtrList abs_list;
|
||||
ArgsToAbs(prim, op_args, &abs_list);
|
||||
if (!hook_flg) {
|
||||
auto iter = tuple_list_bprop_cache_.find(std::pair(prim, abs_list));
|
||||
if (iter != tuple_list_bprop_cache_.end()) {
|
||||
return BasicClone(iter->second);
|
||||
}
|
||||
}
|
||||
|
||||
// do step1 opt
|
||||
bprop_fg->debug_info()->set_name(prim->ToString());
|
||||
auto level_1_graph_info = PrimBpropOptStep1(bprop_fg);
|
||||
|
||||
// do step2 opt
|
||||
auto new_abs_list = AddOutToAbsList(out, abs_list);
|
||||
auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
|
||||
if (!hook_flg) {
|
||||
tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
|
||||
}
|
||||
return level_2_graph_info->opt_func_graph();
|
||||
}
|
||||
|
||||
PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) {
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
auto level_1_graph_info = std::make_shared<PrimBpropOptGraphInfo>();
|
||||
auto prim_bprop_opt_res = std::make_shared<pipeline::Resource>();
|
||||
auto prim_bprop_opt_manage = prim_bprop_opt_res->manager();
|
||||
auto graph_for_cache = BasicClone(bprop_fg);
|
||||
prim_bprop_opt_res->set_func_graph(graph_for_cache);
|
||||
prim_bprop_opt_manage->AddFuncGraph(graph_for_cache);
|
||||
auto opt_bprop_fg = PrimBpOptPassStep1(irpass, prim_bprop_opt_res);
|
||||
level_1_graph_info->opt_func_graph_ = opt_bprop_fg;
|
||||
return level_1_graph_info;
|
||||
}
|
||||
|
||||
void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
|
||||
const abstract::AbstractBasePtrList &abs_list_input) {
|
||||
auto ¶ms = bprop_fg->parameters();
|
||||
if (abs_list_input.size() != params.size()) {
|
||||
MS_LOG(EXCEPTION) << "Param num:" << params.size() << " not match inputs num " << abs_list_input.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < abs_list_input.size(); i++) {
|
||||
params[i]->set_abstract(abs_list_input[i]);
|
||||
}
|
||||
}
|
||||
|
||||
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
|
||||
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) {
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
BindAbsToParameters(bprop_fg, abs_list_input);
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
auto manager = resource->manager();
|
||||
resource->set_func_graph(bprop_fg);
|
||||
manager->AddFuncGraph(bprop_fg);
|
||||
auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource);
|
||||
auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
|
||||
level_2_graph_info->AnalysisArgUsingInfo(manager);
|
||||
return level_2_graph_info;
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto after_opt_bg = BpropGraphFinalOptPass(res);
|
||||
return after_opt_bg;
|
||||
}
|
||||
|
||||
ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
|
||||
const abstract::AbstractBasePtrList &abs_list,
|
||||
PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
|
||||
PrimBpropOptGraphInfoPtr *level_1_graph_info) {
|
||||
auto attrs_ = prim->attrs();
|
||||
for (auto &item : attrs_) {
|
||||
MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
|
||||
}
|
||||
|
||||
auto iter = prim_bprop_cache_.find(prim);
|
||||
if (iter == prim_bprop_cache_.end()) {
|
||||
return E_NOT_FOUND;
|
||||
}
|
||||
|
||||
*level_1_graph_info = iter->second;
|
||||
auto second_iter = (*level_1_graph_info)->graph_level_2_cache_.find(abs_list);
|
||||
if (second_iter == (*level_1_graph_info)->graph_level_2_cache_.end()) {
|
||||
return E_LEVEL_1;
|
||||
}
|
||||
*level_2_graph_info = second_iter->second;
|
||||
return E_LEVEL_2;
|
||||
}
|
||||
|
||||
void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
|
||||
abstract::AbstractBasePtrList *abs_list) {
|
||||
auto const_input_index = prim->get_const_input_indexes();
|
||||
bool have_const_input = !const_input_index.empty();
|
||||
bool is_const_prim = prim->is_const_prim();
|
||||
for (size_t i = 0; i < op_args.size(); ++i) {
|
||||
bool is_const_input =
|
||||
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
|
||||
auto &arg_value = op_args[i];
|
||||
auto arg_abs = arg_value->ToAbstract();
|
||||
if (!is_const_prim && !is_const_input) {
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
arg_abs = arg_abs->Broaden(config);
|
||||
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
||||
}
|
||||
(*abs_list).emplace_back(arg_abs);
|
||||
}
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
|
||||
const abstract::AbstractBasePtrList &abs_list) {
|
||||
if (!out->isa<tensor::Tensor>() && !out->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Out value not Tensor or Tuple, please check the input arguments.";
|
||||
}
|
||||
abstract::AbstractBasePtrList new_abs_list(abs_list);
|
||||
auto out_abs = out->ToAbstract();
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
out_abs = out_abs->Broaden(config);
|
||||
new_abs_list.emplace_back(out_abs);
|
||||
new_abs_list.emplace_back(out_abs);
|
||||
return new_abs_list;
|
||||
}
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,185 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
struct PrimBpropOptGraphInfo;
|
||||
|
||||
class PrimBpropOptGraphLevel2Info;
|
||||
|
||||
struct PrimitiveTotalEqual;
|
||||
|
||||
struct PrimitiveTupleListHasher;
|
||||
|
||||
struct PrimitiveTupleListEqual;
|
||||
|
||||
using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>;
|
||||
|
||||
using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>;
|
||||
|
||||
using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
|
||||
using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>;
|
||||
|
||||
using PrimBpropLevel2Cache =
|
||||
std::unordered_map<abstract::AbstractBasePtrList, PrimBpropOptGraphLevel2InfoPtr, abstract::AbstractBasePtrListHasher,
|
||||
abstract::AbstractBasePtrListEqual>;
|
||||
|
||||
using PrimTupleListCache =
|
||||
std::unordered_map<TupleListKey, FuncGraphPtr, PrimitiveTupleListHasher, PrimitiveTupleListEqual>;
|
||||
|
||||
struct PrimitiveTupleListHasher {
|
||||
bool operator()(const TupleListKey &key) const {
|
||||
abstract::AbstractBasePtrListHasher hasher;
|
||||
return hasher(key.second);
|
||||
}
|
||||
};
|
||||
|
||||
struct PrimitiveTupleListEqual {
|
||||
bool operator()(TupleListKey const &t1, TupleListKey const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1.first);
|
||||
MS_EXCEPTION_IF_NULL(t2.first);
|
||||
|
||||
if (!(*t1.first == *t2.first)) {
|
||||
return false;
|
||||
}
|
||||
abstract::AbstractBasePtrListEqual cmp;
|
||||
return cmp(t1.second, t2.second);
|
||||
}
|
||||
};
|
||||
|
||||
struct PrimitiveTotalEqual {
|
||||
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
MS_EXCEPTION_IF_NULL(t2);
|
||||
return *t1 == *t2;
|
||||
}
|
||||
};
|
||||
|
||||
enum ECacheQrtRes { E_NOT_FOUND, E_LEVEL_1, E_LEVEL_2 };
|
||||
|
||||
struct PrimBpropOptGraphInfo {
|
||||
// the level1 opt func_graph without infer, no shape/type info provide
|
||||
FuncGraphPtr opt_func_graph_;
|
||||
// the opt func_graph after infer, func_graph level2 cache
|
||||
PrimBpropLevel2Cache graph_level_2_cache_;
|
||||
};
|
||||
|
||||
struct ParamUsingInfo {
|
||||
bool using_flg_{false};
|
||||
bool tuple_flg_{false};
|
||||
size_t tuple_size_;
|
||||
std::vector<ParamUsingInfo> sub_using_info_;
|
||||
};
|
||||
|
||||
class PrimBpropOptGraphLevel2Info {
|
||||
public:
|
||||
explicit PrimBpropOptGraphLevel2Info(const FuncGraphPtr &func_graph) : opt_func_graph_(func_graph) {}
|
||||
|
||||
const FuncGraphPtr &opt_func_graph() const { return opt_func_graph_; }
|
||||
|
||||
void TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out);
|
||||
|
||||
void AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager);
|
||||
|
||||
private:
|
||||
void ArgInfoRefresh(const std::shared_ptr<AnfNode> ¶m, ParamUsingInfo *arg_info) const;
|
||||
|
||||
void AnalysisNodeUsingInfo(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> ¶m,
|
||||
ParamUsingInfo *arg_info) const;
|
||||
|
||||
void TryFreeOneValue(const ValuePtrList &op_args, const std::vector<ParamUsingInfo> ¶m_info_vec);
|
||||
|
||||
void AalysisForTupleGetItem(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> ¶m,
|
||||
ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const;
|
||||
|
||||
private:
|
||||
// the level2 opt func_graph
|
||||
FuncGraphPtr opt_func_graph_;
|
||||
// to indicate arguments value using or not, if not using should free device memory
|
||||
std::vector<ParamUsingInfo> args_value_using_info_;
|
||||
bool analysis_finish_flg_{false};
|
||||
};
|
||||
|
||||
class PrimBpropOptimizer {
|
||||
public:
|
||||
~PrimBpropOptimizer() = default;
|
||||
|
||||
void Clear();
|
||||
|
||||
static PrimBpropOptimizer &GetPrimBpropOptimizerInst();
|
||||
|
||||
// bprop_fg has the signature:
|
||||
// (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
|
||||
// c_node contains the prim(input 0) and the input parameters of that prim;
|
||||
// op_args contains the arguments list of each input parameters, it maybe tensor or tuple
|
||||
// out contains the out of c_node;
|
||||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
|
||||
// do inline opt for final bprop graph
|
||||
FuncGraphPtr BpropGraphFinalOpt(const ResourcePtr &res);
|
||||
|
||||
private:
|
||||
PrimBpropOptimizer() = default;
|
||||
|
||||
ECacheQrtRes GetOptBpfgFromCache(const PrimitivePtr &prim, const abstract::AbstractBasePtrList &abs_list,
|
||||
PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
|
||||
PrimBpropOptGraphInfoPtr *level_1_graph_info);
|
||||
|
||||
// converter tensor args to abs value;
|
||||
void ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args, abstract::AbstractBasePtrList *abs_list);
|
||||
|
||||
// add out && dout to abs list
|
||||
abstract::AbstractBasePtrList AddOutToAbsList(const ValuePtr &out, const abstract::AbstractBasePtrList &abs_list);
|
||||
|
||||
// do opt without input info, no infer
|
||||
PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg);
|
||||
|
||||
// do opt with input info
|
||||
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(const FuncGraphPtr &bprop_fg,
|
||||
const abstract::AbstractBasePtrList &abs_list_input);
|
||||
|
||||
void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input);
|
||||
|
||||
FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const PrimitivePtr &prim);
|
||||
|
||||
FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const PrimitivePtr &prim, bool hook_flg);
|
||||
|
||||
private:
|
||||
// cache optimized bprop graph
|
||||
PrimBpropCache prim_bprop_cache_;
|
||||
PrimTupleListCache tuple_list_bprop_cache_;
|
||||
};
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
|
|
@ -50,28 +50,22 @@ enum PynativeStatusCode {
|
|||
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
|
||||
|
||||
struct OpExecInfo {
|
||||
bool is_dynamic_shape = false;
|
||||
bool is_mixed_precision_cast = false;
|
||||
size_t next_input_index = 0;
|
||||
std::string op_name;
|
||||
std::string op_index;
|
||||
std::string op_info;
|
||||
std::string next_op_name = "";
|
||||
PrimitivePyPtr py_primitive;
|
||||
AbstractBasePtr abstract;
|
||||
|
||||
py::list op_inputs;
|
||||
std::vector<int64_t> inputs_mask;
|
||||
bool is_dynamic_shape = false;
|
||||
std::string next_op_name = "";
|
||||
bool is_mixed_precision_cast = false;
|
||||
size_t next_input_index = 0;
|
||||
};
|
||||
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
|
||||
|
||||
const std::set<std::string> ignore_infer_prim = {"mixed_precision_cast"};
|
||||
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
|
||||
const std::set<std::string> ignore_judge_dynamic_cell = {
|
||||
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
|
||||
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
|
||||
const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
|
||||
parse::NAMED_PRIMITIVE_NAMECONSTANT,
|
||||
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
|
||||
const std::set<std::string> dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup",
|
||||
"Transpose"};
|
||||
} // namespace pynative
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -36,11 +36,12 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "ir/anf.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/optimizer/ad/kpynative.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
namespace mindspore::pynative {
|
||||
namespace py = pybind11;
|
||||
using cell_id = std::string;
|
||||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
|
||||
|
||||
|
@ -52,8 +53,8 @@ struct PrimAbsInfo {
|
|||
|
||||
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
|
||||
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
|
||||
using OpIndexWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
|
||||
using TensorIdWithTensor = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
|
||||
using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
|
||||
using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
|
||||
|
||||
py::object RunOp(const py::args &args);
|
||||
|
||||
|
@ -64,108 +65,80 @@ struct GraphInfo {
|
|||
AnfNodePtr output;
|
||||
OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weights
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
|
||||
std::vector<std::string> objects;
|
||||
GraphInfo() = default;
|
||||
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
|
||||
};
|
||||
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
||||
|
||||
class CellInfo {
|
||||
public:
|
||||
CellInfo() = default;
|
||||
~CellInfo() = default;
|
||||
CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
|
||||
: is_custom_bprop_(custom_bprop),
|
||||
is_dynamic_(has_dynamic),
|
||||
fg_(std::move(foward_graph)),
|
||||
cell_id_(std::move(cellid)),
|
||||
bprop_cell_id_(std::move(bprop_id)) {}
|
||||
|
||||
bool is_custom_bprop() const { return is_custom_bprop_; }
|
||||
void set_is_custom_bprop(bool is_custom_bprop) { is_custom_bprop_ = is_custom_bprop; }
|
||||
bool is_dynamic() const { return is_dynamic_; }
|
||||
void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
|
||||
size_t call_times() const { return call_times_; }
|
||||
void set_call_times(size_t call_times) { call_times_ = call_times; }
|
||||
FuncGraphPtr fg() const { return fg_; }
|
||||
void set_fg(FuncGraphPtr fg) { fg_ = std::move(fg); }
|
||||
std::string &cell_id() { return cell_id_; }
|
||||
void set_cell_id(std::string cell_id) { cell_id_ = std::move(cell_id); }
|
||||
std::string &bprop_cell_id() { return bprop_cell_id_; }
|
||||
std::vector<std::string> &cell_ops_info() { return cell_ops_info_; }
|
||||
|
||||
private:
|
||||
bool is_custom_bprop_{false}; // Custom bprop
|
||||
bool is_dynamic_{false}; // Set by has_dynamic_cell
|
||||
size_t call_times_{0};
|
||||
FuncGraphPtr fg_{nullptr}; // Forward graph
|
||||
std::string cell_id_;
|
||||
std::string bprop_cell_id_;
|
||||
std::vector<std::string> cell_ops_info_; // All ops info
|
||||
};
|
||||
using CellInfoPtr = std::shared_ptr<CellInfo>;
|
||||
|
||||
class TopCellInfo {
|
||||
public:
|
||||
TopCellInfo() = default;
|
||||
~TopCellInfo() = default;
|
||||
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid)
|
||||
: is_topest_(topest), resource_(std::move(r)), df_builder_(std::move(df)), cell_id_(std::move(cellid)) {}
|
||||
TopCellInfo(bool topest, size_t grad_order, ResourcePtr r, FuncGraphPtr df, std::string cellid)
|
||||
: is_topest_(topest),
|
||||
grad_order_(grad_order),
|
||||
resource_(std::move(r)),
|
||||
df_builder_(std::move(df)),
|
||||
cell_id_(std::move(cellid)) {}
|
||||
|
||||
bool is_grad() const { return is_grad_; }
|
||||
void set_is_grad(bool is_grad) { is_grad_ = is_grad; }
|
||||
bool is_init_kpynative() const { return is_init_kpynative_; }
|
||||
void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
|
||||
bool is_topest() const { return is_topest_; }
|
||||
size_t grad_order() const { return grad_order_; }
|
||||
void set_grad_order(size_t grad_order) { grad_order_ = grad_order; }
|
||||
bool is_dynamic() const { return is_dynamic_; }
|
||||
void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
|
||||
bool vm_compiled() const { return vm_compiled_; }
|
||||
void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; }
|
||||
bool need_grad() const { return need_grad_; }
|
||||
void set_need_grad(bool need_grad) { need_grad_ = need_grad; }
|
||||
bool has_dynamic_cell() const { return has_dynamic_cell_; }
|
||||
bool is_real_dynamic() const { return is_real_dynamic_; }
|
||||
void set_is_real_dynamic(bool is_real_dynamic) { is_real_dynamic_ = is_real_dynamic; }
|
||||
bool ms_function_flag() const { return ms_function_flag_; }
|
||||
void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
|
||||
bool need_compile_graph() const { return need_compile_graph_; }
|
||||
void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||
bool forward_already_run() const { return forward_already_run_; }
|
||||
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||
ResourcePtr resource() { return resource_; }
|
||||
FuncGraphPtr df_builder() { return df_builder_; }
|
||||
size_t op_num() const { return op_num_; }
|
||||
void set_op_num(size_t op_num) { op_num_ = op_num; }
|
||||
std::string &cell_id() { return cell_id_; }
|
||||
std::string &sens_id() { return sens_id_; }
|
||||
void set_sens_id(std::string sens_id) { sens_id_ = std::move(sens_id); }
|
||||
std::string &weights_id() { return weights_id_; }
|
||||
void set_weights_id(std::string weights_id) { weights_id_ = std::move(weights_id); }
|
||||
std::string &input_args_id() { return input_args_id_; }
|
||||
std::string all_op_info() const { return all_op_info_; }
|
||||
void set_all_op_info(std::string all_op_info) { all_op_info_ = std::move(all_op_info); }
|
||||
void set_input_args_id(std::string input_args_id) { input_args_id_ = std::move(input_args_id); }
|
||||
std::vector<CellInfoPtr> &cell_graph_list() { return cell_graph_list_; }
|
||||
void set_cell_graph_list(const std::vector<CellInfoPtr> &cell_graph_list) { cell_graph_list_ = cell_graph_list; }
|
||||
std::unordered_set<std::string> &sub_cell_list() { return sub_cell_list_; }
|
||||
bool IsSubCell(const std::string &cell_id) const;
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
|
||||
void set_graph_info_map(const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map) {
|
||||
graph_info_map_ = graph_info_map;
|
||||
}
|
||||
void clear() {
|
||||
cell_graph_list_.clear();
|
||||
graph_info_map_.clear();
|
||||
OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
|
||||
TensorIdWithTensorObject &tensor_id_with_tensor_object() { return tensor_id_with_tensor_object_; }
|
||||
ad::KPynativeCellPtr k_pynative_cell_ptr() const { return k_pynative_cell_ptr_; }
|
||||
void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) {
|
||||
k_pynative_cell_ptr_ = k_pynative_cell_ptr;
|
||||
}
|
||||
void clear();
|
||||
|
||||
private:
|
||||
bool is_grad_{false}; // Derivative is calculated
|
||||
bool is_topest_{false};
|
||||
bool is_dynamic_{false};
|
||||
bool vm_compiled_{false};
|
||||
bool need_grad_{true};
|
||||
bool has_dynamic_cell_{false};
|
||||
bool is_real_dynamic_{false};
|
||||
bool ms_function_flag_{false};
|
||||
bool is_init_kpynative_{false};
|
||||
bool forward_already_run_{false};
|
||||
bool need_compile_graph_{false};
|
||||
size_t op_num_{0};
|
||||
size_t grad_order_{0};
|
||||
ResourcePtr resource_{nullptr};
|
||||
FuncGraphPtr df_builder_{nullptr};
|
||||
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
|
||||
std::string cell_id_;
|
||||
std::string sens_id_;
|
||||
std::string weights_id_;
|
||||
std::string input_args_id_;
|
||||
std::vector<CellInfoPtr> cell_graph_list_;
|
||||
std::string all_op_info_;
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||
std::unordered_set<std::string> sub_cell_list_;
|
||||
OpInfoWithTensorId op_info_with_tensor_id_;
|
||||
TensorIdWithTensorObject tensor_id_with_tensor_object_;
|
||||
};
|
||||
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
|
||||
|
||||
class DynamicAnalysis;
|
||||
using DynamicAnalysisPtr = std::shared_ptr<DynamicAnalysis>;
|
||||
|
||||
class ForwardExecutor;
|
||||
using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
|
||||
using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
|
||||
|
@ -174,30 +147,6 @@ class GradExecutor;
|
|||
using GradExecutorPtr = std::shared_ptr<GradExecutor>;
|
||||
using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>;
|
||||
|
||||
class DynamicAnalysis {
|
||||
public:
|
||||
DynamicAnalysis() = default;
|
||||
~DynamicAnalysis() = default;
|
||||
|
||||
// Check cell struct
|
||||
bool IsDynamicCell(const py::object &cell);
|
||||
|
||||
private:
|
||||
std::string GetCellInfo(const py::object &cell);
|
||||
void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
|
||||
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
parse::AstMainType type);
|
||||
|
||||
std::unordered_set<std::string> cell_input_args_;
|
||||
};
|
||||
|
||||
class GradExecutor {
|
||||
public:
|
||||
GradExecutor() = default;
|
||||
|
@ -227,112 +176,74 @@ class GradExecutor {
|
|||
|
||||
FuncGraphPtr curr_g() const;
|
||||
TopCellInfoPtr top_cell() const;
|
||||
bool TopCellIsDynamic();
|
||||
void CheckNeedCompileGraph();
|
||||
TopCellInfoPtr GetTopCell(const string &cell_id) const;
|
||||
bool need_renormalize() const { return need_renormalize_; }
|
||||
void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
||||
bool grad_flag() const { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
bool in_grad_process() const { return in_grad_process_; }
|
||||
std::string top_cell_id() { return top_cell()->cell_id(); }
|
||||
bool grad_is_running() const { return grad_is_running_; }
|
||||
bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false);
|
||||
std::stack<std::string> &cell_stack() { return cell_stack_; }
|
||||
std::vector<TopCellInfoPtr> &top_cell_list() { return top_cell_list_; }
|
||||
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info);
|
||||
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
|
||||
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
|
||||
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real);
|
||||
void DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &op_out);
|
||||
void MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &fprop_g, const py::object &out,
|
||||
const py::args &args, const std::string &graph_phase);
|
||||
void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
|
||||
const OpExecInfoPtr &op_exec_info, ValuePtrList *input_values,
|
||||
CNodePtr *ms_function_cnode);
|
||||
void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
|
||||
void SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args, const py::object &phase);
|
||||
bool need_construct_graph() const { return !graph_stack_.empty() && grad_flag_; }
|
||||
void set_dynamic_analysis(DynamicAnalysisPtr dynamic_analysis) { dynamic_analysis_ = std::move(dynamic_analysis); }
|
||||
std::stack<FuncGraphPtr> &graph_stack() { return graph_stack_; }
|
||||
std::vector<TopCellInfoPtr> &top_cell_list() { return top_cell_list_; }
|
||||
bool need_replace_forward() const { return need_replace_forward_; }
|
||||
std::stack<std::string> &cell_op_info_stack() { return cell_op_info_stack_; }
|
||||
std::unordered_map<std::string, size_t> &op_index_map() { return op_index_map_; }
|
||||
std::unordered_map<std::string, std::string> &obj_to_forward_id() { return obj_to_forward_id_; }
|
||||
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
|
||||
void ClearGrad(const py::object &cell, const py::args &args);
|
||||
void ClearRes();
|
||||
void ClearCellRes(const std::string &cell_id = "");
|
||||
|
||||
private:
|
||||
ForwardExecutorPtr forward() const;
|
||||
DynamicAnalysisPtr dynamic_analysis() const;
|
||||
bool grad_running() const { return grad_is_running_; }
|
||||
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
|
||||
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
|
||||
|
||||
// Higher derivative
|
||||
bool IsNestedGrad() const;
|
||||
void AddNestedGradOrder() { ++grad_order_; }
|
||||
void SubNestedGradOrder();
|
||||
void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
|
||||
const std::string &cell_id);
|
||||
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
|
||||
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
|
||||
const py::object &out, bool has_sens);
|
||||
void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
|
||||
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
|
||||
|
||||
// Dynamic
|
||||
bool CheckDynamicCell(const std::string &cell_id);
|
||||
bool CheckRealDynamicCell(const std::string &cell_id);
|
||||
void ClearDynamicTopRes(const std::string &cell_id);
|
||||
|
||||
void PushCurrentGraphToStack();
|
||||
void PopGraphStack();
|
||||
void PushCurrentCellOpInfoToStack();
|
||||
void PopCurrentCellOpInfoFromStack();
|
||||
std::string GetCellOpInfo();
|
||||
void ReplaceCellOpInfoByCellId(const std::string &cell_id);
|
||||
void SwitchTopcell();
|
||||
size_t GetHighOrderStackSize() const { return high_order_stack_.size(); }
|
||||
void MakeNestedCnode(const py::object &cell, const std::string &cell_id, const py::args &forward_args,
|
||||
const ResourcePtr &resource, const py::object &out);
|
||||
void PushCellStack(const std::string &cell_id);
|
||||
void PopCellStack();
|
||||
void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell);
|
||||
TopCellInfoPtr PopHighOrderGraphStack();
|
||||
|
||||
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
|
||||
ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
bool IsFirstGradStep();
|
||||
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
|
||||
bool IsTopGraph(const std::string &cell_id);
|
||||
bool IsTopestGraph(const std::string &cell_id);
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
bool IsGradBefore(const std::string &cell_id);
|
||||
bool CheckCellGraph(const std::string &cell_id);
|
||||
bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
|
||||
bool is_grad);
|
||||
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
||||
bool need_cloned = false, bool is_grad = false);
|
||||
bool CheckCellChanged(const std::string &cell_id);
|
||||
void UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled);
|
||||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args);
|
||||
void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest);
|
||||
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
|
||||
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
|
||||
const std::string &out_id, const py::args &args);
|
||||
bool EndBpropGraph(const string &cell_id);
|
||||
FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
|
||||
const std::string &cell_id, const py::args &args);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
|
||||
py::object *sens = nullptr);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
|
||||
py::args *forward_args = nullptr);
|
||||
void GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
void SetTopCellTensorId(const std::string &cell_id);
|
||||
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
|
||||
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
|
||||
void SetGradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size, const std::string &cell_id);
|
||||
FuncGraphPtr GetBpropGraph(const GradOperationPtr &grad, const py::object &cell,
|
||||
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
|
||||
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
|
||||
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
|
||||
// Memory clean between steps
|
||||
void ClearResidualRes(const std::string &cell_id);
|
||||
void ClearCnodeRes(const AnfNodePtr &node);
|
||||
void CleanPreMemoryInValueNode();
|
||||
void SaveTensorsInValueNode(const ResourcePtr &resource);
|
||||
void SaveAllValueNodeTensors(const FuncGraphPtr &graph);
|
||||
|
||||
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
|
||||
top_cell()->graph_info_map()[g]->objects.push_back(obj);
|
||||
}
|
||||
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
|
||||
bool is_param = false);
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) {
|
||||
|
@ -346,33 +257,33 @@ class GradExecutor {
|
|||
const std::vector<int64_t> &index) {
|
||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g, const py::object &out);
|
||||
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
|
||||
|
||||
private:
|
||||
size_t grad_order_{0};
|
||||
bool grad_flag_{false};
|
||||
bool in_bprop_process_{false};
|
||||
bool in_grad_process_{false};
|
||||
bool has_dynamic_cell_{false};
|
||||
bool need_replace_forward_{true};
|
||||
bool need_renormalize_{false};
|
||||
bool grad_is_running_{false};
|
||||
int custom_bprop_cell_count_{0};
|
||||
size_t grad_order_{0};
|
||||
|
||||
// Only set in high grad
|
||||
FuncGraphPtr curr_g_{nullptr};
|
||||
// For clear pre top res
|
||||
TopCellInfoPtr pre_top_cell_{nullptr};
|
||||
TopCellInfoPtr top_cell_{nullptr};
|
||||
std::unordered_map<std::string, size_t> op_index_map_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
|
||||
std::unordered_set<tensor::TensorPtr> all_value_node_tensors_;
|
||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||
|
||||
// Records forwrad graph, the bottom is top graph
|
||||
std::stack<FuncGraphPtr> graph_stack_;
|
||||
// Records op info of every cell, the bottom is op info of top cell
|
||||
std::stack<std::string> cell_op_info_stack_;
|
||||
|
||||
// Records forwrad cell, the bottom is top cell
|
||||
std::stack<std::string> cell_stack_;
|
||||
// For high grad of bprop
|
||||
std::stack<bool> bprop_grad_stack_;
|
||||
std::vector<std::string> bprop_cell_list_;
|
||||
// For high grad order
|
||||
std::stack<std::pair<FuncGraphPtr, TopCellInfoPtr>> high_order_stack_;
|
||||
// Use vector for keep order
|
||||
std::vector<TopCellInfoPtr> top_cell_list_;
|
||||
// Record all top cell which has been ran
|
||||
std::map<cell_id, TopCellInfoPtr> already_run_top_cell_;
|
||||
// Use vector for keep order
|
||||
ForwardExecutorWeakPtr forward_executor_;
|
||||
DynamicAnalysisPtr dynamic_analysis_;
|
||||
};
|
||||
|
||||
class ForwardExecutor {
|
||||
|
@ -388,13 +299,9 @@ class ForwardExecutor {
|
|||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
||||
void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
|
||||
std::unordered_map<std::string, OpIndexWithTensorId> &cell_op_index_with_tensor_id() {
|
||||
return cell_op_index_with_tensor_id_;
|
||||
}
|
||||
std::unordered_map<std::string, TensorIdWithTensor> &cell_tensor_id_with_tensor() {
|
||||
return cell_tensor_id_with_tensor_;
|
||||
}
|
||||
void ClearRes();
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
|
||||
private:
|
||||
GradExecutorPtr grad() const;
|
||||
|
@ -404,17 +311,14 @@ class ForwardExecutor {
|
|||
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||
PynativeStatusCode *status);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
|
||||
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
bool *is_find);
|
||||
// Update the abstract and device address info of value node and tensors in bprop graph
|
||||
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
|
||||
|
||||
py::object GetOpOutputObject(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodePtr &CNode, bool out_abstract_existed);
|
||||
// Mix precision
|
||||
void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
|
||||
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index);
|
||||
|
@ -429,9 +333,6 @@ class ForwardExecutor {
|
|||
GradExecutorWeakPtr grad_executor_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
// Used for runop and replace forward result of grad graph
|
||||
std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_;
|
||||
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
|
||||
// Used to cache cast struct
|
||||
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
|
||||
};
|
||||
|
@ -444,7 +345,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
|
||||
forward_executor_ = std::make_shared<ForwardExecutor>();
|
||||
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
|
||||
grad_executor_->set_dynamic_analysis(std::make_shared<DynamicAnalysis>());
|
||||
forward_executor_->set_grad_executor(grad_executor_);
|
||||
}
|
||||
return executor_;
|
||||
|
@ -459,6 +359,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
ForwardExecutorPtr forward_executor();
|
||||
|
||||
void set_grad_flag(bool flag);
|
||||
std::string graph_phase() const { return graph_phase_; }
|
||||
void set_graph_phase(std::string graph_phase) { graph_phase_ = graph_phase; }
|
||||
void GradMsFunction(const py::object &out, const py::args &args);
|
||||
void NewGraph(const py::object &cell, const py::args &args);
|
||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||
|
@ -468,9 +371,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
|
||||
// Used by graph clean
|
||||
bool GetIsDynamicCell();
|
||||
bool need_replace_forward() { return grad_executor()->need_replace_forward(); }
|
||||
// Cell destruct will call
|
||||
void ClearCell(const std::string &flag = "");
|
||||
void ClearCell(const std::string &cell_id);
|
||||
void ClearGrad(const py::object &cell, const py::args &args);
|
||||
// Abnormal existed
|
||||
void ClearRes();
|
||||
|
@ -480,6 +382,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
private:
|
||||
PynativeExecutor() = default;
|
||||
|
||||
// The graph phase is used to obtain backend graph that is complied by ms_function
|
||||
std::string graph_phase_;
|
||||
static std::shared_ptr<PynativeExecutor> executor_;
|
||||
static std::mutex instance_lock_;
|
||||
static ForwardExecutorPtr forward_executor_;
|
||||
|
@ -493,7 +397,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
};
|
||||
|
||||
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::pynative
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
|
||||
|
|
|
@ -288,6 +288,9 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
|
|||
auto primitive_py = primitive->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(primitive_py);
|
||||
this->set_hook(primitive_py->hook());
|
||||
if (primitive_py->HasAttr(kBpropAttrName)) {
|
||||
this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
|
||||
|
|
|
@ -292,6 +292,8 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
|
|||
}
|
||||
// Clear the output address of graph.
|
||||
ClearOutputAddress(inputs, value_nodes, execution_order);
|
||||
|
||||
graph_output_map_.erase(graph_id);
|
||||
}
|
||||
|
||||
void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) {
|
||||
|
|
|
@ -286,13 +286,15 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *
|
|||
if (element->isa<tensor::Tensor>()) {
|
||||
auto tensor = element->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensors->push_back(tensor);
|
||||
tensors->emplace_back(tensor);
|
||||
} else if (element->isa<ValueTuple>()) {
|
||||
TensorValueToTensor(element, tensors);
|
||||
}
|
||||
}
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensors->push_back(tensor);
|
||||
tensors->emplace_back(tensor);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -107,7 +107,7 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std:
|
|||
result.graph_id = graph_id;
|
||||
|
||||
graph_id_map_[graph_id] = result;
|
||||
if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) {
|
||||
if (!pynative_mode) {
|
||||
(void)g_ConvertCache.emplace(segment, result);
|
||||
}
|
||||
return result;
|
||||
|
|
|
@ -214,7 +214,10 @@ class _MindSporeFunction:
|
|||
new_inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
new_inputs.append(i)
|
||||
return self._executor(tuple(new_inputs), phase)
|
||||
output = self._executor(tuple(new_inputs), phase)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
_pynative_exec.set_graph_phase(phase)
|
||||
return output
|
||||
|
||||
|
||||
def ms_function(fn=None, obj=None, input_signature=None):
|
||||
|
@ -272,10 +275,15 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
def wrap_mindspore(func):
|
||||
@wraps(func)
|
||||
def staging_specialize(*args):
|
||||
input_args = args
|
||||
process_obj = obj
|
||||
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
|
||||
input_args = args[1:]
|
||||
process_obj = args[0]
|
||||
return _MindSporeFunction(func, input_signature, process_obj)(*args)
|
||||
out = _MindSporeFunction(func, input_signature, process_obj)(*args)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
_pynative_exec.grad_ms_function(out, *input_args)
|
||||
return out
|
||||
|
||||
return staging_specialize
|
||||
|
||||
|
@ -376,6 +384,12 @@ class _PynativeExecutor:
|
|||
def sync(self):
|
||||
self._executor.sync()
|
||||
|
||||
def grad_ms_function(self, output, *args):
|
||||
self._executor.grad_ms_function(output, *args)
|
||||
|
||||
def set_graph_phase(self, phase):
|
||||
self._executor.set_graph_phase(phase)
|
||||
|
||||
def set_grad_flag(self, flag):
|
||||
self._executor.set_grad_flag(flag)
|
||||
|
||||
|
|
|
@ -498,7 +498,13 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
MS_EXCEPTION_IF_NULL(input_x->shape());
|
||||
auto input_type = primitive->GetAttr("dst_type")->cast<TypePtr>();
|
||||
auto attr = primitive->GetAttr("dst_type");
|
||||
if (attr == nullptr) {
|
||||
attr = args_spec_list[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
primitive->set_attr("dst_type", attr);
|
||||
}
|
||||
auto input_type = attr->cast<TypePtr>();
|
||||
auto ret = std::make_shared<AbstractTensor>(input_type, input_x->shape()->shape());
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -165,8 +165,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
|
|||
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
|
||||
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
|
||||
}
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
|
||||
<< index_value->ToString();
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index->ToString();
|
||||
}
|
||||
auto idx_v = GetValue<int64_t>(index_value);
|
||||
std::size_t nelems = queue->elements().size();
|
||||
|
|
|
@ -41,8 +41,8 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
|
|||
return res;
|
||||
}
|
||||
size_t seen = NewSeenGeneration();
|
||||
std::deque<AnfNodePtr> todo(1024);
|
||||
todo.clear();
|
||||
std::vector<AnfNodePtr> todo;
|
||||
todo.reserve(1024);
|
||||
todo.push_back(root);
|
||||
|
||||
while (!todo.empty()) {
|
||||
|
@ -95,14 +95,15 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
|
|||
|
||||
// search the cnodes inside this graph only
|
||||
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts) {
|
||||
std::deque<CNodePtr> todo(1024);
|
||||
todo.clear();
|
||||
std::vector<CNodePtr> todo;
|
||||
todo.reserve(1024);
|
||||
todo.insert(todo.end(), starts.begin(), starts.end());
|
||||
std::vector<CNodePtr> sorted_nodes;
|
||||
auto seen = NewSeenGeneration();
|
||||
while (!todo.empty()) {
|
||||
CNodePtr top = todo.front();
|
||||
todo.pop_front();
|
||||
size_t top_idx = 0;
|
||||
while (top_idx < todo.size()) {
|
||||
CNodePtr top = todo[top_idx];
|
||||
top_idx++;
|
||||
sorted_nodes.push_back(top);
|
||||
auto inputs = top->inputs();
|
||||
for (auto &item : inputs) {
|
||||
|
@ -121,13 +122,14 @@ std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &s
|
|||
|
||||
// search the cnode match the predicate inside this graph only
|
||||
CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate) {
|
||||
std::deque<CNodePtr> todo(1024);
|
||||
todo.clear();
|
||||
std::vector<CNodePtr> todo;
|
||||
todo.reserve(1024);
|
||||
todo.insert(todo.end(), starts.begin(), starts.end());
|
||||
auto seen = NewSeenGeneration();
|
||||
while (!todo.empty()) {
|
||||
CNodePtr top = todo.front();
|
||||
todo.pop_front();
|
||||
size_t top_idx = 0;
|
||||
while (top_idx < todo.size()) {
|
||||
CNodePtr top = todo[top_idx];
|
||||
top_idx++;
|
||||
if (match_predicate(top)) {
|
||||
return top;
|
||||
}
|
||||
|
@ -147,13 +149,16 @@ CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const Matc
|
|||
}
|
||||
|
||||
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
|
||||
std::deque<FuncGraphPtr> todo;
|
||||
std::vector<FuncGraphPtr> todo;
|
||||
todo.reserve(128);
|
||||
todo.push_back(root);
|
||||
std::vector<FuncGraphPtr> sorted;
|
||||
sorted.reserve(128);
|
||||
auto seen = NewSeenGeneration();
|
||||
while (!todo.empty()) {
|
||||
FuncGraphPtr top = todo.front();
|
||||
todo.pop_front();
|
||||
size_t top_idx = 0;
|
||||
while (top_idx < todo.size()) {
|
||||
FuncGraphPtr top = todo[top_idx];
|
||||
top_idx++;
|
||||
sorted.push_back(top);
|
||||
auto used = top->func_graphs_used();
|
||||
for (auto &item : used) {
|
||||
|
|
|
@ -44,7 +44,7 @@ class Primitive : public Named {
|
|||
Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs);
|
||||
Primitive(const Primitive &prim);
|
||||
MS_DECLARE_PARENT(Primitive, Named);
|
||||
abstract::AbstractBasePtr ToAbstract();
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||
std::string ToString() const override { return name(); }
|
||||
void BeginRecordAddAttr() {
|
||||
|
|
|
@ -25,7 +25,7 @@ void ScopeManager::EnterScope(const ScopePtr &scope) {
|
|||
}
|
||||
|
||||
void ScopeManager::LeaveScope(const ScopePtr &scope) noexcept {
|
||||
if (scope != kDefaultScope) {
|
||||
if (scope != kDefaultScope && !scope_stack_.empty()) {
|
||||
scope_stack_.pop();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -165,7 +165,7 @@ std::string GraphDebugInfo::debug_name() {
|
|||
|
||||
LocationPtr GraphDebugInfo::location() {
|
||||
// function may have decorator which is included in its location
|
||||
if (deco_loc_ != nullptr) {
|
||||
if (deco_loc_ != nullptr && DebugInfo::location() != nullptr) {
|
||||
LocationPtr loc = std::make_shared<Location>(*DebugInfo::location());
|
||||
loc->set_line(loc->line() + (deco_loc_->line_end() - deco_loc_->line() + 1));
|
||||
return loc;
|
||||
|
|
|
@ -348,15 +348,9 @@ class Cell(Cell_):
|
|||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("cell inputs should not be numpy array.")
|
||||
origin_grad = []
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(self, *inputs, **kwargs)
|
||||
for cell in self.cells():
|
||||
origin_grad.append(cell.requires_grad)
|
||||
cell.set_grad(True)
|
||||
else:
|
||||
_pynative_exec.set_grad_flag(False)
|
||||
_pynative_exec.new_graph(self, *inputs, **kwargs)
|
||||
cast_inputs = list()
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
|
@ -368,10 +362,7 @@ class Cell(Cell_):
|
|||
output = self.run_construct(cast_inputs, kwargs)
|
||||
if isinstance(output, Parameter):
|
||||
output = output.data
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
|
||||
for i, cell in enumerate(self.cells()):
|
||||
cell.set_grad(origin_grad[i])
|
||||
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
|
||||
return output
|
||||
|
||||
def _add_attr(self, name, value):
|
||||
|
|
|
@ -491,7 +491,7 @@ def get_bprop_log(self):
|
|||
def bprop(x, out, dout):
|
||||
g = reciprocal(x)
|
||||
dx = g * dout
|
||||
return dx, 0
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
@ -505,7 +505,7 @@ def get_bprop_log1p(self):
|
|||
x_1p = x + 1
|
||||
g = reciprocal(x_1p)
|
||||
dx = g * dout
|
||||
return dx, 0
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -296,5 +296,18 @@ def _add_rowtensor_tensor(x, y):
|
|||
"""
|
||||
return x + y
|
||||
|
||||
@_add_backward.register("None", "None")
|
||||
def _add_nonetensor_tensor(x, y):
|
||||
"""
|
||||
Adds None and None.
|
||||
|
||||
Args:
|
||||
x (None): x
|
||||
y (None): y
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
return x + y
|
||||
|
||||
hyper_add = base.HyperMap(_add_backward)
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "frontend/optimizer/ad/kpynative.h"
|
||||
#include "common/common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
class TestKPynative : public UT::Common {
|
||||
public:
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
|
||||
protected:
|
||||
AbstractBasePtr BuildArg() {
|
||||
std::vector<int64_t> shp = {2, 2};
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
|
||||
auto abstract = tensor->ToAbstract();
|
||||
return abstract;
|
||||
}
|
||||
|
||||
FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
auto x = g->add_parameter();
|
||||
auto y = g->add_parameter();
|
||||
x->set_abstract(BuildArg());
|
||||
y->set_abstract(BuildArg());
|
||||
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
|
||||
c_node->set_abstract(BuildArg());
|
||||
g->set_output(c_node);
|
||||
return g;
|
||||
}
|
||||
|
||||
// a = x * y
|
||||
// b = stop_gradient(a)
|
||||
// c = b * y
|
||||
// return c
|
||||
FuncGraphPtr BuildStopGradient(const std::string &testCase) {
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
auto x = g->add_parameter();
|
||||
x->debug_info()->set_name("x");
|
||||
auto y = g->add_parameter();
|
||||
y->debug_info()->set_name("y");
|
||||
x->set_abstract(BuildArg());
|
||||
y->set_abstract(BuildArg());
|
||||
auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
|
||||
a_node->set_abstract(BuildArg());
|
||||
auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
|
||||
b_node->set_abstract(BuildArg());
|
||||
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
|
||||
c_node->set_abstract(BuildArg());
|
||||
auto d_node =
|
||||
g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
|
||||
d_node->set_abstract(BuildArg());
|
||||
g->set_output(d_node);
|
||||
return g;
|
||||
}
|
||||
|
||||
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
|
||||
auto input_params = primal_fg->parameters();
|
||||
std::vector<ValuePtr> input_param_values;
|
||||
std::for_each(input_params.begin(), input_params.end(),
|
||||
[&](const AnfNodePtr ¶m) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
|
||||
auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
|
||||
auto node_list = TopoSort(primal_fg->output());
|
||||
for (auto node : node_list) {
|
||||
if (node->isa<CNode>()) {
|
||||
auto c_node = node->cast<CNodePtr>();
|
||||
auto out = c_node->abstract()->GetValueTrack();
|
||||
ValuePtrList args;
|
||||
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
|
||||
args.push_back(c_node->input(i)->abstract()->GetValueTrack());
|
||||
}
|
||||
GradPynativeOp(k_pynative_cell, c_node, args, out);
|
||||
}
|
||||
}
|
||||
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false, false, true);
|
||||
return bprop_fg;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestKPynative, test_simple_add) {
|
||||
auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
|
||||
resource->manager()->KeepRoots({primal_fg});
|
||||
ExportIR(primal_fg->ToString() + ".dat", primal_fg);
|
||||
|
||||
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
|
||||
resource->manager()->KeepRoots({bprop_fg});
|
||||
|
||||
ExportIR(bprop_fg->ToString() + ".dat", bprop_fg);
|
||||
}
|
||||
|
||||
TEST_F(TestKPynative, test_stop_gradient) {
|
||||
auto primal_fg = BuildStopGradient("test_stop_gradient");
|
||||
resource->manager()->KeepRoots({primal_fg});
|
||||
ExportIR(primal_fg->ToString() + ".dat", primal_fg);
|
||||
|
||||
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
|
||||
resource->manager()->KeepRoots({bprop_fg});
|
||||
|
||||
ExportIR(bprop_fg->ToString() + ".dat", bprop_fg);
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue