!2376 [CT][ME][parallel] fixed One-hot runs failed in RP-search mode.

Merge pull request !2376 from Chong/zc
This commit is contained in:
mindspore-ci-bot 2020-06-22 20:54:41 +08:00 committed by Gitee
commit f975963a58
6 changed files with 76 additions and 137 deletions

View File

@ -28,7 +28,6 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
#define DOUBLE_MAX (std::numeric_limits<double>::max)()
// Compute redistributed cost // Compute redistributed cost
double CostRedis(const Graph::NodeType &node, double CostRedis(const Graph::NodeType &node,
@ -621,75 +620,50 @@ StrategyRec CostCommon::ChoseStr(const std::vector<double> &cost_op, StrategyRec
break; break;
default: default:
MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed."; MS_LOG(EXCEPTION) << "Failure: Common failed.";
} }
return str; return str;
} }
// Get weight for BN // Get optimal strategy for BatchParallel OPs
double CostBatchNorm::GetMinCostIn(const OperatorRec &op) { StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) {
int tensor = static_cast<int>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) *
static_cast<int>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) *
static_cast<int>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) *
static_cast<int>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
std::vector<double> cost_in;
cost_in.push_back(StrDimB(tensor) * 1.2);
cost_in.push_back(DOUBLE_MAX);
cost_in.push_back(StrDimH(tensor) * 1.2);
cost_in.push_back(StrDimW(tensor) * 1.2);
return *min_element(cost_in.begin(), cost_in.end());
}
// Get optimal strategy for BN
StrategyRec CostBatchNorm::GetOptimalStr(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
const Graph &graph) {
const OperatorRec &op = node.apply; const OperatorRec &op = node.apply;
int tensor_n = static_cast<int>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
int tensor_filter_n = static_cast<int>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n); int tensor_c = static_cast<int>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
int tensor_filter_c = static_cast<int>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); int tensor_h = static_cast<int>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
int tensor_filter_h = static_cast<int>(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h); int tensor_w = static_cast<int>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
int tensor_filter_w = static_cast<int>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w);
int tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c;
int output_tensor_h = static_cast<int>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h);
int output_tensor_w = static_cast<int>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w);
int output_tensor_n = static_cast<int>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
std::vector<double> cost_op; std::vector<double> cost_op;
std::vector<std::vector<float>> mode;
if (output_tensor_n < 2 || output_tensor_n % 2 != 0) { if (tensor_n < 2 || tensor_n % 2 != 0) {
cost_op.push_back(DOUBLE_MAX); cost_op.push_back(DOUBLE_MAX);
} else { } else {
cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, cost_op.push_back(cost_in_);
mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph));
} }
cost_op.push_back(DOUBLE_MAX); if (tensor_c < 2 || tensor_c % 2 != 0) {
if (output_tensor_h < 2 || output_tensor_h % 2 != 0) {
cost_op.push_back(DOUBLE_MAX); cost_op.push_back(DOUBLE_MAX);
} else { } else {
cost_op.push_back(StrDimH(tensor_filter) + CostRedis(node, node_name_to_strategy, cost_op.push_back(cost_in_);
mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}}, graph));
} }
if (output_tensor_w < 2 || output_tensor_w % 2 != 0) { if (tensor_h < 2 || tensor_h % 2 != 0) {
cost_op.push_back(DOUBLE_MAX); cost_op.push_back(DOUBLE_MAX);
} else { } else {
cost_op.push_back(StrDimW(tensor_filter) + CostRedis(node, node_name_to_strategy, cost_op.push_back(cost_in_);
mode = {{1, 1, 1, 0.5}, {1, 1, 1, 1}, {1, 1, 1, 0.5}}, graph)); }
if (tensor_w < 2 || tensor_w % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(cost_in_);
} }
return ChoseStr(cost_op, node.apply.str); return ChoseStr(cost_op, node.apply.str);
} }
// Chose strategy for BatchNorm // Chose strategy for BatchParallel op
StrategyRec CostBatchNorm::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) { StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
return str; return str;
@ -700,36 +674,32 @@ StrategyRec CostBatchNorm::ChoseStr(const std::vector<double> &cost_op, Strategy
str.inputTensor[0].str_n /= 2.0; str.inputTensor[0].str_n /= 2.0;
str.outputTensor.str_n /= 2.0; str.outputTensor.str_n /= 2.0;
str.cut_counter += 1; str.cut_counter += 1;
str.cost = str.cost + cost_in_b_; str.cost = str.cost + cost_in_;
break; break;
case 1: case 1:
str.inputTensor[0].str_c /= 2.0; str.inputTensor[0].str_c /= 2.0;
str.inputTensor[1].str_c /= 2.0;
str.inputTensor[2].str_c /= 2.0;
str.inputTensor[3].str_c /= 2.0;
str.inputTensor[4].str_c /= 2.0;
str.outputTensor.str_c /= 2.0; str.outputTensor.str_c /= 2.0;
str.cut_counter += 1; str.cut_counter += 1;
str.cost = str.cost + cost_in_c_; str.cost = str.cost + cost_in_;
break; break;
case 2: case 2:
str.inputTensor[0].str_h /= 2.0; str.inputTensor[0].str_h /= 2.0;
str.outputTensor.str_h /= 2.0; str.outputTensor.str_h /= 2.0;
str.cut_counter += 1; str.cut_counter += 1;
str.cost = str.cost + cost_in_h_; str.cost = str.cost + cost_in_;
break; break;
case 3: case 3:
str.inputTensor[0].str_w /= 2.0; str.inputTensor[0].str_w /= 2.0;
str.outputTensor.str_w /= 2.0; str.outputTensor.str_w /= 2.0;
str.cut_counter += 1; str.cut_counter += 1;
str.cost = str.cost + cost_in_w_; str.cost = str.cost + cost_in_;
break; break;
default: default:
MS_LOG(EXCEPTION) << "Failure: CostBatchNorm failed."; MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed.";
} }
return str; return str;
} }

