modify strategys to strategies

This commit is contained in:
yangzhenzhang 2022-07-08 10:17:45 +08:00
parent df68f7cb92
commit 3b7fc4db29
100 changed files with 461 additions and 461 deletions

View File

@ -1674,7 +1674,7 @@ Status CostGraph::InitReshapeStrategy() {
if (stra.empty()) { if (stra.empty()) {
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
} }
Strategys stra_inputs = {stra}; Strategies stra_inputs = {stra};
StrategyPtr reshape_stra = StrategyPtr reshape_stra =
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
reshape_info->set_strategy(reshape_stra); reshape_info->set_strategy(reshape_stra);

View File

@ -80,9 +80,9 @@ Dimensions PrepareMatMulStrategy(const std::shared_ptr<Graph> &graph, const size
return s; return s;
} }
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategies PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) { const size_t iter_graph, const size_t iter_ops) {
Strategys strategies; Strategies strategies;
auto attrs = ops[iter_ops]->attrs(); auto attrs = ops[iter_ops]->attrs();
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value(); bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
@ -94,8 +94,8 @@ Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<s
return strategies; return strategies;
} }
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) { Strategies PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
Strategys strategies; Strategies strategies;
strategies.push_back(*s); strategies.push_back(*s);
Dimensions s_biasadd; Dimensions s_biasadd;
s_biasadd.push_back(s->at(1)); s_biasadd.push_back(s->at(1));
@ -103,9 +103,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
return strategies; return strategies;
} }
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra) { Dimensions basic_stra) {
Strategys stra; Strategies stra;
auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1)); auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1));
auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2)); auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2));
@ -128,9 +128,9 @@ Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &
return stra; return stra;
} }
Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra) { Dimensions basic_stra) {
Strategys strategies; Strategies strategies;
strategies.push_back(basic_stra); strategies.push_back(basic_stra);
std::vector<int64_t> axis_list; std::vector<int64_t> axis_list;
string axis_name = AXIS; string axis_name = AXIS;
@ -173,8 +173,8 @@ Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
return strategies; return strategies;
} }
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) { Strategies PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
Strategys strategies; Strategies strategies;
// The Dimension size of the first input tensor of OneHot should be 2 even its Shape size is 1. Using the division of // The Dimension size of the first input tensor of OneHot should be 2 even its Shape size is 1. Using the division of
// the number of devices and the partition parts of the first dimension. // the number of devices and the partition parts of the first dimension.
@ -206,8 +206,8 @@ Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, c
return strategies; return strategies;
} }
Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) { Strategies PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
Strategys strategies; Strategies strategies;
auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape(); auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape();
Dimensions index(output_shape.size() - 1, 0); Dimensions index(output_shape.size() - 1, 0);
@ -302,8 +302,8 @@ Dimensions PrepareGatherV2OutputStrategy(const std::vector<std::shared_ptr<Opera
return strategie; return strategie;
} }
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions s) { Dimensions s) {
int64_t axis = 0; int64_t axis = 0;
auto iter = ops[iter_ops]->attrs().find(AXIS); auto iter = ops[iter_ops]->attrs().find(AXIS);
if (iter != ops[iter_ops]->attrs().end()) { if (iter != ops[iter_ops]->attrs().end()) {
@ -323,15 +323,15 @@ Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &o
s[LongToSize(axis_index)] = 1; s[LongToSize(axis_index)] = 1;
Strategys strategies; Strategies strategies;
strategies.push_back(s); strategies.push_back(s);
return strategies; return strategies;
} }
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph, Strategies PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops) { const size_t iter_ops) {
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); Strategies strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
if (strategies.size() < 1) { if (strategies.size() < 1) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy."; MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
} }
@ -380,9 +380,9 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
return strategies; return strategies;
} }
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, Strategies MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops) { const size_t iter_ops) {
if (ops.empty()) { if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty."; MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
} }
@ -394,7 +394,7 @@ Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
} }
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); StrategyPtr origin_strategy = ops[iter_ops]->strategy();
Strategys strategies; Strategies strategies;
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
@ -437,9 +437,9 @@ Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
return strategies; return strategies;
} }
Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, Strategies MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops) { const size_t iter_ops) {
if (ops.empty()) { if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty."; MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
} }
@ -448,7 +448,7 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
} }
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); StrategyPtr origin_strategy = ops[iter_ops]->strategy();
Strategys strategies; Strategies strategies;
size_t max_device_num = g_device_manager->DeviceNum(); size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
@ -500,9 +500,9 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
return strategies; return strategies;
} }
Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph, Strategies MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops) { const size_t iter_ops) {
if (ops.empty()) { if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty."; MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
} }
@ -511,7 +511,7 @@ Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
} }
StrategyPtr origin_strategy = ops[iter_ops]->strategy(); StrategyPtr origin_strategy = ops[iter_ops]->strategy();
Strategys strategies; Strategies strategies;
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
@ -540,7 +540,7 @@ Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) { void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
StrategyPtr origin_strategy = op->strategy(); StrategyPtr origin_strategy = op->strategy();
Strategys strategies; Strategies strategies;
for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) { for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) {
Dimensions s; Dimensions s;
@ -561,8 +561,8 @@ void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
op->SetSelectedStrategyAndCost(sp, op->selected_cost()); op->SetSelectedStrategyAndCost(sp, op->selected_cost());
} }
Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategies PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) { const size_t iter_graph, const size_t iter_ops) {
if (ops.empty()) { if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty."; MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
} }
@ -593,7 +593,7 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> &index_list) { const std::shared_ptr<std::vector<size_t>> &index_list) {
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) {
Strategys strategies; Strategies strategies;
size_t iter_graph = index_list->at(iter_ops); size_t iter_graph = index_list->at(iter_ops);
if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) {
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
@ -618,7 +618,7 @@ void ModifyParamSharingOpsStrategy(const std::vector<std::shared_ptr<OperatorInf
} else { } else {
continue; continue;
} }
Strategys strategies; Strategies strategies;
Dimensions str1, str2; Dimensions str1, str2;
str1 = str_j; str1 = str_j;
size_t num_device_used = 1; size_t num_device_used = 1;
@ -1107,9 +1107,9 @@ Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<O
return s; return s;
} }
Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra) { Dimensions basic_stra) {
Strategys stra; Strategies stra;
MS_EXCEPTION_IF_NULL(ops[iter_ops]); MS_EXCEPTION_IF_NULL(ops[iter_ops]);
if (iter_ops >= ops.size()) { if (iter_ops >= ops.size()) {
@ -1157,9 +1157,9 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
} }
// Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc. // Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const Dimensions s) { const Dimensions s) {
Strategys stra; Strategies stra;
size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size(); size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size();
@ -1256,10 +1256,10 @@ Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
} }
// Check whether the operator can be divided by the current strategy. // Check whether the operator can be divided by the current strategy.
Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const Dimensions basic_stra) { const Dimensions basic_stra) {
Dimensions s_empty = {}; Dimensions s_empty = {};
Strategys stra; Strategies stra;
// For all the input tensors. // For all the input tensors.
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
@ -1302,7 +1302,7 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &gra
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
size_t iter_ops = no_stra_op_list->at(iter_list - 1); size_t iter_ops = no_stra_op_list->at(iter_list - 1);
Strategys stra; Strategies stra;
Dimensions s; Dimensions s;
size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
if (incoming_op_index != SIZE_MAX) { if (incoming_op_index != SIZE_MAX) {
@ -1405,7 +1405,7 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_pt
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
auto iter_ops = no_stra_op_list->at(iter_list - 1); auto iter_ops = no_stra_op_list->at(iter_list - 1);
Strategys stra; Strategies stra;
Dimensions s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); Dimensions s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops);
if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) { if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) {
s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s);
@ -1444,7 +1444,7 @@ void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) { for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) {
auto iter_ops = no_stra_op_list->at(iter_list); auto iter_ops = no_stra_op_list->at(iter_list);
Strategys stra; Strategies stra;
Dimensions s; Dimensions s;
size_t max_dim_num = 0; size_t max_dim_num = 0;

View File

@ -34,38 +34,38 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
const std::vector<std::vector<size_t>> &shared_tensors_ops); const std::vector<std::vector<size_t>> &shared_tensors_ops);
Dimensions PrepareMatMulStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_graph, bool transpose_a, Dimensions PrepareMatMulStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_graph, bool transpose_a,
bool transpose_b, size_t iter_op_inputs); bool transpose_b, size_t iter_op_inputs);
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategies PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s); Strategies PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra); Dimensions basic_stra);
Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra); Dimensions basic_stra);
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s); Strategies PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph, Strategies PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops); const size_t iter_ops);
Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s); Strategies PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Dimensions PrepareGatherV2OutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions PrepareGatherV2OutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index); const size_t incoming_op_index);
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions s); Dimensions s);
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, Strategies MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops); const size_t iter_ops);
Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s); Strategies CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s, Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s,
size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor); size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor);
Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s); Strategies CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, Strategies MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops); const size_t iter_ops);
Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph, Strategies MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops); const size_t iter_ops);
void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op); void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op);
Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategies PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph, void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> &index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
@ -97,8 +97,8 @@ Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<Operato
const size_t incoming_op_index, Dimensions s); const size_t incoming_op_index, Dimensions s);
Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t incoming_op_index); const size_t iter_ops, const size_t incoming_op_index);
Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra); Dimensions basic_stra);
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph, void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,

View File

