forked from mindspore-Ecosystem/mindspore
!11320 [ModelZoo]fix bgcf train and eval device id bug
From: @zhan_ke Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34
This commit is contained in:
commit
064ee0b383
|
@ -15,7 +15,6 @@
|
|||
"""
|
||||
BGCF evaluation script.
|
||||
"""
|
||||
import os
|
||||
import datetime
|
||||
|
||||
import mindspore.context as context
|
||||
|
@ -78,12 +77,11 @@ def evaluation():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = parser_args()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False)
|
||||
|
||||
parser = parser_args()
|
||||
os.environ['DEVICE_ID'] = parser.device
|
||||
save_graphs=False,
|
||||
device_id=int(parser.device))
|
||||
|
||||
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
|
||||
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
|
@ -31,7 +30,7 @@ cp *.sh ./eval
|
|||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
echo "start evaluation"
|
||||
|
||||
python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log &
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
|
@ -37,7 +36,7 @@ cp *.sh ./train
|
|||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
echo "start training for device $DEVICE_ID"
|
||||
echo "start training"
|
||||
|
||||
python train.py --datapath=../data_mr --ckptpath=../ckpts &> log &
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ if __name__ == "__main__":
|
|||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=parser.device)
|
||||
device_id=int(parser.device))
|
||||
|
||||
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
|
||||
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,
|
||||
|
|
Loading…
Reference in New Issue