auto parallel context modify

This commit is contained in:
yao_yf 2020-08-29 16:55:19 +08:00
parent 042ac51f05
commit 8f7aa5bd5a
10 changed files with 57 additions and 87 deletions

View File

@ -42,15 +42,12 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
return inst_context_; return inst_context_;
} }
ParallelContext::ParallelContext() { ParallelContext::ParallelContext() { Reset(); }
communication_backend_ = HCCL_BACKEND;
Reset();
}
void ParallelContext::Reset() { void ParallelContext::Reset() {
mirror_mean_ = false; mirror_mean_ = false;
full_batch_ = false; full_batch_ = false;
cast_before_mirror_ = true; gradient_fp32_sync_ = true;
loss_repeated_mean_ = true; loss_repeated_mean_ = true;
device_num_ = 1; device_num_ = 1;
global_rank_ = 0; global_rank_ = 0;
@ -81,14 +78,10 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
void ParallelContext::set_communication_backend(const std::string &communication_backend) {
communication_backend_ = communication_backend;
}
bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) { bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
if (iter == PARALLEL_MODE_LIST.end()) { if (iter == PARALLEL_MODE_LIST.end()) {

View File

@ -58,8 +58,8 @@ class ParallelContext {
void set_full_batch(bool full_batch); void set_full_batch(bool full_batch);
bool full_batch() const { return full_batch_; } bool full_batch() const { return full_batch_; }
void set_cast_before_mirror(bool cast_before_mirror); void set_gradient_fp32_sync(bool gradient_fp32_sync);
bool cast_before_mirror() const { return cast_before_mirror_; } bool gradient_fp32_sync() const { return gradient_fp32_sync_; }
void set_loss_repeated_mean(bool loss_repeated_mean); void set_loss_repeated_mean(bool loss_repeated_mean);
bool loss_repeated_mean() const { return loss_repeated_mean_; } bool loss_repeated_mean() const { return loss_repeated_mean_; }
@ -70,9 +70,6 @@ class ParallelContext {
void set_global_rank(int32_t global_rank); void set_global_rank(int32_t global_rank);
int32_t global_rank() const { return global_rank_; } int32_t global_rank() const { return global_rank_; }
void set_communication_backend(const std::string &communication_backend);
std::string communication_backend() const { return communication_backend_; }
bool set_parallel_mode(const std::string &parallel_mode); bool set_parallel_mode(const std::string &parallel_mode);
std::string parallel_mode() const { return parallel_mode_; } std::string parallel_mode() const { return parallel_mode_; }
@ -112,11 +109,10 @@ class ParallelContext {
static std::shared_ptr<ParallelContext> inst_context_; static std::shared_ptr<ParallelContext> inst_context_;
bool mirror_mean_; bool mirror_mean_;
bool full_batch_; bool full_batch_;
bool cast_before_mirror_; bool gradient_fp32_sync_;
bool loss_repeated_mean_; bool loss_repeated_mean_;
int32_t device_num_; int32_t device_num_;
int32_t global_rank_; int32_t global_rank_;
std::string communication_backend_;
std::string parallel_mode_; std::string parallel_mode_;
std::string strategy_search_mode_; std::string strategy_search_mode_;
bool parameter_broadcast_; bool parameter_broadcast_;

View File

@ -43,6 +43,7 @@
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "utils/ms_context.h"
using mindspore::tensor::Tensor; using mindspore::tensor::Tensor;
@ -869,8 +870,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
} }
bool IsCastBeforMirror(const CNodePtr &node, size_t index) { bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
// only if cast_before_mirror is true, pre node is cast and type is not float32 return true // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
if (!ParallelContext::GetInstance()->cast_before_mirror()) { if (!ParallelContext::GetInstance()->gradient_fp32_sync()) {
return false; return false;
} }
auto pre_node = node->input(index); auto pre_node = node->input(index);
@ -2421,13 +2422,17 @@ Status ParallelInit() {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
int32_t device_num = ParallelContext::GetInstance()->device_num(); int32_t device_num = ParallelContext::GetInstance()->device_num();
int32_t global_rank = ParallelContext::GetInstance()->global_rank(); int32_t global_rank = ParallelContext::GetInstance()->global_rank();
std::string backend = ParallelContext::GetInstance()->communication_backend(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
std::string world_group; std::string world_group;
std::string communication_backend;
if (backend == HCCL_BACKEND) { if (backend == kAscendDevice || backend == kDavinciDevice) {
world_group = HCCL_WORLD_GROUP; world_group = HCCL_WORLD_GROUP;
} else if (backend == NCCL_BACKEND) { communication_backend = HCCL_BACKEND;
} else if (backend == kGPUDevice) {
world_group = NCCL_WORLD_GROUP; world_group = NCCL_WORLD_GROUP;
communication_backend = NCCL_BACKEND;
} else { } else {
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
} }
@ -2450,14 +2455,14 @@ Status ParallelInit() {
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
} }
if (!InitDevice(device_num, global_rank, backend)) { if (!InitDevice(device_num, global_rank, communication_backend)) {
MS_LOG(ERROR) << "Init device failed"; MS_LOG(ERROR) << "Init device failed";
return FAILED; return FAILED;
} }
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
<< ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean()
<< ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
return SUCCESS; return SUCCESS;
} }

View File

@ -209,12 +209,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
.def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.")
.def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.")
.def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.")
.def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.")
.def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
.def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
.def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.")
.def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.")
.def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")

View File

@ -15,7 +15,6 @@
"""Communication management API""" """Communication management API"""
import os import os
from mindspore import context from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
@ -86,9 +85,6 @@ def init(backend_name=None):
else: else:
raise RuntimeError("Backend name {} is not supported.".format(backend_name)) raise RuntimeError("Backend name {} is not supported.".format(backend_name))
auto_parallel_context().set_communication_backend(backend_name)
def release(): def release():
""" """
Release distributed resource. e.g., hccl/nccl. Release distributed resource. e.g., hccl/nccl.

View File

@ -434,7 +434,7 @@ def _context():
return _k_context return _k_context
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def set_auto_parallel_context(**kwargs): def set_auto_parallel_context(**kwargs):
@ -454,9 +454,9 @@ def set_auto_parallel_context(**kwargs):
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
"stand_alone" do not support mirror_mean. Default: False. "stand_alone" do not support mirror_mean. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True..
"stand_alone", "data_parallel" and "hybrid_parallel" do not support "stand_alone", "data_parallel" and "hybrid_parallel" do not support
cast_before_mirror. Default: True. gradient_fp32_sync. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
@ -492,7 +492,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(device_num=8) >>> context.set_auto_parallel_context(device_num=8)
>>> context.set_auto_parallel_context(global_rank=0) >>> context.set_auto_parallel_context(global_rank=0)
>>> context.set_auto_parallel_context(mirror_mean=True) >>> context.set_auto_parallel_context(mirror_mean=True)
>>> context.set_auto_parallel_context(cast_before_mirror=False) >>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel") >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
@ -524,7 +524,7 @@ def reset_auto_parallel_context():
- device_num: 1. - device_num: 1.
- global_rank: 0. - global_rank: 0.
- mirror_mean: False. - mirror_mean: False.
- cast_before_mirror: True. - gradient_fp32_sync: True.
- parallel_mode: "stand_alone". - parallel_mode: "stand_alone".
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: "". - strategy_ckpt_load_file: "".

View File

@ -113,24 +113,24 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_mirror_mean() return self._context_handle.get_mirror_mean()
def set_cast_before_mirror(self, cast_before_mirror): def set_gradient_fp32_sync(self, gradient_fp32_sync):
""" """
Set cast_before_mirror. Set gradient_fp32_sync.
Note: Note:
If cast_before_mirror is true, If gradient_fp32_sync is true,
it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
Args: Args:
cast_before_mirror (bool): The cast_before_mirror flag. gradient_fp32_sync (bool): The gradient_fp32_sync flag.
""" """
self.check_context_handle() self.check_context_handle()
self._context_handle.set_cast_before_mirror(cast_before_mirror) self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
def get_cast_before_mirror(self): def get_gradient_fp32_sync(self):
"""Get cast_before_mirror flag.""" """Get gradient_fp32_sync flag."""
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_cast_before_mirror() return self._context_handle.get_gradient_fp32_sync()
def set_loss_repeated_mean(self, loss_repeated_mean): def set_loss_repeated_mean(self, loss_repeated_mean):
""" """
@ -152,21 +152,6 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_loss_repeated_mean() return self._context_handle.get_loss_repeated_mean()
def set_communication_backend(self, communication_backend):
"""
Set communication backend.
Args:
communication_backend (str): The communication backend.
"""
self.check_context_handle()
self._context_handle.set_communication_backend(communication_backend)
def get_communication_backend(self):
"""Get communication backend."""
self.check_context_handle()
return self._context_handle.get_communication_backend()
def set_parallel_mode(self, parallel_mode): def set_parallel_mode(self, parallel_mode):
""" """
Set parallel mode for auto parallel. Set parallel mode for auto parallel.
@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = {
"device_num": auto_parallel_context().set_device_num, "device_num": auto_parallel_context().set_device_num,
"global_rank": auto_parallel_context().set_global_rank, "global_rank": auto_parallel_context().set_global_rank,
"mirror_mean": auto_parallel_context().set_mirror_mean, "mirror_mean": auto_parallel_context().set_mirror_mean,
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror, "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"parallel_mode": auto_parallel_context().set_parallel_mode, "parallel_mode": auto_parallel_context().set_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = {
"device_num": auto_parallel_context().get_device_num, "device_num": auto_parallel_context().get_device_num,
"global_rank": auto_parallel_context().get_global_rank, "global_rank": auto_parallel_context().get_global_rank,
"mirror_mean": auto_parallel_context().get_mirror_mean, "mirror_mean": auto_parallel_context().get_mirror_mean,
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror, "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"parallel_mode": auto_parallel_context().get_parallel_mode, "parallel_mode": auto_parallel_context().get_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = {
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs):
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
calculations. Default: True. calculations. Default: True.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True. gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
@ -577,7 +563,7 @@ def _reset_auto_parallel_context():
- device_num: 1. - device_num: 1.
- global_rank: 0. - global_rank: 0.
- mirror_mean: False. - mirror_mean: False.
- cast_before_mirror: True. - gradient_fp32_sync: True.
- parallel_mode: "stand_alone". - parallel_mode: "stand_alone".
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: "" - strategy_ckpt_load_file: ""

View File

@ -61,7 +61,7 @@ def get_rank_id(group=None):
def get_rank_size(group=None): def get_rank_size(group=None):
hccl = Hccl() hccl = Hccl()
if group is None: if group is None or "nccl_world_group" in group:
return hccl.rank_size return hccl.rank_size
if isinstance(group, str): if isinstance(group, str):
return int(group.split("-")[0]) return int(group.split("-")[0])

View File

@ -830,7 +830,7 @@ def test_matmul_cast():
compile_net(net, x, y, b) compile_net(net, x, y, b)
def test_cast_before_mirror(): def test_gradient_fp32_sync():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1): def __init__(self, strategy1):
super().__init__() super().__init__()
@ -843,7 +843,7 @@ def test_cast_before_mirror():
out = self.matmul(out, b) out = self.matmul(out, b)
return out return out
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True)
strategy1 = ((2, 2), (2, 2)) strategy1 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1))) net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
@ -854,7 +854,7 @@ def test_cast_before_mirror():
compile_net(net, x, y, b) compile_net(net, x, y, b)
def test_cast_before_mirror1(): def test_gradient_fp32_sync1():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1): def __init__(self, strategy1):
super().__init__() super().__init__()
@ -867,7 +867,7 @@ def test_cast_before_mirror1():
out = self.matmul(out, b) out = self.matmul(out, b)
return out return out
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True)
strategy1 = ((2, 2), (2, 2)) strategy1 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1))) net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
@ -878,7 +878,7 @@ def test_cast_before_mirror1():
compile_net(net, x, y, b) compile_net(net, x, y, b)
def test_cast_before_mirror2(): def test_gradient_fp32_sync2():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1): def __init__(self, strategy1):
super().__init__() super().__init__()
@ -891,7 +891,7 @@ def test_cast_before_mirror2():
out = self.matmul(out, b) out = self.matmul(out, b)
return out return out
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=False) context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=False)
strategy1 = ((2, 2), (2, 2)) strategy1 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1))) net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
@ -902,7 +902,7 @@ def test_cast_before_mirror2():
compile_net(net, x, y, b) compile_net(net, x, y, b)
def test_cast_before_mirror3(): def test_gradient_fp32_sync3():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1): def __init__(self, strategy1):
super().__init__() super().__init__()

