!2530 [CT][ME][parallel] One-hot runs failed in RP-search mode

Merge pull request !2530 from Chong/zc
This commit is contained in:
mindspore-ci-bot 2020-06-24 14:47:15 +08:00 committed by Gitee
commit 7345d7471b
5 changed files with 157 additions and 60 deletions

View File

@ -28,10 +28,10 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list) { const std::shared_ptr<std::vector<size_t>> &index_list) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(eli_list); MS_EXCEPTION_IF_NULL(eli_list);
MS_EXCEPTION_IF_NULL(index_list); MS_EXCEPTION_IF_NULL(index_list);
@ -140,10 +140,24 @@ std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &gr
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) { const size_t iter_graph, const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); std::vector<std::vector<int32_t>> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
int32_t axis = -1;
auto iter = ops[iter_ops]->attrs().find(AXIS);
if (iter != ops[iter_ops]->attrs().end()) {
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<Int32Imm>()) {
axis = iter->second->cast<Int32ImmPtr>()->value();
} else {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int.";
}
}
if (axis == -1) {
strategies[0][0] = strategies[0][1]; strategies[0][0] = strategies[0][1];
strategies[0][1] = 1; strategies[0][1] = 1;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
}
std::vector<int32_t> s_empty = {}; std::vector<int32_t> s_empty = {};
strategies.push_back(s_empty); strategies.push_back(s_empty);
strategies.push_back(s_empty); strategies.push_back(s_empty);
@ -221,7 +235,7 @@ std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Gr
} else if (output_size == 0) { } else if (output_size == 0) {
s = {}; s = {};
} else { } else {
MS_LOG(ERROR) << "Tensor's output size is unexcepted."; MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted.";
} }
strategies.push_back(s); strategies.push_back(s);
} }
@ -241,7 +255,7 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); StrategyPtr origin_strategy = ops[iter_ops]->strategy();
std::vector<std::vector<int32_t>> strategies; std::vector<std::vector<int32_t>> strategies;
size_t max_device_num = g_device_manager->DeviceNum(); size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
@ -256,8 +270,10 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr
} else { } else {
s.push_back(1); s.push_back(1);
} }
} else if (input_size == 0) {
s = {};
} else { } else {
MS_LOG(ERROR) << "Tensor's shape is unknown."; MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
} }
} }
strategies.push_back(s); strategies.push_back(s);
@ -304,13 +320,13 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
} }
} }
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph, void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> index_list) { const std::shared_ptr<std::vector<size_t>> &index_list) {
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) {
std::vector<std::vector<int32_t>> strategies; std::vector<std::vector<int32_t>> strategies;
size_t iter_graph = index_list->at(iter_ops); size_t iter_graph = index_list->at(iter_ops);
if (iter_graph != SIZE_MAX) { if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) {
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
} }
StrategyPtr sp = std::make_shared<Strategy>(0, strategies); StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
@ -335,7 +351,7 @@ size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &
return incoming_op_index; return incoming_op_index;
} }
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph, std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph) { const size_t iter_ops, const size_t iter_graph) {
std::vector<int32_t> s; std::vector<int32_t> s;
@ -354,8 +370,10 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c);
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h);
s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w);
} else if (input_stra_dim == 0) {
s = {};
} else { } else {
MS_LOG(ERROR) << "Tensor's shape is unknown."; MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
} }
break; break;
} }
@ -365,7 +383,8 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index) { const size_t incoming_op_index) {
std::vector<int32_t> s; std::vector<int32_t> s;
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2) { if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 ||
ops[incoming_op_index]->type() == TRANSPOSE) {
return s; return s;
} }
auto strategy = ops[incoming_op_index]->selected_strategy(); auto strategy = ops[incoming_op_index]->selected_strategy();
@ -433,13 +452,23 @@ std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shar
return s_Squeeze; return s_Squeeze;
} }
bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
bool keepdims = false;
auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS);
if (keep_dims_iter == ops[iter_ops]->attrs().end()) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims.";
}
MS_EXCEPTION_IF_NULL(keep_dims_iter->second);
if (!keep_dims_iter->second->isa<BoolImm>()) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool.";
}
keepdims = keep_dims_iter->second->cast<BoolImmPtr>()->value();
return keepdims;
}
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) { std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
std::vector<int32_t> dim_list; std::vector<int32_t> dim_list;
bool keep_dims; bool keep_dims = GetKeepDims(ops, iter_ops);
if (!ops[iter_ops]->attrs().find(KEEP_DIMS)->second->isa<BoolImm>()) {
MS_LOG(EXCEPTION) << "Failure: Parameter keep_dims is not a boolean value." << std::endl;
}
keep_dims = ops[iter_ops]->attrs().find(KEEP_DIMS)->second->cast<BoolImmPtr>()->value();
if (keep_dims != false) { if (keep_dims != false) {
return dim_list; return dim_list;
} }
@ -485,6 +514,62 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
return s_Reduce; return s_Reduce;
} }
std::vector<int32_t> GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
std::vector<int32_t> dim_list;
auto iter = ops[iter_ops]->attrs().find(AXIS);
if (iter == ops[iter_ops]->attrs().end()) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis.";
}
auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<ValueTuple>()) {
auto attr_axis = GetValue<std::vector<int>>(iter->second);
if (attr_axis.empty()) {
for (size_t i = 0; i < input_dim; ++i) {
dim_list.push_back(SizeToInt(i));
}
} else {
for (auto &axis : attr_axis) {
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis);
}
}
} else if (iter->second->isa<Int32Imm>()) {
int axis = GetValue<int>(iter->second);
axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis);
} else {
MS_LOG(EXCEPTION) << "Axis type is invalid.";
}
return dim_list;
}
std::vector<int32_t> ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s) {
bool keepdims = GetKeepDims(ops, incoming_op_index);
if (keepdims) {
return s;
}
std::vector<int32_t> s_Arg;
std::vector<int32_t> axis_list;
for (size_t i = 0; i < s.size(); i++) {
axis_list.push_back(i);
}
auto dim_list = GetDimListFromAttrs(ops, incoming_op_index);
for (auto axis : dim_list) {
auto it = find(axis_list.begin(), axis_list.end(), axis);
if (it == axis_list.end()) {
MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
}
axis_list.erase(it);
}
for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
s_Arg.push_back(s[axis_list[i]]);
}
return s_Arg;
}
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t incoming_op_index) { const size_t iter_ops, const size_t incoming_op_index) {
std::vector<int32_t> s; std::vector<int32_t> s;
@ -497,6 +582,9 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s);
} }
if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) {
s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s);
}
} }
return s; return s;
} }
@ -551,11 +639,11 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
return stra; return stra;
} }
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> graph, void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list, const std::shared_ptr<std::vector<size_t>> &index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
if (no_stra_op_list->size() == 0) { if (no_stra_op_list->size() == 0) {
return; return;
} }
@ -624,7 +712,8 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
std::vector<int32_t> s; std::vector<int32_t> s;
if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN ||
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE ||
ops[iter_ops]->type() == GATHERV2) { ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE ||
ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) {
return s; return s;
} }
@ -656,7 +745,7 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
if (no_stra_op_list->size() == 0) { if (no_stra_op_list->size() == 0) {
return; return;
} }
@ -686,16 +775,16 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_pt
} }
} }
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph, void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list, const std::shared_ptr<std::vector<size_t>> &index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) { const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
if (no_stra_op_list->size() == 0) { if (no_stra_op_list->size() == 0) {
return; return;
} }
size_t no_stra_op_list_size; size_t no_stra_op_list_size = no_stra_op_list->size();
do { do {
no_stra_op_list_size = no_stra_op_list->size(); no_stra_op_list_size = no_stra_op_list->size();
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);

View File

@ -27,10 +27,10 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
@ -50,12 +50,12 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr
std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph, void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops); const size_t iter_ops);
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph, std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph); const size_t iter_ops, const size_t iter_graph);
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
@ -63,19 +63,23 @@ std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std:
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops); std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, std::vector<int32_t> s);
bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops); std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t incoming_op_index); const size_t iter_ops, const size_t incoming_op_index);
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_ops,
std::vector<int32_t> basic_stra); std::vector<int32_t> basic_stra);
void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph, void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list, const std::shared_ptr<std::vector<size_t>> &index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s); const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
@ -83,12 +87,12 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
const size_t iter_ops); const size_t iter_ops);
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph, void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list, const std::shared_ptr<std::vector<size_t>> &index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_

