improve rec-prog str generator

This commit is contained in:
ch-l 2020-04-16 22:30:04 +02:00 committed by klchai
parent 7874e96b9b
commit c71234f383
3 changed files with 163 additions and 141 deletions

View File

@ -27,44 +27,27 @@
namespace mindspore {
namespace parallel {
void GenerateStrategy(const std::shared_ptr<Graph> graph, std::vector<std::shared_ptr<OperatorInfo>> ops,
const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list) {
MaskNoSupportedOps(graph);
void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
MS_EXCEPTION_IF_NULL(graph);
if (mask_special_ops) {
MaskSpecialOps(graph);
}
for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
auto type = ops[iter_ops]->type();
size_t iter_nodes = index_list->at(ops_nodes_list->at(iter_ops));
std::vector<std::vector<int32_t>> stra;
iter_nodes = IterNodes(ops_nodes_list, index_list, eli_list, iter_ops, iter_nodes);
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
std::vector<int32_t> s = PrepareStrategy(graph, ops, type, iter_ops, iter_nodes, iter_op_inputs);
stra.push_back(s);
stra.push_back(PrepareStrategy(graph, ops, iter_ops, iter_op_inputs));
}
StrategyPtr sp = std::make_shared<Strategy>(0, stra);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
}
}
size_t IterNodes(const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const size_t iter_ops,
size_t iter_nodes) {
if (iter_nodes > SIZE_MAX / 2) {
for (size_t iter_eli = 0; iter_eli < eli_list->size(); iter_eli++) {
if (eli_list->at(iter_eli)[0] == ops_nodes_list->at(iter_ops)) {
iter_nodes = index_list->at(eli_list->at(iter_eli)[1]);
break;
}
}
}
return iter_nodes;
}
void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs,
std::vector<int32_t> s) {
auto attrs = ops[iter_ops]->attrs();
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
auto attrs = ops[iter_nodes]->attrs();
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
if (transpose_a && (iter_op_inputs == 0)) {
@ -77,10 +60,12 @@ void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::sh
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
}
return s;
}
void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes, size_t iter_op_inputs,
std::vector<int32_t> s) {
std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
@ -92,20 +77,24 @@ void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes,
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
}
return s;
}
void PrepareBiasAdd(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
std::vector<int32_t> s) {
std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
}
return s;
}
void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
std::vector<int32_t> s) {
std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
@ -114,97 +103,133 @@ void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, cons
} else {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
}
return s;
}
void PrepareSparse(const size_t iter_op_inputs, std::vector<int32_t> s) {
if (iter_op_inputs == 0) {
s.push_back(g_device_manager->DeviceNum());
s.push_back(1);
} else {
s.push_back(g_device_manager->DeviceNum());
}
}
void RefillOrigin(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs, std::vector<int32_t> s) {
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
if (iter_op_inputs == 0) {
for (size_t j = 0; j < origin_strategy->GetInputDim()[0].size(); j++) {
s.push_back(1);
}
} else {
for (size_t k = 0; k < origin_strategy->GetInputDim()[iter_op_inputs].size(); k++) {
s.push_back(1);
}
}
}
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::string &type,
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs) {
std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (type == MATMUL) {
PrepareMatMul(graph, ops, iter_ops, iter_nodes, iter_op_inputs, s);
} else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else if (type == CONV2D) {
PrepareConv2D(graph, iter_nodes, iter_op_inputs, s);
} else if (type == BIAS_ADD) {
PrepareBiasAdd(graph, iter_nodes, iter_op_inputs, s);
} else if (type == RESHAPE) {
if (iter_op_inputs == 0) {
s.push_back(g_device_manager->DeviceNum());
s.push_back(1);
s.push_back(1);
s.push_back(1);
s.push_back(1);
} else if (type == RELU) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_w));
} else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) {
PrepareBN(graph, iter_nodes, iter_op_inputs, s);
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
PrepareSparse(iter_op_inputs, s);
} else {
RefillOrigin(ops, iter_ops, iter_op_inputs, s);
s.push_back(g_device_manager->DeviceNum());
}
return s;
}
void MaskNoSupportedOps(const std::shared_ptr<Graph> graph) {
std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size())
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
s.push_back(1);
}
return s;
}
std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w));
return s;
}
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
if (iter_op_inputs >= origin_strategy->GetInputDim().size())
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
if (dim == 0 && input_size == 4) {
size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0];
s.push_back(std::min(max_device_num, target_tensor_batch));
} else {
s.push_back(1);
}
}
return s;
}
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs) {
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
auto type = ops[iter_ops]->type();
if (type == MATMUL) {
return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs);
} else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) {
return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
} else if (type == CONV2D) {
return PrepareConv2D(graph, iter_ops, iter_op_inputs);
} else if (type == BIAS_ADD) {
return PrepareBiasAdd(graph, iter_ops, iter_op_inputs);
} else if (type == RESHAPE) {
return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs);
} else if (type == RELU) {
return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
} else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) {
return PrepareBN(graph, iter_ops, iter_op_inputs);
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return PrepareSparse(iter_op_inputs);
} else {
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs);
}
}
void MaskSpecialOps(std::shared_ptr<Graph> graph) {
size_t iter_nodes = graph->nodes.size();
for (size_t i = 0; i < iter_nodes; i++) {
if (0 == graph->nodes[i].info) {
Graph::NodeType &node = graph->nodes[i];
Graph::NodeType &node = graph->nodes[i];
if (node.apply.op_type == 1) { // For Convolution
// cover input tensor strategy
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.apply.arguments[0].tensor_str.str_c = 1;
node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1;
// cover filter tensor strategy
node.apply.arguments[1].tensor_str.str_n = 1;
node.apply.arguments[1].tensor_str.str_c = 1;
node.apply.arguments[1].tensor_str.str_h = 1;
node.apply.arguments[1].tensor_str.str_w = 1;
} else if (node.apply.op_type == 8) { // For BN
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.apply.arguments[0].tensor_str.str_c = 1;
node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1;
// cover 1-d argument blobs
node.apply.arguments[1].tensor_str.str_w = 1;
node.apply.arguments[2].tensor_str.str_w = 1;
node.apply.arguments[3].tensor_str.str_w = 1;
node.apply.arguments[4].tensor_str.str_w = 1;
} else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits
node.tensor_parm.tensor_str.str_h = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.tensor_parm.tensor_str.str_w = 1;
}
if (node.apply.op_type == 1) { // For Convolution
// cover input tensor strategy
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.apply.arguments[0].tensor_str.str_c = 1;
node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1;
// cover filter tensor strategy
node.apply.arguments[1].tensor_str.str_n = 1;
node.apply.arguments[1].tensor_str.str_c = 1;
node.apply.arguments[1].tensor_str.str_h = 1;
node.apply.arguments[1].tensor_str.str_w = 1;
} else if (node.apply.op_type == 8) { // For BN
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.apply.arguments[0].tensor_str.str_c = 1;
node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1;
// cover 1-d argument blobs
node.apply.arguments[1].tensor_str.str_n = 1;
node.apply.arguments[2].tensor_str.str_c = 1;
node.apply.arguments[3].tensor_str.str_h = 1;
node.apply.arguments[4].tensor_str.str_w = 1;
} else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits
node.tensor_parm.tensor_str.str_h = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.tensor_parm.tensor_str.str_w = 1;
}
}
}

View File

@ -27,29 +27,28 @@
namespace mindspore {
namespace parallel {
void GenerateStrategy(const std::shared_ptr<Graph> graph, std::vector<std::shared_ptr<OperatorInfo>> ops,
const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list);
void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs, std::vector<int32_t> s);
void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
std::vector<int32_t> s);
void PrepareBiasAdd(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
std::vector<int32_t> s);
void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
std::vector<int32_t> s);
void PrepareSparse(const size_t iter_op_inputs, std::vector<int32_t> s);
void RefillOrigin(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs, std::vector<int32_t> s);
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::string &type,
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs);
size_t IterNodes(const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const size_t iter_ops,
size_t iter_nodes);
void MaskNoSupportedOps(const std::shared_ptr<Graph> graph);
void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
const std::vector<std::shared_ptr<OperatorInfo>> &ops);
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs);
std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs);
std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
const size_t iter_op_inputs);
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_op_inputs);
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs);
void MaskSpecialOps(std::shared_ptr<Graph> graph);
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_

View File

@ -931,12 +931,9 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
}
std::shared_ptr<std::vector<size_t>> ops_nodes_list(new std::vector<size_t>);
std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names, ops_nodes_list);
graph = EliminateGraph(graph, eli_list, index_list);
size_t num_device = g_device_manager->DeviceNum();
if (PartitionForAllDevices(num_device, graph) == SUCCESS) {
MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
@ -945,7 +942,8 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
return FAILED;
}
GenerateStrategy(graph, ops, ops_nodes_list, index_list, eli_list);
bool mask_special_ops = true;
GenerateStrategy(graph, mask_special_ops, ops);
if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
MS_LOG(INFO) << "Init selected strategy succeeded.";