View File

@ -20,25 +20,21 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
def test_set_auto_parallel_context(): def test_set_auto_parallel_context():
context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, cast_before_mirror=False, context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, gradient_fp32_sync=False,
parallel_mode="auto_parallel", parameter_broadcast=False) parallel_mode="auto_parallel", parameter_broadcast=False)
device_num = context.get_auto_parallel_context("device_num") device_num = context.get_auto_parallel_context("device_num")
global_rank = context.get_auto_parallel_context("global_rank") global_rank = context.get_auto_parallel_context("global_rank")
mirror_mean = context.get_auto_parallel_context("mirror_mean") mirror_mean = context.get_auto_parallel_context("mirror_mean")
cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
parallel_mode = context.get_auto_parallel_context("parallel_mode") parallel_mode = context.get_auto_parallel_context("parallel_mode")
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
assert device_num == 4 assert device_num == 4
assert global_rank == 3 assert global_rank == 3
assert mirror_mean assert mirror_mean
assert not cast_before_mirror assert not gradient_fp32_sync
assert parallel_mode == "auto_parallel" assert parallel_mode == "auto_parallel"
assert not parameter_broadcast assert not parameter_broadcast
auto_parallel_context().set_communication_backend("hccl")
backend = auto_parallel_context().get_communication_backend()
assert backend == "hccl"
auto_parallel_context().set_device_num(4) auto_parallel_context().set_device_num(4)
device_num = auto_parallel_context().get_device_num() device_num = auto_parallel_context().get_device_num()
device_num_is_set = auto_parallel_context().get_device_num_is_set() device_num_is_set = auto_parallel_context().get_device_num_is_set()
@ -53,9 +49,9 @@ def test_set_auto_parallel_context():
mirror_mean = auto_parallel_context().get_mirror_mean() mirror_mean = auto_parallel_context().get_mirror_mean()
assert mirror_mean assert mirror_mean
auto_parallel_context().set_cast_before_mirror(False) auto_parallel_context().set_gradient_fp32_sync(False)
cast_before_mirror = auto_parallel_context().get_cast_before_mirror() gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync()
assert not cast_before_mirror assert not gradient_fp32_sync
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
assert parameter_broadcast_is_set assert parameter_broadcast_is_set
@ -91,7 +87,7 @@ def test_reset_auto_parallel_context():
device_num = context.get_auto_parallel_context("device_num") device_num = context.get_auto_parallel_context("device_num")
global_rank = context.get_auto_parallel_context("global_rank") global_rank = context.get_auto_parallel_context("global_rank")
mirror_mean = context.get_auto_parallel_context("mirror_mean") mirror_mean = context.get_auto_parallel_context("mirror_mean")
cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
parallel_mode = context.get_auto_parallel_context("parallel_mode") parallel_mode = context.get_auto_parallel_context("parallel_mode")
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
device_num_is_set = auto_parallel_context().get_device_num_is_set() device_num_is_set = auto_parallel_context().get_device_num_is_set()
@ -99,7 +95,7 @@ def test_reset_auto_parallel_context():
assert device_num == 1 assert device_num == 1
assert global_rank == 0 assert global_rank == 0
assert not mirror_mean assert not mirror_mean
assert cast_before_mirror assert gradient_fp32_sync
assert parallel_mode == "stand_alone" assert parallel_mode == "stand_alone"
assert not parameter_broadcast assert not parameter_broadcast
assert not device_num_is_set assert not device_num_is_set