View File

@ -28,6 +28,8 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
#define DOUBLE_MAX (std::numeric_limits<double>::max)()
double CostRedis(const Graph::NodeType &node, double CostRedis(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
const std::vector<std::vector<float>> &mode, const Graph &graph); const std::vector<std::vector<float>> &mode, const Graph &graph);
@ -195,7 +197,6 @@ class CostTensorAdd : public CostCommon {
}; };
// all the following operation are element-wise and have the same cost // all the following operation are element-wise and have the same cost
class CostOneHot : public CostCommon {};
class CostReLU : public CostCommon {}; class CostReLU : public CostCommon {};
class CostLog : public CostCommon {}; class CostLog : public CostCommon {};
class CostExp : public CostCommon {}; class CostExp : public CostCommon {};
@ -206,50 +207,21 @@ class CostDiv : public CostCommon {};
class CostSqueeze : public CostCommon {}; class CostSqueeze : public CostCommon {};
class CostCast : public CostCommon {}; class CostCast : public CostCommon {};
// class BatchNorm is used to compute the cost of BatchNorm operator. // class BatchParallel is used to compute the cost of BatchParallel operator.
class CostBatchNorm { class CostBatchParallel {
public: public:
StrategyRec GetOptimalStr(const Graph::NodeType &node, virtual StrategyRec GetOptimalStr(const Graph::NodeType &node);
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
const Graph &graph);
double GetMinCostIn(const OperatorRec &op); virtual double GetMaxCostIn() const { return DOUBLE_MAX; }
private: protected:
double StrDimB(int32_t Tensor) { virtual StrategyRec ChoseStr(const std::vector<double> &cost_op, StrategyRec str);
cost_in_b_ = (static_cast<double>(Tensor) * 4.0) / 2.0;
return cost_in_b_; double cost_in_ = 0;
} }; // class BatchParallel is used to compute the cost of BatchParallel operator.
double StrDimC() { class CostBatchNorm : public CostBatchParallel {};
cost_in_c_ = 0.0; class CostOneHot : public CostBatchParallel {};
return cost_in_c_;
}
double StrDimH(int32_t Tensor) {
cost_in_h_ = (static_cast<double>(Tensor) * 4.0) / 2.0;
return cost_in_h_;
}
double StrDimW(int32_t Tensor) {
cost_in_w_ = (static_cast<double>(Tensor) * 4.0) / 2.0;
return cost_in_w_;
}
StrategyRec ChoseStr(const std::vector<double> &cost_op, StrategyRec str);
double cost_in_b_ = 0;
double cost_in_c_ = 0;
double cost_in_h_ = 0;
double cost_in_w_ = 0;
}; // class BatchNorm is used to compute the cost of BatchNorm operator.
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_

View File

