forked from mindspore-Ecosystem/mindspore
maskrcnn support 16p
This commit is contained in:
parent
b8b4ce7442
commit
b34d206f84
|
@ -29,6 +29,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
||||
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
||||
|
@ -56,11 +57,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=a
|
|||
if __name__ == '__main__':
|
||||
print("Start train for maskrcnn!")
|
||||
if not args_opt.do_eval and args_opt.run_distribute:
|
||||
rank = args_opt.rank_id
|
||||
device_num = args_opt.device_num
|
||||
init()
|
||||
rank = get_rank()
|
||||
device_num = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
|
Loading…
Reference in New Issue