forked from mindspore-Ecosystem/mindspore
improve rec-prog str generator
This commit is contained in:
parent
7874e96b9b
commit
c71234f383
|
@ -27,44 +27,27 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
void GenerateStrategy(const std::shared_ptr<Graph> graph, std::vector<std::shared_ptr<OperatorInfo>> ops,
|
||||
const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
|
||||
const std::shared_ptr<std::vector<size_t>> index_list,
|
||||
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list) {
|
||||
MaskNoSupportedOps(graph);
|
||||
void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (mask_special_ops) {
|
||||
MaskSpecialOps(graph);
|
||||
}
|
||||
for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
|
||||
auto type = ops[iter_ops]->type();
|
||||
size_t iter_nodes = index_list->at(ops_nodes_list->at(iter_ops));
|
||||
std::vector<std::vector<int32_t>> stra;
|
||||
iter_nodes = IterNodes(ops_nodes_list, index_list, eli_list, iter_ops, iter_nodes);
|
||||
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
||||
std::vector<int32_t> s = PrepareStrategy(graph, ops, type, iter_ops, iter_nodes, iter_op_inputs);
|
||||
stra.push_back(s);
|
||||
stra.push_back(PrepareStrategy(graph, ops, iter_ops, iter_op_inputs));
|
||||
}
|
||||
StrategyPtr sp = std::make_shared<Strategy>(0, stra);
|
||||
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
|
||||
}
|
||||
}
|
||||
|
||||
size_t IterNodes(const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
|
||||
const std::shared_ptr<std::vector<size_t>> index_list,
|
||||
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const size_t iter_ops,
|
||||
size_t iter_nodes) {
|
||||
if (iter_nodes > SIZE_MAX / 2) {
|
||||
for (size_t iter_eli = 0; iter_eli < eli_list->size(); iter_eli++) {
|
||||
if (eli_list->at(iter_eli)[0] == ops_nodes_list->at(iter_ops)) {
|
||||
iter_nodes = index_list->at(eli_list->at(iter_eli)[1]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return iter_nodes;
|
||||
}
|
||||
|
||||
void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs,
|
||||
std::vector<int32_t> s) {
|
||||
auto attrs = ops[iter_ops]->attrs();
|
||||
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
auto attrs = ops[iter_nodes]->attrs();
|
||||
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
|
||||
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
|
||||
if (transpose_a && (iter_op_inputs == 0)) {
|
||||
|
@ -77,10 +60,12 @@ void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::sh
|
|||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes, size_t iter_op_inputs,
|
||||
std::vector<int32_t> s) {
|
||||
std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
|
||||
size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
if (iter_op_inputs == 0) {
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
|
||||
|
@ -92,20 +77,24 @@ void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes,
|
|||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_h));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void PrepareBiasAdd(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
|
||||
std::vector<int32_t> s) {
|
||||
std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
if (iter_op_inputs == 0) {
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
|
||||
} else {
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
|
||||
std::vector<int32_t> s) {
|
||||
std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
if (iter_op_inputs == 0) {
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
|
||||
|
@ -114,97 +103,133 @@ void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, cons
|
|||
} else {
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void PrepareSparse(const size_t iter_op_inputs, std::vector<int32_t> s) {
|
||||
if (iter_op_inputs == 0) {
|
||||
s.push_back(g_device_manager->DeviceNum());
|
||||
s.push_back(1);
|
||||
} else {
|
||||
s.push_back(g_device_manager->DeviceNum());
|
||||
}
|
||||
}
|
||||
|
||||
void RefillOrigin(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_op_inputs, std::vector<int32_t> s) {
|
||||
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
||||
if (iter_op_inputs == 0) {
|
||||
for (size_t j = 0; j < origin_strategy->GetInputDim()[0].size(); j++) {
|
||||
s.push_back(1);
|
||||
}
|
||||
} else {
|
||||
for (size_t k = 0; k < origin_strategy->GetInputDim()[iter_op_inputs].size(); k++) {
|
||||
s.push_back(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::string &type,
|
||||
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
if (type == MATMUL) {
|
||||
PrepareMatMul(graph, ops, iter_ops, iter_nodes, iter_op_inputs, s);
|
||||
} else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) {
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_n));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_c));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
|
||||
} else if (type == CONV2D) {
|
||||
PrepareConv2D(graph, iter_nodes, iter_op_inputs, s);
|
||||
} else if (type == BIAS_ADD) {
|
||||
PrepareBiasAdd(graph, iter_nodes, iter_op_inputs, s);
|
||||
} else if (type == RESHAPE) {
|
||||
if (iter_op_inputs == 0) {
|
||||
s.push_back(g_device_manager->DeviceNum());
|
||||
s.push_back(1);
|
||||
s.push_back(1);
|
||||
s.push_back(1);
|
||||
s.push_back(1);
|
||||
} else if (type == RELU) {
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_n));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_c));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_h));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_w));
|
||||
} else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) {
|
||||
PrepareBN(graph, iter_nodes, iter_op_inputs, s);
|
||||
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
PrepareSparse(iter_op_inputs, s);
|
||||
} else {
|
||||
RefillOrigin(ops, iter_ops, iter_op_inputs, s);
|
||||
s.push_back(g_device_manager->DeviceNum());
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void MaskNoSupportedOps(const std::shared_ptr<Graph> graph) {
|
||||
std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
if (ops.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
||||
}
|
||||
if (iter_ops >= ops.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
|
||||
}
|
||||
if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size())
|
||||
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
|
||||
size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size();
|
||||
for (size_t dim = 0; dim < input_size; dim++) {
|
||||
s.push_back(1);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
|
||||
const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h));
|
||||
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w));
|
||||
return s;
|
||||
}
|
||||
|
||||
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t iter_op_inputs) {
|
||||
std::vector<int32_t> s;
|
||||
if (ops.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
||||
}
|
||||
if (iter_ops >= ops.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
|
||||
}
|
||||
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
||||
if (iter_op_inputs >= origin_strategy->GetInputDim().size())
|
||||
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
|
||||
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
|
||||
for (size_t dim = 0; dim < input_size; dim++) {
|
||||
if (dim == 0 && input_size == 4) {
|
||||
size_t max_device_num = g_device_manager->DeviceNum();
|
||||
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0];
|
||||
s.push_back(std::min(max_device_num, target_tensor_batch));
|
||||
} else {
|
||||
s.push_back(1);
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_op_inputs) {
|
||||
if (ops.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
||||
}
|
||||
if (iter_ops >= ops.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
|
||||
}
|
||||
auto type = ops[iter_ops]->type();
|
||||
if (type == MATMUL) {
|
||||
return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs);
|
||||
} else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) {
|
||||
return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
|
||||
} else if (type == CONV2D) {
|
||||
return PrepareConv2D(graph, iter_ops, iter_op_inputs);
|
||||
} else if (type == BIAS_ADD) {
|
||||
return PrepareBiasAdd(graph, iter_ops, iter_op_inputs);
|
||||
} else if (type == RESHAPE) {
|
||||
return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs);
|
||||
} else if (type == RELU) {
|
||||
return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
|
||||
} else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) {
|
||||
return PrepareBN(graph, iter_ops, iter_op_inputs);
|
||||
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
return PrepareSparse(iter_op_inputs);
|
||||
} else {
|
||||
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs);
|
||||
}
|
||||
}
|
||||
|
||||
void MaskSpecialOps(std::shared_ptr<Graph> graph) {
|
||||
size_t iter_nodes = graph->nodes.size();
|
||||
for (size_t i = 0; i < iter_nodes; i++) {
|
||||
if (0 == graph->nodes[i].info) {
|
||||
Graph::NodeType &node = graph->nodes[i];
|
||||
Graph::NodeType &node = graph->nodes[i];
|
||||
|
||||
if (node.apply.op_type == 1) { // For Convolution
|
||||
// cover input tensor strategy
|
||||
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
|
||||
node.apply.arguments[0].tensor_str.str_c = 1;
|
||||
node.apply.arguments[0].tensor_str.str_h = 1;
|
||||
node.apply.arguments[0].tensor_str.str_w = 1;
|
||||
// cover filter tensor strategy
|
||||
node.apply.arguments[1].tensor_str.str_n = 1;
|
||||
node.apply.arguments[1].tensor_str.str_c = 1;
|
||||
node.apply.arguments[1].tensor_str.str_h = 1;
|
||||
node.apply.arguments[1].tensor_str.str_w = 1;
|
||||
} else if (node.apply.op_type == 8) { // For BN
|
||||
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
|
||||
node.apply.arguments[0].tensor_str.str_c = 1;
|
||||
node.apply.arguments[0].tensor_str.str_h = 1;
|
||||
node.apply.arguments[0].tensor_str.str_w = 1;
|
||||
// cover 1-d argument blobs
|
||||
node.apply.arguments[1].tensor_str.str_w = 1;
|
||||
node.apply.arguments[2].tensor_str.str_w = 1;
|
||||
node.apply.arguments[3].tensor_str.str_w = 1;
|
||||
node.apply.arguments[4].tensor_str.str_w = 1;
|
||||
} else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits
|
||||
node.tensor_parm.tensor_str.str_h = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
|
||||
node.tensor_parm.tensor_str.str_w = 1;
|
||||
}
|
||||
if (node.apply.op_type == 1) { // For Convolution
|
||||
// cover input tensor strategy
|
||||
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
|
||||
node.apply.arguments[0].tensor_str.str_c = 1;
|
||||
node.apply.arguments[0].tensor_str.str_h = 1;
|
||||
node.apply.arguments[0].tensor_str.str_w = 1;
|
||||
// cover filter tensor strategy
|
||||
node.apply.arguments[1].tensor_str.str_n = 1;
|
||||
node.apply.arguments[1].tensor_str.str_c = 1;
|
||||
node.apply.arguments[1].tensor_str.str_h = 1;
|
||||
node.apply.arguments[1].tensor_str.str_w = 1;
|
||||
} else if (node.apply.op_type == 8) { // For BN
|
||||
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
|
||||
node.apply.arguments[0].tensor_str.str_c = 1;
|
||||
node.apply.arguments[0].tensor_str.str_h = 1;
|
||||
node.apply.arguments[0].tensor_str.str_w = 1;
|
||||
// cover 1-d argument blobs
|
||||
node.apply.arguments[1].tensor_str.str_n = 1;
|
||||
node.apply.arguments[2].tensor_str.str_c = 1;
|
||||
node.apply.arguments[3].tensor_str.str_h = 1;
|
||||
node.apply.arguments[4].tensor_str.str_w = 1;
|
||||
} else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits
|
||||
node.tensor_parm.tensor_str.str_h = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
|
||||
node.tensor_parm.tensor_str.str_w = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,29 +27,28 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
void GenerateStrategy(const std::shared_ptr<Graph> graph, std::vector<std::shared_ptr<OperatorInfo>> ops,
|
||||
const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
|
||||
const std::shared_ptr<std::vector<size_t>> index_list,
|
||||
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list);
|
||||
void PrepareMatMul(const std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs, std::vector<int32_t> s);
|
||||
void PrepareConv2D(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
|
||||
std::vector<int32_t> s);
|
||||
void PrepareBiasAdd(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
|
||||
std::vector<int32_t> s);
|
||||
void PrepareBN(const std::shared_ptr<Graph> graph, const size_t iter_nodes, const size_t iter_op_inputs,
|
||||
std::vector<int32_t> s);
|
||||
void PrepareSparse(const size_t iter_op_inputs, std::vector<int32_t> s);
|
||||
void RefillOrigin(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_op_inputs, std::vector<int32_t> s);
|
||||
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::string &type,
|
||||
const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs);
|
||||
size_t IterNodes(const std::shared_ptr<std::vector<size_t>> ops_nodes_list,
|
||||
const std::shared_ptr<std::vector<size_t>> index_list,
|
||||
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const size_t iter_ops,
|
||||
size_t iter_nodes);
|
||||
void MaskNoSupportedOps(const std::shared_ptr<Graph> graph);
|
||||
void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops);
|
||||
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs);
|
||||
std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs);
|
||||
std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs);
|
||||
std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
|
||||
const size_t iter_op_inputs);
|
||||
std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs);
|
||||
std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_op_inputs);
|
||||
std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
|
||||
const size_t iter_op_inputs);
|
||||
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, const size_t iter_op_inputs);
|
||||
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_op_inputs);
|
||||
void MaskSpecialOps(std::shared_ptr<Graph> graph);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
|
||||
|
|
|
@ -931,12 +931,9 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
|
|||
}
|
||||
|
||||
std::shared_ptr<std::vector<size_t>> ops_nodes_list(new std::vector<size_t>);
|
||||
std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
|
||||
std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
|
||||
|
||||
std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names, ops_nodes_list);
|
||||
|
||||
graph = EliminateGraph(graph, eli_list, index_list);
|
||||
size_t num_device = g_device_manager->DeviceNum();
|
||||
if (PartitionForAllDevices(num_device, graph) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
|
||||
|
@ -945,7 +942,8 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
GenerateStrategy(graph, ops, ops_nodes_list, index_list, eli_list);
|
||||
bool mask_special_ops = true;
|
||||
GenerateStrategy(graph, mask_special_ops, ops);
|
||||
|
||||
if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
|
||||
MS_LOG(INFO) << "Init selected strategy succeeded.";
|
||||
|
|
Loading…
Reference in New Issue