!26794 transformer_slice_activation_config_fix
Merge pull request !26794 from yao_yf/add_transformer_slice_activation_config_fix
This commit is contained in:
commit
4a82477626
|
@ -57,7 +57,7 @@ CNodePtr CreateAllGatherCNode(const AnfNodePtr &node, std::string group) {
|
|||
return new_node;
|
||||
}
|
||||
|
||||
parallel::Group InferRepeatedRankList(const CNodePtr &cnode) {
|
||||
std::vector<parallel::Group> InferRepeatedRankList(const CNodePtr &cnode) {
|
||||
OperatorInfoPtr operator_info = cnode->user_data<parallel::OperatorInfo>();
|
||||
std::vector<parallel::TensorInfo> output_info = operator_info->outputs_tensor_info();
|
||||
if (output_info.size() != 1) {
|
||||
|
@ -67,7 +67,7 @@ parallel::Group InferRepeatedRankList(const CNodePtr &cnode) {
|
|||
auto tensor_map = tensor_layout.origin_tensor_map();
|
||||
std::vector<parallel::Group> groups;
|
||||
operator_info->CreateGroupByTensorMap(tensor_map.array(), &groups);
|
||||
return groups[0];
|
||||
return groups;
|
||||
}
|
||||
|
||||
bool IsDuplicateNode(const AnfNodePtr &node) {
|
||||
|
@ -160,7 +160,11 @@ void InsertSliceAllGatherNode(const std::vector<std::pair<std::shared_ptr<AnfNod
|
|||
int64_t device_num = parallel::g_device_manager->DeviceNum();
|
||||
int64_t stage_device_num = device_num / stage_num;
|
||||
int64_t local_rank_id = global_rank_id % stage_device_num;
|
||||
auto group = InferRepeatedRankList(node);
|
||||
auto groups = InferRepeatedRankList(node);
|
||||
if (groups.empty()) {
|
||||
return;
|
||||
}
|
||||
auto group = groups[0];
|
||||
if (out_shape_element[0] % group.GetDevNum() != 0) {
|
||||
MS_LOG(WARNING) << "The output_shape first dim:" << out_shape_element[0]
|
||||
<< " cannot be divisible by the repeated size: " << group.GetDevNum()
|
||||
|
|
|
@ -45,7 +45,7 @@ bool StrategyCheckpoint::CheckPath(const std::string path) const {
|
|||
MS_LOG(ERROR) << "The checkpoit path " << path << " is too long";
|
||||
return false;
|
||||
}
|
||||
auto realpath = Common::CreatePrefixPath(path);
|
||||
auto realpath = Common::CreatePrefixPath(path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << path;
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue