diff --git a/model_zoo/official/rl/dqn/train.py b/model_zoo/official/rl/dqn/train.py index 40a1234028a..d6e193d2878 100644 --- a/model_zoo/official/rl/dqn/train.py +++ b/model_zoo/official/rl/dqn/train.py @@ -49,7 +49,6 @@ if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) if args.device_target == 'GPU': cfg = cfg_gpu - context.set_context(device_id=1) env = gym.make(cfg.game) env = env.unwrapped @@ -105,4 +104,3 @@ if __name__ == "__main__": times_numpy = np.array(times) print(rewards_numpy.mean(), times_numpy.mean()) - \ No newline at end of file