test
This commit is contained in:
parent
a861db0669
commit
0f2c5c3289
|
@ -0,0 +1,86 @@
|
||||||
|
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
|
||||||
|
random_horizontal_op = C.RandomHorizontalFlip()
|
||||||
|
resize_op = C.Resize((resize_height, resize_width)) # interpolation default BILINEAR
|
||||||
|
rescale_op = C.Rescale(rescale, shift)
|
||||||
|
normalize_op = C.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||||
|
changeswap_op = C.HWC2CHW()
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
c_trans = []
|
||||||
|
if training:
|
||||||
|
c_trans = [random_crop_op, random_horizontal_op]
|
||||||
|
c_trans += [resize_op, rescale_op, normalize_op,
|
||||||
|
changeswap_op]
|
||||||
|
|
||||||
|
# apply map operations on images
|
||||||
|
cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op)
|
||||||
|
cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans)
|
||||||
|
|
||||||
|
# apply shuffle operations
|
||||||
|
cifar_ds = cifar_ds.shuffle(buffer_size=10)
|
||||||
|
|
||||||
|
# apply batch operations
|
||||||
|
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
# apply repeat operations
|
||||||
|
cifar_ds = cifar_ds.repeat(repeat_num)
|
||||||
|
|
||||||
|
ls = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction="mean")
|
||||||
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
|
||||||
|
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
|
||||||
|
|
||||||
|
if args_opt.checkpoint_path:
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
eval_dataset = create_dataset(1, training=False)
|
||||||
|
res = model.eval(eval_dataset)
|
||||||
|
print("result: ", res)
|
||||||
|
|
||||||
|
|
||||||
|
import cifar_resnet50
|
||||||
|
import object
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# add objects for searching
|
||||||
|
objs = [
|
||||||
|
|
||||||
|
|
||||||
|
"random_crop_op=C.RandomCrop((32,32),(4,4,4,4))",
|
||||||
|
"random_horizontal_op=C.RandomHorizontalFlip()",
|
||||||
|
"resize_op=C.Resize((resize_height,resize_width))",
|
||||||
|
"rescale_op=C.Rescale(rescale,shift)",
|
||||||
|
"normalize_op=C.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))",
|
||||||
|
"changeswap_op=C.HWC2CHW()",
|
||||||
|
"type_cast_op=C2.TypeCast(mstype.int32)",
|
||||||
|
"c_trans=[random_crop_op,random_horizontal_op]",
|
||||||
|
"cifar_ds=cifar_ds.map(input_columns=",
|
||||||
|
"cifar_ds=cifar_ds.shuffle(buffer_size=10)",
|
||||||
|
"cifar_ds=cifar_ds.batch(batch_size=args_opt.batch_size,drop_remainder=True)",
|
||||||
|
"cifar_ds=cifar_ds.repeat(repeat_num)",
|
||||||
|
"ls=SoftmaxCrossEntropyWithLogits(sparse=True,is_grad=False,reduction=",
|
||||||
|
"opt=Momentum(filter(lambda x:x.requires_grad,net.get_parameters()),0.01,0.9)",
|
||||||
|
"config_ck=CheckpointConfig(save_checkpoint_steps=batch_num,keep_checkpoint_max=35)",
|
||||||
|
"ckpoint_cb=ModelCheckpoint(prefix=",
|
||||||
|
"loss_cb=LossMonitor()",
|
||||||
|
"model.train(epoch_size,dataset,callbacks=[ckpoint_cb,loss_cb])",
|
||||||
|
"param_dict=load_checkpoint(args_opt.checkpoint_path)",
|
||||||
|
"load_param_into_net(net,param_dict)",
|
||||||
|
"eval_dataset=create_dataset(1,training=False)",
|
||||||
|
"res=model.eval(eval_dataset)",
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
filepath = "./MindSpore/src/step1/cifar_resnet50.py"
|
||||||
|
|
||||||
|
if (object.objectFind(objs, filepath)):
|
||||||
|
print("----------------")
|
||||||
|
print("ok!")
|
||||||
|
else:
|
||||||
|
print("object error!")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue