forked from mindspore-Ecosystem/mindspore
!10050 GPU update init nccl in dataset
From: @VectorSL Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
ed763748de
|
@ -23,7 +23,7 @@ import mindspore.dataset.transforms.c_transforms as C2
|
|||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
|
||||
def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or evaluate cifar10 dataset for resnet50
|
||||
Args:
|
||||
|
@ -32,6 +32,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
|
@ -39,10 +40,12 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
if device_num == 1:
|
||||
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
|
@ -77,7 +80,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
return ds
|
||||
|
||||
|
||||
def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
|
@ -87,6 +90,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
|
@ -94,9 +98,12 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
|
@ -139,7 +146,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
return ds
|
||||
|
||||
|
||||
def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet101
|
||||
Args:
|
||||
|
@ -147,12 +154,21 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
|
@ -192,7 +208,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
return ds
|
||||
|
||||
|
||||
def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for se-resnet50
|
||||
|
||||
|
@ -202,12 +218,20 @@ def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
distribute(bool): data for distribute or not. Default: False
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
else:
|
||||
device_num = 1
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
||||
else:
|
||||
|
|
|
@ -94,7 +94,7 @@ if __name__ == '__main__':
|
|||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target)
|
||||
batch_size=config.batch_size, target=target, distribute=args_opt.run_distribute)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
|
|
Loading…
Reference in New Issue