forked from mindspore-Ecosystem/mindspore
convert subgraph
This commit is contained in:
parent
7b5b4837ff
commit
a27ce973ad
|
@ -28,6 +28,7 @@
|
||||||
#include "utils/config_manager.h"
|
#include "utils/config_manager.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "./common.h"
|
#include "./common.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace transform {
|
namespace transform {
|
||||||
|
@ -205,6 +206,7 @@ const char kNameRange[] = "Range";
|
||||||
const char kNameSquareSumAll[] = "SquareSumAll";
|
const char kNameSquareSumAll[] = "SquareSumAll";
|
||||||
const char kNameAscendQuant[] = "AscendQuant";
|
const char kNameAscendQuant[] = "AscendQuant";
|
||||||
const char kNameAscendDequant[] = "AscendDequant";
|
const char kNameAscendDequant[] = "AscendDequant";
|
||||||
|
const char kNameCase[] = "Case";
|
||||||
|
|
||||||
// -----------------OpAdapter initialization--------------
|
// -----------------OpAdapter initialization--------------
|
||||||
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
||||||
|
@ -411,7 +413,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameRange), ADPT_DESC(RangeD)},
|
{string(kNameRange), ADPT_DESC(RangeD)},
|
||||||
{string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)},
|
{string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)},
|
||||||
{string(kNameAscendQuant), ADPT_DESC(AscendQuant)},
|
{string(kNameAscendQuant), ADPT_DESC(AscendQuant)},
|
||||||
{string(kNameAscendDequant), ADPT_DESC(AscendDequant)}};
|
{string(kNameAscendDequant), ADPT_DESC(AscendDequant)},
|
||||||
|
{string(kNameCase), ADPT_DESC(Case)}};
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
||||||
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
|
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
|
||||||
|
@ -433,13 +436,32 @@ PrimType GetCNodeFuncType(const CNodePtr cnode) {
|
||||||
return kPrimTypeUnknown;
|
return kPrimTypeUnknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsCaseNode(const CNodePtr node) {
|
||||||
|
if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
|
||||||
|
GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetCNodeTargetFuncName(const CNodePtr cnode) {
|
||||||
|
if (IsCaseNode(cnode)) {
|
||||||
|
return string(kNameCase);
|
||||||
|
}
|
||||||
|
auto name = GetCNodeFuncName(cnode);
|
||||||
|
if (name == "switch_layer") {
|
||||||
|
name = "";
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
|
OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
|
||||||
std::string name = kNameCustomOp;
|
std::string name = kNameCustomOp;
|
||||||
if (!IsCustomCNode(cnode)) {
|
if (!IsCustomCNode(cnode)) {
|
||||||
name = GetCNodeFuncName(cnode);
|
name = GetCNodeTargetFuncName(cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto it_adpt = get_adpt_map().find(name);
|
auto it_adpt = get_adpt_map().find(name);
|
||||||
|
@ -957,7 +979,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
|
||||||
auto c = anf_out->cast<CNodePtr>();
|
auto c = anf_out->cast<CNodePtr>();
|
||||||
std::string name = "";
|
std::string name = "";
|
||||||
if (anf_out->isa<CNode>()) {
|
if (anf_out->isa<CNode>()) {
|
||||||
name = GetCNodeFuncName(c);
|
name = GetCNodeTargetFuncName(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (name == "make_tuple") {
|
if (name == "make_tuple") {
|
||||||
|
@ -1029,6 +1051,99 @@ void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DfGraphConvertor::SetSubgraph(AnfNodePtr node) {
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (!IsCaseNode(cnode)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> case_inputs;
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||||
|
case_inputs.emplace_back(cnode->input(i));
|
||||||
|
}
|
||||||
|
std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>();
|
||||||
|
auto bnode = cnode->input(0)->cast<CNodePtr>()->input(2)->cast<CNodePtr>();
|
||||||
|
|
||||||
|
for (size_t i = 1; i < bnode->inputs().size(); i++) {
|
||||||
|
auto branch_node = bnode->input(i)->cast<CNodePtr>();
|
||||||
|
for (size_t j = 2; j < branch_node->inputs().size(); j++) {
|
||||||
|
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
|
||||||
|
case_inputs.emplace_back(branch_node->input(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < bnode->inputs().size(); i++) {
|
||||||
|
ProcessSubgraph(bnode->input(i), case_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < bnode->inputs().size(); i++) {
|
||||||
|
branches->emplace_back(branches_map_[bnode->input(i).get()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_cache_.find(node.get()) == op_cache_.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpAdapterPtr adpt = FindAdapter(node, training_);
|
||||||
|
if (nullptr == adpt) {
|
||||||
|
MS_LOG(DEBUG) << "Not found adapter";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
OperatorPtr op = Convert(node);
|
||||||
|
adpt->setSubgraph(op, 0, branches);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) {
|
||||||
|
std::vector<AnfNodePtr> case_inputs;
|
||||||
|
for (size_t i = 1; i < node->inputs().size(); i++) {
|
||||||
|
case_inputs.emplace_back(node->input(i));
|
||||||
|
}
|
||||||
|
std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>();
|
||||||
|
auto bnode = input_node->input(2)->cast<CNodePtr>();
|
||||||
|
|
||||||
|
for (size_t i = 1; i < bnode->inputs().size(); i++) {
|
||||||
|
auto branch_node = bnode->input(i)->cast<CNodePtr>();
|
||||||
|
for (size_t j = 2; j < branch_node->inputs().size(); j++) {
|
||||||
|
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
|
||||||
|
case_inputs.emplace_back(branch_node->input(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t case_index = 1;
|
||||||
|
const size_t make_tuple_index = 2;
|
||||||
|
|
||||||
|
AnfNodePtr case_index_iter = input_node->input(case_index);
|
||||||
|
AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index);
|
||||||
|
auto make_tuple_node = make_tuple_iter->cast<CNodePtr>();
|
||||||
|
std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < case_inputs.size(); i++) {
|
||||||
|
auto item = case_inputs[i];
|
||||||
|
auto op = Convert(item);
|
||||||
|
if (op != nullptr) {
|
||||||
|
tuple_items->emplace_back(OutHandler(op, ""));
|
||||||
|
} else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) {
|
||||||
|
tuple_items->push_back(out_handle_cache_[item.get()]);
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items;
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<AnfNodePtr>> case_input_items = std::make_shared<std::vector<AnfNodePtr>>();
|
||||||
|
case_input_items->emplace_back(case_index_iter);
|
||||||
|
case_input_items->emplace_back(make_tuple_iter);
|
||||||
|
case_input_handle_cache_[node.get()] = case_input_items;
|
||||||
|
}
|
||||||
|
|
||||||
DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
||||||
SetupDatasetIterGetNextNode(dataset_iter_getnext_);
|
SetupDatasetIterGetNextNode(dataset_iter_getnext_);
|
||||||
|
|
||||||
|
@ -1036,6 +1151,16 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Case node set input.
|
||||||
|
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
||||||
|
for (auto &it : nodes) {
|
||||||
|
if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) {
|
||||||
|
auto node = it->cast<CNodePtr>();
|
||||||
|
auto input_node = node->input(0)->cast<CNodePtr>();
|
||||||
|
GetCaseNodeInput(node, input_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// update tuple_out_handle_cache_
|
// update tuple_out_handle_cache_
|
||||||
for (auto it : tuple_out_handle_cache_) {
|
for (auto it : tuple_out_handle_cache_) {
|
||||||
std::size_t len = it.second->size();
|
std::size_t len = it.second->size();
|
||||||
|
@ -1056,10 +1181,11 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
||||||
|
|
||||||
// set up dependices
|
// set up dependices
|
||||||
MS_LOG(DEBUG) << "set up dependices";
|
MS_LOG(DEBUG) << "set up dependices";
|
||||||
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
||||||
for (auto &it : nodes) {
|
for (auto &it : nodes) {
|
||||||
SetNodeInput(it);
|
SetNodeInput(it);
|
||||||
SetOpControlInput(it);
|
SetOpControlInput(it);
|
||||||
|
SetSubgraph(it);
|
||||||
UpdateOpDesc(it);
|
UpdateOpDesc(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1075,6 +1201,18 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
||||||
inputs.push_back(*dataset_iter_getnext_);
|
inputs.push_back(*dataset_iter_getnext_);
|
||||||
} else {
|
} else {
|
||||||
auto params = anf_graph_->parameters();
|
auto params = anf_graph_->parameters();
|
||||||
|
if (use_inputs_) {
|
||||||
|
params = inputs_;
|
||||||
|
auto anf_params = anf_graph_->parameters();
|
||||||
|
for (size_t i = 0; i < params.size(); i++) {
|
||||||
|
for (size_t j = 0; j < anf_params.size(); j++) {
|
||||||
|
if (params[i]->ToString() == anf_params[j]->ToString()) {
|
||||||
|
params[i] = anf_params[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (auto &it : params) {
|
for (auto &it : params) {
|
||||||
auto name = std::static_pointer_cast<Parameter>(it)->name();
|
auto name = std::static_pointer_cast<Parameter>(it)->name();
|
||||||
|
@ -1185,10 +1323,21 @@ const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNa
|
||||||
|
|
||||||
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
||||||
OperatorPtr src = Convert(node);
|
OperatorPtr src = Convert(node);
|
||||||
|
int case_flag = 0;
|
||||||
auto &inputs = node->inputs();
|
auto &inputs = node->inputs();
|
||||||
for (size_t i = 1; i < inputs.size(); i++) {
|
size_t input_size = inputs.size();
|
||||||
|
if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) {
|
||||||
|
case_flag = 1;
|
||||||
|
input_size = case_input_handle_cache_[node.get()]->size() + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < input_size; i++) {
|
||||||
auto pred = inputs[i];
|
auto pred = inputs[i];
|
||||||
while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "Depend") {
|
if (case_flag != 0) {
|
||||||
|
pred = case_input_handle_cache_[node.get()]->at(i - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "Depend") {
|
||||||
pred = pred->cast<CNodePtr>()->input(1);
|
pred = pred->cast<CNodePtr>()->input(1);
|
||||||
}
|
}
|
||||||
// skip the None input
|
// skip the None input
|
||||||
|
@ -1196,7 +1345,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
||||||
std::string c_name = GetCNodeFuncName(node);
|
std::string c_name = GetCNodeTargetFuncName(node);
|
||||||
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
||||||
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
||||||
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
||||||
|
@ -1220,7 +1369,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
||||||
if (it != out_handle_cache_.end()) {
|
if (it != out_handle_cache_.end()) {
|
||||||
int ret = adpt->setInput(src, SizeToInt(i), it->second);
|
int ret = adpt->setInput(src, SizeToInt(i), it->second);
|
||||||
if (ret == 0) {
|
if (ret == 0) {
|
||||||
if (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") {
|
if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") {
|
||||||
compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
|
compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
|
||||||
<< ":" << i << endl;
|
<< ":" << i << endl;
|
||||||
} else if (pred->isa<Parameter>()) {
|
} else if (pred->isa<Parameter>()) {
|
||||||
|
@ -1278,6 +1427,23 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
|
||||||
DfGraphConvertor::SetOpInput(adpt, cnode);
|
DfGraphConvertor::SetOpInput(adpt, cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs) {
|
||||||
|
if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != "Partial") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
|
||||||
|
FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
|
||||||
|
DfGraphConvertor convertor(anf_graph);
|
||||||
|
convertor.use_inputs_ = true;
|
||||||
|
convertor.inputs_ = inputs;
|
||||||
|
(void)convertor.ConvertAllNode().BuildGraph();
|
||||||
|
std::string name = graph_node->ToString() + "_ge_graph.dot";
|
||||||
|
if (MsContext::GetInstance()->save_graphs_flag()) {
|
||||||
|
convertor.DrawComputeGraph(name);
|
||||||
|
}
|
||||||
|
branches_map_[node.get()] = *(convertor.df_graph_);
|
||||||
|
}
|
||||||
|
|
||||||
// Update GE op's shape and type info
|
// Update GE op's shape and type info
|
||||||
void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) {
|
void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) {
|
||||||
if (nullptr == node || !node->isa<CNode>()) {
|
if (nullptr == node || !node->isa<CNode>()) {
|
||||||
|
@ -1348,6 +1514,7 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size();
|
||||||
tuple_out_handle_cache_[node.get()] = tuple_items;
|
tuple_out_handle_cache_[node.get()] = tuple_items;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1711,6 +1878,14 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (name == "" && GetCNodeFuncName(node) == "switch_layer") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (name == "Partial") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
|
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
|
||||||
if (name == "make_tuple") {
|
if (name == "make_tuple") {
|
||||||
ConvertMakeTuple(node);
|
ConvertMakeTuple(node);
|
||||||
|
@ -1732,7 +1907,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
||||||
}
|
}
|
||||||
|
|
||||||
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
|
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
|
||||||
std::string name = GetCNodeFuncName(node);
|
std::string name = GetCNodeTargetFuncName(node);
|
||||||
if (!CheckCNode(name, node)) {
|
if (!CheckCNode(name, node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1879,7 +2054,7 @@ void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) {
|
||||||
}
|
}
|
||||||
|
|
||||||
compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString()
|
compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString()
|
||||||
<< ":" << GetCNodeFuncName(node) << "\"</td></tr>" << endl;
|
<< ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl;
|
||||||
|
|
||||||
// print attrs' values
|
// print attrs' values
|
||||||
auto atts = adpt->GetAttrsFromDrawGraph();
|
auto atts = adpt->GetAttrsFromDrawGraph();
|
||||||
|
|
|
@ -201,6 +201,7 @@ class DfGraphConvertor {
|
||||||
OperatorPtr ConvertParameter(AnfNodePtr node);
|
OperatorPtr ConvertParameter(AnfNodePtr node);
|
||||||
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
|
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
|
||||||
OperatorPtr ConvertValueNode(ValueNodePtr node);
|
OperatorPtr ConvertValueNode(ValueNodePtr node);
|
||||||
|
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
|
||||||
void ConvertTupleGetItem(const CNodePtr node);
|
void ConvertTupleGetItem(const CNodePtr node);
|
||||||
void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node,
|
void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node,
|
||||||
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
|
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
|
||||||
|
@ -217,6 +218,8 @@ class DfGraphConvertor {
|
||||||
void SetNodeInput(AnfNodePtr node);
|
void SetNodeInput(AnfNodePtr node);
|
||||||
void SetOpControlInput(const AnfNodePtr node);
|
void SetOpControlInput(const AnfNodePtr node);
|
||||||
void UpdateOpDesc(AnfNodePtr node);
|
void UpdateOpDesc(AnfNodePtr node);
|
||||||
|
void SetSubgraph(AnfNodePtr node);
|
||||||
|
void ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs);
|
||||||
void BuildSaveCheckpointGraph();
|
void BuildSaveCheckpointGraph();
|
||||||
void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
|
void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
|
||||||
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
|
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
|
||||||
|
@ -228,22 +231,26 @@ class DfGraphConvertor {
|
||||||
std::shared_ptr<DfGraph> save_ckp_graph_{nullptr};
|
std::shared_ptr<DfGraph> save_ckp_graph_{nullptr};
|
||||||
std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr};
|
std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr};
|
||||||
std::shared_ptr<DfGraph> broadcast_graph_{nullptr};
|
std::shared_ptr<DfGraph> broadcast_graph_{nullptr};
|
||||||
|
std::unordered_map<AnfNode *, DfGraph> branches_map_;
|
||||||
std::unordered_map<AnfNode *, OperatorPtr> op_cache_;
|
std::unordered_map<AnfNode *, OperatorPtr> op_cache_;
|
||||||
std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_;
|
std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_;
|
||||||
/* record "tuple_getitem"<->"out_handler" mapping */
|
/* record "tuple_getitem"<->"out_handler" mapping */
|
||||||
std::unordered_map<AnfNode *, OutHandler> out_handle_cache_;
|
std::unordered_map<AnfNode *, OutHandler> out_handle_cache_;
|
||||||
/* record "make_tuple"<->"out_handler vector" mapping */
|
/* record "make_tuple"<->"out_handler vector" mapping */
|
||||||
std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_;
|
std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_;
|
||||||
|
std::unordered_map<AnfNode *, std::shared_ptr<std::vector<AnfNodePtr>>> case_input_handle_cache_;
|
||||||
std::unordered_map<std::string, AnfNodePtr> params_;
|
std::unordered_map<std::string, AnfNodePtr> params_;
|
||||||
std::unordered_map<std::string, OperatorPtr> vars_;
|
std::unordered_map<std::string, OperatorPtr> vars_;
|
||||||
std::vector<std::pair<ge::Operator, std::string>> graph_outputs_;
|
std::vector<std::pair<ge::Operator, std::string>> graph_outputs_;
|
||||||
std::vector<OperatorPtr> graph_const_inputs_;
|
std::vector<OperatorPtr> graph_const_inputs_;
|
||||||
std::vector<OperatorPtr> init_ops_;
|
std::vector<OperatorPtr> init_ops_;
|
||||||
std::vector<OperatorPtr> broadcast_ops_;
|
std::vector<OperatorPtr> broadcast_ops_;
|
||||||
|
std::vector<AnfNodePtr> inputs_;
|
||||||
OperatorPtr dataset_iter_getnext_;
|
OperatorPtr dataset_iter_getnext_;
|
||||||
Status error_ = SUCCESS;
|
Status error_ = SUCCESS;
|
||||||
bool training_ = false;
|
bool training_ = false;
|
||||||
bool distribute_ = false;
|
bool distribute_ = false;
|
||||||
|
bool use_inputs_ = false;
|
||||||
};
|
};
|
||||||
} // namespace transform
|
} // namespace transform
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -164,6 +164,25 @@ class OpAdapter : public BaseOpAdapter {
|
||||||
const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; }
|
const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; }
|
||||||
const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; }
|
const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; }
|
||||||
const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; }
|
const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; }
|
||||||
|
const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; }
|
||||||
|
|
||||||
|
Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
auto it = dyn_subgraph_map_.find(index);
|
||||||
|
if (it != dyn_subgraph_map_.end()) {
|
||||||
|
auto size = branches->size();
|
||||||
|
it->second.create_dyn_subgraph(op, static_cast<unsigned int>(size));
|
||||||
|
for (size_t i = 0; i < size; i++) {
|
||||||
|
it->second.set_subgraph(op, static_cast<unsigned int>(i), std::make_shared<DfGraph>((*branches)[i]));
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
return NOT_FOUND;
|
||||||
|
}
|
||||||
|
|
||||||
|
int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) override {
|
||||||
|
return static_cast<int>(SetOpSubgraphFunc(op, index, branches));
|
||||||
|
}
|
||||||
|
|
||||||
Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) {
|
Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) {
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
@ -855,6 +874,7 @@ class OpAdapter : public BaseOpAdapter {
|
||||||
static const std::unordered_map<int, DynInputDesc> dyn_input_map_;
|
static const std::unordered_map<int, DynInputDesc> dyn_input_map_;
|
||||||
static const std::unordered_map<int, OutputDesc> output_map_;
|
static const std::unordered_map<int, OutputDesc> output_map_;
|
||||||
static const std::unordered_map<int, DynOutputDesc> dyn_output_map_;
|
static const std::unordered_map<int, DynOutputDesc> dyn_output_map_;
|
||||||
|
static const std::unordered_map<int, DynSubGraphDesc> dyn_subgraph_map_;
|
||||||
static const std::unordered_map<std::string, AttrDesc> attr_map_;
|
static const std::unordered_map<std::string, AttrDesc> attr_map_;
|
||||||
static const std::unordered_map<std::string, int> enum_map_;
|
static const std::unordered_map<std::string, int> enum_map_;
|
||||||
// convert input from anf graph to Attr in Operators
|
// convert input from anf graph to Attr in Operators
|
||||||
|
@ -874,6 +894,8 @@ const std::unordered_map<int, OutputDesc> OpAdapter<T>::output_map_;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
|
const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
|
||||||
|
template <typename T>
|
||||||
const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_;
|
const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_;
|
const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_;
|
||||||
|
|
|
@ -88,6 +88,8 @@ using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr
|
||||||
using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
|
using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
|
||||||
using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
|
using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
|
||||||
using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
|
using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
|
||||||
|
using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>;
|
||||||
|
using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>;
|
||||||
|
|
||||||
struct AttrDesc {
|
struct AttrDesc {
|
||||||
std::string name;
|
std::string name;
|
||||||
|
@ -108,6 +110,12 @@ struct DynInputDesc {
|
||||||
DynInputHandleFunc set_handle;
|
DynInputHandleFunc set_handle;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct DynSubGraphDesc {
|
||||||
|
std::string name;
|
||||||
|
CreateDynSubGraphFunc create_dyn_subgraph;
|
||||||
|
DynSubGraphFunc set_subgraph;
|
||||||
|
};
|
||||||
|
|
||||||
struct OutputDesc {
|
struct OutputDesc {
|
||||||
std::string name;
|
std::string name;
|
||||||
UpdateOutputDescFunc update_out_desc;
|
UpdateOutputDescFunc update_out_desc;
|
||||||
|
@ -123,6 +131,7 @@ class BaseOpAdapter {
|
||||||
virtual ~BaseOpAdapter() {}
|
virtual ~BaseOpAdapter() {}
|
||||||
virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
|
virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
|
||||||
virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
|
virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
|
||||||
|
virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) = 0;
|
||||||
virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
|
virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
|
||||||
virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
|
virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
|
||||||
virtual int setInput(const OperatorPtr &op, int index,
|
virtual int setInput(const OperatorPtr &op, int index,
|
||||||
|
@ -146,6 +155,7 @@ class BaseOpAdapter {
|
||||||
virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0;
|
virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0;
|
||||||
virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0;
|
virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0;
|
||||||
virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0;
|
virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0;
|
||||||
|
virtual const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() = 0;
|
||||||
void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
|
void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
|
||||||
const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
|
const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
|
||||||
void clearAttrVect() { attrs_vec_.clear(); }
|
void clearAttrVect() { attrs_vec_.clear(); }
|
||||||
|
|
|
@ -64,6 +64,22 @@ namespace transform {
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DYN_SUBGRAPH_MAP(T) \
|
||||||
|
template <> \
|
||||||
|
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_
|
||||||
|
#define DYN_SUBGRAPH_DESC(name) \
|
||||||
|
{ \
|
||||||
|
#name, \
|
||||||
|
[](const OperatorPtr op, unsigned int num) { \
|
||||||
|
auto p = std::static_pointer_cast<OpType>(op); \
|
||||||
|
(void)p->create_dynamic_subgraph_##name(num); \
|
||||||
|
}, \
|
||||||
|
[](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \
|
||||||
|
auto p = std::static_pointer_cast<OpType>(op); \
|
||||||
|
(void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
#define ATTR_MAP(T) \
|
#define ATTR_MAP(T) \
|
||||||
template <> \
|
template <> \
|
||||||
const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_
|
const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_
|
||||||
|
@ -841,6 +857,13 @@ INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}};
|
||||||
ATTR_MAP(Cast) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Cast) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
|
// Case
|
||||||
|
INPUT_MAP(Case) = {{1, INPUT_DESC(branch_index)}};
|
||||||
|
DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}};
|
||||||
|
ATTR_MAP(Case) = EMPTY_ATTR_MAP;
|
||||||
|
DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}};
|
||||||
|
DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}};
|
||||||
|
|
||||||
// Reciprocal
|
// Reciprocal
|
||||||
INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP;
|
||||||
|
|
|
@ -46,6 +46,10 @@ namespace transform {
|
||||||
template <> \
|
template <> \
|
||||||
const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_;
|
const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_;
|
||||||
|
|
||||||
|
#define DECLARE_OP_USE_DYN_SUBGRAPH(T) \
|
||||||
|
template <> \
|
||||||
|
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
|
||||||
|
|
||||||
#define DECLARE_OP_USE_DYN_OUTPUT(T) \
|
#define DECLARE_OP_USE_DYN_OUTPUT(T) \
|
||||||
template <> \
|
template <> \
|
||||||
const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
|
const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
|
||||||
|
@ -232,6 +236,10 @@ DECLARE_OP_USE_OUTPUT(RealDiv)
|
||||||
DECLARE_OP_ADAPTER(Cast)
|
DECLARE_OP_ADAPTER(Cast)
|
||||||
DECLARE_OP_USE_INPUT_ATTR(Cast)
|
DECLARE_OP_USE_INPUT_ATTR(Cast)
|
||||||
DECLARE_OP_USE_OUTPUT(Cast)
|
DECLARE_OP_USE_OUTPUT(Cast)
|
||||||
|
DECLARE_OP_ADAPTER(Case)
|
||||||
|
DECLARE_OP_USE_DYN_INPUT(Case)
|
||||||
|
DECLARE_OP_USE_DYN_SUBGRAPH(Case)
|
||||||
|
DECLARE_OP_USE_DYN_OUTPUT(Case)
|
||||||
DECLARE_OP_ADAPTER(Reciprocal)
|
DECLARE_OP_ADAPTER(Reciprocal)
|
||||||
DECLARE_OP_USE_OUTPUT(Reciprocal)
|
DECLARE_OP_USE_OUTPUT(Reciprocal)
|
||||||
DECLARE_OP_ADAPTER(Neg)
|
DECLARE_OP_ADAPTER(Neg)
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Test case."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(1, 3, 3)
|
||||||
|
self.conv2 = nn.Conv2d(1, 3, 5, has_bias=True)
|
||||||
|
self.layers = (self.conv1, self.conv2)
|
||||||
|
|
||||||
|
def construct(self, x, index):
|
||||||
|
x = self.layers[index](x)
|
||||||
|
y = self.conv1(x)
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
def test_case():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
net = Net()
|
||||||
|
data = Tensor(np.ones((1, 1, 224, 224)), mindspore.float32)
|
||||||
|
idx = Tensor(1, mindspore.int32)
|
||||||
|
net(data, idx)
|
Loading…
Reference in New Issue