From e1c4e35094ac50462bf80a1445ad5b8b555dad2e Mon Sep 17 00:00:00 2001 From: zhanke Date: Fri, 15 Jan 2021 16:12:57 +0800 Subject: [PATCH] fix bgcf device id bug --- model_zoo/official/gnn/bgcf/eval.py | 8 +++----- model_zoo/official/gnn/bgcf/scripts/run_eval_ascend.sh | 3 +-- model_zoo/official/gnn/bgcf/scripts/run_train_ascend.sh | 3 +-- model_zoo/official/gnn/bgcf/train.py | 2 +- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/model_zoo/official/gnn/bgcf/eval.py b/model_zoo/official/gnn/bgcf/eval.py index 70f01b481b2..86888aa0420 100644 --- a/model_zoo/official/gnn/bgcf/eval.py +++ b/model_zoo/official/gnn/bgcf/eval.py @@ -15,7 +15,6 @@ """ BGCF evaluation script. """ -import os import datetime import mindspore.context as context @@ -78,12 +77,11 @@ def evaluation(): if __name__ == "__main__": + parser = parser_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", - save_graphs=False) - - parser = parser_args() - os.environ['DEVICE_ID'] = parser.device + save_graphs=False, + device_id=int(parser.device)) train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath) test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs, diff --git a/model_zoo/official/gnn/bgcf/scripts/run_eval_ascend.sh b/model_zoo/official/gnn/bgcf/scripts/run_eval_ascend.sh index 2195e0abb2e..18e28b7bdc7 100644 --- a/model_zoo/official/gnn/bgcf/scripts/run_eval_ascend.sh +++ b/model_zoo/official/gnn/bgcf/scripts/run_eval_ascend.sh @@ -17,7 +17,6 @@ ulimit -u unlimited export DEVICE_NUM=1 export RANK_SIZE=$DEVICE_NUM -export DEVICE_ID=0 export RANK_ID=0 if [ -d "eval" ]; @@ -31,7 +30,7 @@ cp *.sh ./eval cp -r ../src ./eval cd ./eval || exit env > env.log -echo "start evaluation for device $DEVICE_ID" +echo "start evaluation" python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log & diff --git a/model_zoo/official/gnn/bgcf/scripts/run_train_ascend.sh b/model_zoo/official/gnn/bgcf/scripts/run_train_ascend.sh index e14b8ada9a9..eef920ba32c 100644 --- a/model_zoo/official/gnn/bgcf/scripts/run_train_ascend.sh +++ b/model_zoo/official/gnn/bgcf/scripts/run_train_ascend.sh @@ -17,7 +17,6 @@ ulimit -u unlimited export DEVICE_NUM=1 export RANK_SIZE=$DEVICE_NUM -export DEVICE_ID=0 export RANK_ID=0 if [ -d "train" ]; @@ -37,7 +36,7 @@ cp *.sh ./train cp -r ../src ./train cd ./train || exit env > env.log -echo "start training for device $DEVICE_ID" +echo "start training" python train.py --datapath=../data_mr --ckptpath=../ckpts &> log & diff --git a/model_zoo/official/gnn/bgcf/train.py b/model_zoo/official/gnn/bgcf/train.py index 01bb0498c12..f6d0bb24e5e 100644 --- a/model_zoo/official/gnn/bgcf/train.py +++ b/model_zoo/official/gnn/bgcf/train.py @@ -105,7 +105,7 @@ if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, - device_id=parser.device) + device_id=int(parser.device)) train_graph, _, sampled_graph_list = load_graph(parser.datapath) train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,