convert subgraph

This commit is contained in:
changzherui 2020-06-14 12:01:01 +08:00
parent 7b5b4837ff
commit a27ce973ad
7 changed files with 296 additions and 10 deletions

View File

@ -28,6 +28,7 @@
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
#include "./common.h" #include "./common.h"
#include "utils/context/ms_context.h"
namespace mindspore { namespace mindspore {
namespace transform { namespace transform {
@ -205,6 +206,7 @@ const char kNameRange[] = "Range";
const char kNameSquareSumAll[] = "SquareSumAll"; const char kNameSquareSumAll[] = "SquareSumAll";
const char kNameAscendQuant[] = "AscendQuant"; const char kNameAscendQuant[] = "AscendQuant";
const char kNameAscendDequant[] = "AscendDequant"; const char kNameAscendDequant[] = "AscendDequant";
const char kNameCase[] = "Case";
// -----------------OpAdapter initialization-------------- // -----------------OpAdapter initialization--------------
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() { std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
@ -411,7 +413,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameRange), ADPT_DESC(RangeD)}, {string(kNameRange), ADPT_DESC(RangeD)},
{string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)},
{string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, {string(kNameAscendQuant), ADPT_DESC(AscendQuant)},
{string(kNameAscendDequant), ADPT_DESC(AscendDequant)}}; {string(kNameAscendDequant), ADPT_DESC(AscendDequant)},
{string(kNameCase), ADPT_DESC(Case)}};
#ifdef ENABLE_GE #ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
@ -433,13 +436,32 @@ PrimType GetCNodeFuncType(const CNodePtr cnode) {
return kPrimTypeUnknown; return kPrimTypeUnknown;
} }
bool IsCaseNode(const CNodePtr node) {
if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
return true;
}
return false;
}
std::string GetCNodeTargetFuncName(const CNodePtr cnode) {
if (IsCaseNode(cnode)) {
return string(kNameCase);
}
auto name = GetCNodeFuncName(cnode);
if (name == "switch_layer") {
name = "";
}
return name;
}
OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
std::string name = kNameCustomOp; std::string name = kNameCustomOp;
if (!IsCustomCNode(cnode)) { if (!IsCustomCNode(cnode)) {
name = GetCNodeFuncName(cnode); name = GetCNodeTargetFuncName(cnode);
} }
auto it_adpt = get_adpt_map().find(name); auto it_adpt = get_adpt_map().find(name);
@ -957,7 +979,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
auto c = anf_out->cast<CNodePtr>(); auto c = anf_out->cast<CNodePtr>();
std::string name = ""; std::string name = "";
if (anf_out->isa<CNode>()) { if (anf_out->isa<CNode>()) {
name = GetCNodeFuncName(c); name = GetCNodeTargetFuncName(c);
} }
if (name == "make_tuple") { if (name == "make_tuple") {
@ -1029,6 +1051,99 @@ void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
return; return;
} }
void DfGraphConvertor::SetSubgraph(AnfNodePtr node) {
if (!node->isa<CNode>()) {
return;
}
auto cnode = node->cast<CNodePtr>();
if (!IsCaseNode(cnode)) {
return;
}
std::vector<AnfNodePtr> case_inputs;
for (size_t i = 1; i < cnode->inputs().size(); i++) {
case_inputs.emplace_back(cnode->input(i));
}
std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>();
auto bnode = cnode->input(0)->cast<CNodePtr>()->input(2)->cast<CNodePtr>();
for (size_t i = 1; i < bnode->inputs().size(); i++) {
auto branch_node = bnode->input(i)->cast<CNodePtr>();
for (size_t j = 2; j < branch_node->inputs().size(); j++) {
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
case_inputs.emplace_back(branch_node->input(j));
}
}
}
for (size_t i = 1; i < bnode->inputs().size(); i++) {
ProcessSubgraph(bnode->input(i), case_inputs);
}
for (size_t i = 1; i < bnode->inputs().size(); i++) {
branches->emplace_back(branches_map_[bnode->input(i).get()]);
}
if (op_cache_.find(node.get()) == op_cache_.end()) {
return;
}
OpAdapterPtr adpt = FindAdapter(node, training_);
if (nullptr == adpt) {
MS_LOG(DEBUG) << "Not found adapter";
return;
}
OperatorPtr op = Convert(node);
adpt->setSubgraph(op, 0, branches);
return;
}
void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) {
std::vector<AnfNodePtr> case_inputs;
for (size_t i = 1; i < node->inputs().size(); i++) {
case_inputs.emplace_back(node->input(i));
}
std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>();
auto bnode = input_node->input(2)->cast<CNodePtr>();
for (size_t i = 1; i < bnode->inputs().size(); i++) {
auto branch_node = bnode->input(i)->cast<CNodePtr>();
for (size_t j = 2; j < branch_node->inputs().size(); j++) {
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
case_inputs.emplace_back(branch_node->input(j));
}
}
}
const size_t case_index = 1;
const size_t make_tuple_index = 2;
AnfNodePtr case_index_iter = input_node->input(case_index);
AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index);
auto make_tuple_node = make_tuple_iter->cast<CNodePtr>();
std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
for (size_t i = 0; i < case_inputs.size(); i++) {
auto item = case_inputs[i];
auto op = Convert(item);
if (op != nullptr) {
tuple_items->emplace_back(OutHandler(op, ""));
} else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) {
tuple_items->push_back(out_handle_cache_[item.get()]);
} else {
MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString();
continue;
}
}
tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items;
std::shared_ptr<std::vector<AnfNodePtr>> case_input_items = std::make_shared<std::vector<AnfNodePtr>>();
case_input_items->emplace_back(case_index_iter);
case_input_items->emplace_back(make_tuple_iter);
case_input_handle_cache_[node.get()] = case_input_items;
}
DfGraphConvertor &DfGraphConvertor::BuildGraph() { DfGraphConvertor &DfGraphConvertor::BuildGraph() {
SetupDatasetIterGetNextNode(dataset_iter_getnext_); SetupDatasetIterGetNextNode(dataset_iter_getnext_);
@ -1036,6 +1151,16 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
return *this; return *this;
} }
// Case node set input.
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return());
for (auto &it : nodes) {
if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) {
auto node = it->cast<CNodePtr>();
auto input_node = node->input(0)->cast<CNodePtr>();
GetCaseNodeInput(node, input_node);
}
}
// update tuple_out_handle_cache_ // update tuple_out_handle_cache_
for (auto it : tuple_out_handle_cache_) { for (auto it : tuple_out_handle_cache_) {
std::size_t len = it.second->size(); std::size_t len = it.second->size();
@ -1056,10 +1181,11 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
// set up dependices // set up dependices
MS_LOG(DEBUG) << "set up dependices"; MS_LOG(DEBUG) << "set up dependices";
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return()); nodes = ::mindspore::TopoSort(anf_graph_->get_return());
for (auto &it : nodes) { for (auto &it : nodes) {
SetNodeInput(it); SetNodeInput(it);
SetOpControlInput(it); SetOpControlInput(it);
SetSubgraph(it);
UpdateOpDesc(it); UpdateOpDesc(it);
} }
@ -1075,6 +1201,18 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
inputs.push_back(*dataset_iter_getnext_); inputs.push_back(*dataset_iter_getnext_);
} else { } else {
auto params = anf_graph_->parameters(); auto params = anf_graph_->parameters();
if (use_inputs_) {
params = inputs_;
auto anf_params = anf_graph_->parameters();
for (size_t i = 0; i < params.size(); i++) {
for (size_t j = 0; j < anf_params.size(); j++) {
if (params[i]->ToString() == anf_params[j]->ToString()) {
params[i] = anf_params[j];
}
}
}
}
int index = 0; int index = 0;
for (auto &it : params) { for (auto &it : params) {
auto name = std::static_pointer_cast<Parameter>(it)->name(); auto name = std::static_pointer_cast<Parameter>(it)->name();
@ -1185,10 +1323,21 @@ const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNa
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
OperatorPtr src = Convert(node); OperatorPtr src = Convert(node);
int case_flag = 0;
auto &inputs = node->inputs(); auto &inputs = node->inputs();
for (size_t i = 1; i < inputs.size(); i++) { size_t input_size = inputs.size();
if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) {
case_flag = 1;
input_size = case_input_handle_cache_[node.get()]->size() + 1;
}
for (size_t i = 1; i < input_size; i++) {
auto pred = inputs[i]; auto pred = inputs[i];
while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "Depend") { if (case_flag != 0) {
pred = case_input_handle_cache_[node.get()]->at(i - 1);
}
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "Depend") {
pred = pred->cast<CNodePtr>()->input(1); pred = pred->cast<CNodePtr>()->input(1);
} }
// skip the None input // skip the None input
@ -1196,7 +1345,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
continue; continue;
} }
// transform "Const" op to "Variable" op when the next node is "Assign" op. // transform "Const" op to "Variable" op when the next node is "Assign" op.
std::string c_name = GetCNodeFuncName(node); std::string c_name = GetCNodeTargetFuncName(node);
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) { if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
std::string name = std::static_pointer_cast<Parameter>(pred)->name(); std::string name = std::static_pointer_cast<Parameter>(pred)->name();
@ -1220,7 +1369,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
if (it != out_handle_cache_.end()) { if (it != out_handle_cache_.end()) {
int ret = adpt->setInput(src, SizeToInt(i), it->second); int ret = adpt->setInput(src, SizeToInt(i), it->second);
if (ret == 0) { if (ret == 0) {
if (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") { if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == "tuple_getitem") {
compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()] compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
<< ":" << i << endl; << ":" << i << endl;
} else if (pred->isa<Parameter>()) { } else if (pred->isa<Parameter>()) {
@ -1278,6 +1427,23 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
DfGraphConvertor::SetOpInput(adpt, cnode); DfGraphConvertor::SetOpInput(adpt, cnode);
} }
void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs) {
if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != "Partial") {
return;
}
auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
DfGraphConvertor convertor(anf_graph);
convertor.use_inputs_ = true;
convertor.inputs_ = inputs;
(void)convertor.ConvertAllNode().BuildGraph();
std::string name = graph_node->ToString() + "_ge_graph.dot";
if (MsContext::GetInstance()->save_graphs_flag()) {
convertor.DrawComputeGraph(name);
}
branches_map_[node.get()] = *(convertor.df_graph_);
}
// Update GE op's shape and type info // Update GE op's shape and type info
void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) {
if (nullptr == node || !node->isa<CNode>()) { if (nullptr == node || !node->isa<CNode>()) {
@ -1348,6 +1514,7 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
} }
} }
MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size();
tuple_out_handle_cache_[node.get()] = tuple_items; tuple_out_handle_cache_[node.get()] = tuple_items;
} }
@ -1711,6 +1878,14 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
return false; return false;
} }
if (name == "" && GetCNodeFuncName(node) == "switch_layer") {
return false;
}
if (name == "Partial") {
return false;
}
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
if (name == "make_tuple") { if (name == "make_tuple") {
ConvertMakeTuple(node); ConvertMakeTuple(node);
@ -1732,7 +1907,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
} }
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
std::string name = GetCNodeFuncName(node); std::string name = GetCNodeTargetFuncName(node);
if (!CheckCNode(name, node)) { if (!CheckCNode(name, node)) {
return nullptr; return nullptr;
} }
@ -1879,7 +2054,7 @@ void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) {
} }
compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString() compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString()
<< ":" << GetCNodeFuncName(node) << "\"</td></tr>" << endl; << ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl;
// print attrs' values // print attrs' values
auto atts = adpt->GetAttrsFromDrawGraph(); auto atts = adpt->GetAttrsFromDrawGraph();

