diff --git a/README.md b/README.md index e8f88f3b5..075b4e709 100644 --- a/README.md +++ b/README.md @@ -15,21 +15,23 @@ [![Paper](https://img.shields.io/badge/Paper-arXiv:2504.04395-red)](https://arxiv.org/abs/2504.04395) [![Website](https://img.shields.io/badge/Project-Website-blue)](https://metamon.tech) +[![Discord](https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white)](https://discord.gg/9zuJqgDpGg)
- - **Metamon** enables reinforcement learning (RL) research on [Pokémon Showdown](https://pokemonshowdown.com/) by providing: + +1) 20+ pretrained policies ranging from ~average to expert-level human play. +2) A dataset of >4M (and counting) trajectories "reconstructed" from real human battles. +3) A dataset of >20M (and counting) trajectories generated by self-play between agents. +3) Starting points for training (or finetuning) your own imitation learning (IL) and RL policies. +5) A standardized suite of teams and heuristic opponents for evaluation. -1) A standardized suite of teams and opponents for evaluation. -2) A large dataset of RL trajectories "reconstructed" from real human battles. -3) Starting points for training imitation learning (IL) and RL policies. +Metamon is the codebase behind ["Human-Level Competitive Pokémon via Scalable Offline RL and Transformers"](https://arxiv.org/abs/2504.04395) (RLC, 2025). Please check out our [project website](https://metamon.tech) for an overview of our original results. After the release of our conference paper, metamon served as a starter kit and winning baseline for the [NeurIPS 2025 PokéAgent Challenge](https://pokeagent.github.io), which motivated significant improvements to our results and datasets. -Metamon is the codebase behind ["Human-Level Competitive Pokémon via Scalable Offline RL and Transformers"](https://arxiv.org/abs/2504.04395) (RLC, 2025). Please check out our [project website](https://metamon.tech) for an overview of our results. This README documents the dataset, pretrained models, training, and evaluation details to help you get battling!
@@ -41,24 +43,13 @@ Metamon is the codebase behind ["Human-Level Competitive Pokémon via Scalable O #### Supported Rulesets -Pokémon Showdown hosts many different rulesets spanning nine generations of the video game franchise. Metamon initially focused on the most popular singles ruleset ("OverUsed") for **Generations 1, 2, 3, and 4**. However, we are gradually expanding to Gen 9 to support the [NeurIPS 2025 PokéAgent Challenge](https://pokeagent.github.io). This is a large project that will not be finalized in time for the competition launch; please stay tuned for updates. - -The current status is: - -| | Gen 1 OU | Gen 2 OU | Gen 3 OU | Gen 4 OU | Gen 9 OU | -|------------|---------------------|----------|----------|----------|----------| -| Datasets | ✅ | ✅ | ✅ | ✅ | 🟠 (beta) | -| Teams | ✅ | ✅ | ✅ | ✅ | ✅ | -| Heuristic Baselines | ✅ | ✅ | ✅ | ✅ | ✅ | -| Learned Baselines | ✅ | ✅ | ✅ | ✅ | 🟠 (beta) | - -We also support the UnderUsed (UU), NeverUsed (NU), and Ubers tiers for Generations 1, 2, 3, and 4 —-- though constant rule changes and small dataset sizes have always made these a bit of an afterthought. - +Pokémon Showdown hosts many different rulesets spanning nine generations of the video game franchise. Metamon initially focused on the most popular singles ruleset ("OverUsed") for **Generations 1, 2, 3, and 4** but has recently expanded to include **Generation 9 OverUsed** (OU). We also support the UnderUsed (UU), NeverUsed (NU), and Ubers tiers for Generations 1, 2, 3, and 4 – though constant rule changes and small dataset sizes have always made these a bit of an afterthought.
### Table of Contents + 1. [**Installation**](#installation) 2. [**Quick Start**](#quick-start) @@ -79,9 +70,11 @@ We also support the UnderUsed (UU), NeverUsed (NU), and Ubers tiers for Generati 10. [**Battle Backends**](#battle-backends) -11. [**Acknowledgement**](#acknowledgements) +11. [**FAQ**](#faq) -12. [**Citation**](#citation) +12. [**Acknowledgement**](#acknowledgements) + +13. [**Citation**](#citation)
@@ -90,7 +83,8 @@ We also support the UnderUsed (UU), NeverUsed (NU), and Ubers tiers for Generati
-## Installation +
+

Installation

Metamon is written and tested for linux and python 3.10+. We recommend creating a fresh virtual environment or [conda](https://docs.anaconda.com/anaconda/install/) environment: @@ -135,6 +129,8 @@ Metamon provides large datasets of Pokémon team files, human battles, and other export METAMON_CACHE_DIR=/path/to/plenty/of/disk/space ``` +
+
____ @@ -228,17 +224,6 @@ online_dset.refresh_files() You are free to use this data to train an agent however you'd like, but we provide starting points for smaller-scale IL (`python -m metamon.il.train`) and RL (`python -m metamon.rl.train`), and a large set of pretrained models from our paper. - -### PokéAgent Challenge -To run agents on the [PokéAgent Challenge ladder](http://pokeagentshowdown.com.insecure.psim.us/): - -1. Go to the link above and click "Choose name" in the top right corner. *Pick a username that begins with `"PAC"`*. - -2. Click the gear icon, then "register", and create a password. - -2. Use `metamon.env.PokeAgentLadder` exactly how you use `QueueOnLocalLadder` in local tests. Provide your account details with `player_username` and `player_password` args. - -
____ @@ -248,7 +233,7 @@ ____ ## Pretrained Models -We have made every checkpoint of 20 models available on huggingface at [`jakegrigsby/metamon`](https://huggingface.co/jakegrigsby/metamon/tree/main). Pretrained models can run without research GPUs, but you will need to install [`amago`](https://github.com/UT-Austin-RPL/amago), which is an RL codebase by the same authors. Follow instructions [here](https://ut-austin-rpl.github.io/amago/installation.html). +We have made every checkpoint of 29 models available on huggingface at [`jakegrigsby/metamon`](https://huggingface.co/jakegrigsby/metamon/tree/main). You will need to install [`amago`](https://github.com/UT-Austin-RPL/amago), which is an RL codebase by the same authors. Follow instructions [here](https://ut-austin-rpl.github.io/amago/installation.html).
@@ -260,7 +245,7 @@ We have made every checkpoint of 20 models available on huggingface at [`jakegri Load and run pretrained models with `metamon.rl.evaluate`. For example: ```bash -python -m metamon.rl.evaluate --eval_type heuristic --agent SyntheticRLV2 --gens 1 --formats ou --total_battles 100 +python -m metamon.rl.evaluate --eval_type heuristic --agent Kakuna --gens 1 --formats ou --total_battles 100 ``` Will run the default checkpoint of the best model for 100 battles against a set of heuristic baselines highlighted in the paper. @@ -268,45 +253,96 @@ Will run the default checkpoint of the best model for 100 battles against a set Or to battle against whatever is logged onto the local Showdown server (including other pretrained models that are already waiting): ```bash -python -m metamon.rl.evaluate --eval_type ladder --agent SyntheticRLV2 --gens 1 --formats ou --total_battles 50 --username --team_set competitive -``` - -Deploy pretrained agents on the PokéAgent Challenge ladder: - -```bash -python -m metamon.rl.evaluate --eval_type pokeagent --agent SyntheticRLV2 --gens 1 --formats ou --total_battles 10 --username --password --team_set competitive +python -m metamon.rl.evaluate --eval_type ladder --agent Kakuna --gens 1 --formats ou --total_battles 50 --username --team_set competitive ``` -Some model sizes have several variants testing different RL objectives. See `metamon/rl/pretrained.py` for a complete list. +
-### Paper Policies -*Paper policies play Gens 1-4 and are discussed in detail in the RLC 2025 paper.* +### Featured Policies + +There are now **29 official metamon models**. Most of them were stepping stones to later (better) versions, and are now mainly useful as baselines or extra opponents in self-play data collection. Some notable exceptions worth knowing about are: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelSizeDateDescriptionHuman Ladder Ratings (GXE)
G1G2G3G4G9

SyntheticRLV2
200MSep 2024Original paper's best policy. Remains the basis of several successful third-party efforts to specialize in Gen1. Most previous models have complete human ratings (see Paper Policies below), but we have become a lot more cautious about laddering.77%68%64%66%

Abra
57MJul 2025The best gen9ou agent that was open-sourced during the PokéAgent Challenge, and therefore the basis of many of the best third-party metamon extensions.50%

Kadabra3
57MSep 2025The best policy trained in time to participate in the PokéAgent Challenge (as an organizer baseline). #1 in the Gen1OU qualifier and #2 in Gen9OU behind foul-play.80%64%

Alakazam
57MSep 2025The final version of the PokéAgent Challenge effort. Patched a bug that made tera types invisible to the policy, which makes it the best candidate for future work at this model size.

Kakuna
142MDec 2025The best public metamon model – leading by nearly every metric. Trained on diverse teams to serve as a strong foundation for further research in any gen. Appears on all 5 OU leaderboards and is consistently 1500+ Elo in Gen1OU.82%70%63%64%71%
+ + + +Models can be loosely divided into two eras of active development: + +1. **RLC Paper** (Jan 2024 – Feb 2025): Trained on Gen 1-4 with old versions of the replay dataset and team sets. +2. **NeurIPS PokéAgent Challenge** (July – November 2025): Basically restarted from scratch. Broadly speaking, we *reduced* model sizes, reward shaping, and the paper's emphasis on long-term memory while *improving* generalization over diverse team choices and prioritizing support for gen9ou. However, it took several iterations to recover the paper's Gen 1-4 performance. -| Model Name (`--agent`) | Description | -|-----------------------------|-----------------------------------------------------------------------------| -| **`SmallIL`** (2 variants) | 15M imitation learning model trained on 1M human battles | -| **`SmallRL`** (5 variants) | 15M actor-critic model trained on 1M human battles | -| **`MediumIL`** | 50M imitation learning model trained on 1M human battles | -| **`MediumRL`** (3 variants) | 50M actor-critic model trained on 1M human battles | -| **`LargeIL`** | 200M imitation learning model trained on 1M human battles | -| **`LargeRL`** | 200M actor-critic model trained on 1M human battles | -| **`SyntheticRLV0`** | 200M actor-critic model trained on 1M human + 1M diverse self-play battles | -| **`SyntheticRLV1`** | 200M actor-critic model trained on 1M human + 2M diverse self-play battles | -| **`SyntheticRLV1_SelfPlay`** | SyntheticRLV1 fine-tuned on 2M extra battles against itself | -| **`SyntheticRLV1_PlusPlus`** | SyntheticRLV1 finetuned on 2M extra battles against diverse opponents | -| **`SyntheticRLV2`** | Final 200M actor-critic model with value classification trained on 1M human + 4M diverse self-play battles. | -### PokéAgent Challenge Policies -*Policies trained during the PokéAgent Challenge play Gens 1-4 **and 9**, but have a clear bias towards Gen 1 OU and Gen 9 OU. Their docstrings in `metamon/rl/pretrained.py` have some extra discussion and eval metrics.* -| Model Name (`--agent`) | Description | -|-----------------------------|-----------------------------------------------------------------------------| -| **`SmallRLGen9Beta`** | Prototype 15M actor-critic model trained *after* the dataset was expanded to include Gen9OU | -| **`Abra`** | 57M actor-critic trained on `parsed-replays v3` and a small set of synthetic battles. First of a new series of Gen9OU-compatible policies trained in a similar style to the paper's "Synthetic" agents.| -| **`Kadabra, Alakazam, Alakazam2, Alakazam3`** | Are further extensions to large datasets of self-play battles (> 11M). They appear on the PokéAgent Challenge practice ladder, but checkpoint releases are on hold to avoid interfering with the competition. | -| **`Minikazam`** | 4.7M RNN trained on `parsed-replays v4` and a large dataset of self-play battles. Tries to compensate for low parameter count by training on `Alakazam`'s dataset. Creates a decent starting point for finetuning on any GPU. [Evals here](https://docs.google.com/spreadsheets/d/1GU7-Jh0MkIKWhiS1WNQiPfv49WIajanUF4MjKeghMAc/edit?usp=sharing). | +
+
+

Paper Policies

+ +*Paper policies play Gens 1-4 and are discussed in detail in the RLC 2025 paper. Some model sizes have several variants testing different RL objectives. See `metamon/rl/pretrained.py` for a complete list.* + + + + + + + + + + + + + + + + + + +
Model Name (--agent)Description
SmallIL (2 variants)15M imitation learning model trained on 1M human battles
SmallRL (5 variants)15M actor-critic model trained on 1M human battles
MediumIL50M imitation learning model trained on 1M human battles
MediumRL (3 variants)50M actor-critic model trained on 1M human battles
LargeIL200M imitation learning model trained on 1M human battles
LargeRL200M actor-critic model trained on 1M human battles
SyntheticRLV0200M actor-critic model trained on 1M human + 1M diverse self-play battles
SyntheticRLV1200M actor-critic model trained on 1M human + 2M diverse self-play battles
SyntheticRLV1_SelfPlaySyntheticRLV1 fine-tuned on 2M extra battles against itself
SyntheticRLV1_PlusPlusSyntheticRLV1 finetuned on 2M extra battles against diverse opponents
SyntheticRLV2Final 200M actor-critic model with value classification trained on 1M human + 4M diverse self-play battles.
Here is a reference of human evals for key models according to our paper: @@ -314,8 +350,402 @@ Here is a reference of human evals for key models according to our paper: Figure 1
+ + +
+ +
+

PokéAgent Challenge Policies

+ +*Policies trained during the PokéAgent Challenge play Gens 1-4 **and 9**, but have a clear bias towards Gen 1 OU and Gen 9 OU. Their docstrings in `metamon/rl/pretrained.py` have some extra discussion and eval metrics.* + + + + + + + + + + + + + + +
Model Name (--agent)Description
SmallRLGen9BetaPrototype 15M actor-critic model trained after the dataset was expanded to include Gen9OU
Abra57M actor-critic trained on parsed-replays v3 and a small set of synthetic battles. First of a new series of Gen9OU-compatible policies trained in a similar style to the paper's "Synthetic" agents.
Kadabra, Kadabra2, Kadabra3, Kadabra4Are further extensions of Abra to larger datasets of self-play battles (> 11M) trained and deployed as organizer baselines throughout the PokéAgent Challenge practice ladder.
AlakazamConsidered the final edition of the main PokéAgent Challenge effort. Patches a bug that impacted tera type visibility. Actually slightly worse than Kadabra3/4 with competitive teams, but is more robust to diverse team choices thanks to a larger dataset.
Minikazam4.7M RNN trained on parsed-replays v4 and a large dataset of self-play battles. Tries to compensate for low parameter count by training on Alakazam's dataset. Creates a decent starting point for finetuning on any GPU. Evals here.
SuperkazamAn attempt to revisit Alakazam's (11M self-play + 4M human replay) dataset at a model size closer to the original paper (142M). Evals here.
KakunaThe best public metamon agent. Superkazam finetuned on 7M additional self-play battles collected at higher sampling temperature for improved exploration and value estimation. Reduced sampling weight of human replays to prioritize high-Elo self-play data. Compensates for our inattention to Gens2-4 during the PokéAgent Challenge. Evals here.
+ +
+ +
+ +### Internal Leaderboards + + +The human ratings above are clearly the best way to anchor performance to an external metric, but we primarily rely on self comparisons across generations and [team sets](#team-sets) to guide new research. We typically use head-to-head comparisons between key baselines: see [this Kakuna eval](https://docs.google.com/spreadsheets/d/1lU8tQ0tnnupY28kIyK6FVtvPmxLSVT9_slLShOhRsqg/edit?usp=sharing) as an example. But we can get a general sense of the ***relative* strength** of metamon over time by turning policies loose on a locally hosted Showdown ladder and sampling from the same `TeamSet`. + +*![Gold](https://img.shields.io/badge/Gold-DAA520?style=flat) = PokéAgent Challenge policy, ![Pink](https://img.shields.io/badge/Pink-E91E63?style=flat) = Paper policy.* + + +> [!TIP] +> *These GXE values are a measure of performance *relative* to the listed models and **have no connection to ratings on the public ladder**.* + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Early Gen OU Local GXE
ModelCompetitive TeamSetModern Replays TeamSetAvg Rank
G1G2G3G4G1G2G3G4
Kakuna75%66%63%60%68%71%67%69%1.0
Superkazam67%63%59%58%64%61%62%61%3.0
Kadabra466%60%58%58%68%60%66%63%3.5
Kadabra368%61%57%57%67%60%60%60%4.0
Kadabra267%60%58%57%64%62%59%60%4.4
Alakazam66%59%56%57%64%58%61%58%5.5
SynRLV250%59%55%55%54%61%55%56%6.9
Kadabra56%50%47%47%55%53%50%54%7.9
SynRLV1++43%47%41%45%47%49%48%48%10.0
SynRLV143%39%42%46%46%45%44%49%10.2
SynRLV041%38%48%40%45%41%49%45%11.1
Abra39%44%44%45%40%45%48%48%11.2
SmallRLGen9Beta44%42%45%48%12.0
LargeRL25%35%39%39%30%39%41%44%13.9
Minikazam39%34%34%34%41%36%36%39%14.6
SmallILFA24%36%39%35%28%35%38%41%14.8
+ > [!TIP] -> Most these policies predate our expansion to Gen 9. They *can* play Gen 9 OU, but won't play it well. Gen 9 training runs are ongoing. +> ![Paper Policies](https://img.shields.io/badge/Paper%20Policies-E91E63?style=flat) are (predictably) weak in Gen9OU because they were never trained to play the format and use observation spaces that assume Team Preview is not available. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Gen9OU Local GXE
ModelCompetitive TeamSetModern Replays TeamSetAvg Rank
Kakuna76%74%1.0
Superkazam75%73%2.5
Kadabra475%73%2.5
Kadabra373%71%4.5
Kadabra273%69%5.0
Alakazam73%71%5.5
Abra61%57%7.0
SmallRLGen9Beta56%57%8.5
Kadabra58%55%8.5
Minikazam50%50%10.0
SynRLV032%36%11.5
SynRLV232%38%11.5
SynRLV1++32%33%13.5
LargeRL29%34%14.0
SynRLV131%32%14.5
SmallILFA23%27%16.0

@@ -326,9 +756,13 @@ ____ ## Battle Datasets -Showdown creates "replays" of battles that players can choose to upload to the website before they expire. We gathered all surviving historical replays for Gen 1-4 OU/NU/UU/Ubers and Gen 9 OU, and continuously save new battles to grow the dataset. + +Metamon provides two types of offline RL datasets in a flexible format that lets you [customize observations, rewards, and actions on-the-fly](#observation-spaces-action-spaces--reward-functions). +### Human Replay Datasets + +Showdown creates "replays" of battles that players can choose to upload to the website before they expire. We gathered all surviving historical replays for Gen 1-4 OU/NU/UU/Ubers and Gen 9 OU, and continuously save new battles to grow the dataset.
Dataset Overview @@ -343,14 +777,68 @@ Datasets are stored on huggingface in two formats: |**[`metamon-raw-replays`](https://huggingface.co/datasets/jakegrigsby/metamon-raw-replays)** | 2M Battles | Our curated set of Pokémon Showdown replay `.json` files... to save the Showdown API some download requests and to maintain an official reference of our training data. Will be regularly updated as new battles are played and collected. | |**[`metamon-parsed-replays`](https://huggingface.co/datasets/jakegrigsby/metamon-parsed-replays)** | 4M Trajectories | The RL-compatible version of the dataset as reconstructed by the [replay parser](metamon/backend/replay_parser/README.md). This dataset has been significantly expanded and improved since the original paper.| -Parsed replays will download automatically when requested by the `ParsedReplayDataset`, but these datasets are large. Use `python -m metamon.data.download parsed-replays` to download them in advance. +Parsed replays will download automatically when requested by the `ParsedReplayDataset`, but these datasets are large. Download in advance with: + +```bash +python -m metamon.data.download parsed-replays +``` + +```python +from metamon.data import ParsedReplayDataset + +replay_dset = ParsedReplayDataset( + observation_space=obs_space, + action_space=action_space, + reward_function=reward_func, + formats=["gen1ou", "gen9ou"], +) +obs_seq, action_seq, reward_seq, done_seq = replay_dset[0] +``` #### Server/Replay Sim2Sim Gap In Showdown RL, we have to embrace a **mismatch between the trajectories we *observe in our own battles* and those we *gather from other player's replays***. In short, replays are saved from the point-of-view of a *spectator* rather than the point-of-view of a *player*. The server sends info to the players that it does not save to its replay, and we need to try and simulate that missing info. Metamon goes to great lengths to handle this, and is always improving ([more info here](metamon/backend/replay_parser/README.md)), but there is no way to be perfect. -**Therefore, replay data is perhaps best viewed as pretraining data for an offline-to-online finetuning problem.** Self-collected data from the online env fixes inaccuracies and can help concentrate on teams we'll be using on the ladder. The whole project is now set up to do this (see [Quick Start](#quick-start)). +**Therefore, replay data is perhaps best viewed as pretraining data for an offline-to-online finetuning problem.** Self-collected data from the online env fixes inaccuracies and can help concentrate on teams we'll be using on the ladder. The whole project is now set up to do this (see [Quick Start](#quick-start)), and we have open-sourced large self-play sets (below). + + +
+ +### Self-Play Datasets + +Almost all improvement in `metamon`'s performance is driven by large and diverse datasets of agent vs. agent battles. Public self-play datasets are stored on huggingface at [`jakegrigsby/metamon-parsed-pile`](https://huggingface.co/datasets/jakegrigsby/metamon-parsed-pile). Trajectories were generated by the `rl/self_play` launcher with various team sets and model pools. + + There are currently two subsets: + + +| Name | Size | Description | +|------|------|-------------| +|**`pac-base`** | 11M Trajectories | Partially comprised of battles played by organizer baselines on the PokéAgent Challenge practice ladder, but the vast majority are battles collected locally for the purposes of training the ![Abra](https://img.shields.io/badge/Abra-DAA520?style=flat), ![Kadabra](https://img.shields.io/badge/Kadabra-DAA520?style=flat), and ![Alakazam](https://img.shields.io/badge/Alakazam-DAA520?style=flat) line of policies. The version uploaded here trained ![Alakazam](https://img.shields.io/badge/Alakazam-DAA520?style=flat), and previous models were trained on subsets of this dataset. | +|**`pac-exploratory`** | 7M Trajectories | Self-play revisited after the NeurIPS challenge with higher sampling temperature (to improve value estimates of sub-optimal actions). Notably also includes battles of official metamon policies against `PA-Agent` (the winning team of the gen1ou tournament), who trained a great policy by (~overfitting) ![SynRLV2](https://img.shields.io/badge/SynRLV2-E91E63?style=flat) to the "competitive" gen1ou team set. This has inspired a fresh approach of distilling specialized policies back into the main line models. ![Kakuna](https://img.shields.io/badge/Kakuna-DAA520?style=flat) was trained on `metamon-parsed-replays`, `pac-base`, and `pac-exploratory`.| + +Self-play data will download automatically when requested by the `SelfPlayDataset`, but these datasets are large. Download in advance with: + +```bash +python -m metamon.data.download self-play +``` + +This downloads both subsets for all available formats (gen1ou, gen2ou, gen3ou, gen4ou, gen9ou). You can also specify formats explicitly: `--formats gen1ou gen9ou`. The download includes pre-built SQLite indexes for fast loading. + +```python +from metamon.data import SelfPlayDataset + +self_play_dset = SelfPlayDataset( + observation_space=obs_space, + action_space=action_space, + reward_function=reward_func, + subset="pac-base", # or "pac-exploratory" + formats=["gen1ou", "gen9ou"], +) +obs_seq, action_seq, reward_seq, done_seq = self_play_dset[0] +``` + +These datasets are currently only available in the parsed replay format, which makes them liable to be deprecated should that format change or a major bug in the [battle backend](#battle-backends) be found. When/if this happens, the [replay parser](metamon/backend/replay_parser/README.md) would be expanded to parse ground-truth battle logs and the datasets would be re-released as a noisier aggregate of all the logs from every metamon development server during the same time period. @@ -396,7 +884,7 @@ ___ ## Baselines -`baselines/` contains baseline opponents that we can battle against via `BattleAgainstBasline`. `baselines/heuritics` provides more than a dozen heuristic opponents and starter code for developing new ones (or mixing ground-truth Pokémon knowledge into ML agents). `baselines/model_based` ties the simple `il` model checkpoints to `poke-env` (with CPU inference). +`baselines/` contains baseline opponents that we can battle against via `BattleAgainstBaseline`. `baselines/heuristics` provides more than a dozen heuristic opponents and starter code for developing new ones (or mixing ground-truth Pokémon knowledge into ML agents). `baselines/model_based` ties the simple `il` model checkpoints to `poke-env` (with CPU inference). Here is an overview of the opponents mentioned in the paper: @@ -447,7 +935,8 @@ Metamon tries to separate the RL from Pokémon. All we need to do is pick an `Ob 5. The environment takes the current (`UniversalState`, `UniversalAction`) and outputs the next `UniversalState`. Our `RewardFunction` gives the agent a scalar reward. 7. Repeat until victory. -### Observations +
+

Observations

`UniversalState` defines all the features we have access to at each timestep. @@ -458,7 +947,7 @@ We could create a custom version with more/less features by inheriting from `met |--------------------------------------|-----------------------------------------------------------------------------| | `DefaultObservationSpace` | The text/numerical observation space used in our paper. | | `ExpandedObservationSpace` | A slight improvement based on lessons learned from the paper. It also adds tera types for Gen 9. | -| `TeamPreviewObeservationSpace` | Further extends `ExpandedObservationSpace` with a preview of the opponent's team (for Gen 9). | +| `TeamPreviewObservationSpace` | Further extends `ExpandedObservationSpace` with a preview of the opponent's team (for Gen 9). | | `OpponentMoveObservationSpace` | Modifies `TeamPreviewObservationSpace` to include the opponent Pokémon's revealed moves. Continues our trend of deemphasizing long-term memory. | ##### Tokenization @@ -468,7 +957,7 @@ of known vocab words. The built-in observation spaces are designed such that the ```python from metamon.interface import TokenizedObservationSpace, DefaultObservationSpace -from metamon.tokenizer import get_toknenizer +from metamon.tokenizer import get_tokenizer base_obs = DefaultObservationSpace() tokenized_space = TokenizedObservationSpace( @@ -486,7 +975,10 @@ words across the entire replay dataset, with an unknown token for rare cases we |`DefaultObservationSpace-v0`| Updated post-release vocabulary as of `metamon-parsed-replays` dataset `v2`. | |`DefaultObservationSpace-v1`| Updated vocabulary as of `metamon-parsed-replays` dataset `v3-beta` (adds ~1k words for Gen 9). | -### Actions +
+ +
+

Actions

Metamon uses a fixed `UniversalAction` space of 13 discrete choices: - `{0, 1, 2, 3}` use the active Pokémon's moves in alphabetical order. @@ -500,20 +992,25 @@ That might not be how we want to set up our agent. The `ActionSpace` converts be | `DefaultActionSpace` | Standard discrete space of 13 and supports Gen 9. | | `MinimalActionSpace` | The original space of 9 choices (4 moves + 5 switches) --- which is all we need for Gen 1-4. | -Any new action spaces would be added to `metamon.interface.ALL_ACTION_SPACES`. A text action space (for LLM-Agents) is on the short-term roadmap. +Any new action spaces would be added to `metamon.interface.ALL_ACTION_SPACES`. A text action space (for LLM-Agents) is on the short-term roadmap. + +
-### Rewards +
+

Rewards

-Reward functions assign a scalar reward based on consecutive states (R(s, s')). - +Reward functions assign a scalar reward based on consecutive states (R(s, s')). | Reward Function | Description | |--------------------------|----------------------------------------------------------------------------------------------| | `DefaultShapedReward` | Shaped reward used by the paper. +/- 100 for win/loss, light shaping for damage dealt, health recovered, status received/inflicted. | | `BinaryReward` | Removes the smaller shaping terms and simply provides +/- 100 for win/loss. | -| `AggresiveShapedReward` | Edits `DefaultShapedReward`'s sparse reward to +200 for winning +0 for losing. | +| `AggressiveShapedReward` | Edits `DefaultShapedReward`'s sparse reward to +200 for winning +0 for losing. | Any new reward functions would be added to `metamon.interface.ALL_REWARD_FUNCTIONS`, and we can implement a new one by inheriting from `metamon.interface.RewardFunction`. +
+ ---
@@ -558,7 +1055,7 @@ Provides the same setup as the main `train` script but takes care of downloading We might finetune "`SmallRL`" to the new gen 9 replay dataset and custom battles like this: ```bash -python -m metamon.rl.finetune_from_hf --finetune_from_model SmallRL --run_name MyCustomSmallRL --save_dir ~/metamon_finetunes/ --custom_replay_dir /my/custom/parsed_replay_dataset --custom_replay_sample_weight .25 --epochs 10 --steps_per_epoch 10000 --log --formats gen9ou --eval_gens 9 +python -m metamon.rl.finetune_from_hf --finetune_from_model SmallRL --run_name MyCustomSmallRL --save_dir ~/metamon_finetunes/ --custom_replay_dir /my/custom/parsed_replay_dataset --custom_replay_weight .25 --epochs 10 --steps_per_epoch 10000 --log --formats gen9ou --eval_gens 9 ``` You can start from any checkpoint number with `--finetune_from_ckpt`. See the huggingface for a full list. Defaults to the official eval checkpoint. @@ -604,7 +1101,9 @@ python train.py --run_name any_name_will_do --model_config configs/transformer_e To support the main [raw-replays](https://huggingface.co/datasets/jakegrigsby/metamon-raw-replays), [parsed-replays](https://huggingface.co/datasets/jakegrigsby/metamon-parsed-replays), and [teams](https://huggingface.co/datasets/jakegrigsby/metamon-teams) datasets, metamon creates a few resources that may be useful for other purposes: - #### Usage Stats +
+

Usage Stats

+ Showdown records the frequency of team choices (items, moves, abilities, etc.) brought to battles in a given month. The community mainly uses this data to consider rule changes, but we use it to help predict missing details of partially revealed teams. We load data for an arbitrary window of history around the date a battle was played, and fall back to all-time stats for rare Pokémon where data is limited: ```python @@ -624,7 +1123,11 @@ python -m metamon.data.download usage-stats The data is stored on huggingface at [`jakegrigsby/metamon-usage-stats`](https://huggingface.co/datasets/jakegrigsby/metamon-usage-stats). -#### Revealed Teams +
+ +
+

Revealed Teams

+ One of the main problems the replay parser has to solve is predicting a player's full team based on the "partially revealed" team at the end of the battle. As part of this, we record the revealed team in the [standard Showdown team builder format](https://pokepast.es/syntax.html), but with some magic keywords for missing elements. For example: ``` @@ -647,6 +1150,8 @@ python -m metamon.data.download revealed-teams `metamon/backend/team_prediction` contains tools for filling in the blanks of these files, but this is all poorly documented and changes frequently, so we'll leave it at that for now. +
+ ----
@@ -671,6 +1176,35 @@ A `PretrainedAgent`saves the backend it "should" be evaluated with (if you're us
+ + ## FAQ + + + #### How can I contribute? + +Please get in touch! Currently, the easiest place to reach us is via the [PokéAgent Challenge Discord Server](https://discord.gg/9zuJqgDpGg). You can also email the lead author. + + + #### Why do you focus on Gens 1-4? + + Because there is no team preview before Gen 5, and inferring hidden information via long-term memory was our main focus from an RL research perspective. There's more about this in the paper. A common criticism was that we were avoiding the complexity that comes with later generations' increase in the number of available Pokémon, items, abilities, and so on. If this gap exists, it is more than made up for by the volume of gen9ou replays, as Gen9OU is now arguably our second best format. + + + #### Will you add support for the missing Gens 5-8? + + The main engineering barrier is the [replay parser and dataset](metamon/backend/replay_parser/README.md), which supports gen9 but would surely need some updates for backwards-compatible edge cases. This is not a huge job... but redoing the self-play training process to catch up to the performance in existing gens would be. We would definitely accept contributions on this front, but honestly have no plans to do it ourselves, as in our opinion the expansion to gen9 answered research doubts about generality and model-free RL at low search depth and new (singles) formats are more Showdown infra trouble than they're worth. + + + #### What about VGC (doubles)? + +Support for VGC has been in development but we aren't announcing any timelines on this just yet. + + + +
+ + + ## Acknowledgements This project owes a huge debt to the amazing [`poke-env`](https://github.com/hsahovic/poke-env), as well Pokémon resources like [Bulbapedia](https://bulbapedia.bulbagarden.net/wiki/Main_Page), [Smogon](https://www.smogon.com), and of course [Pokémon Showdown](https://github.com/smogon/pokemon-showdown). diff --git a/media/icons/abra.png b/media/icons/abra.png new file mode 100644 index 000000000..1e940030a Binary files /dev/null and b/media/icons/abra.png differ diff --git a/media/icons/alakazam.png b/media/icons/alakazam.png new file mode 100644 index 000000000..82a963cb9 Binary files /dev/null and b/media/icons/alakazam.png differ diff --git a/media/icons/ditto.png b/media/icons/ditto.png new file mode 100644 index 000000000..1bffa5ea1 Binary files /dev/null and b/media/icons/ditto.png differ diff --git a/media/icons/kadabra.png b/media/icons/kadabra.png new file mode 100644 index 000000000..da829f7bc Binary files /dev/null and b/media/icons/kadabra.png differ diff --git a/media/icons/kakuna.png b/media/icons/kakuna.png new file mode 100644 index 000000000..b406d8f9c Binary files /dev/null and b/media/icons/kakuna.png differ diff --git a/media/icons/minikazam.png b/media/icons/minikazam.png new file mode 100644 index 000000000..ef9bf2c1f Binary files /dev/null and b/media/icons/minikazam.png differ diff --git a/metamon/__init__.py b/metamon/__init__.py index c9b717bf4..f35393e23 100644 --- a/metamon/__init__.py +++ b/metamon/__init__.py @@ -1,7 +1,46 @@ import os from importlib.metadata import version -__version__ = "1.4.0" +__version__ = "1.5.0" + +# ANSI color codes +_YELLOW = "\033[38;5;228m" +_BLUE = "\033[94m" +_CYAN = "\033[96m" +_RED = "\033[91m" +_WHITE = "\033[97m" +_BOLD = "\033[1m" +_DIM = "\033[2m" +_RESET = "\033[0m" + +_METAMON_LOGO_LINES = [ + r" __ ___ __ ", + r" / |/ /__ / /_____ _____ ___ ____ ____ ", + " / /|_/ / _ \\/ __/ __ `/ __ `__ \\/ __ \\/ __ \\", + " / / / / __/ /_/ /_/ / / / / / / /_/ / / / /", + "/_/ /_/\\___/\\__/\\__,_/_/ /_/ /_/\\____/_/ /_/ ", +] + + +def print_banner(): + print(f'{_BLUE}╔{"═" * 60}╗{_RESET}') + + for line in _METAMON_LOGO_LINES: + padding = 60 - len(line) + print( + f'{_BLUE}║{_RESET}{_YELLOW}{_BOLD}{line}{" " * padding}{_RESET}{_BLUE}║{_RESET}' + ) + + print(f'{_BLUE}╠{"═" * 60}╣{_RESET}') + tagline = f"Pokémon Showdown RL • v{__version__} • UT-Austin-RPL/metamon" + pad_left = (60 - len(tagline)) // 2 + pad_right = 60 - len(tagline) - pad_left + print( + f'{_BLUE}║{_RESET}{" " * pad_left}{_WHITE}{tagline}{_RESET}{" " * pad_right}{_BLUE}║{_RESET}' + ) + print(f'{_BLUE}╚{"═" * 60}╝{_RESET}') + print() + poke_env_version = version("poke-env") diff --git a/metamon/backend/showdown_dex/__init__.py b/metamon/backend/showdown_dex/__init__.py index 736c1cc47..e254d6fcb 100644 --- a/metamon/backend/showdown_dex/__init__.py +++ b/metamon/backend/showdown_dex/__init__.py @@ -1 +1 @@ -from .dex import Dex \ No newline at end of file +from .dex import Dex diff --git a/metamon/backend/team_preview/gen9ou_high_elo_v4/best_model.pt b/metamon/backend/team_preview/gen9ou_high_elo_v4/best_model.pt new file mode 100644 index 000000000..51620b94d Binary files /dev/null and b/metamon/backend/team_preview/gen9ou_high_elo_v4/best_model.pt differ diff --git a/metamon/backend/team_preview/preview.py b/metamon/backend/team_preview/preview.py new file mode 100644 index 000000000..efbee84f3 --- /dev/null +++ b/metamon/backend/team_preview/preview.py @@ -0,0 +1,834 @@ +""" +Team preview prediction: predict which pokemon to lead with given both teams. +12 inputs (6 ours + 6 opponent) -> 1 output (which of our 6 to lead). +Perceiver-style cross attention architecture. +""" + +import os +import json +import random +import lz4.frame +from typing import Optional, List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from einops import rearrange +import wandb +from tqdm import tqdm + +from metamon.interface import ( + UniversalState, + consistent_pokemon_order, + consistent_move_order, +) +from metamon.tokenizer import PokemonTokenizer, get_tokenizer +from metamon.backend.replay_parser.str_parsing import pokemon_name, move_name +from metamon.il.model import CrossAttentionBlock, SelfAttentionBlock +from metamon.data.download import download_parsed_replays + + +class TeamPreviewDataset(Dataset): + """Dataset for team preview prediction from parsed replays.""" + + def __init__( + self, + tokenizer: PokemonTokenizer, + battle_format: str = "gen9ou", + dset_root: Optional[str] = None, + min_rating: int = 1300, + max_rating: Optional[int] = None, + wins_losses_both: str = "both", + max_samples: Optional[int] = None, + shuffle: bool = True, + ): + self.tokenizer = tokenizer + self.battle_format = battle_format + self.min_rating = min_rating + self.max_rating = max_rating + self.wins_losses_both = wins_losses_both + self.shuffle = shuffle + + if dset_root is None: + print(f"Downloading {battle_format} parsed replays...") + format_path = download_parsed_replays(battle_format) + dset_root = os.path.dirname(format_path) + + format_dir = os.path.join(dset_root, battle_format) + if not os.path.exists(format_dir): + raise ValueError(f"Format directory not found: {format_dir}") + + self.filenames = self._find_and_filter_files(format_dir) + if len(self.filenames) == 0: + raise ValueError(f"No replays found for {battle_format} with given filters") + + print(f"Found {len(self.filenames)} {battle_format} replays matching filters") + + if max_samples and max_samples < len(self.filenames): + random.shuffle(self.filenames) + self.filenames = self.filenames[:max_samples] + print(f"Using {len(self.filenames)} samples") + + def _rating_to_int(self, rating_str: str) -> int: + try: + return int(rating_str) + except ValueError: + return 1000 + + def _find_and_filter_files(self, format_dir: str) -> List[str]: + filenames = [] + all_files = os.listdir(format_dir) + json_files = [f for f in all_files if f.endswith((".json", ".json.lz4"))] + + has_rating_filter = self.min_rating is not None or self.max_rating is not None + has_result_filter = self.wins_losses_both in ("wins", "losses") + + for filename in json_files: + name_without_ext = ( + filename[:-9] if filename.endswith(".json.lz4") else filename[:-5] + ) + parts = name_without_ext.split("_") + if len(parts) != 7: + continue + + battle_id, rating_str, p1_name, _, p2_name, mm_dd_yyyy, result = parts + + if has_result_filter: + if self.wins_losses_both == "wins" and result != "WIN": + continue + if self.wins_losses_both == "losses" and result != "LOSS": + continue + + battle_id_clean = ( + battle_id.replace("[", "").replace("]", "").replace(" ", "").lower() + ) + if self.battle_format not in battle_id_clean: + continue + + if has_rating_filter: + rating = self._rating_to_int(rating_str) + if (self.min_rating is not None and rating < self.min_rating) or ( + self.max_rating is not None and rating > self.max_rating + ): + continue + + filenames.append(os.path.join(format_dir, filename)) + + if self.shuffle: + random.shuffle(filenames) + return filenames + + def __len__(self): + return len(self.filenames) + + def _load_json(self, filename: str) -> dict: + if filename.endswith(".lz4"): + with lz4.frame.open(filename, "rb") as f: + return json.load(f) + with open(filename, "r") as f: + return json.load(f) + + def __getitem__( + self, idx + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """returns (team_tokens, additional_info_tokens, lead_idx, format_token)""" + max_attempts = 100 + attempts = 0 + current_idx = idx + + while attempts < max_attempts: + try: + data = self._load_json(self.filenames[current_idx]) + first_state = UniversalState.from_dict(data["states"][0]) + + our_team = [ + first_state.player_active_pokemon + ] + first_state.available_switches + + if len(our_team) != 6: + attempts += 1 + current_idx = (current_idx + 1) % len(self.filenames) + continue + + our_team_sorted = consistent_pokemon_order(our_team) + + opponent_team_names = first_state.opponent_teampreview + if len(opponent_team_names) != 6: + attempts += 1 + current_idx = (current_idx + 1) % len(self.filenames) + continue + + opponent_team_sorted = consistent_pokemon_order(opponent_team_names) + + our_tokens = [ + self.tokenizer[pokemon_name(p.name)] for p in our_team_sorted + ] + opp_tokens = [ + self.tokenizer[pokemon_name(name)] for name in opponent_team_sorted + ] + team_tokens = torch.tensor(our_tokens + opp_tokens, dtype=torch.long) + + additional_info_tokens = [] + for p in our_team_sorted: + pokemon_info = [] + moves = consistent_move_order(p.moves)[:4] + for move in moves: + pokemon_info.append(self.tokenizer[move_name(move.name)]) + while len(pokemon_info) < 4: + pokemon_info.append(self.tokenizer[""]) + pokemon_info.append(self.tokenizer[p.ability]) + pokemon_info.append(self.tokenizer[p.item]) + additional_info_tokens.append(pokemon_info) + + additional_info_tokens = torch.tensor( + additional_info_tokens, dtype=torch.long + ) + + lead_name = pokemon_name(first_state.player_active_pokemon.name) + try: + lead_idx = next( + i + for i, p in enumerate(our_team_sorted) + if pokemon_name(p.name) == lead_name + ) + except StopIteration: + if attempts == 0: + print(f"WARNING: Active pokemon {lead_name} not found in team") + attempts += 1 + current_idx = (current_idx + 1) % len(self.filenames) + continue + + format_str = f"<{self.battle_format}>" + format_token = torch.tensor( + self.tokenizer[format_str], dtype=torch.long + ) + + return ( + team_tokens, + additional_info_tokens, + torch.tensor(lead_idx, dtype=torch.long), + format_token, + ) + + except Exception as e: + if attempts == 0: + print(f"ERROR loading {self.filenames[current_idx]}: {e}") + attempts += 1 + current_idx = (current_idx + 1) % len(self.filenames) + + raise RuntimeError( + f"Failed to load valid sample after {max_attempts} attempts from idx {idx}" + ) + + +class TeamPreviewModel(nn.Module): + """Perceiver-style model: 12 pokemon + optional additional info -> predict lead (1 of 6).""" + + def __init__( + self, + tokenizer: PokemonTokenizer, + d_model: int = 128, + n_heads: int = 4, + n_layers: int = 3, + latent_tokens: int = 4, + dropout: float = 0.1, + use_additional_info: bool = True, + use_argmax: bool = False, + ): + super().__init__() + self.tokenizer = tokenizer + self.d_model = d_model + self.use_additional_info = use_additional_info + self.use_argmax = use_argmax + self.token_emb = nn.Embedding(len(tokenizer), d_model) + self.pokemon_pos_emb = nn.Embedding(12, d_model) + self.team_emb = nn.Embedding(2, d_model) + + if use_additional_info: + self.info_pokemon_emb = nn.Embedding(6, d_model) + self.info_slot_emb = nn.Embedding(6, d_model) + self.type_emb = nn.Embedding(2, d_model) + + self.latents = nn.Parameter(torch.randn(latent_tokens, d_model) * 0.02) + + self.cross_blocks = nn.ModuleList( + [ + CrossAttentionBlock(d_model=d_model, n_heads=n_heads, dropout=dropout) + for _ in range(n_layers) + ] + ) + self.self_blocks = nn.ModuleList( + [ + SelfAttentionBlock(d_model=d_model, n_heads=n_heads, dropout=dropout) + for _ in range(n_layers) + ] + ) + self.final_norm = nn.LayerNorm(d_model) + self.classifier = nn.Linear(latent_tokens * d_model, 6) + + def forward( + self, + team_tokens: torch.Tensor, + additional_info_tokens: Optional[torch.Tensor] = None, + format_token: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B = team_tokens.shape[0] + device = team_tokens.device + + pokemon_token_emb = self.token_emb(team_tokens) + pokemon_pos_ids = torch.arange(12, device=device).unsqueeze(0).expand(B, -1) + pokemon_pos_emb = self.pokemon_pos_emb(pokemon_pos_ids) + + team_ids = torch.cat( + [ + torch.zeros(B, 6, dtype=torch.long, device=device), + torch.ones(B, 6, dtype=torch.long, device=device), + ], + dim=1, + ) + pokemon_team_emb = self.team_emb(team_ids) + + if self.use_additional_info: + pokemon_type_ids = torch.zeros(B, 12, dtype=torch.long, device=device) + pokemon_type_emb = self.type_emb(pokemon_type_ids) + pokemon_emb = ( + pokemon_token_emb + + pokemon_pos_emb + + pokemon_team_emb + + pokemon_type_emb + ) + else: + pokemon_emb = pokemon_token_emb + pokemon_pos_emb + pokemon_team_emb + + if self.use_additional_info: + if additional_info_tokens is None: + raise ValueError( + "additional_info_tokens required when use_additional_info=True" + ) + + info_token_emb = self.token_emb(additional_info_tokens) + info_pokemon_ids = ( + torch.arange(6, device=device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(B, -1, 6) + ) + info_slot_ids = ( + torch.arange(6, device=device) + .unsqueeze(0) + .unsqueeze(0) + .expand(B, 6, -1) + ) + + info_pokemon_emb = self.info_pokemon_emb(info_pokemon_ids) + info_slot_emb = self.info_slot_emb(info_slot_ids) + info_team_ids = torch.zeros(B, 6, 6, dtype=torch.long, device=device) + info_team_emb = self.team_emb(info_team_ids) + info_type_ids = torch.ones(B, 6, 6, dtype=torch.long, device=device) + info_type_emb = self.type_emb(info_type_ids) + + info_emb = ( + info_token_emb + + info_pokemon_emb + + info_slot_emb + + info_team_emb + + info_type_emb + ) + info_emb = rearrange(info_emb, "b p i d -> b (p i) d") + emb = torch.cat([pokemon_emb, info_emb], dim=1) + else: + emb = pokemon_emb + + if format_token is not None: + format_emb = self.token_emb(format_token).unsqueeze(1) + emb = torch.cat([format_emb, emb], dim=1) + + latents = self.latents.unsqueeze(0).expand(B, -1, -1) + + for cross, self_attn in zip(self.cross_blocks, self.self_blocks): + latents = cross(latents, emb) + latents = self_attn(latents) + + latents = self.final_norm(latents) + latents_flat = rearrange(latents, "b n d -> b (n d)") + return self.classifier(latents_flat) + + def predict_lead_from_state( + self, + state: UniversalState, + device: Optional[str] = None, + ) -> Tuple[str, torch.Tensor]: + """predict lead from a UniversalState (should be first state with teampreview)""" + our_team = [state.player_active_pokemon] + state.available_switches + if len(our_team) != 6: + raise ValueError(f"Expected 6 pokemon in our team, got {len(our_team)}") + + opponent_team_names = state.opponent_teampreview + if len(opponent_team_names) != 6: + raise ValueError( + f"Expected 6 pokemon in opponent teampreview, got {len(opponent_team_names)}" + ) + + return self.predict_lead( + our_team=[p.name for p in our_team], + our_team_moves=[[m.name for m in p.moves] for p in our_team], + our_team_abilities=[p.ability for p in our_team], + our_team_items=[p.item for p in our_team], + opponent_team=opponent_team_names, + device=device, + ) + + def predict_lead( + self, + our_team: List[str], + our_team_moves: List[List[str]], + our_team_abilities: List[str], + our_team_items: List[str], + opponent_team: List[str], + battle_format: Optional[str] = None, + device: Optional[str] = None, + ) -> Tuple[str, torch.Tensor]: + """predict which pokemon to lead with""" + if device is None: + device = next(self.parameters()).device + + # sort teams consistently + our_team_with_info = list( + zip(our_team, our_team_moves, our_team_abilities, our_team_items) + ) + our_team_with_info_sorted = sorted( + our_team_with_info, key=lambda x: pokemon_name(x[0]) + ) + our_team_sorted = [name for name, _, _, _ in our_team_with_info_sorted] + our_moves_sorted = [moves for _, moves, _, _ in our_team_with_info_sorted] + our_abilities_sorted = [ + ability for _, _, ability, _ in our_team_with_info_sorted + ] + our_items_sorted = [item for _, _, _, item in our_team_with_info_sorted] + opponent_team_sorted = consistent_pokemon_order(opponent_team) + + # tokenize pokemon + our_tokens = [self.tokenizer[pokemon_name(name)] for name in our_team_sorted] + opp_tokens = [ + self.tokenizer[pokemon_name(name)] for name in opponent_team_sorted + ] + team_tokens = torch.tensor([our_tokens + opp_tokens], dtype=torch.long).to( + device + ) + + # tokenize additional info + additional_info_tokens = None + if self.use_additional_info: + additional_info_tokens = [] + for moves, ability, item in zip( + our_moves_sorted, our_abilities_sorted, our_items_sorted + ): + pokemon_info = [] + moves_sorted = consistent_move_order(moves)[:4] + for move in moves_sorted: + pokemon_info.append(self.tokenizer[move_name(move)]) + while len(pokemon_info) < 4: + pokemon_info.append(self.tokenizer[""]) + pokemon_info.append(self.tokenizer[ability]) + pokemon_info.append(self.tokenizer[item]) + additional_info_tokens.append(pokemon_info) + additional_info_tokens = torch.tensor( + [additional_info_tokens], dtype=torch.long + ).to(device) + + format_token = None + if battle_format is not None: + format_token = torch.tensor( + [self.tokenizer[f"<{battle_format}>"]], dtype=torch.long + ).to(device) + + self.eval() + with torch.no_grad(): + logits = self(team_tokens, additional_info_tokens, format_token) + probs = F.softmax(logits, dim=-1).squeeze(0) + + if self.use_argmax: + lead_idx = probs.argmax().item() + else: + lead_idx = torch.multinomial(probs, num_samples=1).item() + + return our_team_sorted[lead_idx], probs + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: str, + tokenizer: Optional[PokemonTokenizer] = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + use_argmax: bool = False, + ) -> "TeamPreviewModel": + """load model from checkpoint. tokenizer auto-loaded if not provided.""" + ckpt = torch.load(checkpoint_path, map_location=device) + + # auto-load tokenizer if not provided + if tokenizer is None: + tokenizer_name = ckpt["hparams"]["data"].get( + "tokenizer_name", "DefaultObservationSpace-v1" + ) + tokenizer = get_tokenizer(tokenizer_name) + + model = cls(tokenizer=tokenizer, **ckpt["hparams"]["model"]) + model.load_state_dict(ckpt["model_state_dict"]) + model.to(device) + model.eval() + model.use_argmax = use_argmax + + model.trained_formats = ckpt["hparams"]["data"].get( + "battle_formats", [ckpt["hparams"]["data"]["battle_format"]] + ) + + print( + f"Loaded checkpoint from epoch {ckpt['epoch']} (val_acc={ckpt['val_acc']:.4f})" + ) + print(f"Trained on formats: {model.trained_formats}") + print(f"Using {'argmax' if use_argmax else 'sampling'} for lead selection") + + return model + + +def train_team_preview( + tokenizer: PokemonTokenizer, + save_dir: str, + battle_format: str = "gen9ou", + dset_root: Optional[str] = None, + min_rating: int = 1300, + max_rating: Optional[int] = None, + wins_losses_both: str = "both", + epochs: int = 10, + steps_per_epoch: int = 1000, + batch_size: int = 128, + lr: float = 3e-4, + d_model: int = 128, + n_heads: int = 4, + n_layers: int = 3, + latent_tokens: int = 4, + dropout: float = 0.1, + use_additional_info: bool = True, + max_samples: Optional[int] = None, + patience: int = 5, + dloader_workers: int = 4, + log_wandb: bool = True, + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """train team preview model with early stopping""" + os.makedirs(save_dir, exist_ok=True) + + hparams = { + "model": { + "d_model": d_model, + "n_heads": n_heads, + "n_layers": n_layers, + "latent_tokens": latent_tokens, + "dropout": dropout, + "use_additional_info": use_additional_info, + }, + "training": { + "epochs": epochs, + "steps_per_epoch": steps_per_epoch, + "batch_size": batch_size, + "lr": lr, + "max_samples": max_samples, + "patience": patience, + }, + "data": { + "battle_format": battle_format, + "dset_root": dset_root, + "min_rating": min_rating, + "max_rating": max_rating, + "wins_losses_both": wins_losses_both, + "tokenizer_name": tokenizer.name, + "train_size": None, + "val_size": None, + }, + } + + if log_wandb: + wandb.init(project="metamon", entity="ut-austin-rpl-metamon", config=hparams) + + full_dataset = TeamPreviewDataset( + tokenizer=tokenizer, + battle_format=battle_format, + dset_root=dset_root, + min_rating=min_rating, + max_rating=max_rating, + wins_losses_both=wins_losses_both, + max_samples=max_samples, + ) + train_size = int(0.95 * len(full_dataset)) + val_size = len(full_dataset) - train_size + train_dataset, val_dataset = torch.utils.data.random_split( + full_dataset, [train_size, val_size] + ) + + hparams["data"]["train_size"] = train_size + hparams["data"]["val_size"] = val_size + + dloader_kwargs = { + "batch_size": batch_size, + "num_workers": dloader_workers, + "pin_memory": torch.cuda.is_available(), + "persistent_workers": dloader_workers > 0, + "prefetch_factor": 2 if dloader_workers > 0 else None, + } + train_loader = DataLoader(train_dataset, shuffle=True, **dloader_kwargs) + val_loader = DataLoader(val_dataset, shuffle=False, **dloader_kwargs) + + model = TeamPreviewModel(tokenizer=tokenizer, **hparams["model"]).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + print(f"Training on {train_size} samples, validating on {val_size} samples") + print(f"Steps per epoch: {steps_per_epoch}, Total epochs: {epochs}") + + best_val_acc = 0.0 + best_val_loss = float("inf") + epochs_without_improvement = 0 + best_checkpoint_path = os.path.join(save_dir, "best_model.pt") + + def infinite_dataloader(dataloader): + while True: + for batch in dataloader: + yield batch + + train_iter = infinite_dataloader(train_loader) + ema_alpha = 0.1 + + for epoch in range(epochs): + model.train() + train_loss, train_acc, train_count = 0.0, 0.0, 0 + train_loss_ema, train_acc_ema = None, None + + train_pbar = tqdm( + range(steps_per_epoch), desc=f"Epoch {epoch} [Train]", leave=False + ) + for step in train_pbar: + team_tokens, additional_info_tokens, lead_idx, format_token = next( + train_iter + ) + team_tokens = team_tokens.to(device) + additional_info_tokens = additional_info_tokens.to(device) + lead_idx = lead_idx.to(device) + format_token = format_token.to(device) + + optimizer.zero_grad() + logits = model( + team_tokens, + additional_info_tokens if model.use_additional_info else None, + format_token, + ) + loss = F.cross_entropy(logits, lead_idx) + loss.backward() + optimizer.step() + + train_loss += loss.item() * len(team_tokens) + train_acc += (logits.argmax(1) == lead_idx).float().sum().item() + train_count += len(team_tokens) + + batch_loss = loss.item() + batch_acc = (logits.argmax(1) == lead_idx).float().mean().item() + + if train_loss_ema is None: + train_loss_ema = batch_loss + train_acc_ema = batch_acc + else: + train_loss_ema = ( + ema_alpha * batch_loss + (1 - ema_alpha) * train_loss_ema + ) + train_acc_ema = ema_alpha * batch_acc + (1 - ema_alpha) * train_acc_ema + + global_step = epoch * steps_per_epoch + step + if log_wandb: + wandb.log( + { + "train_loss_ema": train_loss_ema, + "train_acc_ema": train_acc_ema, + "global_step": global_step, + }, + step=global_step, + ) + + train_pbar.set_postfix( + {"loss": f"{train_loss_ema:.4f}", "acc": f"{train_acc_ema:.4f}"} + ) + + # validation + model.eval() + val_loss, val_acc, val_count = 0.0, 0.0, 0 + val_loss_ema, val_acc_ema = None, None + + val_pbar = tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False) + with torch.no_grad(): + for team_tokens, additional_info_tokens, lead_idx, format_token in val_pbar: + team_tokens = team_tokens.to(device) + additional_info_tokens = additional_info_tokens.to(device) + lead_idx = lead_idx.to(device) + format_token = format_token.to(device) + + logits = model( + team_tokens, + additional_info_tokens if model.use_additional_info else None, + format_token, + ) + loss = F.cross_entropy(logits, lead_idx) + + val_loss += loss.item() * len(team_tokens) + val_acc += (logits.argmax(1) == lead_idx).float().sum().item() + val_count += len(team_tokens) + + batch_loss = loss.item() + batch_acc = (logits.argmax(1) == lead_idx).float().mean().item() + + if val_loss_ema is None: + val_loss_ema = batch_loss + val_acc_ema = batch_acc + else: + val_loss_ema = ( + ema_alpha * batch_loss + (1 - ema_alpha) * val_loss_ema + ) + val_acc_ema = ema_alpha * batch_acc + (1 - ema_alpha) * val_acc_ema + + val_pbar.set_postfix( + {"loss": f"{val_loss_ema:.4f}", "acc": f"{val_acc_ema:.4f}"} + ) + + epoch_global_step = (epoch + 1) * steps_per_epoch + metrics = { + "epoch": epoch, + "train_loss_epoch": train_loss / train_count, + "train_acc_epoch": train_acc / train_count, + "val_loss": val_loss / val_count, + "val_acc": val_acc / val_count, + } + + print( + f"Epoch {epoch} (step {epoch_global_step}): " + f"train_loss={metrics['train_loss_epoch']:.4f}, train_acc={metrics['train_acc_epoch']:.4f}, " + f"val_loss={metrics['val_loss']:.4f}, val_acc={metrics['val_acc']:.4f}" + ) + + if log_wandb: + wandb.log(metrics, step=epoch_global_step) + + if metrics["val_acc"] > best_val_acc: + best_val_acc = metrics["val_acc"] + best_val_loss = metrics["val_loss"] + epochs_without_improvement = 0 + + checkpoint = { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "epoch": epoch, + "global_step": epoch_global_step, + "hparams": hparams, + "train_loss": metrics["train_loss_epoch"], + "train_acc": metrics["train_acc_epoch"], + "val_loss": metrics["val_loss"], + "val_acc": metrics["val_acc"], + } + torch.save(checkpoint, best_checkpoint_path) + print(f" New best model saved! val_acc={best_val_acc:.4f}") + else: + epochs_without_improvement += 1 + print( + f" No improvement for {epochs_without_improvement} epoch(s). Best={best_val_acc:.4f}" + ) + + if epochs_without_improvement >= patience: + print( + f"\nEarly stopping after {epoch + 1} epochs ({epoch_global_step} steps)" + ) + print(f"Best val_acc: {best_val_acc:.4f}") + break + + latest_checkpoint_path = os.path.join(save_dir, "latest_model.pt") + checkpoint = { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "epoch": epoch, + "global_step": epoch_global_step, + "hparams": hparams, + "train_loss": metrics["train_loss_epoch"], + "train_acc": metrics["train_acc_epoch"], + "val_loss": metrics["val_loss"], + "val_acc": metrics["val_acc"], + } + torch.save(checkpoint, latest_checkpoint_path) + + if log_wandb: + wandb.finish() + + print(f"\nTraining complete") + print(f"Best checkpoint: {best_checkpoint_path}") + print(f"Best val_acc: {best_val_acc:.4f}, val_loss: {best_val_loss:.4f}") + + return TeamPreviewModel.load_from_checkpoint( + best_checkpoint_path, tokenizer, device + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Train team preview prediction model") + parser.add_argument("--save_dir", type=str, required=True) + parser.add_argument("--battle_format", type=str, default="gen9ou") + parser.add_argument("--dset_root", type=str, default=None) + parser.add_argument("--tokenizer", type=str, default="DefaultObservationSpace-v1") + parser.add_argument("--min_rating", type=int, default=1250) + parser.add_argument("--max_rating", type=int, default=None) + parser.add_argument( + "--wins_losses_both", + type=str, + default="both", + choices=["wins", "losses", "both"], + ) + parser.add_argument("--epochs", type=int, default=250) + parser.add_argument("--steps_per_epoch", type=int, default=5000) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=8e-4) + parser.add_argument("--d_model", type=int, default=132) + parser.add_argument("--n_heads", type=int, default=6) + parser.add_argument("--n_layers", type=int, default=6) + parser.add_argument("--latent_tokens", type=int, default=4) + parser.add_argument("--dropout", type=float, default=0.05) + parser.add_argument( + "--no-additional-info", + action="store_false", + dest="use_additional_info", + ) + parser.add_argument("--max_samples", type=int, default=None) + parser.add_argument("--patience", type=int, default=5) + parser.add_argument("--dloader_workers", type=int, default=4) + parser.add_argument("--no-wandb", action="store_true", dest="no_wandb") + args = parser.parse_args() + + tokenizer = get_tokenizer(args.tokenizer) + + train_team_preview( + tokenizer=tokenizer, + save_dir=args.save_dir, + battle_format=args.battle_format, + dset_root=args.dset_root, + min_rating=args.min_rating, + max_rating=args.max_rating, + wins_losses_both=args.wins_losses_both, + epochs=args.epochs, + steps_per_epoch=args.steps_per_epoch, + batch_size=args.batch_size, + lr=args.lr, + d_model=args.d_model, + n_heads=args.n_heads, + n_layers=args.n_layers, + latent_tokens=args.latent_tokens, + dropout=args.dropout, + use_additional_info=args.use_additional_info, + max_samples=args.max_samples, + patience=args.patience, + dloader_workers=args.dloader_workers, + log_wandb=not args.no_wandb, + ) diff --git a/metamon/data/__init__.py b/metamon/data/__init__.py index cd151b42a..2901be227 100644 --- a/metamon/data/__init__.py +++ b/metamon/data/__init__.py @@ -2,5 +2,5 @@ DATA_PATH = os.path.dirname(__file__) -from .parsed_replay_dset import ParsedReplayDataset +from .parsed_replay_dset import MetamonDataset, ParsedReplayDataset, SelfPlayDataset from . import raw_replay_util diff --git a/metamon/data/download.py b/metamon/data/download.py index 853cc5465..9a5e4659f 100644 --- a/metamon/data/download.py +++ b/metamon/data/download.py @@ -10,6 +10,15 @@ import metamon from metamon import SUPPORTED_BATTLE_FORMATS, METAMON_CACHE_DIR +SELF_PLAY_SUBSETS = ["pac-base", "pac-exploratory"] +SELF_PLAY_FORMATS = [ + "gen1ou", + "gen2ou", + "gen3ou", + "gen4ou", + "gen9ou", +] # OU formats available for self-play + if METAMON_CACHE_DIR is not None: VERSION_REFERENCE_PATH = os.path.join(METAMON_CACHE_DIR, "version_reference.json") else: @@ -193,6 +202,110 @@ def download_raw_replays(version: str = LATEST_RAW_REPLAY_REVISION) -> str: return os.path.join(METAMON_CACHE_DIR, "raw-replays") +def download_self_play_data( + subset: str, + battle_format: str, + version: str = "main", + force_download: bool = False, + extract: bool = False, +) -> str: + """Download self-play data from the metamon-parsed-pile dataset. + + Args: + subset: The subset to download. Options: "pac-base", "pac-exploratory" + battle_format: Showdown battle format (e.g. "gen1ou") + version: Version/revision of the dataset to download. Defaults to "main". + force_download: If True, download the dataset even if a previous version + already exists in the cache. + extract: If True, extract all files from the tar archive (slow, uses many inodes). + If False (default), keep data in .tar format for direct reading. + + Returns: + The path to the .tar file (if extract=False) or extracted directory (if extract=True). + """ + if METAMON_CACHE_DIR is None: + raise ValueError("METAMON_CACHE_DIR environment variable is not set") + if subset not in SELF_PLAY_SUBSETS: + raise ValueError( + f"Invalid subset: {subset}. Must be one of {SELF_PLAY_SUBSETS}" + ) + + self_play_dir = os.path.join(METAMON_CACHE_DIR, "self-play", subset) + tar_lz4_path = os.path.join(self_play_dir, f"{battle_format}.tar.lz4") + tar_path = os.path.join(self_play_dir, f"{battle_format}.tar") + extracted_path = os.path.join(self_play_dir, battle_format) + + # Determine output path based on extract flag + out_path = extracted_path if extract else tar_path + + if os.path.exists(out_path): + if not force_download: + return out_path + if extract: + print(f"Clearing existing dataset at {out_path}...") + shutil.rmtree(out_path) + else: + os.remove(out_path) + + hf_hub_download( + cache_dir=self_play_dir, + repo_id="jakegrigsby/metamon-parsed-pile", + filename=f"{subset}/{battle_format}.tar.lz4", + local_dir=os.path.join(METAMON_CACHE_DIR, "self-play"), + revision=version, + repo_type="dataset", + ) + + # Download pre-built SQLite index (skips expensive index build) + sqlite_index_path = os.path.join(self_play_dir, f"{battle_format}.tar.index.sqlite") + if not os.path.exists(sqlite_index_path) or force_download: + try: + hf_hub_download( + cache_dir=self_play_dir, + repo_id="jakegrigsby/metamon-parsed-pile", + filename=f"{subset}/{battle_format}.tar.index.sqlite", + local_dir=os.path.join(METAMON_CACHE_DIR, "self-play"), + revision=version, + repo_type="dataset", + ) + print(f"Downloaded pre-built index: {sqlite_index_path}") + except Exception as e: + print( + f"Note: Pre-built index not available, will be built on first load ({e})" + ) + + # Decompress .tar.lz4 -> .tar + import lz4.frame + from tqdm import tqdm + + compressed_size = os.path.getsize(tar_lz4_path) + print(f"Decompressing {tar_lz4_path} ({compressed_size / 1e9:.1f}GB compressed)...") + + with lz4.frame.open(tar_lz4_path, "rb") as lz4_file: + with open(tar_path, "wb") as tar_file: + # Stream in chunks to handle large files + bytes_written = 0 + with tqdm(unit="B", unit_scale=True, desc="Decompressing") as pbar: + while True: + chunk = lz4_file.read(64 * 1024 * 1024) # 64MB chunks + if not chunk: + break + tar_file.write(chunk) + bytes_written += len(chunk) + pbar.update(len(chunk)) + + os.remove(tar_lz4_path) + + if extract: + print(f"Extracting {tar_path} (this may take a while for large datasets)...") + with tarfile.open(tar_path) as tar: + tar.extractall(path=self_play_dir) + os.remove(tar_path) + + _update_version_reference("self-play", f"{subset}/{battle_format}", version) + return out_path + + def download_usage_stats( gen: int, version: str = LATEST_USAGE_STATS_REVISION, @@ -286,6 +399,9 @@ def print_version_tree(version_dict: dict, indent: int = 0): # Download (anonymized) Showdown replay logs (all formats) python -m metamon.data.download raw-replays + # Download self-play datasets (pac-base and pac-exploratory) + python -m metamon.data.download self-play --formats gen1ou gen9ou + Note: Requires METAMON_CACHE_DIR environment variable to be set. The cache directory is currently: {colored(METAMON_CACHE_DIR or 'NOT SET', 'red')} @@ -300,6 +416,7 @@ def print_version_tree(version_dict: dict, indent: int = 0): choices=[ "raw-replays", "parsed-replays", + "self-play", "revealed-teams", "replay-stats", "teams", @@ -312,6 +429,7 @@ def print_version_tree(version_dict: dict, indent: int = 0): Dataset to download: raw-replays: Unprocessed Showdown replays (stripped of usernames/chat) parsed-replays: RL-compatible version of replays with reconstructed player actions + self-play: Self-play battle data (pac-base and pac-exploratory subsets) revealed-teams: Teams that were revealed during battles replay-stats: Statistics generated from revealed teams. Used to predict team sets. teams: Various team sets (competitive, paper_variety, paper_replays) @@ -321,9 +439,11 @@ def print_version_tree(version_dict: dict, indent: int = 0): "--formats", nargs="+", type=str, - default=SUPPORTED_BATTLE_FORMATS, + default=None, help=""" -Battle formats to download. Defaults to all Gen 1-4 formats (OU, UU, NU, Ubers). +Battle formats to download. Defaults depend on dataset type: + - parsed-replays, teams, usage-stats: All Gen 1-4 formats (OU, UU, NU, Ubers) + Gen 9 OU + - self-play: gen1ou, gen2ou, gen3ou, gen4ou, gen9ou (only OU available) Examples: --formats gen1ou gen2ou # Only Gen 1-2 OU --formats gen3uu gen4uu # Only Gen 3-4 UU @@ -351,10 +471,19 @@ def print_version_tree(version_dict: dict, indent: int = 0): download_raw_replays(version=version) elif args.dataset == "parsed-replays": version = args.version or LATEST_PARSED_REPLAY_REVISION - if args.formats is None: - raise ValueError("Must specify at least one battle format (e.g., gen1ou)") - for format in args.formats: + formats = args.formats or SUPPORTED_BATTLE_FORMATS + for format in formats: download_parsed_replays(format, version=version, force_download=True) + elif args.dataset == "self-play": + version = args.version or "main" + formats = args.formats or SELF_PLAY_FORMATS + print(f"Downloading self-play data for formats: {formats}") + for subset in SELF_PLAY_SUBSETS: + print(f"\nDownloading {subset}...") + for format in formats: + download_self_play_data( + subset, format, version=version, force_download=True + ) elif args.dataset == "revealed-teams": version = args.version or LATEST_PARSED_REPLAY_REVISION download_revealed_teams(version=version, force_download=True) @@ -363,22 +492,18 @@ def print_version_tree(version_dict: dict, indent: int = 0): download_replay_stats(version=version, force_download=True) elif args.dataset == "usage-stats": version = args.version or LATEST_USAGE_STATS_REVISION - if args.formats is None: - raise ValueError("Must specify at least one battle format (e.g., gen1ou)") - generations = set( - metamon.backend.format_to_gen(format) for format in args.formats - ) + formats = args.formats or SUPPORTED_BATTLE_FORMATS + generations = set(metamon.backend.format_to_gen(format) for format in formats) for gen in generations: download_usage_stats(gen=gen, version=version, force_download=True) elif args.dataset == "teams": version = args.version or LATEST_TEAMS_REVISION - if args.formats is None: - raise ValueError("Must specify at least one set name (e.g., gen1ou)") + formats = args.formats or SUPPORTED_BATTLE_FORMATS set_names = ["competitive", "paper_variety", "paper_replays"] if version > "v0": set_names += ["modern_replays", "modern_replays_v2"] for set_name in set_names: - for format in args.formats: + for format in formats: if "ou" not in format and "replays" in set_name: # only OU tiers have replay sets currently continue diff --git a/metamon/data/parsed_replay_dset.py b/metamon/data/parsed_replay_dset.py index fccae4c57..6a1210953 100644 --- a/metamon/data/parsed_replay_dset.py +++ b/metamon/data/parsed_replay_dset.py @@ -1,9 +1,22 @@ +""" +PyTorch datasets for loading parsed Pokémon battle trajectories. + +Classes: + MetamonDataset: Base class for custom/local datasets. + ParsedReplayDataset: Human replays from HuggingFace (flat directories). + SelfPlayDataset: Self-play data from HuggingFace (tar archives). + +Storage formats supported: + - Flat directories: {format}/*.json[.lz4] + - Tar archives: {format}.tar (O(1) access via ratarmountcore SQLite index) +""" + import os import json import random import csv import copy -from typing import Optional, Dict, Tuple, List, Any, Set +from typing import Optional, Dict, Tuple, List, Any from datetime import datetime from collections import defaultdict @@ -11,6 +24,7 @@ import lz4.frame import numpy as np import tqdm +from ratarmountcore.SQLiteIndexedTarFsspec import SQLiteIndexedTarFileSystem import metamon from metamon.interface import ( @@ -20,87 +34,68 @@ ActionSpace, UniversalAction, ) -from metamon.data.download import download_parsed_replays - - -class ParsedReplayDataset(Dataset): - """An iterable dataset of "parsed replays" - - Parsed replays are records of Pokémon Showdown battles that have been converted to the partially observed - point-of-view of a single player, matching the problem our agents face in the RL environment. They are created by the - `metamon.backend.replay_parser` module from "raw" Showdown replay logs - downloaded from publicly available battles. - - This is a pytorch `Dataset` that returns (nested_obs, actions, rewards, dones) trajectory tuples, - where: - - nested_obs: List of numpy arrays of length seq_len (arrays may have different shapes). - If the observation space is a dict, this becomes a dict of lists of arrays for each key. - - actions: Dict with keys: - - "chosen": list (length seq_len) of actions taken by the agent in the chosen action space - - "legal": list (length seq_len) of sets of legal actions available at each timestep in the chosen action space - - "missing": list (length seq_len) of bools indicating the action is missing (should probably be masked) - - rewards: Numpy array of shape (seq_len,) - - dones: Numpy array of shape (seq_len,) - - Note that depending on the observation space, you may need a custom pad_collate_fn in the pytorch dataloader - to handle the variable-shaped arrays in nested_obs. - - Missing actions are a bool mask where idx i = True if action i is missing (actions[i] == -1, or was originally - missing but has since been filled by some prediction scheme). Missing actions are caused by player choices that - are not revealed to spectators and do not show up in the replay logs (e.g., paralysis, sleep, flinch). - - Data is stored as interface.UniversalStates and observations and rewards are created on the fly. This - means we no longer have to create new versions of the parsed replay dataset to experiment with different - observation spaces or reward functions. - - Example: - ```python - dset = ParsedReplayDataset( - observation_space=TokenizedObservationSpace( - DefaultObservationSpace(), - tokenizer=get_tokenizer("DefaultObservationSpace-v1"), - ), - reward_function=DefaultShapedReward(), - formats=["gen1nu"], - verbose=True, - ) +from metamon.data.download import ( + download_parsed_replays, + download_self_play_data, + SELF_PLAY_SUBSETS, + SELF_PLAY_FORMATS, + METAMON_CACHE_DIR, +) - obs, action_infos, rewards, dones = dset[0] - ``` + +class MetamonDataset(Dataset): + """Base dataset class for loading parsed Pokémon battle trajectories. + + Parsed replays are records of Pokémon Showdown battles converted to the partially + observed point-of-view of a single player, matching the problem our agents face in + the RL environment. They are created by the `metamon.backend.replay_parser` module. + + This class auto-detects whether data is stored as: + - Flat directories: {format}/*.json or {format}/*.json.lz4 + - Tar archives: {format}.tar (uses ratarmountcore for O(1) random access) + + Use MetamonDataset directly for local/custom datasets. For official HuggingFace + datasets, use the subclasses: + - ParsedReplayDataset: Human replays from jakegrigsby/metamon-parsed-replays + - SelfPlayDataset: Self-play data from jakegrigsby/metamon-parsed-pile Args: - observation_space: The observation space to use. Must be an instance of `interface.ObservationSpace`. - reward_function: The reward function to use. Must be an instance of `interface.RewardFunction`. - dset_root: The root directory of the parsed replays. If not specified, the parsed replays will be - downloaded and extracted from the latest version of the huggingface dataset, but this may take minutes. - formats: A list of formats to load (e.g. ["gen1ou", "gen2ubers"]). Defaults to all supported formats - (Gen 1-4 ou, uu, nu, and ubers), but this will take a long time to download and extract the first time. - wins_losses_both: Whether to only load the perspective of players who won their battle, lost their - battle, or both. {"wins", "losses", "both"} - min_rating: The minimum rating of battles to load (in ELO). Note that most replays are Unrated, which - is mapped to 1000 ELO (the minimum rating on Showdown). In reality many of these battles were played - as part of tournaments and should probably not be ignored. - max_rating: The maximum rating of battles to load (in ELO). In Generations 1-4, ELO ratings above 1500 - are very good. - min_date: The minimum date of battles to load (as a datetime). Our dataset begins in 2014. Many replays - from 2021-2024 are missing due to a Showdown database issue. See the raw-replay dataset README on - HF for a visual timeline of the dataset. - max_date: The maximum date of battles to load (as a datetime). The latest date available will depend on - the current version of the parsed replays dataset. - max_seq_len: The maximum sequence length to load. Trajectories are randomly sliced to this length. - verbose: Whether to print progress bars while loading large datasets. - shuffle: Whether to shuffle the filenames. Defaults to False. - use_cached_filenames: Whether to use the cached filenames from a manifest.csv file saved during a previous experiment with this replay directory. - Saves time on startup of large training runs. Defaults to False. + dset_root: Root directory containing format subdirs or tar files. + observation_space: Observation space for converting states to observations. + action_space: Action space for converting actions to agent outputs. + reward_function: Reward function for computing rewards from state transitions. + formats: List of battle formats to load (e.g., ["gen1ou", "gen9ou"]). + wins_losses_both: Filter by outcome: "wins", "losses", or "both". + min_rating: Minimum ELO rating filter (unrated battles default to 1000). + max_rating: Maximum ELO rating filter. + min_date: Minimum battle date filter. + max_date: Maximum battle date filter. + max_seq_len: Maximum trajectory length (randomly sliced if exceeded). + verbose: Print progress information. + shuffle: Shuffle the filename list. + use_cached_filenames: Use cached index files for faster startup. + + Returns (from __getitem__): + nested_obs: Dict of lists of numpy arrays for each observation key. + actions: Dict with keys "chosen" (list), "legal" (list of sets), "missing" (list of bools). + rewards: numpy array of shape (seq_len,). + dones: numpy array of shape (seq_len,). + + Note: + Missing actions (actions["missing"][i] == True) occur when player choices are not + revealed in replay logs (e.g., paralysis, sleep, flinch decisions). """ + # Prefix for tar-backed filenames to distinguish from disk files + TAR_PREFIX = "tar://" + def __init__( self, + dset_root: str, observation_space: ObservationSpace, action_space: ActionSpace, reward_function: RewardFunction, - dset_root: Optional[str] = None, - formats: Optional[List[str]] = None, + formats: List[str], wins_losses_both: str = "both", min_rating: Optional[int] = None, max_rating: Optional[int] = None, @@ -111,187 +106,356 @@ def __init__( shuffle: bool = False, use_cached_filenames: bool = False, ): - formats = formats or metamon.SUPPORTED_BATTLE_FORMATS - - if dset_root is None: - for format in formats: - path_to_format_data = download_parsed_replays(format) - dset_root = os.path.dirname(path_to_format_data) + assert os.path.exists(dset_root), f"Dataset root not found: {dset_root}" - assert dset_root is not None and os.path.exists(dset_root) + self.dset_root = dset_root self.observation_space = copy.deepcopy(observation_space) self.action_space = copy.deepcopy(action_space) self.reward_function = copy.deepcopy(reward_function) - self.dset_root = dset_root self.formats = formats self.min_rating = min_rating self.max_rating = max_rating self.min_date = min_date self.max_date = max_date self.wins_losses_both = wins_losses_both - self.verbose = verbose self.max_seq_len = max_seq_len + self.verbose = verbose self.shuffle = shuffle - self.manifest_path = os.path.join(self.dset_root, "manifest.csv") - if os.path.exists(self.manifest_path) and use_cached_filenames: - with open(self.manifest_path, "r") as f: - reader = csv.reader(f) - next(reader) - self.filenames = [row[0] for row in reader] - if verbose: - print(f"Loaded {len(self.filenames)} battles from {self.manifest_path}") - if shuffle: - random.shuffle(self.filenames) - else: - self.refresh_files() + self.use_cached_filenames = use_cached_filenames + + self.index_path = os.path.join(self.dset_root, "index.csv") + + self._tar_files: Dict[str, SQLiteIndexedTarFileSystem] = {} + self._tar_paths: Dict[str, str] = {} + self._format_is_tar: Dict[str, bool] = {} + self._owner_pid: int = os.getpid() # Track PID for fork-safety + + self._detect_formats() + self.refresh_files() - def parse_battle_date(self, filename: str) -> datetime: - # parsed replays saved by our own gym env will have hour/minute/sec - # while Showdown replays will not. - date_str = filename.split("_")[-2] + ###################### + ## Format Detection ## + ###################### - # Try the more common format first (without time) for faster parsing + def _detect_formats(self): + """Detect whether each format is stored as flat directory or tar archive.""" + available_formats = [] + + for format_name in self.formats: + tar_path = os.path.join(self.dset_root, f"{format_name}.tar") + dir_path = os.path.join(self.dset_root, format_name) + + if os.path.exists(tar_path): + # TAR ARCHIVE: gen1ou.tar + self._format_is_tar[format_name] = True + self._tar_paths[format_name] = tar_path + available_formats.append(format_name) + if self.verbose: + print(f"Detected tar archive for {format_name}") + + elif os.path.isdir(dir_path): + # FLAT DIRECTORY: gen1ou/ + self._format_is_tar[format_name] = False + available_formats.append(format_name) + if self.verbose: + print(f"Detected flat directory for {format_name}") + + else: + if self.verbose: + print(f"Skipping {format_name}: no data found") + + self.formats = available_formats + + ######################################### + ## Tar Archive Handling (ratarmountcore) # + ######################################### + + def _get_tar(self, format_name: str) -> SQLiteIndexedTarFileSystem: + current_pid = os.getpid() + is_worker = current_pid != self._owner_pid + if is_worker: + self._tar_files.clear() + self._owner_pid = current_pid + + if format_name not in self._tar_files: + if self.verbose and not is_worker: + print(f"Opening {format_name}.tar...") + self._tar_files[format_name] = SQLiteIndexedTarFileSystem( + self._tar_paths[format_name], + printDebug=-1, + ) + return self._tar_files[format_name] + + def _get_tar_index_path(self, format_name: str) -> str: + """Get path to our cached filename list for a tar archive.""" + return os.path.join(self.dset_root, f"{format_name}.tar.index.txt") + + def _index_tar(self, format_name: str) -> List[str]: + """TAR: List all json files in archive using ratarmountcore. + + This opens the tar and builds the SQLite index if not present. + Also caches filename list to .txt for use_cached_filenames=True. + """ + fs = self._get_tar(format_name) + + # List files in the format directory within the tar try: - return datetime.strptime(date_str, "%m-%d-%Y") - except ValueError: - try: - return datetime.strptime(date_str, "%m-%d-%Y-%H:%M:%S") - except ValueError: - raise ValueError(f"Could not parse date string: {date_str}") + files = fs.ls(f"/{format_name}", detail=False) + except FileNotFoundError: + files = fs.ls("/", detail=False) - def refresh_files(self): - self.filenames = [] + # Filter to json files and strip leading slash + member_names = [ + f.lstrip("/") for f in files if f.endswith((".json", ".json.lz4")) + ] + + # Cache the filename list for use_cached_filenames + index_path = self._get_tar_index_path(format_name) + with open(index_path, "w") as f: + for name in member_names: + f.write(name + "\n") + + if self.verbose: + print(f"Found {len(member_names)} files in {format_name}.tar") + + return member_names + + def _load_tar_index(self, format_name: str) -> List[str]: + """TAR: Load cached filename list from .txt file.""" + index_path = self._get_tar_index_path(format_name) + with open(index_path, "r") as f: + return [line.strip() for line in f if line.strip()] + + ############################# + ## Flat Directory Handling ## + ############################# - def _rating_to_int(rating: str) -> int: + def _index_directory(self, format_name: str) -> List[str]: + """DIRECTORY: Scan directory for json files.""" + format_dir = os.path.join(self.dset_root, format_name) + try: + files = os.listdir(format_dir) + except (OSError, PermissionError) as e: + if self.verbose: + print(f" Warning: Could not read {format_dir}: {e}") + return [] + + return [ + os.path.join(format_name, f) + for f in files + if f.endswith((".json", ".json.lz4")) + ] + + ######################## + ## Filename Filtering ## + ######################## + + def _filter_filename(self, filename: str, format_name: str) -> bool: + """Apply rating, date, and win/loss filters to a filename.""" + # Parse filename: battle_id_rating_p1_vs_p2_date_result.json + name_without_ext = ( + filename[:-9] if filename.endswith(".json.lz4") else filename[:-5] + ) + parts = name_without_ext.split("_") + + if len(parts) == 7: + battle_id, rating_str, p1, _, p2, date_str, result = parts + elif len(parts) == 8: + battle_id, rating_str, p1a, p1b, _, p2, date_str, result = parts + else: + return False + + # Validate format in battle_id + if ( + format_name + not in battle_id.replace("[", "").replace("]", "").replace(" ", "").lower() + ): + return False + + # Result filter + if self.wins_losses_both == "wins" and result != "WIN": + return False + if self.wins_losses_both == "losses" and result != "LOSS": + return False + + # Rating filter + if self.min_rating is not None or self.max_rating is not None: try: - return int(rating) + rating = int(rating_str) except ValueError: - return 1000 + rating = 1000 + if self.min_rating and rating < self.min_rating: + return False + if self.max_rating and rating > self.max_rating: + return False + + # Date filter + if self.min_date is not None or self.max_date is not None: + try: + date = self._parse_date(date_str) + if self.min_date and date < self.min_date: + return False + if self.max_date and date > self.max_date: + return False + except ValueError: + return False - bar = lambda it, desc: ( - it if not self.verbose else tqdm.tqdm(it, desc=desc, colour="red") - ) + return True - has_rating_filter = self.min_rating is not None or self.max_rating is not None - has_date_filter = self.min_date is not None or self.max_date is not None - has_result_filter = self.wins_losses_both in ("wins", "losses") + def _parse_date(self, date_str: str) -> datetime: + """Parse date string from filename.""" + try: + return datetime.strptime(date_str, "%m-%d-%Y") + except ValueError: + return datetime.strptime(date_str, "%m-%d-%Y-%H:%M:%S") - for format in self.formats: - path = os.path.join(self.dset_root, format) - if not os.path.exists(path): - if self.verbose: - print( - f"Requested data for format `{format}`, but did not find {path}" - ) - continue - - # Get all files at once and filter by extension first - all_files = os.listdir(path) - json_files = [f for f in all_files if f.endswith((".json", ".json.lz4"))] - - for filename in bar(json_files, desc=f"Finding {format} battles"): - name_without_ext = ( - filename[:-9] if filename.endswith(".json.lz4") else filename[:-5] - ) + ################### + ## File Indexing ## + ################### - parts = name_without_ext.split("_") - if len(parts) != 7: - continue + def refresh_files(self): + """Build the list of files to load, applying filters.""" + self.filenames = [] - battle_id, rating_str, p1_name, _, p2_name, mm_dd_yyyy, result = parts + # Check if we need to rebuild directory index.csv + has_directory_formats = any( + not self._format_is_tar.get(fmt, False) for fmt in self.formats + ) + will_rebuild_dir_index = has_directory_formats and ( + not self.use_cached_filenames or not os.path.exists(self.index_path) + ) + if will_rebuild_dir_index and os.path.exists(self.index_path): + os.remove(self.index_path) # Clear stale index before rebuilding - if has_result_filter: - if self.wins_losses_both == "wins" and result != "WIN": - continue - if self.wins_losses_both == "losses" and result != "LOSS": - continue + for format_name in self.formats: + if self._format_is_tar.get(format_name, False): + self._refresh_tar_format(format_name) + else: + self._refresh_directory_format(format_name) - battle_id_clean = ( - battle_id.replace("[", "").replace("]", "").replace(" ", "").lower() - ) - if format not in battle_id_clean: - continue - - if has_rating_filter: - rating = _rating_to_int(rating_str) - if (self.min_rating is not None and rating < self.min_rating) or ( - self.max_rating is not None and rating > self.max_rating - ): - continue - - if has_date_filter: - try: - date = self.parse_battle_date(filename) - if (self.min_date is not None and date < self.min_date) or ( - self.max_date is not None and date > self.max_date - ): - continue - except ValueError: - continue - self.filenames.append(os.path.join(path, filename)) + if self.verbose: + print(f"Total: {len(self.filenames)} battles after filtering") if self.shuffle: random.shuffle(self.filenames) - with open(self.manifest_path, "w") as f: - if self.verbose: - print( - f"Writing {self.manifest_path} with {len(self.filenames)} battles" - ) - writer = csv.writer(f) - writer.writerow(["filename"]) # Write header row - for filename in self.filenames: - writer.writerow([filename]) + def _refresh_tar_format(self, format_name: str): + """TAR: Index and filter files from a tar archive.""" + index_path = self._get_tar_index_path(format_name) - if self.verbose: - print(f"Dataset contains {len(self.filenames)} battles") + # Get file list (from .txt cache or fresh scan) + if self.use_cached_filenames and os.path.exists(index_path): + if self.verbose: + print(f"Loading cached tar index from {index_path}") + member_names = self._load_tar_index(format_name) + else: + # This will open tar, build SQLite index if needed, and cache .txt + member_names = self._index_tar(format_name) + + # Filter and add with TAR_PREFIX + iterator = ( + tqdm.tqdm(member_names, desc=f"Filtering {format_name}", colour="green") + if self.verbose + else member_names + ) + for member_name in iterator: + if self._filter_filename(os.path.basename(member_name), format_name): + self.filenames.append(f"{self.TAR_PREFIX}{format_name}/{member_name}") + + def _refresh_directory_format(self, format_name: str): + """DIRECTORY: Index and filter files from a flat directory.""" + # Get file list (from cache or fresh scan) + if self.use_cached_filenames and os.path.exists(self.index_path): + with open(self.index_path, "r") as f: + reader = csv.reader(f) + next(reader) # skip header + rel_paths = [ + row[0] for row in reader if row[0].startswith(format_name + os.sep) + ] + if self.verbose: + print(f"Loaded {len(rel_paths)} files from index.csv for {format_name}") + else: + rel_paths = self._index_directory(format_name) + if self.verbose: + print(f"Indexed {len(rel_paths)} files from {format_name}/") + # Write to index.csv cache (append if exists, create with header if not) + write_header = not os.path.exists(self.index_path) + with open(self.index_path, "a") as f: + if write_header: + f.write("filename\n") + for rel_path in rel_paths: + f.write(f"{rel_path}\n") + + # Filter and add as absolute paths + iterator = ( + tqdm.tqdm(rel_paths, desc=f"Filtering {format_name}", colour="green") + if self.verbose + else rel_paths + ) + for rel_path in iterator: + if self._filter_filename(os.path.basename(rel_path), format_name): + self.filenames.append(os.path.join(self.dset_root, rel_path)) - def __len__(self): - return len(self.filenames) + ################## + ## Data Loading ## + ################## def _load_json(self, filename: str) -> dict: + """Load JSON data from either tar archive or disk file.""" + if filename.startswith(self.TAR_PREFIX): + return self._load_json_from_tar(filename) + else: + return self._load_json_from_disk(filename) + + def _load_json_from_tar(self, filename: str) -> dict: + """TAR: Read file using ratarmountcore (O(1) random access).""" + path = filename[len(self.TAR_PREFIX) :] + format_name, member_name = path.split("/", 1) + fs = self._get_tar(format_name) + + data = fs.cat("/" + member_name) + if member_name.endswith(".lz4"): + data = lz4.frame.decompress(data) + return json.loads(data.decode("utf-8")) + + def _load_json_from_disk(self, filename: str) -> dict: + """DIRECTORY: Read file from disk.""" if filename.endswith(".json.lz4"): with lz4.frame.open(filename, "rb") as f: - data = json.loads(f.read().decode("utf-8")) - elif filename.endswith(".json"): - with open(filename, "r") as f: - data = json.load(f) + return json.loads(f.read().decode("utf-8")) else: - raise ValueError(f"Unknown file extension: {filename}") - return data + with open(filename, "r") as f: + return json.load(f) def load_filename(self, filename: str): + """Load and process a single battle trajectory.""" data = self._load_json(filename) states = [UniversalState.from_dict(s) for s in data["states"]] - # reset the observation space, then call once on each state, which lets - # any history-dependent features behave as they would in an online battle + + # Build observations self.observation_space.reset() obs = [self.observation_space.state_to_obs(s) for s in states] - # TODO: handle case where observation space is not a dict. don't have one to test yet. nested_obs = defaultdict(list) for o in obs: for k, v in o.items(): nested_obs[k].append(v) - action_infos = { - "chosen": [], - "legal": [], - "missing": [], - } - # NOTE: the replay parser leaves a blank final action + + # Build actions + action_infos = {"chosen": [], "legal": [], "missing": []} for s, a_idx in zip(states, data["actions"][:-1]): universal_action = UniversalAction(action_idx=a_idx) - missing = universal_action.missing - chosen_agent_action = self.action_space.action_to_agent_output( - s, universal_action + action_infos["chosen"].append( + self.action_space.action_to_agent_output(s, universal_action) ) - legal_universal_actions = UniversalAction.maybe_valid_actions(s) - legal_agent_actions = set( - self.action_space.action_to_agent_output(s, l) - for l in legal_universal_actions + action_infos["legal"].append( + set( + self.action_space.action_to_agent_output(s, l) + for l in UniversalAction.maybe_valid_actions(s) + ) ) - action_infos["chosen"].append(chosen_agent_action) - action_infos["legal"].append(legal_agent_actions) - action_infos["missing"].append(missing) + action_infos["missing"].append(universal_action.missing) + + # Build rewards and dones rewards = np.array( [ self.reward_function(s_t, s_t1) @@ -302,39 +466,140 @@ def load_filename(self, filename: str): dones = np.zeros_like(rewards, dtype=bool) dones[-1] = True + # Random slice if max_seq_len specified if self.max_seq_len is not None: - # s s s s s s s s - # a a a a a a a - # r r r r r r r - # d d d d d d d - safe_start = random.randint( + start = random.randint( 0, max(len(action_infos["chosen"]) - self.max_seq_len, 0) ) - nested_obs = { - k: v[safe_start : safe_start + 1 + self.max_seq_len] - for k, v in nested_obs.items() - } - action_infos = { - k: v[safe_start : safe_start + self.max_seq_len] - for k, v in action_infos.items() - } - rewards = rewards[safe_start : safe_start + self.max_seq_len] - dones = dones[safe_start : safe_start + self.max_seq_len] + end = start + self.max_seq_len + nested_obs = {k: v[start : end + 1] for k, v in nested_obs.items()} + action_infos = {k: v[start:end] for k, v in action_infos.items()} + rewards = rewards[start:end] + dones = dones[start:end] return dict(nested_obs), action_infos, rewards, dones - def random_sample(self): - filename = random.choice(self.filenames) - return self.load_filename(filename) - - def __getitem__(self, i) -> Tuple[ - Dict[str, list[np.ndarray]], - Dict[str, list[Any]], - np.ndarray, - np.ndarray, - ]: + ############################### + ## PyTorch Dataset Interface ## + ############################### + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, i) -> Tuple[Dict, Dict, np.ndarray, np.ndarray]: return self.load_filename(self.filenames[i]) + def random_sample(self): + return self.load_filename(random.choice(self.filenames)) + + +class ParsedReplayDataset(MetamonDataset): + """Human replay dataset from jakegrigsby/metamon-parsed-replays. + + Auto-downloads from HuggingFace to {$METAMON_CACHE_DIR}/parsed-replays. + + See MetamonDataset for full argument documentation. + """ + + def __init__( + self, + observation_space: ObservationSpace, + action_space: ActionSpace, + reward_function: RewardFunction, + dset_root: Optional[str] = None, + formats: Optional[List[str]] = None, + wins_losses_both: str = "both", + min_rating: Optional[int] = None, + max_rating: Optional[int] = None, + min_date: Optional[datetime] = None, + max_date: Optional[datetime] = None, + max_seq_len: Optional[int] = None, + verbose: bool = False, + shuffle: bool = False, + use_cached_filenames: bool = False, + ): + formats = formats or metamon.SUPPORTED_BATTLE_FORMATS + + if dset_root is None: + for format_name in formats: + path = download_parsed_replays(format_name) + dset_root = os.path.dirname(path) + + super().__init__( + dset_root=dset_root, + observation_space=observation_space, + action_space=action_space, + reward_function=reward_function, + formats=formats, + wins_losses_both=wins_losses_both, + min_rating=min_rating, + max_rating=max_rating, + min_date=min_date, + max_date=max_date, + max_seq_len=max_seq_len, + verbose=verbose, + shuffle=shuffle, + use_cached_filenames=use_cached_filenames, + ) + + +class SelfPlayDataset(MetamonDataset): + """Self-play dataset from jakegrigsby/metamon-parsed-pile. + + Auto-downloads from HuggingFace to {$METAMON_CACHE_DIR}/self-play/{subset}. + + Args: + subset: Which self-play subset to load: + - "pac-base": 11M trajectories from PokéAgent Challenge training + - "pac-exploratory": 7M trajectories from higher-temperature sampling. + formats: Defaults to SELF_PLAY_FORMATS (gen1-4ou, gen9ou). + + See MetamonDataset for remaining argument documentation. + """ + + def __init__( + self, + subset: str, + observation_space: ObservationSpace, + action_space: ActionSpace, + reward_function: RewardFunction, + formats: Optional[List[str]] = None, + wins_losses_both: str = "both", + min_date: Optional[datetime] = None, + max_date: Optional[datetime] = None, + max_seq_len: Optional[int] = None, + verbose: bool = False, + shuffle: bool = False, + use_cached_filenames: bool = False, + ): + if subset not in SELF_PLAY_SUBSETS: + raise ValueError( + f"Invalid subset: {subset}. Must be one of {SELF_PLAY_SUBSETS}" + ) + + self.subset = subset + formats = formats or SELF_PLAY_FORMATS + + # Download tar files (without extracting) + for format_name in formats: + download_self_play_data(subset, format_name, extract=False) + dset_root = os.path.join(METAMON_CACHE_DIR, "self-play", subset) + + super().__init__( + dset_root=dset_root, + observation_space=observation_space, + action_space=action_space, + reward_function=reward_function, + formats=formats, + wins_losses_both=wins_losses_both, + min_date=min_date, + max_date=max_date, + max_seq_len=max_seq_len, + verbose=verbose, + shuffle=shuffle, + use_cached_filenames=use_cached_filenames, + ) + if __name__ == "__main__": from argparse import ArgumentParser @@ -363,6 +628,7 @@ def __getitem__(self, i) -> Tuple[ formats=args.formats, verbose=True, shuffle=True, + use_cached_filenames=True, ) for i in tqdm.tqdm(range(len(dset))): obs, actions, rewards, dones = dset[i] diff --git a/metamon/env/metamon_player.py b/metamon/env/metamon_player.py index 02ac45f51..b8f962e78 100644 --- a/metamon/env/metamon_player.py +++ b/metamon/env/metamon_player.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import orjson @@ -8,9 +8,24 @@ from metamon.env.metamon_battle import MetamonBackendBattle, PokeAgentBackendBattle from metamon.backend.showdown_dex import Dex +from metamon.backend.replay_parser.str_parsing import pokemon_name, move_name +from metamon.interface import UniversalPokemon class MetamonPlayer(Player): + """Extended Player with optional team preview prediction model.""" + + def __init__(self, *args, team_preview_model=None, **kwargs): + """ + Initialize MetamonPlayer. + + Args: + team_preview_model: Optional TeamPreviewModel to use for predicting leads. + If None, falls back to random team preview selection. + *args, **kwargs: Arguments passed to Player.__init__ + """ + super().__init__(*args, **kwargs) + self.team_preview_model = team_preview_model def create_metamon_battle(self, battle_tag: str) -> MetamonBackendBattle: return MetamonBackendBattle( @@ -176,6 +191,86 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): elif split_message[1] == "uhtml" and split_message[2] == "otsrequest": await self._handle_ots_request(battle.battle_tag) + def teampreview(self, battle: AbstractBattle) -> str: + """ + Returns a teampreview order for the given battle. + + If a team_preview_model is provided, uses it to predict the best lead. + Otherwise, falls back to random selection. + + Args: + battle: The battle in team preview + + Returns: + Team order string in format "/team 3461..." where first pokemon is the lead + """ + if self.team_preview_model is None: + # fallback to random if no model provided + return self.random_teampreview(battle) + + # fallback to random for non-team-preview (or untrained teampreview) formats + battle_format = self._format.replace("-", "").lower() # e.g., "gen9ou" + if battle_format not in self.team_preview_model.trained_formats: + self.logger.warning( + f"Battle format {battle_format} not in trained formats {self.team_preview_model.trained_formats}. " + f"Falling back to random." + ) + return self.random_teampreview(battle) + + team_list = list(battle.team.values()) + opponent_list = list(battle.opponent_team.values()) + if len(team_list) != 6 or len(opponent_list) != 6: + self.logger.warning( + f"Invalid team sizes: our={len(team_list)}, opponent={len(opponent_list)}. " + f"Falling back to random." + ) + return self.random_teampreview(battle) + + # build team preview input + our_team_names = [pokemon_name(p.species) for p in team_list] + our_team_moves = [ + [move_name(m.id) for m in p.moves.values()] if p.moves else [] + for p in team_list + ] + our_team_abilities = [ + UniversalPokemon.universal_abilities(p.ability) for p in team_list + ] + our_team_items = [UniversalPokemon.universal_items(p.item) for p in team_list] + opponent_team_names = [pokemon_name(p.species) for p in opponent_list] + + # team preview inference + predicted_lead_name, probs = self.team_preview_model.predict_lead( + our_team=our_team_names, + our_team_moves=our_team_moves, + our_team_abilities=our_team_abilities, + our_team_items=our_team_items, + opponent_team=opponent_team_names, + battle_format=battle_format, + ) + + # format team preview prediction output to showdown command + lead_position = None + for i, pokemon in enumerate(team_list): + if pokemon_name(pokemon.species) == predicted_lead_name: + lead_position = i + 1 # 1-indexed + break + if lead_position is None: + self.logger.warning( + f"Could not find predicted lead {predicted_lead_name} in team, falling back to random" + ) + return self.random_teampreview(battle) + members = [lead_position] + for i in range(1, len(team_list) + 1): + if i != lead_position: + members.append(i) + team_order = "/team " + "".join([str(c) for c in members]) + self.logger.warning( + f"Team preview prediction: leading with {predicted_lead_name} (position {lead_position}), " + f"probs: {probs.cpu().numpy()}" + ) + + return team_order + @staticmethod def choose_random_move(battle: MetamonBackendBattle): # default version demands built-in Battle/DoubleBattle types diff --git a/metamon/env/wrappers.py b/metamon/env/wrappers.py index 190c6bad4..e2419eb3b 100644 --- a/metamon/env/wrappers.py +++ b/metamon/env/wrappers.py @@ -3,7 +3,6 @@ import copy import json import warnings -import warnings from datetime import datetime from typing import Optional, Type, Any, List @@ -207,6 +206,9 @@ class PokeEnvWrapper(OpenAIGymEnv): Player Username, Team File, Opponent Username, Result, Turn Count, Battle ID. battle_backend: The Showdown state parsing backend. Options are 'poke-env' or 'metamon'. + team_preview_model: Optional TeamPreviewModel to use for predicting leads during + team preview. Only works with battle_backend='metamon'. If None, uses random + team preview selection. """ _INIT_RETRIES = 250 @@ -230,7 +232,8 @@ def __init__( turn_limit: int = 1000, save_trajectories_to: Optional[str] = None, save_team_results_to: Optional[str] = None, - battle_backend: str = "poke-env", + battle_backend: str = "metamon", + team_preview_model=None, ): opponent_team_set = opponent_team_set or copy.deepcopy(player_team_set) random_username = ( @@ -290,6 +293,12 @@ def __init__( if battle_backend == "poke-env": player_class = Player + # Warn if team preview model is provided with poke-env backend + if team_preview_model is not None: + warnings.warn( + "team_preview_model is only supported with battle_backend='metamon'. " + "The model will be ignored with battle_backend='poke-env'." + ) elif battle_backend == "metamon": player_class = MetamonPlayer elif battle_backend == "pokeagent": @@ -310,6 +319,10 @@ def __init__( start_challenging=start_challenging, ) + # Set team preview model on the agent if provided and using metamon backend + if team_preview_model is not None and battle_backend == "metamon": + self.agent.team_preview_model = team_preview_model + @property def server_configuration(self): return LocalhostServerConfiguration @@ -463,6 +476,7 @@ def __init__( save_trajectories_to: Optional[str] = None, save_team_results_to: Optional[str] = None, battle_backend: str = "poke-env", + team_preview_model=None, ): super().__init__( battle_format=battle_format, @@ -472,6 +486,7 @@ def __init__( player_team_set=team_set, opponent_team_set=team_set, opponent_type=opponent_type, + team_preview_model=team_preview_model, turn_limit=turn_limit, save_trajectories_to=save_trajectories_to, save_team_results_to=save_team_results_to, @@ -509,6 +524,7 @@ def __init__( player_password: Optional[str] = None, battle_backend: str = "poke-env", print_battle_bar: bool = True, + team_preview_model=None, ): super().__init__( battle_format=battle_format, @@ -526,6 +542,7 @@ def __init__( save_trajectories_to=save_trajectories_to, save_team_results_to=save_team_results_to, battle_backend=battle_backend, + team_preview_model=team_preview_model, ) print(f"Laddering for {num_battles} battles") self.print_battle_bar = print_battle_bar @@ -568,6 +585,11 @@ class PokeAgentLadder(QueueOnLocalLadder): # to expand to active players... depending on who is online and what your Elo is. _INIT_RETRIES = 3000 + # increases time to launch opponent envs before ladder loop times out ("Agent is not challenging"). + # may need to be especially long for PokéAgent because it takes some time for your Elo search radius + # to expand to active players... depending on who is online and what your Elo is. + _INIT_RETRIES = 3000 + @property def server_configuration(self): return PokeAgentServerConfiguration diff --git a/metamon/rl/configs/models/alakazam.gin b/metamon/rl/configs/models/alakazam.gin new file mode 100644 index 000000000..c1d1c7977 --- /dev/null +++ b/metamon/rl/configs/models/alakazam.gin @@ -0,0 +1,57 @@ +import amago.nets.actor_critic +import amago.nets.traj_encoders +import amago.nets.transformer +import amago.agent +import amago.experiment + +MetamonAMAGOExperiment.agent_type = @agent.MultiTaskAgent +MetamonAMAGOExperiment.tstep_encoder_type = @MetamonTstepEncoder +MetamonAMAGOExperiment.traj_encoder_type = @traj_encoders.TformerTrajEncoder +MetamonAMAGOExperiment.max_seq_len = 200 + +# actor +MultiTaskAgent.actor_type = @MetamonMaskedResidualActor +MultiTaskAgent.pass_obs_keys_to_actor = ["illegal_actions"] +MetamonMaskedActor.activation = "leaky_relu" +MetamonMaskedActor.n_layers = 2 +MetamonMaskedActor.d_hidden = 400 + +# critic +MultiTaskAgent.critic_type = @actor_critic.NCriticsTwoHot +actor_critic.NCriticsTwoHot.activation = "leaky_relu" +actor_critic.NCriticsTwoHot.n_layers = 2 +actor_critic.NCriticsTwoHot.d_hidden = 512 +MultiTaskAgent.popart = True +MultiTaskAgent.num_critics = 6 +actor_critic.NCriticsTwoHot.output_bins = 96 +actor_critic.NCriticsTwoHot.min_return = -100 +actor_critic.NCriticsTwoHot.max_return = 2100 +actor_critic.NCriticsTwoHot.use_symlog = False + + +# local metamon architectures +MetamonTstepEncoder.extra_emb_dim = 18 +MetamonTstepEncoder.d_model = 108 +MetamonTstepEncoder.n_layers = 4 +MetamonTstepEncoder.n_heads = 6 +MetamonTstepEncoder.scratch_tokens = 6 +MetamonTstepEncoder.numerical_tokens = 6 +MetamonTstepEncoder.token_mask_aug = False +MetamonTstepEncoder.dropout = .05 + + + +# amago transformer +traj_encoders.TformerTrajEncoder.n_layers = 6 +traj_encoders.TformerTrajEncoder.n_heads = 12 +traj_encoders.TformerTrajEncoder.d_ff = 3072 +traj_encoders.TformerTrajEncoder.d_model = 768 +traj_encoders.TformerTrajEncoder.normformer_norms = True +traj_encoders.TformerTrajEncoder.sigma_reparam = False +traj_encoders.TformerTrajEncoder.norm = "layer" +traj_encoders.TformerTrajEncoder.head_scaling = True +traj_encoders.TformerTrajEncoder.activation = "leaky_relu" +traj_encoders.TformerTrajEncoder.attention_type = @transformer.ClippedSlidingSinkAttention +transformer.ClippedSlidingSinkAttention.window_size = 96 +transformer.ClippedSlidingSinkAttention.logit_clip = 50 +transformer.ClippedSlidingSinkAttention.sink_size = 5 \ No newline at end of file diff --git a/metamon/rl/configs/models/alakazam2.gin b/metamon/rl/configs/models/alakazam2.gin new file mode 100644 index 000000000..7b2f7523f --- /dev/null +++ b/metamon/rl/configs/models/alakazam2.gin @@ -0,0 +1,56 @@ +import amago.nets.actor_critic +import amago.nets.traj_encoders +import amago.nets.transformer +import amago.agent +import amago.experiment + +MetamonAMAGOExperiment.agent_type = @agent.MultiTaskAgent +MetamonAMAGOExperiment.tstep_encoder_type = @MetamonPerceiverTstepEncoder +MetamonAMAGOExperiment.traj_encoder_type = @traj_encoders.TformerTrajEncoder +MetamonAMAGOExperiment.max_seq_len = 200 + +# actor +MultiTaskAgent.actor_type = @MetamonMaskedResidualActor +MultiTaskAgent.pass_obs_keys_to_actor = ["illegal_actions"] +MetamonMaskedResidualActor.activation = "leaky_relu" +MetamonMaskedResidualActor.feature_dim = 400 +MetamonMaskedResidualActor.residual_ff_dim = 512 +MetamonMaskedResidualActor.residual_blocks = 2 + +# critic +MultiTaskAgent.critic_type = @actor_critic.NCriticsTwoHot +actor_critic.NCriticsTwoHot.activation = "leaky_relu" +actor_critic.NCriticsTwoHot.n_layers = 2 +actor_critic.NCriticsTwoHot.d_hidden = 512 +MultiTaskAgent.popart = True +MultiTaskAgent.num_critics = 6 +actor_critic.NCriticsTwoHot.output_bins = 96 +actor_critic.NCriticsTwoHot.min_return = -100 +actor_critic.NCriticsTwoHot.max_return = 2100 +actor_critic.NCriticsTwoHot.use_symlog = False + + +# Perceiver variant (optional: switch tstep_encoder_type below) +MetamonPerceiverTstepEncoder.extra_emb_dim = 18 +MetamonPerceiverTstepEncoder.d_model = 108 +MetamonPerceiverTstepEncoder.n_layers = 5 +MetamonPerceiverTstepEncoder.n_heads = 6 +MetamonPerceiverTstepEncoder.latent_tokens = 8 +MetamonPerceiverTstepEncoder.numerical_tokens = 6 +MetamonPerceiverTstepEncoder.token_mask_aug = False +MetamonPerceiverTstepEncoder.dropout = .05 +MetamonPerceiverTstepEncoder.max_tokens_per_turn = 128 + + +# amago transformer +traj_encoders.TformerTrajEncoder.n_layers = 6 +traj_encoders.TformerTrajEncoder.n_heads = 12 +traj_encoders.TformerTrajEncoder.d_ff = 3072 +traj_encoders.TformerTrajEncoder.d_model = 768 +traj_encoders.TformerTrajEncoder.normformer_norms = True +traj_encoders.TformerTrajEncoder.sigma_reparam = True +traj_encoders.TformerTrajEncoder.norm = "layer" +traj_encoders.TformerTrajEncoder.head_scaling = True +traj_encoders.TformerTrajEncoder.activation = "leaky_relu" +traj_encoders.TformerTrajEncoder.attention_type = @transformer.FlashAttention +transformer.FlashAttention.window_size = (96, 0) diff --git a/metamon/rl/configs/models/alakazam4.gin b/metamon/rl/configs/models/alakazam4.gin new file mode 100644 index 000000000..7f9b95e8b --- /dev/null +++ b/metamon/rl/configs/models/alakazam4.gin @@ -0,0 +1,56 @@ +import amago.nets.actor_critic +import amago.nets.traj_encoders +import amago.nets.transformer +import amago.agent +import amago.experiment + +MetamonAMAGOExperiment.agent_type = @agent.MultiTaskAgent +MetamonAMAGOExperiment.tstep_encoder_type = @MetamonPerceiverTstepEncoder +MetamonAMAGOExperiment.traj_encoder_type = @traj_encoders.TformerTrajEncoder +MetamonAMAGOExperiment.max_seq_len = 128 + +# actor +MultiTaskAgent.actor_type = @MetamonMaskedResidualActor +MultiTaskAgent.pass_obs_keys_to_actor = ["illegal_actions"] +MetamonMaskedResidualActor.activation = "leaky_relu" +MetamonMaskedResidualActor.feature_dim = 450 +MetamonMaskedResidualActor.residual_ff_dim = 512 +MetamonMaskedResidualActor.residual_blocks = 2 + +# critic +MultiTaskAgent.critic_type = @actor_critic.NCriticsTwoHot +actor_critic.NCriticsTwoHot.activation = "leaky_relu" +actor_critic.NCriticsTwoHot.n_layers = 2 +actor_critic.NCriticsTwoHot.d_hidden = 512 +MultiTaskAgent.popart = True +MultiTaskAgent.num_critics = 6 +actor_critic.NCriticsTwoHot.output_bins = 96 +actor_critic.NCriticsTwoHot.min_return = -100 +actor_critic.NCriticsTwoHot.max_return = 2100 +actor_critic.NCriticsTwoHot.use_symlog = False + + +# Perceiver variant (optional: switch tstep_encoder_type below) +MetamonPerceiverTstepEncoder.extra_emb_dim = 18 +MetamonPerceiverTstepEncoder.d_model = 108 +MetamonPerceiverTstepEncoder.n_layers = 8 +MetamonPerceiverTstepEncoder.n_heads = 6 +MetamonPerceiverTstepEncoder.latent_tokens = 8 +MetamonPerceiverTstepEncoder.numerical_tokens = 6 +MetamonPerceiverTstepEncoder.token_mask_aug = False +MetamonPerceiverTstepEncoder.dropout = .05 +MetamonPerceiverTstepEncoder.max_tokens_per_turn = 128 + + +# amago transformer +traj_encoders.TformerTrajEncoder.n_layers = 8 +traj_encoders.TformerTrajEncoder.n_heads = 12 +traj_encoders.TformerTrajEncoder.d_ff = 3072 +traj_encoders.TformerTrajEncoder.d_model = 768 +traj_encoders.TformerTrajEncoder.normformer_norms = True +traj_encoders.TformerTrajEncoder.sigma_reparam = True +traj_encoders.TformerTrajEncoder.norm = "layer" +traj_encoders.TformerTrajEncoder.head_scaling = True +traj_encoders.TformerTrajEncoder.activation = "leaky_relu" +traj_encoders.TformerTrajEncoder.attention_type = @transformer.FlashAttention +transformer.FlashAttention.window_size = (96, 0) diff --git a/metamon/rl/configs/models/superkazam.gin b/metamon/rl/configs/models/superkazam.gin new file mode 100644 index 000000000..67f7f2998 --- /dev/null +++ b/metamon/rl/configs/models/superkazam.gin @@ -0,0 +1,56 @@ +import amago.nets.actor_critic +import amago.nets.traj_encoders +import amago.nets.transformer +import amago.agent +import amago.experiment + +MetamonAMAGOExperiment.agent_type = @agent.MultiTaskAgent +MetamonAMAGOExperiment.tstep_encoder_type = @MetamonPerceiverTstepEncoder +MetamonAMAGOExperiment.traj_encoder_type = @traj_encoders.TformerTrajEncoder +MetamonAMAGOExperiment.max_seq_len = 128 + +# actor +MultiTaskAgent.actor_type = @MetamonMaskedResidualActor +MultiTaskAgent.pass_obs_keys_to_actor = ["illegal_actions"] +MetamonMaskedResidualActor.activation = "leaky_relu" +MetamonMaskedResidualActor.feature_dim = 500 +MetamonMaskedResidualActor.residual_ff_dim = 800 +MetamonMaskedResidualActor.residual_blocks = 3 + +# critic +MultiTaskAgent.critic_type = @actor_critic.NCriticsTwoHot +actor_critic.NCriticsTwoHot.activation = "leaky_relu" +actor_critic.NCriticsTwoHot.n_layers = 3 +actor_critic.NCriticsTwoHot.d_hidden = 700 +MultiTaskAgent.popart = True +MultiTaskAgent.num_critics = 6 +actor_critic.NCriticsTwoHot.output_bins = 96 +actor_critic.NCriticsTwoHot.min_return = -100 +actor_critic.NCriticsTwoHot.max_return = 2100 +actor_critic.NCriticsTwoHot.use_symlog = False + + +# Perceiver variant (optional: switch tstep_encoder_type below) +MetamonPerceiverTstepEncoder.extra_emb_dim = 18 +MetamonPerceiverTstepEncoder.d_model = 168 +MetamonPerceiverTstepEncoder.n_layers = 10 +MetamonPerceiverTstepEncoder.n_heads = 8 +MetamonPerceiverTstepEncoder.latent_tokens = 8 +MetamonPerceiverTstepEncoder.numerical_tokens = 6 +MetamonPerceiverTstepEncoder.token_mask_aug = True +MetamonPerceiverTstepEncoder.dropout = .08 +MetamonPerceiverTstepEncoder.max_tokens_per_turn = 128 + + +# amago transformer +traj_encoders.TformerTrajEncoder.n_layers = 10 +traj_encoders.TformerTrajEncoder.n_heads = 12 +traj_encoders.TformerTrajEncoder.d_ff = 3600 +traj_encoders.TformerTrajEncoder.d_model = 900 +traj_encoders.TformerTrajEncoder.normformer_norms = True +traj_encoders.TformerTrajEncoder.sigma_reparam = True +traj_encoders.TformerTrajEncoder.norm = "layer" +traj_encoders.TformerTrajEncoder.head_scaling = True +traj_encoders.TformerTrajEncoder.activation = "leaky_relu" +traj_encoders.TformerTrajEncoder.attention_type = @transformer.FlashAttention +transformer.FlashAttention.window_size = (96, 0) diff --git a/metamon/rl/configs/training/alakazam.gin b/metamon/rl/configs/training/alakazam.gin new file mode 100644 index 000000000..929a28b88 --- /dev/null +++ b/metamon/rl/configs/training/alakazam.gin @@ -0,0 +1,31 @@ +import amago.agent + +agent.Agent.reward_multiplier = 10. +agent.MultiTaskAgent.reward_multiplier = 10. + +agent.Agent.tau = .004 +agent.MultiTaskAgent.tau = .004 + +agent.Agent.num_actions_for_value_in_critic_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_critic_loss = 4 + +agent.Agent.num_actions_for_value_in_actor_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_actor_loss = 4 + +agent.Agent.online_coeff = 0.0 +agent.MultiTaskAgent.online_coeff = 0.0 +agent.Agent.offline_coeff = 1.0 +agent.MultiTaskAgent.offline_coeff = 1.0 +agent.Agent.fbc_filter_func = @agent.leaky_relu_filter +agent.MultiTaskAgent.fbc_filter_func = @agent.leaky_relu_filter +agent.leaky_relu_filter.beta = .5 +agent.leaky_relu_filter.tau = 1e-2 +agent.leaky_relu_filter.neg_slope = .05 +agent.leaky_relu_filter.clip_weights_low = 1e-3 +agent.leaky_relu_filter.clip_weights_high = 15. + +MetamonAMAGOExperiment.l2_coeff = 1e-4 +MetamonAMAGOExperiment.learning_rate = 1.25e-4 +MetamonAMAGOExperiment.grad_clip = 1.5 +MetamonAMAGOExperiment.critic_loss_weight = 12.5 +MetamonAMAGOExperiment.lr_warmup_steps = 1250 diff --git a/metamon/rl/configs/training/alakazam2.gin b/metamon/rl/configs/training/alakazam2.gin new file mode 100644 index 000000000..96be2ce8c --- /dev/null +++ b/metamon/rl/configs/training/alakazam2.gin @@ -0,0 +1,31 @@ +import amago.agent + +agent.Agent.reward_multiplier = 10. +agent.MultiTaskAgent.reward_multiplier = 10. + +agent.Agent.tau = .008 +agent.MultiTaskAgent.tau = .008 + +agent.Agent.num_actions_for_value_in_critic_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_critic_loss = 4 + +agent.Agent.num_actions_for_value_in_actor_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_actor_loss = 4 + +agent.Agent.online_coeff = 0.1 +agent.MultiTaskAgent.online_coeff = 0.1 +agent.Agent.offline_coeff = 1.0 +agent.MultiTaskAgent.offline_coeff = 1.0 +agent.Agent.fbc_filter_func = @agent.leaky_relu_filter +agent.MultiTaskAgent.fbc_filter_func = @agent.leaky_relu_filter +agent.leaky_relu_filter.beta = .4 +agent.leaky_relu_filter.tau = 1e-3 +agent.leaky_relu_filter.neg_slope = .05 +agent.leaky_relu_filter.clip_weights_low = 1e-3 +agent.leaky_relu_filter.clip_weights_high = 15. + +MetamonAMAGOExperiment.l2_coeff = 1e-4 +MetamonAMAGOExperiment.learning_rate = 1.25e-4 +MetamonAMAGOExperiment.grad_clip = 1.5 +MetamonAMAGOExperiment.critic_loss_weight = 13.5 +MetamonAMAGOExperiment.lr_warmup_steps = 1500 diff --git a/metamon/rl/configs/training/alakazam3.gin b/metamon/rl/configs/training/alakazam3.gin new file mode 100644 index 000000000..11839afd2 --- /dev/null +++ b/metamon/rl/configs/training/alakazam3.gin @@ -0,0 +1,31 @@ +import amago.agent + +agent.Agent.reward_multiplier = 10. +agent.MultiTaskAgent.reward_multiplier = 10. + +agent.Agent.tau = .008 +agent.MultiTaskAgent.tau = .008 + +agent.Agent.num_actions_for_value_in_critic_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_critic_loss = 4 + +agent.Agent.num_actions_for_value_in_actor_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_actor_loss = 4 + +agent.Agent.online_coeff = 0.1 +agent.MultiTaskAgent.online_coeff = 0.2 +agent.Agent.offline_coeff = 1.0 +agent.MultiTaskAgent.offline_coeff = 1.0 +agent.Agent.fbc_filter_func = @agent.leaky_relu_filter +agent.MultiTaskAgent.fbc_filter_func = @agent.leaky_relu_filter +agent.leaky_relu_filter.beta = .4 +agent.leaky_relu_filter.tau = 1e-3 +agent.leaky_relu_filter.neg_slope = .05 +agent.leaky_relu_filter.clip_weights_low = 1e-3 +agent.leaky_relu_filter.clip_weights_high = 15. + +MetamonAMAGOExperiment.l2_coeff = 1e-4 +MetamonAMAGOExperiment.learning_rate = 1.25e-4 +MetamonAMAGOExperiment.grad_clip = 1.5 +MetamonAMAGOExperiment.critic_loss_weight = 13.5 +MetamonAMAGOExperiment.lr_warmup_steps = 2000 diff --git a/metamon/rl/configs/training/kakuna.gin b/metamon/rl/configs/training/kakuna.gin new file mode 100644 index 000000000..8e1419789 --- /dev/null +++ b/metamon/rl/configs/training/kakuna.gin @@ -0,0 +1,31 @@ +import amago.agent + +agent.Agent.reward_multiplier = 10. +agent.MultiTaskAgent.reward_multiplier = 10. + +agent.Agent.tau = .008 +agent.MultiTaskAgent.tau = .008 + +agent.Agent.num_actions_for_value_in_critic_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_critic_loss = 5 + +agent.Agent.num_actions_for_value_in_actor_loss = 1 +agent.MultiTaskAgent.num_actions_for_value_in_actor_loss = 4 + +agent.Agent.online_coeff = 0.1 +agent.MultiTaskAgent.online_coeff = 0.25 +agent.Agent.offline_coeff = 1.0 +agent.MultiTaskAgent.offline_coeff = 1.0 +agent.Agent.fbc_filter_func = @agent.leaky_relu_filter +agent.MultiTaskAgent.fbc_filter_func = @agent.leaky_relu_filter +agent.leaky_relu_filter.beta = .3 +agent.leaky_relu_filter.tau = 1e-3 +agent.leaky_relu_filter.neg_slope = .05 +agent.leaky_relu_filter.clip_weights_low = 1e-3 +agent.leaky_relu_filter.clip_weights_high = 15. + +MetamonAMAGOExperiment.l2_coeff = 1e-4 +MetamonAMAGOExperiment.learning_rate = 1e-4 +MetamonAMAGOExperiment.grad_clip = 1.0 +MetamonAMAGOExperiment.critic_loss_weight = 12.5 +MetamonAMAGOExperiment.lr_warmup_steps = 10000 diff --git a/metamon/rl/evaluate.py b/metamon/rl/evaluate.py index 1c2778e06..307eb5167 100644 --- a/metamon/rl/evaluate.py +++ b/metamon/rl/evaluate.py @@ -10,6 +10,7 @@ PretrainedModel, ) from metamon.baselines import get_baseline +from metamon.backend.team_preview.preview import TeamPreviewModel from metamon.rl.metamon_to_amago import ( make_baseline_env, make_local_ladder_env, @@ -34,19 +35,23 @@ def pretrained_vs_baselines( checkpoint: Optional[int] = None, total_battles: int = 250, parallel_actors_per_baseline: int = 5, + action_temperature: float = 1.0, async_mp_context: str = "forkserver", - battle_backend: str = "poke-env", + battle_backend: str = "metamon", log_to_wandb: bool = False, save_trajectories_to: Optional[str] = None, save_team_results_to: Optional[str] = None, baselines: Optional[List[str]] = None, + team_preview_model: Optional[TeamPreviewModel] = None, ) -> Dict[str, Any]: """Evaluate a pretrained model against built-in baseline opponents. Defaults to the 6 baselines that the paper calls the "Heuristic Composite Score", but you can specify a list of any of the available baselines (see metamon.baselines.get_all_baseline_names()). """ - agent = pretrained_model.initialize_agent(checkpoint=checkpoint, log=log_to_wandb) + agent = pretrained_model.initialize_agent( + checkpoint=checkpoint, log=log_to_wandb, action_temperature=action_temperature + ) baselines = baselines or HEURISTIC_COMPOSITE_BASELINES agent.async_env_mp_context = async_mp_context # create envs that match the agent's observation/actions/rewards @@ -62,6 +67,7 @@ def pretrained_vs_baselines( battle_backend=battle_backend, team_set=team_set, opponent_type=get_baseline(opponent), + team_preview_model=team_preview_model, ) for opponent in baselines ] @@ -84,10 +90,14 @@ def _pretrained_on_ladder( total_battles: int, checkpoint: Optional[int], log_to_wandb: bool, + action_temperature: float = 1.0, + team_preview_model: Optional[TeamPreviewModel] = None, **ladder_kwargs, ) -> Dict[str, Any]: """Helper function for ladder-based evaluation.""" - agent = pretrained_model.initialize_agent(checkpoint=checkpoint, log=log_to_wandb) + agent = pretrained_model.initialize_agent( + checkpoint=checkpoint, log=log_to_wandb, action_temperature=action_temperature + ) agent.env_mode = "sync" agent.parallel_actors = 1 agent.verbose = False # turn off tqdm progress bar and print poke-env battle status @@ -98,6 +108,7 @@ def _pretrained_on_ladder( action_space=pretrained_model.action_space, reward_function=pretrained_model.reward_function, num_battles=total_battles, + team_preview_model=team_preview_model, **ladder_kwargs, ) @@ -117,10 +128,12 @@ def pretrained_vs_local_ladder( total_battles: int, avatar: Optional[str] = None, checkpoint: Optional[int] = None, - battle_backend: str = "poke-env", + battle_backend: str = "metamon", + action_temperature: float = 1.0, save_trajectories_to: Optional[str] = None, save_team_results_to: Optional[str] = None, log_to_wandb: bool = False, + team_preview_model: Optional[TeamPreviewModel] = None, ) -> Dict[str, Any]: """Evaluate a pretrained model on the ladder of your Local Showdown server. @@ -140,6 +153,8 @@ def pretrained_vs_local_ladder( total_battles=total_battles, checkpoint=checkpoint, log_to_wandb=log_to_wandb, + action_temperature=action_temperature, + team_preview_model=team_preview_model, player_username=username, player_avatar=avatar, player_team_set=team_set, @@ -159,10 +174,12 @@ def pretrained_vs_pokeagent_ladder( total_battles: int, avatar: Optional[str] = None, checkpoint: Optional[int] = None, - battle_backend: str = "poke-env", + battle_backend: str = "metamon", + action_temperature: float = 1.0, save_trajectories_to: Optional[str] = None, save_team_results_to: Optional[str] = None, log_to_wandb: bool = False, + team_preview_model: Optional[TeamPreviewModel] = None, ) -> Dict[str, Any]: """Evaluate a pretrained model on the PokéAgent Challenge ladder. @@ -182,6 +199,8 @@ def pretrained_vs_pokeagent_ladder( total_battles=total_battles, checkpoint=checkpoint, log_to_wandb=log_to_wandb, + action_temperature=action_temperature, + team_preview_model=team_preview_model, player_username=username, player_password=password, player_avatar=avatar, @@ -239,15 +258,30 @@ def _run_default_evaluation(args) -> Dict[str, List[Dict[str, Any]]]: all_results = collections.defaultdict(list) backend = args.battle_backend or pretrained_model.battle_backend - # Print a pretty header with agent, preferred backend, and active backend - pre_header = "=" * 60 - print(pre_header) - print(" Metamon RL Agent Evaluation".center(60)) - print(pre_header) - print(f" Pretrained Agent : {pretrained_model.model_name}") - print(f" Preferred Backend: {pretrained_model.battle_backend}") - print(f" Active Backend : {backend}") - print(pre_header) + # Load team preview model if checkpoint provided + team_preview_model = None + if args.team_preview_checkpoint is not None: + team_preview_model = TeamPreviewModel.load_from_checkpoint( + checkpoint_path=args.team_preview_checkpoint, + device="cuda" if backend == "metamon" else "cpu", + use_argmax=args.team_preview_use_argmax, + ) + print(f"Team preview model loaded from: {args.team_preview_checkpoint}") + + if backend != "metamon": + print( + "WARNING: team_preview_model only works with --battle_backend metamon. It will be ignored." + ) + team_preview_model = None + + # Print banner and evaluation info + metamon.print_banner() + print(f" Agent: {pretrained_model.model_name} | Backend: {backend}", end="") + if team_preview_model is not None: + print(f" | Team Preview: ✓") + else: + print() + print() for gen in args.gens: for format_name in args.formats: @@ -269,8 +303,10 @@ def _run_default_evaluation(args) -> Dict[str, List[Dict[str, Any]]]: "checkpoint": checkpoint, "battle_backend": backend, "save_trajectories_to": args.save_trajectories_to, + "action_temperature": args.temperature, "save_team_results_to": args.save_team_results_to, "log_to_wandb": args.log_to_wandb, + "team_preview_model": team_preview_model, } eval_function = _get_default_eval(args, eval_kwargs) results = eval_function(**eval_kwargs) @@ -383,7 +419,30 @@ def add_cli(parser): action="store_true", help="Log results to Weights & Biases.", ) - + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature for temperature-based sampling. Higher temperature means more exploration.", + ) + parser.add_argument( + "--team_preview_checkpoint", + type=str, + default=None, + help=( + "Path to a team preview model checkpoint (e.g., './checkpoints/best_model.pt'). " + "If provided, the model will predict which pokemon to lead with during team preview. " + "Only works with --battle_backend metamon." + ), + ) + parser.add_argument( + "--team_preview_use_argmax", + action="store_true", + help=( + "If set, use argmax for team preview lead selection instead of sampling from the distribution. " + "Only applies when --team_preview_checkpoint is provided." + ), + ) return parser diff --git a/metamon/rl/finetune_from_hf.py b/metamon/rl/finetune_from_hf.py index 6a1aa07b8..4c580ae0e 100644 --- a/metamon/rl/finetune_from_hf.py +++ b/metamon/rl/finetune_from_hf.py @@ -1,5 +1,6 @@ import wandb +import metamon from metamon.rl.train import ( create_offline_dataset, create_offline_rl_trainer, @@ -84,17 +85,42 @@ def add_cli(parser): default=None, help="Path to the parsed replay directory. Defaults to the official huggingface version.", ) + parser.add_argument( + "--replay_weight", + type=float, + default=1.0, + help="Sampling weight for the human parsed replay dataset (metamon-parsed-replays). Will be renormalized with other weights.", + ) + parser.add_argument( + "--self_play_subsets", + type=str, + nargs="+", + default=None, + help="Official self-play dataset (metamon-parsed-pile) subsets to include (e.g., 'pac-base', 'pac-exploratory'). If not provided, self-play data is not used.", + ) + parser.add_argument( + "--self_play_weights", + type=float, + nargs="+", + default=None, + help="Sampling weights for each self-play subset. Must match length of --self_play_subsets.", + ) parser.add_argument( "--custom_replay_dir", type=str, default=None, - help="Path to an optional second parsed replay dataset (e.g., self-play data you've collected).", + help="Path to an optional custom parsed replay dataset (e.g., additional self-play data you've collected).", ) parser.add_argument( - "--custom_replay_sample_weight", + "--custom_replay_weight", type=float, default=0.25, - help="[0, 1] portion of each batch to sample from the custom dataset (if provided).", + help="Sampling weight for the custom dataset (if provided). Will be renormalized with other weights.", + ) + parser.add_argument( + "--use_cached_filenames", + action="store_true", + help="Use cached filename index for faster startup when reusing an identical training set.", ) parser.add_argument( "--async_env_mp_context", @@ -107,7 +133,7 @@ def add_cli(parser): type=int, nargs="*", default=[1, 2, 3, 4, 9], - help="Generations (of OU) to play against heuristics between training epochs. Win rates usually saturate at 90\%+ quickly, so this is mostly a sanity-check. Reduce gens to save time on launch! Use `--eval_gens` (no arguments) to disable evaluation.", + help="Generations (of OU) to play against heuristics between training epochs. Win rates usually saturate at 90%%+ quickly, so this is mostly a sanity-check. Reduce gens to save time on launch! Use `--eval_gens` (no arguments) to disable evaluation.", ) parser.add_argument( "--formats", @@ -131,6 +157,10 @@ def add_cli(parser): add_cli(parser) args = parser.parse_args() + metamon.print_banner() + print(f" Finetuning: {args.finetune_from_model} → {args.run_name}") + print() + pretrained = get_pretrained_model(args.finetune_from_model) # create the dataset we'll be finetuning on amago_dataset = create_offline_dataset( @@ -138,9 +168,13 @@ def add_cli(parser): action_space=pretrained.action_space, reward_function=pretrained.reward_function, parsed_replay_dir=args.parsed_replay_dir, + replay_weight=args.replay_weight, + self_play_subsets=args.self_play_subsets, + self_play_weights=args.self_play_weights, custom_replay_dir=args.custom_replay_dir, - custom_replay_sample_weight=args.custom_replay_sample_weight, + custom_replay_weight=args.custom_replay_weight, formats=args.formats, + use_cached_filenames=args.use_cached_filenames, ) if args.reward_function is not None: # custom reward function diff --git a/metamon/rl/metamon_to_amago.py b/metamon/rl/metamon_to_amago.py index 94a7fafe0..965019058 100644 --- a/metamon/rl/metamon_to_amago.py +++ b/metamon/rl/metamon_to_amago.py @@ -239,6 +239,59 @@ def env_name(self): return f"{self.env.metamon_battle_format}_vs_{self.env.metamon_opponent_name}" +@gin.configurable +class MetamonDiscrete(amago.nets.policy_dists.Discrete): + """Discrete policy with temperature-based sampling. + + Extends AMAGO's Discrete PolicyOutput to add temperature scaling to the logits. + High-temperature sampling is a better alternative to epsilon-greedy exploration + for self-play in metamon due to illegal action masking. + + Args: + d_action: Dimension of the action space. + temperature: Temperature for scaling logits. Default is 1.0 (no scaling). + clip_prob_low: Clips action probabilities to this value before + renormalizing. Default is 0.001. + clip_prob_high: Clips action probabilities to this value before + renormalizing. Default is 0.99. + """ + + def __init__( + self, + d_action: int, + clip_prob_low: float = 0.001, + clip_prob_high: float = 0.99, + temperature: float = 1.0, + ): + super().__init__( + d_action=d_action, + clip_prob_low=clip_prob_low, + clip_prob_high=clip_prob_high, + ) + self.temperature = temperature + + def forward( + self, vec: torch.Tensor, log_dict: Optional[dict] = None + ) -> amago.nets.policy_dists._Categorical: + scaled_logits = vec / self.temperature + + dist = amago.nets.policy_dists._Categorical(logits=scaled_logits) + probs = dist.probs + clip_probs = probs.clamp(self.clip_prob_low, self.clip_prob_high) + safe_probs = clip_probs / clip_probs.sum(-1, keepdims=True).detach() + safe_dist = amago.nets.policy_dists._Categorical(probs=safe_probs) + + if log_dict is not None: + from amago.nets.utils import add_activation_log + + add_activation_log("MetamonDiscrete-probs", probs, log_dict) + add_activation_log( + "MetamonDiscrete-temperature", torch.tensor(self.temperature), log_dict + ) + + return safe_dist + + @gin.configurable class MetamonMaskedActor(amago.nets.actor_critic.Actor): """ @@ -273,6 +326,7 @@ def __init__( activation=activation, dropout_p=dropout_p, continuous_dist_type=continuous_dist_type, + discrete_dist_type=MetamonDiscrete, ) self.mask_illegal_actions = mask_illegal_actions @@ -335,6 +389,7 @@ def __init__( normalization=normalization, dropout_p=dropout_p, continuous_dist_type=continuous_dist_type, + discrete_dist_type=MetamonDiscrete, ) self.mask_illegal_actions = mask_illegal_actions @@ -385,7 +440,7 @@ def env_name(self): return f"psladder_{self.env.env.username}" -def unknown_token_mask(tokens, skip_prob: float = 0.2, batch_max_prob: float = 0.33): +def unknown_token_mask(tokens, skip_prob: float = 0.5, batch_max_prob: float = 0.2): """Randomly set entries in the text component of the observation space to UNKNOWN_TOKEN. Args: @@ -614,6 +669,9 @@ class MetamonAMAGOExperiment(amago.Experiment): Adds actions masking to the main AMAGO experiment, and leaves room for further tweaks. """ + def start(self): + super().start() + def init_envs(self): out = super().init_envs() amago.utils.call_async_env(self.val_envs, "take_long_break") diff --git a/metamon/rl/pretrained.py b/metamon/rl/pretrained.py index 11c0ff533..2970a1157 100644 --- a/metamon/rl/pretrained.py +++ b/metamon/rl/pretrained.py @@ -17,6 +17,7 @@ def red_warning(msg: str): import metamon from metamon.rl.metamon_to_amago import ( make_placeholder_experiment, + MetamonDiscrete, ) from metamon.interface import ( ObservationSpace, @@ -110,6 +111,7 @@ class PretrainedModel: 'poke-env' is deprecated; maintains the original paper's models. 'metamon' is the lateset version 'pokeagent' maintains policies trained (and used as the organizer baselines) during the PokéAgent Challenge + action_temperature: Temperature for temperature-based sampling. Higher temperature means more exploration. Default is 1.0 (no scaling). """ HF_REPO_ID = "jakegrigsby/metamon" @@ -128,7 +130,7 @@ def __init__( hf_cache_dir: Optional[str] = None, default_checkpoint: int = 40, gin_overrides: Optional[dict] = None, - battle_backend: str = "poke-env", + battle_backend: str = "metamon", ): self.model_name = model_name self.model_gin_config = model_gin_config @@ -193,11 +195,14 @@ def get_path_to_checkpoint(self, checkpoint: int) -> str: return checkpoint_path def initialize_agent( - self, checkpoint: Optional[int] = None, log: bool = False + self, + checkpoint: Optional[int] = None, + log: bool = False, + action_temperature: float = 1.0, ) -> amago.Experiment: # use the base config and the gin file to configure the model amago.cli_utils.use_config( - self.base_config, + self.base_config | {"MetamonDiscrete.temperature": action_temperature}, [self.model_gin_config_path, self.train_gin_config_path], finalize=False, ) @@ -576,6 +581,35 @@ def __init__(self): ) +@pretrained_model() +class Minikazam(PretrainedModel): + """ + An attempt to create an affordable starting point for finetuning. + + Small RNN trained on parsed-replays v4 and ~5M self-play battles. + + Detailed evals compiled here: https://docs.google.com/spreadsheets/d/1GU7-Jh0MkIKWhiS1WNQiPfv49WIajanUF4MjKeghMAc/edit?usp=sharing + """ + + def __init__(self): + super().__init__( + model_name="minikazam", + model_gin_config="minikazam.gin", + train_gin_config="binary_rl.gin", + default_checkpoint=40, + action_space=get_action_space("DefaultActionSpace"), + observation_space=get_observation_space("PAC-OpponentMoveObservationSpace"), + reward_function=get_reward_function("AggressiveShapedReward"), + tokenizer=get_tokenizer("DefaultObservationSpace-v1"), + battle_backend="pokeagent", + gin_overrides={ + "MetamonPerceiverTstepEncoder.tokenizer": get_tokenizer( + "DefaultObservationSpace-v1" + ), + }, + ) + + @pretrained_model() class Abra(PretrainedModel): """ @@ -608,67 +642,217 @@ def __init__(self): @pretrained_model() -class Minikazam(PretrainedModel): +class Kadabra(PretrainedModel): """ - An attempt to create an affordable starting point for finetuning. + A second attempt at self-play on gens1-4 & 9 that was featured in the PokéAgent Challenge. - Small RNN trained on parsed-replays v4 and ~5M self-play battles. - - Detailed evals compiled here: https://docs.google.com/spreadsheets/d/1GU7-Jh0MkIKWhiS1WNQiPfv49WIajanUF4MjKeghMAc/edit?usp=sharing + This policy held the top organizer gen9ou rank for most of the "practice ladder" period in Summer 2025. """ def __init__(self): super().__init__( - model_name="minikazam", - model_gin_config="minikazam.gin", + model_name="kadabra", + model_gin_config="medium_multitaskagent.gin", train_gin_config="binary_rl.gin", - default_checkpoint=40, + default_checkpoint=46, + action_space=get_action_space("DefaultActionSpace"), + observation_space=get_observation_space("TeamPreviewObservationSpace"), + tokenizer=get_tokenizer("DefaultObservationSpace-v1"), + battle_backend="pokeagent", + gin_overrides={ + "amago.nets.traj_encoders.TformerTrajEncoder.attention_type": amago.nets.transformer.FlashAttention, + "amago.nets.transformer.FlashAttention.window_size": (32, 0), + }, + ) + + +@pretrained_model() +class Kadabra2(PretrainedModel): + """ + A third attempt at self-play on gens1-4 & 9 that was featured in the PokéAgent Challenge. + + Confusingly, this policy played under the username "PAC-MM-Alakazam" for most of the challange, and held + the top organizer gen9ou rank at the end of the Summer 2025 practice ladder. Checkpoints have been renamed + for public release such that the best policy with this architecture gets to be "Alakazam" :) + + This marks the first time where performance of policies *trained on Gen9OU* roughly match the paper policies in Gens1-4; + all policies below can play Gen9OU without sacrificing significant performance in Gens1-4. + """ + + def __init__(self): + super().__init__( + model_name="kadabra2", + model_gin_config="alakazam2.gin", + train_gin_config="alakazam2.gin", + default_checkpoint=44, action_space=get_action_space("DefaultActionSpace"), observation_space=get_observation_space("PAC-OpponentMoveObservationSpace"), reward_function=get_reward_function("AggressiveShapedReward"), tokenizer=get_tokenizer("DefaultObservationSpace-v1"), battle_backend="pokeagent", + gin_overrides={ + "MetamonPerceiverTstepEncoder.tokenizer": get_tokenizer( + "DefaultObservationSpace-v1" + ), + }, ) - @property - def base_config(self): - return { - "MetamonPerceiverTstepEncoder.tokenizer": self.tokenizer, - "amago.nets.transformer.SigmaReparam.fast_init": True, - } +@pretrained_model() +class Kadabra3(PretrainedModel): + """ + A fourth attempt at self-play on gens1-4 & 9 that was featured in the PokéAgent Challenge. -############################################## -## 100% Correct PokéAgent Challenge Aliases ## -############################################## + This policy played under the username "PAC-MM-Wildcard" or "PAC-MM-Mystery" during the qualification period. + If it had been pubilcly available, it would have qualified as the #2 seed in Gen1OU and #3 seed in Gen9OU. + """ -# These policies use the unpatched observation space. They will play (slightly) better -# than the main version when the backend is correctly specified as "pokeagent", because -# they can see the tera types that appear when the agent or opponent uses tera in battle. + def __init__(self): + super().__init__( + model_name="kadabra3", + model_gin_config="alakazam2.gin", + train_gin_config="alakazam3.gin", + default_checkpoint=20, + action_space=get_action_space("DefaultActionSpace"), + observation_space=get_observation_space("PAC-OpponentMoveObservationSpace"), + reward_function=get_reward_function("AggressiveShapedReward"), + tokenizer=get_tokenizer("DefaultObservationSpace-v1"), + battle_backend="pokeagent", + gin_overrides={ + "MetamonPerceiverTstepEncoder.tokenizer": get_tokenizer( + "DefaultObservationSpace-v1" + ), + }, + ) + + +@pretrained_model() +class Kadabra4(PretrainedModel): + """ + A fifth attempt at self-play on gens1-4 & 9 that was featured in the PokéAgent Challenge. + + The final PokéAgent Challenge era dataset was 11.6M self-play battles + parsed-replays-v4. + + This policy played under the username "PAC-MM-Mystery" or "PAC-MM-Wildcard" during the qualification period. + If it had been pubilcally available, it would have qualified as the #1 seed in Gen1OU and #2 seed in Gen9OU + (behind FoulPlay). + + Most of the performance gains from Kadabra2 --> Kadabra4 are seen in diverse team evaluations (i.e., "modern_replays_v2" TeamSet). + """ + + def __init__(self): + super().__init__( + model_name="kadabra4", + model_gin_config="alakazam4.gin", + train_gin_config="alakazam3.gin", + default_checkpoint=50, + action_space=get_action_space("DefaultActionSpace"), + observation_space=get_observation_space("PAC-OpponentMoveObservationSpace"), + reward_function=get_reward_function("AggressiveShapedReward"), + tokenizer=get_tokenizer("DefaultObservationSpace-v1"), + battle_backend="pokeagent", + gin_overrides={ + "MetamonPerceiverTstepEncoder.tokenizer": get_tokenizer( + "DefaultObservationSpace-v1" + ), + }, + ) + + +@pretrained_model() +class Alakazam(PretrainedModel): + """ + This policy patches a bug (https://github.com/UT-Austin-RPL/metamon/pull/54) that impacted all PokéAgnet Challenge training runs. + We finetuned Kadabra4 on a new version of the self-play dataset that was patched to include tera types. + The "Kadabra*" policies now intentionally *preserve* the bug for backwards compatibility, so this policy gains a slight + edge when evaluated today (after the bug was patched). + + This policy never appeared on the PokéAgent Challenge ladder but is called "Alakazam" because it is the last model + of this size (~50M params) to be trained on the PokéAgent Challenge dataset. + """ -@pretrained_model("PAC-SmallRLGen9Beta") -class PACSmallRLGen9Beta(SmallRLGen9Beta): def __init__(self): - super().__init__() - self.observation_space.base_obs_space = get_observation_space( - "TeamPreviewObservationSpace" + super().__init__( + model_name="alakazam", + model_gin_config="alakazam4.gin", + train_gin_config="alakazam3.gin", + default_checkpoint=8, + action_space=get_action_space("DefaultActionSpace"), + observation_space=get_observation_space("OpponentMoveObservationSpace"), + reward_function=get_reward_function("AggressiveShapedReward"), + tokenizer=get_tokenizer("DefaultObservationSpace-v1"), + battle_backend="metamon", + gin_overrides={ + "MetamonPerceiverTstepEncoder.tokenizer": get_tokenizer( + "DefaultObservationSpace-v1" + ), + }, ) -@pretrained_model("PAC-Abra") -class PACAbra(Abra): +@pretrained_model() +class Superkazam(PretrainedModel): + """ + Revisits the PokéAgent Challenge dataset at a model size closer to the paper's SyntheticRLV2 configuration (~140M params). + + - PokéAgent Challenge self-play dataset (11.6M battles) + - (Human) parsed-replays-v4 (4M battles) + + Evals against the most important (modern) baselines are available here: https://docs.google.com/spreadsheets/d/1lU8tQ0tnnupY28kIyK6FVtvPmxLSVT9_slLShOhRsqg/edit?usp=sharing + """ + def __init__(self): - super().__init__() - self.observation_space.base_obs_space = get_observation_space( - "TeamPreviewObservationSpace" + super().__init__( + model_name="superkazam", + model_gin_config="superkazam.gin", + train_gin_config="alakazam3.gin", + default_checkpoint=50, + action_space=get_action_space("DefaultActionSpace"), + observation_space=get_observation_space("OpponentMoveObservationSpace"), + reward_function=get_reward_function("AggressiveShapedReward"), + tokenizer=get_tokenizer("DefaultObservationSpace-v1"), + battle_backend="metamon", + gin_overrides={ + "MetamonPerceiverTstepEncoder.tokenizer": get_tokenizer( + "DefaultObservationSpace-v1" + ), + }, ) -@pretrained_model("PAC-Minikazam") -class PACMinikazam(Minikazam): +@pretrained_model() +class Kakuna(PretrainedModel): + """ + The current best Metamon policy. + + Superkazam, finetuned on a dataset of self-play battles collected at increased temperature for exploration and value learning (+7.8M battles). + + After > 700 total games played over a span of a month, we estimate GXEs vs. humans (with "competitive" TeamSet) of: + + gen1ou: ~82% + gen2ou: ~70% + gen3ou: ~63% + gen4ou: ~64% + gen9ou: ~71% + + Evals against the most important (modern) metamon baselines are available here: https://docs.google.com/spreadsheets/d/1lU8tQ0tnnupY28kIyK6FVtvPmxLSVT9_slLShOhRsqg/edit?usp=sharing + """ + def __init__(self): - super().__init__() - self.observation_space.base_obs_space = get_observation_space( - "OpponentMoveObservationSpace" + super().__init__( + model_name="kakuna", + model_gin_config="superkazam.gin", + train_gin_config="kakuna.gin", + default_checkpoint=34, + action_space=get_action_space("DefaultActionSpace"), + observation_space=get_observation_space("OpponentMoveObservationSpace"), + reward_function=get_reward_function("AggressiveShapedReward"), + tokenizer=get_tokenizer("DefaultObservationSpace-v1"), + battle_backend="metamon", + gin_overrides={ + "MetamonPerceiverTstepEncoder.tokenizer": get_tokenizer( + "DefaultObservationSpace-v1" + ), + }, ) diff --git a/metamon/rl/self_play/README.md b/metamon/rl/self_play/README.md new file mode 100644 index 000000000..75dbbc13a --- /dev/null +++ b/metamon/rl/self_play/README.md @@ -0,0 +1,31 @@ +# Self-Play + +Utility to auto-manage a local ladder of agents for self-play data collection or bulk eval purposes. + +Specify the participating agents with a `.yaml` file. Here is an exmaple: + +```yaml +defaults: + team_set: competitive + battle_backend: metamon + # if a list, each agent launch will pick a value from the list at random + checkpoints: [null] + temperatures: [1.0] + num_agents: 1 # number of parallel copies to launch per agent + +agents: + # USERNAME: + # model_name: SomeModel + # checkpoints: [2] # override default + # num_agents: 3 # will launch USERNAME-1, USERNAME-2, USERNAME-3 + + PAC-MM-Kadabra: + model_name: Kadabra + + PAC-MM-SynRLV2: + model_name: SyntheticRLV2 +``` + +```bash +python launch_models.py --format gen2ou --gpus 0 1 --config earlygen_config.yaml +``` \ No newline at end of file diff --git a/metamon/rl/self_play/__main__.py b/metamon/rl/self_play/__main__.py new file mode 100644 index 000000000..5ce8ae76f --- /dev/null +++ b/metamon/rl/self_play/__main__.py @@ -0,0 +1,11 @@ +""" +Run the self-play launcher as a module. + +Usage: + python -m metamon.rl.self_play --format gen9ou --gpus 0 1 --config metamon/rl/self_play/gen9ou_config.yaml --save_trajectories_to ./trajectories +""" + +from metamon.rl.self_play.launch_models import main + +if __name__ == "__main__": + main() diff --git a/metamon/rl/self_play/earlygen_config.yaml b/metamon/rl/self_play/earlygen_config.yaml new file mode 100644 index 000000000..efcc4f704 --- /dev/null +++ b/metamon/rl/self_play/earlygen_config.yaml @@ -0,0 +1,53 @@ +defaults: + team_set: modern_replays_v2 + battle_backend: metamon + checkpoints: [null] + temperatures: [1.0, 1.2, 1.3, 1.5, 1.75, 2.0, 2.25] + num_agents: 1 + +agents: + SynRLV0: + model_name: SyntheticRLV0 + checkpoints: [null] + num_agents: 1 + + SynRLV1: + model_name: SyntheticRLV1 + checkpoints: [null] + num_agents: 1 + + SynRLV1_PlusPlus: + model_name: SyntheticRLV1_PlusPlus + checkpoints: [null] + num_agents: 1 + + PAC-MM-SynRLV2: + model_name: SyntheticRLV2 + checkpoints: [null, 32, 36, 34] + num_agents: 5 + + PAC-MM-Kadabra: + model_name: Kadabra + + PAC-MM-SmallILFA: + model_name: SmallILFA + checkpoints: [2] + num_agents: 2 + + PAC-MM-Minikazam: + model_name: Minikazam + + PAC-MM-Alakazam: + model_name: Alakazam2 + checkpoints: [40, 48, 36, 32] + num_agents: 2 + + PAC-MM-Wildcard: + model_name: Alakazam3 + checkpoints: [10, 12, 20, 22, 14] + num_agents: 4 + + PAC-MM-Mystery: + model_name: Alakazam4 + checkpoints: [30, 40, 48, 50, 52, 32] + num_agents: 4 diff --git a/metamon/rl/self_play/launch_models.py b/metamon/rl/self_play/launch_models.py new file mode 100644 index 000000000..2e46aa399 --- /dev/null +++ b/metamon/rl/self_play/launch_models.py @@ -0,0 +1,341 @@ +import gc +import os +import random +import subprocess +import sys +import threading +import time +import yaml +from argparse import ArgumentParser +from typing import List, Dict + + +def run_username_on_gpu_continuous( + gpu_id: int, + username: str, + format_name: str, + config_path: str, + n_challenges: int = 50, + startup_delay: int = 0, + restart_delay: int = 60, + timeout: int = 2700, + save_trajectories_to: str = None, + verbose: bool = False, +): + if startup_delay > 0: + print( + f"Waiting {startup_delay} seconds before starting {username} on GPU {gpu_id}..." + ) + time.sleep(startup_delay) + + run_count = 0 + while True: + run_count += 1 + print(f"\n{'='*60}") + print( + f"[Run #{run_count}] Starting {username} on GPU {gpu_id} for format {format_name} with {n_challenges} challenges..." + ) + print(f"{'='*60}") + + # set GPU + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + + cmd = [ + "python", + "serve_model.py", + "--username", + username, + "--format", + format_name, + "--n_challenges", + str(n_challenges), + "--config", + config_path, + ] + + if save_trajectories_to: + cmd.extend(["--save_trajectories_to", save_trajectories_to]) + + process = None + try: + if verbose: + # verbose mode: stream output in real-time + process = subprocess.Popen( + cmd, + env=env, + cwd=os.path.dirname(os.path.abspath(__file__)), + text=True, + ) + else: + # quiet mode: capture output + process = subprocess.Popen( + cmd, + env=env, + cwd=os.path.dirname(os.path.abspath(__file__)), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + # wait for completion + try: + process.wait(timeout=timeout) + if process.returncode == 0: + print( + f"✓ {username} on GPU {gpu_id} [Run #{run_count}] completed successfully" + ) + else: + print( + f"✗ {username} on GPU {gpu_id} [Run #{run_count}] failed with code {process.returncode}" + ) + if verbose: + # stderr already printed in real-time + pass + else: + stderr_output = process.stderr.read() + if stderr_output: + print(f"Error output from {username}:") + print(stderr_output) + except subprocess.TimeoutExpired: + print( + f"⏰ {username} on GPU {gpu_id} [Run #{run_count}] timed out after {timeout} seconds" + ) + process.kill() + process.wait() + + except Exception as e: + print( + f"✗ {username} on GPU {gpu_id} [Run #{run_count}] failed with exception: {e}" + ) + if process and process.poll() is None: + process.kill() + process.wait() + + finally: + # cleanup resources + if process: + if process.poll() is None: + process.kill() + try: + process.wait(timeout=5) + except: + pass + + if hasattr(process, "stdout") and process.stdout: + process.stdout.close() + if hasattr(process, "stderr") and process.stderr: + process.stderr.close() + + del process + + gc.collect() + + print(f"Waiting {restart_delay} seconds before relaunching {username}...") + time.sleep(restart_delay) + + +def get_usernames(config_path: str) -> List[str]: + """Expand agents based on num_agents field""" + with open(config_path, "r") as f: + raw_config = yaml.safe_load(f) + + # validate structure + if "agents" not in raw_config: + raise ValueError("Config must have 'agents' section") + + defaults = raw_config.get("defaults", {}) + agents = raw_config.get("agents", {}) + + # validate defaults + required_defaults = ["team_set", "battle_backend", "checkpoints", "num_agents"] + missing_defaults = [field for field in required_defaults if field not in defaults] + if missing_defaults: + raise ValueError( + f"defaults section missing required fields: {', '.join(missing_defaults)}" + ) + + expanded_usernames = [] + for base_username, agent_config in agents.items(): + # validate required fields + if "model_name" not in agent_config and "model_name" not in defaults: + raise ValueError( + f"Agent {base_username} missing required field: model_name" + ) + + # expand based on num_agents + merged_config = {**defaults, **agent_config} + num_agents = merged_config.get("num_agents", 1) + # handle None/null values in yaml + if num_agents is None: + num_agents = 1 + + if num_agents == 1: + expanded_usernames.append(base_username) + else: + # add numbered copies + for i in range(1, num_agents + 1): + expanded_username = f"{base_username}-{i}" + expanded_usernames.append(expanded_username) + + print( + f"Found {len(agents)} base agents, expanded to {len(expanded_usernames)} total: {', '.join(expanded_usernames)}" + ) + return expanded_usernames + + +def distribute_across_gpus( + usernames: List[str], gpus: List[int] +) -> Dict[int, List[str]]: + gpu_assignments = {gpu: [] for gpu in gpus} + for i, username in enumerate(usernames): + gpu_id = gpus[i % len(gpus)] + gpu_assignments[gpu_id].append(username) + return gpu_assignments + + +def run_all_usernames_parallel( + format_name: str, + gpus: List[int], + config_path: str, + n_challenges: int = 50, + restart_delay: int = 60, + timeout: int = 2700, + save_trajectories_to: str = None, + verbose: bool = False, +): + usernames = get_usernames(config_path) + + print(f"Running usernames: {', '.join(usernames)}") + print(f"Available GPUs: {gpus}") + print(f"Format: {format_name}") + print(f"Config: {config_path}") + print(f"Challenges per username: {n_challenges}") + print(f"Restart delay: {restart_delay} seconds") + print(f"Timeout per run: {timeout} seconds ({timeout//60} minutes)") + if save_trajectories_to: + print(f"Saving trajectories to: {save_trajectories_to}") + print("-" * 50) + + # distribute usernames across GPUs + gpu_assignments = distribute_across_gpus(usernames, gpus) + + for gpu_id, usernames_for_gpu in gpu_assignments.items(): + print(f"GPU {gpu_id}: {', '.join(usernames_for_gpu)}") + print("-" * 50) + + threads = [] + startup_delay = 0 + for gpu_id, usernames_for_gpu in gpu_assignments.items(): + for username in usernames_for_gpu: + thread = threading.Thread( + target=run_username_on_gpu_continuous, + args=( + gpu_id, + username, + format_name, + config_path, + n_challenges, + startup_delay, + restart_delay, + timeout, + save_trajectories_to, + verbose, + ), + daemon=True, + ) + threads.append(thread) + thread.start() + startup_delay += ( + 10 # increased delay to allow bots to connect and start challenging + ) + + print(f"\n✓ All {len(threads)} bots launched and running continuously!") + print("Press Ctrl+C to stop all bots") + print("-" * 50) + + try: + while True: + time.sleep(60) + except KeyboardInterrupt: + print("\n\nShutting down all bots...") + sys.exit(0) + + +def main(): + parser = ArgumentParser( + description="Run serve_model.py for all usernames across multiple GPUs (self-play)" + ) + parser.add_argument( + "--format", + required=True, + choices=["gen1ou", "gen2ou", "gen3ou", "gen4ou", "gen9ou"], + help="The battle format to use", + ) + parser.add_argument( + "--gpus", + nargs="+", + type=int, + required=True, + help="List of GPU IDs to use (e.g., --gpus 0 1 2 3)", + ) + parser.add_argument( + "--config", + default="earlygen_config.yaml", + help="Path to YAML config file (default: earlygen_config.yaml)", + ) + parser.add_argument( + "--n_challenges", + type=int, + default=50, + help="Number of challenges per username (default: 50)", + ) + parser.add_argument( + "--restart_delay", + type=int, + default=80, + help="Seconds to wait before relaunching each bot after completion (default: 80)", + ) + parser.add_argument( + "--timeout", + type=int, + default=2700, + help="Timeout in seconds for each bot run (default: 2700 = 45 minutes)", + ) + parser.add_argument( + "--save_trajectories_to", + required=True, + help="Base directory to save trajectories (will create subdirs per model)", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print error messages from failed runs", + ) + + args = parser.parse_args() + + # validate GPUs + if not args.gpus: + print("Error: At least one GPU ID must be specified") + sys.exit(1) + + # convert config path to absolute path so subprocesses can find it + config_path = os.path.abspath(args.config) + + # run continuously + run_all_usernames_parallel( + args.format, + args.gpus, + config_path, + args.n_challenges, + args.restart_delay, + args.timeout, + args.save_trajectories_to, + args.verbose, + ) + + +if __name__ == "__main__": + main() diff --git a/metamon/rl/self_play/serve_model.py b/metamon/rl/self_play/serve_model.py new file mode 100644 index 000000000..4ea23d8a9 --- /dev/null +++ b/metamon/rl/self_play/serve_model.py @@ -0,0 +1,207 @@ +import os +import random +import warnings +import yaml +from functools import partial +from typing import Optional, Iterable +import json + +import amago + +from metamon.env import get_metamon_teams, QueueOnLocalLadder, TeamSet +from metamon.interface import ObservationSpace, RewardFunction, ActionSpace +from metamon.rl.pretrained import get_pretrained_model +from metamon.rl.metamon_to_amago import PSLadderAMAGOWrapper + +warnings.filterwarnings("ignore") + + +def make_ladder_env( + battle_format: str, + player_team_set: TeamSet, + observation_space: ObservationSpace, + action_space: ActionSpace, + reward_function: RewardFunction, + num_battles: int, + username: str, + save_trajectories_to: Optional[str] = None, + battle_backend: str = "metamon", +): + """ + Battle on the local Showdown ladder + """ + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=amago.utils.AmagoWarning) + menv = QueueOnLocalLadder( + battle_format=battle_format, + num_battles=num_battles, + observation_space=observation_space, + action_space=action_space, + reward_function=reward_function, + player_team_set=player_team_set, + player_username=username, + save_trajectories_to=save_trajectories_to, + battle_backend=battle_backend, + print_battle_bar=False, + ) + return PSLadderAMAGOWrapper(menv) + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument( + "--username", + required=True, + help="Choose a username from the config to evaluate.", + ) + parser.add_argument( + "--format", + default="gen1ou", + choices=["gen1ou", "gen2ou", "gen3ou", "gen4ou", "gen9ou"], + help="Specify the battle format/tier.", + ) + parser.add_argument( + "--save_trajectories_to", + type=str, + default=None, + help="Path to save trajectories to.", + ) + parser.add_argument( + "--n_challenges", + type=int, + default=10, + help=( + "Number of battles to run before returning eval stats. " + "Note this is the total sample size across all parallel actors." + ), + ) + parser.add_argument( + "--config", + required=True, + help="Path to the YAML config file.", + ) + args = parser.parse_args() + + # load config + with open(args.config, "r") as f: + raw_config = yaml.safe_load(f) + + # validate structure + if "agents" not in raw_config: + raise ValueError("Config must have 'agents' section") + + defaults = raw_config.get("defaults", {}) + agents = raw_config.get("agents", {}) + + # validate defaults + required_defaults = [ + "team_set", + "battle_backend", + "checkpoints", + "temperatures", + "num_agents", + ] + missing_defaults = [field for field in required_defaults if field not in defaults] + if missing_defaults: + raise ValueError( + f"defaults section missing required fields: {', '.join(missing_defaults)}" + ) + + # find base username (strip numeric suffix if from num_agents expansion) + username = args.username + if username in agents: + base_username = username + else: + if "-" in username and username.split("-")[-1].isdigit(): + potential_base = "-".join(username.split("-")[:-1]) + if potential_base in agents: + base_username = potential_base + else: + raise ValueError( + f"Username {username} not found in config and could not find base" + ) + else: + raise ValueError(f"Username {username} not found in config") + + # merge config + agent_config = agents[base_username] + account_config = {**defaults, **agent_config} + account_config["battle_format"] = args.format + + # validate required fields + if "model_name" not in account_config: + raise ValueError(f"Agent {base_username} missing required field: model_name") + + # load model and team set + model_name = account_config["model_name"] + agent_maker = get_pretrained_model(model_name) + + # get team_set - uniform random sampling if list + team_set_config = account_config["team_set"] + if isinstance(team_set_config, list) and len(team_set_config) > 0: + team_set_choice = random.choice(team_set_config) + else: + team_set_choice = team_set_config + print(f"Using team_set {team_set_choice}") + player_team_set = get_metamon_teams(args.format, team_set_choice) + + # get checkpoint - uniform random sampling + checkpoints = account_config["checkpoints"] + if checkpoints is not None and len(checkpoints) > 0: + checkpoint = random.choice(checkpoints) + else: + checkpoint = None + print(f"Using checkpoint {checkpoint}") + + # get temperature - uniform random sampling + temperatures = account_config.get("temperatures", [1.0]) + if isinstance(temperatures, Iterable) and not isinstance(temperatures, str): + temperature = random.choice(temperatures) + else: + temperature = float(temperatures) + print(f"Using temperature {temperature}") + battle_backend = account_config["battle_backend"] + print(f"Using battle backend {battle_backend}") + + save_trajectories_to = os.path.join( + args.save_trajectories_to, model_name, battle_backend + ) + os.makedirs(save_trajectories_to, exist_ok=True) + + # initialize agent + agent = agent_maker.initialize_agent( + checkpoint=checkpoint, log=False, action_temperature=temperature + ) + agent.env_mode = "sync" + # create envs + env_kwargs = dict( + battle_format=args.format, + player_team_set=player_team_set, + observation_space=agent_maker.observation_space, + action_space=agent_maker.action_space, + reward_function=agent_maker.reward_function, + save_trajectories_to=save_trajectories_to, + battle_backend=battle_backend, + ) + make_envs = [ + partial( + make_ladder_env, + **env_kwargs, + num_battles=args.n_challenges, + username=username, + ) + ] + agent.verbose = False + agent.parallel_actors = len(make_envs) + + # evaluate + results = agent.evaluate_test( + make_envs, + # sets upper bound on total timesteps + timesteps=args.n_challenges * 350, + # terminates after n_challenges + episodes=args.n_challenges, + ) + print(json.dumps(results, indent=4, sort_keys=True)) diff --git a/metamon/rl/train.py b/metamon/rl/train.py index 1e48b4061..f79965986 100644 --- a/metamon/rl/train.py +++ b/metamon/rl/train.py @@ -15,7 +15,7 @@ RewardFunction, ) from metamon.tokenizer import get_tokenizer -from metamon.data import ParsedReplayDataset +from metamon.data import ParsedReplayDataset, SelfPlayDataset, MetamonDataset from metamon.rl.metamon_to_amago import ( MetamonAMAGOExperiment, MetamonAMAGODataset, @@ -120,17 +120,42 @@ def add_cli(parser): default=None, help="Path to the parsed replay directory. Defaults to the official huggingface version.", ) + parser.add_argument( + "--replay_weight", + type=float, + default=1.0, + help="Sampling weight for the human parsed replay dataset (metamon-parsed-replays). Will be renormalized with other weights.", + ) + parser.add_argument( + "--self_play_subsets", + type=str, + nargs="+", + default=None, + help="Official self-play dataset (metamon-parsed-pile) subsets to include (e.g., 'pac-base', 'pac-exploratory'). If not provided, self-play data is not used.", + ) + parser.add_argument( + "--self_play_weights", + type=float, + nargs="+", + default=None, + help="Sampling weights for each self-play subset. Must match length of --self_play_subsets.", + ) parser.add_argument( "--custom_replay_dir", type=str, default=None, - help="Path to an optional second parsed replay dataset (e.g., self-play data you've collected).", + help="Path to an optional custom parsed replay dataset (e.g., additional self-play data you've collected).", ) parser.add_argument( - "--custom_replay_sample_weight", + "--custom_replay_weight", type=float, default=0.25, - help="[0, 1] portion of each batch to sample from the custom dataset (if provided).", + help="Sampling weight for the custom dataset (if provided). Will be renormalized with other weights.", + ) + parser.add_argument( + "--use_cached_filenames", + action="store_true", + help="Use cached filename index for faster startup when reusing an identical training set.", ) parser.add_argument( "--async_env_mp_context", @@ -143,7 +168,7 @@ def add_cli(parser): type=int, nargs="*", default=[1, 2, 3, 4, 9], - help="Generations (of OU) to play against heuristics between training epochs. Win rates usually saturate at 90\%+ quickly, so this is mostly a sanity-check. Reduce gens to save time on launch! Use `--eval_gens` (no arguments) to disable evaluation.", + help="Generations (of OU) to play against heuristics between training epochs. Win rates usually saturate at 90%%+ quickly, so this is mostly a sanity-check. Reduce gens to save time on launch! Use `--eval_gens` (no arguments) to disable evaluation.", ) parser.add_argument( "--formats", @@ -159,47 +184,138 @@ def create_offline_dataset( obs_space: TokenizedObservationSpace, action_space: ActionSpace, reward_function: RewardFunction, - parsed_replay_dir: str, + parsed_replay_dir: Optional[str] = None, + replay_weight: float = 1.0, + self_play_subsets: Optional[List[str]] = None, + self_play_weights: Optional[List[float]] = None, custom_replay_dir: Optional[str] = None, - custom_replay_sample_weight: float = 0.25, + custom_replay_weight: float = 0.25, verbose: bool = True, formats: Optional[List[str]] = None, + use_cached_filenames: bool = False, ) -> amago.loading.RLDataset: + """ + Create a mixed offline RL dataset from multiple sources. + Args: + obs_space: Tokenized observation space + action_space: Action space + reward_function: Reward function + parsed_replay_dir: Path to parsed replays (None = download from HuggingFace) + replay_weight: Sampling weight for parsed replays + self_play_subsets: List of self-play subsets to include (e.g., ["pac-base", "pac-exploratory"]) + self_play_weights: Sampling weights for each self-play subset (must match length of self_play_subsets) + custom_replay_dir: Path to custom replay directory + custom_replay_weight: Sampling weight for custom replays + verbose: Print dataset loading progress + formats: Battle formats to include + use_cached_filenames: Use cached filename index for faster startup + + Returns: + AMAGO RLDataset (possibly a MixtureOfDatasets) + """ formats = formats or metamon.SUPPORTED_BATTLE_FORMATS + + # Validate self-play weights + if self_play_subsets is not None: + if self_play_weights is None: + # Default to equal weights + self_play_weights = [1.0] * len(self_play_subsets) + elif len(self_play_weights) != len(self_play_subsets): + raise ValueError( + f"--self_play_weights ({len(self_play_weights)}) must match " + f"--self_play_subsets ({len(self_play_subsets)})" + ) + + # Common dataset kwargs dset_kwargs = { "observation_space": obs_space, "action_space": action_space, "reward_function": reward_function, - # amago will handle sequence lengths on its side - "max_seq_len": None, + "max_seq_len": None, # amago handles sequence lengths "formats": formats, - "verbose": verbose, # False to hide dset setup progress bar - "use_cached_filenames": False, # Switch to True to save a lot of time on startup when reusing an identical training set + "verbose": verbose, + "use_cached_filenames": use_cached_filenames, } - parsed_replays_amago = MetamonAMAGODataset( - dset_name="Metamon Parsed Replays", - parsed_replay_dset=ParsedReplayDataset( - dset_root=parsed_replay_dir, **dset_kwargs - ), - ) - if custom_replay_dir is not None: - custom_dset_amago = MetamonAMAGODataset( - dset_name="Custom Parsed Replays", - parsed_replay_dset=ParsedReplayDataset( - dset_root=custom_replay_dir, **dset_kwargs - ), + + # Collect all datasets and weights + datasets = [] + weights = [] + dataset_info = [] # For pretty printing + + # 1. Parsed Replays (human battles) + if replay_weight > 0: + parsed_dset = ParsedReplayDataset(dset_root=parsed_replay_dir, **dset_kwargs) + datasets.append( + MetamonAMAGODataset( + dset_name="Parsed Replays (Human)", + parsed_replay_dset=parsed_dset, + ) + ) + weights.append(replay_weight) + dataset_info.append(("Parsed Replays (Human)", len(parsed_dset), replay_weight)) + + # 2. Self-Play Datasets + if self_play_subsets is not None: + for subset, weight in zip(self_play_subsets, self_play_weights): + if weight > 0: + selfplay_dset = SelfPlayDataset(subset=subset, **dset_kwargs) + datasets.append( + MetamonAMAGODataset( + dset_name=f"Self-Play ({subset})", + parsed_replay_dset=selfplay_dset, + ) + ) + weights.append(weight) + dataset_info.append( + (f"Self-Play ({subset})", len(selfplay_dset), weight) + ) + + # 3. Custom Replay Directory + if custom_replay_dir is not None and custom_replay_weight > 0: + custom_dset = MetamonDataset(dset_root=custom_replay_dir, **dset_kwargs) + datasets.append( + MetamonAMAGODataset( + dset_name="Custom Replays", + parsed_replay_dset=custom_dset, + ) ) - amago_dataset = amago.loading.MixtureOfDatasets( - datasets=[parsed_replays_amago, custom_dset_amago], - sampling_weights=[ - 1 - custom_replay_sample_weight, - custom_replay_sample_weight, - ], + weights.append(custom_replay_weight) + dataset_info.append(("Custom Replays", len(custom_dset), custom_replay_weight)) + + if not datasets: + raise ValueError( + "No datasets configured! Provide at least one of: parsed replays, self-play subsets, or custom replay dir." ) + + # Renormalize weights to sum to 1 + total_weight = sum(weights) + normalized_weights = [w / total_weight for w in weights] + + # Print pretty summary + print("\n" + "=" * 70) + print("TRAINING DATASET SUMMARY") + print("=" * 70) + print(f"{'Dataset':<35} {'Files':>12} {'Weight':>10} {'Norm Weight':>12}") + print("-" * 70) + total_files = 0 + for (name, num_files, raw_weight), norm_weight in zip( + dataset_info, normalized_weights + ): + total_files += num_files + print(f"{name:<35} {num_files:>12,} {raw_weight:>10.2f} {norm_weight:>11.1%}") + print("-" * 70) + print(f"{'TOTAL':<35} {total_files:>12,} {total_weight:>10.2f} {'100.0%':>12}") + print("=" * 70 + "\n") + + # Create final dataset + if len(datasets) == 1: + return datasets[0] else: - amago_dataset = parsed_replays_amago - return amago_dataset + return amago.loading.MixtureOfDatasets( + datasets=datasets, + sampling_weights=normalized_weights, + ) def create_offline_rl_trainer( @@ -323,6 +439,12 @@ def create_offline_rl_trainer( add_cli(parser) args = parser.parse_args() + metamon.print_banner() + print( + f" Run: {args.run_name} | Model: {args.model_gin_config} | Training: {args.train_gin_config}" + ) + print() + # agent input/output/rewards obs_space = TokenizedObservationSpace( get_observation_space(args.obs_space), get_tokenizer(args.tokenizer) @@ -336,9 +458,13 @@ def create_offline_rl_trainer( action_space=action_space, reward_function=reward_function, parsed_replay_dir=args.parsed_replay_dir, + replay_weight=args.replay_weight, + self_play_subsets=args.self_play_subsets, + self_play_weights=args.self_play_weights, custom_replay_dir=args.custom_replay_dir, - custom_replay_sample_weight=args.custom_replay_sample_weight, + custom_replay_weight=args.custom_replay_weight, formats=args.formats, + use_cached_filenames=args.use_cached_filenames, ) # quick-setup for an offline RL experiment diff --git a/pyproject.toml b/pyproject.toml index 5c57d87a0..8da348675 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "metamon" -version = "1.4.0" +version = "1.5.0" description = "Baselines and Datasets for Pokémon Showdown RL" readme = { file = "README.md", content-type = "text/markdown" } authors = [ @@ -33,6 +33,7 @@ dependencies = [ "termcolor", "huggingface_hub", "datasets", + "ratarmountcore", "poke-env @ git+https://github.com/UT-Austin-RPL/poke-env.git" ]