This commit is contained in:
马天姿 2020-07-09 23:17:15 +08:00
parent a861db0669
commit 0f2c5c3289
1 changed files with 86 additions and 0 deletions

View File

@ -0,0 +1,86 @@
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = C.RandomHorizontalFlip()
resize_op = C.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = C.HWC2CHW()
type_cast_op = C2.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op,
changeswap_op]
# apply map operations on images
cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans)
# apply shuffle operations
cifar_ds = cifar_ds.shuffle(buffer_size=10)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
ls = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction="mean")
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
eval_dataset = create_dataset(1, training=False)
res = model.eval(eval_dataset)
print("result: ", res)
import cifar_resnet50
import object
if __name__ == '__main__':
# add objects for searching
objs = [
"random_crop_op=C.RandomCrop((32,32),(4,4,4,4))",
"random_horizontal_op=C.RandomHorizontalFlip()",
"resize_op=C.Resize((resize_height,resize_width))",
"rescale_op=C.Rescale(rescale,shift)",
"normalize_op=C.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))",
"changeswap_op=C.HWC2CHW()",
"type_cast_op=C2.TypeCast(mstype.int32)",
"c_trans=[random_crop_op,random_horizontal_op]",
"cifar_ds=cifar_ds.map(input_columns=",
"cifar_ds=cifar_ds.shuffle(buffer_size=10)",
"cifar_ds=cifar_ds.batch(batch_size=args_opt.batch_size,drop_remainder=True)",
"cifar_ds=cifar_ds.repeat(repeat_num)",
"ls=SoftmaxCrossEntropyWithLogits(sparse=True,is_grad=False,reduction=",
"opt=Momentum(filter(lambda x:x.requires_grad,net.get_parameters()),0.01,0.9)",
"config_ck=CheckpointConfig(save_checkpoint_steps=batch_num,keep_checkpoint_max=35)",
"ckpoint_cb=ModelCheckpoint(prefix=",
"loss_cb=LossMonitor()",
"model.train(epoch_size,dataset,callbacks=[ckpoint_cb,loss_cb])",
"param_dict=load_checkpoint(args_opt.checkpoint_path)",
"load_param_into_net(net,param_dict)",
"eval_dataset=create_dataset(1,training=False)",
"res=model.eval(eval_dataset)",
]
filepath = "./MindSpore/src/step1/cifar_resnet50.py"
if (object.objectFind(objs, filepath)):
print("----------------")
print("ok!")
else:
print("object error!")