View File

@ -201,6 +201,7 @@ class DfGraphConvertor {
OperatorPtr ConvertParameter(AnfNodePtr node); OperatorPtr ConvertParameter(AnfNodePtr node);
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
OperatorPtr ConvertValueNode(ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node);
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
void ConvertTupleGetItem(const CNodePtr node); void ConvertTupleGetItem(const CNodePtr node);
void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node,
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
@ -217,6 +218,8 @@ class DfGraphConvertor {
void SetNodeInput(AnfNodePtr node); void SetNodeInput(AnfNodePtr node);
void SetOpControlInput(const AnfNodePtr node); void SetOpControlInput(const AnfNodePtr node);
void UpdateOpDesc(AnfNodePtr node); void UpdateOpDesc(AnfNodePtr node);
void SetSubgraph(AnfNodePtr node);
void ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs);
void BuildSaveCheckpointGraph(); void BuildSaveCheckpointGraph();
void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
@ -228,22 +231,26 @@ class DfGraphConvertor {
std::shared_ptr<DfGraph> save_ckp_graph_{nullptr}; std::shared_ptr<DfGraph> save_ckp_graph_{nullptr};
std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr}; std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr};
std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; std::shared_ptr<DfGraph> broadcast_graph_{nullptr};
std::unordered_map<AnfNode *, DfGraph> branches_map_;
std::unordered_map<AnfNode *, OperatorPtr> op_cache_; std::unordered_map<AnfNode *, OperatorPtr> op_cache_;
std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_; std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_;
/* record "tuple_getitem"<->"out_handler" mapping */ /* record "tuple_getitem"<->"out_handler" mapping */
std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; std::unordered_map<AnfNode *, OutHandler> out_handle_cache_;
/* record "make_tuple"<->"out_handler vector" mapping */ /* record "make_tuple"<->"out_handler vector" mapping */
std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_; std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_;
std::unordered_map<AnfNode *, std::shared_ptr<std::vector<AnfNodePtr>>> case_input_handle_cache_;
std::unordered_map<std::string, AnfNodePtr> params_; std::unordered_map<std::string, AnfNodePtr> params_;
std::unordered_map<std::string, OperatorPtr> vars_; std::unordered_map<std::string, OperatorPtr> vars_;
std::vector<std::pair<ge::Operator, std::string>> graph_outputs_; std::vector<std::pair<ge::Operator, std::string>> graph_outputs_;
std::vector<OperatorPtr> graph_const_inputs_; std::vector<OperatorPtr> graph_const_inputs_;
std::vector<OperatorPtr> init_ops_; std::vector<OperatorPtr> init_ops_;
std::vector<OperatorPtr> broadcast_ops_; std::vector<OperatorPtr> broadcast_ops_;
std::vector<AnfNodePtr> inputs_;
OperatorPtr dataset_iter_getnext_; OperatorPtr dataset_iter_getnext_;
Status error_ = SUCCESS; Status error_ = SUCCESS;
bool training_ = false; bool training_ = false;
bool distribute_ = false; bool distribute_ = false;
bool use_inputs_ = false;
}; };
} // namespace transform } // namespace transform
} // namespace mindspore } // namespace mindspore

View File

@ -164,6 +164,25 @@ class OpAdapter : public BaseOpAdapter {
const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; } const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; }
const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; } const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; }
const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; } const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; }
const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; }
Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) {
MS_EXCEPTION_IF_NULL(op);
auto it = dyn_subgraph_map_.find(index);
if (it != dyn_subgraph_map_.end()) {
auto size = branches->size();
it->second.create_dyn_subgraph(op, static_cast<unsigned int>(size));
for (size_t i = 0; i < size; i++) {
it->second.set_subgraph(op, static_cast<unsigned int>(i), std::make_shared<DfGraph>((*branches)[i]));
}
return SUCCESS;
}
return NOT_FOUND;
}
int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) override {
return static_cast<int>(SetOpSubgraphFunc(op, index, branches));
}
Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) {
MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(op);
@ -855,6 +874,7 @@ class OpAdapter : public BaseOpAdapter {
static const std::unordered_map<int, DynInputDesc> dyn_input_map_; static const std::unordered_map<int, DynInputDesc> dyn_input_map_;
static const std::unordered_map<int, OutputDesc> output_map_; static const std::unordered_map<int, OutputDesc> output_map_;
static const std::unordered_map<int, DynOutputDesc> dyn_output_map_; static const std::unordered_map<int, DynOutputDesc> dyn_output_map_;
static const std::unordered_map<int, DynSubGraphDesc> dyn_subgraph_map_;
static const std::unordered_map<std::string, AttrDesc> attr_map_; static const std::unordered_map<std::string, AttrDesc> attr_map_;
static const std::unordered_map<std::string, int> enum_map_; static const std::unordered_map<std::string, int> enum_map_;
// convert input from anf graph to Attr in Operators // convert input from anf graph to Attr in Operators
@ -874,6 +894,8 @@ const std::unordered_map<int, OutputDesc> OpAdapter<T>::output_map_;
template <typename T> template <typename T>
const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
template <typename T> template <typename T>
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
template <typename T>
const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_; const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_;
template <typename T> template <typename T>
const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_; const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_;

View File

@ -88,6 +88,8 @@ using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr
using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>; using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>; using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>; using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>;
using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>;
struct AttrDesc { struct AttrDesc {
std::string name; std::string name;
@ -108,6 +110,12 @@ struct DynInputDesc {
DynInputHandleFunc set_handle; DynInputHandleFunc set_handle;
}; };
struct DynSubGraphDesc {
std::string name;
CreateDynSubGraphFunc create_dyn_subgraph;
DynSubGraphFunc set_subgraph;
};
struct OutputDesc { struct OutputDesc {
std::string name; std::string name;
UpdateOutputDescFunc update_out_desc; UpdateOutputDescFunc update_out_desc;
@ -123,6 +131,7 @@ class BaseOpAdapter {
virtual ~BaseOpAdapter() {} virtual ~BaseOpAdapter() {}
virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); } virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) = 0;
virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
virtual int setInput(const OperatorPtr &op, int index, virtual int setInput(const OperatorPtr &op, int index,
@ -146,6 +155,7 @@ class BaseOpAdapter {
virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0; virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0;
virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0; virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0;
virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0; virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0;
virtual const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() = 0;
void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; } const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
void clearAttrVect() { attrs_vec_.clear(); } void clearAttrVect() { attrs_vec_.clear(); }

View File

@ -64,6 +64,22 @@ namespace transform {
} \ } \
} }
#define DYN_SUBGRAPH_MAP(T) \
template <> \
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_
#define DYN_SUBGRAPH_DESC(name) \
{ \
#name, \
[](const OperatorPtr op, unsigned int num) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->create_dynamic_subgraph_##name(num); \
}, \
[](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \
} \
}
#define ATTR_MAP(T) \ #define ATTR_MAP(T) \
template <> \ template <> \
const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_ const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_
@ -841,6 +857,13 @@ INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}};
ATTR_MAP(Cast) = EMPTY_ATTR_MAP; ATTR_MAP(Cast) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}};
// Case
INPUT_MAP(Case) = {{1, INPUT_DESC(branch_index)}};
DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}};
ATTR_MAP(Case) = EMPTY_ATTR_MAP;
DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}};
DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}};
// Reciprocal // Reciprocal
INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP; ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP;

View File

@ -46,6 +46,10 @@ namespace transform {
template <> \ template <> \
const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_; const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_;
#define DECLARE_OP_USE_DYN_SUBGRAPH(T) \
template <> \
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
#define DECLARE_OP_USE_DYN_OUTPUT(T) \ #define DECLARE_OP_USE_DYN_OUTPUT(T) \
template <> \ template <> \
const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
@ -232,6 +236,10 @@ DECLARE_OP_USE_OUTPUT(RealDiv)
DECLARE_OP_ADAPTER(Cast) DECLARE_OP_ADAPTER(Cast)
DECLARE_OP_USE_INPUT_ATTR(Cast) DECLARE_OP_USE_INPUT_ATTR(Cast)
DECLARE_OP_USE_OUTPUT(Cast) DECLARE_OP_USE_OUTPUT(Cast)
DECLARE_OP_ADAPTER(Case)
DECLARE_OP_USE_DYN_INPUT(Case)
DECLARE_OP_USE_DYN_SUBGRAPH(Case)
DECLARE_OP_USE_DYN_OUTPUT(Case)
DECLARE_OP_ADAPTER(Reciprocal) DECLARE_OP_ADAPTER(Reciprocal)
DECLARE_OP_USE_OUTPUT(Reciprocal) DECLARE_OP_USE_OUTPUT(Reciprocal)
DECLARE_OP_ADAPTER(Neg) DECLARE_OP_ADAPTER(Neg)

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test case."""
import numpy as np
import mindspore
import mindspore.nn as nn
from mindspore import Tensor, context
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 3, 3)
self.conv2 = nn.Conv2d(1, 3, 5, has_bias=True)
self.layers = (self.conv1, self.conv2)
def construct(self, x, index):
x = self.layers[index](x)
y = self.conv1(x)
return x + y
def test_case():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = Net()
data = Tensor(np.ones((1, 1, 224, 224)), mindspore.float32)
idx = Tensor(1, mindspore.int32)
net(data, idx)