fix opt shard not shard param in pipeline
This commit is contained in:
parent
a04b337d78
commit
dd773cb404
|
@ -1265,13 +1265,18 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
|
|||
op = CreateAllGatherOp(group);
|
||||
}
|
||||
CNodePtr cast_node = InsertAllGatherAfterCast(cnode);
|
||||
std::string opt_shard_mirror_group;
|
||||
bool is_with_mirror = false;
|
||||
auto param_ptr = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (param_ptr->user_data<TensorLayout>()) {
|
||||
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
|
||||
auto opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
|
||||
is_with_mirror = !opt_shard_mirror_group.empty();
|
||||
if (!is_with_mirror) {
|
||||
auto mirror_group = mirror_group_list(param_ptr->user_data<TensorLayout>());
|
||||
is_with_mirror = !mirror_group.empty();
|
||||
}
|
||||
}
|
||||
bool is_with_mirror = !opt_shard_mirror_group.empty();
|
||||
|
||||
if (!is_shared_param && cast_node) {
|
||||
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
|
||||
|
|
|
@ -1515,6 +1515,17 @@ TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_
|
|||
return tensorlayout_in;
|
||||
}
|
||||
|
||||
Shape mirror_group_list(const TensorLayoutPtr &layout) {
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
auto stage_dev_list = g_device_manager->GetDeviceListInThisStage();
|
||||
DeviceMatrix dev_matrix(rank, stage_dev_list, layout->device_arrangement().array());
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(layout->tensor_map().array(), &group_devices) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "For layout:" << layout->ToString() << ", infer mirror failed";
|
||||
}
|
||||
return group_devices;
|
||||
}
|
||||
|
||||
std::string GetSerialNumberString(size_t number) {
|
||||
std::string suffix = "th";
|
||||
if (number == kSizeOne) {
|
||||
|
|
|
@ -111,7 +111,7 @@ StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape);
|
|||
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
|
||||
bool IsInsertVirtualOutput(const FuncGraphPtr &root);
|
||||
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);
|
||||
|
||||
Shape mirror_group_list(const TensorLayoutPtr &layout);
|
||||
// Transfer number to serial number string
|
||||
std::string GetSerialNumberString(size_t number);
|
||||
} // namespace parallel
|
||||
|
|
Loading…
Reference in New Issue