Skip to content

Commit

Permalink
fixed small typos and minor todos
Browse files Browse the repository at this point in the history
  • Loading branch information
LabChameleon committed Jun 5, 2024
1 parent 6c74a3f commit d42dee6
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 64 deletions.
32 changes: 19 additions & 13 deletions arlbench/core/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
env: Environment | Wrapper,
eval_env: Environment | Wrapper | None = None,
deterministic_eval: bool = True,
eval_eps: float = 0.05,
cnn_policy: bool = False,
nas_config: Configuration | None = None,
track_trajectories: bool = False,
Expand All @@ -141,7 +142,9 @@ def __init__(
Args:
hpo_config (Configuration): Hyperparameter configuration.
env (Environment | AutoRLWrapper): Training environment.
eval_env (Environment | AutoRLWrapper | None, optional): Evaluation environent (otherwise training environment is used for evaluation). Defaults to None.
eval_env (Environment | AutoRLWrapper | None, optional): Evaluation environment (otherwise training environment is used for evaluation). Defaults to None.
deterministic_eval (bool, optional): Use deterministic evaluation. Defaults to True.
eval_eps (float, optional): Epsilon value for non-deterministic evaluation. Defaults to 0.05.
cnn_policy (bool, optional): Use CNN network architecture. Defaults to False.
nas_config (Configuration | None, optional): Neural architecture configuration. Defaults to None.
track_trajectories (bool, optional): Track metrics such as loss and gradients during training. Defaults to False.
Expand All @@ -160,6 +163,8 @@ def __init__(
track_metrics=track_metrics,
)

self.eval_eps = eval_eps

# For the network, we need the properties of the action space
action_size, discrete = self.action_type
network_cls = CNNQ if cnn_policy else MLPQ
Expand Down Expand Up @@ -218,6 +223,7 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace:
"tau": Float("tau", (0.01, 1.0), default=1.0),
"initial_epsilon": Float("initial_epsilon", (0.5, 1.0), default=1.0),
"target_epsilon": Float("target_epsilon", (0.001, 0.2), default=0.05),
"exploration_fraction": Float("initial_epsilon", (0.005, 0.5), default=0.1),
"use_target_network": Categorical(
"use_target_network", [True, False], default=True
),
Expand Down Expand Up @@ -267,6 +273,7 @@ def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace:
"tau": Float("tau", (0.01, 1.0), default=1.0),
"initial_epsilon": Float("initial_epsilon", (0.5, 1.0), default=1.0),
"target_epsilon": Float("target_epsilon", (0.001, 0.2), default=0.05),
"exploration_fraction": Float("initial_epsilon", (0.005, 0.5), default=0.1),
"use_target_network": Categorical(
"use_target_network", [True, False], default=True
),
Expand Down Expand Up @@ -315,7 +322,7 @@ def get_checkpoint_factory(
runner_state: DQNRunnerState,
train_result: DQNTrainingResult | None,
) -> dict[str, Callable]:
"""Creates a factory dictionary of all posssible checkpointing options for DQN.
"""Creates a factory dictionary of all possible checkpointing options for DQN.
Args:
runner_state (DQNRunnerState): Algorithm runner state.
Expand Down Expand Up @@ -381,7 +388,7 @@ def init(
_, (_obs, _reward, _done, _) = self.env.step(env_state, _action, dummy_rng)

if buffer_state is None:
# This is how transitions will look like during training so we need to pass one
# This is how transitions will look like during training, so we need to pass one
# once to the buffer to estimate and allocate the required buffer size
_timestep = TimeStep(
last_obs=_obs[0],
Expand Down Expand Up @@ -458,7 +465,7 @@ def sample_action(rng: chex.PRNGKey, obs: jnp.ndarray) -> jnp.ndarray:
rnd_action = random_action(rng, obs)
grd_action = greedy_action(rng, obs)
return jax.lax.select(
jax.random.uniform(rng, obs.shape[:1]) < 0.05, rnd_action, grd_action
jax.random.uniform(rng, obs.shape[:1]) < self.eval_eps, rnd_action, grd_action
)

return jax.lax.cond(
Expand All @@ -478,7 +485,7 @@ def train(
n_eval_steps: int = 100,
n_eval_episodes: int = 10,
) -> DQNTrainReturnT:
"""Performs one iteration of training.
"""Performs one full training.
Args:
runner_state (DQNRunnerState): DQN runner state.
Expand Down Expand Up @@ -521,7 +528,6 @@ def train_eval_step(
n_update_steps,
)
eval_returns = self.eval(runner_state, n_eval_episodes)
#jax.debug.print("{eval_returns}", eval_returns=eval_returns.mean())

return (runner_state, buffer_state), DQNTrainingResult(
eval_rewards=eval_returns, trajectories=trajectories, metrics=metrics
Expand Down Expand Up @@ -549,7 +555,7 @@ def update(
Args:
train_state (DQNTrainState): DQN training state.
observations (jnp.ndarray): Batch of observations..
observations (jnp.ndarray): Batch of observations.
actions (jnp.ndarray): Batch of actions.
next_observations (jnp.ndarray): Batch of next observations.
rewards (jnp.ndarray): Batch of rewards.
Expand Down Expand Up @@ -632,7 +638,7 @@ def take_step(
],
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict],
]:
"""Takes one environment step (n_envs timesteps).
"""Takes one environment step (n_envs many steps).
Args:
carry (tuple[chex.PRNGKey, DQNTrainState, RunningStatisticsState, jnp.ndarray, Any, int, PrioritisedTrajectoryBufferState]): Carry for jax.lax.scan().
Expand Down Expand Up @@ -685,11 +691,11 @@ def greedy_action(_: chex.PRNGKey, obs: jnp.ndarray) -> jnp.ndarray:

rng, sample_rng, action_rng = jax.random.split(rng, 3)
training_fraction = jnp.min(
jnp.array([global_step * self.env.n_envs / n_total_timesteps, 0.1])
jnp.array([global_step * self.env.n_envs / n_total_timesteps, self.hpo_config["exploration_fraction"]])
)
epsilon = self.hpo_config["initial_epsilon"] - training_fraction * (
(self.hpo_config["initial_epsilon"] - self.hpo_config["target_epsilon"])
/ 0.1
/ self.hpo_config["exploration_fraction"]
)
rand_action = random_action(sample_rng, last_obs)
greedy_action = greedy_action(action_rng, last_obs)
Expand Down Expand Up @@ -964,13 +970,13 @@ def dont_update(
obs=last_obs,
global_step=global_step,
)
tracjectories = None
trajectories = None
if self.track_trajectories:
tracjectories = Transition(
trajectories = Transition(
obs=observations,
action=action,
reward=reward,
done=done,
info=info,
)
return (runner_state, buffer_state), (metrics, tracjectories)
return (runner_state, buffer_state), (metrics, trajectories)
2 changes: 1 addition & 1 deletion arlbench/core/algorithms/dqn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __call__(self, x):


class MLPQ(nn.Module):
"""A MLP-based Q-Network for DQN."""
"""An MLP-based Q-Network for DQN."""

action_dim: int
activation: str = "tanh"
Expand Down
4 changes: 2 additions & 2 deletions arlbench/core/algorithms/ppo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class MLPActorCritic(nn.Module):
"""A MLP-based Actor-Critic network for PPO."""
"""An MLP-based Actor-Critic network for PPO."""

action_dim: int
activation: str = "tanh"
Expand Down Expand Up @@ -131,7 +131,7 @@ def setup(self):
)

def __call__(self, x):
x = x / 255.0 # todo: make a clean solution for this
x = x / 255.0
x = jnp.transpose(x, (0, 2, 3, 1))
features = self.feature_conv0(x)
features = self.activation_func(features)
Expand Down
22 changes: 0 additions & 22 deletions arlbench/core/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ class PPORunnerState(NamedTuple):
env_state: Any
obs: chex.Array
global_step: int
return_buffer_idx: chex.Array | None = None
return_buffer: chex.Array | None = None
cur_rewards: chex.Array | None = None


class PPOState(NamedTuple):
Expand Down Expand Up @@ -327,9 +324,6 @@ def init(
env_state=env_state,
obs=obs,
global_step=0,
return_buffer_idx=jnp.array([0]),
return_buffer=jnp.zeros(100),
cur_rewards=jnp.zeros(self.env.n_envs),
)

return PPOState(runner_state=runner_state, buffer_state=None)
Expand Down Expand Up @@ -409,7 +403,6 @@ def train_eval_step(
Returns:
tuple[PPORunnerState, PPOTrainingResult]: Tuple of PPO runner state and training result.
"""
#jax.debug.print("hallo")
_runner_state, (metrics, trajectories) = jax.lax.scan(
self._update_step,
_runner_state,
Expand All @@ -422,7 +415,6 @@ def train_eval_step(
),
)
eval_returns = self.eval(_runner_state, n_eval_episodes)
#jax.debug.print("{ret}", ret=eval_returns.mean())

return _runner_state, PPOTrainingResult(
eval_rewards=eval_returns, trajectories=trajectories, metrics=metrics
Expand Down Expand Up @@ -459,9 +451,6 @@ def _update_step(
env_state,
last_obs,
global_step,
return_buffer_idx,
return_buffer,
cur_rewards,
) = runner_state
if self.hpo_config["normalize_observations"]:
normalizer_state = running_statistics.update(
Expand Down Expand Up @@ -501,9 +490,6 @@ def _update_step(
env_state=env_state,
obs=last_obs,
global_step=global_step,
return_buffer_idx=return_buffer_idx,
return_buffer=runner_state.return_buffer,
cur_rewards=runner_state.cur_rewards,
)
metrics, trajectories = None, None
if self.track_metrics:
Expand Down Expand Up @@ -532,9 +518,6 @@ def _env_step(
env_state,
last_obs,
global_step,
return_buffer_idx,
return_buffer,
cur_rewards,
) = runner_state

# Select action(s)
Expand Down Expand Up @@ -563,8 +546,6 @@ def _env_step(
global_step += 1

transition = Transition(done, action, value, reward, log_prob, last_obs, info)
cur_rewards += reward
jnp.array([False]) # todo: print_reward!!

runner_state = PPORunnerState(
train_state=train_state,
Expand All @@ -573,9 +554,6 @@ def _env_step(
obs=obsv,
rng=rng,
global_step=global_step,
return_buffer_idx=return_buffer_idx,
return_buffer=return_buffer,
cur_rewards=cur_rewards,
)
return runner_state, transition

Expand Down
4 changes: 2 additions & 2 deletions arlbench/core/algorithms/sac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __call__(self) -> jnp.ndarray:


class SACMLPActor(nn.Module):
"""A MLP-based actor network for PPO."""
"""An MLP-based actor network for PPO."""

action_dim: int
activation: int
Expand Down Expand Up @@ -153,7 +153,7 @@ def __call__(self, x):


class SACMLPCritic(nn.Module):
"""A MLP-based critic network for SAC."""
"""An MLP-based critic network for SAC."""

action_dim: int
activation: int
Expand Down
Loading

0 comments on commit d42dee6

Please sign in to comment.