forked from mindspore-Ecosystem/mindspore
fix transformer eval device id
This commit is contained in:
parent
2b58af0e9d
commit
55ee90e696
|
@ -193,10 +193,10 @@ Parameters for learning rate:
|
||||||
|
|
||||||
- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||||
|
|
||||||
- Run `run_standalone_train_ascend.sh` for non-distributed training of Transformer model.
|
- Run `run_standalone_train.sh` for non-distributed training of Transformer model.
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
sh scripts/run_standalone_train_ascend.sh DEVICE_ID EPOCH_SIZE DATA_PATH
|
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH
|
||||||
```
|
```
|
||||||
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.
|
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ def run_transformer_eval():
|
||||||
parser = argparse.ArgumentParser(description='tranformer')
|
parser = argparse.ArgumentParser(description='tranformer')
|
||||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
help="device where the code will be implemented, default is Ascend")
|
help="device where the code will be implemented, default is Ascend")
|
||||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False,
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False,
|
||||||
|
|
Loading…
Reference in New Issue