View File

@ -50,7 +50,8 @@ enum OperatorType {
kRecCast, kRecCast,
kRecReduce, kRecReduce,
kRecPReLU, kRecPReLU,
kRecGatherV2 kRecGatherV2,
kRecArgWithValue
}; };
enum InfoType { kApplication, kConstant }; enum InfoType { kApplication, kConstant };

View File

@ -163,8 +163,8 @@ size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &i
return SIZE_MAX; return SIZE_MAX;
} }
void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph, void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list) { const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list) {
std::vector<size_t> eli; std::vector<size_t> eli;
eli.push_back(node_index); eli.push_back(node_index);
for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) {
@ -211,18 +211,18 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph,
} }
} }
std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph, std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::shared_ptr<std::vector<size_t>> index_list) { const std::shared_ptr<std::vector<size_t>> &index_list) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
const std::set<OperatorType> type_list = { const std::set<OperatorType> elementwise_type = {
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecReshape, OperatorType::kRecGatherV2}; OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type; auto type = graph->nodes[node_index].apply.op_type;
if (type_list.find(type) != type_list.end()) { if (elementwise_type.find(type) != elementwise_type.end()) {
Eliminate_Aux(node_index, graph, eli_list); Eliminate_Aux(node_index, graph, eli_list);
} }
} }

View File

@ -47,6 +47,8 @@ const std::map<std::string, OperatorType> DictOpType{
{REDUCE_MIN, OperatorType::kRecReduce}, {REDUCE_MIN, OperatorType::kRecReduce},
{REDUCE_MEAN, OperatorType::kRecReduce}, {REDUCE_MEAN, OperatorType::kRecReduce},
{GATHERV2, OperatorType::kRecGatherV2}, {GATHERV2, OperatorType::kRecGatherV2},
{ARGMAXWITHVALUE, OperatorType::kRecArgWithValue},
{ARGMINWITHVALUE, OperatorType::kRecArgWithValue},
{RELU, OperatorType::kRecReLU}, {RELU, OperatorType::kRecReLU},
{"ReLU6", OperatorType::kRecReLU}, {"ReLU6", OperatorType::kRecReLU},
@ -59,6 +61,7 @@ const std::map<std::string, OperatorType> DictOpType{
{PRELU, OperatorType::kRecPReLU}, {PRELU, OperatorType::kRecPReLU},
{TRANSPOSE, OperatorType::kRecElmWiseOp},
{L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {L2_NORMALIZE, OperatorType::kRecElmWiseOp},
{TENSOR_ADD, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},
@ -124,12 +127,12 @@ void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, s
size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names, size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names,
const std::string &input_name); const std::string &input_name);
void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph, void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list); const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list);
std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph, std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::shared_ptr<std::vector<size_t>> index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_