From 77b1f71833fc4500e6fb35b192236a742b34dc2d Mon Sep 17 00:00:00 2001 From: VectorSL Date: Fri, 29 May 2020 16:39:25 +0800 Subject: [PATCH] gpu fix resnet script --- example/resnet50_cifar10/dataset.py | 2 +- example/resnet50_imagenet2012/dataset.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/example/resnet50_cifar10/dataset.py b/example/resnet50_cifar10/dataset.py index 1d7074d7333..fabdf0b181c 100755 --- a/example/resnet50_cifar10/dataset.py +++ b/example/resnet50_cifar10/dataset.py @@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 -from mindspore.communication.management import get_rank, get_group_size +from mindspore.communication.management import init, get_rank, get_group_size from config import config diff --git a/example/resnet50_imagenet2012/dataset.py b/example/resnet50_imagenet2012/dataset.py index 400a4dc4fa9..0691985e0b1 100755 --- a/example/resnet50_imagenet2012/dataset.py +++ b/example/resnet50_imagenet2012/dataset.py @@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 -from mindspore.communication.management import get_rank, get_group_size +from mindspore.communication.management import init, get_rank, get_group_size def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): """ @@ -40,6 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" device_num = int(os.getenv("DEVICE_NUM")) rank_id = int(os.getenv("RANK_ID")) else: + init("nccl") rank_id = get_rank() device_num = get_group_size()