modify strategys to strategies
This commit is contained in:
parent
df68f7cb92
commit
3b7fc4db29
|
@ -1674,7 +1674,7 @@ Status CostGraph::InitReshapeStrategy() {
|
||||||
if (stra.empty()) {
|
if (stra.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
|
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
|
||||||
}
|
}
|
||||||
Strategys stra_inputs = {stra};
|
Strategies stra_inputs = {stra};
|
||||||
StrategyPtr reshape_stra =
|
StrategyPtr reshape_stra =
|
||||||
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
|
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
|
||||||
reshape_info->set_strategy(reshape_stra);
|
reshape_info->set_strategy(reshape_stra);
|
||||||
|
|
|
@ -80,9 +80,9 @@ Dimensions PrepareMatMulStrategy(const std::shared_ptr<Graph> &graph, const size
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
Strategies PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const size_t iter_graph, const size_t iter_ops) {
|
const size_t iter_graph, const size_t iter_ops) {
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
auto attrs = ops[iter_ops]->attrs();
|
auto attrs = ops[iter_ops]->attrs();
|
||||||
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
|
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
|
||||||
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
|
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
|
||||||
|
@ -94,8 +94,8 @@ Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<s
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
|
Strategies PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
strategies.push_back(*s);
|
strategies.push_back(*s);
|
||||||
Dimensions s_biasadd;
|
Dimensions s_biasadd;
|
||||||
s_biasadd.push_back(s->at(1));
|
s_biasadd.push_back(s->at(1));
|
||||||
|
@ -103,9 +103,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions basic_stra) {
|
Dimensions basic_stra) {
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
|
|
||||||
auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1));
|
auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1));
|
||||||
auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2));
|
auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2));
|
||||||
|
@ -128,9 +128,9 @@ Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &
|
||||||
return stra;
|
return stra;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions basic_stra) {
|
Dimensions basic_stra) {
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
strategies.push_back(basic_stra);
|
strategies.push_back(basic_stra);
|
||||||
std::vector<int64_t> axis_list;
|
std::vector<int64_t> axis_list;
|
||||||
string axis_name = AXIS;
|
string axis_name = AXIS;
|
||||||
|
@ -173,8 +173,8 @@ Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
|
Strategies PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
|
|
||||||
// The Dimension size of the first input tensor of OneHot should be 2 even its Shape size is 1. Using the division of
|
// The Dimension size of the first input tensor of OneHot should be 2 even its Shape size is 1. Using the division of
|
||||||
// the number of devices and the partition parts of the first dimension.
|
// the number of devices and the partition parts of the first dimension.
|
||||||
|
@ -206,8 +206,8 @@ Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, c
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
|
Strategies PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
|
|
||||||
auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape();
|
auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape();
|
||||||
Dimensions index(output_shape.size() - 1, 0);
|
Dimensions index(output_shape.size() - 1, 0);
|
||||||
|
@ -302,8 +302,8 @@ Dimensions PrepareGatherV2OutputStrategy(const std::vector<std::shared_ptr<Opera
|
||||||
return strategie;
|
return strategie;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions s) {
|
Dimensions s) {
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto iter = ops[iter_ops]->attrs().find(AXIS);
|
auto iter = ops[iter_ops]->attrs().find(AXIS);
|
||||||
if (iter != ops[iter_ops]->attrs().end()) {
|
if (iter != ops[iter_ops]->attrs().end()) {
|
||||||
|
@ -323,15 +323,15 @@ Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &o
|
||||||
|
|
||||||
s[LongToSize(axis_index)] = 1;
|
s[LongToSize(axis_index)] = 1;
|
||||||
|
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
strategies.push_back(s);
|
strategies.push_back(s);
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops) {
|
const size_t iter_ops) {
|
||||||
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
Strategies strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||||
if (strategies.size() < 1) {
|
if (strategies.size() < 1) {
|
||||||
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
|
||||||
}
|
}
|
||||||
|
@ -380,9 +380,9 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops) {
|
const size_t iter_ops) {
|
||||||
if (ops.empty()) {
|
if (ops.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
||||||
}
|
}
|
||||||
|
@ -394,7 +394,7 @@ Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
}
|
}
|
||||||
|
|
||||||
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
||||||
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
|
||||||
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
|
||||||
|
@ -437,9 +437,9 @@ Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops) {
|
const size_t iter_ops) {
|
||||||
if (ops.empty()) {
|
if (ops.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
||||||
}
|
}
|
||||||
|
@ -448,7 +448,7 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
}
|
}
|
||||||
|
|
||||||
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
size_t max_device_num = g_device_manager->DeviceNum();
|
size_t max_device_num = g_device_manager->DeviceNum();
|
||||||
size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
|
size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
|
||||||
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
||||||
|
@ -500,9 +500,9 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
return strategies;
|
return strategies;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops) {
|
const size_t iter_ops) {
|
||||||
if (ops.empty()) {
|
if (ops.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
||||||
}
|
}
|
||||||
|
@ -511,7 +511,7 @@ Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
}
|
}
|
||||||
|
|
||||||
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
|
||||||
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
|
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
|
||||||
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
|
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
|
||||||
|
@ -540,7 +540,7 @@ Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
|
|
||||||
void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
|
void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
|
||||||
StrategyPtr origin_strategy = op->strategy();
|
StrategyPtr origin_strategy = op->strategy();
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
|
|
||||||
for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) {
|
for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) {
|
||||||
Dimensions s;
|
Dimensions s;
|
||||||
|
@ -561,8 +561,8 @@ void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
|
||||||
op->SetSelectedStrategyAndCost(sp, op->selected_cost());
|
op->SetSelectedStrategyAndCost(sp, op->selected_cost());
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
Strategies PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const size_t iter_graph, const size_t iter_ops) {
|
const size_t iter_graph, const size_t iter_ops) {
|
||||||
if (ops.empty()) {
|
if (ops.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
|
||||||
}
|
}
|
||||||
|
@ -593,7 +593,7 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const std::shared_ptr<std::vector<size_t>> &index_list) {
|
const std::shared_ptr<std::vector<size_t>> &index_list) {
|
||||||
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) {
|
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) {
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
size_t iter_graph = index_list->at(iter_ops);
|
size_t iter_graph = index_list->at(iter_ops);
|
||||||
if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) {
|
if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) {
|
||||||
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
|
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
|
||||||
|
@ -618,7 +618,7 @@ void ModifyParamSharingOpsStrategy(const std::vector<std::shared_ptr<OperatorInf
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Strategys strategies;
|
Strategies strategies;
|
||||||
Dimensions str1, str2;
|
Dimensions str1, str2;
|
||||||
str1 = str_j;
|
str1 = str_j;
|
||||||
size_t num_device_used = 1;
|
size_t num_device_used = 1;
|
||||||
|
@ -1107,9 +1107,9 @@ Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<O
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions basic_stra) {
|
Dimensions basic_stra) {
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
|
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
|
||||||
|
|
||||||
if (iter_ops >= ops.size()) {
|
if (iter_ops >= ops.size()) {
|
||||||
|
@ -1157,9 +1157,9 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
|
// Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
|
||||||
Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
const Dimensions s) {
|
const Dimensions s) {
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
|
|
||||||
size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
|
size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
|
||||||
size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size();
|
size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size();
|
||||||
|
@ -1256,10 +1256,10 @@ Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether the operator can be divided by the current strategy.
|
// Check whether the operator can be divided by the current strategy.
|
||||||
Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
const Dimensions basic_stra) {
|
const Dimensions basic_stra) {
|
||||||
Dimensions s_empty = {};
|
Dimensions s_empty = {};
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
|
|
||||||
// For all the input tensors.
|
// For all the input tensors.
|
||||||
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
|
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
|
||||||
|
@ -1302,7 +1302,7 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &gra
|
||||||
|
|
||||||
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
|
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
|
||||||
size_t iter_ops = no_stra_op_list->at(iter_list - 1);
|
size_t iter_ops = no_stra_op_list->at(iter_list - 1);
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
Dimensions s;
|
Dimensions s;
|
||||||
size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
|
size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
|
||||||
if (incoming_op_index != SIZE_MAX) {
|
if (incoming_op_index != SIZE_MAX) {
|
||||||
|
@ -1405,7 +1405,7 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_pt
|
||||||
|
|
||||||
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
|
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
|
||||||
auto iter_ops = no_stra_op_list->at(iter_list - 1);
|
auto iter_ops = no_stra_op_list->at(iter_list - 1);
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
Dimensions s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops);
|
Dimensions s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops);
|
||||||
if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) {
|
if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) {
|
||||||
s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s);
|
s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s);
|
||||||
|
@ -1444,7 +1444,7 @@ void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
|
|
||||||
for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) {
|
for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) {
|
||||||
auto iter_ops = no_stra_op_list->at(iter_list);
|
auto iter_ops = no_stra_op_list->at(iter_list);
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
Dimensions s;
|
Dimensions s;
|
||||||
|
|
||||||
size_t max_dim_num = 0;
|
size_t max_dim_num = 0;
|
||||||
|
|
|
@ -34,38 +34,38 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
|
||||||
const std::vector<std::vector<size_t>> &shared_tensors_ops);
|
const std::vector<std::vector<size_t>> &shared_tensors_ops);
|
||||||
Dimensions PrepareMatMulStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_graph, bool transpose_a,
|
Dimensions PrepareMatMulStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_graph, bool transpose_a,
|
||||||
bool transpose_b, size_t iter_op_inputs);
|
bool transpose_b, size_t iter_op_inputs);
|
||||||
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
Strategies PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const size_t iter_graph, const size_t iter_ops);
|
const size_t iter_graph, const size_t iter_ops);
|
||||||
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
|
Strategies PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
|
||||||
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions basic_stra);
|
Dimensions basic_stra);
|
||||||
Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions basic_stra);
|
Dimensions basic_stra);
|
||||||
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
Strategies PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
||||||
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops);
|
const size_t iter_ops);
|
||||||
Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
Strategies PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
||||||
Dimensions PrepareGatherV2OutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
Dimensions PrepareGatherV2OutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const size_t incoming_op_index);
|
const size_t incoming_op_index);
|
||||||
Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions s);
|
Dimensions s);
|
||||||
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops);
|
const size_t iter_ops);
|
||||||
Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
Strategies CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
||||||
Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s,
|
Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s,
|
||||||
size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor);
|
size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor);
|
||||||
Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
Strategies CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
||||||
Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops);
|
const size_t iter_ops);
|
||||||
Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
|
Strategies MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||||
const size_t iter_ops);
|
const size_t iter_ops);
|
||||||
void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op);
|
void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op);
|
||||||
Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
Strategies PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const size_t iter_graph, const size_t iter_ops);
|
const size_t iter_graph, const size_t iter_ops);
|
||||||
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
|
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const std::shared_ptr<std::vector<size_t>> &index_list);
|
const std::shared_ptr<std::vector<size_t>> &index_list);
|
||||||
|
@ -97,8 +97,8 @@ Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<Operato
|
||||||
const size_t incoming_op_index, Dimensions s);
|
const size_t incoming_op_index, Dimensions s);
|
||||||
Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const size_t iter_ops, const size_t incoming_op_index);
|
const size_t iter_ops, const size_t incoming_op_index);
|
||||||
Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||||
Dimensions basic_stra);
|
Dimensions basic_stra);
|
||||||
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
|
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
|
||||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||||
const std::vector<std::vector<std::string>> &input_tensor_names,
|
const std::vector<std::vector<std::string>> &input_tensor_names,
|
||||||
|
|
|
@ -108,7 +108,7 @@ Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
|
||||||
for (auto &element : axis_) {
|
for (auto &element : axis_) {
|
||||||
|
@ -243,7 +243,7 @@ Status CumOpBase::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
if (input_strategy.size() <= LongToSize(axis_)) {
|
if (input_strategy.size() <= LongToSize(axis_)) {
|
||||||
MS_LOG(ERROR) << "The " << name_ << " input strategy length: " << input_strategy.size() << ", is less ot equal to "
|
MS_LOG(ERROR) << "The " << name_ << " input strategy length: " << input_strategy.size() << ", is less ot equal to "
|
||||||
|
@ -292,7 +292,7 @@ Status CumOpBase::InferMirrorOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ActivationBase::InferDevMatrixShape() {
|
Status ActivationBase::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
|
|
|
@ -256,8 +256,8 @@ class ExpandDimsInfo : public ActivationOther {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t positive_axis_ = -1;
|
int64_t positive_axis_ = -1;
|
||||||
Strategys inputs_strategy_;
|
Strategies inputs_strategy_;
|
||||||
Strategys outputs_strategy_;
|
Strategies outputs_strategy_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SqueezeInfo : public ActivationOther {
|
class SqueezeInfo : public ActivationOther {
|
||||||
|
|
|
@ -27,7 +27,7 @@ Status AddNInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The strategy for each input tensor must be equal
|
// The strategy for each input tensor must be equal
|
||||||
Strategys strategies = strategy->GetInputDim();
|
Strategies strategies = strategy->GetInputDim();
|
||||||
for (size_t i = 1; i < strategies.size(); ++i) {
|
for (size_t i = 1; i < strategies.size(); ++i) {
|
||||||
if (strategies[i] != strategies[0]) {
|
if (strategies[i] != strategies[0]) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -39,7 +39,7 @@ Status AddNInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
Status AddNInfo::InferDevMatrixShape() {
|
Status AddNInfo::InferDevMatrixShape() {
|
||||||
dev_matrix_shape_.clear();
|
dev_matrix_shape_.clear();
|
||||||
|
|
||||||
Strategys strategies = strategy_->GetInputDim();
|
Strategies strategies = strategy_->GetInputDim();
|
||||||
if (strategies.empty()) {
|
if (strategies.empty()) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ Status AddNInfo::InferTensorMap() {
|
||||||
sub_tensor_map.push_back(dev_size - i - 1);
|
sub_tensor_map.push_back(dev_size - i - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategies = strategy_->GetInputDim();
|
Strategies strategies = strategy_->GetInputDim();
|
||||||
for (size_t i = 0; i < strategies.size(); ++i) {
|
for (size_t i = 0; i < strategies.size(); ++i) {
|
||||||
inputs_tensor_map_.push_back(sub_tensor_map);
|
inputs_tensor_map_.push_back(sub_tensor_map);
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,9 +53,9 @@ Shapes ArithmeticBase::InferExpandShape() {
|
||||||
return input_shapes;
|
return input_shapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys ExpandStrategy(const StrategyPtr &strategy) {
|
Strategies ExpandStrategy(const StrategyPtr &strategy) {
|
||||||
Strategys expand_strategy;
|
Strategies expand_strategy;
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions sub_a_strategy = stra.at(0);
|
Dimensions sub_a_strategy = stra.at(0);
|
||||||
Dimensions sub_b_strategy = stra.at(1);
|
Dimensions sub_b_strategy = stra.at(1);
|
||||||
size_t input_a_size = sub_a_strategy.size();
|
size_t input_a_size = sub_a_strategy.size();
|
||||||
|
@ -77,7 +77,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Shapes input_shapes = InferExpandShape();
|
Shapes input_shapes = InferExpandShape();
|
||||||
Strategys expand_strategy = ExpandStrategy(strategy);
|
Strategies expand_strategy = ExpandStrategy(strategy);
|
||||||
Dimensions sub_a_strategy = expand_strategy.at(0);
|
Dimensions sub_a_strategy = expand_strategy.at(0);
|
||||||
Dimensions sub_b_strategy = expand_strategy.at(1);
|
Dimensions sub_b_strategy = expand_strategy.at(1);
|
||||||
Shape input_a_shape = input_shapes.at(0);
|
Shape input_a_shape = input_shapes.at(0);
|
||||||
|
@ -93,7 +93,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ArithmeticBase::InferDevMatrixShape() {
|
Status ArithmeticBase::InferDevMatrixShape() {
|
||||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
Strategies expand_strategy = ExpandStrategy(strategy_);
|
||||||
Dimensions sub_a_strategy = expand_strategy.at(0);
|
Dimensions sub_a_strategy = expand_strategy.at(0);
|
||||||
Dimensions sub_b_strategy = expand_strategy.at(1);
|
Dimensions sub_b_strategy = expand_strategy.at(1);
|
||||||
Shape dev_shape;
|
Shape dev_shape;
|
||||||
|
@ -150,10 +150,10 @@ void ArithmeticBase::ReComputeBatchSplitFlagList() {
|
||||||
|
|
||||||
Status ArithmeticBase::InferTensorMap() {
|
Status ArithmeticBase::InferTensorMap() {
|
||||||
Shape tensor_map_index;
|
Shape tensor_map_index;
|
||||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
Strategies expand_strategy = ExpandStrategy(strategy_);
|
||||||
Dimensions sub_a_expand_strategy = expand_strategy.at(0);
|
Dimensions sub_a_expand_strategy = expand_strategy.at(0);
|
||||||
Dimensions sub_b_expand_strategy = expand_strategy.at(1);
|
Dimensions sub_b_expand_strategy = expand_strategy.at(1);
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions sub_a_strategy = stra.at(0);
|
Dimensions sub_a_strategy = stra.at(0);
|
||||||
Dimensions sub_b_strategy = stra.at(1);
|
Dimensions sub_b_strategy = stra.at(1);
|
||||||
for (size_t i = 0; i < sub_a_expand_strategy.size(); ++i) {
|
for (size_t i = 0; i < sub_a_expand_strategy.size(); ++i) {
|
||||||
|
@ -251,7 +251,7 @@ Status LerpInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate strategy of weight
|
// validate strategy of weight
|
||||||
Strategys expand_strategy = ExpandStrategy(strategy);
|
Strategies expand_strategy = ExpandStrategy(strategy);
|
||||||
Dimensions expand_begin_strategy = expand_strategy.at(0);
|
Dimensions expand_begin_strategy = expand_strategy.at(0);
|
||||||
Dimensions expand_end_strategy = expand_strategy.at(1);
|
Dimensions expand_end_strategy = expand_strategy.at(1);
|
||||||
Dimensions expand_cmp_strategy;
|
Dimensions expand_cmp_strategy;
|
||||||
|
@ -286,7 +286,7 @@ Status LerpInfo::InferDevMatrixShape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
dev_matrix_shape_.clear();
|
dev_matrix_shape_.clear();
|
||||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
Strategies expand_strategy = ExpandStrategy(strategy_);
|
||||||
Dimensions expand_start_strategy = expand_strategy.at(0);
|
Dimensions expand_start_strategy = expand_strategy.at(0);
|
||||||
Dimensions expand_end_strategy = expand_strategy.at(1);
|
Dimensions expand_end_strategy = expand_strategy.at(1);
|
||||||
auto strategies = strategy_->GetInputDim();
|
auto strategies = strategy_->GetInputDim();
|
||||||
|
@ -316,9 +316,9 @@ Status LerpInfo::InferTensorMap() {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
// Generate tensor map for 'weight'
|
// Generate tensor map for 'weight'
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions weight_strategy = stra.at(2);
|
Dimensions weight_strategy = stra.at(2);
|
||||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
Strategies expand_strategy = ExpandStrategy(strategy_);
|
||||||
Dimensions expand_start_strategy = expand_strategy.at(0);
|
Dimensions expand_start_strategy = expand_strategy.at(0);
|
||||||
Dimensions expand_weight_strategy = ExpandShape(expand_start_strategy, weight_strategy);
|
Dimensions expand_weight_strategy = ExpandShape(expand_start_strategy, weight_strategy);
|
||||||
Shape dev_shape = dev_matrix_shape_;
|
Shape dev_shape = dev_matrix_shape_;
|
||||||
|
|
|
@ -32,7 +32,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t strategy_size = strategy->GetInputNumber();
|
size_t strategy_size = strategy->GetInputNumber();
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
for (size_t i = 0; i < strategy_size; ++i) {
|
for (size_t i = 0; i < strategy_size; ++i) {
|
||||||
Shape sub_strategy = stra.at(i);
|
Shape sub_strategy = stra.at(i);
|
||||||
size_t strategy_len = sub_strategy.size();
|
size_t strategy_len = sub_strategy.size();
|
||||||
|
@ -122,7 +122,7 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||||
|
|
||||||
std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_id) {
|
std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
StrategyPtr sp;
|
StrategyPtr sp;
|
||||||
Strategys strategy;
|
Strategies strategy;
|
||||||
ComputeBatchSplitFlagList();
|
ComputeBatchSplitFlagList();
|
||||||
|
|
||||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||||
|
@ -181,7 +181,7 @@ Status CheckValidInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (stra[0][1] != 1) {
|
if (stra[0][1] != 1) {
|
||||||
MS_LOG(ERROR) << name_ << ": The second dimension of the first input can not be split, but got " << stra[0][1];
|
MS_LOG(ERROR) << name_ << ": The second dimension of the first input can not be split, but got " << stra[0][1];
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -195,7 +195,7 @@ Status CheckValidInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CheckValidInfo::InferDevMatrixShape() {
|
Status CheckValidInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
dev_matrix_shape_.push_back(stra[0][0]);
|
dev_matrix_shape_.push_back(stra[0][0]);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
|
@ -315,7 +315,7 @@ std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
if (first_input_strategy.size() < 2) {
|
if (first_input_strategy.size() < 2) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The size of first input strategy can not smaller than 2, but got "
|
MS_LOG(EXCEPTION) << name_ << ": The size of first input strategy can not smaller than 2, but got "
|
||||||
|
|
|
@ -30,7 +30,7 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions sub_a_strategy = stra.at(0);
|
Dimensions sub_a_strategy = stra.at(0);
|
||||||
Dimensions sub_b_strategy = stra.at(1);
|
Dimensions sub_b_strategy = stra.at(1);
|
||||||
int64_t channel_a_strategy = sub_a_strategy.at(1);
|
int64_t channel_a_strategy = sub_a_strategy.at(1);
|
||||||
|
@ -43,7 +43,7 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BiasAddInfo::InferDevMatrixShape() {
|
Status BiasAddInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions sub_a_strategy = stra.at(0);
|
Dimensions sub_a_strategy = stra.at(0);
|
||||||
dev_matrix_shape_ = sub_a_strategy;
|
dev_matrix_shape_ = sub_a_strategy;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
@ -57,7 +57,7 @@ void BiasAddInfo::ReComputeBatchSplitFlagList() {
|
||||||
Status BiasAddInfo::InferTensorMap() {
|
Status BiasAddInfo::InferTensorMap() {
|
||||||
TensorMap sub_a_tensor_map;
|
TensorMap sub_a_tensor_map;
|
||||||
TensorMap sub_b_tensor_map;
|
TensorMap sub_b_tensor_map;
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions sub_a_strategy = stra.at(0);
|
Dimensions sub_a_strategy = stra.at(0);
|
||||||
size_t sub_a_strategy_size = sub_a_strategy.size();
|
size_t sub_a_strategy_size = sub_a_strategy.size();
|
||||||
for (size_t i = 0; i < sub_a_strategy_size; ++i) {
|
for (size_t i = 0; i < sub_a_strategy_size; ++i) {
|
||||||
|
@ -88,7 +88,7 @@ std::vector<StrategyPtr> BiasAddInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
MS_LOG(INFO) << name_ << " : Generate strategies success.";
|
MS_LOG(INFO) << name_ << " : Generate strategies success.";
|
||||||
|
|
||||||
for (auto &sp : sp_vector) {
|
for (auto &sp : sp_vector) {
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions input0_strategy = sp->GetInputDim()[0];
|
Dimensions input0_strategy = sp->GetInputDim()[0];
|
||||||
tmp_strategy.push_back(input0_strategy); // input0
|
tmp_strategy.push_back(input0_strategy); // input0
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ Status BoundingBoxEncodeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategies = strategy->GetInputDim();
|
Strategies strategies = strategy->GetInputDim();
|
||||||
Dimensions input_a_strategy = strategies[0];
|
Dimensions input_a_strategy = strategies[0];
|
||||||
Dimensions input_b_strategy = strategies[1];
|
Dimensions input_b_strategy = strategies[1];
|
||||||
if (input_a_strategy != input_b_strategy) {
|
if (input_a_strategy != input_b_strategy) {
|
||||||
|
@ -45,7 +45,7 @@ Status BoundingBoxEncodeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BoundingBoxEncodeInfo::InferDevMatrixShape() {
|
Status BoundingBoxEncodeInfo::InferDevMatrixShape() {
|
||||||
Strategys strategies = strategy_->GetInputDim();
|
Strategies strategies = strategy_->GetInputDim();
|
||||||
Dimensions input_a_strategy = strategies.at(0);
|
Dimensions input_a_strategy = strategies.at(0);
|
||||||
|
|
||||||
dev_matrix_shape_.clear();
|
dev_matrix_shape_.clear();
|
||||||
|
@ -106,7 +106,7 @@ Status BoundingBoxEncodeInfo::PrepareStrategy(int64_t stage_id, int64_t split_nu
|
||||||
|
|
||||||
Dimensions input0_partitions = {split_num, 1};
|
Dimensions input0_partitions = {split_num, 1};
|
||||||
Dimensions input1_partitions = {split_num, 1};
|
Dimensions input1_partitions = {split_num, 1};
|
||||||
Strategys strategies = {input0_partitions, input1_partitions};
|
Strategies strategies = {input0_partitions, input1_partitions};
|
||||||
(*sp) = std::make_shared<Strategy>(stage_id, strategies);
|
(*sp) = std::make_shared<Strategy>(stage_id, strategies);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
|
@ -141,7 +141,7 @@ std::vector<StrategyPtr> BroadcastToInfo::GenerateOpStrategies(int64_t stage_id)
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
tmp_strategy.push_back(first_input_strategy);
|
tmp_strategy.push_back(first_input_strategy);
|
||||||
|
|
|
@ -172,7 +172,7 @@ std::vector<StrategyPtr> ConcatInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
tmp_strategy.push_back(first_input_strategy);
|
tmp_strategy.push_back(first_input_strategy);
|
||||||
|
|
|
@ -936,7 +936,7 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
auto search_mode = parallel_context->strategy_search_mode();
|
auto search_mode = parallel_context->strategy_search_mode();
|
||||||
// generate data parallel strategy when the search mode is not sharding propagation
|
// generate data parallel strategy when the search mode is not sharding propagation
|
||||||
if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) {
|
if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) {
|
||||||
Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
|
Strategies strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
|
||||||
StrategyPtr data_parallel_sp = std::make_shared<Strategy>(stage_id, strategy);
|
StrategyPtr data_parallel_sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||||
sp_vector.push_back(data_parallel_sp);
|
sp_vector.push_back(data_parallel_sp);
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
|
@ -960,7 +960,7 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys replace_strategy;
|
Strategies replace_strategy;
|
||||||
Dimensions tmp_strategy = sp->GetInputDim()[0];
|
Dimensions tmp_strategy = sp->GetInputDim()[0];
|
||||||
if (tmp_strategy.size() != 5) {
|
if (tmp_strategy.size() != 5) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The size of first tmp strategy must be 5, but got " << tmp_strategy.size();
|
MS_LOG(EXCEPTION) << name_ << ": The size of first tmp strategy must be 5, but got " << tmp_strategy.size();
|
||||||
|
|
|
@ -29,7 +29,7 @@ Status CropAndResizeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategies = strategy->GetInputDim();
|
Strategies strategies = strategy->GetInputDim();
|
||||||
auto x_strategy = strategies.at(0);
|
auto x_strategy = strategies.at(0);
|
||||||
auto boxes_strategy = strategies.at(1);
|
auto boxes_strategy = strategies.at(1);
|
||||||
auto index_strategy = strategies.at(2);
|
auto index_strategy = strategies.at(2);
|
||||||
|
|
|
@ -39,7 +39,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (stra.size() != 1) {
|
if (stra.size() != 1) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1";
|
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -61,7 +61,7 @@ Status DropoutDoMaskInfo::InferDevMatrixShape() {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategy = strategy_->GetInputDim();
|
Strategies strategy = strategy_->GetInputDim();
|
||||||
if (strategy.empty()) {
|
if (strategy.empty()) {
|
||||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -110,11 +110,11 @@ std::vector<StrategyPtr> DropoutDoMaskInfo::GenerateOpStrategies(int64_t stage_i
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> DropoutDoMaskInfo::GenerateBatchStrategies() {
|
||||||
Dimensions strategy(inputs_shape_[0].size() - 1, 1);
|
Dimensions strategy(inputs_shape_[0].size() - 1, 1);
|
||||||
(void)strategy.insert(strategy.begin(), stage_device_size_);
|
(void)strategy.insert(strategy.begin(), stage_device_size_);
|
||||||
Strategys strategy_v = {strategy};
|
Strategies strategy_v = {strategy};
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GetNonMonadInputSize(const CNodePtr &cnode) {
|
size_t GetNonMonadInputSize(const CNodePtr &cnode) {
|
||||||
|
|
|
@ -38,7 +38,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
|
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
|
||||||
void ReplaceNodeInputOrAttrs() override;
|
void ReplaceNodeInputOrAttrs() override;
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ Status DSDMatmulInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys stras = strategy->GetInputDim();
|
Strategies stras = strategy->GetInputDim();
|
||||||
if (stras.size() != DSD_MATMUL_INPUTS_SIZE) {
|
if (stras.size() != DSD_MATMUL_INPUTS_SIZE) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 3.";
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 3.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -89,7 +89,7 @@ Status DSDMatmulInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
* device matrix use the strategy0.
|
* device matrix use the strategy0.
|
||||||
*/
|
*/
|
||||||
Status DSDMatmulInfo::InferDevMatrixShape() {
|
Status DSDMatmulInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
input_strategy_ = input_strategy;
|
input_strategy_ = input_strategy;
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
|
@ -172,7 +172,7 @@ std::vector<StrategyPtr> DSDMatmulInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions input_w1_strategy = sp->GetInputDim()[0];
|
Dimensions input_w1_strategy = sp->GetInputDim()[0];
|
||||||
Dimensions input_w2_strategy = input_w1_strategy;
|
Dimensions input_w2_strategy = input_w1_strategy;
|
||||||
Dimensions input_v_strategy = {input_w1_strategy[0], input_w1_strategy[1], 1, 1};
|
Dimensions input_v_strategy = {input_w1_strategy[0], input_w1_strategy[1], 1, 1};
|
||||||
|
|
|
@ -175,7 +175,7 @@ Status GatherInfo::GetAttrs() {
|
||||||
// output's strategy: [a, b, ..., c] or [1, a, b, ..., c]
|
// output's strategy: [a, b, ..., c] or [1, a, b, ..., c]
|
||||||
// dev_matrix: [a, b, ..., c]
|
// dev_matrix: [a, b, ..., c]
|
||||||
// can not support repeated calculation
|
// can not support repeated calculation
|
||||||
Status GatherInfo::CheckManualSplit(const Strategys &strategy) {
|
Status GatherInfo::CheckManualSplit(const Strategies &strategy) {
|
||||||
if (strategy.size() != 2) {
|
if (strategy.size() != 2) {
|
||||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
|
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -270,7 +270,7 @@ Status GatherInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
|
||||||
|
|
||||||
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
|
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
|
||||||
// otherwise return false
|
// otherwise return false
|
||||||
bool GatherInfo::ShardBatchAndAxis(const Strategys &strategy) const {
|
bool GatherInfo::ShardBatchAndAxis(const Strategies &strategy) const {
|
||||||
if (axis_ != 0) {
|
if (axis_ != 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1015,7 +1015,7 @@ std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> GatherInfo::GenerateBatchStrategies() {
|
||||||
if (GetAttrs() != SUCCESS) {
|
if (GetAttrs() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||||
}
|
}
|
||||||
|
@ -1029,8 +1029,8 @@ std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
|
||||||
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
|
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
|
||||||
index_strategy.push_back(1);
|
index_strategy.push_back(1);
|
||||||
}
|
}
|
||||||
Strategys strategy_v = {param_strategy, index_strategy};
|
Strategies strategy_v = {param_strategy, index_strategy};
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -45,7 +45,7 @@ class GatherInfo : public OperatorInfo {
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
const std::vector<int64_t> ¶m_split_shapes() const { return param_split_shapes_; }
|
const std::vector<int64_t> ¶m_split_shapes() const { return param_split_shapes_; }
|
||||||
const std::vector<int64_t> &index_offsets() const { return index_offsets_; }
|
const std::vector<int64_t> &index_offsets() const { return index_offsets_; }
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class GatherInfo : public OperatorInfo {
|
||||||
void InferOutputsTensorMap();
|
void InferOutputsTensorMap();
|
||||||
void InferTensorMapForManualSplit();
|
void InferTensorMapForManualSplit();
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
Status CheckManualSplit(const Strategys &strategy);
|
Status CheckManualSplit(const Strategies &strategy);
|
||||||
Status CheckSplitAxisStrategy(const StrategyPtr &strategy);
|
Status CheckSplitAxisStrategy(const StrategyPtr &strategy);
|
||||||
void SetAttribute(const StrategyPtr &strategy);
|
void SetAttribute(const StrategyPtr &strategy);
|
||||||
Status GetManualSplitAttr();
|
Status GetManualSplitAttr();
|
||||||
|
@ -73,7 +73,7 @@ class GatherInfo : public OperatorInfo {
|
||||||
Status InferBias();
|
Status InferBias();
|
||||||
Status InferOffset();
|
Status InferOffset();
|
||||||
Status InferGroup();
|
Status InferGroup();
|
||||||
bool ShardBatchAndAxis(const Strategys &strategy) const;
|
bool ShardBatchAndAxis(const Strategies &strategy) const;
|
||||||
Shape InferOutputsTensorMapSplitAxis();
|
Shape InferOutputsTensorMapSplitAxis();
|
||||||
|
|
||||||
int64_t axis_;
|
int64_t axis_;
|
||||||
|
|
|
@ -206,7 +206,7 @@ std::vector<StrategyPtr> GatherDInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
tmp_strategy.push_back(first_input_strategy);
|
tmp_strategy.push_back(first_input_strategy);
|
||||||
|
|
|
@ -139,7 +139,7 @@ std::vector<StrategyPtr> GatherNdInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions indices_strategy = sp->GetInputDim()[0];
|
Dimensions indices_strategy = sp->GetInputDim()[0];
|
||||||
Dimensions input_strategy(inputs_shape_[0].size(), 1);
|
Dimensions input_strategy(inputs_shape_[0].size(), 1);
|
||||||
tmp_strategy.push_back(input_strategy);
|
tmp_strategy.push_back(input_strategy);
|
||||||
|
|
|
@ -105,7 +105,7 @@ Status GetNextInfo::InferDevMatrixShape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) {
|
Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
Strategys stras = strategy->GetInputDim();
|
Strategies stras = strategy->GetInputDim();
|
||||||
for (Dimensions stra : stras) {
|
for (Dimensions stra : stras) {
|
||||||
if (stra.size() != 0) {
|
if (stra.size() != 0) {
|
||||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||||
|
@ -219,7 +219,7 @@ void GetNextInfo::InferReplaceOps() {
|
||||||
Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
|
||||||
std::vector<StrategyPtr> GetNextInfo::GenerateOpStrategies(int64_t stage_id) {
|
std::vector<StrategyPtr> GetNextInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
Strategys stra;
|
Strategies stra;
|
||||||
StrategyPtr sp = std::make_shared<Strategy>(stage_id, stra);
|
StrategyPtr sp = std::make_shared<Strategy>(stage_id, stra);
|
||||||
std::vector<StrategyPtr> sp_vector;
|
std::vector<StrategyPtr> sp_vector;
|
||||||
sp_vector.push_back(sp);
|
sp_vector.push_back(sp);
|
||||||
|
|
|
@ -60,7 +60,7 @@ class GetNextInfo : public OperatorInfo {
|
||||||
int64_t output_num_ = 0;
|
int64_t output_num_ = 0;
|
||||||
int64_t shard_num_ = 1;
|
int64_t shard_num_ = 1;
|
||||||
std::string shared_name_;
|
std::string shared_name_;
|
||||||
Strategys dataset_strategy_;
|
Strategies dataset_strategy_;
|
||||||
Shape dev_matrix_shape_origin_;
|
Shape dev_matrix_shape_origin_;
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -27,7 +27,7 @@ Status InplaceAddInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys strategies = strategy->GetInputDim();
|
Strategies strategies = strategy->GetInputDim();
|
||||||
auto x_strategy = strategies.at(0);
|
auto x_strategy = strategies.at(0);
|
||||||
auto input_v_strategy = strategies.at(1);
|
auto input_v_strategy = strategies.at(1);
|
||||||
if (x_strategy[0] != 1 || input_v_strategy[0] != 1) {
|
if (x_strategy[0] != 1 || input_v_strategy[0] != 1) {
|
||||||
|
|
|
@ -24,7 +24,7 @@ Status IOUInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategies = strategy->GetInputDim();
|
Strategies strategies = strategy->GetInputDim();
|
||||||
if (strategies[0][1] != 1 || strategies[1][1] != 1) {
|
if (strategies[0][1] != 1 || strategies[1][1] != 1) {
|
||||||
MS_LOG(ERROR) << name_ << ": Only supports shard the 0th dimension of each input tensor, but got strategy "
|
MS_LOG(ERROR) << name_ << ": Only supports shard the 0th dimension of each input tensor, but got strategy "
|
||||||
<< StrategyToString(strategies);
|
<< StrategyToString(strategies);
|
||||||
|
@ -34,7 +34,7 @@ Status IOUInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IOUInfo::InferDevMatrixShape() {
|
Status IOUInfo::InferDevMatrixShape() {
|
||||||
Strategys strategise = strategy_->GetInputDim();
|
Strategies strategise = strategy_->GetInputDim();
|
||||||
int64_t dev1 = strategise[0][0];
|
int64_t dev1 = strategise[0][0];
|
||||||
int64_t dev0 = strategise[1][0];
|
int64_t dev0 = strategise[1][0];
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
int64_t axis_index = axis_;
|
int64_t axis_index = axis_;
|
||||||
if (axis_ < 0) {
|
if (axis_ < 0) {
|
||||||
|
|
|
@ -61,7 +61,7 @@ Status LayerNormInfo::GetAttrs() {
|
||||||
|
|
||||||
Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
|
Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
MS_EXCEPTION_IF_NULL(strategy);
|
MS_EXCEPTION_IF_NULL(strategy);
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (stra.size() != LAYER_NORM_INPUT_SIZE) {
|
if (stra.size() != LAYER_NORM_INPUT_SIZE) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
|
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -116,7 +116,7 @@ Status LayerNormInfo::InferDevMatrixShape() {
|
||||||
MS_LOG(ERROR) << name_ << ": The strategy is null";
|
MS_LOG(ERROR) << name_ << ": The strategy is null";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
if (stra.empty()) {
|
if (stra.empty()) {
|
||||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -187,7 +187,7 @@ Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyP
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions input_strategy = sp->GetInputDim()[0];
|
Dimensions input_strategy = sp->GetInputDim()[0];
|
||||||
Dimensions gamma_strategy = input_strategy;
|
Dimensions gamma_strategy = input_strategy;
|
||||||
(void)gamma_strategy.erase(gamma_strategy.begin(),
|
(void)gamma_strategy.erase(gamma_strategy.begin(),
|
||||||
|
|
|
@ -85,14 +85,14 @@ std::vector<StrategyPtr> LinSpaceInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> LinSpaceInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> LinSpaceInfo::GenerateBatchStrategies() {
|
||||||
if (InferAttrs() != SUCCESS) {
|
if (InferAttrs() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t dev_num = g_device_manager->stage_device_num();
|
int64_t dev_num = g_device_manager->stage_device_num();
|
||||||
Strategys strategies = {Dimensions{dev_num}};
|
Strategies strategies = {Dimensions{dev_num}};
|
||||||
return std::make_shared<Strategys>(strategies);
|
return std::make_shared<Strategies>(strategies);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t LinSpaceInfo::GetSplitNum() {
|
int64_t LinSpaceInfo::GetSplitNum() {
|
||||||
|
|
|
@ -38,7 +38,7 @@ class LinSpaceInfo : public OperatorInfo {
|
||||||
|
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -32,7 +32,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
Dimensions label_strategy = stra.at(1);
|
Dimensions label_strategy = stra.at(1);
|
||||||
if (input_strategy != label_strategy) {
|
if (input_strategy != label_strategy) {
|
||||||
|
@ -69,7 +69,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() {
|
Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
|
|
@ -43,7 +43,7 @@ namespace parallel {
|
||||||
* Only bs and num_heads can be splited, thus the q[0] should at least be size_per_head,
|
* Only bs and num_heads can be splited, thus the q[0] should at least be size_per_head,
|
||||||
* q[1] should at least be seq_len // 16. The strategy check can use bs/head from attrs.
|
* q[1] should at least be seq_len // 16. The strategy check can use bs/head from attrs.
|
||||||
*/
|
*/
|
||||||
Status MatmulDDSInfo::CheckStrategys(const Strategys &stras) {
|
Status MatmulDDSInfo::CheckStrategys(const Strategies &stras) {
|
||||||
if (stras.size() != MATMUL_DDS_INPUTS_SIZE) {
|
if (stras.size() != MATMUL_DDS_INPUTS_SIZE) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 4.";
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 4.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -106,7 +106,7 @@ Status MatmulDDSInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys stras = strategy->GetInputDim();
|
Strategies stras = strategy->GetInputDim();
|
||||||
if (CheckStrategys(stras) != SUCCESS) {
|
if (CheckStrategys(stras) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ Status MatmulDDSInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
* device matrix is extended by the strategy0.
|
* device matrix is extended by the strategy0.
|
||||||
*/
|
*/
|
||||||
Status MatmulDDSInfo::InferDevMatrixShape() {
|
Status MatmulDDSInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
input_strategy_ = input_strategy;
|
input_strategy_ = input_strategy;
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
|
@ -288,7 +288,7 @@ std::vector<StrategyPtr> MatmulDDSInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions q_strategy = sp->GetInputDim()[0];
|
Dimensions q_strategy = sp->GetInputDim()[0];
|
||||||
Dimensions k_strategy = q_strategy;
|
Dimensions k_strategy = q_strategy;
|
||||||
Dimensions local_mask_strategy = {1, q_strategy[0], 1, 1};
|
Dimensions local_mask_strategy = {1, q_strategy[0], 1, 1};
|
||||||
|
|
|
@ -50,7 +50,7 @@ class MatmulDDSInfo : public OperatorInfo {
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
Status InferAsLossDivisor() override { return SUCCESS; }
|
Status InferAsLossDivisor() override { return SUCCESS; }
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
Status CheckStrategys(const Strategys &stras);
|
Status CheckStrategys(const Strategies &stras);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Dimensions input_strategy_;
|
Dimensions input_strategy_;
|
||||||
|
|
|
@ -152,7 +152,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions mat_a_strategy = stra.at(0);
|
Dimensions mat_a_strategy = stra.at(0);
|
||||||
Dimensions mat_b_strategy = stra.at(1);
|
Dimensions mat_b_strategy = stra.at(1);
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys in_stra = strategy_->GetInputDim();
|
Strategies in_stra = strategy_->GetInputDim();
|
||||||
Dimensions x_strategy = in_stra.at(0);
|
Dimensions x_strategy = in_stra.at(0);
|
||||||
Dimensions w_strategy = in_stra.at(1);
|
Dimensions w_strategy = in_stra.at(1);
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
||||||
in_shard_c = w_strategy[1];
|
in_shard_c = w_strategy[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys out_stra = out_strategy->GetInputDim();
|
Strategies out_stra = out_strategy->GetInputDim();
|
||||||
Dimensions output_strategy = out_stra[0];
|
Dimensions output_strategy = out_stra[0];
|
||||||
|
|
||||||
int64_t out_shard_a_or_ab = output_strategy[0];
|
int64_t out_shard_a_or_ab = output_strategy[0];
|
||||||
|
@ -248,7 +248,7 @@ Status MatMul::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MatMulBase::InferDevMatrixShape() {
|
Status MatMulBase::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions mat_a_strategy = stra.at(0);
|
Dimensions mat_a_strategy = stra.at(0);
|
||||||
Dimensions mat_b_strategy = stra.at(1);
|
Dimensions mat_b_strategy = stra.at(1);
|
||||||
|
|
||||||
|
@ -460,7 +460,7 @@ std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys replace_strategy;
|
Strategies replace_strategy;
|
||||||
Dimensions tmp_strategy = sp->GetInputDim()[0];
|
Dimensions tmp_strategy = sp->GetInputDim()[0];
|
||||||
Dimensions mat_a_strategy = tmp_strategy;
|
Dimensions mat_a_strategy = tmp_strategy;
|
||||||
mat_a_strategy.pop_back();
|
mat_a_strategy.pop_back();
|
||||||
|
@ -483,11 +483,11 @@ std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
|
||||||
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
|
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
|
||||||
(void)batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
|
(void)batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
|
||||||
Strategys strategy_v = {batch_strategy, batch_strategy};
|
Strategies strategy_v = {batch_strategy, batch_strategy};
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
Status MatMulBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
|
|
@ -86,7 +86,7 @@ class BatchMatMulInfo : public MatMul {
|
||||||
: MatMul(name, inputs_shape, outputs_shape, attrs) {}
|
: MatMul(name, inputs_shape, outputs_shape, attrs) {}
|
||||||
~BatchMatMulInfo() override = default;
|
~BatchMatMulInfo() override = default;
|
||||||
|
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -70,7 +70,7 @@ Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status OneHotInfo::InferDevMatrixShape() {
|
Status OneHotInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
|
||||||
if (axis_ == 0) {
|
if (axis_ == 0) {
|
||||||
|
@ -235,11 +235,11 @@ std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
|
||||||
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
|
||||||
std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> OneHotInfo::GenerateBatchStrategies() {
|
||||||
Dimensions strategy = {stage_device_size_, 1};
|
Dimensions strategy = {stage_device_size_, 1};
|
||||||
Dimensions empty_strategy;
|
Dimensions empty_strategy;
|
||||||
Strategys strategy_v = {strategy, empty_strategy, empty_strategy};
|
Strategies strategy_v = {strategy, empty_strategy, empty_strategy};
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
Shapes OneHotInfo::InferParamStrategy(const Shapes &default_strategy) {
|
Shapes OneHotInfo::InferParamStrategy(const Shapes &default_strategy) {
|
||||||
|
|
|
@ -39,7 +39,7 @@ class OneHotInfo : public OperatorInfo {
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
Shapes InferParamStrategy(const Shapes &default_strategy) override;
|
Shapes InferParamStrategy(const Shapes &default_strategy) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -91,7 +91,7 @@ struct OutStrategyValueRegister {
|
||||||
} out_regist;
|
} out_regist;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::string StrategyToString(const Strategys &strategy) {
|
std::string StrategyToString(const Strategies &strategy) {
|
||||||
std::string strategy_str = "";
|
std::string strategy_str = "";
|
||||||
strategy_str += "(";
|
strategy_str += "(";
|
||||||
for (size_t i = 0; i < strategy.size(); ++i) {
|
for (size_t i = 0; i < strategy.size(); ++i) {
|
||||||
|
@ -130,7 +130,7 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape
|
||||||
|
|
||||||
size_t strategy_size = strategy->GetInputNumber();
|
size_t strategy_size = strategy->GetInputNumber();
|
||||||
size_t inputs_shape_size = inputs_shape.size();
|
size_t inputs_shape_size = inputs_shape.size();
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (strategy_size != inputs_shape_size) {
|
if (strategy_size != inputs_shape_size) {
|
||||||
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
|
||||||
<< " is not equal to inputs size: " << inputs_shape_size;
|
<< " is not equal to inputs size: " << inputs_shape_size;
|
||||||
|
@ -790,7 +790,7 @@ Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
|
||||||
return slice_shape;
|
return slice_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) {
|
Status InferSliceShapeByStrategy(const Strategies &strategys, const Shapes &shapes, Shapes *slice_shapes) {
|
||||||
if (slice_shapes == nullptr) {
|
if (slice_shapes == nullptr) {
|
||||||
MS_LOG(ERROR) << "The slice_shapes is null.";
|
MS_LOG(ERROR) << "The slice_shapes is null.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -829,7 +829,7 @@ Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shape
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
|
Status OperatorInfo::InferSliceShape(const Strategies &inputs_strategy, const Strategies &outputs_strategy,
|
||||||
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) {
|
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) {
|
||||||
if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) {
|
if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) {
|
||||||
MS_LOG(ERROR) << name_ << ": The slice_shape is null.";
|
MS_LOG(ERROR) << name_ << ": The slice_shape is null.";
|
||||||
|
@ -1061,8 +1061,8 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op,
|
||||||
succ_edges_ = update_pre_edges;
|
succ_edges_ = update_pre_edges;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
|
std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
|
||||||
const std::vector<bool> &split_flag_list) {
|
const std::vector<bool> &split_flag_list) {
|
||||||
if (shapes.size() != split_flag_list.size()) {
|
if (shapes.size() != split_flag_list.size()) {
|
||||||
MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
|
MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
|
||||||
<< shapes.size();
|
<< shapes.size();
|
||||||
|
@ -1070,7 +1070,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap
|
||||||
}
|
}
|
||||||
CheckGlobalDeviceManager();
|
CheckGlobalDeviceManager();
|
||||||
int64_t dev_num = g_device_manager->stage_device_num();
|
int64_t dev_num = g_device_manager->stage_device_num();
|
||||||
Strategys strategy_v;
|
Strategies strategy_v;
|
||||||
for (size_t i = 0; i != shapes.size(); i++) {
|
for (size_t i = 0; i != shapes.size(); i++) {
|
||||||
if (shapes[i].empty()) {
|
if (shapes[i].empty()) {
|
||||||
MS_LOG(INFO) << "Elements of shapes is empty.";
|
MS_LOG(INFO) << "Elements of shapes is empty.";
|
||||||
|
@ -1084,7 +1084,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap
|
||||||
strategy_v.push_back(element);
|
strategy_v.push_back(element);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OperatorInfo::ReComputeBatchSplitFlagList() {
|
void OperatorInfo::ReComputeBatchSplitFlagList() {
|
||||||
|
@ -1122,12 +1122,12 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Strategys stras(inputs_partitions);
|
Strategies stras(inputs_partitions);
|
||||||
(*sp) = std::make_shared<Strategy>(stage_id, stras);
|
(*sp) = std::make_shared<Strategy>(stage_id, stras);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategies() {
|
||||||
if (inputs_shape_.empty() && InferAttrs() != SUCCESS) {
|
if (inputs_shape_.empty() && InferAttrs() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||||
}
|
}
|
||||||
|
@ -1212,7 +1212,7 @@ Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs
|
||||||
|
|
||||||
// second, get the correct strategy for input0
|
// second, get the correct strategy for input0
|
||||||
for (auto &sp : *sp_vector) {
|
for (auto &sp : *sp_vector) {
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions input0_strategy = sp->GetInputDim()[0];
|
Dimensions input0_strategy = sp->GetInputDim()[0];
|
||||||
size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
|
size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
|
||||||
|
|
||||||
|
@ -1261,7 +1261,7 @@ Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &input
|
||||||
|
|
||||||
// second, get the correct strategy for input1
|
// second, get the correct strategy for input1
|
||||||
for (auto &sp : *sp_vector) {
|
for (auto &sp : *sp_vector) {
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
tmp_strategy.push_back(sp->GetInputDim()[0]); // input0
|
tmp_strategy.push_back(sp->GetInputDim()[0]); // input0
|
||||||
|
|
||||||
Dimensions input1_strategy = sp->GetInputDim()[1];
|
Dimensions input1_strategy = sp->GetInputDim()[1];
|
||||||
|
@ -1503,7 +1503,7 @@ Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inpu
|
||||||
[stage_id, &indices_mp, &splittable_inputs](const StrategyPtr &sp) {
|
[stage_id, &indices_mp, &splittable_inputs](const StrategyPtr &sp) {
|
||||||
auto sp_strategies = sp->GetInputDim();
|
auto sp_strategies = sp->GetInputDim();
|
||||||
auto sp_sub_strategy = sp_strategies.at(0);
|
auto sp_sub_strategy = sp_strategies.at(0);
|
||||||
Strategys strategies(splittable_inputs);
|
Strategies strategies(splittable_inputs);
|
||||||
for (size_t i = 0; i < strategies.size(); ++i) {
|
for (size_t i = 0; i < strategies.size(); ++i) {
|
||||||
for (size_t j = 0; j < strategies[i].size(); ++j) {
|
for (size_t j = 0; j < strategies[i].size(); ++j) {
|
||||||
if (splittable_inputs[i][j] == 0) {
|
if (splittable_inputs[i][j] == 0) {
|
||||||
|
|
|
@ -97,7 +97,7 @@ class OperatorInfo {
|
||||||
virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
|
virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
|
||||||
Shapes GenerateParamStrategy(const Shapes &default_strategy);
|
Shapes GenerateParamStrategy(const Shapes &default_strategy);
|
||||||
|
|
||||||
virtual std::shared_ptr<Strategys> GenerateBatchStrategies();
|
virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
|
||||||
virtual void ReComputeBatchSplitFlagList();
|
virtual void ReComputeBatchSplitFlagList();
|
||||||
void ComputeBatchSplitFlagList();
|
void ComputeBatchSplitFlagList();
|
||||||
|
|
||||||
|
@ -242,7 +242,7 @@ class OperatorInfo {
|
||||||
// The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output
|
// The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output
|
||||||
// is used for grad and overload the function. If the output is a scalar, need to override the function too.
|
// is used for grad and overload the function. If the output is a scalar, need to override the function too.
|
||||||
virtual Status InferAsLossDivisor();
|
virtual Status InferAsLossDivisor();
|
||||||
Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
|
Status InferSliceShape(const Strategies &inputs_strategy, const Strategies &outputs_strategy,
|
||||||
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape);
|
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape);
|
||||||
void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &);
|
void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &);
|
||||||
int64_t GetIntAttr(const std::string &attr_name);
|
int64_t GetIntAttr(const std::string &attr_name);
|
||||||
|
@ -351,9 +351,9 @@ void AddCommOpParamFlag(const CNodePtr &comm_node);
|
||||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
|
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
|
||||||
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
|
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
|
||||||
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
|
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
|
std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
|
||||||
const std::vector<bool> &split_flag_list);
|
const std::vector<bool> &split_flag_list);
|
||||||
std::string StrategyToString(const Strategys &strategy);
|
std::string StrategyToString(const Strategies &strategy);
|
||||||
void PrintStrategy(const StrategyPtr &strategy);
|
void PrintStrategy(const StrategyPtr &strategy);
|
||||||
Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape,
|
Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape,
|
||||||
const Shapes &splittable_inputs,
|
const Shapes &splittable_inputs,
|
||||||
|
|
|
@ -150,7 +150,7 @@ std::vector<StrategyPtr> StackInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
tmp_strategy.push_back(first_input_strategy);
|
tmp_strategy.push_back(first_input_strategy);
|
||||||
|
|
|
@ -37,7 +37,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) {
|
if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy size.";
|
MS_LOG(ERROR) << name_ << ": Invalid strategy size.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -53,7 +53,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
* device matrix is same with the strategy matrix
|
* device matrix is same with the strategy matrix
|
||||||
*/
|
*/
|
||||||
Status PReLUInfo::InferDevMatrixShape() {
|
Status PReLUInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
input_strategy_ = input_strategy;
|
input_strategy_ = input_strategy;
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
|
|
|
@ -36,7 +36,7 @@ Status RandomChoiceWithMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategies = strategy->GetInputDim();
|
Strategies strategies = strategy->GetInputDim();
|
||||||
Dimensions input_strategy = strategies[0];
|
Dimensions input_strategy = strategies[0];
|
||||||
auto is_shard = [](int64_t val) -> bool { return val != 1; };
|
auto is_shard = [](int64_t val) -> bool { return val != 1; };
|
||||||
if (std::any_of(input_strategy.begin(), input_strategy.end(), is_shard)) {
|
if (std::any_of(input_strategy.begin(), input_strategy.end(), is_shard)) {
|
||||||
|
@ -66,7 +66,7 @@ Status RandomChoiceWithMaskInfo::InferTensorMap() {
|
||||||
|
|
||||||
std::vector<StrategyPtr> RandomChoiceWithMaskInfo::GenerateOpStrategies(int64_t stage_id) {
|
std::vector<StrategyPtr> RandomChoiceWithMaskInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
Dimensions input_partitions(inputs_shape_[0].size(), 1);
|
Dimensions input_partitions(inputs_shape_[0].size(), 1);
|
||||||
Strategys strategies = {input_partitions};
|
Strategies strategies = {input_partitions};
|
||||||
std::vector<StrategyPtr> sp_vector;
|
std::vector<StrategyPtr> sp_vector;
|
||||||
(void)sp_vector.emplace_back(std::make_shared<Strategy>(stage_id, strategies));
|
(void)sp_vector.emplace_back(std::make_shared<Strategy>(stage_id, strategies));
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
|
|
|
@ -65,7 +65,7 @@ Status RangeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RangeInfo::InferDevMatrixShape() {
|
Status RangeInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
dev_matrix_shape_ = stra[0];
|
dev_matrix_shape_ = stra[0];
|
||||||
split_num_ = stra[0][0];
|
split_num_ = stra[0][0];
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
|
|
@ -33,7 +33,7 @@ namespace parallel {
|
||||||
Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
||||||
|
|
||||||
Status ReduceMethod::InferDevMatrixShape() {
|
Status ReduceMethod::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
|
@ -400,10 +400,10 @@ Status ReduceMethod::InferTensorInfo() {
|
||||||
|
|
||||||
// infer slice shape
|
// infer slice shape
|
||||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
Strategies inputs_strategy = strategy_->GetInputDim();
|
||||||
Dimensions output_strategy = InferOutputStrategy();
|
Dimensions output_strategy = InferOutputStrategy();
|
||||||
|
|
||||||
Strategys outputs_strategy = {output_strategy};
|
Strategies outputs_strategy = {output_strategy};
|
||||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -478,7 +478,7 @@ Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
std::vector<int64_t> dim_list = reduce_dim();
|
std::vector<int64_t> dim_list = reduce_dim();
|
||||||
MS_ASSERT(dim_list.size() == 1);
|
MS_ASSERT(dim_list.size() == 1);
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
MS_ASSERT(stra.size() == 1);
|
MS_ASSERT(stra.size() == 1);
|
||||||
Shape input_strategy = stra.at(0);
|
Shape input_strategy = stra.at(0);
|
||||||
MS_ASSERT(dim_list.at(0) < input_strategy.size());
|
MS_ASSERT(dim_list.at(0) < input_strategy.size());
|
||||||
|
@ -510,10 +510,10 @@ Status ArgMaxWithValueInfo::InferTensorInfo() {
|
||||||
|
|
||||||
// infer slice shape
|
// infer slice shape
|
||||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
Strategies inputs_strategy = strategy_->GetInputDim();
|
||||||
Dimensions output_strategy = InferOutputStrategy();
|
Dimensions output_strategy = InferOutputStrategy();
|
||||||
|
|
||||||
Strategys outputs_strategy = {output_strategy, output_strategy};
|
Strategies outputs_strategy = {output_strategy, output_strategy};
|
||||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -627,7 +627,7 @@ Status ArgmaxInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
std::vector<int64_t> dim_list = reduce_dim();
|
std::vector<int64_t> dim_list = reduce_dim();
|
||||||
MS_ASSERT(dim_list.size() == 1);
|
MS_ASSERT(dim_list.size() == 1);
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
MS_ASSERT(stra.size() == 1);
|
MS_ASSERT(stra.size() == 1);
|
||||||
Shape input_strategy = stra.at(0);
|
Shape input_strategy = stra.at(0);
|
||||||
MS_ASSERT(dim_list.at(0) < input_strategy.size());
|
MS_ASSERT(dim_list.at(0) < input_strategy.size());
|
||||||
|
@ -689,7 +689,7 @@ Status SquareSumAllInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions sub_a_strategy = stra.at(0);
|
Dimensions sub_a_strategy = stra.at(0);
|
||||||
Dimensions sub_b_strategy = stra.at(1);
|
Dimensions sub_b_strategy = stra.at(1);
|
||||||
Shape input_a_shape = inputs_shape_.at(0);
|
Shape input_a_shape = inputs_shape_.at(0);
|
||||||
|
@ -707,7 +707,7 @@ Status SquareSumAllInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SquareSumAllInfo::InferDevMatrixShape() {
|
Status SquareSumAllInfo::InferDevMatrixShape() {
|
||||||
Strategys strategy = strategy_->GetInputDim();
|
Strategies strategy = strategy_->GetInputDim();
|
||||||
Dimensions sub_a_strategy = strategy.at(0);
|
Dimensions sub_a_strategy = strategy.at(0);
|
||||||
Shape dev_shape;
|
Shape dev_shape;
|
||||||
for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
|
for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
|
||||||
|
@ -736,10 +736,10 @@ Status SquareSumAllInfo::InferTensorInfo() {
|
||||||
|
|
||||||
// infer slice shape
|
// infer slice shape
|
||||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
Strategies inputs_strategy = strategy_->GetInputDim();
|
||||||
Dimensions output_strategy = InferOutputStrategy();
|
Dimensions output_strategy = InferOutputStrategy();
|
||||||
|
|
||||||
Strategys outputs_strategy = {output_strategy, output_strategy};
|
Strategies outputs_strategy = {output_strategy, output_strategy};
|
||||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
if (input_strategy[1] != 1) {
|
if (input_strategy[1] != 1) {
|
||||||
MS_LOG(ERROR) << name_ << "The second dimension is not splitable.";
|
MS_LOG(ERROR) << name_ << "The second dimension is not splitable.";
|
||||||
|
@ -65,7 +65,7 @@ std::vector<StrategyPtr> ReLUV2Info::GenerateOpStrategies(int64_t stage_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ReLUV2Info::InferDevMatrixShape() {
|
Status ReLUV2Info::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
|
|
|
@ -37,7 +37,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStr
|
||||||
* only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
|
* only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
|
||||||
*/
|
*/
|
||||||
Status ReshapeInfo::InferDevMatrixShape() {
|
Status ReshapeInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
input_strategy_ = stra.at(0);
|
input_strategy_ = stra.at(0);
|
||||||
dev_matrix_shape_ = stra.at(0);
|
dev_matrix_shape_ = stra.at(0);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
@ -195,8 +195,8 @@ Status ReshapeInfo::InferTensorMap() {
|
||||||
* the output tensor strategy is the same as input tensor strategy
|
* the output tensor strategy is the same as input tensor strategy
|
||||||
* only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
|
* only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
|
||||||
*/
|
*/
|
||||||
Strategys ReshapeInfo::GetOutputsStrategy() {
|
Strategies ReshapeInfo::GetOutputsStrategy() {
|
||||||
Strategys outputs_strategy;
|
Strategies outputs_strategy;
|
||||||
Dimensions strategy;
|
Dimensions strategy;
|
||||||
for (size_t j = 0; j < outputs_shape_[0].size(); ++j) {
|
for (size_t j = 0; j < outputs_shape_[0].size(); ++j) {
|
||||||
strategy.push_back(1);
|
strategy.push_back(1);
|
||||||
|
@ -269,8 +269,8 @@ Status ReshapeInfo::InferTensorInfo() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
Strategies inputs_strategy = strategy_->GetInputDim();
|
||||||
Strategys outputs_strategy = GetOutputsStrategy();
|
Strategies outputs_strategy = GetOutputsStrategy();
|
||||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -460,7 +460,7 @@ Status ReshapeInfo::GenerateStrategyCosts(const std::vector<std::shared_ptr<Stra
|
||||||
MS_LOG(ERROR) << "Infer strategy by tensor_info failed";
|
MS_LOG(ERROR) << "Infer strategy by tensor_info failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys stra_inputs = {stra};
|
Strategies stra_inputs = {stra};
|
||||||
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
|
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
|
||||||
if (is_next_reshape) {
|
if (is_next_reshape) {
|
||||||
SetOutputLayout(pre_out_tensor_info.tensor_layout());
|
SetOutputLayout(pre_out_tensor_info.tensor_layout());
|
||||||
|
|
|
@ -87,7 +87,7 @@ class ReshapeInfo : public OperatorInfo {
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
|
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
Strategys GetOutputsStrategy();
|
Strategies GetOutputsStrategy();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status GetParameterInput();
|
Status GetParameterInput();
|
||||||
|
|
|
@ -40,7 +40,7 @@ Status ROIAlignInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategies = strategy->GetInputDim();
|
Strategies strategies = strategy->GetInputDim();
|
||||||
auto features_strategy = strategies.at(0);
|
auto features_strategy = strategies.at(0);
|
||||||
auto rois_strategy = strategies.at(1);
|
auto rois_strategy = strategies.at(1);
|
||||||
if (features_strategy[2] != 1 || features_strategy[3] != 1) {
|
if (features_strategy[2] != 1 || features_strategy[3] != 1) {
|
||||||
|
|
|
@ -158,7 +158,7 @@ std::vector<StrategyPtr> ScatterUpdateInfo::GenerateOpStrategies(int64_t stage_i
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
Dimensions indices_strategy(inputs_shape_[1].size(), 1);
|
Dimensions indices_strategy(inputs_shape_[1].size(), 1);
|
||||||
// updates_strategy = indices_strategy + input_strategy[1:]
|
// updates_strategy = indices_strategy + input_strategy[1:]
|
||||||
|
|
|
@ -115,7 +115,7 @@ std::vector<StrategyPtr> SelectInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
}
|
}
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
tmp_strategy.push_back(first_input_strategy);
|
tmp_strategy.push_back(first_input_strategy);
|
||||||
|
|
|
@ -149,7 +149,7 @@ Status SliceInfo::InferMirrorOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
|
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
|
||||||
std::shared_ptr<Strategys> SliceInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> SliceInfo::GenerateBatchStrategies() {
|
||||||
split_flag_list_ = {true};
|
split_flag_list_ = {true};
|
||||||
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
|
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ class SliceInfo : public OperatorInfo {
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
|
|
|
@ -144,7 +144,7 @@ std::vector<StrategyPtr> SplitInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> SplitInfo::GenerateBatchStrategies() {
|
||||||
if (GetAttrs() != SUCCESS) {
|
if (GetAttrs() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||||
}
|
}
|
||||||
|
@ -157,8 +157,8 @@ std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() {
|
||||||
input_strategy[0] = stage_device_size_;
|
input_strategy[0] = stage_device_size_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Strategys strategy_v = {input_strategy};
|
Strategies strategy_v = {input_strategy};
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SplitInfo::InferAsLossDivisor() {
|
Status SplitInfo::InferAsLossDivisor() {
|
||||||
|
|
|
@ -36,7 +36,7 @@ class SplitInfo : public OperatorInfo {
|
||||||
~SplitInfo() override = default;
|
~SplitInfo() override = default;
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -249,7 +249,7 @@ Status StridedSliceInfo::InferMirrorOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
|
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
|
||||||
std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> StridedSliceInfo::GenerateBatchStrategies() {
|
||||||
split_flag_list_ = {true};
|
split_flag_list_ = {true};
|
||||||
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
|
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ class StridedSliceInfo : public OperatorInfo {
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
void ComputeBeginMask(int64_t begin_mask_);
|
void ComputeBeginMask(int64_t begin_mask_);
|
||||||
void ComputeEndMask(int64_t end_mask_);
|
void ComputeEndMask(int64_t end_mask_);
|
||||||
void ComputeEllipsisMask(int64_t ellipsis_mask_);
|
void ComputeEllipsisMask(int64_t ellipsis_mask_);
|
||||||
|
|
|
@ -117,7 +117,7 @@ Status TensorDotInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (stra.size() != 2) {
|
if (stra.size() != 2) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
|
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -148,7 +148,7 @@ Status TensorDotInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorDotInfo::InferDevMatrixShape() {
|
Status TensorDotInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_a_strategy = stra.at(0);
|
Dimensions input_a_strategy = stra.at(0);
|
||||||
Dimensions input_b_strategy = stra.at(1);
|
Dimensions input_b_strategy = stra.at(1);
|
||||||
|
|
||||||
|
@ -306,7 +306,7 @@ Status TensorDotInfo::InferTensorMap() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> TensorDotInfo::GenerateBatchStrategies() {
|
||||||
if (GetAttrs() != SUCCESS) {
|
if (GetAttrs() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||||
}
|
}
|
||||||
|
@ -339,8 +339,8 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE";
|
MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE";
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys strategy = {input_a_strategy, input_b_strategy};
|
Strategies strategy = {input_a_strategy, input_b_strategy};
|
||||||
return std::make_shared<Strategys>(strategy);
|
return std::make_shared<Strategies>(strategy);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<StrategyPtr> TensorDotInfo::GenerateOpStrategies(int64_t) {
|
std::vector<StrategyPtr> TensorDotInfo::GenerateOpStrategies(int64_t) {
|
||||||
|
|
|
@ -45,7 +45,7 @@ class TensorDotInfo : public OperatorInfo {
|
||||||
~TensorDotInfo() override = default;
|
~TensorDotInfo() override = default;
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
|
Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
|
||||||
size_t input1_shape_size, StrategyPtr *sp);
|
size_t input1_shape_size, StrategyPtr *sp);
|
||||||
|
|
|
@ -179,7 +179,7 @@ void TileInfo::UpdateMultiples() {
|
||||||
|
|
||||||
void TileInfo::ReplaceNodeInputOrAttrs() { UpdateMultiples(); }
|
void TileInfo::ReplaceNodeInputOrAttrs() { UpdateMultiples(); }
|
||||||
|
|
||||||
std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> TileInfo::GenerateBatchStrategies() {
|
||||||
if (InferAttrs() != SUCCESS) {
|
if (InferAttrs() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ class TileInfo : public OperatorInfo {
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
void UpdateMultiples();
|
void UpdateMultiples();
|
||||||
void ReplaceNodeInputOrAttrs() override;
|
void ReplaceNodeInputOrAttrs() override;
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &st
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TmpIdentityInfo::InferDevMatrixShape() {
|
Status TmpIdentityInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
dev_matrix_shape_ = input_strategy;
|
dev_matrix_shape_ = input_strategy;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace parallel {
|
||||||
Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
||||||
|
|
||||||
Status TransposeInfo::InferDevMatrixShape() {
|
Status TransposeInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
input_strategy_ = stra.at(0);
|
input_strategy_ = stra.at(0);
|
||||||
for (auto &iter : input_strategy_) {
|
for (auto &iter : input_strategy_) {
|
||||||
dev_matrix_shape_.push_back(iter);
|
dev_matrix_shape_.push_back(iter);
|
||||||
|
|
|
@ -171,14 +171,14 @@ std::vector<StrategyPtr> UniformCandidateSamplerInfo::GenerateOpStrategies(int64
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> UniformCandidateSamplerInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> UniformCandidateSamplerInfo::GenerateBatchStrategies() {
|
||||||
if (GetAttrs() != SUCCESS) {
|
if (GetAttrs() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||||
}
|
}
|
||||||
CheckGlobalDeviceManager();
|
CheckGlobalDeviceManager();
|
||||||
Dimensions input_strategy(inputs_shape_[0].size(), 1);
|
Dimensions input_strategy(inputs_shape_[0].size(), 1);
|
||||||
Strategys strategy_v = {input_strategy};
|
Strategies strategy_v = {input_strategy};
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
|
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
|
|
|
@ -44,7 +44,7 @@ class UniformCandidateSamplerInfo : public OperatorInfo {
|
||||||
~UniformCandidateSamplerInfo() override = default;
|
~UniformCandidateSamplerInfo() override = default;
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||||
Status InferAsLossDivisor() override;
|
Status InferAsLossDivisor() override;
|
||||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
|
|
|
@ -60,7 +60,7 @@ Status UniqueInfo::InferDevMatrixShape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) {
|
Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
Strategys stras = strategy->GetInputDim();
|
Strategies stras = strategy->GetInputDim();
|
||||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,7 +73,7 @@ Status UnsortedSegmentOpInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
if (CheckStrategyValue(strategy, {inputs_shape_.at(0), inputs_shape_.at(1)}) != SUCCESS) {
|
if (CheckStrategyValue(strategy, {inputs_shape_.at(0), inputs_shape_.at(1)}) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
Dimensions sub_a_strategy = stra.at(0);
|
Dimensions sub_a_strategy = stra.at(0);
|
||||||
Dimensions sub_b_strategy = stra.at(1);
|
Dimensions sub_b_strategy = stra.at(1);
|
||||||
Shape input_a_shape = inputs_shape_.at(0);
|
Shape input_a_shape = inputs_shape_.at(0);
|
||||||
|
@ -91,7 +91,7 @@ Status UnsortedSegmentOpInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status UnsortedSegmentOpInfo::InferDevMatrixShape() {
|
Status UnsortedSegmentOpInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
dev_matrix_shape_ = stra.at(0);
|
dev_matrix_shape_ = stra.at(0);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,7 @@ std::vector<StrategyPtr> UnsortedSegmentOpInfo::GenerateOpStrategies(int64_t sta
|
||||||
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
|
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||||
}
|
}
|
||||||
for (auto &sp : sp_vector) {
|
for (auto &sp : sp_vector) {
|
||||||
Strategys tmp_strategy;
|
Strategies tmp_strategy;
|
||||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
Dimensions second_input_strategy;
|
Dimensions second_input_strategy;
|
||||||
for (size_t i = 0; i < inputs_shape_[1].size(); ++i) {
|
for (size_t i = 0; i < inputs_shape_[1].size(); ++i) {
|
||||||
|
@ -214,7 +214,7 @@ Status UnsortedSegmentOpInfo::SetCostUnderStrategy(const StrategyPtr &strategy)
|
||||||
return SetCostUnderStrategyBase(strategy);
|
return SetCostUnderStrategyBase(strategy);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategies> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
|
||||||
if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) {
|
if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is "
|
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is "
|
||||||
<< inputs_shape_.size();
|
<< inputs_shape_.size();
|
||||||
|
@ -233,8 +233,8 @@ std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
|
||||||
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
|
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
|
||||||
strategy_b.push_back(1);
|
strategy_b.push_back(1);
|
||||||
}
|
}
|
||||||
Strategys strategy_v = {strategy_a, strategy_b};
|
Strategies strategy_v = {strategy_a, strategy_b};
|
||||||
return std::make_shared<Strategys>(strategy_v);
|
return std::make_shared<Strategies>(strategy_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
// When the index is splited, the graph should be replaced
|
// When the index is splited, the graph should be replaced
|
||||||
|
|
|
@ -47,7 +47,7 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
std::shared_ptr<Strategies> GenerateBatchStrategies() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::string reduce_method_;
|
std::string reduce_method_;
|
||||||
|
|
|
@ -35,7 +35,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (stra.size() < 1) {
|
if (stra.size() < 1) {
|
||||||
MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1.";
|
MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -84,7 +84,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status VirtualDatasetInfo::InferDevMatrixShape() {
|
Status VirtualDatasetInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategies stra = strategy_->GetInputDim();
|
||||||
dev_matrix_shape_ = stra[max_size_strategy_dim_];
|
dev_matrix_shape_ = stra[max_size_strategy_dim_];
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||||
std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t stage_id) {
|
std::vector<StrategyPtr> VirtualDatasetInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||||
StrategyPtr sp;
|
StrategyPtr sp;
|
||||||
Strategys strategy;
|
Strategies strategy;
|
||||||
if (!ParallelContext::GetInstance()->dataset_strategy().empty()) {
|
if (!ParallelContext::GetInstance()->dataset_strategy().empty()) {
|
||||||
strategy = ParallelContext::GetInstance()->dataset_strategy();
|
strategy = ParallelContext::GetInstance()->dataset_strategy();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -33,9 +33,9 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
Strategys stra = strategy->GetInputDim();
|
Strategies stra = strategy->GetInputDim();
|
||||||
if (stra.size() != 1) {
|
if (stra.size() != 1) {
|
||||||
MS_LOG(ERROR) << name_ << ": Strategys size must be 1.";
|
MS_LOG(ERROR) << name_ << ": Strategies size must be 1.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Dimensions strategy_first = stra.at(0);
|
Dimensions strategy_first = stra.at(0);
|
||||||
|
@ -53,7 +53,7 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
|
||||||
std::vector<StrategyPtr> VirtualOutputInfo::GenerateOpStrategies(int64_t stage_id) {
|
std::vector<StrategyPtr> VirtualOutputInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
StrategyPtr sp;
|
StrategyPtr sp;
|
||||||
Strategys strategy;
|
Strategies strategy;
|
||||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||||
size_t total_dev_num;
|
size_t total_dev_num;
|
||||||
if (full_batch) {
|
if (full_batch) {
|
||||||
|
|
|
@ -1250,7 +1250,7 @@ StrategyPtr ExtractStrategy(const ValuePtr &stra) {
|
||||||
MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
|
MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
|
||||||
if (var->size() > 0) {
|
if (var->size() > 0) {
|
||||||
std::vector<ValuePtr> elements = var->value();
|
std::vector<ValuePtr> elements = var->value();
|
||||||
Strategys strategy;
|
Strategies strategy;
|
||||||
for (uint64_t index = 0; index < elements.size(); ++index) {
|
for (uint64_t index = 0; index < elements.size(); ++index) {
|
||||||
Dimensions dim;
|
Dimensions dim;
|
||||||
if (elements[index]->isa<ValueSequence>()) {
|
if (elements[index]->isa<ValueSequence>()) {
|
||||||
|
@ -1725,7 +1725,7 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
|
||||||
MS_EXCEPTION_IF_NULL(operator_);
|
MS_EXCEPTION_IF_NULL(operator_);
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
StrategyPtr strategyPtr;
|
StrategyPtr strategyPtr;
|
||||||
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
|
std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategies();
|
||||||
MS_EXCEPTION_IF_NULL(strategy_v_ptr);
|
MS_EXCEPTION_IF_NULL(strategy_v_ptr);
|
||||||
strategyPtr = NewStrategy(0, *strategy_v_ptr);
|
strategyPtr = NewStrategy(0, *strategy_v_ptr);
|
||||||
std::vector<ValuePtr> elements;
|
std::vector<ValuePtr> elements;
|
||||||
|
|
|
@ -31,13 +31,13 @@ namespace parallel {
|
||||||
#define MIN_SLICE_NUM 1
|
#define MIN_SLICE_NUM 1
|
||||||
|
|
||||||
using Dimensions = Shape;
|
using Dimensions = Shape;
|
||||||
using Strategys = std::vector<Dimensions>;
|
using Strategies = std::vector<Dimensions>;
|
||||||
class Strategy;
|
class Strategy;
|
||||||
using StrategyPtr = std::shared_ptr<Strategy>;
|
using StrategyPtr = std::shared_ptr<Strategy>;
|
||||||
|
|
||||||
class Strategy {
|
class Strategy {
|
||||||
public:
|
public:
|
||||||
Strategy(int64_t stage, Strategys inputs)
|
Strategy(int64_t stage, Strategies inputs)
|
||||||
: stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
|
: stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
|
||||||
|
|
||||||
Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
|
Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
|
||||||
|
@ -52,14 +52,14 @@ class Strategy {
|
||||||
|
|
||||||
~Strategy() = default;
|
~Strategy() = default;
|
||||||
size_t GetInputNumber() const { return inputs_.size(); }
|
size_t GetInputNumber() const { return inputs_.size(); }
|
||||||
Strategys GetInputDim() const { return inputs_; }
|
Strategies GetInputDim() const { return inputs_; }
|
||||||
int64_t GetInputStage() const { return stage_; }
|
int64_t GetInputStage() const { return stage_; }
|
||||||
void ExpandInputDimFromOneToTwo() {
|
void ExpandInputDimFromOneToTwo() {
|
||||||
if (inputs_.size() == 1) {
|
if (inputs_.size() == 1) {
|
||||||
inputs_.push_back(inputs_[0]);
|
inputs_.push_back(inputs_[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void ResetInputs(const Strategys &input) { inputs_ = input; }
|
void ResetInputs(const Strategies &input) { inputs_ = input; }
|
||||||
std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
|
std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
|
||||||
size_t GetInternalSize() const { return internal_size_; }
|
size_t GetInternalSize() const { return internal_size_; }
|
||||||
|
|
||||||
|
@ -103,12 +103,12 @@ class Strategy {
|
||||||
const int64_t stage_;
|
const int64_t stage_;
|
||||||
|
|
||||||
// The size of Dimensions must be equal to inputs_ tensor dimension.
|
// The size of Dimensions must be equal to inputs_ tensor dimension.
|
||||||
Strategys inputs_;
|
Strategies inputs_;
|
||||||
size_t internal_size_ = 0;
|
size_t internal_size_ = 0;
|
||||||
std::vector<StrategyPtr> internal_stragies_;
|
std::vector<StrategyPtr> internal_stragies_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline StrategyPtr NewStrategy(const int64_t stage, const Strategys &inputs) {
|
inline StrategyPtr NewStrategy(const int64_t stage, const Strategies &inputs) {
|
||||||
return std::make_shared<Strategy>(stage, inputs);
|
return std::make_shared<Strategy>(stage, inputs);
|
||||||
}
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -123,7 +123,7 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
||||||
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
|
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
|
||||||
auto stage = (int64_t)parallel_strategys.stage();
|
auto stage = (int64_t)parallel_strategys.stage();
|
||||||
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
|
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
|
||||||
Strategys strategy_inputs;
|
Strategies strategy_inputs;
|
||||||
for (size_t j = 0; j < strategys_num; j++) {
|
for (size_t j = 0; j < strategys_num; j++) {
|
||||||
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
|
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
|
||||||
Dimensions dimension;
|
Dimensions dimension;
|
||||||
|
|
|
@ -523,7 +523,7 @@ py::list GraphExecutorPy::GetParallelParameterNameList(const std::string &phase)
|
||||||
return mindspore::parallel::GetParallelParameterNameListFromGraph(graph);
|
return mindspore::parallel::GetParallelParameterNameListFromGraph(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
|
void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategies &strategy) {
|
||||||
MS_LOG(DEBUG) << "SetCNodeStrategy!";
|
MS_LOG(DEBUG) << "SetCNodeStrategy!";
|
||||||
stra_dict_[phase_][py::str(name)] = strategy;
|
stra_dict_[phase_][py::str(name)] = strategy;
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
||||||
py::dict GetParallelGraphInfo(const std::string &phase);
|
py::dict GetParallelGraphInfo(const std::string &phase);
|
||||||
py::dict GetCNodeStrategy(const std::string &phase);
|
py::dict GetCNodeStrategy(const std::string &phase);
|
||||||
py::list GetParallelParameterNameList(const std::string &phase);
|
py::list GetParallelParameterNameList(const std::string &phase);
|
||||||
void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy);
|
void SetCNodeStrategy(const std::string &name, const parallel::Strategies &strategy);
|
||||||
size_t GetNumOpsInfo(const std::string &phase);
|
size_t GetNumOpsInfo(const std::string &phase);
|
||||||
void SetNumOpsInfo(size_t);
|
void SetNumOpsInfo(size_t);
|
||||||
py::dict GetAllreduceFusion(const std::string &phase);
|
py::dict GetAllreduceFusion(const std::string &phase);
|
||||||
|
|
|
@ -40,7 +40,7 @@ constexpr auto kAnfConvWeight = 2;
|
||||||
constexpr auto kAnfConvBias = 3;
|
constexpr auto kAnfConvBias = 3;
|
||||||
int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
||||||
int split_count = 0;
|
int split_count = 0;
|
||||||
Strategys strategys = strategy.strategys;
|
Strategies strategys = strategy.strategys;
|
||||||
MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR);
|
MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR);
|
||||||
MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR);
|
MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR);
|
||||||
MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR);
|
MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR);
|
||||||
|
@ -281,7 +281,7 @@ std::shared_ptr<ops::Conv2DFusion> Conv2DInfo::GetNewConvPrimitive(const api::Sh
|
||||||
prim->set_pad_list(conv_prim->get_pad_list());
|
prim->set_pad_list(conv_prim->get_pad_list());
|
||||||
prim->set_stride(conv_prim->get_stride());
|
prim->set_stride(conv_prim->get_stride());
|
||||||
prim->set_activation_type(conv_prim->get_activation_type());
|
prim->set_activation_type(conv_prim->get_activation_type());
|
||||||
Strategys strategys = strategy_.strategys;
|
Strategies strategys = strategy_.strategys;
|
||||||
size_t dev_num = strategy_.dev_num;
|
size_t dev_num = strategy_.dev_num;
|
||||||
switch (split_mode_) {
|
switch (split_mode_) {
|
||||||
case SplitH: {
|
case SplitH: {
|
||||||
|
@ -327,7 +327,7 @@ int Conv2DInfo::ConstructOutputCNodes(const api::SharedPtr<ops::Conv2DFusion> &c
|
||||||
const std::vector<AnfNodePtr> &kernel_split_outputs,
|
const std::vector<AnfNodePtr> &kernel_split_outputs,
|
||||||
const std::vector<AnfNodePtr> &bias_split_outputs) {
|
const std::vector<AnfNodePtr> &bias_split_outputs) {
|
||||||
MS_ASSERT(conv_prim != nullptr);
|
MS_ASSERT(conv_prim != nullptr);
|
||||||
Strategys strategys = strategy_.strategys;
|
Strategies strategys = strategy_.strategys;
|
||||||
size_t dev_num = strategy_.dev_num;
|
size_t dev_num = strategy_.dev_num;
|
||||||
int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0);
|
int cin_strategy_sum = std::accumulate(strategys[0][kAxisCIn].begin(), strategys[0][kAxisCIn].end(), 0);
|
||||||
int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0);
|
int cout_strategy_sum = std::accumulate(strategys[1][kAxisCOut].begin(), strategys[1][kAxisCOut].end(), 0);
|
||||||
|
|
|
@ -131,7 +131,7 @@ int DepthwiseConv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
||||||
// for depthwise conv2d, we only split channel && include split feature map, weight && bias
|
// for depthwise conv2d, we only split channel && include split feature map, weight && bias
|
||||||
// so just get the ratio from strategy
|
// so just get the ratio from strategy
|
||||||
int split_count = 0;
|
int split_count = 0;
|
||||||
Strategys strategys = strategy.strategys;
|
Strategies strategys = strategy.strategys;
|
||||||
MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR);
|
MS_CHECK_GE(strategys.size(), kInputSizeTwo, RET_ERROR);
|
||||||
MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR);
|
MS_CHECK_GE(strategys[0].size(), kInputSizeFour, RET_ERROR);
|
||||||
MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR);
|
MS_CHECK_GE(strategys[1].size(), kInputSizeFour, RET_ERROR);
|
||||||
|
|
|
@ -76,7 +76,7 @@ std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(const st
|
||||||
default:
|
default:
|
||||||
return split_strategys;
|
return split_strategys;
|
||||||
}
|
}
|
||||||
opt::Strategys strategys = {split_feature_map, split_weight};
|
opt::Strategies strategys = {split_feature_map, split_weight};
|
||||||
for (const auto &supported_parallel_op : kParallelOpNames) {
|
for (const auto &supported_parallel_op : kParallelOpNames) {
|
||||||
split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode};
|
split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode};
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ const std::vector<int64_t> kSplitDefaultRatio = {0, 0};
|
||||||
// user's device to split, only split to cpu && gpu, no support npu
|
// user's device to split, only split to cpu && gpu, no support npu
|
||||||
const std::vector<std::string> kSplitDevTypes = {"cpu", "gpu"};
|
const std::vector<std::string> kSplitDevTypes = {"cpu", "gpu"};
|
||||||
|
|
||||||
using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
|
using Strategies = std::vector<std::vector<std::vector<int64_t>>>;
|
||||||
|
|
||||||
constexpr auto kDeviceTypeNone = -1;
|
constexpr auto kDeviceTypeNone = -1;
|
||||||
// strategy format is NHWC-KHWC
|
// strategy format is NHWC-KHWC
|
||||||
|
@ -71,7 +71,7 @@ enum SplitMode {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SplitStrategy {
|
struct SplitStrategy {
|
||||||
Strategys strategys{};
|
Strategies strategys{};
|
||||||
std::vector<std::string> dev_types{};
|
std::vector<std::string> dev_types{};
|
||||||
size_t dev_num{0};
|
size_t dev_num{0};
|
||||||
SplitMode split_mode_{NoSplit};
|
SplitMode split_mode_{NoSplit};
|
||||||
|
|
|
@ -64,7 +64,7 @@ void TestActivationInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, InferDevMatrixShape1) {
|
TEST_F(TestActivationInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}};
|
Strategies inputs = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
activation->Init(strategy, nullptr);
|
activation->Init(strategy, nullptr);
|
||||||
|
@ -75,7 +75,7 @@ TEST_F(TestActivationInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, InferSliceShape1) {
|
TEST_F(TestActivationInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 8, 16}};
|
Strategies str = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
activation->Init(strategy, nullptr);
|
activation->Init(strategy, nullptr);
|
||||||
|
@ -96,7 +96,7 @@ TEST_F(TestActivationInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, GetTensorLayout1) {
|
TEST_F(TestActivationInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 8, 16}};
|
Strategies str = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
activation->Init(strategy, nullptr);
|
activation->Init(strategy, nullptr);
|
||||||
|
@ -117,7 +117,7 @@ TEST_F(TestActivationInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, GetForwardOp1) {
|
TEST_F(TestActivationInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}};
|
Strategies inputs = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
activation->Init(strategy, nullptr);
|
activation->Init(strategy, nullptr);
|
||||||
|
@ -128,7 +128,7 @@ TEST_F(TestActivationInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, GetMirrorOPs1) {
|
TEST_F(TestActivationInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{1, 4, 8, 16}};
|
Strategies inputs = {{1, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
activation->Init(strategy, nullptr);
|
activation->Init(strategy, nullptr);
|
||||||
|
@ -148,7 +148,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, GetMirrorOPs2) {
|
TEST_F(TestActivationInfo, GetMirrorOPs2) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}};
|
Strategies inputs = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
activation->Init(strategy, nullptr);
|
activation->Init(strategy, nullptr);
|
||||||
|
@ -161,7 +161,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs2) {
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, CheckStrategy1) {
|
TEST_F(TestActivationInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,8,16}}
|
// Success: {{2,4,8,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = activation->Init(strategy, nullptr);
|
Status ret = activation->Init(strategy, nullptr);
|
||||||
|
@ -170,7 +170,7 @@ TEST_F(TestActivationInfo, CheckStrategy1) {
|
||||||
|
|
||||||
TEST_F(TestActivationInfo, CheckStrategy2) {
|
TEST_F(TestActivationInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,8,16}}
|
// Success: {{2,4,8,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = activation->Init(strategy, nullptr);
|
Status ret = activation->Init(strategy, nullptr);
|
||||||
|
|
|
@ -101,7 +101,7 @@ TEST_F(TestActivation, test_softmax_strategies) {
|
||||||
ASSERT_NE(sp, nullptr);
|
ASSERT_NE(sp, nullptr);
|
||||||
Cost cost = *(swc->cost_list[0]);
|
Cost cost = *(swc->cost_list[0]);
|
||||||
|
|
||||||
Strategys stra = sp->GetInputDim();
|
Strategies stra = sp->GetInputDim();
|
||||||
ASSERT_GT(stra.size(), 0);
|
ASSERT_GT(stra.size(), 0);
|
||||||
Dimensions input0_stra = stra[0];
|
Dimensions input0_stra = stra[0];
|
||||||
ASSERT_GT(input0_stra.size(), 2);
|
ASSERT_GT(input0_stra.size(), 2);
|
||||||
|
|
|
@ -63,7 +63,7 @@ void TestGeLUInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, InferDevMatrixShape1) {
|
TEST_F(TestGeLUInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
gelu->Init(strategy, nullptr);
|
gelu->Init(strategy, nullptr);
|
||||||
|
@ -74,7 +74,7 @@ TEST_F(TestGeLUInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, InferSliceShape1) {
|
TEST_F(TestGeLUInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
gelu->Init(strategy, nullptr);
|
gelu->Init(strategy, nullptr);
|
||||||
|
@ -95,7 +95,7 @@ TEST_F(TestGeLUInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, GetTensorLayout1) {
|
TEST_F(TestGeLUInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
gelu->Init(strategy, nullptr);
|
gelu->Init(strategy, nullptr);
|
||||||
|
@ -116,7 +116,7 @@ TEST_F(TestGeLUInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, GetForwardOp1) {
|
TEST_F(TestGeLUInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
gelu->Init(strategy, nullptr);
|
gelu->Init(strategy, nullptr);
|
||||||
|
@ -127,7 +127,7 @@ TEST_F(TestGeLUInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, GetMirrorOPs1) {
|
TEST_F(TestGeLUInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
gelu->Init(strategy, nullptr);
|
gelu->Init(strategy, nullptr);
|
||||||
|
@ -140,7 +140,7 @@ TEST_F(TestGeLUInfo, GetMirrorOPs1) {
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, CheckStrategy1) {
|
TEST_F(TestGeLUInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = gelu->Init(strategy, nullptr);
|
Status ret = gelu->Init(strategy, nullptr);
|
||||||
|
@ -149,7 +149,7 @@ TEST_F(TestGeLUInfo, CheckStrategy1) {
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, CheckStrategy2) {
|
TEST_F(TestGeLUInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = gelu->Init(strategy, nullptr);
|
Status ret = gelu->Init(strategy, nullptr);
|
||||||
|
@ -158,7 +158,7 @@ TEST_F(TestGeLUInfo, CheckStrategy2) {
|
||||||
|
|
||||||
TEST_F(TestGeLUInfo, CheckStrategy3) {
|
TEST_F(TestGeLUInfo, CheckStrategy3) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = gelu->Init(strategy, nullptr);
|
Status ret = gelu->Init(strategy, nullptr);
|
||||||
|
|
|
@ -64,7 +64,7 @@ void TestL2NormalizeInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) {
|
TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{4, 1, 8}};
|
Strategies inputs = {{4, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
norm->Init(strategy, nullptr);
|
norm->Init(strategy, nullptr);
|
||||||
|
@ -75,7 +75,7 @@ TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, InferSliceShape1) {
|
TEST_F(TestL2NormalizeInfo, InferSliceShape1) {
|
||||||
Strategys str = {{4, 1, 8}};
|
Strategies str = {{4, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
norm->Init(strategy, nullptr);
|
norm->Init(strategy, nullptr);
|
||||||
|
@ -96,7 +96,7 @@ TEST_F(TestL2NormalizeInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, GetTensorLayout1) {
|
TEST_F(TestL2NormalizeInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{4, 1, 8}};
|
Strategies str = {{4, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
norm->Init(strategy, nullptr);
|
norm->Init(strategy, nullptr);
|
||||||
|
@ -117,7 +117,7 @@ TEST_F(TestL2NormalizeInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, GetForwardOp1) {
|
TEST_F(TestL2NormalizeInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{4, 1, 8}};
|
Strategies inputs = {{4, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
norm->Init(strategy, nullptr);
|
norm->Init(strategy, nullptr);
|
||||||
|
@ -128,7 +128,7 @@ TEST_F(TestL2NormalizeInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) {
|
TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{4, 1, 8}};
|
Strategies inputs = {{4, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
norm->Init(strategy, nullptr);
|
norm->Init(strategy, nullptr);
|
||||||
|
@ -140,7 +140,7 @@ TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, CheckStrategy1) {
|
TEST_F(TestL2NormalizeInfo, CheckStrategy1) {
|
||||||
Strategys inputs = {{4, 1, 8}, {4, 1, 8}};
|
Strategies inputs = {{4, 1, 8}, {4, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = norm->Init(strategy, nullptr);
|
Status ret = norm->Init(strategy, nullptr);
|
||||||
|
@ -148,7 +148,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, CheckStrategy2) {
|
TEST_F(TestL2NormalizeInfo, CheckStrategy2) {
|
||||||
Strategys inputs = {{4, 2, 3}};
|
Strategies inputs = {{4, 2, 3}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = norm->Init(strategy, nullptr);
|
Status ret = norm->Init(strategy, nullptr);
|
||||||
|
@ -156,7 +156,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, CheckStrategy3) {
|
TEST_F(TestL2NormalizeInfo, CheckStrategy3) {
|
||||||
Strategys inputs = {{4, 2, 3, 4}};
|
Strategies inputs = {{4, 2, 3, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = norm->Init(strategy, nullptr);
|
Status ret = norm->Init(strategy, nullptr);
|
||||||
|
@ -164,7 +164,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, CheckStrategy4) {
|
TEST_F(TestL2NormalizeInfo, CheckStrategy4) {
|
||||||
Strategys inputs = {{4, 1, 8}};
|
Strategies inputs = {{4, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = norm->Init(strategy, nullptr);
|
Status ret = norm->Init(strategy, nullptr);
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy4) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestL2NormalizeInfo, mirror_ops) {
|
TEST_F(TestL2NormalizeInfo, mirror_ops) {
|
||||||
Strategys inputs = {{2, 1, 8}};
|
Strategies inputs = {{2, 1, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
norm->Init(strategy, nullptr);
|
norm->Init(strategy, nullptr);
|
||||||
|
|
|
@ -64,7 +64,7 @@ void TestLogSoftmaxInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) {
|
TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
log_softmax->Init(strategy, nullptr);
|
log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -75,7 +75,7 @@ TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, InferSliceShape1) {
|
TEST_F(TestLogSoftmaxInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
log_softmax->Init(strategy, nullptr);
|
log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -96,7 +96,7 @@ TEST_F(TestLogSoftmaxInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) {
|
TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
log_softmax->Init(strategy, nullptr);
|
log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -117,7 +117,7 @@ TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, GetForwardOp1) {
|
TEST_F(TestLogSoftmaxInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
log_softmax->Init(strategy, nullptr);
|
log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -128,7 +128,7 @@ TEST_F(TestLogSoftmaxInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) {
|
TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
log_softmax->Init(strategy, nullptr);
|
log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -141,7 +141,7 @@ TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) {
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, CheckStrategy1) {
|
TEST_F(TestLogSoftmaxInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = log_softmax->Init(strategy, nullptr);
|
Status ret = log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -150,7 +150,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy1) {
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, CheckStrategy2) {
|
TEST_F(TestLogSoftmaxInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = log_softmax->Init(strategy, nullptr);
|
Status ret = log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -159,7 +159,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy2) {
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, CheckStrategy3) {
|
TEST_F(TestLogSoftmaxInfo, CheckStrategy3) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 8, 16}};
|
Strategies inputs = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = log_softmax->Init(strategy, nullptr);
|
Status ret = log_softmax->Init(strategy, nullptr);
|
||||||
|
@ -167,7 +167,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestLogSoftmaxInfo, GetDeviceList1) {
|
TEST_F(TestLogSoftmaxInfo, GetDeviceList1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
log_softmax->Init(strategy, nullptr);
|
log_softmax->Init(strategy, nullptr);
|
||||||
|
|
|
@ -97,7 +97,7 @@ void TestMatmulInfo::SetUp() {
|
||||||
/// Description: infer dev matrix
|
/// Description: infer dev matrix
|
||||||
/// Expectation: the dev matrix is right
|
/// Expectation: the dev matrix is right
|
||||||
TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
|
TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -111,7 +111,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
|
||||||
/// Description: infer dev matrix
|
/// Description: infer dev matrix
|
||||||
/// Expectation: the dev matrix is right
|
/// Expectation: the dev matrix is right
|
||||||
TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
|
TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
|
||||||
Strategys inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}};
|
Strategies inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -125,7 +125,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
|
||||||
/// Description: infer dev matrix
|
/// Description: infer dev matrix
|
||||||
/// Expectation: the dev matrix is right
|
/// Expectation: the dev matrix is right
|
||||||
TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
|
TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {1, 16}};
|
Strategies inputs = {{2, 4, 8, 16}, {1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul2->Init(strategy, nullptr);
|
matmul2->Init(strategy, nullptr);
|
||||||
|
@ -139,7 +139,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
|
||||||
/// Description: infer dev matrix
|
/// Description: infer dev matrix
|
||||||
/// Expectation: the dev matrix is right
|
/// Expectation: the dev matrix is right
|
||||||
TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
|
TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
|
||||||
Strategys inputs = {{2, 4, 8, 8}, {2, 8}};
|
Strategies inputs = {{2, 4, 8, 8}, {2, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul2->Init(strategy, nullptr);
|
matmul2->Init(strategy, nullptr);
|
||||||
|
@ -153,7 +153,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
|
||||||
/// Description: infer dev matrix
|
/// Description: infer dev matrix
|
||||||
/// Expectation: the dev matrix is right
|
/// Expectation: the dev matrix is right
|
||||||
TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
|
TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
|
||||||
Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
|
Strategies inputs = {{8, 16}, {2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul3->Init(strategy, nullptr);
|
matmul3->Init(strategy, nullptr);
|
||||||
|
@ -167,7 +167,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
|
||||||
/// Description: infer dev matrix
|
/// Description: infer dev matrix
|
||||||
/// Expectation: the dev matrix is right
|
/// Expectation: the dev matrix is right
|
||||||
TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
|
TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
|
||||||
Strategys inputs = {{8, 8}, {2, 4, 2, 8}};
|
Strategies inputs = {{8, 8}, {2, 4, 2, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul3->Init(strategy, nullptr);
|
matmul3->Init(strategy, nullptr);
|
||||||
|
@ -181,7 +181,7 @@ TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
|
||||||
/// Description: infer tensor map
|
/// Description: infer tensor map
|
||||||
/// Expectation: the tensor map is right
|
/// Expectation: the tensor map is right
|
||||||
TEST_F(TestMatmulInfo, InferTensorMap1) {
|
TEST_F(TestMatmulInfo, InferTensorMap1) {
|
||||||
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -209,7 +209,7 @@ TEST_F(TestMatmulInfo, InferTensorMap1) {
|
||||||
/// Description: infer tensor map
|
/// Description: infer tensor map
|
||||||
/// Expectation: the tensor map is right
|
/// Expectation: the tensor map is right
|
||||||
TEST_F(TestMatmulInfo, InferTensorMap2) {
|
TEST_F(TestMatmulInfo, InferTensorMap2) {
|
||||||
Strategys str = {{2, 4, 8, 16}, {1, 16}};
|
Strategies str = {{2, 4, 8, 16}, {1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
matmul2->Init(strategy, nullptr);
|
matmul2->Init(strategy, nullptr);
|
||||||
|
@ -237,7 +237,7 @@ TEST_F(TestMatmulInfo, InferTensorMap2) {
|
||||||
/// Description: infer tensor map
|
/// Description: infer tensor map
|
||||||
/// Expectation: the tensor map is right
|
/// Expectation: the tensor map is right
|
||||||
TEST_F(TestMatmulInfo, InferTensorMap3) {
|
TEST_F(TestMatmulInfo, InferTensorMap3) {
|
||||||
Strategys str = {{8, 16}, {2, 4, 1, 16}};
|
Strategies str = {{8, 16}, {2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
matmul3->Init(strategy, nullptr);
|
matmul3->Init(strategy, nullptr);
|
||||||
|
@ -265,7 +265,7 @@ TEST_F(TestMatmulInfo, InferTensorMap3) {
|
||||||
/// Description: infer slice shape
|
/// Description: infer slice shape
|
||||||
/// Expectation: the slice shape is right
|
/// Expectation: the slice shape is right
|
||||||
TEST_F(TestMatmulInfo, InferSliceShape1) {
|
TEST_F(TestMatmulInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -293,7 +293,7 @@ TEST_F(TestMatmulInfo, InferSliceShape1) {
|
||||||
/// Description: infer slice shape
|
/// Description: infer slice shape
|
||||||
/// Expectation: the slice shape is right
|
/// Expectation: the slice shape is right
|
||||||
TEST_F(TestMatmulInfo, InferSliceShape2) {
|
TEST_F(TestMatmulInfo, InferSliceShape2) {
|
||||||
Strategys str = {{2, 4, 8, 16}, {1, 16}};
|
Strategies str = {{2, 4, 8, 16}, {1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
matmul2->Init(strategy, nullptr);
|
matmul2->Init(strategy, nullptr);
|
||||||
|
@ -321,7 +321,7 @@ TEST_F(TestMatmulInfo, InferSliceShape2) {
|
||||||
/// Description: infer slice shape
|
/// Description: infer slice shape
|
||||||
/// Expectation: the slice shape is right
|
/// Expectation: the slice shape is right
|
||||||
TEST_F(TestMatmulInfo, InferSliceShape3) {
|
TEST_F(TestMatmulInfo, InferSliceShape3) {
|
||||||
Strategys str = {{8, 16}, {2, 4, 1, 16}};
|
Strategies str = {{8, 16}, {2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
matmul3->Init(strategy, nullptr);
|
matmul3->Init(strategy, nullptr);
|
||||||
|
@ -349,7 +349,7 @@ TEST_F(TestMatmulInfo, InferSliceShape3) {
|
||||||
/// Description: get tensor layout
|
/// Description: get tensor layout
|
||||||
/// Expectation: the tensor layout is right
|
/// Expectation: the tensor layout is right
|
||||||
TEST_F(TestMatmulInfo, GetTensorLayout3) {
|
TEST_F(TestMatmulInfo, GetTensorLayout3) {
|
||||||
Strategys str = {{8, 16}, {2, 4, 1, 16}};
|
Strategies str = {{8, 16}, {2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
matmul3->Init(strategy, nullptr);
|
matmul3->Init(strategy, nullptr);
|
||||||
|
@ -377,7 +377,7 @@ TEST_F(TestMatmulInfo, GetTensorLayout3) {
|
||||||
/// Description: infer forward op
|
/// Description: infer forward op
|
||||||
/// Expectation: the forward op is right
|
/// Expectation: the forward op is right
|
||||||
TEST_F(TestMatmulInfo, GetForwardOp1) {
|
TEST_F(TestMatmulInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -406,7 +406,7 @@ TEST_F(TestMatmulInfo, GetForwardOp1) {
|
||||||
/// Description: infer forward op
|
/// Description: infer forward op
|
||||||
/// Expectation: the forward op is right
|
/// Expectation: the forward op is right
|
||||||
TEST_F(TestMatmulInfo, GetForwardOp2) {
|
TEST_F(TestMatmulInfo, GetForwardOp2) {
|
||||||
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -419,7 +419,7 @@ TEST_F(TestMatmulInfo, GetForwardOp2) {
|
||||||
/// Description: infer virtual_div op
|
/// Description: infer virtual_div op
|
||||||
/// Expectation: the virtual_div op is right
|
/// Expectation: the virtual_div op is right
|
||||||
TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
|
TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -441,7 +441,7 @@ TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
|
||||||
/// Description: infer mirror op
|
/// Description: infer mirror op
|
||||||
/// Expectation: the mirror op is right
|
/// Expectation: the mirror op is right
|
||||||
TEST_F(TestMatmulInfo, GetMirrorOPs1) {
|
TEST_F(TestMatmulInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -463,7 +463,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs1) {
|
||||||
/// Description: infer mirror op
|
/// Description: infer mirror op
|
||||||
/// Expectation: the mirror op is right
|
/// Expectation: the mirror op is right
|
||||||
TEST_F(TestMatmulInfo, GetMirrorOPs2) {
|
TEST_F(TestMatmulInfo, GetMirrorOPs2) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}, {8, 16}};
|
Strategies inputs = {{2, 4, 1, 16}, {8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul2->Init(strategy, nullptr);
|
matmul2->Init(strategy, nullptr);
|
||||||
|
@ -485,7 +485,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs2) {
|
||||||
/// Description: infer mirror op
|
/// Description: infer mirror op
|
||||||
/// Expectation: the mirror op is right
|
/// Expectation: the mirror op is right
|
||||||
TEST_F(TestMatmulInfo, GetMirrorOPs3) {
|
TEST_F(TestMatmulInfo, GetMirrorOPs3) {
|
||||||
Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
|
Strategies inputs = {{8, 16}, {2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul3->Init(strategy, nullptr);
|
matmul3->Init(strategy, nullptr);
|
||||||
|
@ -506,7 +506,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs3) {
|
||||||
/// Description: infer mirror op
|
/// Description: infer mirror op
|
||||||
/// Expectation: the mirror op is right
|
/// Expectation: the mirror op is right
|
||||||
TEST_F(TestMatmulInfo, GetMirrorOPs4) {
|
TEST_F(TestMatmulInfo, GetMirrorOPs4) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}};
|
Strategies inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
matmul1->Init(strategy, nullptr);
|
matmul1->Init(strategy, nullptr);
|
||||||
|
@ -519,7 +519,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) {
|
||||||
/// Description: init twice
|
/// Description: init twice
|
||||||
/// Expectation: the mirror op is right
|
/// Expectation: the mirror op is right
|
||||||
TEST_F(TestMatmulInfo, InitTwice) {
|
TEST_F(TestMatmulInfo, InitTwice) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
// init twice
|
// init twice
|
||||||
|
@ -544,7 +544,7 @@ TEST_F(TestMatmulInfo, InitTwice) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, CheckStrategy1) {
|
TEST_F(TestMatmulInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul1->Init(strategy, nullptr);
|
Status ret = matmul1->Init(strategy, nullptr);
|
||||||
|
@ -556,7 +556,7 @@ TEST_F(TestMatmulInfo, CheckStrategy1) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, CheckStrategy2) {
|
TEST_F(TestMatmulInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {4, 16, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul1->Init(strategy, nullptr);
|
Status ret = matmul1->Init(strategy, nullptr);
|
||||||
|
@ -568,7 +568,7 @@ TEST_F(TestMatmulInfo, CheckStrategy2) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, CheckStrategy3) {
|
TEST_F(TestMatmulInfo, CheckStrategy3) {
|
||||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul1->Init(strategy, nullptr);
|
Status ret = matmul1->Init(strategy, nullptr);
|
||||||
|
@ -580,7 +580,7 @@ TEST_F(TestMatmulInfo, CheckStrategy3) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, CheckStrategy4) {
|
TEST_F(TestMatmulInfo, CheckStrategy4) {
|
||||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul1->Init(strategy, nullptr);
|
Status ret = matmul1->Init(strategy, nullptr);
|
||||||
|
@ -592,7 +592,7 @@ TEST_F(TestMatmulInfo, CheckStrategy4) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, CheckStrategy5) {
|
TEST_F(TestMatmulInfo, CheckStrategy5) {
|
||||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||||
Strategys inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul1->Init(strategy, nullptr);
|
Status ret = matmul1->Init(strategy, nullptr);
|
||||||
|
@ -604,7 +604,7 @@ TEST_F(TestMatmulInfo, CheckStrategy5) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, CheckStrategy6) {
|
TEST_F(TestMatmulInfo, CheckStrategy6) {
|
||||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||||
Strategys inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul1->Init(strategy, nullptr);
|
Status ret = matmul1->Init(strategy, nullptr);
|
||||||
|
@ -616,7 +616,7 @@ TEST_F(TestMatmulInfo, CheckStrategy6) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, CheckStrategy7) {
|
TEST_F(TestMatmulInfo, CheckStrategy7) {
|
||||||
// Success: {{2,4,8,16}, {2,4,16,1}}
|
// Success: {{2,4,8,16}, {2,4,16,1}}
|
||||||
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul1->Init(strategy, nullptr);
|
Status ret = matmul1->Init(strategy, nullptr);
|
||||||
|
@ -628,7 +628,7 @@ TEST_F(TestMatmulInfo, CheckStrategy7) {
|
||||||
/// Expectation: return FAILED
|
/// Expectation: return FAILED
|
||||||
TEST_F(TestMatmulInfo, InitFailed) {
|
TEST_F(TestMatmulInfo, InitFailed) {
|
||||||
// matmul4 attr is wrong
|
// matmul4 attr is wrong
|
||||||
Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = matmul4->Init(strategy, nullptr);
|
Status ret = matmul4->Init(strategy, nullptr);
|
||||||
|
|
|
@ -64,7 +64,7 @@ void TestOneHotInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, InferDevMatrixShape1) {
|
TEST_F(TestOneHotInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{8, 1}, {}, {}};
|
Strategies inputs = {{8, 1}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -76,7 +76,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
|
TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
|
||||||
Strategys inputs = {{4, 1}, {}, {}};
|
Strategies inputs = {{4, 1}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -88,7 +88,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, InferDevMatrixShape3) {
|
TEST_F(TestOneHotInfo, InferDevMatrixShape3) {
|
||||||
Strategys inputs = {{4, 2}, {}, {}};
|
Strategies inputs = {{4, 2}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, InferTensorMap2) {
|
TEST_F(TestOneHotInfo, InferTensorMap2) {
|
||||||
Strategys str = {{8, 1}, {}, {}};
|
Strategies str = {{8, 1}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo, InferTensorMap2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, InferSliceShape1) {
|
TEST_F(TestOneHotInfo, InferSliceShape1) {
|
||||||
Strategys str = {{8, 1}, {}, {}};
|
Strategies str = {{8, 1}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, InferSliceShape2) {
|
TEST_F(TestOneHotInfo, InferSliceShape2) {
|
||||||
Strategys str = {{4, 2}, {}, {}};
|
Strategies str = {{4, 2}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo, InferSliceShape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, InferSliceShape3) {
|
TEST_F(TestOneHotInfo, InferSliceShape3) {
|
||||||
Strategys str = {{2, 2}, {}, {}};
|
Strategies str = {{2, 2}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -188,7 +188,7 @@ TEST_F(TestOneHotInfo, InferSliceShape3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, GetMirrorOPs1) {
|
TEST_F(TestOneHotInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{8, 1}, {}, {}};
|
Strategies inputs = {{8, 1}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status status = onehot_info->Init(strategy, nullptr);
|
Status status = onehot_info->Init(strategy, nullptr);
|
||||||
|
@ -199,7 +199,7 @@ TEST_F(TestOneHotInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo, CheckStrategy1) {
|
TEST_F(TestOneHotInfo, CheckStrategy1) {
|
||||||
Strategys inputs = {{16}, {}, {}};
|
Strategies inputs = {{16}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = onehot_info->Init(strategy, nullptr);
|
Status ret = onehot_info->Init(strategy, nullptr);
|
||||||
|
|
|
@ -64,7 +64,7 @@ void TestOneHotInfo2::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo2, InferDevMatrixShape1) {
|
TEST_F(TestOneHotInfo2, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{1, 8}, {}, {}};
|
Strategies inputs = {{1, 8}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status status = onehot_info2->Init(strategy, nullptr);
|
Status status = onehot_info2->Init(strategy, nullptr);
|
||||||
|
@ -76,7 +76,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
|
TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
|
||||||
Strategys inputs = {{1, 4}, {}, {}};
|
Strategies inputs = {{1, 4}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status status = onehot_info2->Init(strategy, nullptr);
|
Status status = onehot_info2->Init(strategy, nullptr);
|
||||||
|
@ -88,7 +88,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo2, InferDevMatrixShape3) {
|
TEST_F(TestOneHotInfo2, InferDevMatrixShape3) {
|
||||||
Strategys inputs = {{2, 4}, {}, {}};
|
Strategies inputs = {{2, 4}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status status = onehot_info2->Init(strategy, nullptr);
|
Status status = onehot_info2->Init(strategy, nullptr);
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo2, InferTensorMap2) {
|
TEST_F(TestOneHotInfo2, InferTensorMap2) {
|
||||||
Strategys str = {{1, 8}, {}, {}};
|
Strategies str = {{1, 8}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info2->Init(strategy, nullptr);
|
Status status = onehot_info2->Init(strategy, nullptr);
|
||||||
|
@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo2, InferTensorMap2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo2, InferSliceShape1) {
|
TEST_F(TestOneHotInfo2, InferSliceShape1) {
|
||||||
Strategys str = {{1, 8}, {}, {}};
|
Strategies str = {{1, 8}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info2->Init(strategy, nullptr);
|
Status status = onehot_info2->Init(strategy, nullptr);
|
||||||
|
@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo2, InferSliceShape2) {
|
TEST_F(TestOneHotInfo2, InferSliceShape2) {
|
||||||
Strategys str = {{2, 4}, {}, {}};
|
Strategies str = {{2, 4}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info2->Init(strategy, nullptr);
|
Status status = onehot_info2->Init(strategy, nullptr);
|
||||||
|
@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOneHotInfo2, InferSliceShape3) {
|
TEST_F(TestOneHotInfo2, InferSliceShape3) {
|
||||||
Strategys str = {{2, 2}, {}, {}};
|
Strategies str = {{2, 2}, {}, {}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
Status status = onehot_info2->Init(strategy, nullptr);
|
Status status = onehot_info2->Init(strategy, nullptr);
|
||||||
|
|
|
@ -63,7 +63,7 @@ void TestPowInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, InferDevMatrixShape1) {
|
TEST_F(TestPowInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
pow->Init(strategy, nullptr);
|
pow->Init(strategy, nullptr);
|
||||||
|
@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, InferSliceShape1) {
|
TEST_F(TestPowInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 8}, {2, 4, 8}};
|
Strategies str = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
pow->Init(strategy, nullptr);
|
pow->Init(strategy, nullptr);
|
||||||
|
@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, GetTensorLayout1) {
|
TEST_F(TestPowInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 8}, {2, 4, 8}};
|
Strategies str = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
pow->Init(strategy, nullptr);
|
pow->Init(strategy, nullptr);
|
||||||
|
@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, GetForwardOp1) {
|
TEST_F(TestPowInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
pow->Init(strategy, nullptr);
|
pow->Init(strategy, nullptr);
|
||||||
|
@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, GetMirrorOPs1) {
|
TEST_F(TestPowInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
pow->Init(strategy, nullptr);
|
pow->Init(strategy, nullptr);
|
||||||
|
@ -139,7 +139,7 @@ TEST_F(TestPowInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, CheckStrategy1) {
|
TEST_F(TestPowInfo, CheckStrategy1) {
|
||||||
Strategys inputs = {{2, 2, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 2, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = pow->Init(strategy, nullptr);
|
Status ret = pow->Init(strategy, nullptr);
|
||||||
|
@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, CheckStrategy2) {
|
TEST_F(TestPowInfo, CheckStrategy2) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
|
Strategies inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = pow->Init(strategy, nullptr);
|
Status ret = pow->Init(strategy, nullptr);
|
||||||
|
@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPowInfo, CheckStrategy3) {
|
TEST_F(TestPowInfo, CheckStrategy3) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = pow->Init(strategy, nullptr);
|
Status ret = pow->Init(strategy, nullptr);
|
||||||
|
|
|
@ -64,7 +64,7 @@ void TestPReLUInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, InferDevMatrixShape1) {
|
TEST_F(TestPReLUInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 1, 8, 16}, {1}};
|
Strategies inputs = {{2, 1, 8, 16}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
prelu->Init(strategy, nullptr);
|
prelu->Init(strategy, nullptr);
|
||||||
|
@ -75,7 +75,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, InferSliceShape1) {
|
TEST_F(TestPReLUInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 1, 8, 16}, {1}};
|
Strategies str = {{2, 1, 8, 16}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
prelu->Init(strategy, nullptr);
|
prelu->Init(strategy, nullptr);
|
||||||
|
@ -98,7 +98,7 @@ TEST_F(TestPReLUInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, GetTensorLayout1) {
|
TEST_F(TestPReLUInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 1, 8, 16}, {1}};
|
Strategies str = {{2, 1, 8, 16}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
prelu->Init(strategy, nullptr);
|
prelu->Init(strategy, nullptr);
|
||||||
|
@ -122,7 +122,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, GetMirrorOPs1) {
|
TEST_F(TestPReLUInfo, GetMirrorOPs1) {
|
||||||
Strategys str = {{2, 1, 2, 2}, {1}};
|
Strategies str = {{2, 1, 2, 2}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
prelu->Init(strategy, nullptr);
|
prelu->Init(strategy, nullptr);
|
||||||
MirrorOps mirror_ops = prelu->mirror_ops();
|
MirrorOps mirror_ops = prelu->mirror_ops();
|
||||||
|
@ -139,14 +139,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs1) {
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, CheckStrategy1) {
|
TEST_F(TestPReLUInfo, CheckStrategy1) {
|
||||||
// Success: {{2,1,8,16},{1}}
|
// Success: {{2,1,8,16},{1}}
|
||||||
Strategys inputs = {{2, 1, 8, 16}};
|
Strategies inputs = {{2, 1, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
Status ret = prelu->Init(strategy, nullptr);
|
Status ret = prelu->Init(strategy, nullptr);
|
||||||
ASSERT_EQ(ret, FAILED);
|
ASSERT_EQ(ret, FAILED);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, CheckStrategy2) {
|
TEST_F(TestPReLUInfo, CheckStrategy2) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}, {4}};
|
Strategies inputs = {{2, 4, 8, 16}, {4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
Status ret = prelu->Init(strategy, nullptr);
|
Status ret = prelu->Init(strategy, nullptr);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
|
@ -169,7 +169,7 @@ TEST_F(TestPReLUInfo, AutoStrategy1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) {
|
TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) {
|
||||||
Strategys inputs = {{128, 1}, {1}};
|
Strategies inputs = {{128, 1}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
prelu_2d->Init(strategy, nullptr);
|
prelu_2d->Init(strategy, nullptr);
|
||||||
|
@ -180,7 +180,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, InferSliceShape_2d1) {
|
TEST_F(TestPReLUInfo, InferSliceShape_2d1) {
|
||||||
Strategys str = {{128, 1}, {1}};
|
Strategies str = {{128, 1}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
prelu_2d->Init(strategy, nullptr);
|
prelu_2d->Init(strategy, nullptr);
|
||||||
|
@ -203,7 +203,7 @@ TEST_F(TestPReLUInfo, InferSliceShape_2d1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, GetTensorLayout_2d1) {
|
TEST_F(TestPReLUInfo, GetTensorLayout_2d1) {
|
||||||
Strategys str = {{128, 1}, {1}};
|
Strategies str = {{128, 1}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
prelu_2d->Init(strategy, nullptr);
|
prelu_2d->Init(strategy, nullptr);
|
||||||
|
@ -227,7 +227,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout_2d1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) {
|
TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) {
|
||||||
Strategys str = {{128, 1}, {1}};
|
Strategies str = {{128, 1}, {1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
prelu_2d->Init(strategy, nullptr);
|
prelu_2d->Init(strategy, nullptr);
|
||||||
MirrorOps mirror_ops = prelu_2d->mirror_ops();
|
MirrorOps mirror_ops = prelu_2d->mirror_ops();
|
||||||
|
@ -244,14 +244,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) {
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, CheckStrategy_2d1) {
|
TEST_F(TestPReLUInfo, CheckStrategy_2d1) {
|
||||||
// Success: {{2,1,8,16},{1}}
|
// Success: {{2,1,8,16},{1}}
|
||||||
Strategys inputs = {{128, 1}};
|
Strategies inputs = {{128, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
Status ret = prelu_2d->Init(strategy, nullptr);
|
Status ret = prelu_2d->Init(strategy, nullptr);
|
||||||
ASSERT_EQ(ret, FAILED);
|
ASSERT_EQ(ret, FAILED);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPReLUInfo, CheckStrategy_2d2) {
|
TEST_F(TestPReLUInfo, CheckStrategy_2d2) {
|
||||||
Strategys inputs = {{128, 4}, {4}};
|
Strategies inputs = {{128, 4}, {4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
Status ret = prelu_2d->Init(strategy, nullptr);
|
Status ret = prelu_2d->Init(strategy, nullptr);
|
||||||
ASSERT_EQ(ret, SUCCESS);
|
ASSERT_EQ(ret, SUCCESS);
|
||||||
|
|
|
@ -68,7 +68,7 @@ void TestReduceSumInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, InferDevMatrixShape1) {
|
TEST_F(TestReduceSumInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{4, 8, 1}};
|
Strategies inputs = {{4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reduce_sum->Init(strategy, nullptr);
|
reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -79,7 +79,7 @@ TEST_F(TestReduceSumInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, InferSliceShape1) {
|
TEST_F(TestReduceSumInfo, InferSliceShape1) {
|
||||||
Strategys str = {{4, 8, 1}};
|
Strategies str = {{4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
reduce_sum->Init(strategy, nullptr);
|
reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(TestReduceSumInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, GetTensorLayout1) {
|
TEST_F(TestReduceSumInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{4, 8, 1}};
|
Strategies str = {{4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
reduce_sum->Init(strategy, nullptr);
|
reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -121,7 +121,7 @@ TEST_F(TestReduceSumInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, GetForwardOp1) {
|
TEST_F(TestReduceSumInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{4, 8, 1}};
|
Strategies inputs = {{4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reduce_sum->Init(strategy, nullptr);
|
reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -132,7 +132,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, GetForwardOp2) {
|
TEST_F(TestReduceSumInfo, GetForwardOp2) {
|
||||||
Strategys inputs = {{4, 4, 2}};
|
Strategies inputs = {{4, 4, 2}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reduce_sum->Init(strategy, nullptr);
|
reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -156,7 +156,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, GetMirrorOPs1) {
|
TEST_F(TestReduceSumInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{4, 8, 1}};
|
Strategies inputs = {{4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reduce_sum->Init(strategy, nullptr);
|
reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -168,7 +168,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, GetMirrorOPs2) {
|
TEST_F(TestReduceSumInfo, GetMirrorOPs2) {
|
||||||
Strategys inputs = {{4, 4, 1}};
|
Strategies inputs = {{4, 4, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reduce_sum->Init(strategy, nullptr);
|
reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -187,7 +187,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, CheckStrategy1) {
|
TEST_F(TestReduceSumInfo, CheckStrategy1) {
|
||||||
Strategys inputs = {{2, 2, 8, 16}};
|
Strategies inputs = {{2, 2, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = reduce_sum->Init(strategy, nullptr);
|
Status ret = reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -195,7 +195,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, CheckStrategy2) {
|
TEST_F(TestReduceSumInfo, CheckStrategy2) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = reduce_sum->Init(strategy, nullptr);
|
Status ret = reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -203,7 +203,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, CheckStrategy3) {
|
TEST_F(TestReduceSumInfo, CheckStrategy3) {
|
||||||
Strategys inputs = {{4, 4, 2}};
|
Strategies inputs = {{4, 4, 2}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = reduce_sum->Init(strategy, nullptr);
|
Status ret = reduce_sum->Init(strategy, nullptr);
|
||||||
|
@ -211,7 +211,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReduceSumInfo, CheckStrategy4) {
|
TEST_F(TestReduceSumInfo, CheckStrategy4) {
|
||||||
Strategys inputs = {{4, 8, 1}};
|
Strategies inputs = {{4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = reduce_sum->Init(strategy, nullptr);
|
Status ret = reduce_sum->Init(strategy, nullptr);
|
||||||
|
|
|
@ -68,7 +68,7 @@ void TestReshapeInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
|
TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{4, 1, 1, 1}};
|
Strategies inputs = {{4, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -79,7 +79,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
|
TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
|
||||||
Strategys inputs = {{32, 1, 1, 1}};
|
Strategies inputs = {{32, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -90,7 +90,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, InferSliceShape1) {
|
TEST_F(TestReshapeInfo, InferSliceShape1) {
|
||||||
Strategys str = {{4, 1, 1, 1}};
|
Strategies str = {{4, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -111,7 +111,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, InferSliceShape2) {
|
TEST_F(TestReshapeInfo, InferSliceShape2) {
|
||||||
Strategys str = {{32, 1, 1, 1}};
|
Strategies str = {{32, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -132,7 +132,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, GetTensorLayout1) {
|
TEST_F(TestReshapeInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{4, 1, 1, 1}};
|
Strategies str = {{4, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -153,7 +153,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, GetTensorLayout2) {
|
TEST_F(TestReshapeInfo, GetTensorLayout2) {
|
||||||
Strategys str = {{32, 1, 1, 1}};
|
Strategies str = {{32, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -174,7 +174,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, GetForwardOp1) {
|
TEST_F(TestReshapeInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{4, 1, 1, 1}};
|
Strategies inputs = {{4, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -185,7 +185,7 @@ TEST_F(TestReshapeInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, GetMirrorOPs1) {
|
TEST_F(TestReshapeInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{4, 1, 1, 1}};
|
Strategies inputs = {{4, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
reshape->Init(strategy, nullptr);
|
reshape->Init(strategy, nullptr);
|
||||||
|
@ -197,7 +197,7 @@ TEST_F(TestReshapeInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, CheckStrategy1) {
|
TEST_F(TestReshapeInfo, CheckStrategy1) {
|
||||||
Strategys inputs = {{1, 4, 8}};
|
Strategies inputs = {{1, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = reshape->Init(strategy, nullptr);
|
Status ret = reshape->Init(strategy, nullptr);
|
||||||
|
@ -205,7 +205,7 @@ TEST_F(TestReshapeInfo, CheckStrategy1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, CheckStrategy2) {
|
TEST_F(TestReshapeInfo, CheckStrategy2) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = reshape->Init(strategy, nullptr);
|
Status ret = reshape->Init(strategy, nullptr);
|
||||||
|
@ -213,7 +213,7 @@ TEST_F(TestReshapeInfo, CheckStrategy2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestReshapeInfo, CheckStrategy3) {
|
TEST_F(TestReshapeInfo, CheckStrategy3) {
|
||||||
Strategys inputs = {{4, 1, 1, 1}};
|
Strategies inputs = {{4, 1, 1, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = reshape->Init(strategy, nullptr);
|
Status ret = reshape->Init(strategy, nullptr);
|
||||||
|
|
|
@ -64,7 +64,7 @@ void TestSoftmaxLoss::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) {
|
TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
Strategies inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
loss->Init(strategy, nullptr);
|
loss->Init(strategy, nullptr);
|
||||||
|
@ -75,7 +75,7 @@ TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, InferSliceShape1) {
|
TEST_F(TestSoftmaxLoss, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
Strategies str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
loss->Init(strategy, nullptr);
|
loss->Init(strategy, nullptr);
|
||||||
|
@ -104,7 +104,7 @@ TEST_F(TestSoftmaxLoss, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, GetTensorLayout1) {
|
TEST_F(TestSoftmaxLoss, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
Strategies str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
loss->Init(strategy, nullptr);
|
loss->Init(strategy, nullptr);
|
||||||
|
@ -133,7 +133,7 @@ TEST_F(TestSoftmaxLoss, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, GetForwardOp1) {
|
TEST_F(TestSoftmaxLoss, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
Strategies inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
loss->Init(strategy, nullptr);
|
loss->Init(strategy, nullptr);
|
||||||
|
@ -144,7 +144,7 @@ TEST_F(TestSoftmaxLoss, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, GetMirrorOPs1) {
|
TEST_F(TestSoftmaxLoss, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
Strategies inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
loss->Init(strategy, nullptr);
|
loss->Init(strategy, nullptr);
|
||||||
|
@ -156,7 +156,7 @@ TEST_F(TestSoftmaxLoss, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) {
|
TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) {
|
||||||
Strategys inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}};
|
Strategies inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
loss->Init(strategy, nullptr);
|
loss->Init(strategy, nullptr);
|
||||||
|
@ -176,7 +176,7 @@ TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) {
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, CheckStrategy1) {
|
TEST_F(TestSoftmaxLoss, CheckStrategy1) {
|
||||||
// Success: {{2,4,8,16}}
|
// Success: {{2,4,8,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = loss->Init(strategy, nullptr);
|
Status ret = loss->Init(strategy, nullptr);
|
||||||
|
@ -185,7 +185,7 @@ TEST_F(TestSoftmaxLoss, CheckStrategy1) {
|
||||||
|
|
||||||
TEST_F(TestSoftmaxLoss, CheckStrategy2) {
|
TEST_F(TestSoftmaxLoss, CheckStrategy2) {
|
||||||
// Success: {{2,4,8,16}}
|
// Success: {{2,4,8,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = loss->Init(strategy, nullptr);
|
Status ret = loss->Init(strategy, nullptr);
|
||||||
|
|
|
@ -68,7 +68,7 @@ void TestSoftmaxInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) {
|
TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
softmax->Init(strategy, nullptr);
|
softmax->Init(strategy, nullptr);
|
||||||
|
@ -79,7 +79,7 @@ TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, InferSliceShape1) {
|
TEST_F(TestSoftmaxInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
softmax->Init(strategy, nullptr);
|
softmax->Init(strategy, nullptr);
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(TestSoftmaxInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, GetTensorLayout1) {
|
TEST_F(TestSoftmaxInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
softmax->Init(strategy, nullptr);
|
softmax->Init(strategy, nullptr);
|
||||||
|
@ -121,7 +121,7 @@ TEST_F(TestSoftmaxInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, GetForwardOp1) {
|
TEST_F(TestSoftmaxInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
softmax->Init(strategy, nullptr);
|
softmax->Init(strategy, nullptr);
|
||||||
|
@ -132,7 +132,7 @@ TEST_F(TestSoftmaxInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, GetMirrorOPs1) {
|
TEST_F(TestSoftmaxInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
softmax->Init(strategy, nullptr);
|
softmax->Init(strategy, nullptr);
|
||||||
|
@ -145,7 +145,7 @@ TEST_F(TestSoftmaxInfo, GetMirrorOPs1) {
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, CheckStrategy1) {
|
TEST_F(TestSoftmaxInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = softmax->Init(strategy, nullptr);
|
Status ret = softmax->Init(strategy, nullptr);
|
||||||
|
@ -154,7 +154,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy1) {
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, CheckStrategy2) {
|
TEST_F(TestSoftmaxInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = softmax->Init(strategy, nullptr);
|
Status ret = softmax->Init(strategy, nullptr);
|
||||||
|
@ -163,7 +163,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy2) {
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, CheckStrategy3) {
|
TEST_F(TestSoftmaxInfo, CheckStrategy3) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 8, 16}};
|
Strategies inputs = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = softmax->Init(strategy, nullptr);
|
Status ret = softmax->Init(strategy, nullptr);
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy3) {
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, InitFailed1) {
|
TEST_F(TestSoftmaxInfo, InitFailed1) {
|
||||||
// softmax2's axis is wrong
|
// softmax2's axis is wrong
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = softmax2->Init(strategy, nullptr);
|
Status ret = softmax2->Init(strategy, nullptr);
|
||||||
|
@ -181,7 +181,7 @@ TEST_F(TestSoftmaxInfo, InitFailed1) {
|
||||||
|
|
||||||
TEST_F(TestSoftmaxInfo, InitFailed2) {
|
TEST_F(TestSoftmaxInfo, InitFailed2) {
|
||||||
// dev num is wrong
|
// dev num is wrong
|
||||||
Strategys inputs = {{2, 4, 1, 100}};
|
Strategies inputs = {{2, 4, 1, 100}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = softmax2->Init(strategy, nullptr);
|
Status ret = softmax2->Init(strategy, nullptr);
|
||||||
|
|
|
@ -63,7 +63,7 @@ void TestTanhInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, InferDevMatrixShape1) {
|
TEST_F(TestTanhInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
tanh->Init(strategy, nullptr);
|
tanh->Init(strategy, nullptr);
|
||||||
|
@ -74,7 +74,7 @@ TEST_F(TestTanhInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, InferSliceShape1) {
|
TEST_F(TestTanhInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
tanh->Init(strategy, nullptr);
|
tanh->Init(strategy, nullptr);
|
||||||
|
@ -95,7 +95,7 @@ TEST_F(TestTanhInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, GetTensorLayout1) {
|
TEST_F(TestTanhInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategies str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
tanh->Init(strategy, nullptr);
|
tanh->Init(strategy, nullptr);
|
||||||
|
@ -116,7 +116,7 @@ TEST_F(TestTanhInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, GetForwardOp1) {
|
TEST_F(TestTanhInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
tanh->Init(strategy, nullptr);
|
tanh->Init(strategy, nullptr);
|
||||||
|
@ -127,7 +127,7 @@ TEST_F(TestTanhInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, GetMirrorOPs1) {
|
TEST_F(TestTanhInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
tanh->Init(strategy, nullptr);
|
tanh->Init(strategy, nullptr);
|
||||||
|
@ -140,7 +140,7 @@ TEST_F(TestTanhInfo, GetMirrorOPs1) {
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, CheckStrategy1) {
|
TEST_F(TestTanhInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = tanh->Init(strategy, nullptr);
|
Status ret = tanh->Init(strategy, nullptr);
|
||||||
|
@ -149,7 +149,7 @@ TEST_F(TestTanhInfo, CheckStrategy1) {
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, CheckStrategy2) {
|
TEST_F(TestTanhInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = tanh->Init(strategy, nullptr);
|
Status ret = tanh->Init(strategy, nullptr);
|
||||||
|
@ -158,7 +158,7 @@ TEST_F(TestTanhInfo, CheckStrategy2) {
|
||||||
|
|
||||||
TEST_F(TestTanhInfo, CheckStrategy3) {
|
TEST_F(TestTanhInfo, CheckStrategy3) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategies inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = tanh->Init(strategy, nullptr);
|
Status ret = tanh->Init(strategy, nullptr);
|
||||||
|
|
|
@ -66,7 +66,7 @@ void TestTensorAddInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, InferDevMatrixShape1) {
|
TEST_F(TestTensorAddInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
|
Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
tensor_add->Init(strategy, nullptr);
|
tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -77,7 +77,7 @@ TEST_F(TestTensorAddInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, InferSliceShape1) {
|
TEST_F(TestTensorAddInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 4}, {2, 4, 4}};
|
Strategies str = {{2, 4, 4}, {2, 4, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
tensor_add->Init(strategy, nullptr);
|
tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -101,7 +101,7 @@ TEST_F(TestTensorAddInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, GetTensorLayout1) {
|
TEST_F(TestTensorAddInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 4}, {2, 4, 4}};
|
Strategies str = {{2, 4, 4}, {2, 4, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
tensor_add->Init(strategy, nullptr);
|
tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -125,7 +125,7 @@ TEST_F(TestTensorAddInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, GetForwardOp1) {
|
TEST_F(TestTensorAddInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
|
Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
tensor_add->Init(strategy, nullptr);
|
tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -136,7 +136,7 @@ TEST_F(TestTensorAddInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, GetMirrorOPs1) {
|
TEST_F(TestTensorAddInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
|
Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
tensor_add->Init(strategy, nullptr);
|
tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -148,7 +148,7 @@ TEST_F(TestTensorAddInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, CheckStrategy1) {
|
TEST_F(TestTensorAddInfo, CheckStrategy1) {
|
||||||
Strategys inputs = {{2, 4, 4}, {2, 6, 4}};
|
Strategies inputs = {{2, 4, 4}, {2, 6, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = tensor_add->Init(strategy, nullptr);
|
Status ret = tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -156,7 +156,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, CheckStrategy2) {
|
TEST_F(TestTensorAddInfo, CheckStrategy2) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = tensor_add->Init(strategy, nullptr);
|
Status ret = tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -164,7 +164,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, CheckStrategy3) {
|
TEST_F(TestTensorAddInfo, CheckStrategy3) {
|
||||||
Strategys inputs = {{2, 4, 6}};
|
Strategies inputs = {{2, 4, 6}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = tensor_add->Init(strategy, nullptr);
|
Status ret = tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, CheckStrategy4) {
|
TEST_F(TestTensorAddInfo, CheckStrategy4) {
|
||||||
Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
|
Strategies inputs = {{2, 4, 4}, {2, 4, 4}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = tensor_add->Init(strategy, nullptr);
|
Status ret = tensor_add->Init(strategy, nullptr);
|
||||||
|
@ -224,7 +224,7 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTensorAddInfo, mirror_ops) {
|
TEST_F(TestTensorAddInfo, mirror_ops) {
|
||||||
Strategys inputs = {{1, 8}, {4, 1}};
|
Strategies inputs = {{1, 8}, {4, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
tensor_add1->Init(strategy, nullptr);
|
tensor_add1->Init(strategy, nullptr);
|
||||||
|
|
|
@ -65,7 +65,7 @@ void TestTmpIdentityInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) {
|
TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 8, 16}};
|
Strategies inputs = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
identity_ptr->Init(strategy, nullptr);
|
identity_ptr->Init(strategy, nullptr);
|
||||||
|
@ -76,7 +76,7 @@ TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTmpIdentityInfo, InferSliceShape1) {
|
TEST_F(TestTmpIdentityInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 8, 16}};
|
Strategies str = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
identity_ptr->Init(strategy, nullptr);
|
identity_ptr->Init(strategy, nullptr);
|
||||||
|
@ -97,7 +97,7 @@ TEST_F(TestTmpIdentityInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTmpIdentityInfo, GetTensorLayout1) {
|
TEST_F(TestTmpIdentityInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 8, 16}};
|
Strategies str = {{2, 4, 8, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
identity_ptr->Init(strategy, nullptr);
|
identity_ptr->Init(strategy, nullptr);
|
||||||
|
@ -119,7 +119,7 @@ TEST_F(TestTmpIdentityInfo, GetTensorLayout1) {
|
||||||
|
|
||||||
TEST_F(TestTmpIdentityInfo, CheckStrategy1) {
|
TEST_F(TestTmpIdentityInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,8,16}}
|
// Success: {{2,4,8,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategies inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = identity_ptr->Init(strategy, nullptr);
|
Status ret = identity_ptr->Init(strategy, nullptr);
|
||||||
|
@ -128,7 +128,7 @@ TEST_F(TestTmpIdentityInfo, CheckStrategy1) {
|
||||||
|
|
||||||
TEST_F(TestTmpIdentityInfo, CheckStrategy2) {
|
TEST_F(TestTmpIdentityInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,8,16}}
|
// Success: {{2,4,8,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = identity_ptr->Init(strategy, nullptr);
|
Status ret = identity_ptr->Init(strategy, nullptr);
|
||||||
|
|
|
@ -68,7 +68,7 @@ void TestTransposeInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, InferDevMatrixShape1) {
|
TEST_F(TestTransposeInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{4, 8}};
|
Strategies inputs = {{4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
transpose->Init(strategy, nullptr);
|
transpose->Init(strategy, nullptr);
|
||||||
|
@ -79,7 +79,7 @@ TEST_F(TestTransposeInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, InferDevMatrixShape2) {
|
TEST_F(TestTransposeInfo, InferDevMatrixShape2) {
|
||||||
Strategys inputs = {{4, 1}};
|
Strategies inputs = {{4, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
transpose->Init(strategy, nullptr);
|
transpose->Init(strategy, nullptr);
|
||||||
|
@ -90,7 +90,7 @@ TEST_F(TestTransposeInfo, InferDevMatrixShape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, InferSliceShape1) {
|
TEST_F(TestTransposeInfo, InferSliceShape1) {
|
||||||
Strategys str = {{4, 8}};
|
Strategies str = {{4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
transpose->Init(strategy, nullptr);
|
transpose->Init(strategy, nullptr);
|
||||||
|
@ -111,7 +111,7 @@ TEST_F(TestTransposeInfo, InferSliceShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, GetTensorLayout1) {
|
TEST_F(TestTransposeInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{4, 8}};
|
Strategies str = {{4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
transpose->Init(strategy, nullptr);
|
transpose->Init(strategy, nullptr);
|
||||||
|
@ -132,7 +132,7 @@ TEST_F(TestTransposeInfo, GetTensorLayout1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, GetForwardOp1) {
|
TEST_F(TestTransposeInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{4, 8}};
|
Strategies inputs = {{4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
transpose->Init(strategy, nullptr);
|
transpose->Init(strategy, nullptr);
|
||||||
|
@ -143,7 +143,7 @@ TEST_F(TestTransposeInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, GetMirrorOPs1) {
|
TEST_F(TestTransposeInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{4, 8}};
|
Strategies inputs = {{4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
transpose->Init(strategy, nullptr);
|
transpose->Init(strategy, nullptr);
|
||||||
|
@ -155,7 +155,7 @@ TEST_F(TestTransposeInfo, GetMirrorOPs1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, CheckStrategy1) {
|
TEST_F(TestTransposeInfo, CheckStrategy1) {
|
||||||
Strategys inputs = {{1, 4, 8}};
|
Strategies inputs = {{1, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = transpose->Init(strategy, nullptr);
|
Status ret = transpose->Init(strategy, nullptr);
|
||||||
|
@ -163,7 +163,7 @@ TEST_F(TestTransposeInfo, CheckStrategy1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, CheckStrategy2) {
|
TEST_F(TestTransposeInfo, CheckStrategy2) {
|
||||||
Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
|
Strategies inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = transpose->Init(strategy, nullptr);
|
Status ret = transpose->Init(strategy, nullptr);
|
||||||
|
@ -171,7 +171,7 @@ TEST_F(TestTransposeInfo, CheckStrategy2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTransposeInfo, CheckStrategy3) {
|
TEST_F(TestTransposeInfo, CheckStrategy3) {
|
||||||
Strategys inputs = {{4, 8}};
|
Strategies inputs = {{4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
Status ret = transpose->Init(strategy, nullptr);
|
Status ret = transpose->Init(strategy, nullptr);
|
||||||
|
|
|
@ -237,9 +237,9 @@ TEST_F(TestStepParallel, ExtractStrategy) {
|
||||||
std::vector<ValuePtr> elements = {val1, val2};
|
std::vector<ValuePtr> elements = {val1, val2};
|
||||||
ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
|
ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
|
||||||
attrs["in_strategy"] = strategy_tuple;
|
attrs["in_strategy"] = strategy_tuple;
|
||||||
Strategys strategy_expect = {v1, v2};
|
Strategies strategy_expect = {v1, v2};
|
||||||
StrategyPtr strategy = ExtractStrategy(attrs["in_strategy"]);
|
StrategyPtr strategy = ExtractStrategy(attrs["in_strategy"]);
|
||||||
Strategys strategy_test = strategy->GetInputDim();
|
Strategies strategy_test = strategy->GetInputDim();
|
||||||
|
|
||||||
ASSERT_EQ(strategy_expect, strategy_test);
|
ASSERT_EQ(strategy_expect, strategy_test);
|
||||||
}
|
}
|
||||||
|
@ -380,7 +380,7 @@ TEST_F(TestStepParallel, OperatorInstance) {
|
||||||
prim->set_attr("transpose_b", transpose_b);
|
prim->set_attr("transpose_b", transpose_b);
|
||||||
auto attrs = prim->attrs();
|
auto attrs = prim->attrs();
|
||||||
// create strategy
|
// create strategy
|
||||||
Strategys strategy = {{2, 2}, {2, 4}};
|
Strategies strategy = {{2, 2}, {2, 4}};
|
||||||
StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
|
StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
|
||||||
// create shape
|
// create shape
|
||||||
Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
|
Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
|
||||||
|
@ -557,7 +557,7 @@ TEST_F(TestStepParallel, GetTensorInLayout) {
|
||||||
prim->set_attr("transpose_b", transpose_b);
|
prim->set_attr("transpose_b", transpose_b);
|
||||||
auto attrs = prim->attrs();
|
auto attrs = prim->attrs();
|
||||||
// create strategy
|
// create strategy
|
||||||
Strategys strategy = {{2, 2}, {2, 4}};
|
Strategies strategy = {{2, 2}, {2, 4}};
|
||||||
StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
|
StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
|
||||||
// create shape
|
// create shape
|
||||||
Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
|
Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
|
||||||
|
|
|
@ -35,7 +35,7 @@ TEST_F(TestStrategy, GetInputNumber) {
|
||||||
int32_t stage = 1;
|
int32_t stage = 1;
|
||||||
Dimensions dimension1 = {2, 4};
|
Dimensions dimension1 = {2, 4};
|
||||||
Dimensions dimension2 = {2, 2};
|
Dimensions dimension2 = {2, 2};
|
||||||
Strategys inputs = {dimension1, dimension2};
|
Strategies inputs = {dimension1, dimension2};
|
||||||
|
|
||||||
Strategy strategy(stage, inputs);
|
Strategy strategy(stage, inputs);
|
||||||
int32_t number_test = strategy.GetInputNumber();
|
int32_t number_test = strategy.GetInputNumber();
|
||||||
|
@ -46,7 +46,7 @@ TEST_F(TestStrategy, GetInputStage) {
|
||||||
int32_t stage = 1;
|
int32_t stage = 1;
|
||||||
Dimensions dimension1 = {2, 4};
|
Dimensions dimension1 = {2, 4};
|
||||||
Dimensions dimension2 = {2, 2};
|
Dimensions dimension2 = {2, 2};
|
||||||
Strategys inputs = {dimension1, dimension2};
|
Strategies inputs = {dimension1, dimension2};
|
||||||
|
|
||||||
Strategy strategy(stage, inputs);
|
Strategy strategy(stage, inputs);
|
||||||
int32_t stage_test = strategy.GetInputStage();
|
int32_t stage_test = strategy.GetInputStage();
|
||||||
|
@ -57,10 +57,10 @@ TEST_F(TestStrategy, GetInputDim) {
|
||||||
int32_t stage = 1;
|
int32_t stage = 1;
|
||||||
Dimensions dimension1 = {2, 4};
|
Dimensions dimension1 = {2, 4};
|
||||||
Dimensions dimension2 = {2, 2};
|
Dimensions dimension2 = {2, 2};
|
||||||
Strategys inputs = {dimension1, dimension2};
|
Strategies inputs = {dimension1, dimension2};
|
||||||
|
|
||||||
Strategy strategy(stage, inputs);
|
Strategy strategy(stage, inputs);
|
||||||
Strategys inputs_test = strategy.GetInputDim();
|
Strategies inputs_test = strategy.GetInputDim();
|
||||||
ASSERT_EQ(inputs, inputs_test);
|
ASSERT_EQ(inputs, inputs_test);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,10 +68,10 @@ TEST_F(TestStrategy, IsEqual) {
|
||||||
int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0;
|
int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0;
|
||||||
Dimensions dimension1 = {8, 1};
|
Dimensions dimension1 = {8, 1};
|
||||||
Dimensions dimension2 = {1, 8};
|
Dimensions dimension2 = {1, 8};
|
||||||
Strategys inputs1 = {dimension1};
|
Strategies inputs1 = {dimension1};
|
||||||
Strategys inputs2 = {dimension1};
|
Strategies inputs2 = {dimension1};
|
||||||
Strategys inputs3 = {dimension2};
|
Strategies inputs3 = {dimension2};
|
||||||
Strategys inputs4 = {dimension1, dimension2};
|
Strategies inputs4 = {dimension1, dimension2};
|
||||||
|
|
||||||
StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1);
|
StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1);
|
||||||
StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2);
|
StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2);
|
||||||
|
|
|
@ -62,7 +62,7 @@ void TestConstructOperator::SetUp() {
|
||||||
|
|
||||||
MatMulInfoPtr matmul = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_1, outputs_shape_1, attr_1);
|
MatMulInfoPtr matmul = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_1, outputs_shape_1, attr_1);
|
||||||
|
|
||||||
Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
Strategies str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
matmul->Init(strategy, nullptr);
|
matmul->Init(strategy, nullptr);
|
||||||
Shape tensor_shape = {512, 1024};
|
Shape tensor_shape = {512, 1024};
|
||||||
|
|
|
@ -62,7 +62,7 @@ void TestVirtualDatasetInfo::SetUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) {
|
TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{16, 1}, {16, 1}, {16, 1}};
|
Strategies inputs = {{16, 1}, {16, 1}, {16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
virtual_dataset->Init(strategy, nullptr);
|
virtual_dataset->Init(strategy, nullptr);
|
||||||
Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape();
|
Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape();
|
||||||
|
@ -72,7 +72,7 @@ TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestVirtualDatasetInfo, GetForwardOp1) {
|
TEST_F(TestVirtualDatasetInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{8, 1}, {8, 1}, {8, 1}};
|
Strategies inputs = {{8, 1}, {8, 1}, {8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
virtual_dataset->Init(strategy, nullptr);
|
virtual_dataset->Init(strategy, nullptr);
|
||||||
|
@ -83,7 +83,7 @@ TEST_F(TestVirtualDatasetInfo, GetForwardOp1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) {
|
TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{8, 1}, {8, 1}, {8, 1}};
|
Strategies inputs = {{8, 1}, {8, 1}, {8, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
virtual_dataset->Init(strategy, nullptr);
|
virtual_dataset->Init(strategy, nullptr);
|
||||||
|
|
Loading…
Reference in New Issue