!1674 GPU fix resent scripts

Merge pull request !1674 from VectorSL/gpu-resnet-scripts
This commit is contained in:
mindspore-ci-bot 2020-05-29 18:15:12 +08:00 committed by Gitee
commit 9cbed69ee5
2 changed files with 3 additions and 2 deletions

View File

@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from config import config from config import config

View File

@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
""" """
@ -40,6 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
device_num = int(os.getenv("DEVICE_NUM")) device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID")) rank_id = int(os.getenv("RANK_ID"))
else: else:
init("nccl")
rank_id = get_rank() rank_id = get_rank()
device_num = get_group_size() device_num = get_group_size()