@ -108,7 +108,7 @@ Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
for (auto &element : axis_) { for (auto &element : axis_) {
@ -243,7 +243,7 @@ Status CumOpBase::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
if (input_strategy.size() <= LongToSize(axis_)) { if (input_strategy.size() <= LongToSize(axis_)) {
MS_LOG(ERROR) << "The " << name_ << " input strategy length: " << input_strategy.size() << ", is less ot equal to " MS_LOG(ERROR) << "The " << name_ << " input strategy length: " << input_strategy.size() << ", is less ot equal to "
@ -292,7 +292,7 @@ Status CumOpBase::InferMirrorOps() {
} }
Status ActivationBase::InferDevMatrixShape() { Status ActivationBase::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;

View File

@ -256,8 +256,8 @@ class ExpandDimsInfo : public ActivationOther {
private: private:
int64_t positive_axis_ = -1; int64_t positive_axis_ = -1;
Strategys inputs_strategy_; Strategies inputs_strategy_;
Strategys outputs_strategy_; Strategies outputs_strategy_;
}; };
class SqueezeInfo : public ActivationOther { class SqueezeInfo : public ActivationOther {

View File

@ -27,7 +27,7 @@ Status AddNInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
// The strategy for each input tensor must be equal // The strategy for each input tensor must be equal
Strategys strategies = strategy->GetInputDim(); Strategies strategies = strategy->GetInputDim();
for (size_t i = 1; i < strategies.size(); ++i) { for (size_t i = 1; i < strategies.size(); ++i) {
if (strategies[i] != strategies[0]) { if (strategies[i] != strategies[0]) {
return FAILED; return FAILED;
@ -39,7 +39,7 @@ Status AddNInfo::CheckStrategy(const StrategyPtr &strategy) {
Status AddNInfo::InferDevMatrixShape() { Status AddNInfo::InferDevMatrixShape() {
dev_matrix_shape_.clear(); dev_matrix_shape_.clear();
Strategys strategies = strategy_->GetInputDim(); Strategies strategies = strategy_->GetInputDim();
if (strategies.empty()) { if (strategies.empty()) {
return SUCCESS; return SUCCESS;
} }
@ -59,7 +59,7 @@ Status AddNInfo::InferTensorMap() {
sub_tensor_map.push_back(dev_size - i - 1); sub_tensor_map.push_back(dev_size - i - 1);
} }
Strategys strategies = strategy_->GetInputDim(); Strategies strategies = strategy_->GetInputDim();
for (size_t i = 0; i < strategies.size(); ++i) { for (size_t i = 0; i < strategies.size(); ++i) {
inputs_tensor_map_.push_back(sub_tensor_map); inputs_tensor_map_.push_back(sub_tensor_map);
} }

View File

@ -53,9 +53,9 @@ Shapes ArithmeticBase::InferExpandShape() {
return input_shapes; return input_shapes;
} }
Strategys ExpandStrategy(const StrategyPtr &strategy) { Strategies ExpandStrategy(const StrategyPtr &strategy) {
Strategys expand_strategy; Strategies expand_strategy;
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
size_t input_a_size = sub_a_strategy.size(); size_t input_a_size = sub_a_strategy.size();
@ -77,7 +77,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Shapes input_shapes = InferExpandShape(); Shapes input_shapes = InferExpandShape();
Strategys expand_strategy = ExpandStrategy(strategy); Strategies expand_strategy = ExpandStrategy(strategy);
Dimensions sub_a_strategy = expand_strategy.at(0); Dimensions sub_a_strategy = expand_strategy.at(0);
Dimensions sub_b_strategy = expand_strategy.at(1); Dimensions sub_b_strategy = expand_strategy.at(1);
Shape input_a_shape = input_shapes.at(0); Shape input_a_shape = input_shapes.at(0);
@ -93,7 +93,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
} }
Status ArithmeticBase::InferDevMatrixShape() { Status ArithmeticBase::InferDevMatrixShape() {
Strategys expand_strategy = ExpandStrategy(strategy_); Strategies expand_strategy = ExpandStrategy(strategy_);
Dimensions sub_a_strategy = expand_strategy.at(0); Dimensions sub_a_strategy = expand_strategy.at(0);
Dimensions sub_b_strategy = expand_strategy.at(1); Dimensions sub_b_strategy = expand_strategy.at(1);
Shape dev_shape; Shape dev_shape;
@ -150,10 +150,10 @@ void ArithmeticBase::ReComputeBatchSplitFlagList() {
Status ArithmeticBase::InferTensorMap() { Status ArithmeticBase::InferTensorMap() {
Shape tensor_map_index; Shape tensor_map_index;
Strategys expand_strategy = ExpandStrategy(strategy_); Strategies expand_strategy = ExpandStrategy(strategy_);
Dimensions sub_a_expand_strategy = expand_strategy.at(0); Dimensions sub_a_expand_strategy = expand_strategy.at(0);
Dimensions sub_b_expand_strategy = expand_strategy.at(1); Dimensions sub_b_expand_strategy = expand_strategy.at(1);
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
for (size_t i = 0; i < sub_a_expand_strategy.size(); ++i) { for (size_t i = 0; i < sub_a_expand_strategy.size(); ++i) {
@ -251,7 +251,7 @@ Status LerpInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
// validate strategy of weight // validate strategy of weight
Strategys expand_strategy = ExpandStrategy(strategy); Strategies expand_strategy = ExpandStrategy(strategy);
Dimensions expand_begin_strategy = expand_strategy.at(0); Dimensions expand_begin_strategy = expand_strategy.at(0);
Dimensions expand_end_strategy = expand_strategy.at(1); Dimensions expand_end_strategy = expand_strategy.at(1);
Dimensions expand_cmp_strategy; Dimensions expand_cmp_strategy;
@ -286,7 +286,7 @@ Status LerpInfo::InferDevMatrixShape() {
} }
dev_matrix_shape_.clear(); dev_matrix_shape_.clear();
Strategys expand_strategy = ExpandStrategy(strategy_); Strategies expand_strategy = ExpandStrategy(strategy_);
Dimensions expand_start_strategy = expand_strategy.at(0); Dimensions expand_start_strategy = expand_strategy.at(0);
Dimensions expand_end_strategy = expand_strategy.at(1); Dimensions expand_end_strategy = expand_strategy.at(1);
auto strategies = strategy_->GetInputDim(); auto strategies = strategy_->GetInputDim();
@ -316,9 +316,9 @@ Status LerpInfo::InferTensorMap() {
return FAILED; return FAILED;
} }
// Generate tensor map for 'weight' // Generate tensor map for 'weight'
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions weight_strategy = stra.at(2); Dimensions weight_strategy = stra.at(2);
Strategys expand_strategy = ExpandStrategy(strategy_); Strategies expand_strategy = ExpandStrategy(strategy_);
Dimensions expand_start_strategy = expand_strategy.at(0); Dimensions expand_start_strategy = expand_strategy.at(0);
Dimensions expand_weight_strategy = ExpandShape(expand_start_strategy, weight_strategy); Dimensions expand_weight_strategy = ExpandShape(expand_start_strategy, weight_strategy);
Shape dev_shape = dev_matrix_shape_; Shape dev_shape = dev_matrix_shape_;

View File

@ -32,7 +32,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
size_t strategy_size = strategy->GetInputNumber(); size_t strategy_size = strategy->GetInputNumber();
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
for (size_t i = 0; i < strategy_size; ++i) { for (size_t i = 0; i < strategy_size; ++i) {
Shape sub_strategy = stra.at(i); Shape sub_strategy = stra.at(i);
size_t strategy_len = sub_strategy.size(); size_t strategy_len = sub_strategy.size();
@ -122,7 +122,7 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_id) { std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_id) {
StrategyPtr sp; StrategyPtr sp;
Strategys strategy; Strategies strategy;
ComputeBatchSplitFlagList(); ComputeBatchSplitFlagList();
for (size_t i = 0; i < inputs_shape_.size(); i++) { for (size_t i = 0; i < inputs_shape_.size(); i++) {
@ -181,7 +181,7 @@ Status CheckValidInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (stra[0][1] != 1) { if (stra[0][1] != 1) {
MS_LOG(ERROR) << name_ << ": The second dimension of the first input can not be split, but got " << stra[0][1]; MS_LOG(ERROR) << name_ << ": The second dimension of the first input can not be split, but got " << stra[0][1];
return FAILED; return FAILED;
@ -195,7 +195,7 @@ Status CheckValidInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status CheckValidInfo::InferDevMatrixShape() { Status CheckValidInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
dev_matrix_shape_.push_back(stra[0][0]); dev_matrix_shape_.push_back(stra[0][0]);
return SUCCESS; return SUCCESS;
} }

View File

@ -315,7 +315,7 @@ std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
if (first_input_strategy.size() < 2) { if (first_input_strategy.size() < 2) {
MS_LOG(EXCEPTION) << name_ << ": The size of first input strategy can not smaller than 2, but got " MS_LOG(EXCEPTION) << name_ << ": The size of first input strategy can not smaller than 2, but got "

View File

@ -30,7 +30,7 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
int64_t channel_a_strategy = sub_a_strategy.at(1); int64_t channel_a_strategy = sub_a_strategy.at(1);
@ -43,7 +43,7 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status BiasAddInfo::InferDevMatrixShape() { Status BiasAddInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
dev_matrix_shape_ = sub_a_strategy; dev_matrix_shape_ = sub_a_strategy;
return SUCCESS; return SUCCESS;
@ -57,7 +57,7 @@ void BiasAddInfo::ReComputeBatchSplitFlagList() {
Status BiasAddInfo::InferTensorMap() { Status BiasAddInfo::InferTensorMap() {
TensorMap sub_a_tensor_map; TensorMap sub_a_tensor_map;
TensorMap sub_b_tensor_map; TensorMap sub_b_tensor_map;
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
size_t sub_a_strategy_size = sub_a_strategy.size(); size_t sub_a_strategy_size = sub_a_strategy.size();
for (size_t i = 0; i < sub_a_strategy_size; ++i) { for (size_t i = 0; i < sub_a_strategy_size; ++i) {
@ -88,7 +88,7 @@ std::vector<StrategyPtr> BiasAddInfo::GenerateOpStrategies(int64_t stage_id) {
MS_LOG(INFO) << name_ << " : Generate strategies success."; MS_LOG(INFO) << name_ << " : Generate strategies success.";
for (auto &sp : sp_vector) { for (auto &sp : sp_vector) {
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions input0_strategy = sp->GetInputDim()[0]; Dimensions input0_strategy = sp->GetInputDim()[0];
tmp_strategy.push_back(input0_strategy); // input0 tmp_strategy.push_back(input0_strategy); // input0

View File

@ -23,7 +23,7 @@ Status BoundingBoxEncodeInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys strategies = strategy->GetInputDim(); Strategies strategies = strategy->GetInputDim();
Dimensions input_a_strategy = strategies[0]; Dimensions input_a_strategy = strategies[0];
Dimensions input_b_strategy = strategies[1]; Dimensions input_b_strategy = strategies[1];
if (input_a_strategy != input_b_strategy) { if (input_a_strategy != input_b_strategy) {
@ -45,7 +45,7 @@ Status BoundingBoxEncodeInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status BoundingBoxEncodeInfo::InferDevMatrixShape() { Status BoundingBoxEncodeInfo::InferDevMatrixShape() {
Strategys strategies = strategy_->GetInputDim(); Strategies strategies = strategy_->GetInputDim();
Dimensions input_a_strategy = strategies.at(0); Dimensions input_a_strategy = strategies.at(0);
dev_matrix_shape_.clear(); dev_matrix_shape_.clear();
@ -106,7 +106,7 @@ Status BoundingBoxEncodeInfo::PrepareStrategy(int64_t stage_id, int64_t split_nu
Dimensions input0_partitions = {split_num, 1}; Dimensions input0_partitions = {split_num, 1};
Dimensions input1_partitions = {split_num, 1}; Dimensions input1_partitions = {split_num, 1};
Strategys strategies = {input0_partitions, input1_partitions}; Strategies strategies = {input0_partitions, input1_partitions};
(*sp) = std::make_shared<Strategy>(stage_id, strategies); (*sp) = std::make_shared<Strategy>(stage_id, strategies);
return SUCCESS; return SUCCESS;
} }

View File

@ -141,7 +141,7 @@ std::vector<StrategyPtr> BroadcastToInfo::GenerateOpStrategies(int64_t stage_id)
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
for (size_t i = 0; i < inputs_shape_.size(); ++i) { for (size_t i = 0; i < inputs_shape_.size(); ++i) {
tmp_strategy.push_back(first_input_strategy); tmp_strategy.push_back(first_input_strategy);

View File

@ -172,7 +172,7 @@ std::vector<StrategyPtr> ConcatInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
for (size_t i = 0; i < inputs_shape_.size(); ++i) { for (size_t i = 0; i < inputs_shape_.size(); ++i) {
tmp_strategy.push_back(first_input_strategy); tmp_strategy.push_back(first_input_strategy);

View File

@ -936,7 +936,7 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
auto search_mode = parallel_context->strategy_search_mode(); auto search_mode = parallel_context->strategy_search_mode();
// generate data parallel strategy when the search mode is not sharding propagation // generate data parallel strategy when the search mode is not sharding propagation
if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) { if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) {
Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}}; Strategies strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
StrategyPtr data_parallel_sp = std::make_shared<Strategy>(stage_id, strategy); StrategyPtr data_parallel_sp = std::make_shared<Strategy>(stage_id, strategy);
sp_vector.push_back(data_parallel_sp); sp_vector.push_back(data_parallel_sp);
return sp_vector; return sp_vector;
@ -960,7 +960,7 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys replace_strategy; Strategies replace_strategy;
Dimensions tmp_strategy = sp->GetInputDim()[0]; Dimensions tmp_strategy = sp->GetInputDim()[0];
if (tmp_strategy.size() != 5) { if (tmp_strategy.size() != 5) {
MS_LOG(EXCEPTION) << name_ << ": The size of first tmp strategy must be 5, but got " << tmp_strategy.size(); MS_LOG(EXCEPTION) << name_ << ": The size of first tmp strategy must be 5, but got " << tmp_strategy.size();

View File

@ -29,7 +29,7 @@ Status CropAndResizeInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys strategies = strategy->GetInputDim(); Strategies strategies = strategy->GetInputDim();
auto x_strategy = strategies.at(0); auto x_strategy = strategies.at(0);
auto boxes_strategy = strategies.at(1); auto boxes_strategy = strategies.at(1);
auto index_strategy = strategies.at(2); auto index_strategy = strategies.at(2);

View File

@ -39,7 +39,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (stra.size() != 1) { if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1"; MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1";
return FAILED; return FAILED;
@ -61,7 +61,7 @@ Status DropoutDoMaskInfo::InferDevMatrixShape() {
return FAILED; return FAILED;
} }
Strategys strategy = strategy_->GetInputDim(); Strategies strategy = strategy_->GetInputDim();
if (strategy.empty()) { if (strategy.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty"; MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED; return FAILED;
@ -110,11 +110,11 @@ std::vector<StrategyPtr> DropoutDoMaskInfo::GenerateOpStrategies(int64_t stage_i
return sp_vector; return sp_vector;
} }
std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> DropoutDoMaskInfo::GenerateBatchStrategies() {
Dimensions strategy(inputs_shape_[0].size() - 1, 1); Dimensions strategy(inputs_shape_[0].size() - 1, 1);
(void)strategy.insert(strategy.begin(), stage_device_size_); (void)strategy.insert(strategy.begin(), stage_device_size_);
Strategys strategy_v = {strategy}; Strategies strategy_v = {strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
size_t GetNonMonadInputSize(const CNodePtr &cnode) { size_t GetNonMonadInputSize(const CNodePtr &cnode) {

View File

@ -38,7 +38,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
void ReplaceNodeInputOrAttrs() override; void ReplaceNodeInputOrAttrs() override;

View File

@ -48,7 +48,7 @@ Status DSDMatmulInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
} }
Strategys stras = strategy->GetInputDim(); Strategies stras = strategy->GetInputDim();
if (stras.size() != DSD_MATMUL_INPUTS_SIZE) { if (stras.size() != DSD_MATMUL_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 3."; MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 3.";
return FAILED; return FAILED;
@ -89,7 +89,7 @@ Status DSDMatmulInfo::CheckStrategy(const StrategyPtr &strategy) {
* device matrix use the strategy0. * device matrix use the strategy0.
*/ */
Status DSDMatmulInfo::InferDevMatrixShape() { Status DSDMatmulInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
input_strategy_ = input_strategy; input_strategy_ = input_strategy;
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
@ -172,7 +172,7 @@ std::vector<StrategyPtr> DSDMatmulInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions input_w1_strategy = sp->GetInputDim()[0]; Dimensions input_w1_strategy = sp->GetInputDim()[0];
Dimensions input_w2_strategy = input_w1_strategy; Dimensions input_w2_strategy = input_w1_strategy;
Dimensions input_v_strategy = {input_w1_strategy[0], input_w1_strategy[1], 1, 1}; Dimensions input_v_strategy = {input_w1_strategy[0], input_w1_strategy[1], 1, 1};

View File

@ -175,7 +175,7 @@ Status GatherInfo::GetAttrs() {
// output's strategy: [a, b, ..., c] or [1, a, b, ..., c] // output's strategy: [a, b, ..., c] or [1, a, b, ..., c]
// dev_matrix: [a, b, ..., c] // dev_matrix: [a, b, ..., c]
// can not support repeated calculation // can not support repeated calculation
Status GatherInfo::CheckManualSplit(const Strategys &strategy) { Status GatherInfo::CheckManualSplit(const Strategies &strategy) {
if (strategy.size() != 2) { if (strategy.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size(); MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
return FAILED; return FAILED;
@ -270,7 +270,7 @@ Status GatherInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices // return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
// otherwise return false // otherwise return false
bool GatherInfo::ShardBatchAndAxis(const Strategys &strategy) const { bool GatherInfo::ShardBatchAndAxis(const Strategies &strategy) const {
if (axis_ != 0) { if (axis_ != 0) {
return false; return false;
} }
@ -1015,7 +1015,7 @@ std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector; return sp_vector;
} }
std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> GatherInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
} }
@ -1029,8 +1029,8 @@ std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
for (size_t i = 1; i < inputs_shape_[1].size(); i++) { for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
index_strategy.push_back(1); index_strategy.push_back(1);
} }
Strategys strategy_v = {param_strategy, index_strategy}; Strategies strategy_v = {param_strategy, index_strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -45,7 +45,7 @@ class GatherInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
const std::vector<int64_t> &param_split_shapes() const { return param_split_shapes_; } const std::vector<int64_t> &param_split_shapes() const { return param_split_shapes_; }
const std::vector<int64_t> &index_offsets() const { return index_offsets_; } const std::vector<int64_t> &index_offsets() const { return index_offsets_; }
@ -64,7 +64,7 @@ class GatherInfo : public OperatorInfo {
void InferOutputsTensorMap(); void InferOutputsTensorMap();
void InferTensorMapForManualSplit(); void InferTensorMapForManualSplit();
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit(const Strategys &strategy); Status CheckManualSplit(const Strategies &strategy);
Status CheckSplitAxisStrategy(const StrategyPtr &strategy); Status CheckSplitAxisStrategy(const StrategyPtr &strategy);
void SetAttribute(const StrategyPtr &strategy); void SetAttribute(const StrategyPtr &strategy);
Status GetManualSplitAttr(); Status GetManualSplitAttr();
@ -73,7 +73,7 @@ class GatherInfo : public OperatorInfo {
Status InferBias(); Status InferBias();
Status InferOffset(); Status InferOffset();
Status InferGroup(); Status InferGroup();
bool ShardBatchAndAxis(const Strategys &strategy) const; bool ShardBatchAndAxis(const Strategies &strategy) const;
Shape InferOutputsTensorMapSplitAxis(); Shape InferOutputsTensorMapSplitAxis();
int64_t axis_; int64_t axis_;

View File

@ -206,7 +206,7 @@ std::vector<StrategyPtr> GatherDInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
for (size_t i = 0; i < inputs_shape_.size(); ++i) { for (size_t i = 0; i < inputs_shape_.size(); ++i) {
tmp_strategy.push_back(first_input_strategy); tmp_strategy.push_back(first_input_strategy);

View File

@ -139,7 +139,7 @@ std::vector<StrategyPtr> GatherNdInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions indices_strategy = sp->GetInputDim()[0]; Dimensions indices_strategy = sp->GetInputDim()[0];
Dimensions input_strategy(inputs_shape_[0].size(), 1); Dimensions input_strategy(inputs_shape_[0].size(), 1);
tmp_strategy.push_back(input_strategy); tmp_strategy.push_back(input_strategy);

View File

@ -105,7 +105,7 @@ Status GetNextInfo::InferDevMatrixShape() {
} }
Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) {
Strategys stras = strategy->GetInputDim(); Strategies stras = strategy->GetInputDim();
for (Dimensions stra : stras) { for (Dimensions stra : stras) {
if (stra.size() != 0) { if (stra.size() != 0) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
@ -219,7 +219,7 @@ void GetNextInfo::InferReplaceOps() {
Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> GetNextInfo::GenerateOpStrategies(int64_t stage_id) { std::vector<StrategyPtr> GetNextInfo::GenerateOpStrategies(int64_t stage_id) {
Strategys stra; Strategies stra;
StrategyPtr sp = std::make_shared<Strategy>(stage_id, stra); StrategyPtr sp = std::make_shared<Strategy>(stage_id, stra);
std::vector<StrategyPtr> sp_vector; std::vector<StrategyPtr> sp_vector;
sp_vector.push_back(sp); sp_vector.push_back(sp);

View File

@ -60,7 +60,7 @@ class GetNextInfo : public OperatorInfo {
int64_t output_num_ = 0; int64_t output_num_ = 0;
int64_t shard_num_ = 1; int64_t shard_num_ = 1;
std::string shared_name_; std::string shared_name_;
Strategys dataset_strategy_; Strategies dataset_strategy_;
Shape dev_matrix_shape_origin_; Shape dev_matrix_shape_origin_;
}; };
} // namespace parallel } // namespace parallel

View File

@ -27,7 +27,7 @@ Status InplaceAddInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_LOG(ERROR) << name_ << ": Invalid strategy"; MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED; return FAILED;
} }
Strategys strategies = strategy->GetInputDim(); Strategies strategies = strategy->GetInputDim();
auto x_strategy = strategies.at(0); auto x_strategy = strategies.at(0);
auto input_v_strategy = strategies.at(1); auto input_v_strategy = strategies.at(1);
if (x_strategy[0] != 1 || input_v_strategy[0] != 1) { if (x_strategy[0] != 1 || input_v_strategy[0] != 1) {

View File

@ -24,7 +24,7 @@ Status IOUInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys strategies = strategy->GetInputDim(); Strategies strategies = strategy->GetInputDim();
if (strategies[0][1] != 1 || strategies[1][1] != 1) { if (strategies[0][1] != 1 || strategies[1][1] != 1) {
MS_LOG(ERROR) << name_ << ": Only supports shard the 0th dimension of each input tensor, but got strategy " MS_LOG(ERROR) << name_ << ": Only supports shard the 0th dimension of each input tensor, but got strategy "
<< StrategyToString(strategies); << StrategyToString(strategies);
@ -34,7 +34,7 @@ Status IOUInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status IOUInfo::InferDevMatrixShape() { Status IOUInfo::InferDevMatrixShape() {
Strategys strategise = strategy_->GetInputDim(); Strategies strategise = strategy_->GetInputDim();
int64_t dev1 = strategise[0][0]; int64_t dev1 = strategise[0][0];
int64_t dev0 = strategise[1][0]; int64_t dev0 = strategise[1][0];

View File

@ -31,7 +31,7 @@ Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
int64_t axis_index = axis_; int64_t axis_index = axis_;
if (axis_ < 0) { if (axis_ < 0) {

View File

@ -61,7 +61,7 @@ Status LayerNormInfo::GetAttrs() {
Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy); MS_EXCEPTION_IF_NULL(strategy);
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (stra.size() != LAYER_NORM_INPUT_SIZE) { if (stra.size() != LAYER_NORM_INPUT_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
return FAILED; return FAILED;
@ -116,7 +116,7 @@ Status LayerNormInfo::InferDevMatrixShape() {
MS_LOG(ERROR) << name_ << ": The strategy is null"; MS_LOG(ERROR) << name_ << ": The strategy is null";
return FAILED; return FAILED;
} }
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
if (stra.empty()) { if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty"; MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED; return FAILED;
@ -187,7 +187,7 @@ Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyP
MS_LOG(ERROR) << name_ << ": Invalid strategy"; MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED; return FAILED;
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions input_strategy = sp->GetInputDim()[0]; Dimensions input_strategy = sp->GetInputDim()[0];
Dimensions gamma_strategy = input_strategy; Dimensions gamma_strategy = input_strategy;
(void)gamma_strategy.erase(gamma_strategy.begin(), (void)gamma_strategy.erase(gamma_strategy.begin(),

View File

@ -85,14 +85,14 @@ std::vector<StrategyPtr> LinSpaceInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector; return sp_vector;
} }
std::shared_ptr<Strategys> LinSpaceInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> LinSpaceInfo::GenerateBatchStrategies() {
if (InferAttrs() != SUCCESS) { if (InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
} }
int64_t dev_num = g_device_manager->stage_device_num(); int64_t dev_num = g_device_manager->stage_device_num();
Strategys strategies = {Dimensions{dev_num}}; Strategies strategies = {Dimensions{dev_num}};
return std::make_shared<Strategys>(strategies); return std::make_shared<Strategies>(strategies);
} }
int64_t LinSpaceInfo::GetSplitNum() { int64_t LinSpaceInfo::GetSplitNum() {

View File

@ -38,7 +38,7 @@ class LinSpaceInfo : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); } Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected: protected:

View File

@ -32,7 +32,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
Dimensions label_strategy = stra.at(1); Dimensions label_strategy = stra.at(1);
if (input_strategy != label_strategy) { if (input_strategy != label_strategy) {
@ -69,7 +69,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() {
} }
Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() { Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
return SUCCESS; return SUCCESS;

View File

@ -43,7 +43,7 @@ namespace parallel {
* Only bs and num_heads can be splited, thus the q[0] should at least be size_per_head, * Only bs and num_heads can be splited, thus the q[0] should at least be size_per_head,
* q[1] should at least be seq_len // 16. The strategy check can use bs/head from attrs. * q[1] should at least be seq_len // 16. The strategy check can use bs/head from attrs.
*/ */
Status MatmulDDSInfo::CheckStrategys(const Strategys &stras) { Status MatmulDDSInfo::CheckStrategys(const Strategies &stras) {
if (stras.size() != MATMUL_DDS_INPUTS_SIZE) { if (stras.size() != MATMUL_DDS_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 4."; MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 4.";
return FAILED; return FAILED;
@ -106,7 +106,7 @@ Status MatmulDDSInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_LOG(ERROR) << name_ << ": Invalid strategy."; MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED; return FAILED;
} }
Strategys stras = strategy->GetInputDim(); Strategies stras = strategy->GetInputDim();
if (CheckStrategys(stras) != SUCCESS) { if (CheckStrategys(stras) != SUCCESS) {
return FAILED; return FAILED;
} }
@ -117,7 +117,7 @@ Status MatmulDDSInfo::CheckStrategy(const StrategyPtr &strategy) {
* device matrix is extended by the strategy0. * device matrix is extended by the strategy0.
*/ */
Status MatmulDDSInfo::InferDevMatrixShape() { Status MatmulDDSInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
input_strategy_ = input_strategy; input_strategy_ = input_strategy;
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
@ -288,7 +288,7 @@ std::vector<StrategyPtr> MatmulDDSInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions q_strategy = sp->GetInputDim()[0]; Dimensions q_strategy = sp->GetInputDim()[0];
Dimensions k_strategy = q_strategy; Dimensions k_strategy = q_strategy;
Dimensions local_mask_strategy = {1, q_strategy[0], 1, 1}; Dimensions local_mask_strategy = {1, q_strategy[0], 1, 1};

View File

@ -50,7 +50,7 @@ class MatmulDDSInfo : public OperatorInfo {
Status GetAttrs() override; Status GetAttrs() override;
Status InferAsLossDivisor() override { return SUCCESS; } Status InferAsLossDivisor() override { return SUCCESS; }
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckStrategys(const Strategys &stras); Status CheckStrategys(const Strategies &stras);
private: private:
Dimensions input_strategy_; Dimensions input_strategy_;

View File

@ -152,7 +152,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions mat_a_strategy = stra.at(0); Dimensions mat_a_strategy = stra.at(0);
Dimensions mat_b_strategy = stra.at(1); Dimensions mat_b_strategy = stra.at(1);
@ -209,7 +209,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
return FAILED; return FAILED;
} }
Strategys in_stra = strategy_->GetInputDim(); Strategies in_stra = strategy_->GetInputDim();
Dimensions x_strategy = in_stra.at(0); Dimensions x_strategy = in_stra.at(0);
Dimensions w_strategy = in_stra.at(1); Dimensions w_strategy = in_stra.at(1);
@ -222,7 +222,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
in_shard_c = w_strategy[1]; in_shard_c = w_strategy[1];
} }
Strategys out_stra = out_strategy->GetInputDim(); Strategies out_stra = out_strategy->GetInputDim();
Dimensions output_strategy = out_stra[0]; Dimensions output_strategy = out_stra[0];
int64_t out_shard_a_or_ab = output_strategy[0]; int64_t out_shard_a_or_ab = output_strategy[0];
@ -248,7 +248,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
} }
Status MatMulBase::InferDevMatrixShape() { Status MatMulBase::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions mat_a_strategy = stra.at(0); Dimensions mat_a_strategy = stra.at(0);
Dimensions mat_b_strategy = stra.at(1); Dimensions mat_b_strategy = stra.at(1);
@ -460,7 +460,7 @@ std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys replace_strategy; Strategies replace_strategy;
Dimensions tmp_strategy = sp->GetInputDim()[0]; Dimensions tmp_strategy = sp->GetInputDim()[0];
Dimensions mat_a_strategy = tmp_strategy; Dimensions mat_a_strategy = tmp_strategy;
mat_a_strategy.pop_back(); mat_a_strategy.pop_back();
@ -483,11 +483,11 @@ std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
return sp_vector; return sp_vector;
} }
std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
(void)batch_strategy.insert(batch_strategy.begin(), stage_device_size_); (void)batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
Strategys strategy_v = {batch_strategy, batch_strategy}; Strategies strategy_v = {batch_strategy, batch_strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }

View File

@ -86,7 +86,7 @@ class BatchMatMulInfo : public MatMul {
: MatMul(name, inputs_shape, outputs_shape, attrs) {} : MatMul(name, inputs_shape, outputs_shape, attrs) {}
~BatchMatMulInfo() override = default; ~BatchMatMulInfo() override = default;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -70,7 +70,7 @@ Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status OneHotInfo::InferDevMatrixShape() { Status OneHotInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
if (axis_ == 0) { if (axis_ == 0) {
@ -235,11 +235,11 @@ std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> OneHotInfo::GenerateBatchStrategies() {
Dimensions strategy = {stage_device_size_, 1}; Dimensions strategy = {stage_device_size_, 1};
Dimensions empty_strategy; Dimensions empty_strategy;
Strategys strategy_v = {strategy, empty_strategy, empty_strategy}; Strategies strategy_v = {strategy, empty_strategy, empty_strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
Shapes OneHotInfo::InferParamStrategy(const Shapes &default_strategy) { Shapes OneHotInfo::InferParamStrategy(const Shapes &default_strategy) {

View File

@ -39,7 +39,7 @@ class OneHotInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
Shapes InferParamStrategy(const Shapes &default_strategy) override; Shapes InferParamStrategy(const Shapes &default_strategy) override;
protected: protected:

View File

@ -91,7 +91,7 @@ struct OutStrategyValueRegister {
} out_regist; } out_regist;
} // namespace } // namespace
std::string StrategyToString(const Strategys &strategy) { std::string StrategyToString(const Strategies &strategy) {
std::string strategy_str = ""; std::string strategy_str = "";
strategy_str += "("; strategy_str += "(";
for (size_t i = 0; i < strategy.size(); ++i) { for (size_t i = 0; i < strategy.size(); ++i) {
@ -130,7 +130,7 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape
size_t strategy_size = strategy->GetInputNumber(); size_t strategy_size = strategy->GetInputNumber();
size_t inputs_shape_size = inputs_shape.size(); size_t inputs_shape_size = inputs_shape.size();
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (strategy_size != inputs_shape_size) { if (strategy_size != inputs_shape_size) {
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
<< " is not equal to inputs size: " << inputs_shape_size; << " is not equal to inputs size: " << inputs_shape_size;
@ -790,7 +790,7 @@ Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
return slice_shape; return slice_shape;
} }
Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { Status InferSliceShapeByStrategy(const Strategies &strategys, const Shapes &shapes, Shapes *slice_shapes) {
if (slice_shapes == nullptr) { if (slice_shapes == nullptr) {
MS_LOG(ERROR) << "The slice_shapes is null."; MS_LOG(ERROR) << "The slice_shapes is null.";
return FAILED; return FAILED;
@ -829,7 +829,7 @@ Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shape
return SUCCESS; return SUCCESS;
} }
Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, Status OperatorInfo::InferSliceShape(const Strategies &inputs_strategy, const Strategies &outputs_strategy,
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) {
if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) {
MS_LOG(ERROR) << name_ << ": The slice_shape is null."; MS_LOG(ERROR) << name_ << ": The slice_shape is null.";
@ -1061,8 +1061,8 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op,
succ_edges_ = update_pre_edges; succ_edges_ = update_pre_edges;
} }
std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes, std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
const std::vector<bool> &split_flag_list) { const std::vector<bool> &split_flag_list) {
if (shapes.size() != split_flag_list.size()) { if (shapes.size() != split_flag_list.size()) {
MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
<< shapes.size(); << shapes.size();
@ -1070,7 +1070,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap
} }
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
int64_t dev_num = g_device_manager->stage_device_num(); int64_t dev_num = g_device_manager->stage_device_num();
Strategys strategy_v; Strategies strategy_v;
for (size_t i = 0; i != shapes.size(); i++) { for (size_t i = 0; i != shapes.size(); i++) {
if (shapes[i].empty()) { if (shapes[i].empty()) {
MS_LOG(INFO) << "Elements of shapes is empty."; MS_LOG(INFO) << "Elements of shapes is empty.";
@ -1084,7 +1084,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap
strategy_v.push_back(element); strategy_v.push_back(element);
} }
} }
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
void OperatorInfo::ReComputeBatchSplitFlagList() { void OperatorInfo::ReComputeBatchSplitFlagList() {
@ -1122,12 +1122,12 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input
return FAILED; return FAILED;
} }
} }
Strategys stras(inputs_partitions); Strategies stras(inputs_partitions);
(*sp) = std::make_shared<Strategy>(stage_id, stras); (*sp) = std::make_shared<Strategy>(stage_id, stras);
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategies() {
if (inputs_shape_.empty() && InferAttrs() != SUCCESS) { if (inputs_shape_.empty() && InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
} }
@ -1212,7 +1212,7 @@ Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs
// second, get the correct strategy for input0 // second, get the correct strategy for input0
for (auto &sp : *sp_vector) { for (auto &sp : *sp_vector) {
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions input0_strategy = sp->GetInputDim()[0]; Dimensions input0_strategy = sp->GetInputDim()[0];
size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
@ -1261,7 +1261,7 @@ Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &input
// second, get the correct strategy for input1 // second, get the correct strategy for input1
for (auto &sp : *sp_vector) { for (auto &sp : *sp_vector) {
Strategys tmp_strategy; Strategies tmp_strategy;
tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 tmp_strategy.push_back(sp->GetInputDim()[0]); // input0
Dimensions input1_strategy = sp->GetInputDim()[1]; Dimensions input1_strategy = sp->GetInputDim()[1];
@ -1503,7 +1503,7 @@ Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inpu
[stage_id, &indices_mp, &splittable_inputs](const StrategyPtr &sp) { [stage_id, &indices_mp, &splittable_inputs](const StrategyPtr &sp) {
auto sp_strategies = sp->GetInputDim(); auto sp_strategies = sp->GetInputDim();
auto sp_sub_strategy = sp_strategies.at(0); auto sp_sub_strategy = sp_strategies.at(0);
Strategys strategies(splittable_inputs); Strategies strategies(splittable_inputs);
for (size_t i = 0; i < strategies.size(); ++i) { for (size_t i = 0; i < strategies.size(); ++i) {
for (size_t j = 0; j < strategies[i].size(); ++j) { for (size_t j = 0; j < strategies[i].size(); ++j) {
if (splittable_inputs[i][j] == 0) { if (splittable_inputs[i][j] == 0) {

View File

@ -97,7 +97,7 @@ class OperatorInfo {
virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
Shapes GenerateParamStrategy(const Shapes &default_strategy); Shapes GenerateParamStrategy(const Shapes &default_strategy);
virtual std::shared_ptr<Strategys> GenerateBatchStrategies(); virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
virtual void ReComputeBatchSplitFlagList(); virtual void ReComputeBatchSplitFlagList();
void ComputeBatchSplitFlagList(); void ComputeBatchSplitFlagList();
@ -242,7 +242,7 @@ class OperatorInfo {
// The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output
// is used for grad and overload the function. If the output is a scalar, need to override the function too. // is used for grad and overload the function. If the output is a scalar, need to override the function too.
virtual Status InferAsLossDivisor(); virtual Status InferAsLossDivisor();
Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, Status InferSliceShape(const Strategies &inputs_strategy, const Strategies &outputs_strategy,
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); Shapes *inputs_slice_shape, Shapes *outputs_slice_shape);
void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &);
int64_t GetIntAttr(const std::string &attr_name); int64_t GetIntAttr(const std::string &attr_name);
@ -351,9 +351,9 @@ void AddCommOpParamFlag(const CNodePtr &comm_node);
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes, std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
const std::vector<bool> &split_flag_list); const std::vector<bool> &split_flag_list);
std::string StrategyToString(const Strategys &strategy); std::string StrategyToString(const Strategies &strategy);
void PrintStrategy(const StrategyPtr &strategy); void PrintStrategy(const StrategyPtr &strategy);
Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape, Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape,
const Shapes &splittable_inputs, const Shapes &splittable_inputs,

View File

@ -150,7 +150,7 @@ std::vector<StrategyPtr> StackInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
for (size_t i = 0; i < inputs_shape_.size(); ++i) { for (size_t i = 0; i < inputs_shape_.size(); ++i) {
tmp_strategy.push_back(first_input_strategy); tmp_strategy.push_back(first_input_strategy);

View File

@ -37,7 +37,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) { if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid strategy size."; MS_LOG(ERROR) << name_ << ": Invalid strategy size.";
return FAILED; return FAILED;
@ -53,7 +53,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
* device matrix is same with the strategy matrix * device matrix is same with the strategy matrix
*/ */
Status PReLUInfo::InferDevMatrixShape() { Status PReLUInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
input_strategy_ = input_strategy; input_strategy_ = input_strategy;
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;

View File

@ -36,7 +36,7 @@ Status RandomChoiceWithMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys strategies = strategy->GetInputDim(); Strategies strategies = strategy->GetInputDim();
Dimensions input_strategy = strategies[0]; Dimensions input_strategy = strategies[0];
auto is_shard = [](int64_t val) -> bool { return val != 1; }; auto is_shard = [](int64_t val) -> bool { return val != 1; };
if (std::any_of(input_strategy.begin(), input_strategy.end(), is_shard)) { if (std::any_of(input_strategy.begin(), input_strategy.end(), is_shard)) {
@ -66,7 +66,7 @@ Status RandomChoiceWithMaskInfo::InferTensorMap() {
std::vector<StrategyPtr> RandomChoiceWithMaskInfo::GenerateOpStrategies(int64_t stage_id) { std::vector<StrategyPtr> RandomChoiceWithMaskInfo::GenerateOpStrategies(int64_t stage_id) {
Dimensions input_partitions(inputs_shape_[0].size(), 1); Dimensions input_partitions(inputs_shape_[0].size(), 1);
Strategys strategies = {input_partitions}; Strategies strategies = {input_partitions};
std::vector<StrategyPtr> sp_vector; std::vector<StrategyPtr> sp_vector;
(void)sp_vector.emplace_back(std::make_shared<Strategy>(stage_id, strategies)); (void)sp_vector.emplace_back(std::make_shared<Strategy>(stage_id, strategies));
return sp_vector; return sp_vector;

View File

@ -65,7 +65,7 @@ Status RangeInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status RangeInfo::InferDevMatrixShape() { Status RangeInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
dev_matrix_shape_ = stra[0]; dev_matrix_shape_ = stra[0];
split_num_ = stra[0][0]; split_num_ = stra[0][0];
return SUCCESS; return SUCCESS;

View File

@ -33,7 +33,7 @@ namespace parallel {
Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
Status ReduceMethod::InferDevMatrixShape() { Status ReduceMethod::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
@ -400,10 +400,10 @@ Status ReduceMethod::InferTensorInfo() {
// infer slice shape // infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape; Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim(); Strategies inputs_strategy = strategy_->GetInputDim();
Dimensions output_strategy = InferOutputStrategy(); Dimensions output_strategy = InferOutputStrategy();
Strategys outputs_strategy = {output_strategy}; Strategies outputs_strategy = {output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED; return FAILED;
} }
@ -478,7 +478,7 @@ Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<int64_t> dim_list = reduce_dim(); std::vector<int64_t> dim_list = reduce_dim();
MS_ASSERT(dim_list.size() == 1); MS_ASSERT(dim_list.size() == 1);
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
MS_ASSERT(stra.size() == 1); MS_ASSERT(stra.size() == 1);
Shape input_strategy = stra.at(0); Shape input_strategy = stra.at(0);
MS_ASSERT(dim_list.at(0) < input_strategy.size()); MS_ASSERT(dim_list.at(0) < input_strategy.size());
@ -510,10 +510,10 @@ Status ArgMaxWithValueInfo::InferTensorInfo() {
// infer slice shape // infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape; Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim(); Strategies inputs_strategy = strategy_->GetInputDim();
Dimensions output_strategy = InferOutputStrategy(); Dimensions output_strategy = InferOutputStrategy();
Strategys outputs_strategy = {output_strategy, output_strategy}; Strategies outputs_strategy = {output_strategy, output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED; return FAILED;
} }
@ -627,7 +627,7 @@ Status ArgmaxInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<int64_t> dim_list = reduce_dim(); std::vector<int64_t> dim_list = reduce_dim();
MS_ASSERT(dim_list.size() == 1); MS_ASSERT(dim_list.size() == 1);
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
MS_ASSERT(stra.size() == 1); MS_ASSERT(stra.size() == 1);
Shape input_strategy = stra.at(0); Shape input_strategy = stra.at(0);
MS_ASSERT(dim_list.at(0) < input_strategy.size()); MS_ASSERT(dim_list.at(0) < input_strategy.size());
@ -689,7 +689,7 @@ Status SquareSumAllInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
Shape input_a_shape = inputs_shape_.at(0); Shape input_a_shape = inputs_shape_.at(0);
@ -707,7 +707,7 @@ Status SquareSumAllInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status SquareSumAllInfo::InferDevMatrixShape() { Status SquareSumAllInfo::InferDevMatrixShape() {
Strategys strategy = strategy_->GetInputDim(); Strategies strategy = strategy_->GetInputDim();
Dimensions sub_a_strategy = strategy.at(0); Dimensions sub_a_strategy = strategy.at(0);
Shape dev_shape; Shape dev_shape;
for (size_t i = 0; i < sub_a_strategy.size(); ++i) { for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
@ -736,10 +736,10 @@ Status SquareSumAllInfo::InferTensorInfo() {
// infer slice shape // infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape; Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim(); Strategies inputs_strategy = strategy_->GetInputDim();
Dimensions output_strategy = InferOutputStrategy(); Dimensions output_strategy = InferOutputStrategy();
Strategys outputs_strategy = {output_strategy, output_strategy}; Strategies outputs_strategy = {output_strategy, output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED; return FAILED;
} }

View File

@ -38,7 +38,7 @@ Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
if (input_strategy[1] != 1) { if (input_strategy[1] != 1) {
MS_LOG(ERROR) << name_ << "The second dimension is not splitable."; MS_LOG(ERROR) << name_ << "The second dimension is not splitable.";
@ -65,7 +65,7 @@ std::vector<StrategyPtr> ReLUV2Info::GenerateOpStrategies(int64_t stage_id) {
} }
Status ReLUV2Info::InferDevMatrixShape() { Status ReLUV2Info::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;

View File

@ -37,7 +37,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStr
* only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
*/ */
Status ReshapeInfo::InferDevMatrixShape() { Status ReshapeInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
input_strategy_ = stra.at(0); input_strategy_ = stra.at(0);
dev_matrix_shape_ = stra.at(0); dev_matrix_shape_ = stra.at(0);
return SUCCESS; return SUCCESS;
@ -195,8 +195,8 @@ Status ReshapeInfo::InferTensorMap() {
* the output tensor strategy is the same as input tensor strategy * the output tensor strategy is the same as input tensor strategy
* only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
*/ */
Strategys ReshapeInfo::GetOutputsStrategy() { Strategies ReshapeInfo::GetOutputsStrategy() {
Strategys outputs_strategy; Strategies outputs_strategy;
Dimensions strategy; Dimensions strategy;
for (size_t j = 0; j < outputs_shape_[0].size(); ++j) { for (size_t j = 0; j < outputs_shape_[0].size(); ++j) {
strategy.push_back(1); strategy.push_back(1);
@ -269,8 +269,8 @@ Status ReshapeInfo::InferTensorInfo() {
} }
Shapes inputs_slice_shape, outputs_slice_shape; Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim(); Strategies inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = GetOutputsStrategy(); Strategies outputs_strategy = GetOutputsStrategy();
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED; return FAILED;
} }
@ -460,7 +460,7 @@ Status ReshapeInfo::GenerateStrategyCosts(const std::vector<std::shared_ptr<Stra
MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; MS_LOG(ERROR) << "Infer strategy by tensor_info failed";
return FAILED; return FAILED;
} }
Strategys stra_inputs = {stra}; Strategies stra_inputs = {stra};
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
if (is_next_reshape) { if (is_next_reshape) {
SetOutputLayout(pre_out_tensor_info.tensor_layout()); SetOutputLayout(pre_out_tensor_info.tensor_layout());

View File

@ -87,7 +87,7 @@ class ReshapeInfo : public OperatorInfo {
Status InferDevMatrixShape() override; Status InferDevMatrixShape() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
Status GetAttrs() override; Status GetAttrs() override;
Strategys GetOutputsStrategy(); Strategies GetOutputsStrategy();
private: private:
Status GetParameterInput(); Status GetParameterInput();

View File

@ -40,7 +40,7 @@ Status ROIAlignInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys strategies = strategy->GetInputDim(); Strategies strategies = strategy->GetInputDim();
auto features_strategy = strategies.at(0); auto features_strategy = strategies.at(0);
auto rois_strategy = strategies.at(1); auto rois_strategy = strategies.at(1);
if (features_strategy[2] != 1 || features_strategy[3] != 1) { if (features_strategy[2] != 1 || features_strategy[3] != 1) {

View File

@ -158,7 +158,7 @@ std::vector<StrategyPtr> ScatterUpdateInfo::GenerateOpStrategies(int64_t stage_i
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
Dimensions indices_strategy(inputs_shape_[1].size(), 1); Dimensions indices_strategy(inputs_shape_[1].size(), 1);
// updates_strategy = indices_strategy + input_strategy[1:] // updates_strategy = indices_strategy + input_strategy[1:]

View File

@ -115,7 +115,7 @@ std::vector<StrategyPtr> SelectInfo::GenerateOpStrategies(int64_t stage_id) {
if ((sp == nullptr) || sp->GetInputDim().empty()) { if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
} }
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
for (size_t i = 0; i < inputs_shape_.size(); ++i) { for (size_t i = 0; i < inputs_shape_.size(); ++i) {
tmp_strategy.push_back(first_input_strategy); tmp_strategy.push_back(first_input_strategy);

View File

@ -149,7 +149,7 @@ Status SliceInfo::InferMirrorOps() {
} }
// Note: if the batch dimension is not fully fetched, the batch strategy may not work. // Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<Strategys> SliceInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> SliceInfo::GenerateBatchStrategies() {
split_flag_list_ = {true}; split_flag_list_ = {true};
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
} }

View File

@ -40,7 +40,7 @@ class SliceInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override; Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
protected: protected:
Status GetAttrs() override; Status GetAttrs() override;

View File

@ -144,7 +144,7 @@ std::vector<StrategyPtr> SplitInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector; return sp_vector;
} }
std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> SplitInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
} }
@ -157,8 +157,8 @@ std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() {
input_strategy[0] = stage_device_size_; input_strategy[0] = stage_device_size_;
} }
} }
Strategys strategy_v = {input_strategy}; Strategies strategy_v = {input_strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
Status SplitInfo::InferAsLossDivisor() { Status SplitInfo::InferAsLossDivisor() {

View File

@ -36,7 +36,7 @@ class SplitInfo : public OperatorInfo {
~SplitInfo() override = default; ~SplitInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
Status SetCostUnderStrategy(const StrategyPtr &) override; Status SetCostUnderStrategy(const StrategyPtr &) override;
protected: protected:

View File

@ -249,7 +249,7 @@ Status StridedSliceInfo::InferMirrorOps() {
} }
// Note: if the batch dimension is not fully fetched, the batch strategy may not work. // Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> StridedSliceInfo::GenerateBatchStrategies() {
split_flag_list_ = {true}; split_flag_list_ = {true};
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
} }

View File

@ -39,7 +39,7 @@ class StridedSliceInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override; Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
void ComputeBeginMask(int64_t begin_mask_); void ComputeBeginMask(int64_t begin_mask_);
void ComputeEndMask(int64_t end_mask_); void ComputeEndMask(int64_t end_mask_);
void ComputeEllipsisMask(int64_t ellipsis_mask_); void ComputeEllipsisMask(int64_t ellipsis_mask_);

View File

@ -117,7 +117,7 @@ Status TensorDotInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (stra.size() != 2) { if (stra.size() != 2) {
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
return FAILED; return FAILED;
@ -148,7 +148,7 @@ Status TensorDotInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status TensorDotInfo::InferDevMatrixShape() { Status TensorDotInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_a_strategy = stra.at(0); Dimensions input_a_strategy = stra.at(0);
Dimensions input_b_strategy = stra.at(1); Dimensions input_b_strategy = stra.at(1);
@ -306,7 +306,7 @@ Status TensorDotInfo::InferTensorMap() {
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> TensorDotInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
} }
@ -339,8 +339,8 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE"; MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE";
} }
Strategys strategy = {input_a_strategy, input_b_strategy}; Strategies strategy = {input_a_strategy, input_b_strategy};
return std::make_shared<Strategys>(strategy); return std::make_shared<Strategies>(strategy);
} }
std::vector<StrategyPtr> TensorDotInfo::GenerateOpStrategies(int64_t) { std::vector<StrategyPtr> TensorDotInfo::GenerateOpStrategies(int64_t) {

View File

@ -45,7 +45,7 @@ class TensorDotInfo : public OperatorInfo {
~TensorDotInfo() override = default; ~TensorDotInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
size_t input1_shape_size, StrategyPtr *sp); size_t input1_shape_size, StrategyPtr *sp);

View File

@ -179,7 +179,7 @@ void TileInfo::UpdateMultiples() {
void TileInfo::ReplaceNodeInputOrAttrs() { UpdateMultiples(); } void TileInfo::ReplaceNodeInputOrAttrs() { UpdateMultiples(); }
std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> TileInfo::GenerateBatchStrategies() {
if (InferAttrs() != SUCCESS) { if (InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
} }

View File

@ -39,7 +39,7 @@ class TileInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override; Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
void UpdateMultiples(); void UpdateMultiples();
void ReplaceNodeInputOrAttrs() override; void ReplaceNodeInputOrAttrs() override;

View File

@ -29,7 +29,7 @@ Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &st
} }
Status TmpIdentityInfo::InferDevMatrixShape() { Status TmpIdentityInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
return SUCCESS; return SUCCESS;

View File

@ -30,7 +30,7 @@ namespace parallel {
Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
Status TransposeInfo::InferDevMatrixShape() { Status TransposeInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
input_strategy_ = stra.at(0); input_strategy_ = stra.at(0);
for (auto &iter : input_strategy_) { for (auto &iter : input_strategy_) {
dev_matrix_shape_.push_back(iter); dev_matrix_shape_.push_back(iter);

View File

@ -171,14 +171,14 @@ std::vector<StrategyPtr> UniformCandidateSamplerInfo::GenerateOpStrategies(int64
return sp_vector; return sp_vector;
} }
std::shared_ptr<Strategys> UniformCandidateSamplerInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> UniformCandidateSamplerInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
} }
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
Dimensions input_strategy(inputs_shape_[0].size(), 1); Dimensions input_strategy(inputs_shape_[0].size(), 1);
Strategys strategy_v = {input_strategy}; Strategies strategy_v = {input_strategy};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {

View File

@ -44,7 +44,7 @@ class UniformCandidateSamplerInfo : public OperatorInfo {
~UniformCandidateSamplerInfo() override = default; ~UniformCandidateSamplerInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
Status SetCostUnderStrategy(const StrategyPtr &) override; Status SetCostUnderStrategy(const StrategyPtr &) override;
Status InferAsLossDivisor() override; Status InferAsLossDivisor() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;

View File

@ -60,7 +60,7 @@ Status UniqueInfo::InferDevMatrixShape() {
} }
Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) { Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) {
Strategys stras = strategy->GetInputDim(); Strategies stras = strategy->GetInputDim();
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
} }

View File

@ -73,7 +73,7 @@ Status UnsortedSegmentOpInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, {inputs_shape_.at(0), inputs_shape_.at(1)}) != SUCCESS) { if (CheckStrategyValue(strategy, {inputs_shape_.at(0), inputs_shape_.at(1)}) != SUCCESS) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
Shape input_a_shape = inputs_shape_.at(0); Shape input_a_shape = inputs_shape_.at(0);
@ -91,7 +91,7 @@ Status UnsortedSegmentOpInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status UnsortedSegmentOpInfo::InferDevMatrixShape() { Status UnsortedSegmentOpInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
dev_matrix_shape_ = stra.at(0); dev_matrix_shape_ = stra.at(0);
return SUCCESS; return SUCCESS;
} }
@ -173,7 +173,7 @@ std::vector<StrategyPtr> UnsortedSegmentOpInfo::GenerateOpStrategies(int64_t sta
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed."; MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
} }
for (auto &sp : sp_vector) { for (auto &sp : sp_vector) {
Strategys tmp_strategy; Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions first_input_strategy = sp->GetInputDim()[0];
Dimensions second_input_strategy; Dimensions second_input_strategy;
for (size_t i = 0; i < inputs_shape_[1].size(); ++i) { for (size_t i = 0; i < inputs_shape_[1].size(); ++i) {
@ -214,7 +214,7 @@ Status UnsortedSegmentOpInfo::SetCostUnderStrategy(const StrategyPtr &strategy)
return SetCostUnderStrategyBase(strategy); return SetCostUnderStrategyBase(strategy);
} }
std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() { std::shared_ptr<Strategies> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) { if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is "
<< inputs_shape_.size(); << inputs_shape_.size();
@ -233,8 +233,8 @@ std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
for (size_t i = 1; i < inputs_shape_[1].size(); i++) { for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
strategy_b.push_back(1); strategy_b.push_back(1);
} }
Strategys strategy_v = {strategy_a, strategy_b}; Strategies strategy_v = {strategy_a, strategy_b};
return std::make_shared<Strategys>(strategy_v); return std::make_shared<Strategies>(strategy_v);
} }
// When the index is splited, the graph should be replaced // When the index is splited, the graph should be replaced

View File

@ -47,7 +47,7 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategies> GenerateBatchStrategies() override;
protected: protected:
std::string reduce_method_; std::string reduce_method_;

View File

@ -35,7 +35,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (stra.size() < 1) { if (stra.size() < 1) {
MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1."; MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1.";
return FAILED; return FAILED;
@ -84,7 +84,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status VirtualDatasetInfo::InferDevMatrixShape() { Status VirtualDatasetInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategies stra = strategy_->GetInputDim();
dev_matrix_shape_ = stra[max_size_strategy_dim_]; dev_matrix_shape_ = stra[max_size_strategy_dim_];
return SUCCESS; return SUCCESS;
} }
@ -162,7 +162,7 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t stage_id) { std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t stage_id) {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
StrategyPtr sp; StrategyPtr sp;
Strategys strategy; Strategies strategy;
if (!ParallelContext::GetInstance()->dataset_strategy().empty()) { if (!ParallelContext::GetInstance()->dataset_strategy().empty()) {
strategy = ParallelContext::GetInstance()->dataset_strategy(); strategy = ParallelContext::GetInstance()->dataset_strategy();
} else { } else {

View File

@ -33,9 +33,9 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Strategys stra = strategy->GetInputDim(); Strategies stra = strategy->GetInputDim();
if (stra.size() != 1) { if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": Strategys size must be 1."; MS_LOG(ERROR) << name_ << ": Strategies size must be 1.";
return FAILED; return FAILED;
} }
Dimensions strategy_first = stra.at(0); Dimensions strategy_first = stra.at(0);
@ -53,7 +53,7 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<StrategyPtr> VirtualOutputInfo::GenerateOpStrategies(int64_t stage_id) { std::vector<StrategyPtr> VirtualOutputInfo::GenerateOpStrategies(int64_t stage_id) {
StrategyPtr sp; StrategyPtr sp;
Strategys strategy; Strategies strategy;
bool full_batch = ParallelContext::GetInstance()->full_batch(); bool full_batch = ParallelContext::GetInstance()->full_batch();
size_t total_dev_num; size_t total_dev_num;
if (full_batch) { if (full_batch) {

View File

@ -1250,7 +1250,7 @@ StrategyPtr ExtractStrategy(const ValuePtr &stra) {
MS_LOG(INFO) << "Extract information: strategy " << stra->ToString(); MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
if (var->size() > 0) { if (var->size() > 0) {
std::vector<ValuePtr> elements = var->value(); std::vector<ValuePtr> elements = var->value();
Strategys strategy; Strategies strategy;
for (uint64_t index = 0; index < elements.size(); ++index) { for (uint64_t index = 0; index < elements.size(); ++index) {
Dimensions dim; Dimensions dim;
if (elements[index]->isa<ValueSequence>()) { if (elements[index]->isa<ValueSequence>()) {
@ -1725,7 +1725,7 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
MS_EXCEPTION_IF_NULL(operator_); MS_EXCEPTION_IF_NULL(operator_);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
StrategyPtr strategyPtr; StrategyPtr strategyPtr;
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies(); std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategies();
MS_EXCEPTION_IF_NULL(strategy_v_ptr); MS_EXCEPTION_IF_NULL(strategy_v_ptr);
strategyPtr = NewStrategy(0, *strategy_v_ptr); strategyPtr = NewStrategy(0, *strategy_v_ptr);
std::vector<ValuePtr> elements; std::vector<ValuePtr> elements;

View File

@ -31,13 +31,13 @@ namespace parallel {
#define MIN_SLICE_NUM 1 #define MIN_SLICE_NUM 1
using Dimensions = Shape; using Dimensions = Shape;
using Strategys = std::vector<Dimensions>; using Strategies = std::vector<Dimensions>;
class Strategy; class Strategy;
using StrategyPtr = std::shared_ptr<Strategy>; using StrategyPtr = std::shared_ptr<Strategy>;
class Strategy { class Strategy {
public: public:
Strategy(int64_t stage, Strategys inputs) Strategy(int64_t stage, Strategies inputs)
: stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {} : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) { Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
@ -52,14 +52,14 @@ class Strategy {
~Strategy() = default; ~Strategy() = default;
size_t GetInputNumber() const { return inputs_.size(); } size_t GetInputNumber() const { return inputs_.size(); }
Strategys GetInputDim() const { return inputs_; } Strategies GetInputDim() const { return inputs_; }
int64_t GetInputStage() const { return stage_; } int64_t GetInputStage() const { return stage_; }
void ExpandInputDimFromOneToTwo() { void ExpandInputDimFromOneToTwo() {
if (inputs_.size() == 1) { if (inputs_.size() == 1) {
inputs_.push_back(inputs_[0]); inputs_.push_back(inputs_[0]);
} }
} }
void ResetInputs(const Strategys &input) { inputs_ = input; } void ResetInputs(const Strategies &input) { inputs_ = input; }
std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; } std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
size_t GetInternalSize() const { return internal_size_; } size_t GetInternalSize() const { return internal_size_; }
@ -103,12 +103,12 @@ class Strategy {
const int64_t stage_; const int64_t stage_;
// The size of Dimensions must be equal to inputs_ tensor dimension. // The size of Dimensions must be equal to inputs_ tensor dimension.
Strategys inputs_; Strategies inputs_;
size_t internal_size_ = 0; size_t internal_size_ = 0;
std::vector<StrategyPtr> internal_stragies_; std::vector<StrategyPtr> internal_stragies_;
}; };
inline StrategyPtr NewStrategy(const int64_t stage, const Strategys &inputs) { inline StrategyPtr NewStrategy(const int64_t stage, const Strategies &inputs) {
return std::make_shared<Strategy>(stage, inputs); return std::make_shared<Strategy>(stage, inputs);
} }
} // namespace parallel } // namespace parallel

View File

@ -123,7 +123,7 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys(); straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
auto stage = (int64_t)parallel_strategys.stage(); auto stage = (int64_t)parallel_strategys.stage();
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size()); size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
Strategys strategy_inputs; Strategies strategy_inputs;
for (size_t j = 0; j < strategys_num; j++) { for (size_t j = 0; j < strategys_num; j++) {
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j)); straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
Dimensions dimension; Dimensions dimension;

View File

@ -523,7 +523,7 @@ py::list GraphExecutorPy::GetParallelParameterNameList(const std::string &phase)
return mindspore::parallel::GetParallelParameterNameListFromGraph(graph); return mindspore::parallel::GetParallelParameterNameListFromGraph(graph);
} }
void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) { void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategies &strategy) {
MS_LOG(DEBUG) << "SetCNodeStrategy!"; MS_LOG(DEBUG) << "SetCNodeStrategy!";
stra_dict_[phase_][py::str(name)] = strategy; stra_dict_[phase_][py::str(name)] = strategy;
} }

View File

@ -108,7 +108,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
py::dict GetParallelGraphInfo(const std::string &phase); py::dict GetParallelGraphInfo(const std::string &phase);
py::dict GetCNodeStrategy(const std::string &phase); py::dict GetCNodeStrategy(const std::string &phase);
py::list GetParallelParameterNameList(const std::string &phase); py::list GetParallelParameterNameList(const std::string &phase);
void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy); void SetCNodeStrategy(const std::string &name, const parallel::Strategies &strategy);
size_t GetNumOpsInfo(const std::string &phase); size_t GetNumOpsInfo(const std::string &phase);
void SetNumOpsInfo(size_t); void SetNumOpsInfo(size_t);
py::dict GetAllreduceFusion(const std::string &phase); py::dict GetAllreduceFusion(const std::string &phase);

View File

@ -40,7 +40,7 @@ constexpr auto kAnfConvWeight = 2;
constexpr auto kAnfConvBias = 3; constexpr auto kAnfConvBias = 3;
int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) { int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
int split_count = 0; int split_count = 0;
Strategys strategys = strategy.strategys; Strategies strategys = strategy.strategys;
MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR); MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR);
MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR); MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR);
MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR); MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR);
@ -281,7 +281,7 @@ std::shared_ptr<ops::Conv2DFusion> Conv2DInfo::GetNewConvPrimitive(const api::Sh
prim->set_pad_list(conv_prim->get_pad_list()); prim->set_pad_list(conv_prim->get_pad_list());
prim->set_stride(conv_prim->get_stride()); prim->set_stride(conv_prim->get_stride());
prim->set_activation_type(conv_prim->get_activation_type()); prim->set_activation_type(conv_prim->get_activation_type());
Strategys strategys = strategy_.strategys; Strategies strategys = strategy_.strategys;
size_t dev_num = strategy_.dev_num; size_t dev_num = strategy_.dev_num;
switch (split_mode_) { switch (split_mode_) {
case SplitH: { case SplitH: {
@ -327,7 +327,7 @@ int Conv2DInfo::ConstructOutputCNodes(const api::SharedPtr<ops::Conv2DFusion> &c
const std::vector<AnfNodePtr> &kernel_split_outputs, const std::vector<AnfNodePtr> &kernel_split_outputs,
const std::vector<AnfNodePtr> &bias_split_outputs) { const std::vector<AnfNodePtr> &bias_split_outputs) {
MS_ASSERT(conv_prim != nullptr); MS_ASSERT(conv_prim != nullptr);
Strategys strategys = strategy_.strategys; Strategies strategys = strategy_.strategys;
size_t dev_num = strategy_.dev_num; size_t dev_num = strategy_.dev_num;
int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0); int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0);
int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0); int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0);

View File

@ -131,7 +131,7 @@ int DepthwiseConv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
// for depthwise conv2d, we only split channel && include split feature map, weight && bias // for depthwise conv2d, we only split channel && include split feature map, weight && bias
// so just get the ratio from strategy // so just get the ratio from strategy
int split_count = 0; int split_count = 0;
Strategys strategys = strategy.strategys; Strategies strategys = strategy.strategys;
MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR); MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR);
MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR); MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR);
MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR); MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR);

View File

@ -76,7 +76,7 @@ std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(const st
default: default:
return split_strategys; return split_strategys;
} }
opt::Strategys strategys = {split_feature_map, split_weight}; opt::Strategies strategys = {split_feature_map, split_weight};
for (const auto &supported_parallel_op : kParallelOpNames) { for (const auto &supported_parallel_op : kParallelOpNames) {
split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode}; split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode};
} }

View File

@ -37,7 +37,7 @@ const std::vector<int64_t> kSplitDefaultRatio = {0, 0};
// user's device to split, only split to cpu && gpu, no support npu // user's device to split, only split to cpu && gpu, no support npu
const std::vector<std::string> kSplitDevTypes = {"cpu", "gpu"}; const std::vector<std::string> kSplitDevTypes = {"cpu", "gpu"};
using Strategys = std::vector<std::vector<std::vector<int64_t>>>; using Strategies = std::vector<std::vector<std::vector<int64_t>>>;
constexpr auto kDeviceTypeNone = -1; constexpr auto kDeviceTypeNone = -1;
// strategy format is NHWC-KHWC // strategy format is NHWC-KHWC
@ -71,7 +71,7 @@ enum SplitMode {
}; };
struct SplitStrategy { struct SplitStrategy {
Strategys strategys{}; Strategies strategys{};
std::vector<std::string> dev_types{}; std::vector<std::string> dev_types{};
size_t dev_num{0}; size_t dev_num{0};
SplitMode split_mode_{NoSplit}; SplitMode split_mode_{NoSplit};

View File

@ -64,7 +64,7 @@ void TestActivationInfo::SetUp() {
} }
TEST_F(TestActivationInfo, InferDevMatrixShape1) { TEST_F(TestActivationInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 8, 16}}; Strategies inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy, nullptr); activation->Init(strategy, nullptr);
@ -75,7 +75,7 @@ TEST_F(TestActivationInfo, InferDevMatrixShape1) {
} }
TEST_F(TestActivationInfo, InferSliceShape1) { TEST_F(TestActivationInfo, InferSliceShape1) {
Strategys str = {{2, 4, 8, 16}}; Strategies str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
activation->Init(strategy, nullptr); activation->Init(strategy, nullptr);
@ -96,7 +96,7 @@ TEST_F(TestActivationInfo, InferSliceShape1) {
} }
TEST_F(TestActivationInfo, GetTensorLayout1) { TEST_F(TestActivationInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 8, 16}}; Strategies str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
activation->Init(strategy, nullptr); activation->Init(strategy, nullptr);
@ -117,7 +117,7 @@ TEST_F(TestActivationInfo, GetTensorLayout1) {
} }
TEST_F(TestActivationInfo, GetForwardOp1) { TEST_F(TestActivationInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 8, 16}}; Strategies inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy, nullptr); activation->Init(strategy, nullptr);
@ -128,7 +128,7 @@ TEST_F(TestActivationInfo, GetForwardOp1) {
} }
TEST_F(TestActivationInfo, GetMirrorOPs1) { TEST_F(TestActivationInfo, GetMirrorOPs1) {
Strategys inputs = {{1, 4, 8, 16}}; Strategies inputs = {{1, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy, nullptr); activation->Init(strategy, nullptr);
@ -148,7 +148,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs1) {
} }
TEST_F(TestActivationInfo, GetMirrorOPs2) { TEST_F(TestActivationInfo, GetMirrorOPs2) {
Strategys inputs = {{2, 4, 8, 16}}; Strategies inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy, nullptr); activation->Init(strategy, nullptr);
@ -161,7 +161,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs2) {
TEST_F(TestActivationInfo, CheckStrategy1) { TEST_F(TestActivationInfo, CheckStrategy1) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = activation->Init(strategy, nullptr); Status ret = activation->Init(strategy, nullptr);
@ -170,7 +170,7 @@ TEST_F(TestActivationInfo, CheckStrategy1) {
TEST_F(TestActivationInfo, CheckStrategy2) { TEST_F(TestActivationInfo, CheckStrategy2) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
Strategys inputs = {{2, 4, 8}}; Strategies inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = activation->Init(strategy, nullptr); Status ret = activation->Init(strategy, nullptr);

View File

@ -101,7 +101,7 @@ TEST_F(TestActivation, test_softmax_strategies) {
ASSERT_NE(sp, nullptr); ASSERT_NE(sp, nullptr);
Cost cost = *(swc->cost_list[0]); Cost cost = *(swc->cost_list[0]);
Strategys stra = sp->GetInputDim(); Strategies stra = sp->GetInputDim();
ASSERT_GT(stra.size(), 0); ASSERT_GT(stra.size(), 0);
Dimensions input0_stra = stra[0]; Dimensions input0_stra = stra[0];
ASSERT_GT(input0_stra.size(), 2); ASSERT_GT(input0_stra.size(), 2);

View File

@ -63,7 +63,7 @@ void TestGeLUInfo::SetUp() {
} }
TEST_F(TestGeLUInfo, InferDevMatrixShape1) { TEST_F(TestGeLUInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
gelu->Init(strategy, nullptr); gelu->Init(strategy, nullptr);
@ -74,7 +74,7 @@ TEST_F(TestGeLUInfo, InferDevMatrixShape1) {
} }
TEST_F(TestGeLUInfo, InferSliceShape1) { TEST_F(TestGeLUInfo, InferSliceShape1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
gelu->Init(strategy, nullptr); gelu->Init(strategy, nullptr);
@ -95,7 +95,7 @@ TEST_F(TestGeLUInfo, InferSliceShape1) {
} }
TEST_F(TestGeLUInfo, GetTensorLayout1) { TEST_F(TestGeLUInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
gelu->Init(strategy, nullptr); gelu->Init(strategy, nullptr);
@ -116,7 +116,7 @@ TEST_F(TestGeLUInfo, GetTensorLayout1) {
} }
TEST_F(TestGeLUInfo, GetForwardOp1) { TEST_F(TestGeLUInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
gelu->Init(strategy, nullptr); gelu->Init(strategy, nullptr);
@ -127,7 +127,7 @@ TEST_F(TestGeLUInfo, GetForwardOp1) {
} }
TEST_F(TestGeLUInfo, GetMirrorOPs1) { TEST_F(TestGeLUInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
gelu->Init(strategy, nullptr); gelu->Init(strategy, nullptr);
@ -140,7 +140,7 @@ TEST_F(TestGeLUInfo, GetMirrorOPs1) {
TEST_F(TestGeLUInfo, CheckStrategy1) { TEST_F(TestGeLUInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = gelu->Init(strategy, nullptr); Status ret = gelu->Init(strategy, nullptr);
@ -149,7 +149,7 @@ TEST_F(TestGeLUInfo, CheckStrategy1) {
TEST_F(TestGeLUInfo, CheckStrategy2) { TEST_F(TestGeLUInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 8}}; Strategies inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = gelu->Init(strategy, nullptr); Status ret = gelu->Init(strategy, nullptr);
@ -158,7 +158,7 @@ TEST_F(TestGeLUInfo, CheckStrategy2) {
TEST_F(TestGeLUInfo, CheckStrategy3) { TEST_F(TestGeLUInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = gelu->Init(strategy, nullptr); Status ret = gelu->Init(strategy, nullptr);

View File

@ -64,7 +64,7 @@ void TestL2NormalizeInfo::SetUp() {
} }
TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) { TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) {
Strategys inputs = {{4, 1, 8}}; Strategies inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy, nullptr); norm->Init(strategy, nullptr);
@ -75,7 +75,7 @@ TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) {
} }
TEST_F(TestL2NormalizeInfo, InferSliceShape1) { TEST_F(TestL2NormalizeInfo, InferSliceShape1) {
Strategys str = {{4, 1, 8}}; Strategies str = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
norm->Init(strategy, nullptr); norm->Init(strategy, nullptr);
@ -96,7 +96,7 @@ TEST_F(TestL2NormalizeInfo, InferSliceShape1) {
} }
TEST_F(TestL2NormalizeInfo, GetTensorLayout1) { TEST_F(TestL2NormalizeInfo, GetTensorLayout1) {
Strategys str = {{4, 1, 8}}; Strategies str = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
norm->Init(strategy, nullptr); norm->Init(strategy, nullptr);
@ -117,7 +117,7 @@ TEST_F(TestL2NormalizeInfo, GetTensorLayout1) {
} }
TEST_F(TestL2NormalizeInfo, GetForwardOp1) { TEST_F(TestL2NormalizeInfo, GetForwardOp1) {
Strategys inputs = {{4, 1, 8}}; Strategies inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy, nullptr); norm->Init(strategy, nullptr);
@ -128,7 +128,7 @@ TEST_F(TestL2NormalizeInfo, GetForwardOp1) {
} }
TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) { TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) {
Strategys inputs = {{4, 1, 8}}; Strategies inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy, nullptr); norm->Init(strategy, nullptr);
@ -140,7 +140,7 @@ TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy1) { TEST_F(TestL2NormalizeInfo, CheckStrategy1) {
Strategys inputs = {{4, 1, 8}, {4, 1, 8}}; Strategies inputs = {{4, 1, 8}, {4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy, nullptr); Status ret = norm->Init(strategy, nullptr);
@ -148,7 +148,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy1) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy2) { TEST_F(TestL2NormalizeInfo, CheckStrategy2) {
Strategys inputs = {{4, 2, 3}}; Strategies inputs = {{4, 2, 3}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy, nullptr); Status ret = norm->Init(strategy, nullptr);
@ -156,7 +156,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy2) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy3) { TEST_F(TestL2NormalizeInfo, CheckStrategy3) {
Strategys inputs = {{4, 2, 3, 4}}; Strategies inputs = {{4, 2, 3, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy, nullptr); Status ret = norm->Init(strategy, nullptr);
@ -164,7 +164,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy3) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy4) { TEST_F(TestL2NormalizeInfo, CheckStrategy4) {
Strategys inputs = {{4, 1, 8}}; Strategies inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy, nullptr); Status ret = norm->Init(strategy, nullptr);
@ -172,7 +172,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy4) {
} }
TEST_F(TestL2NormalizeInfo, mirror_ops) { TEST_F(TestL2NormalizeInfo, mirror_ops) {
Strategys inputs = {{2, 1, 8}}; Strategies inputs = {{2, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy, nullptr); norm->Init(strategy, nullptr);

View File

@ -64,7 +64,7 @@ void TestLogSoftmaxInfo::SetUp() {
} }
TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) { TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy, nullptr); log_softmax->Init(strategy, nullptr);
@ -75,7 +75,7 @@ TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) {
} }
TEST_F(TestLogSoftmaxInfo, InferSliceShape1) { TEST_F(TestLogSoftmaxInfo, InferSliceShape1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
log_softmax->Init(strategy, nullptr); log_softmax->Init(strategy, nullptr);
@ -96,7 +96,7 @@ TEST_F(TestLogSoftmaxInfo, InferSliceShape1) {
} }
TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) { TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
log_softmax->Init(strategy, nullptr); log_softmax->Init(strategy, nullptr);
@ -117,7 +117,7 @@ TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) {
} }
TEST_F(TestLogSoftmaxInfo, GetForwardOp1) { TEST_F(TestLogSoftmaxInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy, nullptr); log_softmax->Init(strategy, nullptr);
@ -128,7 +128,7 @@ TEST_F(TestLogSoftmaxInfo, GetForwardOp1) {
} }
TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) { TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy, nullptr); log_softmax->Init(strategy, nullptr);
@ -141,7 +141,7 @@ TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) {
TEST_F(TestLogSoftmaxInfo, CheckStrategy1) { TEST_F(TestLogSoftmaxInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = log_softmax->Init(strategy, nullptr); Status ret = log_softmax->Init(strategy, nullptr);
@ -150,7 +150,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy1) {
TEST_F(TestLogSoftmaxInfo, CheckStrategy2) { TEST_F(TestLogSoftmaxInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 8}}; Strategies inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = log_softmax->Init(strategy, nullptr); Status ret = log_softmax->Init(strategy, nullptr);
@ -159,7 +159,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy2) {
TEST_F(TestLogSoftmaxInfo, CheckStrategy3) { TEST_F(TestLogSoftmaxInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 8, 16}}; Strategies inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = log_softmax->Init(strategy, nullptr); Status ret = log_softmax->Init(strategy, nullptr);
@ -167,7 +167,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy3) {
} }
TEST_F(TestLogSoftmaxInfo, GetDeviceList1) { TEST_F(TestLogSoftmaxInfo, GetDeviceList1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy, nullptr); log_softmax->Init(strategy, nullptr);

View File

@ -97,7 +97,7 @@ void TestMatmulInfo::SetUp() {
/// Description: infer dev matrix /// Description: infer dev matrix
/// Expectation: the dev matrix is right /// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape1) { TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -111,7 +111,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
/// Description: infer dev matrix /// Description: infer dev matrix
/// Expectation: the dev matrix is right /// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape2) { TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
Strategys inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}}; Strategies inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -125,7 +125,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
/// Description: infer dev matrix /// Description: infer dev matrix
/// Expectation: the dev matrix is right /// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape3) { TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
Strategys inputs = {{2, 4, 8, 16}, {1, 16}}; Strategies inputs = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul2->Init(strategy, nullptr); matmul2->Init(strategy, nullptr);
@ -139,7 +139,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
/// Description: infer dev matrix /// Description: infer dev matrix
/// Expectation: the dev matrix is right /// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape4) { TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
Strategys inputs = {{2, 4, 8, 8}, {2, 8}}; Strategies inputs = {{2, 4, 8, 8}, {2, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul2->Init(strategy, nullptr); matmul2->Init(strategy, nullptr);
@ -153,7 +153,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
/// Description: infer dev matrix /// Description: infer dev matrix
/// Expectation: the dev matrix is right /// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape5) { TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
Strategys inputs = {{8, 16}, {2, 4, 1, 16}}; Strategies inputs = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul3->Init(strategy, nullptr); matmul3->Init(strategy, nullptr);
@ -167,7 +167,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
/// Description: infer dev matrix /// Description: infer dev matrix
/// Expectation: the dev matrix is right /// Expectation: the dev matrix is right
TEST_F(TestMatmulInfo, InferDevMatrixShape6) { TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
Strategys inputs = {{8, 8}, {2, 4, 2, 8}}; Strategies inputs = {{8, 8}, {2, 4, 2, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul3->Init(strategy, nullptr); matmul3->Init(strategy, nullptr);
@ -181,7 +181,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
/// Description: infer tensor map /// Description: infer tensor map
/// Expectation: the tensor map is right /// Expectation: the tensor map is right
TEST_F(TestMatmulInfo, InferTensorMap1) { TEST_F(TestMatmulInfo, InferTensorMap1) {
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -209,7 +209,7 @@ TEST_F(TestMatmulInfo, InferTensorMap1) {
/// Description: infer tensor map /// Description: infer tensor map
/// Expectation: the tensor map is right /// Expectation: the tensor map is right
TEST_F(TestMatmulInfo, InferTensorMap2) { TEST_F(TestMatmulInfo, InferTensorMap2) {
Strategys str = {{2, 4, 8, 16}, {1, 16}}; Strategies str = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul2->Init(strategy, nullptr); matmul2->Init(strategy, nullptr);
@ -237,7 +237,7 @@ TEST_F(TestMatmulInfo, InferTensorMap2) {
/// Description: infer tensor map /// Description: infer tensor map
/// Expectation: the tensor map is right /// Expectation: the tensor map is right
TEST_F(TestMatmulInfo, InferTensorMap3) { TEST_F(TestMatmulInfo, InferTensorMap3) {
Strategys str = {{8, 16}, {2, 4, 1, 16}}; Strategies str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul3->Init(strategy, nullptr); matmul3->Init(strategy, nullptr);
@ -265,7 +265,7 @@ TEST_F(TestMatmulInfo, InferTensorMap3) {
/// Description: infer slice shape /// Description: infer slice shape
/// Expectation: the slice shape is right /// Expectation: the slice shape is right
TEST_F(TestMatmulInfo, InferSliceShape1) { TEST_F(TestMatmulInfo, InferSliceShape1) {
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -293,7 +293,7 @@ TEST_F(TestMatmulInfo, InferSliceShape1) {
/// Description: infer slice shape /// Description: infer slice shape
/// Expectation: the slice shape is right /// Expectation: the slice shape is right
TEST_F(TestMatmulInfo, InferSliceShape2) { TEST_F(TestMatmulInfo, InferSliceShape2) {
Strategys str = {{2, 4, 8, 16}, {1, 16}}; Strategies str = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul2->Init(strategy, nullptr); matmul2->Init(strategy, nullptr);
@ -321,7 +321,7 @@ TEST_F(TestMatmulInfo, InferSliceShape2) {
/// Description: infer slice shape /// Description: infer slice shape
/// Expectation: the slice shape is right /// Expectation: the slice shape is right
TEST_F(TestMatmulInfo, InferSliceShape3) { TEST_F(TestMatmulInfo, InferSliceShape3) {
Strategys str = {{8, 16}, {2, 4, 1, 16}}; Strategies str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul3->Init(strategy, nullptr); matmul3->Init(strategy, nullptr);
@ -349,7 +349,7 @@ TEST_F(TestMatmulInfo, InferSliceShape3) {
/// Description: get tensor layout /// Description: get tensor layout
/// Expectation: the tensor layout is right /// Expectation: the tensor layout is right
TEST_F(TestMatmulInfo, GetTensorLayout3) { TEST_F(TestMatmulInfo, GetTensorLayout3) {
Strategys str = {{8, 16}, {2, 4, 1, 16}}; Strategies str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul3->Init(strategy, nullptr); matmul3->Init(strategy, nullptr);
@ -377,7 +377,7 @@ TEST_F(TestMatmulInfo, GetTensorLayout3) {
/// Description: infer forward op /// Description: infer forward op
/// Expectation: the forward op is right /// Expectation: the forward op is right
TEST_F(TestMatmulInfo, GetForwardOp1) { TEST_F(TestMatmulInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -406,7 +406,7 @@ TEST_F(TestMatmulInfo, GetForwardOp1) {
/// Description: infer forward op /// Description: infer forward op
/// Expectation: the forward op is right /// Expectation: the forward op is right
TEST_F(TestMatmulInfo, GetForwardOp2) { TEST_F(TestMatmulInfo, GetForwardOp2) {
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}}; Strategies inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -419,7 +419,7 @@ TEST_F(TestMatmulInfo, GetForwardOp2) {
/// Description: infer virtual_div op /// Description: infer virtual_div op
/// Expectation: the virtual_div op is right /// Expectation: the virtual_div op is right
TEST_F(TestMatmulInfo, GetVirtualDivOp1) { TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -441,7 +441,7 @@ TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
/// Description: infer mirror op /// Description: infer mirror op
/// Expectation: the mirror op is right /// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs1) { TEST_F(TestMatmulInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -463,7 +463,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs1) {
/// Description: infer mirror op /// Description: infer mirror op
/// Expectation: the mirror op is right /// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs2) { TEST_F(TestMatmulInfo, GetMirrorOPs2) {
Strategys inputs = {{2, 4, 1, 16}, {8, 16}}; Strategies inputs = {{2, 4, 1, 16}, {8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul2->Init(strategy, nullptr); matmul2->Init(strategy, nullptr);
@ -485,7 +485,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs2) {
/// Description: infer mirror op /// Description: infer mirror op
/// Expectation: the mirror op is right /// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs3) { TEST_F(TestMatmulInfo, GetMirrorOPs3) {
Strategys inputs = {{8, 16}, {2, 4, 1, 16}}; Strategies inputs = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul3->Init(strategy, nullptr); matmul3->Init(strategy, nullptr);
@ -506,7 +506,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs3) {
/// Description: infer mirror op /// Description: infer mirror op
/// Expectation: the mirror op is right /// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, GetMirrorOPs4) { TEST_F(TestMatmulInfo, GetMirrorOPs4) {
Strategys inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}}; Strategies inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy, nullptr); matmul1->Init(strategy, nullptr);
@ -519,7 +519,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) {
/// Description: init twice /// Description: init twice
/// Expectation: the mirror op is right /// Expectation: the mirror op is right
TEST_F(TestMatmulInfo, InitTwice) { TEST_F(TestMatmulInfo, InitTwice) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
// init twice // init twice
@ -544,7 +544,7 @@ TEST_F(TestMatmulInfo, InitTwice) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy1) { TEST_F(TestMatmulInfo, CheckStrategy1) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy, nullptr); Status ret = matmul1->Init(strategy, nullptr);
@ -556,7 +556,7 @@ TEST_F(TestMatmulInfo, CheckStrategy1) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy2) { TEST_F(TestMatmulInfo, CheckStrategy2) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 4, 8, 16}, {4, 16, 1}}; Strategies inputs = {{2, 4, 8, 16}, {4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy, nullptr); Status ret = matmul1->Init(strategy, nullptr);
@ -568,7 +568,7 @@ TEST_F(TestMatmulInfo, CheckStrategy2) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy3) { TEST_F(TestMatmulInfo, CheckStrategy3) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}}; Strategies inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy, nullptr); Status ret = matmul1->Init(strategy, nullptr);
@ -580,7 +580,7 @@ TEST_F(TestMatmulInfo, CheckStrategy3) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy4) { TEST_F(TestMatmulInfo, CheckStrategy4) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}}; Strategies inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy, nullptr); Status ret = matmul1->Init(strategy, nullptr);
@ -592,7 +592,7 @@ TEST_F(TestMatmulInfo, CheckStrategy4) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy5) { TEST_F(TestMatmulInfo, CheckStrategy5) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy, nullptr); Status ret = matmul1->Init(strategy, nullptr);
@ -604,7 +604,7 @@ TEST_F(TestMatmulInfo, CheckStrategy5) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy6) { TEST_F(TestMatmulInfo, CheckStrategy6) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy, nullptr); Status ret = matmul1->Init(strategy, nullptr);
@ -616,7 +616,7 @@ TEST_F(TestMatmulInfo, CheckStrategy6) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, CheckStrategy7) { TEST_F(TestMatmulInfo, CheckStrategy7) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy, nullptr); Status ret = matmul1->Init(strategy, nullptr);
@ -628,7 +628,7 @@ TEST_F(TestMatmulInfo, CheckStrategy7) {
/// Expectation: return FAILED /// Expectation: return FAILED
TEST_F(TestMatmulInfo, InitFailed) { TEST_F(TestMatmulInfo, InitFailed) {
// matmul4 attr is wrong // matmul4 attr is wrong
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul4->Init(strategy, nullptr); Status ret = matmul4->Init(strategy, nullptr);

View File

@ -64,7 +64,7 @@ void TestOneHotInfo::SetUp() {
} }
TEST_F(TestOneHotInfo, InferDevMatrixShape1) { TEST_F(TestOneHotInfo, InferDevMatrixShape1) {
Strategys inputs = {{8, 1}, {}, {}}; Strategies inputs = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -76,7 +76,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape1) {
} }
TEST_F(TestOneHotInfo, InferDevMatrixShape2) { TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
Strategys inputs = {{4, 1}, {}, {}}; Strategies inputs = {{4, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -88,7 +88,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
} }
TEST_F(TestOneHotInfo, InferDevMatrixShape3) { TEST_F(TestOneHotInfo, InferDevMatrixShape3) {
Strategys inputs = {{4, 2}, {}, {}}; Strategies inputs = {{4, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -100,7 +100,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape3) {
} }
TEST_F(TestOneHotInfo, InferTensorMap2) { TEST_F(TestOneHotInfo, InferTensorMap2) {
Strategys str = {{8, 1}, {}, {}}; Strategies str = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo, InferTensorMap2) {
} }
TEST_F(TestOneHotInfo, InferSliceShape1) { TEST_F(TestOneHotInfo, InferSliceShape1) {
Strategys str = {{8, 1}, {}, {}}; Strategies str = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo, InferSliceShape1) {
} }
TEST_F(TestOneHotInfo, InferSliceShape2) { TEST_F(TestOneHotInfo, InferSliceShape2) {
Strategys str = {{4, 2}, {}, {}}; Strategies str = {{4, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo, InferSliceShape2) {
} }
TEST_F(TestOneHotInfo, InferSliceShape3) { TEST_F(TestOneHotInfo, InferSliceShape3) {
Strategys str = {{2, 2}, {}, {}}; Strategies str = {{2, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -188,7 +188,7 @@ TEST_F(TestOneHotInfo, InferSliceShape3) {
} }
TEST_F(TestOneHotInfo, GetMirrorOPs1) { TEST_F(TestOneHotInfo, GetMirrorOPs1) {
Strategys inputs = {{8, 1}, {}, {}}; Strategies inputs = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy, nullptr); Status status = onehot_info->Init(strategy, nullptr);
@ -199,7 +199,7 @@ TEST_F(TestOneHotInfo, GetMirrorOPs1) {
} }
TEST_F(TestOneHotInfo, CheckStrategy1) { TEST_F(TestOneHotInfo, CheckStrategy1) {
Strategys inputs = {{16}, {}, {}}; Strategies inputs = {{16}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = onehot_info->Init(strategy, nullptr); Status ret = onehot_info->Init(strategy, nullptr);

View File

@ -64,7 +64,7 @@ void TestOneHotInfo2::SetUp() {
} }
TEST_F(TestOneHotInfo2, InferDevMatrixShape1) { TEST_F(TestOneHotInfo2, InferDevMatrixShape1) {
Strategys inputs = {{1, 8}, {}, {}}; Strategies inputs = {{1, 8}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info2->Init(strategy, nullptr); Status status = onehot_info2->Init(strategy, nullptr);
@ -76,7 +76,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape1) {
} }
TEST_F(TestOneHotInfo2, InferDevMatrixShape2) { TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
Strategys inputs = {{1, 4}, {}, {}}; Strategies inputs = {{1, 4}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info2->Init(strategy, nullptr); Status status = onehot_info2->Init(strategy, nullptr);
@ -88,7 +88,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
} }
TEST_F(TestOneHotInfo2, InferDevMatrixShape3) { TEST_F(TestOneHotInfo2, InferDevMatrixShape3) {
Strategys inputs = {{2, 4}, {}, {}}; Strategies inputs = {{2, 4}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info2->Init(strategy, nullptr); Status status = onehot_info2->Init(strategy, nullptr);
@ -100,7 +100,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape3) {
} }
TEST_F(TestOneHotInfo2, InferTensorMap2) { TEST_F(TestOneHotInfo2, InferTensorMap2) {
Strategys str = {{1, 8}, {}, {}}; Strategies str = {{1, 8}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy, nullptr); Status status = onehot_info2->Init(strategy, nullptr);
@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo2, InferTensorMap2) {
} }
TEST_F(TestOneHotInfo2, InferSliceShape1) { TEST_F(TestOneHotInfo2, InferSliceShape1) {
Strategys str = {{1, 8}, {}, {}}; Strategies str = {{1, 8}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy, nullptr); Status status = onehot_info2->Init(strategy, nullptr);
@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape1) {
} }
TEST_F(TestOneHotInfo2, InferSliceShape2) { TEST_F(TestOneHotInfo2, InferSliceShape2) {
Strategys str = {{2, 4}, {}, {}}; Strategies str = {{2, 4}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy, nullptr); Status status = onehot_info2->Init(strategy, nullptr);
@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape2) {
} }
TEST_F(TestOneHotInfo2, InferSliceShape3) { TEST_F(TestOneHotInfo2, InferSliceShape3) {
Strategys str = {{2, 2}, {}, {}}; Strategies str = {{2, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy, nullptr); Status status = onehot_info2->Init(strategy, nullptr);

View File

@ -63,7 +63,7 @@ void TestPowInfo::SetUp() {
} }
TEST_F(TestPowInfo, InferDevMatrixShape1) { TEST_F(TestPowInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy, nullptr); pow->Init(strategy, nullptr);
@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
} }
TEST_F(TestPowInfo, InferSliceShape1) { TEST_F(TestPowInfo, InferSliceShape1) {
Strategys str = {{2, 4, 8}, {2, 4, 8}}; Strategies str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy, nullptr); pow->Init(strategy, nullptr);
@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
} }
TEST_F(TestPowInfo, GetTensorLayout1) { TEST_F(TestPowInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 8}, {2, 4, 8}}; Strategies str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy, nullptr); pow->Init(strategy, nullptr);
@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
} }
TEST_F(TestPowInfo, GetForwardOp1) { TEST_F(TestPowInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy, nullptr); pow->Init(strategy, nullptr);
@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
} }
TEST_F(TestPowInfo, GetMirrorOPs1) { TEST_F(TestPowInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy, nullptr); pow->Init(strategy, nullptr);
@ -139,7 +139,7 @@ TEST_F(TestPowInfo, GetMirrorOPs1) {
} }
TEST_F(TestPowInfo, CheckStrategy1) { TEST_F(TestPowInfo, CheckStrategy1) {
Strategys inputs = {{2, 2, 8}, {2, 4, 8}}; Strategies inputs = {{2, 2, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy, nullptr); Status ret = pow->Init(strategy, nullptr);
@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
} }
TEST_F(TestPowInfo, CheckStrategy2) { TEST_F(TestPowInfo, CheckStrategy2) {
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}}; Strategies inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy, nullptr); Status ret = pow->Init(strategy, nullptr);
@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
} }
TEST_F(TestPowInfo, CheckStrategy3) { TEST_F(TestPowInfo, CheckStrategy3) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy, nullptr); Status ret = pow->Init(strategy, nullptr);

View File

@ -64,7 +64,7 @@ void TestPReLUInfo::SetUp() {
} }
TEST_F(TestPReLUInfo, InferDevMatrixShape1) { TEST_F(TestPReLUInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 1, 8, 16}, {1}}; Strategies inputs = {{2, 1, 8, 16}, {1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
prelu->Init(strategy, nullptr); prelu->Init(strategy, nullptr);
@ -75,7 +75,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape1) {
} }
TEST_F(TestPReLUInfo, InferSliceShape1) { TEST_F(TestPReLUInfo, InferSliceShape1) {
Strategys str = {{2, 1, 8, 16}, {1}}; Strategies str = {{2, 1, 8, 16}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu->Init(strategy, nullptr); prelu->Init(strategy, nullptr);
@ -98,7 +98,7 @@ TEST_F(TestPReLUInfo, InferSliceShape1) {
} }
TEST_F(TestPReLUInfo, GetTensorLayout1) { TEST_F(TestPReLUInfo, GetTensorLayout1) {
Strategys str = {{2, 1, 8, 16}, {1}}; Strategies str = {{2, 1, 8, 16}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu->Init(strategy, nullptr); prelu->Init(strategy, nullptr);
@ -122,7 +122,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout1) {
} }
TEST_F(TestPReLUInfo, GetMirrorOPs1) { TEST_F(TestPReLUInfo, GetMirrorOPs1) {
Strategys str = {{2, 1, 2, 2}, {1}}; Strategies str = {{2, 1, 2, 2}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu->Init(strategy, nullptr); prelu->Init(strategy, nullptr);
MirrorOps mirror_ops = prelu->mirror_ops(); MirrorOps mirror_ops = prelu->mirror_ops();
@ -139,14 +139,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs1) {
TEST_F(TestPReLUInfo, CheckStrategy1) { TEST_F(TestPReLUInfo, CheckStrategy1) {
// Success: {{2,1,8,16},{1}} // Success: {{2,1,8,16},{1}}
Strategys inputs = {{2, 1, 8, 16}}; Strategies inputs = {{2, 1, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu->Init(strategy, nullptr); Status ret = prelu->Init(strategy, nullptr);
ASSERT_EQ(ret, FAILED); ASSERT_EQ(ret, FAILED);
} }
TEST_F(TestPReLUInfo, CheckStrategy2) { TEST_F(TestPReLUInfo, CheckStrategy2) {
Strategys inputs = {{2, 4, 8, 16}, {4}}; Strategies inputs = {{2, 4, 8, 16}, {4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu->Init(strategy, nullptr); Status ret = prelu->Init(strategy, nullptr);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
@ -169,7 +169,7 @@ TEST_F(TestPReLUInfo, AutoStrategy1) {
} }
TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) { TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) {
Strategys inputs = {{128, 1}, {1}}; Strategies inputs = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
prelu_2d->Init(strategy, nullptr); prelu_2d->Init(strategy, nullptr);
@ -180,7 +180,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) {
} }
TEST_F(TestPReLUInfo, InferSliceShape_2d1) { TEST_F(TestPReLUInfo, InferSliceShape_2d1) {
Strategys str = {{128, 1}, {1}}; Strategies str = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu_2d->Init(strategy, nullptr); prelu_2d->Init(strategy, nullptr);
@ -203,7 +203,7 @@ TEST_F(TestPReLUInfo, InferSliceShape_2d1) {
} }
TEST_F(TestPReLUInfo, GetTensorLayout_2d1) { TEST_F(TestPReLUInfo, GetTensorLayout_2d1) {
Strategys str = {{128, 1}, {1}}; Strategies str = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu_2d->Init(strategy, nullptr); prelu_2d->Init(strategy, nullptr);
@ -227,7 +227,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout_2d1) {
} }
TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) { TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) {
Strategys str = {{128, 1}, {1}}; Strategies str = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu_2d->Init(strategy, nullptr); prelu_2d->Init(strategy, nullptr);
MirrorOps mirror_ops = prelu_2d->mirror_ops(); MirrorOps mirror_ops = prelu_2d->mirror_ops();
@ -244,14 +244,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) {
TEST_F(TestPReLUInfo, CheckStrategy_2d1) { TEST_F(TestPReLUInfo, CheckStrategy_2d1) {
// Success: {{2,1,8,16},{1}} // Success: {{2,1,8,16},{1}}
Strategys inputs = {{128, 1}}; Strategies inputs = {{128, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu_2d->Init(strategy, nullptr); Status ret = prelu_2d->Init(strategy, nullptr);
ASSERT_EQ(ret, FAILED); ASSERT_EQ(ret, FAILED);
} }
TEST_F(TestPReLUInfo, CheckStrategy_2d2) { TEST_F(TestPReLUInfo, CheckStrategy_2d2) {
Strategys inputs = {{128, 4}, {4}}; Strategies inputs = {{128, 4}, {4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu_2d->Init(strategy, nullptr); Status ret = prelu_2d->Init(strategy, nullptr);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);

View File

@ -68,7 +68,7 @@ void TestReduceSumInfo::SetUp() {
} }
TEST_F(TestReduceSumInfo, InferDevMatrixShape1) { TEST_F(TestReduceSumInfo, InferDevMatrixShape1) {
Strategys inputs = {{4, 8, 1}}; Strategies inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy, nullptr); reduce_sum->Init(strategy, nullptr);
@ -79,7 +79,7 @@ TEST_F(TestReduceSumInfo, InferDevMatrixShape1) {
} }
TEST_F(TestReduceSumInfo, InferSliceShape1) { TEST_F(TestReduceSumInfo, InferSliceShape1) {
Strategys str = {{4, 8, 1}}; Strategies str = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reduce_sum->Init(strategy, nullptr); reduce_sum->Init(strategy, nullptr);
@ -100,7 +100,7 @@ TEST_F(TestReduceSumInfo, InferSliceShape1) {
} }
TEST_F(TestReduceSumInfo, GetTensorLayout1) { TEST_F(TestReduceSumInfo, GetTensorLayout1) {
Strategys str = {{4, 8, 1}}; Strategies str = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reduce_sum->Init(strategy, nullptr); reduce_sum->Init(strategy, nullptr);
@ -121,7 +121,7 @@ TEST_F(TestReduceSumInfo, GetTensorLayout1) {
} }
TEST_F(TestReduceSumInfo, GetForwardOp1) { TEST_F(TestReduceSumInfo, GetForwardOp1) {
Strategys inputs = {{4, 8, 1}}; Strategies inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy, nullptr); reduce_sum->Init(strategy, nullptr);
@ -132,7 +132,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp1) {
} }
TEST_F(TestReduceSumInfo, GetForwardOp2) { TEST_F(TestReduceSumInfo, GetForwardOp2) {
Strategys inputs = {{4, 4, 2}}; Strategies inputs = {{4, 4, 2}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy, nullptr); reduce_sum->Init(strategy, nullptr);
@ -156,7 +156,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp2) {
} }
TEST_F(TestReduceSumInfo, GetMirrorOPs1) { TEST_F(TestReduceSumInfo, GetMirrorOPs1) {
Strategys inputs = {{4, 8, 1}}; Strategies inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy, nullptr); reduce_sum->Init(strategy, nullptr);
@ -168,7 +168,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs1) {
} }
TEST_F(TestReduceSumInfo, GetMirrorOPs2) { TEST_F(TestReduceSumInfo, GetMirrorOPs2) {
Strategys inputs = {{4, 4, 1}}; Strategies inputs = {{4, 4, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy, nullptr); reduce_sum->Init(strategy, nullptr);
@ -187,7 +187,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs2) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy1) { TEST_F(TestReduceSumInfo, CheckStrategy1) {
Strategys inputs = {{2, 2, 8, 16}}; Strategies inputs = {{2, 2, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy, nullptr); Status ret = reduce_sum->Init(strategy, nullptr);
@ -195,7 +195,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy1) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy2) { TEST_F(TestReduceSumInfo, CheckStrategy2) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy, nullptr); Status ret = reduce_sum->Init(strategy, nullptr);
@ -203,7 +203,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy2) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy3) { TEST_F(TestReduceSumInfo, CheckStrategy3) {
Strategys inputs = {{4, 4, 2}}; Strategies inputs = {{4, 4, 2}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy, nullptr); Status ret = reduce_sum->Init(strategy, nullptr);
@ -211,7 +211,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy3) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy4) { TEST_F(TestReduceSumInfo, CheckStrategy4) {
Strategys inputs = {{4, 8, 1}}; Strategies inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy, nullptr); Status ret = reduce_sum->Init(strategy, nullptr);

View File

@ -68,7 +68,7 @@ void TestReshapeInfo::SetUp() {
} }
TEST_F(TestReshapeInfo, InferDevMatrixShape1) { TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
Strategys inputs = {{4, 1, 1, 1}}; Strategies inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -79,7 +79,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
} }
TEST_F(TestReshapeInfo, InferDevMatrixShape2) { TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
Strategys inputs = {{32, 1, 1, 1}}; Strategies inputs = {{32, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -90,7 +90,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
} }
TEST_F(TestReshapeInfo, InferSliceShape1) { TEST_F(TestReshapeInfo, InferSliceShape1) {
Strategys str = {{4, 1, 1, 1}}; Strategies str = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -111,7 +111,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) {
} }
TEST_F(TestReshapeInfo, InferSliceShape2) { TEST_F(TestReshapeInfo, InferSliceShape2) {
Strategys str = {{32, 1, 1, 1}}; Strategies str = {{32, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -132,7 +132,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) {
} }
TEST_F(TestReshapeInfo, GetTensorLayout1) { TEST_F(TestReshapeInfo, GetTensorLayout1) {
Strategys str = {{4, 1, 1, 1}}; Strategies str = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -153,7 +153,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) {
} }
TEST_F(TestReshapeInfo, GetTensorLayout2) { TEST_F(TestReshapeInfo, GetTensorLayout2) {
Strategys str = {{32, 1, 1, 1}}; Strategies str = {{32, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -174,7 +174,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) {
} }
TEST_F(TestReshapeInfo, GetForwardOp1) { TEST_F(TestReshapeInfo, GetForwardOp1) {
Strategys inputs = {{4, 1, 1, 1}}; Strategies inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -185,7 +185,7 @@ TEST_F(TestReshapeInfo, GetForwardOp1) {
} }
TEST_F(TestReshapeInfo, GetMirrorOPs1) { TEST_F(TestReshapeInfo, GetMirrorOPs1) {
Strategys inputs = {{4, 1, 1, 1}}; Strategies inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy, nullptr); reshape->Init(strategy, nullptr);
@ -197,7 +197,7 @@ TEST_F(TestReshapeInfo, GetMirrorOPs1) {
} }
TEST_F(TestReshapeInfo, CheckStrategy1) { TEST_F(TestReshapeInfo, CheckStrategy1) {
Strategys inputs = {{1, 4, 8}}; Strategies inputs = {{1, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reshape->Init(strategy, nullptr); Status ret = reshape->Init(strategy, nullptr);
@ -205,7 +205,7 @@ TEST_F(TestReshapeInfo, CheckStrategy1) {
} }
TEST_F(TestReshapeInfo, CheckStrategy2) { TEST_F(TestReshapeInfo, CheckStrategy2) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reshape->Init(strategy, nullptr); Status ret = reshape->Init(strategy, nullptr);
@ -213,7 +213,7 @@ TEST_F(TestReshapeInfo, CheckStrategy2) {
} }
TEST_F(TestReshapeInfo, CheckStrategy3) { TEST_F(TestReshapeInfo, CheckStrategy3) {
Strategys inputs = {{4, 1, 1, 1}}; Strategies inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reshape->Init(strategy, nullptr); Status ret = reshape->Init(strategy, nullptr);

View File

@ -64,7 +64,7 @@ void TestSoftmaxLoss::SetUp() {
} }
TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) { TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategies inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy, nullptr); loss->Init(strategy, nullptr);
@ -75,7 +75,7 @@ TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) {
} }
TEST_F(TestSoftmaxLoss, InferSliceShape1) { TEST_F(TestSoftmaxLoss, InferSliceShape1) {
Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategies str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
loss->Init(strategy, nullptr); loss->Init(strategy, nullptr);
@ -104,7 +104,7 @@ TEST_F(TestSoftmaxLoss, InferSliceShape1) {
} }
TEST_F(TestSoftmaxLoss, GetTensorLayout1) { TEST_F(TestSoftmaxLoss, GetTensorLayout1) {
Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategies str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
loss->Init(strategy, nullptr); loss->Init(strategy, nullptr);
@ -133,7 +133,7 @@ TEST_F(TestSoftmaxLoss, GetTensorLayout1) {
} }
TEST_F(TestSoftmaxLoss, GetForwardOp1) { TEST_F(TestSoftmaxLoss, GetForwardOp1) {
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategies inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy, nullptr); loss->Init(strategy, nullptr);
@ -144,7 +144,7 @@ TEST_F(TestSoftmaxLoss, GetForwardOp1) {
} }
TEST_F(TestSoftmaxLoss, GetMirrorOPs1) { TEST_F(TestSoftmaxLoss, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategies inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy, nullptr); loss->Init(strategy, nullptr);
@ -156,7 +156,7 @@ TEST_F(TestSoftmaxLoss, GetMirrorOPs1) {
} }
TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) { TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) {
Strategys inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}}; Strategies inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy, nullptr); loss->Init(strategy, nullptr);
@ -176,7 +176,7 @@ TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) {
TEST_F(TestSoftmaxLoss, CheckStrategy1) { TEST_F(TestSoftmaxLoss, CheckStrategy1) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = loss->Init(strategy, nullptr); Status ret = loss->Init(strategy, nullptr);
@ -185,7 +185,7 @@ TEST_F(TestSoftmaxLoss, CheckStrategy1) {
TEST_F(TestSoftmaxLoss, CheckStrategy2) { TEST_F(TestSoftmaxLoss, CheckStrategy2) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
Strategys inputs = {{2, 4, 8}}; Strategies inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = loss->Init(strategy, nullptr); Status ret = loss->Init(strategy, nullptr);

View File

@ -68,7 +68,7 @@ void TestSoftmaxInfo::SetUp() {
} }
TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) { TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
softmax->Init(strategy, nullptr); softmax->Init(strategy, nullptr);
@ -79,7 +79,7 @@ TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) {
} }
TEST_F(TestSoftmaxInfo, InferSliceShape1) { TEST_F(TestSoftmaxInfo, InferSliceShape1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
softmax->Init(strategy, nullptr); softmax->Init(strategy, nullptr);
@ -100,7 +100,7 @@ TEST_F(TestSoftmaxInfo, InferSliceShape1) {
} }
TEST_F(TestSoftmaxInfo, GetTensorLayout1) { TEST_F(TestSoftmaxInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
softmax->Init(strategy, nullptr); softmax->Init(strategy, nullptr);
@ -121,7 +121,7 @@ TEST_F(TestSoftmaxInfo, GetTensorLayout1) {
} }
TEST_F(TestSoftmaxInfo, GetForwardOp1) { TEST_F(TestSoftmaxInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
softmax->Init(strategy, nullptr); softmax->Init(strategy, nullptr);
@ -132,7 +132,7 @@ TEST_F(TestSoftmaxInfo, GetForwardOp1) {
} }
TEST_F(TestSoftmaxInfo, GetMirrorOPs1) { TEST_F(TestSoftmaxInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
softmax->Init(strategy, nullptr); softmax->Init(strategy, nullptr);
@ -145,7 +145,7 @@ TEST_F(TestSoftmaxInfo, GetMirrorOPs1) {
TEST_F(TestSoftmaxInfo, CheckStrategy1) { TEST_F(TestSoftmaxInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax->Init(strategy, nullptr); Status ret = softmax->Init(strategy, nullptr);
@ -154,7 +154,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy1) {
TEST_F(TestSoftmaxInfo, CheckStrategy2) { TEST_F(TestSoftmaxInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 8}}; Strategies inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax->Init(strategy, nullptr); Status ret = softmax->Init(strategy, nullptr);
@ -163,7 +163,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy2) {
TEST_F(TestSoftmaxInfo, CheckStrategy3) { TEST_F(TestSoftmaxInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 8, 16}}; Strategies inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax->Init(strategy, nullptr); Status ret = softmax->Init(strategy, nullptr);
@ -172,7 +172,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy3) {
TEST_F(TestSoftmaxInfo, InitFailed1) { TEST_F(TestSoftmaxInfo, InitFailed1) {
// softmax2's axis is wrong // softmax2's axis is wrong
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax2->Init(strategy, nullptr); Status ret = softmax2->Init(strategy, nullptr);
@ -181,7 +181,7 @@ TEST_F(TestSoftmaxInfo, InitFailed1) {
TEST_F(TestSoftmaxInfo, InitFailed2) { TEST_F(TestSoftmaxInfo, InitFailed2) {
// dev num is wrong // dev num is wrong
Strategys inputs = {{2, 4, 1, 100}}; Strategies inputs = {{2, 4, 1, 100}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax2->Init(strategy, nullptr); Status ret = softmax2->Init(strategy, nullptr);

View File

@ -63,7 +63,7 @@ void TestTanhInfo::SetUp() {
} }
TEST_F(TestTanhInfo, InferDevMatrixShape1) { TEST_F(TestTanhInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tanh->Init(strategy, nullptr); tanh->Init(strategy, nullptr);
@ -74,7 +74,7 @@ TEST_F(TestTanhInfo, InferDevMatrixShape1) {
} }
TEST_F(TestTanhInfo, InferSliceShape1) { TEST_F(TestTanhInfo, InferSliceShape1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tanh->Init(strategy, nullptr); tanh->Init(strategy, nullptr);
@ -95,7 +95,7 @@ TEST_F(TestTanhInfo, InferSliceShape1) {
} }
TEST_F(TestTanhInfo, GetTensorLayout1) { TEST_F(TestTanhInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 1, 16}}; Strategies str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tanh->Init(strategy, nullptr); tanh->Init(strategy, nullptr);
@ -116,7 +116,7 @@ TEST_F(TestTanhInfo, GetTensorLayout1) {
} }
TEST_F(TestTanhInfo, GetForwardOp1) { TEST_F(TestTanhInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tanh->Init(strategy, nullptr); tanh->Init(strategy, nullptr);
@ -127,7 +127,7 @@ TEST_F(TestTanhInfo, GetForwardOp1) {
} }
TEST_F(TestTanhInfo, GetMirrorOPs1) { TEST_F(TestTanhInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tanh->Init(strategy, nullptr); tanh->Init(strategy, nullptr);
@ -140,7 +140,7 @@ TEST_F(TestTanhInfo, GetMirrorOPs1) {
TEST_F(TestTanhInfo, CheckStrategy1) { TEST_F(TestTanhInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tanh->Init(strategy, nullptr); Status ret = tanh->Init(strategy, nullptr);
@ -149,7 +149,7 @@ TEST_F(TestTanhInfo, CheckStrategy1) {
TEST_F(TestTanhInfo, CheckStrategy2) { TEST_F(TestTanhInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 8}}; Strategies inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tanh->Init(strategy, nullptr); Status ret = tanh->Init(strategy, nullptr);
@ -158,7 +158,7 @@ TEST_F(TestTanhInfo, CheckStrategy2) {
TEST_F(TestTanhInfo, CheckStrategy3) { TEST_F(TestTanhInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
Strategys inputs = {{2, 4, 1, 16}}; Strategies inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tanh->Init(strategy, nullptr); Status ret = tanh->Init(strategy, nullptr);

View File

@ -66,7 +66,7 @@ void TestTensorAddInfo::SetUp() {
} }
TEST_F(TestTensorAddInfo, InferDevMatrixShape1) { TEST_F(TestTensorAddInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add->Init(strategy, nullptr); tensor_add->Init(strategy, nullptr);
@ -77,7 +77,7 @@ TEST_F(TestTensorAddInfo, InferDevMatrixShape1) {
} }
TEST_F(TestTensorAddInfo, InferSliceShape1) { TEST_F(TestTensorAddInfo, InferSliceShape1) {
Strategys str = {{2, 4, 4}, {2, 4, 4}}; Strategies str = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tensor_add->Init(strategy, nullptr); tensor_add->Init(strategy, nullptr);
@ -101,7 +101,7 @@ TEST_F(TestTensorAddInfo, InferSliceShape1) {
} }
TEST_F(TestTensorAddInfo, GetTensorLayout1) { TEST_F(TestTensorAddInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 4}, {2, 4, 4}}; Strategies str = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tensor_add->Init(strategy, nullptr); tensor_add->Init(strategy, nullptr);
@ -125,7 +125,7 @@ TEST_F(TestTensorAddInfo, GetTensorLayout1) {
} }
TEST_F(TestTensorAddInfo, GetForwardOp1) { TEST_F(TestTensorAddInfo, GetForwardOp1) {
Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add->Init(strategy, nullptr); tensor_add->Init(strategy, nullptr);
@ -136,7 +136,7 @@ TEST_F(TestTensorAddInfo, GetForwardOp1) {
} }
TEST_F(TestTensorAddInfo, GetMirrorOPs1) { TEST_F(TestTensorAddInfo, GetMirrorOPs1) {
Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add->Init(strategy, nullptr); tensor_add->Init(strategy, nullptr);
@ -148,7 +148,7 @@ TEST_F(TestTensorAddInfo, GetMirrorOPs1) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy1) { TEST_F(TestTensorAddInfo, CheckStrategy1) {
Strategys inputs = {{2, 4, 4}, {2, 6, 4}}; Strategies inputs = {{2, 4, 4}, {2, 6, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy, nullptr); Status ret = tensor_add->Init(strategy, nullptr);
@ -156,7 +156,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy1) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy2) { TEST_F(TestTensorAddInfo, CheckStrategy2) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy, nullptr); Status ret = tensor_add->Init(strategy, nullptr);
@ -164,7 +164,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy2) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy3) { TEST_F(TestTensorAddInfo, CheckStrategy3) {
Strategys inputs = {{2, 4, 6}}; Strategies inputs = {{2, 4, 6}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy, nullptr); Status ret = tensor_add->Init(strategy, nullptr);
@ -172,7 +172,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy3) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy4) { TEST_F(TestTensorAddInfo, CheckStrategy4) {
Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy, nullptr); Status ret = tensor_add->Init(strategy, nullptr);
@ -224,7 +224,7 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
} }
TEST_F(TestTensorAddInfo, mirror_ops) { TEST_F(TestTensorAddInfo, mirror_ops) {
Strategys inputs = {{1, 8}, {4, 1}}; Strategies inputs = {{1, 8}, {4, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add1->Init(strategy, nullptr); tensor_add1->Init(strategy, nullptr);

View File

@ -65,7 +65,7 @@ void TestTmpIdentityInfo::SetUp() {
} }
TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) { TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) {
Strategys inputs = {{2, 4, 8, 16}}; Strategies inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
identity_ptr->Init(strategy, nullptr); identity_ptr->Init(strategy, nullptr);
@ -76,7 +76,7 @@ TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) {
} }
TEST_F(TestTmpIdentityInfo, InferSliceShape1) { TEST_F(TestTmpIdentityInfo, InferSliceShape1) {
Strategys str = {{2, 4, 8, 16}}; Strategies str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
identity_ptr->Init(strategy, nullptr); identity_ptr->Init(strategy, nullptr);
@ -97,7 +97,7 @@ TEST_F(TestTmpIdentityInfo, InferSliceShape1) {
} }
TEST_F(TestTmpIdentityInfo, GetTensorLayout1) { TEST_F(TestTmpIdentityInfo, GetTensorLayout1) {
Strategys str = {{2, 4, 8, 16}}; Strategies str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
identity_ptr->Init(strategy, nullptr); identity_ptr->Init(strategy, nullptr);
@ -119,7 +119,7 @@ TEST_F(TestTmpIdentityInfo, GetTensorLayout1) {
TEST_F(TestTmpIdentityInfo, CheckStrategy1) { TEST_F(TestTmpIdentityInfo, CheckStrategy1) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = identity_ptr->Init(strategy, nullptr); Status ret = identity_ptr->Init(strategy, nullptr);
@ -128,7 +128,7 @@ TEST_F(TestTmpIdentityInfo, CheckStrategy1) {
TEST_F(TestTmpIdentityInfo, CheckStrategy2) { TEST_F(TestTmpIdentityInfo, CheckStrategy2) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
Strategys inputs = {{2, 4, 8}}; Strategies inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = identity_ptr->Init(strategy, nullptr); Status ret = identity_ptr->Init(strategy, nullptr);

View File

@ -68,7 +68,7 @@ void TestTransposeInfo::SetUp() {
} }
TEST_F(TestTransposeInfo, InferDevMatrixShape1) { TEST_F(TestTransposeInfo, InferDevMatrixShape1) {
Strategys inputs = {{4, 8}}; Strategies inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy, nullptr); transpose->Init(strategy, nullptr);
@ -79,7 +79,7 @@ TEST_F(TestTransposeInfo, InferDevMatrixShape1) {
} }
TEST_F(TestTransposeInfo, InferDevMatrixShape2) { TEST_F(TestTransposeInfo, InferDevMatrixShape2) {
Strategys inputs = {{4, 1}}; Strategies inputs = {{4, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy, nullptr); transpose->Init(strategy, nullptr);
@ -90,7 +90,7 @@ TEST_F(TestTransposeInfo, InferDevMatrixShape2) {
} }
TEST_F(TestTransposeInfo, InferSliceShape1) { TEST_F(TestTransposeInfo, InferSliceShape1) {
Strategys str = {{4, 8}}; Strategies str = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
transpose->Init(strategy, nullptr); transpose->Init(strategy, nullptr);
@ -111,7 +111,7 @@ TEST_F(TestTransposeInfo, InferSliceShape1) {
} }
TEST_F(TestTransposeInfo, GetTensorLayout1) { TEST_F(TestTransposeInfo, GetTensorLayout1) {
Strategys str = {{4, 8}}; Strategies str = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
transpose->Init(strategy, nullptr); transpose->Init(strategy, nullptr);
@ -132,7 +132,7 @@ TEST_F(TestTransposeInfo, GetTensorLayout1) {
} }
TEST_F(TestTransposeInfo, GetForwardOp1) { TEST_F(TestTransposeInfo, GetForwardOp1) {
Strategys inputs = {{4, 8}}; Strategies inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy, nullptr); transpose->Init(strategy, nullptr);
@ -143,7 +143,7 @@ TEST_F(TestTransposeInfo, GetForwardOp1) {
} }
TEST_F(TestTransposeInfo, GetMirrorOPs1) { TEST_F(TestTransposeInfo, GetMirrorOPs1) {
Strategys inputs = {{4, 8}}; Strategies inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy, nullptr); transpose->Init(strategy, nullptr);
@ -155,7 +155,7 @@ TEST_F(TestTransposeInfo, GetMirrorOPs1) {
} }
TEST_F(TestTransposeInfo, CheckStrategy1) { TEST_F(TestTransposeInfo, CheckStrategy1) {
Strategys inputs = {{1, 4, 8}}; Strategies inputs = {{1, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = transpose->Init(strategy, nullptr); Status ret = transpose->Init(strategy, nullptr);
@ -163,7 +163,7 @@ TEST_F(TestTransposeInfo, CheckStrategy1) {
} }
TEST_F(TestTransposeInfo, CheckStrategy2) { TEST_F(TestTransposeInfo, CheckStrategy2) {
Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = transpose->Init(strategy, nullptr); Status ret = transpose->Init(strategy, nullptr);
@ -171,7 +171,7 @@ TEST_F(TestTransposeInfo, CheckStrategy2) {
} }
TEST_F(TestTransposeInfo, CheckStrategy3) { TEST_F(TestTransposeInfo, CheckStrategy3) {
Strategys inputs = {{4, 8}}; Strategies inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = transpose->Init(strategy, nullptr); Status ret = transpose->Init(strategy, nullptr);

View File

@ -237,9 +237,9 @@ TEST_F(TestStepParallel, ExtractStrategy) {
std::vector<ValuePtr> elements = {val1, val2}; std::vector<ValuePtr> elements = {val1, val2};
ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements); ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
attrs["in_strategy"] = strategy_tuple; attrs["in_strategy"] = strategy_tuple;
Strategys strategy_expect = {v1, v2}; Strategies strategy_expect = {v1, v2};
StrategyPtr strategy = ExtractStrategy(attrs["in_strategy"]); StrategyPtr strategy = ExtractStrategy(attrs["in_strategy"]);
Strategys strategy_test = strategy->GetInputDim(); Strategies strategy_test = strategy->GetInputDim();
ASSERT_EQ(strategy_expect, strategy_test); ASSERT_EQ(strategy_expect, strategy_test);
} }
@ -380,7 +380,7 @@ TEST_F(TestStepParallel, OperatorInstance) {
prim->set_attr("transpose_b", transpose_b); prim->set_attr("transpose_b", transpose_b);
auto attrs = prim->attrs(); auto attrs = prim->attrs();
// create strategy // create strategy
Strategys strategy = {{2, 2}, {2, 4}}; Strategies strategy = {{2, 2}, {2, 4}};
StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy); StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
// create shape // create shape
Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}}; Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
@ -557,7 +557,7 @@ TEST_F(TestStepParallel, GetTensorInLayout) {
prim->set_attr("transpose_b", transpose_b); prim->set_attr("transpose_b", transpose_b);
auto attrs = prim->attrs(); auto attrs = prim->attrs();
// create strategy // create strategy
Strategys strategy = {{2, 2}, {2, 4}}; Strategies strategy = {{2, 2}, {2, 4}};
StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy); StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
// create shape // create shape
Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}}; Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};

View File

@ -35,7 +35,7 @@ TEST_F(TestStrategy, GetInputNumber) {
int32_t stage = 1; int32_t stage = 1;
Dimensions dimension1 = {2, 4}; Dimensions dimension1 = {2, 4};
Dimensions dimension2 = {2, 2}; Dimensions dimension2 = {2, 2};
Strategys inputs = {dimension1, dimension2}; Strategies inputs = {dimension1, dimension2};
Strategy strategy(stage, inputs); Strategy strategy(stage, inputs);
int32_t number_test = strategy.GetInputNumber(); int32_t number_test = strategy.GetInputNumber();
@ -46,7 +46,7 @@ TEST_F(TestStrategy, GetInputStage) {
int32_t stage = 1; int32_t stage = 1;
Dimensions dimension1 = {2, 4}; Dimensions dimension1 = {2, 4};
Dimensions dimension2 = {2, 2}; Dimensions dimension2 = {2, 2};
Strategys inputs = {dimension1, dimension2}; Strategies inputs = {dimension1, dimension2};
Strategy strategy(stage, inputs); Strategy strategy(stage, inputs);
int32_t stage_test = strategy.GetInputStage(); int32_t stage_test = strategy.GetInputStage();
@ -57,10 +57,10 @@ TEST_F(TestStrategy, GetInputDim) {
int32_t stage = 1; int32_t stage = 1;
Dimensions dimension1 = {2, 4}; Dimensions dimension1 = {2, 4};
Dimensions dimension2 = {2, 2}; Dimensions dimension2 = {2, 2};
Strategys inputs = {dimension1, dimension2}; Strategies inputs = {dimension1, dimension2};
Strategy strategy(stage, inputs); Strategy strategy(stage, inputs);
Strategys inputs_test = strategy.GetInputDim(); Strategies inputs_test = strategy.GetInputDim();
ASSERT_EQ(inputs, inputs_test); ASSERT_EQ(inputs, inputs_test);
} }
@ -68,10 +68,10 @@ TEST_F(TestStrategy, IsEqual) {
int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0; int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0;
Dimensions dimension1 = {8, 1}; Dimensions dimension1 = {8, 1};
Dimensions dimension2 = {1, 8}; Dimensions dimension2 = {1, 8};
Strategys inputs1 = {dimension1}; Strategies inputs1 = {dimension1};
Strategys inputs2 = {dimension1}; Strategies inputs2 = {dimension1};
Strategys inputs3 = {dimension2}; Strategies inputs3 = {dimension2};
Strategys inputs4 = {dimension1, dimension2}; Strategies inputs4 = {dimension1, dimension2};
StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1); StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1);
StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2); StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2);

View File

@ -62,7 +62,7 @@ void TestConstructOperator::SetUp() {
MatMulInfoPtr matmul = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_1, outputs_shape_1, attr_1); MatMulInfoPtr matmul = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_1, outputs_shape_1, attr_1);
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategies str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul->Init(strategy, nullptr); matmul->Init(strategy, nullptr);
Shape tensor_shape = {512, 1024}; Shape tensor_shape = {512, 1024};

View File

@ -62,7 +62,7 @@ void TestVirtualDatasetInfo::SetUp() {
} }
TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) { TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) {
Strategys inputs = {{16, 1}, {16, 1}, {16, 1}}; Strategies inputs = {{16, 1}, {16, 1}, {16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
virtual_dataset->Init(strategy, nullptr); virtual_dataset->Init(strategy, nullptr);
Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape(); Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape();
@ -72,7 +72,7 @@ TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) {
} }
TEST_F(TestVirtualDatasetInfo, GetForwardOp1) { TEST_F(TestVirtualDatasetInfo, GetForwardOp1) {
Strategys inputs = {{8, 1}, {8, 1}, {8, 1}}; Strategies inputs = {{8, 1}, {8, 1}, {8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
virtual_dataset->Init(strategy, nullptr); virtual_dataset->Init(strategy, nullptr);
@ -83,7 +83,7 @@ TEST_F(TestVirtualDatasetInfo, GetForwardOp1) {
} }
TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) { TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) {
Strategys inputs = {{8, 1}, {8, 1}, {8, 1}}; Strategies inputs = {{8, 1}, {8, 1}, {8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
virtual_dataset->Init(strategy, nullptr); virtual_dataset->Init(strategy, nullptr);