forked from mindspore-Ecosystem/mindspore
fix l2normalize/prelu/softmax cost
This commit is contained in:
parent
948ea950af
commit
8e20d4d84e
|
@ -703,5 +703,48 @@ StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, Stra
|
|||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
// Chose strategy for CostSoftmaxCrossEntropyWithLogits
|
||||
StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
|
||||
uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
|
||||
if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
|
||||
return str;
|
||||
}
|
||||
|
||||
switch (min_position) {
|
||||
case 0:
|
||||
str.inputTensor[0].str_n /= 2.0;
|
||||
str.inputTensor[1].str_n /= 2.0;
|
||||
str.cut_counter += 1;
|
||||
str.cost = str.cost + cost_in_;
|
||||
break;
|
||||
|
||||
case 1:
|
||||
str.inputTensor[0].str_c /= 2.0;
|
||||
str.inputTensor[1].str_c /= 2.0;
|
||||
str.cut_counter += 1;
|
||||
str.cost = str.cost + cost_in_;
|
||||
break;
|
||||
|
||||
case 2:
|
||||
str.inputTensor[0].str_h /= 2.0;
|
||||
str.inputTensor[1].str_h /= 2.0;
|
||||
str.outputTensor.str_w /= 2.0;
|
||||
str.cut_counter += 1;
|
||||
str.cost = str.cost + cost_in_;
|
||||
break;
|
||||
|
||||
case 3:
|
||||
str.inputTensor[0].str_w /= 2.0;
|
||||
str.inputTensor[1].str_w /= 2.0;
|
||||
str.cut_counter += 1;
|
||||
str.cost = str.cost + cost_in_;
|
||||
break;
|
||||
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed.";
|
||||
}
|
||||
return str;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -222,6 +222,12 @@ class CostBatchParallel {
|
|||
|
||||
class CostBatchNorm : public CostBatchParallel {};
|
||||
class CostOneHot : public CostBatchParallel {};
|
||||
class CostPRelu : public CostBatchParallel {};
|
||||
class CostSoftmax : public CostBatchParallel {};
|
||||
|
||||
class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel {
|
||||
StrategyRec ChoseStr(const std::vector<double> &cost_op, StrategyRec str);
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_
|
||||
|
|
|
@ -127,14 +127,6 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr
|
|||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops) {
|
||||
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
strategies[1][0] = 1;
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) {
|
||||
std::vector<std::vector<int32_t>> strategies;
|
||||
strategies.push_back(*s);
|
||||
|
@ -164,6 +156,32 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vec
|
|||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s) {
|
||||
int32_t axis = 0;
|
||||
auto iter = ops[iter_ops]->attrs().find(AXIS);
|
||||
if (iter != ops[iter_ops]->attrs().end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
if (iter->second->isa<Int32Imm>()) {
|
||||
axis = iter->second->cast<Int32ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int.";
|
||||
}
|
||||
}
|
||||
|
||||
int32_t axis_index = axis;
|
||||
if (axis < 0) {
|
||||
size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
|
||||
axis_index = static_cast<int32_t>(input_dim) + axis;
|
||||
}
|
||||
|
||||
s[IntToSize(axis_index)] = 1;
|
||||
|
||||
std::vector<std::vector<int32_t>> strategies;
|
||||
strategies.push_back(s);
|
||||
return strategies;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops) {
|
||||
|
@ -279,13 +297,8 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
|
|||
|
||||
if (type == MATMUL) {
|
||||
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == PRELU) {
|
||||
return PreparePReLU(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == ONEHOT) {
|
||||
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS ||
|
||||
type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else {
|
||||
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||
}
|
||||
|
@ -510,6 +523,9 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
|
|||
if (ops[iter_ops]->type() == GATHERV2) {
|
||||
return PrepareGatherV2(s_ptr);
|
||||
}
|
||||
if (ops[iter_ops]->type() == L2_NORMALIZE) {
|
||||
return PrepareL2Normalize(ops, iter_ops, basic_stra);
|
||||
}
|
||||
|
||||
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
|
||||
iter_op_inputs++) {
|
||||
|
|
|
@ -34,14 +34,13 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
|
|||
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s);
|
||||
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s);
|
||||
std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
|
|
|
@ -38,6 +38,7 @@ enum OperatorType {
|
|||
kRecBiasAdd,
|
||||
kRecSoftmax,
|
||||
kRecSparseSoftmaxCrossEntropyWithLogits,
|
||||
kRecSoftmaxCrossEntropyWithLogits,
|
||||
kRecOneHot,
|
||||
kRecLog,
|
||||
kRecExp,
|
||||
|
|
|
@ -250,12 +250,22 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
|
|||
|
||||
new_graph->nodes.push_back(graph->nodes[i]);
|
||||
auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
|
||||
for (size_t j = 0; j < node_in->size(); j++) {
|
||||
node_in->at(j) = index_list->at(node_in->at(j));
|
||||
for (size_t j = node_in->size(); j > 0; j--) {
|
||||
bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX);
|
||||
if (IsEliminated) {
|
||||
node_in->erase(node_in->begin() + j - 1);
|
||||
} else {
|
||||
node_in->at(j - 1) = index_list->at(node_in->at(j - 1));
|
||||
}
|
||||
}
|
||||
auto *node_out = &new_graph->nodes[index_list->at(i)].node_out;
|
||||
for (size_t j = 0; j < node_out->size(); j++) {
|
||||
node_out->at(j) = index_list->at(node_out->at(j));
|
||||
for (size_t j = node_out->size(); j > 0; j--) {
|
||||
bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX);
|
||||
if (IsEliminated) {
|
||||
node_out->erase(node_out->begin() + j - 1);
|
||||
} else {
|
||||
node_out->at(j - 1) = index_list->at(node_out->at(j - 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
return new_graph;
|
||||
|
|
|
@ -67,7 +67,7 @@ const std::map<std::string, OperatorType> DictOpType{
|
|||
{REAL_DIV, OperatorType::kRecElmWiseOp},
|
||||
{SOFTMAX, OperatorType::kRecSoftmax},
|
||||
{LOG_SOFTMAX, OperatorType::kRecSoftmax},
|
||||
{SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmax},
|
||||
{SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits},
|
||||
{SQRT, OperatorType::kRecElmWiseOp},
|
||||
{NEG, OperatorType::kRecElmWiseOp},
|
||||
{POW, OperatorType::kRecElmWiseOp},
|
||||
|
|
|
@ -76,15 +76,16 @@ double GetWeights(const Graph::NodeType &node) {
|
|||
auto cost_ptr = std::make_shared<CostCommon>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot) {
|
||||
} else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot ||
|
||||
op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax ||
|
||||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
|
||||
op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
|
||||
// For BatchParallel op
|
||||
auto cost_ptr = std::make_shared<CostBatchParallel>();
|
||||
|
||||
return cost_ptr->GetMaxCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU ||
|
||||
op.op_type == OperatorType::kRecSoftmax ||
|
||||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For unprocessed type
|
||||
} else if (op.op_type == OperatorType::kRecUnkownType) {
|
||||
// For Unkown type
|
||||
return 0.0;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
|
||||
|
@ -170,14 +171,18 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
auto cost_ptr = std::make_shared<CostCommon>();
|
||||
|
||||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
} else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot) {
|
||||
} else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot ||
|
||||
node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax ||
|
||||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For BatchParallel type
|
||||
auto cost_ptr = std::make_shared<CostBatchParallel>();
|
||||
return cost_ptr->GetOptimalStr(node);
|
||||
} else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU ||
|
||||
node.apply.op_type == OperatorType::kRecSoftmax ||
|
||||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For unprocessed type
|
||||
} else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
|
||||
// For SoftmaxCrossEntropyWithLogits type
|
||||
auto cost_ptr = std::make_shared<CostSoftmaxCrossEntropyWithLogits>();
|
||||
return cost_ptr->GetOptimalStr(node);
|
||||
} else if (node.apply.op_type == OperatorType::kRecUnkownType) {
|
||||
// For Unkown type
|
||||
StrategyRec default_strategy;
|
||||
return default_strategy;
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue