forked from mindspore-Ecosystem/mindspore
!1938 [AutoParallel] limit GatherV2, BN and Softmax to data parallel
Merge pull request !1938 from Chong/ReID
This commit is contained in:
commit
c1c683eea8
|
@ -135,24 +135,51 @@ std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &gra
|
|||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(std::vector<int32_t> s) {
|
||||
std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops) {
|
||||
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
for (size_t i = 1; i < strategies.size(); i++) {
|
||||
strategies[i][0] = strategies[0][1];
|
||||
}
|
||||
strategies[1][0] = 1;
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareSoftmaxWithLogits(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops) {
|
||||
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = graph->nodes[iter_graph].tensor_parm.tensor_str.str_h;
|
||||
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_c;
|
||||
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = graph->nodes[iter_graph].tensor_parm.tensor_str.str_n;
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) {
|
||||
std::vector<std::vector<int32_t>> strategies;
|
||||
strategies.push_back(s);
|
||||
strategies.push_back(*s);
|
||||
std::vector<int32_t> s_biasadd;
|
||||
s_biasadd.push_back(s[1]);
|
||||
s_biasadd.push_back(s->at(1));
|
||||
strategies.push_back(s_biasadd);
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s) {
|
||||
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s) {
|
||||
std::vector<std::vector<int32_t>> strategies;
|
||||
std::vector<int32_t> s_empty = {};
|
||||
strategies.push_back(s);
|
||||
strategies.push_back(*s);
|
||||
strategies.push_back(s_empty);
|
||||
strategies.push_back(s_empty);
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s) {
|
||||
std::vector<std::vector<int32_t>> strategies;
|
||||
strategies.push_back(*s);
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops) {
|
||||
|
@ -270,6 +297,12 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
|
|||
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == PRELU) {
|
||||
return PreparePReLU(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == BATCH_NORM) {
|
||||
return PrepareBatchNorm(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
return PrepareSoftmaxWithLogits(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else {
|
||||
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||
}
|
||||
|
@ -336,7 +369,7 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr
|
|||
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index) {
|
||||
std::vector<int32_t> s;
|
||||
if (ops[incoming_op_index]->type() == RESHAPE) {
|
||||
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2) {
|
||||
return s;
|
||||
}
|
||||
auto strategy = ops[incoming_op_index]->selected_strategy();
|
||||
|
@ -456,11 +489,6 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
|
|||
return s_Reduce;
|
||||
}
|
||||
|
||||
std::vector<int32_t> ModifyStrategyIfSoftmaxIncoming(std::vector<int32_t> s) {
|
||||
s.pop_back();
|
||||
return s;
|
||||
}
|
||||
|
||||
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t incoming_op_index) {
|
||||
std::vector<int32_t> s;
|
||||
|
@ -474,9 +502,6 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
|
|||
ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
|
||||
s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s);
|
||||
}
|
||||
if (ops[incoming_op_index]->type() == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
s = ModifyStrategyIfSoftmaxIncoming(s);
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
@ -496,11 +521,15 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
|
|||
return stra;
|
||||
}
|
||||
|
||||
auto s_ptr = std::make_shared<std::vector<int32_t>>(basic_stra);
|
||||
if (ops[iter_ops]->type() == BIAS_ADD) {
|
||||
return PrepareBiasAdd(basic_stra);
|
||||
return PrepareBiasAdd(s_ptr);
|
||||
}
|
||||
if (ops[iter_ops]->type() == ONEHOT) {
|
||||
return PrepareOneHot(basic_stra);
|
||||
return PrepareOneHot(s_ptr);
|
||||
}
|
||||
if (ops[iter_ops]->type() == GATHERV2) {
|
||||
return PrepareGatherV2(s_ptr);
|
||||
}
|
||||
|
||||
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
|
||||
|
@ -599,7 +628,8 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
|
|||
const size_t iter_ops) {
|
||||
std::vector<int32_t> s;
|
||||
if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN ||
|
||||
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE) {
|
||||
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE ||
|
||||
ops[iter_ops]->type() == GATHERV2) {
|
||||
return s;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,8 +37,15 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr
|
|||
std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
std::vector<std::vector<int32_t>> PrepareSoftmaxWithLogits(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s);
|
||||
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s);
|
||||
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s);
|
||||
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
|
@ -64,11 +71,11 @@ std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shar
|
|||
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
|
||||
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index, std::vector<int32_t> s);
|
||||
std::vector<int32_t> ModifyStrategyIfSoftmaxIncoming(std::vector<int32_t> s);
|
||||
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t incoming_op_index);
|
||||
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
const size_t iter_ops,
|
||||
std::vector<int32_t> basic_stra);
|
||||
void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const std::vector<std::vector<std::string>> &input_tensor_names,
|
||||
|
|
|
@ -48,7 +48,8 @@ enum OperatorType {
|
|||
kRecSqueeze,
|
||||
kRecCast,
|
||||
kRecReduce,
|
||||
kRecPReLU
|
||||
kRecPReLU,
|
||||
kRecGatherV2
|
||||
};
|
||||
|
||||
enum InfoType { kApplication, kConstant };
|
||||
|
|
|
@ -199,7 +199,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
|
|||
OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp,
|
||||
OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub,
|
||||
OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce,
|
||||
OperatorType::kRecCast, OperatorType::kRecReshape};
|
||||
OperatorType::kRecCast, OperatorType::kRecReshape, OperatorType::kRecGatherV2};
|
||||
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
|
||||
auto type = graph->nodes[node_index].apply.op_type;
|
||||
if (type_list.find(type) != type_list.end()) {
|
||||
|
|
|
@ -46,6 +46,7 @@ const std::map<std::string, OperatorType> DictOpType{
|
|||
{REDUCE_MAX, OperatorType::kRecReduce},
|
||||
{REDUCE_MIN, OperatorType::kRecReduce},
|
||||
{REDUCE_MEAN, OperatorType::kRecReduce},
|
||||
{GATHERV2, OperatorType::kRecGatherV2},
|
||||
|
||||
{RELU, OperatorType::kRecReLU},
|
||||
{"ReLU6", OperatorType::kRecReLU},
|
||||
|
@ -63,9 +64,9 @@ const std::map<std::string, OperatorType> DictOpType{
|
|||
{MUL, OperatorType::kRecElmWiseOp},
|
||||
{DIV, OperatorType::kRecElmWiseOp},
|
||||
{REAL_DIV, OperatorType::kRecElmWiseOp},
|
||||
{SOFTMAX, OperatorType::kRecElmWiseOp},
|
||||
{LOG_SOFTMAX, OperatorType::kRecElmWiseOp},
|
||||
{SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecElmWiseOp},
|
||||
{SOFTMAX, OperatorType::kRecSoftmax},
|
||||
{LOG_SOFTMAX, OperatorType::kRecSoftmax},
|
||||
{SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmax},
|
||||
{SQRT, OperatorType::kRecElmWiseOp},
|
||||
{NEG, OperatorType::kRecElmWiseOp},
|
||||
{POW, OperatorType::kRecElmWiseOp},
|
||||
|
|
|
@ -53,9 +53,8 @@ double GetWeights(const Graph::NodeType &node) {
|
|||
auto cost_ptr = std::make_shared<CostTensorAdd>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecReLU || op.op_type == OperatorType::kRecSoftmax ||
|
||||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For Activation and Softmax
|
||||
} else if (op.op_type == OperatorType::kRecReLU) {
|
||||
// For Activation
|
||||
auto cost_ptr = std::make_shared<CostCommon>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
|
@ -69,11 +68,6 @@ double GetWeights(const Graph::NodeType &node) {
|
|||
auto cost_ptr = std::make_shared<CostBiasAdd>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecBatchNorm) {
|
||||
// For BatchNorm
|
||||
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
||||
|
||||
return cost_ptr->GetMinCostIn(op);
|
||||
} else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog ||
|
||||
op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd ||
|
||||
op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul ||
|
||||
|
@ -83,8 +77,10 @@ double GetWeights(const Graph::NodeType &node) {
|
|||
auto cost_ptr = std::make_shared<CostCommon>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU) {
|
||||
// For unknown type
|
||||
} else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU ||
|
||||
op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecSoftmax ||
|
||||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For unprocessed type
|
||||
return 0.0;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
|
||||
|
@ -147,9 +143,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
auto cost_ptr = std::make_shared<CostTensorAdd>();
|
||||
|
||||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
} else if (node.apply.op_type == OperatorType::kRecReLU || node.apply.op_type == OperatorType::kRecSoftmax ||
|
||||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For Softmax & Activation
|
||||
} else if (node.apply.op_type == OperatorType::kRecReLU) {
|
||||
// For Activation
|
||||
auto cost_ptr = std::make_shared<CostCommon>();
|
||||
|
||||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
|
@ -162,11 +157,6 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
// For BiasAdd
|
||||
auto cost_ptr = std::make_shared<CostBiasAdd>();
|
||||
|
||||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
} else if (node.apply.op_type == OperatorType::kRecBatchNorm) {
|
||||
// For BatchNorm
|
||||
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
||||
|
||||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
} else if (node.apply.op_type == OperatorType::kRecOneHot || node.apply.op_type == OperatorType::kRecLog ||
|
||||
node.apply.op_type == OperatorType::kRecExp || node.apply.op_type == OperatorType::kRecAdd ||
|
||||
|
@ -177,8 +167,10 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
auto cost_ptr = std::make_shared<CostCommon>();
|
||||
|
||||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
} else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU) {
|
||||
// For unknown type
|
||||
} else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU ||
|
||||
node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecSoftmax ||
|
||||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For unprocessed type
|
||||
StrategyRec default_strategy;
|
||||
return default_strategy;
|
||||
} else {
|
||||
|
|
|
@ -410,9 +410,11 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (!IsAutoParallelCareNode(cnode)) {
|
||||
// Needed by rec_parser
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
if (prim->name() == TUPLE_GETITEM) {
|
||||
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
|
||||
if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
|
||||
auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
|
||||
if (prev_cnode != nullptr) {
|
||||
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -473,9 +475,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (!IsAutoParallelCareNode(cnode)) {
|
||||
// Needed by rec_parser
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
if (prim->name() == TUPLE_GETITEM) {
|
||||
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
|
||||
if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
|
||||
auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
|
||||
if (prev_cnode != nullptr) {
|
||||
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -1100,6 +1104,26 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
|
|||
return input_tensor_names;
|
||||
}
|
||||
|
||||
CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) {
|
||||
auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
|
||||
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
|
||||
while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) {
|
||||
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
|
||||
if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
|
||||
return nullptr;
|
||||
}
|
||||
prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
|
||||
}
|
||||
return prev_cnode;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
|
||||
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
|
||||
|
|
|
@ -57,6 +57,8 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
|
|||
|
||||
std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
|
||||
std::vector<std::vector<std::string>> input_tensor_names);
|
||||
|
||||
CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_STEP_AUTO_PARALLEL_H_
|
||||
|
|
Loading…
Reference in New Issue