-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #177 from Lux-AI-Challenge/sb3-integration
Sb3 integration
- Loading branch information
Showing
46 changed files
with
2,686 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,6 @@ __temp__.mp4 | |
.ipynb_checkpoints | ||
|
||
node_modules | ||
dist | ||
dist | ||
|
||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,319 @@ | ||
import copy | ||
import os.path as osp | ||
|
||
import gym | ||
import numpy as np | ||
import torch as th | ||
import torch.nn as nn | ||
from gym import spaces | ||
from gym.wrappers import TimeLimit | ||
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback | ||
from stable_baselines3.common.evaluation import evaluate_policy | ||
from stable_baselines3.common.monitor import Monitor | ||
from stable_baselines3.common.utils import set_random_seed | ||
from stable_baselines3.common.vec_env import ( | ||
DummyVecEnv, | ||
SubprocVecEnv, | ||
VecCheckNan, | ||
VecVideoRecorder, | ||
) | ||
from stable_baselines3.ppo import PPO | ||
|
||
from luxai_s2.state import ObservationStateDict, StatsStateDict, create_empty_stats | ||
from luxai_s2.utils.heuristics.factory import build_single_heavy | ||
from luxai_s2.utils.heuristics.factory_placement import place_near_random_ice | ||
from luxai_s2.wrappers import ( | ||
SB3Wrapper, | ||
SimpleSingleUnitDiscreteController, | ||
SingleUnitObservationWrapper, | ||
) | ||
|
||
|
||
class CustomEnvWrapper(gym.Wrapper): | ||
def __init__(self, env: gym.Env) -> None: | ||
""" | ||
Adds a custom reward and turns the LuxAI_S2 environment into a single-agent environment for easy training | ||
""" | ||
super().__init__(env) | ||
self.prev_step_metrics = None | ||
|
||
def step(self, action): | ||
agent = "player_0" | ||
opp_agent = "player_1" | ||
|
||
opp_factories = self.env.state.factories[opp_agent] | ||
for k in opp_factories: | ||
factory = opp_factories[k] | ||
factory.cargo.water = 1000 # set enemy factories to have 1000 water to keep them alive the whole around and treat the game as single-agent | ||
|
||
action = {agent: action} | ||
obs, reward, done, info = super().step(action) | ||
|
||
# this is the observation seen by both agents | ||
shared_obs: ObservationStateDict = self.env.prev_obs[agent] | ||
done = done[agent] | ||
|
||
# we collect stats on teams here: | ||
stats: StatsStateDict = self.env.state.stats[agent] | ||
|
||
# compute reward | ||
# we simply want to encourage the heavy units to move to ice tiles | ||
# and mine them and then bring them back to the factory and dump it | ||
# as well as survive as long as possible | ||
|
||
factories = shared_obs["factories"][agent] | ||
factory_pos = None | ||
for unit_id in factories: | ||
factory = factories[unit_id] | ||
# note that ice converts to water at a 4:1 ratio | ||
factory_pos = np.array(factory["pos"]) | ||
break | ||
units = shared_obs["units"][agent] | ||
unit_deliver_ice_reward = 0 | ||
unit_move_to_ice_reward = 0 | ||
unit_overmining_penalty = 0 | ||
penalize_power_waste = 0 | ||
|
||
ice_map = shared_obs["board"]["ice"] | ||
ice_tile_locations = np.argwhere(ice_map == 1) | ||
|
||
def manhattan_dist(p1, p2): | ||
return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) | ||
|
||
unit_power = 0 | ||
for unit_id in units: | ||
unit = units[unit_id] | ||
if unit["unit_type"] == "HEAVY": | ||
pos = np.array(unit["pos"]) | ||
ice_tile_distances = np.mean((ice_tile_locations - pos) ** 2, 1) | ||
closest_ice_tile = ice_tile_locations[np.argmin(ice_tile_distances)] | ||
dist_to_ice = manhattan_dist(closest_ice_tile, pos) | ||
unit_power = unit["power"] | ||
if unit["cargo"]["ice"] < 20: | ||
|
||
dist_penalty = min( | ||
1.0, dist_to_ice / (10) | ||
) # go beyond 12 squares manhattan dist and no reward | ||
unit_move_to_ice_reward += ( | ||
1 - dist_penalty | ||
) * 0.1 # encourage unit to move to ice | ||
else: | ||
if factory_pos is not None: | ||
dist_to_factory = manhattan_dist(pos, factory_pos) | ||
dist_penalty = min(1.0, dist_to_factory / 10) | ||
unit_deliver_ice_reward = ( | ||
0.2 + (1 - dist_penalty) * 0.1 | ||
) # encourage unit to move back to factory | ||
if action[agent] == 15 and unit["power"] < 70: | ||
# penalize the agent for trying to dig with insufficient power, which wastes 10 power for trying to update the action queue | ||
penalize_power_waste -= 0.005 | ||
|
||
# save some stats to the info object so we can record it with our SB3 logger | ||
info = dict() | ||
metrics = dict() | ||
metrics["ice_dug"] = ( | ||
stats["generation"]["ice"]["HEAVY"] + stats["generation"]["ice"]["LIGHT"] | ||
) | ||
metrics["water_produced"] = stats["generation"]["water"] | ||
metrics["action_queue_updates_success"] = stats["action_queue_updates_success"] | ||
metrics["action_queue_updates_total"] = stats["action_queue_updates_total"] | ||
|
||
metrics["unit_deliver_ice_reward"] = unit_deliver_ice_reward | ||
metrics["unit_move_to_ice_reward"] = unit_move_to_ice_reward | ||
|
||
info["metrics"] = metrics | ||
|
||
reward = ( | ||
0 | ||
+ unit_move_to_ice_reward | ||
+ unit_deliver_ice_reward | ||
+ unit_overmining_penalty | ||
+ metrics["water_produced"] / 10 + penalize_power_waste | ||
) | ||
reward = reward | ||
if self.prev_step_metrics is not None: | ||
ice_dug_this_step = metrics["ice_dug"] - self.prev_step_metrics["ice_dug"] | ||
water_produced_this_step = ( | ||
metrics["water_produced"] - self.prev_step_metrics["water_produced"] | ||
) | ||
# reward += ice_dug_this_step # reward agent for digging ice | ||
# reward += water_produced_this_step * 100 # reward agent even more producing water by delivering ice back to base | ||
self.prev_step_metrics = copy.deepcopy(metrics) | ||
return obs["player_0"], reward, done, info | ||
|
||
def reset(self, **kwargs): | ||
obs = self.env.reset(**kwargs)["player_0"] | ||
self.prev_step_metrics = None | ||
return obs | ||
|
||
|
||
def parse_args(): | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Simple script that simplifies Lux AI Season 2 as a single-agent environment with a reduced observation and action space. It trains a policy that can succesfully control a heavy unit to dig ice and transfer it back to a factory to keep it alive" | ||
) | ||
parser.add_argument("-s", "--seed", type=int, default=12, help="seed for training") | ||
parser.add_argument( | ||
"-n", | ||
"--n-envs", | ||
type=int, | ||
default=8, | ||
help="Number of parallel envs to run. Note that the rollout size is configured separately and invariant to this value", | ||
) | ||
parser.add_argument( | ||
"--max-episode-steps", | ||
type=int, | ||
default=100, | ||
help="Max steps per episode before truncating them", | ||
) | ||
parser.add_argument( | ||
"--total-timesteps", | ||
type=int, | ||
default=3_000_000, | ||
help="Total timesteps for training", | ||
) | ||
|
||
parser.add_argument( | ||
"--eval", | ||
action="store_true", | ||
help="If set, will only evaluate a given policy. Otherwise enters training mode", | ||
) | ||
parser.add_argument( | ||
"--model-path", type=str, help="Path to SB3 model weights to use for evaluation" | ||
) | ||
parser.add_argument( | ||
"-l", | ||
"--log-path", | ||
type=str, | ||
default="logs", | ||
help="Logging path", | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def make_env(env_id: str, rank: int, seed: int = 0, max_episode_steps=100): | ||
def _init() -> gym.Env: | ||
# verbose = 0 | ||
# collect stats so we can create reward functions | ||
# max factories set to 2 for simplification and keeping returns consistent as we survive longer if there are more initial resources | ||
env = gym.make(env_id, verbose=0, collect_stats=True, MAX_FACTORIES=2) | ||
|
||
# Add a SB3 wrapper to make it work with SB3 and simplify the action space with the controller | ||
# this will remove the bidding phase and factory placement phase. For factory placement we use | ||
# the provided place_near_random_ice function which will randomly select an ice tile and place a factory near it. | ||
env = SB3Wrapper( | ||
env, | ||
controller=SimpleSingleUnitDiscreteController(env.state.env_cfg), | ||
factory_placement_policy=place_near_random_ice, | ||
heuristic_policy=build_single_heavy, | ||
) | ||
env = SingleUnitObservationWrapper( | ||
env | ||
) # changes observation to include a few simple features | ||
env = CustomEnvWrapper(env) # convert to single agent and add our reward | ||
env = TimeLimit( | ||
env, max_episode_steps=max_episode_steps | ||
) # set horizon to 100 to make training faster. Default is 1000 | ||
env = Monitor(env) # for SB3 to allow it to record metrics | ||
env.reset(seed=seed + rank) | ||
set_random_seed(seed) | ||
return env | ||
|
||
return _init | ||
|
||
|
||
env_id = "LuxAI_S2-v0" | ||
|
||
from collections import defaultdict | ||
|
||
|
||
class TensorboardCallback(BaseCallback): | ||
def __init__(self, tag: str, verbose=0): | ||
super().__init__(verbose) | ||
self.tag = tag | ||
|
||
def _on_step(self) -> bool: | ||
c = 0 | ||
|
||
for i, done in enumerate(self.locals["dones"]): | ||
if done: | ||
info = self.locals["infos"][i] | ||
c += 1 | ||
for k in info["metrics"]: | ||
stat = info["metrics"][k] | ||
self.logger.record_mean(f"{self.tag}/{k}", stat) | ||
return True | ||
|
||
|
||
def evaluate(args, model): | ||
model = model.load(args.model_path) | ||
video_length = 1000 # default horizon | ||
eval_env = SubprocVecEnv([make_env(env_id, i, max_episode_steps=1000) for i in range(args.n_envs)]) | ||
eval_env = VecVideoRecorder( | ||
eval_env, | ||
osp.join(args.log_path, "eval_videos"), | ||
record_video_trigger=lambda x: x == 0, | ||
video_length=video_length, | ||
name_prefix=f"evaluation_video", | ||
) | ||
eval_env.reset() | ||
out =evaluate_policy(model, eval_env, render=False, deterministic=False) | ||
print(out) | ||
|
||
def train(args, model: PPO): | ||
eval_env = SubprocVecEnv([make_env(env_id, i, max_episode_steps=1000) for i in range(4)]) | ||
video_length = 1000 | ||
eval_env = VecVideoRecorder( | ||
eval_env, | ||
osp.join(args.log_path, "eval_videos"), | ||
record_video_trigger=lambda x: x == 0, | ||
video_length=video_length, | ||
name_prefix=f"evaluation-{env_id}", | ||
) | ||
eval_callback = EvalCallback( | ||
eval_env, | ||
best_model_save_path=osp.join(args.log_path, "models"), | ||
log_path=osp.join(args.log_path, "eval_logs"), | ||
eval_freq=24_000, | ||
deterministic=False, | ||
render=False, | ||
) | ||
model.learn( | ||
args.total_timesteps, | ||
callback=[TensorboardCallback(tag="train_metrics"), eval_callback], | ||
) | ||
model.save(args.log_path, "latest_model") | ||
|
||
|
||
def main(args): | ||
print("Training with args", args) | ||
set_random_seed(args.seed) | ||
|
||
env = SubprocVecEnv([make_env(env_id, i, max_episode_steps=args.max_episode_steps) for i in range(args.n_envs)]) | ||
env.reset() | ||
rollout_steps = 4_000 | ||
policy_kwargs = dict(net_arch=(128, 128)) | ||
model = PPO( | ||
"MlpPolicy", | ||
env, | ||
n_steps=rollout_steps // args.n_envs, | ||
batch_size=800, | ||
learning_rate=1e-3, | ||
policy_kwargs=policy_kwargs, | ||
verbose=1, | ||
n_epochs=3, | ||
target_kl=0.07, | ||
gamma=0.97, | ||
tensorboard_log=osp.join(args.log_path), | ||
) | ||
if args.eval: | ||
evaluate(args, model) | ||
else: | ||
train(args, model) | ||
|
||
|
||
if __name__ == "__main__": | ||
# python ../examples/sb3.py -l logs/exp_1 -s 42 -n 1 | ||
main(parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from lux.kit import obs_to_game_state, GameState | ||
from lux.config import EnvConfig | ||
from lux.utils import direction_to, my_turn_to_place_factory | ||
import numpy as np | ||
import sys | ||
from wrappers import SimpleSingleUnitDiscreteController | ||
from wrappers import SingleUnitObservationWrapper | ||
import torch as th | ||
class Agent(): | ||
def __init__(self, player: str, env_cfg: EnvConfig) -> None: | ||
self.player = player | ||
self.opp_player = "player_1" if self.player == "player_0" else "player_0" | ||
np.random.seed(0) | ||
self.env_cfg: EnvConfig = env_cfg | ||
|
||
# load our RL policy | ||
th.load("") | ||
|
||
def bid_policy(self, step: int, obs, remainingOverageTime: int = 60): | ||
return dict(faction="AlphaStrike", bid=0) | ||
def factory_placement_policy(self, step: int, obs, remainingOverageTime: int = 60): | ||
if obs["teams"][self.player]["metal"] == 0: | ||
return dict() | ||
potential_spawns = list(zip(*np.where(obs["board"]["valid_spawns_mask"] == 1))) | ||
potential_spawns_set = set(potential_spawns) | ||
done_search = False | ||
# if player == "player_1": | ||
ice_diff = np.diff(obs["board"]["ice"]) | ||
pot_ice_spots = np.argwhere(ice_diff == 1) | ||
if len(pot_ice_spots) == 0: | ||
pot_ice_spots = potential_spawns | ||
trials = 5 | ||
while trials > 0: | ||
pos_idx = np.random.randint(0, len(pot_ice_spots)) | ||
pos = pot_ice_spots[pos_idx] | ||
|
||
area = 3 | ||
for x in range(area): | ||
for y in range(area): | ||
check_pos = [pos[0] + x - area // 2, pos[1] + y - area // 2] | ||
if tuple(check_pos) in potential_spawns_set: | ||
done_search = True | ||
pos = check_pos | ||
break | ||
if done_search: | ||
break | ||
if done_search: | ||
break | ||
trials -= 1 | ||
spawn_loc = potential_spawns[np.random.randint(0, len(potential_spawns))] | ||
if not done_search: | ||
pos = spawn_loc | ||
|
||
metal = obs["teams"][self.player]["metal"] | ||
return dict(spawn=pos, metal=metal, water=metal) | ||
|
||
def act(self, step: int, obs, remainingOverageTime: int = 60): | ||
|
||
return actions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from dataclasses import dataclass | ||
|
||
@dataclass | ||
class UnitCargo: | ||
ice: int = 0 | ||
ore: int = 0 | ||
water: int = 0 | ||
metal: int = 0 |
Oops, something went wrong.