@ -135,17 +135,6 @@ std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &gra
return strategies; return strategies;
} }
std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
for (size_t i = 1; i < strategies.size(); i++) {
strategies[i][0] = strategies[0][1];
}
strategies[1][0] = 1;
return strategies;
}
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) { std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) {
std::vector<std::vector<int32_t>> strategies; std::vector<std::vector<int32_t>> strategies;
strategies.push_back(*s); strategies.push_back(*s);
@ -155,10 +144,15 @@ std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vect
return strategies; return strategies;
} }
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s) { std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph,
std::vector<std::vector<int32_t>> strategies; const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
strategies[0][0] = strategies[0][1];
strategies[0][1] = 1;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
std::vector<int32_t> s_empty = {}; std::vector<int32_t> s_empty = {};
strategies.push_back(*s);
strategies.push_back(s_empty); strategies.push_back(s_empty);
strategies.push_back(s_empty); strategies.push_back(s_empty);
return strategies; return strategies;
@ -287,8 +281,8 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
return PrepareMatMul(graph, ops, iter_graph, iter_ops); return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == PRELU) { } else if (type == PRELU) {
return PreparePReLU(graph, ops, iter_graph, iter_ops); return PreparePReLU(graph, ops, iter_graph, iter_ops);
} else if (type == BATCH_NORM) { } else if (type == ONEHOT) {
return PrepareBatchNorm(graph, ops, iter_graph, iter_ops); return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS || } else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS ||
type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
@ -513,9 +507,6 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
if (ops[iter_ops]->type() == BIAS_ADD) { if (ops[iter_ops]->type() == BIAS_ADD) {
return PrepareBiasAdd(s_ptr); return PrepareBiasAdd(s_ptr);
} }
if (ops[iter_ops]->type() == ONEHOT) {
return PrepareOneHot(s_ptr);
}
if (ops[iter_ops]->type() == GATHERV2) { if (ops[iter_ops]->type() == GATHERV2) {
return PrepareGatherV2(s_ptr); return PrepareGatherV2(s_ptr);
} }
@ -559,7 +550,7 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> grap
std::vector<std::vector<int32_t>> stra; std::vector<std::vector<int32_t>> stra;
std::vector<int32_t> s; std::vector<int32_t> s;
size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
if (incoming_op_index != SIZE_MAX && ops[iter_ops]->type() != ONEHOT) { if (incoming_op_index != SIZE_MAX) {
auto iter_graph = index_list->at(incoming_op_index); auto iter_graph = index_list->at(incoming_op_index);
if (iter_graph != SIZE_MAX) { if (iter_graph != SIZE_MAX) {
s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph);
@ -640,7 +631,7 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
} }
if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) { if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) {
for (size_t k = 0; k < ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs].size(); ++k) { for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) {
s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]); s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]);
} }
} }

View File

@ -37,11 +37,10 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr
std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s); std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s);
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s); std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s); std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,

View File

@ -216,10 +216,10 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
const std::shared_ptr<std::vector<size_t>> index_list) { const std::shared_ptr<std::vector<size_t>> index_list) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
const std::set<OperatorType> type_list = { const std::set<OperatorType> type_list = {
OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecCast, OperatorType::kRecReshape, OperatorType::kRecGatherV2}; OperatorType::kRecReshape, OperatorType::kRecGatherV2};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type; auto type = graph->nodes[node_index].apply.op_type;
if (type_list.find(type) != type_list.end()) { if (type_list.find(type) != type_list.end()) {

View File

@ -68,17 +68,21 @@ double GetWeights(const Graph::NodeType &node) {
auto cost_ptr = std::make_shared<CostBiasAdd>(); auto cost_ptr = std::make_shared<CostBiasAdd>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog || } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp ||
op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub ||
op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv ||
op.op_type == OperatorType::kRecDiv || op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) {
op.op_type == OperatorType::kRecCast) {
// For element-wise op // For element-wise op
auto cost_ptr = std::make_shared<CostCommon>(); auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot) {
// For BatchParallel op
auto cost_ptr = std::make_shared<CostBatchParallel>();
return cost_ptr->GetMaxCostIn();
} else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU || } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU ||
op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecSoftmax || op.op_type == OperatorType::kRecSoftmax ||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
// For unprocessed type // For unprocessed type
return 0.0; return 0.0;
@ -158,17 +162,20 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
auto cost_ptr = std::make_shared<CostBiasAdd>(); auto cost_ptr = std::make_shared<CostBiasAdd>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecOneHot || node.apply.op_type == OperatorType::kRecLog || } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp ||
node.apply.op_type == OperatorType::kRecExp || node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub ||
node.apply.op_type == OperatorType::kRecSub || node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv ||
node.apply.op_type == OperatorType::kRecDiv || node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) {
node.apply.op_type == OperatorType::kRecCast) {
// For element-wise op // For element-wise op
auto cost_ptr = std::make_shared<CostCommon>(); auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot) {
// For BatchParallel type
auto cost_ptr = std::make_shared<CostBatchParallel>();
return cost_ptr->GetOptimalStr(node);
} else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU || } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU ||
node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecSoftmax || node.apply.op_type == OperatorType::kRecSoftmax ||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
// For unprocessed type // For unprocessed type
StrategyRec default_strategy; StrategyRec default_strategy;