forked from mindspore-Ecosystem/mindspore
!1674 GPU fix resent scripts
Merge pull request !1674 from VectorSL/gpu-resnet-scripts
This commit is contained in:
commit
9cbed69ee5
|
@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset.engine as de
|
import mindspore.dataset.engine as de
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||||
import mindspore.dataset.transforms.c_transforms as C2
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
from mindspore.communication.management import get_rank, get_group_size
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
from config import config
|
from config import config
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset.engine as de
|
import mindspore.dataset.engine as de
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||||
import mindspore.dataset.transforms.c_transforms as C2
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
from mindspore.communication.management import get_rank, get_group_size
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
|
||||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||||
"""
|
"""
|
||||||
|
@ -40,6 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
||||||
device_num = int(os.getenv("DEVICE_NUM"))
|
device_num = int(os.getenv("DEVICE_NUM"))
|
||||||
rank_id = int(os.getenv("RANK_ID"))
|
rank_id = int(os.getenv("RANK_ID"))
|
||||||
else:
|
else:
|
||||||
|
init("nccl")
|
||||||
rank_id = get_rank()
|
rank_id = get_rank()
|
||||||
device_num = get_group_size()
|
device_num = get_group_size()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue