!4714 bug fix for ME

Merge pull request !4714 from chenzhongming/new_master
This commit is contained in:
mindspore-ci-bot 2020-08-19 14:12:18 +08:00 committed by Gitee
commit c55a8297c8
5 changed files with 28 additions and 35 deletions

View File

@ -51,6 +51,8 @@ class _Conv(Cell):
self.kernel_size = kernel_size
self.stride = stride
self.pad_mode = pad_mode
self.weight_init = weight_init
self.bias_init = bias_init
if isinstance(padding, int):
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
self.padding = padding
@ -85,12 +87,12 @@ class _Conv(Cell):
shape = [in_channels, out_channels // group, *kernel_size]
else:
shape = [out_channels, in_channels // group, *kernel_size]
self.weight = Parameter(initializer(weight_init, shape), name='weight')
self.weight = Parameter(initializer(self.weight_init, shape), name='weight')
if check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
else:
if bias_init != 'zeros':
if self.bias_init != 'zeros':
logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
self.bias = None
@ -249,11 +251,8 @@ class Conv2d(_Conv):
self.dilation,
self.group,
self.has_bias,
self.weight,
self.bias)
if self.has_bias:
s += ', bias={}'.format(self.bias)
self.weight_init,
self.bias_init)
return s
@ -431,11 +430,8 @@ class Conv1d(_Conv):
self.dilation,
self.group,
self.has_bias,
self.weight,
self.bias)
if self.has_bias:
s += ', bias={}'.format(self.bias)
self.weight_init,
self.bias_init)
return s
@ -605,8 +601,8 @@ class Conv2dTranspose(_Conv):
self.dilation,
self.group,
self.has_bias,
self.weight,
self.bias)
self.weight_init,
self.bias_init)
return s
@ -788,8 +784,8 @@ class Conv1dTranspose(_Conv):
self.dilation,
self.group,
self.has_bias,
self.weight,
self.bias)
self.weight_init,
self.bias_init)
return s

View File

@ -30,7 +30,7 @@ from src.mobilenetV2 import mobilenet_v2
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_targe', type=str, default=None, help='run device_targe')
parser.add_argument('--device_target', type=str, default=None, help='run device_target')
args_opt = parser.parse_args()

View File

@ -73,6 +73,7 @@ run_gpu()
mpirun -n $2 --allow-run-as-root \
python ${BASEPATH}/../train.py \
--dataset_path=$4 \
--pre_trained=$5 \
--device_target=$1 \
&> ../train.log & # dataset train folder
}
@ -81,7 +82,7 @@ if [ $# -gt 6 ] || [ $# -lt 4 ]
then
echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
"
exit 1
fi

View File

@ -49,10 +49,10 @@ de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--device_targe', type=str, default=None, help='run device_targe')
parser.add_argument('--device_target', type=str, default=None, help='run device_target')
args_opt = parser.parse_args()
if args_opt.device_targe == "Ascend":
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
rank_id = int(os.getenv('RANK_ID'))
rank_size = int(os.getenv('RANK_SIZE'))
@ -61,7 +61,7 @@ if args_opt.device_targe == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.device_targe == "GPU":
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU",
save_graphs=False)
@ -161,13 +161,13 @@ class Monitor(Callback):
if __name__ == '__main__':
if args_opt.device_targe == "GPU":
if args_opt.device_target == "GPU":
# train on gpu
print("train args: ", args_opt)
print("cfg: ", config_gpu)
# define network
net = mobilenet_v2(num_classes=config_gpu.num_classes, device_targe="GPU")
net = mobilenet_v2(num_classes=config_gpu.num_classes, device_target="GPU")
# define loss
if config_gpu.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(smooth_factor=config_gpu.label_smooth,
@ -179,7 +179,7 @@ if __name__ == '__main__':
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config_gpu,
device_targe=args_opt.device_targe,
device_target=args_opt.device_target,
repeat_num=1,
batch_size=config_gpu.batch_size)
step_size = dataset.get_dataset_size()
@ -216,7 +216,7 @@ if __name__ == '__main__':
# begin train
model.train(epoch_size, dataset, callbacks=cb)
print("============== End Training ==============")
elif args_opt.device_targe == "Ascend":
elif args_opt.device_target == "Ascend":
# train on ascend
print("train args: ", args_opt, "\ncfg: ", config_ascend,
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
@ -228,7 +228,7 @@ if __name__ == '__main__':
init()
epoch_size = config_ascend.epoch_size
net = mobilenet_v2(num_classes=config_ascend.num_classes, device_targe="Ascend")
net = mobilenet_v2(num_classes=config_ascend.num_classes, device_target="Ascend")
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
@ -242,7 +242,7 @@ if __name__ == '__main__':
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config_ascend,
device_targe=args_opt.device_targe,
device_target=args_opt.device_target,
repeat_num=1,
batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size()
@ -276,4 +276,4 @@ if __name__ == '__main__':
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)
else:
raise ValueError("Unsupported device_targe.")
raise ValueError("Unsupported device_target.")

View File

@ -27,8 +27,8 @@ Dataset used: [imagenet](http://www.image-net.org/)
# Environment Requirements
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- HardwareGPU
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below
@ -60,14 +60,12 @@ Dataset used: [imagenet](http://www.image-net.org/)
### Usage
- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
### Launch
```
# training example
Ascend: sh run_train.sh Ascend 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet/train/
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
```
@ -86,14 +84,12 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
### Usage
- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH]
- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
### Launch
```
# infer example
Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
```