!1327 fix resnet cifar 1p test case device error

Merge pull request !1327 from chujinjin/fix_cifar_1p_test
This commit is contained in:
mindspore-ci-bot 2020-05-21 16:59:09 +08:00 committed by Gitee
commit beb3cadb75
1 changed files with 4 additions and 31 deletions

View File

@ -134,12 +134,8 @@ class LossGet(Callback):
return self._loss
def train_process(device_id, epoch_size, num_classes, batch_size):
os.system("mkdir " + str(device_id))
os.chdir(str(device_id))
def train_process(epoch_size, num_classes, batch_size):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(mode=context.GRAPH_MODE)
net = resnet50(batch_size, num_classes)
loss = CrossEntropyLoss()
opt = Momentum(filter(lambda x: x.requires_grad,
@ -148,34 +144,15 @@ def train_process(device_id, epoch_size, num_classes, batch_size):
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
dataset = create_dataset(epoch_size, training=True, batch_size=batch_size)
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./",
config=config_ck)
loss_cb = LossGet()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
model.train(epoch_size, dataset, callbacks=[loss_cb])
def eval(batch_size, num_classes):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
net = resnet50(batch_size, num_classes)
loss = CrossEntropyLoss()
opt = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt"
param_dict = load_checkpoint(checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
eval_dataset = create_dataset(1, training=False)
res = model.eval(eval_dataset)
print("result: ", res)
return res
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -184,11 +161,7 @@ def test_resnet_cifar_1p():
epoch_size = 1
num_classes = 10
batch_size = 32
device_id = 0
train_process(device_id, epoch_size, num_classes, batch_size)
time.sleep(3)
acc = eval(batch_size, num_classes)
os.chdir("../")
os.system("rm -rf " + str(device_id))
acc = train_process(epoch_size, num_classes, batch_size)
os.system("rm -rf kernel_meta")
print("End training...")
assert acc['acc'] > 0.35