forked from mindspore-Ecosystem/mindspore
Updating the redistribution cost in D-Rec cost model
This commit is contained in:
parent
5a54daa3ff
commit
4d90a1dbd2
|
@ -169,4 +169,9 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_f
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32
|
||||
<<<<<<< HEAD
|
||||
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::Run
|
||||
=======
|
||||
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::GetWeights
|
||||
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
|
||||
>>>>>>> Updating the redistribution cost in D-Rec cost model
|
||||
|
|
|
@ -576,7 +576,7 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|||
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == LAYER_NORM) {
|
||||
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == DROPOUT) || (type == BATCH_MATMUL)) {
|
||||
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == VIRTUAL_DATA_SET) {
|
||||
if (ParallelContext::GetInstance()->full_batch()) {
|
||||
|
@ -866,9 +866,7 @@ Dimensions PrepareIncompingArithmeticOpeartorInputStrategy(const std::vector<std
|
|||
Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index) {
|
||||
Dimensions s;
|
||||
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) {
|
||||
return s;
|
||||
}
|
||||
|
||||
if (ops[incoming_op_index]->type() == GATHERV2) {
|
||||
auto pos = ops[incoming_op_index]->name().find("Info");
|
||||
if (pos == std::string::npos) {
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
@ -85,7 +84,7 @@ double GetWeights(const Graph::NodeType &node) {
|
|||
|
||||
return cost_ptr->GetMaxCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecUnkownType) {
|
||||
// For Unkown type
|
||||
// For Unknown type
|
||||
return 0.0;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
|
||||
|
@ -183,7 +182,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
auto cost_ptr = std::make_shared<CostSoftmaxCrossEntropyWithLogits>();
|
||||
return cost_ptr->GetOptimalStr(node);
|
||||
} else if (node.apply.op_type == OperatorType::kRecUnkownType) {
|
||||
// For Unkown type
|
||||
// For Unknown type
|
||||
StrategyRec default_strategy;
|
||||
return default_strategy;
|
||||
} else {
|
||||
|
@ -191,7 +190,21 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
}
|
||||
}
|
||||
|
||||
// Parttion graph into all devices.
|
||||
StrategyRec GetOneLoopStrategy(size_t op_inputs_num, StrategyRec old_str, StrategyRec new_str) {
|
||||
for (size_t i = 0; i < op_inputs_num; i++) {
|
||||
if (old_str.inputTensor[i].str_n != 0 && old_str.inputTensor[i].str_c != 0 && old_str.inputTensor[i].str_h != 0 &&
|
||||
old_str.inputTensor[i].str_w != 0) {
|
||||
new_str.inputTensor[i].str_n = new_str.inputTensor[i].str_n / old_str.inputTensor[i].str_n;
|
||||
new_str.inputTensor[i].str_c = new_str.inputTensor[i].str_c / old_str.inputTensor[i].str_c;
|
||||
new_str.inputTensor[i].str_h = new_str.inputTensor[i].str_h / old_str.inputTensor[i].str_h;
|
||||
new_str.inputTensor[i].str_w = new_str.inputTensor[i].str_w / old_str.inputTensor[i].str_w;
|
||||
}
|
||||
}
|
||||
|
||||
return new_str;
|
||||
}
|
||||
|
||||
// Partition graph into all devices.
|
||||
Status PartitionForAllDevices(const size_t num_device, const double device_memory,
|
||||
const std::shared_ptr<Graph> &graph) {
|
||||
if (num_device < 1) {
|
||||
|
@ -227,15 +240,21 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor
|
|||
|
||||
Graph::NodeType &node_ptr = graph->nodes[index];
|
||||
|
||||
// 2-parts partitioning StrategyRec of the last loop
|
||||
StrategyRec old_str = graph->nodes[index].apply.str;
|
||||
|
||||
// Serch optimal strategy to cut this operator. And store the result optimal strategy in graph.
|
||||
graph->nodes[index].apply.str = PartitionNode(node_ptr, node_name_to_strategy, graph);
|
||||
|
||||
// Get Current 2-parts partitioning strategy of this loop
|
||||
size_t op_inputs_num = graph->nodes[index].node_in.size();
|
||||
StrategyRec one_loop_strategyrec = GetOneLoopStrategy(op_inputs_num, old_str, graph->nodes[index].apply.str);
|
||||
|
||||
// Apply OP Strategy to Tensor Strategy.
|
||||
graph->nodes[index] = ApplyStrToTensor(node_ptr);
|
||||
|
||||
// Note down the node name and its strategy in this loop.
|
||||
auto node_name_to_str =
|
||||
std::pair<std::string, StrategyRec>(graph->nodes[index].name, graph->nodes[index].apply.str);
|
||||
auto node_name_to_str = std::pair<std::string, StrategyRec>(graph->nodes[index].name, one_loop_strategyrec);
|
||||
node_name_to_strategy.push_back(node_name_to_str);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,8 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node);
|
|||
|
||||
Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph);
|
||||
|
||||
StrategyRec GetOneLoopStrategy(size_t op_inputs_num, StrategyRec old_str, StrategyRec new_str);
|
||||
|
||||
size_t GetDataTypeSize(const TensorType &type);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue