Add some reinforcement learning example. (#1090)
* Add some reinforcement learning example. * Python initialization. * Get the example to run. * Vectorized gym envs for the atari wrappers. * Get some simulation loop to run.
This commit is contained in:
parent
9309cfc47d
commit
29c7f2565d
|
@ -21,6 +21,7 @@ half = { workspace = true, optional = true }
|
|||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
@ -58,3 +59,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
|||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl", "flash-attn"]
|
||||
|
||||
[[example]]
|
||||
name = "reinforcement-learning"
|
||||
required-features = ["pyo3"]
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# candle-reinforcement-learning
|
||||
|
||||
Reinforcement Learning examples for candle.
|
||||
|
||||
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||
Python package with:
|
||||
```bash
|
||||
pip install "gymnasium[accept-rom-license]"
|
||||
```
|
||||
|
||||
In order to run the example, use the following command. Note the additional
|
||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||
crate.
|
||||
```bash
|
||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
||||
```
|
|
@ -0,0 +1,308 @@
|
|||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from PIL import Image
|
||||
from multiprocessing import Process, Pipe
|
||||
|
||||
# atari_wrappers.py
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Take action on reset for environments that are fixed until firing."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
|
||||
class ImageSaver(gym.Wrapper):
|
||||
def __init__(self, env, img_path, rank):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self._cnt = 0
|
||||
self._img_path = img_path
|
||||
self._rank = rank
|
||||
|
||||
def step(self, action):
|
||||
step_result = self.env.step(action)
|
||||
obs, _, _, _ = step_result
|
||||
img = Image.fromarray(obs, 'RGB')
|
||||
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
|
||||
self._cnt += 1
|
||||
return step_result
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action):
|
||||
"""Repeat action, sum reward, and max over last observations."""
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
def reward(self, reward):
|
||||
"""Bin reward to {+1, 0, -1} by its sign."""
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.res = 84
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
|
||||
|
||||
def observation(self, obs):
|
||||
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
|
||||
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
|
||||
resample=Image.BILINEAR), dtype=np.uint8)
|
||||
return frame.reshape((self.res, self.res, 1))
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
"""Buffer observations and stack across channels (last axis)."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
assert shp[2] == 1 # can only stack 1-channel frames
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
|
||||
|
||||
def reset(self):
|
||||
"""Clear buffer and re-fill by duplicating the first observation."""
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k): self.frames.append(ob)
|
||||
return self.observation()
|
||||
|
||||
def step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self.observation(), reward, done, info
|
||||
|
||||
def observation(self):
|
||||
assert len(self.frames) == self.k
|
||||
return np.concatenate(self.frames, axis=2)
|
||||
|
||||
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
|
||||
"""Configure environment for DeepMind-style Atari.
|
||||
|
||||
Note: this does not include frame stacking!"""
|
||||
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
||||
if episode_life:
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = WarpFrame(env)
|
||||
if clip_rewards:
|
||||
env = ClipRewardEnv(env)
|
||||
return env
|
||||
|
||||
# envs.py
|
||||
def make_env(env_id, img_dir, seed, rank):
|
||||
def _thunk():
|
||||
env = gym.make(env_id)
|
||||
env.reset(seed=(seed + rank))
|
||||
if img_dir is not None:
|
||||
env = ImageSaver(env, img_dir, rank)
|
||||
env = wrap_deepmind(env)
|
||||
env = WrapPyTorch(env)
|
||||
return env
|
||||
|
||||
return _thunk
|
||||
|
||||
class WrapPyTorch(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(WrapPyTorch, self).__init__(env)
|
||||
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
|
||||
|
||||
def observation(self, observation):
|
||||
return observation.transpose(2, 0, 1)
|
||||
|
||||
# vecenv.py
|
||||
class VecEnv(object):
|
||||
"""
|
||||
Vectorized environment base class
|
||||
"""
|
||||
def step(self, vac):
|
||||
"""
|
||||
Apply sequence of actions to sequence of environments
|
||||
actions -> (observations, rewards, news)
|
||||
|
||||
where 'news' is a boolean vector indicating whether each element is new.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# subproc_vec_env.py
|
||||
def worker(remote, env_fn_wrapper):
|
||||
env = env_fn_wrapper.x()
|
||||
while True:
|
||||
cmd, data = remote.recv()
|
||||
if cmd == 'step':
|
||||
ob, reward, done, info = env.step(data)
|
||||
if done:
|
||||
ob = env.reset()
|
||||
remote.send((ob, reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
ob = env.reset()
|
||||
remote.send(ob)
|
||||
elif cmd == 'close':
|
||||
remote.close()
|
||||
break
|
||||
elif cmd == 'get_spaces':
|
||||
remote.send((env.action_space, env.observation_space))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.x)
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
|
||||
class SubprocVecEnv(VecEnv):
|
||||
def __init__(self, env_fns):
|
||||
"""
|
||||
envs: list of gym environments to run in subprocesses
|
||||
"""
|
||||
nenvs = len(env_fns)
|
||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
|
||||
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
|
||||
for p in self.ps:
|
||||
p.start()
|
||||
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
self.action_space, self.observation_space = self.remotes[0].recv()
|
||||
|
||||
|
||||
def step(self, actions):
|
||||
for remote, action in zip(self.remotes, actions):
|
||||
remote.send(('step', action))
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
obs, rews, dones, infos = zip(*results)
|
||||
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def reset(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('reset', None))
|
||||
return np.stack([remote.recv() for remote in self.remotes])
|
||||
|
||||
def close(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('close', None))
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
|
||||
@property
|
||||
def num_envs(self):
|
||||
return len(self.remotes)
|
||||
|
||||
# Create the environment.
|
||||
def make(env_name, img_dir, num_processes):
|
||||
envs = SubprocVecEnv([
|
||||
make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
|
||||
])
|
||||
return envs
|
|
@ -0,0 +1,108 @@
|
|||
#![allow(unused)]
|
||||
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
|
||||
use candle::{Device, Result, Tensor};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
/// The return value for a step.
|
||||
#[derive(Debug)]
|
||||
pub struct Step<A> {
|
||||
pub obs: Tensor,
|
||||
pub action: A,
|
||||
pub reward: f64,
|
||||
pub is_done: bool,
|
||||
}
|
||||
|
||||
impl<A: Copy> Step<A> {
|
||||
/// Returns a copy of this step changing the observation tensor.
|
||||
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
|
||||
Step {
|
||||
obs: obs.clone(),
|
||||
action: self.action,
|
||||
reward: self.reward,
|
||||
is_done: self.is_done,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An OpenAI Gym session.
|
||||
pub struct GymEnv {
|
||||
env: PyObject,
|
||||
action_space: usize,
|
||||
observation_space: Vec<usize>,
|
||||
}
|
||||
|
||||
fn w(res: PyErr) -> candle::Error {
|
||||
candle::Error::wrap(res)
|
||||
}
|
||||
|
||||
impl GymEnv {
|
||||
/// Creates a new session of the specified OpenAI Gym environment.
|
||||
pub fn new(name: &str) -> Result<GymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let gym = py.import("gymnasium")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name,))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
let action_space = if let Ok(val) = action_space.getattr("n") {
|
||||
val.extract()?
|
||||
} else {
|
||||
let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
|
||||
action_space[0]
|
||||
};
|
||||
let observation_space = env.getattr("observation_space")?;
|
||||
let observation_space = observation_space.getattr("shape")?.extract()?;
|
||||
Ok(GymEnv {
|
||||
env: env.into(),
|
||||
action_space,
|
||||
observation_space,
|
||||
})
|
||||
})
|
||||
.map_err(w)
|
||||
}
|
||||
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let obs: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
obs.as_ref(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(obs, &Device::Cpu)
|
||||
}
|
||||
|
||||
/// Applies an environment step using the specified action.
|
||||
pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
|
||||
&self,
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let obs: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let is_done: bool = step.get_item(2)?.extract()?;
|
||||
Ok((obs, reward, is_done))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let obs = Tensor::new(obs, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
obs,
|
||||
reward,
|
||||
is_done,
|
||||
action,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the number of allowed actions for this environment.
|
||||
pub fn action_space(&self) -> usize {
|
||||
self.action_space
|
||||
}
|
||||
|
||||
/// Returns the shape of the observation tensors.
|
||||
pub fn observation_space(&self) -> &[usize] {
|
||||
&self.observation_space
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod gym_env;
|
||||
mod vec_gym_env;
|
||||
|
||||
use candle::Result;
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
|
||||
// The total number of episodes.
|
||||
const MAX_EPISODES: usize = 100;
|
||||
// The maximum length of an episode.
|
||||
const EPISODE_LENGTH: usize = 200;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
||||
println!("action space: {}", env.action_space());
|
||||
println!("observation space: {:?}", env.observation_space());
|
||||
|
||||
let _num_obs = env.observation_space().iter().product::<usize>();
|
||||
let _num_actions = env.action_space();
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for episode in 0..MAX_EPISODES {
|
||||
let mut obs = env.reset(episode as u64)?;
|
||||
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let actions = rng.gen_range(-2.0..2.0);
|
||||
|
||||
let step = env.step(vec![actions])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
if step.is_done {
|
||||
break;
|
||||
}
|
||||
obs = step.obs;
|
||||
}
|
||||
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
#![allow(unused)]
|
||||
//! Vectorized version of the gym environment.
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Step {
|
||||
pub obs: Tensor,
|
||||
pub reward: Tensor,
|
||||
pub is_done: Tensor,
|
||||
}
|
||||
|
||||
pub struct VecGymEnv {
|
||||
env: PyObject,
|
||||
action_space: usize,
|
||||
observation_space: Vec<usize>,
|
||||
}
|
||||
|
||||
fn w(res: PyErr) -> candle::Error {
|
||||
candle::Error::wrap(res)
|
||||
}
|
||||
|
||||
impl VecGymEnv {
|
||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let sys = py.import("sys")?;
|
||||
let path = sys.getattr("path")?;
|
||||
let _ = path.call_method1(
|
||||
"append",
|
||||
("candle-examples/examples/reinforcement-learning",),
|
||||
)?;
|
||||
let gym = py.import("atari_wrappers")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name, img_dir, nprocesses))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
let action_space = action_space.getattr("n")?.extract()?;
|
||||
let observation_space = env.getattr("observation_space")?;
|
||||
let observation_space: Vec<usize> = observation_space.getattr("shape")?.extract()?;
|
||||
let observation_space =
|
||||
[vec![nprocesses].as_slice(), observation_space.as_slice()].concat();
|
||||
Ok(VecGymEnv {
|
||||
env: env.into(),
|
||||
action_space,
|
||||
observation_space,
|
||||
})
|
||||
})
|
||||
.map_err(w)
|
||||
}
|
||||
|
||||
pub fn reset(&self) -> Result<Tensor> {
|
||||
let obs = Python::with_gil(|py| {
|
||||
let obs = self.env.call_method0(py, "reset")?;
|
||||
let obs = obs.call_method0(py, "flatten")?;
|
||||
obs.extract::<Vec<f32>>(py)
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice())
|
||||
}
|
||||
|
||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||
Ok((obs, reward, is_done))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let reward = Tensor::new(reward, &Device::Cpu)?;
|
||||
let is_done = Tensor::new(is_done, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
obs,
|
||||
reward,
|
||||
is_done,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn action_space(&self) -> usize {
|
||||
self.action_space
|
||||
}
|
||||
|
||||
pub fn observation_space(&self) -> &[usize] {
|
||||
&self.observation_space
|
||||
}
|
||||
}
|
|
@ -143,7 +143,6 @@ fn main() -> Result<()> {
|
|||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
|
|
Loading…
Reference in New Issue