forked from mindspore-Ecosystem/mindspore
auto parallel context modify
This commit is contained in:
parent
042ac51f05
commit
8f7aa5bd5a
|
@ -42,15 +42,12 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
|
||||||
return inst_context_;
|
return inst_context_;
|
||||||
}
|
}
|
||||||
|
|
||||||
ParallelContext::ParallelContext() {
|
ParallelContext::ParallelContext() { Reset(); }
|
||||||
communication_backend_ = HCCL_BACKEND;
|
|
||||||
Reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ParallelContext::Reset() {
|
void ParallelContext::Reset() {
|
||||||
mirror_mean_ = false;
|
mirror_mean_ = false;
|
||||||
full_batch_ = false;
|
full_batch_ = false;
|
||||||
cast_before_mirror_ = true;
|
gradient_fp32_sync_ = true;
|
||||||
loss_repeated_mean_ = true;
|
loss_repeated_mean_ = true;
|
||||||
device_num_ = 1;
|
device_num_ = 1;
|
||||||
global_rank_ = 0;
|
global_rank_ = 0;
|
||||||
|
@ -81,14 +78,10 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_
|
||||||
|
|
||||||
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
|
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
|
||||||
|
|
||||||
void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; }
|
void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
|
||||||
|
|
||||||
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
|
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
|
||||||
|
|
||||||
void ParallelContext::set_communication_backend(const std::string &communication_backend) {
|
|
||||||
communication_backend_ = communication_backend;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) {
|
bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) {
|
||||||
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
|
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
|
||||||
if (iter == PARALLEL_MODE_LIST.end()) {
|
if (iter == PARALLEL_MODE_LIST.end()) {
|
||||||
|
|
|
@ -58,8 +58,8 @@ class ParallelContext {
|
||||||
void set_full_batch(bool full_batch);
|
void set_full_batch(bool full_batch);
|
||||||
bool full_batch() const { return full_batch_; }
|
bool full_batch() const { return full_batch_; }
|
||||||
|
|
||||||
void set_cast_before_mirror(bool cast_before_mirror);
|
void set_gradient_fp32_sync(bool gradient_fp32_sync);
|
||||||
bool cast_before_mirror() const { return cast_before_mirror_; }
|
bool gradient_fp32_sync() const { return gradient_fp32_sync_; }
|
||||||
|
|
||||||
void set_loss_repeated_mean(bool loss_repeated_mean);
|
void set_loss_repeated_mean(bool loss_repeated_mean);
|
||||||
bool loss_repeated_mean() const { return loss_repeated_mean_; }
|
bool loss_repeated_mean() const { return loss_repeated_mean_; }
|
||||||
|
@ -70,9 +70,6 @@ class ParallelContext {
|
||||||
void set_global_rank(int32_t global_rank);
|
void set_global_rank(int32_t global_rank);
|
||||||
int32_t global_rank() const { return global_rank_; }
|
int32_t global_rank() const { return global_rank_; }
|
||||||
|
|
||||||
void set_communication_backend(const std::string &communication_backend);
|
|
||||||
std::string communication_backend() const { return communication_backend_; }
|
|
||||||
|
|
||||||
bool set_parallel_mode(const std::string ¶llel_mode);
|
bool set_parallel_mode(const std::string ¶llel_mode);
|
||||||
std::string parallel_mode() const { return parallel_mode_; }
|
std::string parallel_mode() const { return parallel_mode_; }
|
||||||
|
|
||||||
|
@ -112,11 +109,10 @@ class ParallelContext {
|
||||||
static std::shared_ptr<ParallelContext> inst_context_;
|
static std::shared_ptr<ParallelContext> inst_context_;
|
||||||
bool mirror_mean_;
|
bool mirror_mean_;
|
||||||
bool full_batch_;
|
bool full_batch_;
|
||||||
bool cast_before_mirror_;
|
bool gradient_fp32_sync_;
|
||||||
bool loss_repeated_mean_;
|
bool loss_repeated_mean_;
|
||||||
int32_t device_num_;
|
int32_t device_num_;
|
||||||
int32_t global_rank_;
|
int32_t global_rank_;
|
||||||
std::string communication_backend_;
|
|
||||||
std::string parallel_mode_;
|
std::string parallel_mode_;
|
||||||
std::string strategy_search_mode_;
|
std::string strategy_search_mode_;
|
||||||
bool parameter_broadcast_;
|
bool parameter_broadcast_;
|
||||||
|
|
|
@ -43,6 +43,7 @@
|
||||||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||||
#include "utils/comm_manager.h"
|
#include "utils/comm_manager.h"
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
using mindspore::tensor::Tensor;
|
using mindspore::tensor::Tensor;
|
||||||
|
|
||||||
|
@ -869,8 +870,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
|
bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
|
||||||
// only if cast_before_mirror is true, pre node is cast and type is not float32 return true
|
// only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
|
||||||
if (!ParallelContext::GetInstance()->cast_before_mirror()) {
|
if (!ParallelContext::GetInstance()->gradient_fp32_sync()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto pre_node = node->input(index);
|
auto pre_node = node->input(index);
|
||||||
|
@ -2421,13 +2422,17 @@ Status ParallelInit() {
|
||||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||||
int32_t device_num = ParallelContext::GetInstance()->device_num();
|
int32_t device_num = ParallelContext::GetInstance()->device_num();
|
||||||
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
|
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
|
||||||
std::string backend = ParallelContext::GetInstance()->communication_backend();
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||||
std::string world_group;
|
std::string world_group;
|
||||||
|
std::string communication_backend;
|
||||||
if (backend == HCCL_BACKEND) {
|
if (backend == kAscendDevice || backend == kDavinciDevice) {
|
||||||
world_group = HCCL_WORLD_GROUP;
|
world_group = HCCL_WORLD_GROUP;
|
||||||
} else if (backend == NCCL_BACKEND) {
|
communication_backend = HCCL_BACKEND;
|
||||||
|
} else if (backend == kGPUDevice) {
|
||||||
world_group = NCCL_WORLD_GROUP;
|
world_group = NCCL_WORLD_GROUP;
|
||||||
|
communication_backend = NCCL_BACKEND;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
|
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
|
||||||
}
|
}
|
||||||
|
@ -2450,14 +2455,14 @@ Status ParallelInit() {
|
||||||
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
|
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!InitDevice(device_num, global_rank, backend)) {
|
if (!InitDevice(device_num, global_rank, communication_backend)) {
|
||||||
MS_LOG(ERROR) << "Init device failed";
|
MS_LOG(ERROR) << "Init device failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
|
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
|
||||||
<< ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean()
|
<< ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean()
|
||||||
<< ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror();
|
<< ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -209,12 +209,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
|
.def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
|
||||||
.def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.")
|
.def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.")
|
||||||
.def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.")
|
.def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.")
|
||||||
.def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.")
|
.def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.")
|
||||||
.def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.")
|
.def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.")
|
||||||
.def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
|
.def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
|
||||||
.def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
|
.def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
|
||||||
.def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.")
|
|
||||||
.def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.")
|
|
||||||
.def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
|
.def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
|
||||||
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
|
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
|
||||||
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
|
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
"""Communication management API"""
|
"""Communication management API"""
|
||||||
import os
|
import os
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
||||||
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
||||||
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
||||||
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
|
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
|
||||||
|
@ -86,9 +85,6 @@ def init(backend_name=None):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
|
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
|
||||||
|
|
||||||
auto_parallel_context().set_communication_backend(backend_name)
|
|
||||||
|
|
||||||
|
|
||||||
def release():
|
def release():
|
||||||
"""
|
"""
|
||||||
Release distributed resource. e.g., hccl/nccl.
|
Release distributed resource. e.g., hccl/nccl.
|
||||||
|
|
|
@ -434,7 +434,7 @@ def _context():
|
||||||
return _k_context
|
return _k_context
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
||||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
||||||
def set_auto_parallel_context(**kwargs):
|
def set_auto_parallel_context(**kwargs):
|
||||||
|
@ -454,9 +454,9 @@ def set_auto_parallel_context(**kwargs):
|
||||||
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
||||||
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
|
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
|
||||||
"stand_alone" do not support mirror_mean. Default: False.
|
"stand_alone" do not support mirror_mean. Default: False.
|
||||||
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True.
|
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True..
|
||||||
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
|
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
|
||||||
cast_before_mirror. Default: True.
|
gradient_fp32_sync. Default: True.
|
||||||
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
||||||
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
|
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
|
||||||
|
|
||||||
|
@ -492,7 +492,7 @@ def set_auto_parallel_context(**kwargs):
|
||||||
>>> context.set_auto_parallel_context(device_num=8)
|
>>> context.set_auto_parallel_context(device_num=8)
|
||||||
>>> context.set_auto_parallel_context(global_rank=0)
|
>>> context.set_auto_parallel_context(global_rank=0)
|
||||||
>>> context.set_auto_parallel_context(mirror_mean=True)
|
>>> context.set_auto_parallel_context(mirror_mean=True)
|
||||||
>>> context.set_auto_parallel_context(cast_before_mirror=False)
|
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
||||||
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||||
|
@ -524,7 +524,7 @@ def reset_auto_parallel_context():
|
||||||
- device_num: 1.
|
- device_num: 1.
|
||||||
- global_rank: 0.
|
- global_rank: 0.
|
||||||
- mirror_mean: False.
|
- mirror_mean: False.
|
||||||
- cast_before_mirror: True.
|
- gradient_fp32_sync: True.
|
||||||
- parallel_mode: "stand_alone".
|
- parallel_mode: "stand_alone".
|
||||||
- parameter_broadcast: False.
|
- parameter_broadcast: False.
|
||||||
- strategy_ckpt_load_file: "".
|
- strategy_ckpt_load_file: "".
|
||||||
|
|
|
@ -113,24 +113,24 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_mirror_mean()
|
return self._context_handle.get_mirror_mean()
|
||||||
|
|
||||||
def set_cast_before_mirror(self, cast_before_mirror):
|
def set_gradient_fp32_sync(self, gradient_fp32_sync):
|
||||||
"""
|
"""
|
||||||
Set cast_before_mirror.
|
Set gradient_fp32_sync.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
If cast_before_mirror is true,
|
If gradient_fp32_sync is true,
|
||||||
it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
|
it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cast_before_mirror (bool): The cast_before_mirror flag.
|
gradient_fp32_sync (bool): The gradient_fp32_sync flag.
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
self._context_handle.set_cast_before_mirror(cast_before_mirror)
|
self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
|
||||||
|
|
||||||
def get_cast_before_mirror(self):
|
def get_gradient_fp32_sync(self):
|
||||||
"""Get cast_before_mirror flag."""
|
"""Get gradient_fp32_sync flag."""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_cast_before_mirror()
|
return self._context_handle.get_gradient_fp32_sync()
|
||||||
|
|
||||||
def set_loss_repeated_mean(self, loss_repeated_mean):
|
def set_loss_repeated_mean(self, loss_repeated_mean):
|
||||||
"""
|
"""
|
||||||
|
@ -152,21 +152,6 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_loss_repeated_mean()
|
return self._context_handle.get_loss_repeated_mean()
|
||||||
|
|
||||||
def set_communication_backend(self, communication_backend):
|
|
||||||
"""
|
|
||||||
Set communication backend.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
communication_backend (str): The communication backend.
|
|
||||||
"""
|
|
||||||
self.check_context_handle()
|
|
||||||
self._context_handle.set_communication_backend(communication_backend)
|
|
||||||
|
|
||||||
def get_communication_backend(self):
|
|
||||||
"""Get communication backend."""
|
|
||||||
self.check_context_handle()
|
|
||||||
return self._context_handle.get_communication_backend()
|
|
||||||
|
|
||||||
def set_parallel_mode(self, parallel_mode):
|
def set_parallel_mode(self, parallel_mode):
|
||||||
"""
|
"""
|
||||||
Set parallel mode for auto parallel.
|
Set parallel mode for auto parallel.
|
||||||
|
@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = {
|
||||||
"device_num": auto_parallel_context().set_device_num,
|
"device_num": auto_parallel_context().set_device_num,
|
||||||
"global_rank": auto_parallel_context().set_global_rank,
|
"global_rank": auto_parallel_context().set_global_rank,
|
||||||
"mirror_mean": auto_parallel_context().set_mirror_mean,
|
"mirror_mean": auto_parallel_context().set_mirror_mean,
|
||||||
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
|
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
|
||||||
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
||||||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||||
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
|
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
|
||||||
|
@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = {
|
||||||
"device_num": auto_parallel_context().get_device_num,
|
"device_num": auto_parallel_context().get_device_num,
|
||||||
"global_rank": auto_parallel_context().get_global_rank,
|
"global_rank": auto_parallel_context().get_global_rank,
|
||||||
"mirror_mean": auto_parallel_context().get_mirror_mean,
|
"mirror_mean": auto_parallel_context().get_mirror_mean,
|
||||||
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
|
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
|
||||||
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
||||||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||||
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
|
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
|
||||||
|
@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = {
|
||||||
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool,
|
||||||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
||||||
|
@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs):
|
||||||
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
||||||
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
|
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
|
||||||
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
|
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
|
||||||
calculations. Default: True.
|
calculations. Default: True.
|
||||||
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
|
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
|
||||||
|
Default: True.
|
||||||
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
||||||
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
|
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
|
||||||
|
|
||||||
|
@ -577,7 +563,7 @@ def _reset_auto_parallel_context():
|
||||||
- device_num: 1.
|
- device_num: 1.
|
||||||
- global_rank: 0.
|
- global_rank: 0.
|
||||||
- mirror_mean: False.
|
- mirror_mean: False.
|
||||||
- cast_before_mirror: True.
|
- gradient_fp32_sync: True.
|
||||||
- parallel_mode: "stand_alone".
|
- parallel_mode: "stand_alone".
|
||||||
- parameter_broadcast: False.
|
- parameter_broadcast: False.
|
||||||
- strategy_ckpt_load_file: ""
|
- strategy_ckpt_load_file: ""
|
||||||
|
|
|
@ -61,7 +61,7 @@ def get_rank_id(group=None):
|
||||||
|
|
||||||
def get_rank_size(group=None):
|
def get_rank_size(group=None):
|
||||||
hccl = Hccl()
|
hccl = Hccl()
|
||||||
if group is None:
|
if group is None or "nccl_world_group" in group:
|
||||||
return hccl.rank_size
|
return hccl.rank_size
|
||||||
if isinstance(group, str):
|
if isinstance(group, str):
|
||||||
return int(group.split("-")[0])
|
return int(group.split("-")[0])
|
||||||
|
|
|
@ -830,7 +830,7 @@ def test_matmul_cast():
|
||||||
compile_net(net, x, y, b)
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
def test_cast_before_mirror():
|
def test_gradient_fp32_sync():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, strategy1):
|
def __init__(self, strategy1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -843,7 +843,7 @@ def test_cast_before_mirror():
|
||||||
out = self.matmul(out, b)
|
out = self.matmul(out, b)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True)
|
context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True)
|
||||||
strategy1 = ((2, 2), (2, 2))
|
strategy1 = ((2, 2), (2, 2))
|
||||||
net = GradWrap(NetWithLoss(Net(strategy1)))
|
net = GradWrap(NetWithLoss(Net(strategy1)))
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
@ -854,7 +854,7 @@ def test_cast_before_mirror():
|
||||||
compile_net(net, x, y, b)
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
def test_cast_before_mirror1():
|
def test_gradient_fp32_sync1():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, strategy1):
|
def __init__(self, strategy1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -867,7 +867,7 @@ def test_cast_before_mirror1():
|
||||||
out = self.matmul(out, b)
|
out = self.matmul(out, b)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True)
|
context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True)
|
||||||
strategy1 = ((2, 2), (2, 2))
|
strategy1 = ((2, 2), (2, 2))
|
||||||
net = GradWrap(NetWithLoss(Net(strategy1)))
|
net = GradWrap(NetWithLoss(Net(strategy1)))
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
@ -878,7 +878,7 @@ def test_cast_before_mirror1():
|
||||||
compile_net(net, x, y, b)
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
def test_cast_before_mirror2():
|
def test_gradient_fp32_sync2():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, strategy1):
|
def __init__(self, strategy1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -891,7 +891,7 @@ def test_cast_before_mirror2():
|
||||||
out = self.matmul(out, b)
|
out = self.matmul(out, b)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=False)
|
context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=False)
|
||||||
strategy1 = ((2, 2), (2, 2))
|
strategy1 = ((2, 2), (2, 2))
|
||||||
net = GradWrap(NetWithLoss(Net(strategy1)))
|
net = GradWrap(NetWithLoss(Net(strategy1)))
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
@ -902,7 +902,7 @@ def test_cast_before_mirror2():
|
||||||
compile_net(net, x, y, b)
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
def test_cast_before_mirror3():
|
def test_gradient_fp32_sync3():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, strategy1):
|
def __init__(self, strategy1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -20,25 +20,21 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
|
||||||
|
|
||||||
def test_set_auto_parallel_context():
|
def test_set_auto_parallel_context():
|
||||||
context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, cast_before_mirror=False,
|
context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, gradient_fp32_sync=False,
|
||||||
parallel_mode="auto_parallel", parameter_broadcast=False)
|
parallel_mode="auto_parallel", parameter_broadcast=False)
|
||||||
device_num = context.get_auto_parallel_context("device_num")
|
device_num = context.get_auto_parallel_context("device_num")
|
||||||
global_rank = context.get_auto_parallel_context("global_rank")
|
global_rank = context.get_auto_parallel_context("global_rank")
|
||||||
mirror_mean = context.get_auto_parallel_context("mirror_mean")
|
mirror_mean = context.get_auto_parallel_context("mirror_mean")
|
||||||
cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror")
|
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
|
||||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
||||||
assert device_num == 4
|
assert device_num == 4
|
||||||
assert global_rank == 3
|
assert global_rank == 3
|
||||||
assert mirror_mean
|
assert mirror_mean
|
||||||
assert not cast_before_mirror
|
assert not gradient_fp32_sync
|
||||||
assert parallel_mode == "auto_parallel"
|
assert parallel_mode == "auto_parallel"
|
||||||
assert not parameter_broadcast
|
assert not parameter_broadcast
|
||||||
|
|
||||||
auto_parallel_context().set_communication_backend("hccl")
|
|
||||||
backend = auto_parallel_context().get_communication_backend()
|
|
||||||
assert backend == "hccl"
|
|
||||||
|
|
||||||
auto_parallel_context().set_device_num(4)
|
auto_parallel_context().set_device_num(4)
|
||||||
device_num = auto_parallel_context().get_device_num()
|
device_num = auto_parallel_context().get_device_num()
|
||||||
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
||||||
|
@ -53,9 +49,9 @@ def test_set_auto_parallel_context():
|
||||||
mirror_mean = auto_parallel_context().get_mirror_mean()
|
mirror_mean = auto_parallel_context().get_mirror_mean()
|
||||||
assert mirror_mean
|
assert mirror_mean
|
||||||
|
|
||||||
auto_parallel_context().set_cast_before_mirror(False)
|
auto_parallel_context().set_gradient_fp32_sync(False)
|
||||||
cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
|
gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync()
|
||||||
assert not cast_before_mirror
|
assert not gradient_fp32_sync
|
||||||
|
|
||||||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||||
assert parameter_broadcast_is_set
|
assert parameter_broadcast_is_set
|
||||||
|
@ -91,7 +87,7 @@ def test_reset_auto_parallel_context():
|
||||||
device_num = context.get_auto_parallel_context("device_num")
|
device_num = context.get_auto_parallel_context("device_num")
|
||||||
global_rank = context.get_auto_parallel_context("global_rank")
|
global_rank = context.get_auto_parallel_context("global_rank")
|
||||||
mirror_mean = context.get_auto_parallel_context("mirror_mean")
|
mirror_mean = context.get_auto_parallel_context("mirror_mean")
|
||||||
cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror")
|
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
|
||||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
||||||
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
||||||
|
@ -99,7 +95,7 @@ def test_reset_auto_parallel_context():
|
||||||
assert device_num == 1
|
assert device_num == 1
|
||||||
assert global_rank == 0
|
assert global_rank == 0
|
||||||
assert not mirror_mean
|
assert not mirror_mean
|
||||||
assert cast_before_mirror
|
assert gradient_fp32_sync
|
||||||
assert parallel_mode == "stand_alone"
|
assert parallel_mode == "stand_alone"
|
||||||
assert not parameter_broadcast
|
assert not parameter_broadcast
|
||||||
assert not device_num_is_set
|
assert not device_num_is_set
|
||||||
|
|
Loading…
Reference in New Issue