dont create hccl group in auto parallel strategy search

This commit is contained in:
yao_yf 2020-06-23 11:29:04 +08:00
parent 4f3ea8015e
commit ad06ab4049
3 changed files with 17 additions and 4 deletions

View File

@ -1406,7 +1406,9 @@ Status CostGraph::InitSelectedStrategy() {
int32_t next_index = reshape_info->next_operator_index();
reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout());
}
return reshape_info->Init(nullptr);
if (reshape_info->Init(nullptr) != SUCCESS) {
return FAILED;
}
}
}
return SUCCESS;

View File

@ -133,9 +133,13 @@ Status ReshapeInfo::GetParameterInput() {
Status ReshapeInfo::ComputeReplaceOp() {
RankList dev_list = global_device_list();
TensorRedistribution tensor_redistribution(true, true);
TensorRedistribution tensor_redistribution(!is_generating_costs_, true);
if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) {
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
if (is_generating_costs_) {
MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed.";
} else {
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
}
return FAILED;
}
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
@ -143,7 +147,11 @@ Status ReshapeInfo::ComputeReplaceOp() {
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
if (redistribution_oplist_ptr == nullptr) {
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
if (is_generating_costs_) {
MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed.";
} else {
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
}
return FAILED;
}
replace_op_ = redistribution_oplist_ptr->first;
@ -444,6 +452,7 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs,
int32_t out_index, int32_t in_index, bool is_prev_param) {
is_generating_costs_ = true;
for (auto pre_stra_cost : pre_stra_costs) {
std::vector<TensorInfo> pre_out_tensor_infos;
if (is_prev_param) {
@ -488,6 +497,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
SetCostForReshape(reshape_stra);
}
}
is_generating_costs_ = false;
if (strategy_cost_.empty()) {
return FAILED;
}

View File

@ -97,6 +97,7 @@ class ReshapeInfo : public OperatorInfo {
TensorLayout output_layout_;
bool input_layout_set_flag_;
bool output_layout_set_flag_;
bool is_generating_costs_;
std::string pre_operator_name_;
std::string next_operator_name_;
};