From 7029bc5dd30ef8e848e594eef8b1413648e53e0e Mon Sep 17 00:00:00 2001 From: hongxing Date: Tue, 23 Jun 2020 16:30:32 +0200 Subject: [PATCH] fix onehot axis --- .../rec_core/rec_generate_strategy.cc | 149 ++++++++++++++---- .../rec_core/rec_generate_strategy.h | 30 ++-- .../auto_parallel/rec_core/rec_graph.h | 3 +- .../auto_parallel/rec_core/rec_parse_graph.cc | 22 +-- .../auto_parallel/rec_core/rec_parse_graph.h | 13 +- 5 files changed, 157 insertions(+), 60 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 19e07aae025..630833f4a69 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -28,10 +28,10 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops, - const std::shared_ptr>> eli_list, +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list) { + const std::shared_ptr> &index_list) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(eli_list); MS_EXCEPTION_IF_NULL(index_list); @@ -140,10 +140,24 @@ std::vector> PrepareOneHot(const std::shared_ptr &gr const std::vector> &ops, const size_t iter_graph, const size_t iter_ops) { std::vector> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); - strategies[0][0] = strategies[0][1]; - strategies[0][1] = 1; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + + int32_t axis = -1; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int."; + } + } + if (axis == -1) { + strategies[0][0] = strategies[0][1]; + strategies[0][1] = 1; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + } + std::vector s_empty = {}; strategies.push_back(s_empty); strategies.push_back(s_empty); @@ -221,7 +235,7 @@ std::vector> MakeRecSearchStrategy(const std::shared_ptrname() << ": Tensor's output size is unexcepted."; } strategies.push_back(s); } @@ -241,7 +255,7 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr StrategyPtr origin_strategy = ops[iter_ops]->strategy(); std::vector> strategies; size_t max_device_num = g_device_manager->DeviceNum(); - size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; + size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; @@ -256,8 +270,10 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr } else { s.push_back(1); } + } else if (input_size == 0) { + s = {}; } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; } } strategies.push_back(s); @@ -304,13 +320,13 @@ std::vector> PrepareStrategy(const std::shared_ptr & } } -void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr> index_list) { + const std::shared_ptr> &index_list) { for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { std::vector> strategies; size_t iter_graph = index_list->at(iter_ops); - if (iter_graph != SIZE_MAX) { + if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); } StrategyPtr sp = std::make_shared(0, strategies); @@ -335,7 +351,7 @@ size_t FindIndexOfOperatorIncoming(const std::vector> & return incoming_op_index; } -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr graph, +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_ops, const size_t iter_graph) { std::vector s; @@ -354,8 +370,10 @@ std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptrnodes[iter_graph].tensor_parm.tensor_str.str_c); s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); + } else if (input_stra_dim == 0) { + s = {}; } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; } break; } @@ -365,7 +383,8 @@ std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr PrepareIncomingOperatorInputStrategy(const std::vector> &ops, const size_t incoming_op_index) { std::vector s; - if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2) { + if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || + ops[incoming_op_index]->type() == TRANSPOSE) { return s; } auto strategy = ops[incoming_op_index]->selected_strategy(); @@ -433,13 +452,23 @@ std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, const size_t iter_ops) { + bool keepdims = false; + auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS); + if (keep_dims_iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims."; + } + MS_EXCEPTION_IF_NULL(keep_dims_iter->second); + if (!keep_dims_iter->second->isa()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool."; + } + keepdims = keep_dims_iter->second->cast()->value(); + return keepdims; +} + std::vector GetDimList(const std::vector> &ops, const size_t iter_ops) { std::vector dim_list; - bool keep_dims; - if (!ops[iter_ops]->attrs().find(KEEP_DIMS)->second->isa()) { - MS_LOG(EXCEPTION) << "Failure: Parameter keep_dims is not a boolean value." << std::endl; - } - keep_dims = ops[iter_ops]->attrs().find(KEEP_DIMS)->second->cast()->value(); + bool keep_dims = GetKeepDims(ops, iter_ops); if (keep_dims != false) { return dim_list; } @@ -485,6 +514,62 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { + std::vector dim_list; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; + } + auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto attr_axis = GetValue>(iter->second); + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (iter->second->isa()) { + int axis = GetValue(iter->second); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + return dim_list; +} + +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + bool keepdims = GetKeepDims(ops, incoming_op_index); + if (keepdims) { + return s; + } + + std::vector s_Arg; + std::vector axis_list; + for (size_t i = 0; i < s.size(); i++) { + axis_list.push_back(i); + } + + auto dim_list = GetDimListFromAttrs(ops, incoming_op_index); + for (auto axis : dim_list) { + auto it = find(axis_list.begin(), axis_list.end(), axis); + if (it == axis_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + axis_list.erase(it); + } + + for (size_t i = 0; i < (size_t)axis_list.size(); i++) { + s_Arg.push_back(s[axis_list[i]]); + } + return s_Arg; +} + std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, const size_t iter_ops, const size_t incoming_op_index) { std::vector s; @@ -497,6 +582,9 @@ std::vector CopyIncomingOperatorInputStrategy(const std::vectortype() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); } + if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) { + s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s); + } } return s; } @@ -551,11 +639,11 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect return stra; } -void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr graph, +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list) { + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { if (no_stra_op_list->size() == 0) { return; } @@ -624,7 +712,8 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector s; if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || - ops[iter_ops]->type() == GATHERV2) { + ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || + ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) { return s; } @@ -656,7 +745,7 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> no_stra_op_list) { + const std::shared_ptr> &no_stra_op_list) { if (no_stra_op_list->size() == 0) { return; } @@ -686,16 +775,16 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vector graph, +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list) { + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { if (no_stra_op_list->size() == 0) { return; } - size_t no_stra_op_list_size; + size_t no_stra_op_list_size = no_stra_op_list->size(); do { no_stra_op_list_size = no_stra_op_list->size(); GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index c9604b449f6..1e8080f2b74 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -27,10 +27,10 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(std::shared_ptr graph, const std::vector> &ops, - const std::shared_ptr>> eli_list, +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list); + const std::shared_ptr> &index_list); std::vector> PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); @@ -50,12 +50,12 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr std::vector> PrepareStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -void GeneratePartitionedOperatorStrategy(const std::shared_ptr graph, +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr> index_list); + const std::shared_ptr> &index_list); size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, const size_t iter_ops); -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr graph, +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_ops, const size_t iter_graph); std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, @@ -63,19 +63,23 @@ std::vector PrepareIncomingOperatorInputStrategy(const std::vector GetAxisList(const std::vector> &ops, const int iter_ops); std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, const size_t incoming_op_index, std::vector s); +bool GetKeepDims(const std::vector> &ops, const size_t iter_ops); std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, const size_t incoming_op_index, std::vector s); +std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, const size_t iter_ops, const size_t incoming_op_index); std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, std::vector basic_stra); -void GenerateEliminatedOperatorStrategyForward(std::shared_ptr graph, +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list); + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, const size_t iter_ops, std::vector s); std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, @@ -83,12 +87,12 @@ std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> no_stra_op_list); -void GenerateRemainingOperatorStrategy(const std::shared_ptr graph, + const std::shared_ptr> &no_stra_op_list); +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, - const std::shared_ptr> index_list, - const std::shared_ptr> no_stra_op_list); + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h index 647b857e161..9007218d152 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h @@ -50,7 +50,8 @@ enum OperatorType { kRecCast, kRecReduce, kRecPReLU, - kRecGatherV2 + kRecGatherV2, + kRecArgWithValue }; enum InfoType { kApplication, kConstant }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 3e4eafe0a4c..58884be9db7 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -163,8 +163,8 @@ size_t GetIndexInInputTensorNames(const std::vector> &i return SIZE_MAX; } -void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, - const std::shared_ptr>> eli_list) { +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list) { std::vector eli; eli.push_back(node_index); for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { @@ -211,18 +211,18 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, } } -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - const std::shared_ptr>> eli_list, - const std::shared_ptr> index_list) { +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list) { MS_EXCEPTION_IF_NULL(graph); - const std::set type_list = { - OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, - OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, - OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, - OperatorType::kRecReshape, OperatorType::kRecGatherV2}; + const std::set elementwise_type = { + OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, + OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, + OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, + OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { auto type = graph->nodes[node_index].apply.op_type; - if (type_list.find(type) != type_list.end()) { + if (elementwise_type.find(type) != elementwise_type.end()) { Eliminate_Aux(node_index, graph, eli_list); } } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 1be8e4c7963..a696e883327 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -47,6 +47,8 @@ const std::map DictOpType{ {REDUCE_MIN, OperatorType::kRecReduce}, {REDUCE_MEAN, OperatorType::kRecReduce}, {GATHERV2, OperatorType::kRecGatherV2}, + {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, + {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, {RELU, OperatorType::kRecReLU}, {"ReLU6", OperatorType::kRecReLU}, @@ -59,6 +61,7 @@ const std::map DictOpType{ {PRELU, OperatorType::kRecPReLU}, + {TRANSPOSE, OperatorType::kRecElmWiseOp}, {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp}, @@ -123,12 +126,12 @@ void MakeEdge(const std::vector> &input_tensor_names, s size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, const std::string &input_name); -void Eliminate_Aux(const size_t node_index, const std::shared_ptr graph, - const std::shared_ptr>> eli_list); +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list); -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - const std::shared_ptr>> eli_list, - const std::shared_ptr> index_list); +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_