forked from mindspore-Ecosystem/mindspore
!1327 fix resnet cifar 1p test case device error
Merge pull request !1327 from chujinjin/fix_cifar_1p_test
This commit is contained in:
commit
beb3cadb75
|
@ -134,12 +134,8 @@ class LossGet(Callback):
|
|||
return self._loss
|
||||
|
||||
|
||||
def train_process(device_id, epoch_size, num_classes, batch_size):
|
||||
os.system("mkdir " + str(device_id))
|
||||
os.chdir(str(device_id))
|
||||
def train_process(epoch_size, num_classes, batch_size):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = resnet50(batch_size, num_classes)
|
||||
loss = CrossEntropyLoss()
|
||||
opt = Momentum(filter(lambda x: x.requires_grad,
|
||||
|
@ -148,34 +144,15 @@ def train_process(device_id, epoch_size, num_classes, batch_size):
|
|||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
dataset = create_dataset(epoch_size, training=True, batch_size=batch_size)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./",
|
||||
config=config_ck)
|
||||
loss_cb = LossGet()
|
||||
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
|
||||
model.train(epoch_size, dataset, callbacks=[loss_cb])
|
||||
|
||||
|
||||
def eval(batch_size, num_classes):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
|
||||
net = resnet50(batch_size, num_classes)
|
||||
loss = CrossEntropyLoss()
|
||||
opt = Momentum(filter(lambda x: x.requires_grad,
|
||||
net.get_parameters()), 0.01, 0.9)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt"
|
||||
param_dict = load_checkpoint(checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
eval_dataset = create_dataset(1, training=False)
|
||||
res = model.eval(eval_dataset)
|
||||
print("result: ", res)
|
||||
return res
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -184,11 +161,7 @@ def test_resnet_cifar_1p():
|
|||
epoch_size = 1
|
||||
num_classes = 10
|
||||
batch_size = 32
|
||||
device_id = 0
|
||||
train_process(device_id, epoch_size, num_classes, batch_size)
|
||||
time.sleep(3)
|
||||
acc = eval(batch_size, num_classes)
|
||||
os.chdir("../")
|
||||
os.system("rm -rf " + str(device_id))
|
||||
acc = train_process(epoch_size, num_classes, batch_size)
|
||||
os.system("rm -rf kernel_meta")
|
||||
print("End training...")
|
||||
assert acc['acc'] > 0.35
|
||||
|
|
Loading…
Reference in New Issue