test
This commit is contained in:
parent
a861db0669
commit
0f2c5c3289
|
@ -0,0 +1,86 @@
|
|||
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
|
||||
random_horizontal_op = C.RandomHorizontalFlip()
|
||||
resize_op = C.Resize((resize_height, resize_width)) # interpolation default BILINEAR
|
||||
rescale_op = C.Rescale(rescale, shift)
|
||||
normalize_op = C.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
changeswap_op = C.HWC2CHW()
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
c_trans = []
|
||||
if training:
|
||||
c_trans = [random_crop_op, random_horizontal_op]
|
||||
c_trans += [resize_op, rescale_op, normalize_op,
|
||||
changeswap_op]
|
||||
|
||||
# apply map operations on images
|
||||
cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans)
|
||||
|
||||
# apply shuffle operations
|
||||
cifar_ds = cifar_ds.shuffle(buffer_size=10)
|
||||
|
||||
# apply batch operations
|
||||
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
|
||||
|
||||
# apply repeat operations
|
||||
cifar_ds = cifar_ds.repeat(repeat_num)
|
||||
|
||||
ls = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction="mean")
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
eval_dataset = create_dataset(1, training=False)
|
||||
res = model.eval(eval_dataset)
|
||||
print("result: ", res)
|
||||
|
||||
|
||||
import cifar_resnet50
|
||||
import object
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# add objects for searching
|
||||
objs = [
|
||||
|
||||
|
||||
"random_crop_op=C.RandomCrop((32,32),(4,4,4,4))",
|
||||
"random_horizontal_op=C.RandomHorizontalFlip()",
|
||||
"resize_op=C.Resize((resize_height,resize_width))",
|
||||
"rescale_op=C.Rescale(rescale,shift)",
|
||||
"normalize_op=C.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))",
|
||||
"changeswap_op=C.HWC2CHW()",
|
||||
"type_cast_op=C2.TypeCast(mstype.int32)",
|
||||
"c_trans=[random_crop_op,random_horizontal_op]",
|
||||
"cifar_ds=cifar_ds.map(input_columns=",
|
||||
"cifar_ds=cifar_ds.shuffle(buffer_size=10)",
|
||||
"cifar_ds=cifar_ds.batch(batch_size=args_opt.batch_size,drop_remainder=True)",
|
||||
"cifar_ds=cifar_ds.repeat(repeat_num)",
|
||||
"ls=SoftmaxCrossEntropyWithLogits(sparse=True,is_grad=False,reduction=",
|
||||
"opt=Momentum(filter(lambda x:x.requires_grad,net.get_parameters()),0.01,0.9)",
|
||||
"config_ck=CheckpointConfig(save_checkpoint_steps=batch_num,keep_checkpoint_max=35)",
|
||||
"ckpoint_cb=ModelCheckpoint(prefix=",
|
||||
"loss_cb=LossMonitor()",
|
||||
"model.train(epoch_size,dataset,callbacks=[ckpoint_cb,loss_cb])",
|
||||
"param_dict=load_checkpoint(args_opt.checkpoint_path)",
|
||||
"load_param_into_net(net,param_dict)",
|
||||
"eval_dataset=create_dataset(1,training=False)",
|
||||
"res=model.eval(eval_dataset)",
|
||||
|
||||
]
|
||||
|
||||
filepath = "./MindSpore/src/step1/cifar_resnet50.py"
|
||||
|
||||
if (object.objectFind(objs, filepath)):
|
||||
print("----------------")
|
||||
print("ok!")
|
||||
else:
|
||||
print("object error!")
|
||||
|
||||
|
Loading…
Reference in New Issue