From 300dd2971cb69c3dfbddfc47777f0b2e957597a1 Mon Sep 17 00:00:00 2001 From: hongxing Date: Fri, 26 Jun 2020 13:59:16 +0200 Subject: [PATCH] merge master code to r0.5 --- .../auto_parallel/rec_core/rec_cost.cc | 125 +++++----- .../auto_parallel/rec_core/rec_cost.h | 54 ++--- .../rec_core/rec_generate_strategy.cc | 214 +++++++++++++----- .../rec_core/rec_generate_strategy.h | 42 ++-- .../auto_parallel/rec_core/rec_graph.h | 4 +- .../auto_parallel/rec_core/rec_parse_graph.cc | 44 ++-- .../auto_parallel/rec_core/rec_parse_graph.h | 19 +- .../auto_parallel/rec_core/rec_partition.cc | 55 +++-- .../auto_parallel/rec_core/rec_partition.h | 8 +- 9 files changed, 340 insertions(+), 225 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc index bb252466082..9fb79ceee42 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc @@ -28,7 +28,6 @@ namespace mindspore { namespace parallel { -#define DOUBLE_MAX (std::numeric_limits::max)() // Compute redistributed cost double CostRedis(const Graph::NodeType &node, @@ -621,75 +620,50 @@ StrategyRec CostCommon::ChoseStr(const std::vector &cost_op, StrategyRec break; default: - MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed."; + MS_LOG(EXCEPTION) << "Failure: Common failed."; } return str; } -// Get weight for BN -double CostBatchNorm::GetMinCostIn(const OperatorRec &op) { - int tensor = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) * - static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) * - static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) * - static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - - std::vector cost_in; - cost_in.push_back(StrDimB(tensor) * 1.2); - cost_in.push_back(DOUBLE_MAX); - cost_in.push_back(StrDimH(tensor) * 1.2); - cost_in.push_back(StrDimW(tensor) * 1.2); - - return *min_element(cost_in.begin(), cost_in.end()); -} - -// Get optimal strategy for BN -StrategyRec CostBatchNorm::GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph) { +// Get optimal strategy for BatchParallel OPs +StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) { const OperatorRec &op = node.apply; - - int tensor_filter_n = static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n); - int tensor_filter_c = static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); - int tensor_filter_h = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h); - int tensor_filter_w = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); - - int tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c; - - int output_tensor_h = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h); - int output_tensor_w = static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w); - int output_tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); + int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); + int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); std::vector cost_op; - std::vector> mode; - if (output_tensor_n < 2 || output_tensor_n % 2 != 0) { + if (tensor_n < 2 || tensor_n % 2 != 0) { cost_op.push_back(DOUBLE_MAX); } else { - cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, - mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); + cost_op.push_back(cost_in_); } - cost_op.push_back(DOUBLE_MAX); - - if (output_tensor_h < 2 || output_tensor_h % 2 != 0) { + if (tensor_c < 2 || tensor_c % 2 != 0) { cost_op.push_back(DOUBLE_MAX); } else { - cost_op.push_back(StrDimH(tensor_filter) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}}, graph)); + cost_op.push_back(cost_in_); } - if (output_tensor_w < 2 || output_tensor_w % 2 != 0) { + if (tensor_h < 2 || tensor_h % 2 != 0) { cost_op.push_back(DOUBLE_MAX); } else { - cost_op.push_back(StrDimW(tensor_filter) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 0.5}, {1, 1, 1, 1}, {1, 1, 1, 0.5}}, graph)); + cost_op.push_back(cost_in_); + } + + if (tensor_w < 2 || tensor_w % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); } return ChoseStr(cost_op, node.apply.str); } -// Chose strategy for BatchNorm -StrategyRec CostBatchNorm::ChoseStr(const std::vector &cost_op, StrategyRec str) { +// Chose strategy for BatchParallel op +StrategyRec CostBatchParallel::ChoseStr(const std::vector &cost_op, StrategyRec str) { uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { return str; @@ -700,36 +674,75 @@ StrategyRec CostBatchNorm::ChoseStr(const std::vector &cost_op, Strategy str.inputTensor[0].str_n /= 2.0; str.outputTensor.str_n /= 2.0; str.cut_counter += 1; - str.cost = str.cost + cost_in_b_; + str.cost = str.cost + cost_in_; break; case 1: str.inputTensor[0].str_c /= 2.0; - str.inputTensor[1].str_c /= 2.0; - str.inputTensor[2].str_c /= 2.0; - str.inputTensor[3].str_c /= 2.0; - str.inputTensor[4].str_c /= 2.0; str.outputTensor.str_c /= 2.0; str.cut_counter += 1; - str.cost = str.cost + cost_in_c_; + str.cost = str.cost + cost_in_; break; case 2: str.inputTensor[0].str_h /= 2.0; str.outputTensor.str_h /= 2.0; str.cut_counter += 1; - str.cost = str.cost + cost_in_h_; + str.cost = str.cost + cost_in_; break; case 3: str.inputTensor[0].str_w /= 2.0; str.outputTensor.str_w /= 2.0; str.cut_counter += 1; - str.cost = str.cost + cost_in_w_; + str.cost = str.cost + cost_in_; break; default: - MS_LOG(EXCEPTION) << "Failure: CostBatchNorm failed."; + MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed."; + } + return str; +} + +// Chose strategy for CostSoftmaxCrossEntropyWithLogits +StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.inputTensor[1].str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed."; } return str; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h index c45c81aca09..fb4fc27164c 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h @@ -28,6 +28,8 @@ namespace mindspore { namespace parallel { +#define DOUBLE_MAX (std::numeric_limits::max)() + double CostRedis(const Graph::NodeType &node, const std::vector> &node_name_to_strategy, const std::vector> &mode, const Graph &graph); @@ -195,7 +197,6 @@ class CostTensorAdd : public CostCommon { }; // all the following operation are element-wise and have the same cost -class CostOneHot : public CostCommon {}; class CostReLU : public CostCommon {}; class CostLog : public CostCommon {}; class CostExp : public CostCommon {}; @@ -206,50 +207,27 @@ class CostDiv : public CostCommon {}; class CostSqueeze : public CostCommon {}; class CostCast : public CostCommon {}; -// class BatchNorm is used to compute the cost of BatchNorm operator. -class CostBatchNorm { +// class BatchParallel is used to compute the cost of BatchParallel operator. +class CostBatchParallel { public: - StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph); + virtual StrategyRec GetOptimalStr(const Graph::NodeType &node); - double GetMinCostIn(const OperatorRec &op); + virtual double GetMaxCostIn() const { return DOUBLE_MAX; } - private: - double StrDimB(int32_t Tensor) { - cost_in_b_ = (static_cast(Tensor) * 4.0) / 2.0; + protected: + virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - return cost_in_b_; - } + double cost_in_ = 0; +}; // class BatchParallel is used to compute the cost of BatchParallel operator. - double StrDimC() { - cost_in_c_ = 0.0; - - return cost_in_c_; - } - - double StrDimH(int32_t Tensor) { - cost_in_h_ = (static_cast(Tensor) * 4.0) / 2.0; - - return cost_in_h_; - } - - double StrDimW(int32_t Tensor) { - cost_in_w_ = (static_cast(Tensor) * 4.0) / 2.0; - - return cost_in_w_; - } +class CostBatchNorm : public CostBatchParallel {}; +class CostOneHot : public CostBatchParallel {}; +class CostPRelu : public CostBatchParallel {}; +class CostSoftmax : public CostBatchParallel {}; +class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel { StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_b_ = 0; - - double cost_in_c_ = 0; - - double cost_in_h_ = 0; - - double cost_in_w_ = 0; -}; // class BatchNorm is used to compute the cost of BatchNorm operator. +}; } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index f04dcf15163..b8a57ae9970 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -28,10 +28,10 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops, - const std::shared_ptr>> eli_list, +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list) { + const std::shared_ptr> &index_list) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(eli_list); MS_EXCEPTION_IF_NULL(index_list); @@ -127,25 +127,6 @@ std::vector> PrepareMatMul(const std::shared_ptr &gr return strategies; } -std::vector> PreparePReLU(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); - strategies[1][0] = 1; - return strategies; -} - -std::vector> PrepareBatchNorm(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); - for (size_t i = 1; i < strategies.size(); i++) { - strategies[i][0] = strategies[0][1]; - } - strategies[1][0] = 1; - return strategies; -} - std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { std::vector> strategies; strategies.push_back(*s); @@ -155,10 +136,29 @@ std::vector> PrepareBiasAdd(const std::shared_ptr> PrepareOneHot(const std::shared_ptr> &s) { - std::vector> strategies; +std::vector> PrepareOneHot(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); + + int32_t axis = -1; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int."; + } + } + if (axis == -1) { + strategies[0][0] = strategies[0][1]; + strategies[0][1] = 1; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + } + std::vector s_empty = {}; - strategies.push_back(*s); strategies.push_back(s_empty); strategies.push_back(s_empty); return strategies; @@ -170,6 +170,32 @@ std::vector> PrepareGatherV2(const std::shared_ptr> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + int32_t axis = 0; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int."; + } + } + + int32_t axis_index = axis; + if (axis < 0) { + size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + axis_index = static_cast(input_dim) + axis; + } + + s[IntToSize(axis_index)] = 1; + + std::vector> strategies; + strategies.push_back(s); + return strategies; +} + std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops) { @@ -209,7 +235,7 @@ std::vector> MakeRecSearchStrategy(const std::shared_ptrname() << ": Tensor's output size is unexcepted."; } strategies.push_back(s); } @@ -229,7 +255,7 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr StrategyPtr origin_strategy = ops[iter_ops]->strategy(); std::vector> strategies; size_t max_device_num = g_device_manager->DeviceNum(); - size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; + size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; @@ -244,8 +270,10 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr } else { s.push_back(1); } + } else if (input_size == 0) { + s = {}; } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; } } strategies.push_back(s); @@ -285,25 +313,20 @@ std::vector> PrepareStrategy(const std::shared_ptr & if (type == MATMUL) { return PrepareMatMul(graph, ops, iter_graph, iter_ops); - } else if (type == PRELU) { - return PreparePReLU(graph, ops, iter_graph, iter_ops); - } else if (type == BATCH_NORM) { - return PrepareBatchNorm(graph, ops, iter_graph, iter_ops); - } else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS || - type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { - return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); + } else if (type == ONEHOT) { + return PrepareOneHot(graph, ops, iter_graph, iter_ops); } else { return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); } } -void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr> index_list) { + const std::shared_ptr> &index_list) { for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { std::vector> strategies; size_t iter_graph = index_list->at(iter_ops); - if (iter_graph != SIZE_MAX) { + if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); } StrategyPtr sp = std::make_shared(0, strategies); @@ -328,7 +351,7 @@ size_t FindIndexOfOperatorIncoming(const std::vector> & return incoming_op_index; } -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr graph, +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_ops, const size_t iter_graph) { std::vector s; @@ -348,7 +371,7 @@ std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptrnodes[iter_graph].tensor_parm.tensor_str.str_h); s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; } break; } @@ -358,7 +381,8 @@ std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr PrepareIncomingOperatorInputStrategy(const std::vector> &ops, const size_t incoming_op_index) { std::vector s; - if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2) { + if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || + ops[incoming_op_index]->type() == TRANSPOSE) { return s; } auto strategy = ops[incoming_op_index]->selected_strategy(); @@ -426,13 +450,23 @@ std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, const size_t iter_ops) { + bool keepdims = false; + auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS); + if (keep_dims_iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims."; + } + MS_EXCEPTION_IF_NULL(keep_dims_iter->second); + if (!keep_dims_iter->second->isa()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool."; + } + keepdims = keep_dims_iter->second->cast()->value(); + return keepdims; +} + std::vector GetDimList(const std::vector> &ops, const size_t iter_ops) { std::vector dim_list; - bool keep_dims; - if (!ops[iter_ops]->attrs().find(KEEP_DIMS)->second->isa()) { - MS_LOG(EXCEPTION) << "Failure: Parameter keep_dims is not a boolean value." << std::endl; - } - keep_dims = ops[iter_ops]->attrs().find(KEEP_DIMS)->second->cast()->value(); + bool keep_dims = GetKeepDims(ops, iter_ops); if (keep_dims != false) { return dim_list; } @@ -478,6 +512,62 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { + std::vector dim_list; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; + } + auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto attr_axis = GetValue>(iter->second); + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (iter->second->isa()) { + int axis = GetValue(iter->second); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + return dim_list; +} + +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + bool keepdims = GetKeepDims(ops, incoming_op_index); + if (keepdims) { + return s; + } + + std::vector s_Arg; + std::vector axis_list; + for (size_t i = 0; i < s.size(); i++) { + axis_list.push_back(i); + } + + auto dim_list = GetDimListFromAttrs(ops, incoming_op_index); + for (auto axis : dim_list) { + auto it = find(axis_list.begin(), axis_list.end(), axis); + if (it == axis_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + axis_list.erase(it); + } + + for (size_t i = 0; i < (size_t)axis_list.size(); i++) { + s_Arg.push_back(s[axis_list[i]]); + } + return s_Arg; +} + std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, const size_t iter_ops, const size_t incoming_op_index) { std::vector s; @@ -490,6 +580,9 @@ std::vector CopyIncomingOperatorInputStrategy(const std::vectortype() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); } + if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) { + s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s); + } } return s; } @@ -513,12 +606,12 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect if (ops[iter_ops]->type() == BIAS_ADD) { return PrepareBiasAdd(s_ptr); } - if (ops[iter_ops]->type() == ONEHOT) { - return PrepareOneHot(s_ptr); - } if (ops[iter_ops]->type() == GATHERV2) { return PrepareGatherV2(s_ptr); } + if (ops[iter_ops]->type() == L2_NORMALIZE) { + return PrepareL2Normalize(ops, iter_ops, basic_stra); + } for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { @@ -544,11 +637,11 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect return stra; } -void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr graph, +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list) { + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { if (no_stra_op_list->size() == 0) { return; } @@ -559,7 +652,7 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr grap std::vector> stra; std::vector s; size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); - if (incoming_op_index != SIZE_MAX && ops[iter_ops]->type() != ONEHOT) { + if (incoming_op_index != SIZE_MAX) { auto iter_graph = index_list->at(incoming_op_index); if (iter_graph != SIZE_MAX) { s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); @@ -617,7 +710,8 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector s; if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || - ops[iter_ops]->type() == GATHERV2) { + ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || + ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) { return s; } @@ -640,7 +734,7 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vectorselected_strategy()->GetInputDim()[iter_op_inputs].size(); ++k) { + for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) { s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]); } } @@ -649,7 +743,7 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> no_stra_op_list) { + const std::shared_ptr> &no_stra_op_list) { if (no_stra_op_list->size() == 0) { return; } @@ -679,16 +773,16 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vector graph, +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list) { + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { if (no_stra_op_list->size() == 0) { return; } - size_t no_stra_op_list_size; + size_t no_stra_op_list_size = no_stra_op_list->size(); do { no_stra_op_list_size = no_stra_op_list->size(); GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 1e5d4d95d0f..1e8080f2b74 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -27,22 +27,20 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops, - const std::shared_ptr>> eli_list, +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list); + const std::shared_ptr> &index_list); std::vector> PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -std::vector> PreparePReLU(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareBatchNorm(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); std::vector> PrepareBiasAdd(const std::shared_ptr> &s); -std::vector> PrepareOneHot(const std::shared_ptr> &s); +std::vector> PrepareOneHot(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); std::vector> PrepareGatherV2(const std::shared_ptr> &s); +std::vector> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s); std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); @@ -52,12 +50,12 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr std::vector> PrepareStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr> index_list); + const std::shared_ptr> &index_list); size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, const size_t iter_ops); -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr graph, +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_ops, const size_t iter_graph); std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, @@ -65,19 +63,23 @@ std::vector PrepareIncomingOperatorInputStrategy(const std::vector GetAxisList(const std::vector> &ops, const int iter_ops); std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, const size_t incoming_op_index, std::vector s); +bool GetKeepDims(const std::vector> &ops, const size_t iter_ops); std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, const size_t incoming_op_index, std::vector s); +std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, const size_t iter_ops, const size_t incoming_op_index); std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, std::vector basic_stra); -void GenerateEliminatedOperatorStrategyForward(std::shared_ptr graph, +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list); + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, const size_t iter_ops, std::vector s); std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, @@ -85,12 +87,12 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> no_stra_op_list); -void GenerateRemainingOperatorStrategy(const std::shared_ptr graph, + const std::shared_ptr> &no_stra_op_list); +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list); + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h index d578bd82ef1..9007218d152 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h @@ -38,6 +38,7 @@ enum OperatorType { kRecBiasAdd, kRecSoftmax, kRecSparseSoftmaxCrossEntropyWithLogits, + kRecSoftmaxCrossEntropyWithLogits, kRecOneHot, kRecLog, kRecExp, @@ -49,7 +50,8 @@ enum OperatorType { kRecCast, kRecReduce, kRecPReLU, - kRecGatherV2 + kRecGatherV2, + kRecArgWithValue }; enum InfoType { kApplication, kConstant }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 2aa9bddcc1e..c0412e9108a 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -40,7 +40,7 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { return tensor; } -Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops) { +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { Graph::NodeType NewOp; NewOp.name = ops[iter_ops]->name(); NewOp.info = InfoType::kApplication; @@ -140,7 +140,7 @@ std::shared_ptr ParseGraph(const std::vector> &input_tensor_names, std::shared_ptr graph) { +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); @@ -163,8 +163,8 @@ size_t GetIndexInInputTensorNames(const std::vector> &i return SIZE_MAX; } -void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, - const std::shared_ptr>> eli_list) { +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list) { std::vector eli; eli.push_back(node_index); for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { @@ -211,18 +211,18 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, } } -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - const std::shared_ptr>> eli_list, - const std::shared_ptr> index_list) { +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list) { MS_EXCEPTION_IF_NULL(graph); - const std::set type_list = { - OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, - OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, - OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, - OperatorType::kRecCast, OperatorType::kRecReshape, OperatorType::kRecGatherV2}; + static const std::set elementwise_type = { + OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, + OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, + OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, + OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { auto type = graph->nodes[node_index].apply.op_type; - if (type_list.find(type) != type_list.end()) { + if (elementwise_type.find(type) != elementwise_type.end()) { Eliminate_Aux(node_index, graph, eli_list); } } @@ -250,12 +250,22 @@ std::shared_ptr EliminateGraph(const std::shared_ptr graph, new_graph->nodes.push_back(graph->nodes[i]); auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; - for (size_t j = 0; j < node_in->size(); j++) { - node_in->at(j) = index_list->at(node_in->at(j)); + for (size_t j = node_in->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_in->erase(node_in->begin() + j - 1); + } else { + node_in->at(j - 1) = index_list->at(node_in->at(j - 1)); + } } auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; - for (size_t j = 0; j < node_out->size(); j++) { - node_out->at(j) = index_list->at(node_out->at(j)); + for (size_t j = node_out->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_out->erase(node_out->begin() + j - 1); + } else { + node_out->at(j - 1) = index_list->at(node_out->at(j - 1)); + } } } return new_graph; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index f39546dffc3..53abefd1c86 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -47,6 +47,8 @@ const std::map DictOpType{ {REDUCE_MIN, OperatorType::kRecReduce}, {REDUCE_MEAN, OperatorType::kRecReduce}, {GATHERV2, OperatorType::kRecGatherV2}, + {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, + {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, {RELU, OperatorType::kRecReLU}, {"ReLU6", OperatorType::kRecReLU}, @@ -59,6 +61,7 @@ const std::map DictOpType{ {PRELU, OperatorType::kRecPReLU}, + {TRANSPOSE, OperatorType::kRecElmWiseOp}, {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp}, @@ -67,7 +70,7 @@ const std::map DictOpType{ {REAL_DIV, OperatorType::kRecElmWiseOp}, {SOFTMAX, OperatorType::kRecSoftmax}, {LOG_SOFTMAX, OperatorType::kRecSoftmax}, - {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmax}, + {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, {SQRT, OperatorType::kRecElmWiseOp}, {NEG, OperatorType::kRecElmWiseOp}, {POW, OperatorType::kRecElmWiseOp}, @@ -107,7 +110,7 @@ const std::map DictOpType{ const TensorParam MakeTensor(int n, int c, int h, int w); -Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops); +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops); OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor); @@ -118,17 +121,17 @@ TensorParam Complete2DInputs(const std::vector> &o std::shared_ptr ParseGraph(const std::vector> &ops, const std::vector> &input_tensor_names); -void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph); +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph); size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, const std::string &input_name); -void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, - const std::shared_ptr>> eli_list); +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list); -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - const std::shared_ptr>> eli_list, - const std::shared_ptr> index_list); +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index 6bcc63c1465..d5200f54d8c 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -68,19 +68,24 @@ double GetWeights(const Graph::NodeType &node) { auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog || - op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd || - op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul || - op.op_type == OperatorType::kRecDiv || op.op_type == OperatorType::kRecSqueeze || - op.op_type == OperatorType::kRecCast) { + } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp || + op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub || + op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv || + op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) { // For element-wise op auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU || - op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecSoftmax || - op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For unprocessed type + } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || + op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || + op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || + op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { + // For BatchParallel op + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMaxCostIn(); + } else if (op.op_type == OperatorType::kRecUnkownType) { + // For Unkown type return 0.0; } else { MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; @@ -88,7 +93,7 @@ double GetWeights(const Graph::NodeType &node) { } // Sort all the nodes by their weights -std::vector SortByWeight(const std::shared_ptr graph) { +std::vector SortByWeight(const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector> weight_to_node_index; @@ -119,7 +124,7 @@ std::vector SortByWeight(const std::shared_ptr graph) { // Get optimal strategy to partition the target node StrategyRec PartitionNode(const Graph::NodeType &node, const std::vector> &node_name_to_strategy, - std::shared_ptr graph) { + const std::shared_ptr &graph) { bool enable_conv_chw_partition = false; MS_EXCEPTION_IF_NULL(graph); @@ -158,19 +163,26 @@ StrategyRec PartitionNode(const Graph::NodeType &node, auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecOneHot || node.apply.op_type == OperatorType::kRecLog || - node.apply.op_type == OperatorType::kRecExp || node.apply.op_type == OperatorType::kRecAdd || - node.apply.op_type == OperatorType::kRecSub || node.apply.op_type == OperatorType::kRecMul || - node.apply.op_type == OperatorType::kRecDiv || node.apply.op_type == OperatorType::kRecSqueeze || - node.apply.op_type == OperatorType::kRecCast) { + } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp || + node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub || + node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv || + node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) { // For element-wise op auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU || - node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecSoftmax || + } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || + node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For unprocessed type + // For BatchParallel type + auto cost_ptr = std::make_shared(); + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { + // For SoftmaxCrossEntropyWithLogits type + auto cost_ptr = std::make_shared(); + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecUnkownType) { + // For Unkown type StrategyRec default_strategy; return default_strategy; } else { @@ -179,7 +191,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node, } // Parttion graph into all devices. -Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr graph) { +Status PartitionForAllDevices(const size_t num_device, const double device_memory, + const std::shared_ptr &graph) { if (num_device < 1) { MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; } @@ -249,7 +262,7 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { return Node; } -Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr graph) { +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); if (num_device == 0) { MS_LOG(EXCEPTION) << "Failure: device number is 0."; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h index b2fbeddebd8..c98f3317f85 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h @@ -32,19 +32,19 @@ namespace mindspore { namespace parallel { -std::vector SortByWeight(const std::shared_ptr graph); +std::vector SortByWeight(const std::shared_ptr &graph); double GetWeights(const Graph::NodeType &node); StrategyRec PartitionNode(const Graph::NodeType &node, const std::vector> &node_name_to_strategy, - std::shared_ptr graph); + const std::shared_ptr &graph); -Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr graph); +Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr &graph); Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); -Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr graph); +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph); size_t GetDataTypeSize(const TensorType &type); } // namespace parallel