forked from mindspore-Ecosystem/mindspore
commit
c5a1f5b855
|
@ -149,6 +149,13 @@ For more configuration details, please refer the script `config.py`.
|
|||
Usage: sh scripts/run_standalone_train.sh [DEVICE_ID] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
|
||||
```
|
||||
|
||||
```shell
|
||||
# standalone training example
|
||||
sh scripts/run_standalone_train.sh 0 /data/imagenet/train
|
||||
```
|
||||
|
||||
checkpoint can be produced in training process and be saved in the folder ./train/ckpt_squeezenet.
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
@ -182,11 +189,9 @@ Usage: sh scripts/run_eval.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]
|
|||
|
||||
```shell
|
||||
# evaluation example
|
||||
sh scripts/run_eval.sh 0 ~/data/imagenet/train ckpt_squeezenet/squeezenet_imagenet-200_40036.ckpt
|
||||
sh scripts/run_eval.sh 0 /data/imagenet/val ./train/ckpt_squeezenet/squeezenet_imagenet-200_40036.ckpt
|
||||
```
|
||||
|
||||
checkpoint can be produced in training process.
|
||||
|
||||
### Result
|
||||
|
||||
Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
|
||||
|
|
|
@ -25,7 +25,6 @@ from src.CrossEntropySmooth import CrossEntropySmooth
|
|||
from src.squeezenet import SqueezeNet as squeezenet
|
||||
from src.dataset import create_dataset_imagenet as create_dataset
|
||||
from src.config import config
|
||||
import moxing as mox
|
||||
|
||||
local_data_url = '/cache/data'
|
||||
local_ckpt_url = '/cache/ckpt.ckpt'
|
||||
|
@ -33,7 +32,7 @@ local_ckpt_url = '/cache/ckpt.ckpt'
|
|||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--dataset', type=str, default='imagenet', help='Dataset.')
|
||||
parser.add_argument('--net', type=str, default='squeezenet', help='Model.')
|
||||
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True,
|
||||
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=False,
|
||||
help='Whether it is running on CloudBrain platform.')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
|
@ -60,6 +59,7 @@ if __name__ == '__main__':
|
|||
|
||||
# create dataset
|
||||
if args_opt.run_cloudbrain:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(args_opt.checkpoint_path, local_ckpt_url)
|
||||
mox.file.copy_parallel(args_opt.data_url, local_data_url)
|
||||
dataset = create_dataset(dataset_path=local_data_url,
|
||||
|
@ -81,7 +81,10 @@ if __name__ == '__main__':
|
|||
net = squeezenet(num_classes=config.class_num)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(local_ckpt_url)
|
||||
if args_opt.run_cloudbrain:
|
||||
param_dict = load_checkpoint(local_ckpt_url)
|
||||
else:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
|
|
|
@ -37,9 +37,9 @@ from src.dataset import create_dataset_imagenet as create_dataset
|
|||
parser = argparse.ArgumentParser(description='SqueezeNet1_1')
|
||||
parser.add_argument('--net', type=str, default='squeezenet', help='Model.')
|
||||
parser.add_argument('--dataset', type=str, default='imagenet', help='Dataset.')
|
||||
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True,
|
||||
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=False,
|
||||
help='Whether it is running on CloudBrain platform.')
|
||||
parser.add_argument('--run_distribute', type=bool, default=True, help='Run distribute')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
|
|
Loading…
Reference in New Issue