MindSpore-Model-Development/examples/train_parallel_with_func_ex...

157 lines
5.0 KiB
Python

import os
import sys
from time import time
sys.path.append(".")
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
from mindspore.communication import get_group_size, get_rank, init
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindcv.data import create_dataset, create_loader, create_transforms
from mindcv.loss import create_loss
from mindcv.models import create_model
from mindcv.optim import create_optimizer
from mindcv.utils import Allreduce
def main():
ms.set_seed(1)
ms.set_context(mode=ms.PYNATIVE_MODE)
# --------------------------- Prepare data -------------------------#
# create dataset for train and val
init()
device_num = get_group_size()
rank_id = get_rank()
ms.set_auto_parallel_context(
device_num=device_num,
parallel_mode="data_parallel",
gradients_mean=True,
)
num_classes = 10
num_workers = 8
data_dir = "/data/cifar-10-batches-bin"
download = False if os.path.exists(data_dir) else True
dataset_train = create_dataset(
name="cifar10",
root=data_dir,
split="train",
shuffle=True,
download=download,
num_shards=device_num,
shard_id=rank_id,
num_parallel_workers=num_workers,
)
dataset_test = create_dataset(
name="cifar10",
root=data_dir,
split="test",
shuffle=False,
download=False,
num_shards=device_num,
shard_id=rank_id,
num_parallel_workers=num_workers,
)
# create transform and get trans list
trans_train = create_transforms(dataset_name="cifar10", is_training=True)
trans_test = create_transforms(dataset_name="cifar10", is_training=False)
# get data loader
loader_train = create_loader(
dataset=dataset_train,
batch_size=64,
is_training=True,
num_classes=num_classes,
transform=trans_train,
num_parallel_workers=num_workers,
drop_remainder=True,
)
loader_test = create_loader(
dataset=dataset_test,
batch_size=32,
is_training=False,
num_classes=num_classes,
transform=trans_test,
)
num_batches = loader_train.get_dataset_size()
print("Num batches: ", num_batches)
# --------------------------- Build model -------------------------#
network = create_model(model_name="resnet18", num_classes=num_classes, pretrained=False)
loss = create_loss(name="CE")
opt = create_optimizer(network.trainable_params(), opt="adam", lr=1e-3)
# --------------------------- Training and monitoring -------------------------#
epochs = 10
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
save_path = f"./resnet18-{t + 1}_{num_batches}.ckpt"
b = time()
train_epoch(network, loader_train, loss, opt)
print("Epoch time cost: ", time() - b)
test_epoch(network, loader_test)
if rank_id in [None, 0]:
ms.save_checkpoint(network, save_path, async_save=True)
print("Done!")
def train_epoch(network, dataset, loss_fn, optimizer):
# Define forward function
def forward_fn(data, label):
logits = network(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
mean = _get_gradients_mean()
degree = _get_device_num()
grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
# Define function of one-step training,
@ms.ms_function
def train_step_parallel(data, label):
(loss, _), grads = grad_fn(data, label)
grads = grad_reducer(grads)
loss = ops.depend(loss, optimizer(grads))
return loss
network.set_train()
size = dataset.get_dataset_size()
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss = train_step_parallel(data, label)
if batch % 100 == 0:
loss, current = loss.asnumpy(), batch
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
def test_epoch(network, dataset):
network.set_train(False)
total, correct = 0, 0
for data, label in dataset.create_tuple_iterator():
pred = network(data)
total += len(data)
if len(label.shape) == 1:
correct += (pred.argmax(1) == label).asnumpy().sum()
else: # one-hot or soft label
correct += (pred.argmax(1) == label.argmax(1)).asnumpy().sum()
all_reduce = Allreduce()
correct = all_reduce(Tensor(correct, ms.float32))
total = all_reduce(Tensor(total, ms.float32))
correct /= total
acc = 100 * correct.asnumpy()
print(f"Test Accuracy: {acc:>0.2f}% \n")
return acc
if __name__ == "__main__":
main()