fix opt shard not shard param in pipeline

This commit is contained in:
yao_yf 2023-03-06 11:50:29 +08:00
parent a04b337d78
commit dd773cb404
3 changed files with 20 additions and 4 deletions

View File

@ -1265,13 +1265,18 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
op = CreateAllGatherOp(group);
}
CNodePtr cast_node = InsertAllGatherAfterCast(cnode);
std::string opt_shard_mirror_group;
bool is_with_mirror = false;
auto param_ptr = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->user_data<TensorLayout>()) {
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
auto opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
is_with_mirror = !opt_shard_mirror_group.empty();
if (!is_with_mirror) {
auto mirror_group = mirror_group_list(param_ptr->user_data<TensorLayout>());
is_with_mirror = !mirror_group.empty();
}
}
bool is_with_mirror = !opt_shard_mirror_group.empty();
if (!is_shared_param && cast_node) {
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;

View File

@ -1515,6 +1515,17 @@ TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_
return tensorlayout_in;
}
Shape mirror_group_list(const TensorLayoutPtr &layout) {
int64_t rank = g_device_manager->global_rank();
auto stage_dev_list = g_device_manager->GetDeviceListInThisStage();
DeviceMatrix dev_matrix(rank, stage_dev_list, layout->device_arrangement().array());
RankList group_devices;
if (dev_matrix.GetDevicesByTensorMap(layout->tensor_map().array(), &group_devices) != SUCCESS) {
MS_LOG(EXCEPTION) << "For layout:" << layout->ToString() << ", infer mirror failed";
}
return group_devices;
}
std::string GetSerialNumberString(size_t number) {
std::string suffix = "th";
if (number == kSizeOne) {

View File

@ -111,7 +111,7 @@ StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape);
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
bool IsInsertVirtualOutput(const FuncGraphPtr &root);
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);
Shape mirror_group_list(const TensorLayoutPtr &layout);
// Transfer number to serial number string
std::string GetSerialNumberString(size_t number);
} // namespace parallel