From b9603e4c5a493ebcfc8e79253b4d64fa1796f767 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Mon, 24 Nov 2025 20:15:52 -0500 Subject: [PATCH 1/5] Remove top level import - fix toml instead --- pyproject.toml | 6 ++++-- walrus/__init__.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3086f5f..d8056bf 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] -include = ["walrus"] +include = ["walrus*"] [project] name = "walrus" @@ -21,7 +21,9 @@ dependencies = [ "plotly>=5.0,<6", "the_well[benchmark] @ git+https://github.com/PolymathicAI/the_well.git@master", "timm>=1.0,<2", - "torch>=2.4, <=2.6", + "torch==2.5.1", + "torchvision==0.20.1", + "torchaudio==2.5.1", "torchinfo>=1.8.0,<2", "wandb>=0.17.9" ] diff --git a/walrus/__init__.py b/walrus/__init__.py index 4bc9527..c8cdefb 100755 --- a/walrus/__init__.py +++ b/walrus/__init__.py @@ -1,9 +1,9 @@ -from . import data, models, optim, trainer, utils +# from . import data, models, optim, trainer, utils -__all__ = [ - "data", - "models", - "optim", - "trainer", - "utils", -] +# __all__ = [ + # "data", + # "models", + # "optim", + # "trainer", + # "utils", +# ] From 3cd747ca81871dd5a7e85df0db1d1ca3c56ebfe7 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Mon, 24 Nov 2025 20:28:24 -0500 Subject: [PATCH 2/5] Confirm running and use HF objects so defaults match --- .../walrus_example_1_RunningWalrus.ipynb | 2211 ++++++++++++++++- 1 file changed, 2177 insertions(+), 34 deletions(-) diff --git a/demo_notebooks/walrus_example_1_RunningWalrus.ipynb b/demo_notebooks/walrus_example_1_RunningWalrus.ipynb index 8ea3b74..ea8ba25 100644 --- a/demo_notebooks/walrus_example_1_RunningWalrus.ipynb +++ b/demo_notebooks/walrus_example_1_RunningWalrus.ipynb @@ -35,10 +35,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "ead2153e", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2025-11-24 20:23:34-- https://huggingface.co/polymathic-ai/walrus/resolve/main/extended_config.yaml\n", + "Resolving huggingface.co (huggingface.co)... 3.168.73.111, 3.168.73.106, 3.168.73.129, ...\n", + "Connecting to huggingface.co (huggingface.co)|3.168.73.111|:443... connected.\n", + "HTTP request sent, awaiting response... 307 Temporary Redirect\n", + "Location: /api/resolve-cache/models/polymathic-ai/walrus/f8fd578e7d8a15d2e510d32d5952b9eddc37548c/extended_config.yaml?%2Fpolymathic-ai%2Fwalrus%2Fresolve%2Fmain%2Fextended_config.yaml=&etag=%223eb6c57e518c935eba9ade2e0b7a3b3381f491b6%22 [following]\n", + "--2025-11-24 20:23:35-- https://huggingface.co/api/resolve-cache/models/polymathic-ai/walrus/f8fd578e7d8a15d2e510d32d5952b9eddc37548c/extended_config.yaml?%2Fpolymathic-ai%2Fwalrus%2Fresolve%2Fmain%2Fextended_config.yaml=&etag=%223eb6c57e518c935eba9ade2e0b7a3b3381f491b6%22\n", + "Reusing existing connection to huggingface.co:443.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 8037 (7.8K) [text/plain]\n", + "Saving to: ‘./configs//extended_config.yaml’\n", + "\n", + "./configs//extended 100%[===================>] 7.85K --.-KB/s in 0s \n", + "\n", + "2025-11-24 20:23:35 (232 MB/s) - ‘./configs//extended_config.yaml’ saved [8037/8037]\n", + "\n", + "--2025-11-24 20:23:35-- https://huggingface.co/polymathic-ai/walrus/resolve/main/walrus.pt\n", + "Resolving huggingface.co (huggingface.co)... 3.168.73.38, 3.168.73.129, 3.168.73.111, ...\n", + "Connecting to huggingface.co (huggingface.co)|3.168.73.38|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://cas-bridge.xethub.hf.co/xet-bridge-us/691f2b27e01d6f3db3e150bf/fb24df8b23d8cc37ba6511bc0ff2f01b27c8f2ad63be4f405799d1a583942cf8?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251125%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251125T012335Z&X-Amz-Expires=3600&X-Amz-Signature=eca7ca53423f59ab92fddba4981cf1157599ceb3c07cf95a5592549e81034c98&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=public&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27walrus.pt%3B+filename%3D%22walrus.pt%22%3B&x-id=GetObject&Expires=1764037415&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2NDAzNzQxNX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OTFmMmIyN2UwMWQ2ZjNkYjNlMTUwYmYvZmIyNGRmOGIyM2Q4Y2MzN2JhNjUxMWJjMGZmMmYwMWIyN2M4ZjJhZDYzYmU0ZjQwNTc5OWQxYTU4Mzk0MmNmOCoifV19&Signature=AOKwsN6u2-Z32j8tYosravl-4D6pXxTGIEqUqKWTLLBnfMkO2g9ow%7EaaJDPGVvq7jltSkl1k2ylBqoZPtkMPTC3lwPovdqZsLhXXQwXmyoQddSA6qFYS%7EHBVJFpAkeuOdJ00AP63cD-RI7DUXgHP2t7wO4H14RVq626nElTU9KIMuFbThGFs21neG8PAn0Kw73ilE-DLkTew53aWpYxKNC3Eu5MS0LFWNpPh6%7ElsJas3fa8xFnZ%7Eu06kO5IxOyoORpDPd4Fln6FoAWCvvbr36g6RBJDJS%7EkzNa4IpfFa5iqLkLileav1g0-UvQU7hO4-39X4xUUkgUDq4eyfdrUOzg__&Key-Pair-Id=K2L8F4GPSG1IFC [following]\n", + "--2025-11-24 20:23:35-- https://cas-bridge.xethub.hf.co/xet-bridge-us/691f2b27e01d6f3db3e150bf/fb24df8b23d8cc37ba6511bc0ff2f01b27c8f2ad63be4f405799d1a583942cf8?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251125%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251125T012335Z&X-Amz-Expires=3600&X-Amz-Signature=eca7ca53423f59ab92fddba4981cf1157599ceb3c07cf95a5592549e81034c98&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=public&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27walrus.pt%3B+filename%3D%22walrus.pt%22%3B&x-id=GetObject&Expires=1764037415&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2NDAzNzQxNX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OTFmMmIyN2UwMWQ2ZjNkYjNlMTUwYmYvZmIyNGRmOGIyM2Q4Y2MzN2JhNjUxMWJjMGZmMmYwMWIyN2M4ZjJhZDYzYmU0ZjQwNTc5OWQxYTU4Mzk0MmNmOCoifV19&Signature=AOKwsN6u2-Z32j8tYosravl-4D6pXxTGIEqUqKWTLLBnfMkO2g9ow%7EaaJDPGVvq7jltSkl1k2ylBqoZPtkMPTC3lwPovdqZsLhXXQwXmyoQddSA6qFYS%7EHBVJFpAkeuOdJ00AP63cD-RI7DUXgHP2t7wO4H14RVq626nElTU9KIMuFbThGFs21neG8PAn0Kw73ilE-DLkTew53aWpYxKNC3Eu5MS0LFWNpPh6%7ElsJas3fa8xFnZ%7Eu06kO5IxOyoORpDPd4Fln6FoAWCvvbr36g6RBJDJS%7EkzNa4IpfFa5iqLkLileav1g0-UvQU7hO4-39X4xUUkgUDq4eyfdrUOzg__&Key-Pair-Id=K2L8F4GPSG1IFC\n", + "Resolving cas-bridge.xethub.hf.co (cas-bridge.xethub.hf.co)... 13.33.67.84, 13.33.67.95, 13.33.67.126, ...\n", + "Connecting to cas-bridge.xethub.hf.co (cas-bridge.xethub.hf.co)|13.33.67.84|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 5145064530 (4.8G)\n", + "Saving to: ‘./checkpoints//walrus.pt’\n", + "\n", + "./checkpoints//walr 100%[===================>] 4.79G 194MB/s in 22s \n", + "\n", + "2025-11-24 20:23:57 (221 MB/s) - ‘./checkpoints//walrus.pt’ saved [5145064530/5145064530]\n", + "\n" + ] + } + ], "source": [ "import os\n", "\n", @@ -56,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "id": "1d9353bb", "metadata": {}, "outputs": [ @@ -65,16 +103,15 @@ "output_type": "stream", "text": [ "data_workers: 10\n", - "name: Walrus_ft_major_v2-wella-delta-Isotr[Space-Adapt-]-AdamW-0.0002\n", - "finetune: true\n", + "name: Walrus-wella-delta-Isotr[Space-Adapt-]-AdamW-0.0002\n", "automatic_setup: true\n", "trainer:\n", " _target_: walrus.trainer.Trainer\n", - " max_epoch: 201\n", + " max_epoch: 200\n", " val_frequency: 10\n", " rollout_val_frequency: 10\n", " short_validation_length: 20\n", - " max_rollout_steps: 60\n", + " max_rollout_steps: 200\n", " num_time_intervals: 5\n", " enable_amp: false\n", " loss_fn:\n", @@ -372,7 +409,7 @@ "auto_resume: true\n", "folder_override: ''\n", "checkpoint_override: ''\n", - "config_override: /mnt/home/polymathic/ceph/MPPX_logging/golden_checkpoints/extended_config.yaml\n", + "config_override: null\n", "validation_mode: false\n", "frozen_components:\n", "- model\n", @@ -381,18 +418,18 @@ " local_size: 4\n", "logger:\n", " wandb: true\n", - " wandb_project_name: MPPX_Training_Attempts\n", + " wandb_project_name: walrus_Training_Attempts\n", "checkpoint:\n", " _target_: walrus.trainer.checkpoints.CheckPointer\n", - " save_dir: /mnt/home/polymathic/ceph/MPPX_logging/runs/Walrus_ft_major_v2-wella-delta-Isotr[Space-Adapt-]-AdamW-0.0002/finetune/0/checkpoints\n", + " save_dir: /mnt/home/polymathic/ceph/walrus_logging/runs/Walrus_ft_major_v2-wella-delta-Isotr[Space-Adapt-]-AdamW-0.0002/0/checkpoints\n", " load_checkpoint_path: null\n", - " coalesced_checkpoint_path: /mnt/home/polymathic/ceph/MPPX_logging/golden_checkpoints/walrus_final/walrus_final_coalesced.pth\n", + " coalesced_checkpoint_path: null\n", " save_best: true\n", " checkpoint_frequency: 20\n", " align_fields: true\n", " load_chkpt_after_finetuning_expansion: false\n", "finetuning_mods: {}\n", - "experiment_dir: /mnt/home/polymathic/ceph/MPPX_logging/runs\n", + "experiment_dir: /mnt/home/polymathic/ceph/walrus_logging/runs\n", "\n" ] } @@ -442,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "id": "e611b9a6", "metadata": {}, "outputs": [], @@ -475,10 +512,1063 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "bc5b4eb0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mnt/home/mmccabe/venvs/clean_walrus/lib/python3.11/site-packages/walrus/models/temporal_blocks/axial_time_attention.py:38: FutureWarning: `nn.init.kaiming_uniform` is now deprecated in favor of `nn.init.kaiming_uniform_`.\n", + " init.kaiming_uniform(\n" + ] + }, + { + "data": { + "text/plain": [ + "IsotropicModel(\n", + " (patch_jitterer): FixedPatchJittererBoundaryPad()\n", + " (embed): ModuleDict(\n", + " (2): SpaceBagAdaptiveDVstrideEncoder(\n", + " (proj1): Conv3d(67, 352, kernel_size=(8, 8, 8), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): Conv3d(352, 1408, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm2): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " )\n", + " (3): SpaceBagAdaptiveDVstrideEncoder(\n", + " (proj1): Conv3d(67, 352, kernel_size=(8, 8, 8), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): Conv3d(352, 1408, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm2): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " )\n", + " )\n", + " (blocks): ModuleList(\n", + " (0): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): Identity()\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): Identity()\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (1): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.001)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.001)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (2): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.003)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.003)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (3): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.004)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.004)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (4): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.005)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.005)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (5): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.006)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.006)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (6): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.008)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.008)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (7): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.009)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.009)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (8): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.010)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.010)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (9): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.012)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.012)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (10): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.013)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.013)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (11): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.014)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.014)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (12): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.015)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.015)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (13): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.017)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.017)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (14): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.018)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.018)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (15): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.019)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.019)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (16): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.021)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.021)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (17): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.022)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.022)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (18): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.023)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.023)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (19): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.024)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.024)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (20): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.026)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.026)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (21): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.027)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.027)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (22): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.028)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.028)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (23): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.029)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.029)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (24): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.031)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.031)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (25): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.032)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.032)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (26): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.033)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.033)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (27): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.035)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.035)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (28): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.036)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.036)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (29): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.037)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.037)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (30): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.038)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.038)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (31): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.040)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.040)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (32): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.041)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.041)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (33): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.042)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.042)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (34): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.044)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.044)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (35): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.045)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.045)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (36): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.046)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.046)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (37): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.047)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.047)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (38): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.049)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.049)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (39): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.050)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.050)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " )\n", + " (debed): ModuleDict(\n", + " (2): AdaptiveDVstrideDecoder(\n", + " (proj1): ConvTranspose3d(1408, 352, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): ConvTranspose3d(352, 67, kernel_size=(8, 8, 8), stride=(1, 1, 1))\n", + " )\n", + " (3): AdaptiveDVstrideDecoder(\n", + " (proj1): ConvTranspose3d(1408, 352, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): ConvTranspose3d(352, 67, kernel_size=(8, 8, 8), stride=(1, 1, 1))\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "field_to_index_map = data_module.train_dataset.field_to_index_map\n", "# Retrieve the number of fields used in training\n", @@ -511,7 +1601,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "id": "068890a0", "metadata": {}, "outputs": [], @@ -530,7 +1620,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, "id": "4342d53a", "metadata": {}, "outputs": [ @@ -538,7 +1628,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Metadata: WellMetadata(dataset_name='acoustic_scattering_inclusions', n_spatial_dims=3, spatial_resolution=(256, 256, 1), scalar_names=[], constant_scalar_names=[], field_names={0: ['pressure'], 1: ['velocity_x', 'velocity_y', 'velocity_z'], 2: []}, constant_field_names={0: ['density', 'speed_of_sound'], 1: [], 2: []}, boundary_condition_types=['OPEN', 'WALL', 'PERIODIC'], n_files=4, n_trajectories_per_file=[100, 100, 100, 100], n_steps_per_trajectory=[102, 102, 102, 102], grid_type='cartesian')\n", + "Metadata: WellMetadata(dataset_name='acoustic_scattering_inclusions', n_spatial_dims=3, spatial_resolution=(256, 256, 1), scalar_names=[], constant_scalar_names=[], field_names={0: ['pressure'], 1: ['velocity_x', 'velocity_y', 'velocity_z'], 2: []}, constant_field_names={0: ['density', 'speed_of_sound'], 1: [], 2: []}, boundary_condition_types=['WALL', 'OPEN', 'PERIODIC'], n_files=4, n_trajectories_per_file=[100, 100, 100, 100], n_steps_per_trajectory=[102, 102, 102, 102], grid_type='cartesian')\n", "Trajectory example keys: dict_keys(['input_fields', 'output_fields', 'constant_fields', 'boundary_conditions', 'space_grid', 'input_time_grid', 'output_time_grid', 'padded_field_mask', 'field_indices', 'metadata'])\n" ] } @@ -588,7 +1678,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 7, "id": "984997b9", "metadata": {}, "outputs": [], @@ -698,7 +1788,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 8, "id": "dd3cc52a", "metadata": {}, "outputs": [], @@ -737,7 +1827,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 9, "id": "ee402e79", "metadata": {}, "outputs": [ @@ -747,7 +1837,7 @@ "{}" ] }, - "execution_count": 16, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -769,7 +1859,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 10, "id": "8ca95388", "metadata": {}, "outputs": [ @@ -784,7 +1874,7 @@ "" ] }, - "execution_count": 17, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -813,7 +1903,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "288c6f34", "metadata": {}, "outputs": [], @@ -832,7 +1922,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "4b7c0c2a", "metadata": {}, "outputs": [ @@ -842,7 +1932,7 @@ "{'closed_boundary': 0, 'open_boundary': 1, 'bias_correction': 2, 'pressure': 3, 'velocity_x': 4, 'velocity_y': 5, 'velocity_z': 6, 'zeros_like_density': 7, 'speed_of_sound': 8, 'concentration': 9, 'D_xx': 10, 'D_xy': 11, 'D_xz': 12, 'D_yx': 13, 'D_yy': 14, 'D_yz': 15, 'D_zx': 16, 'D_zy': 17, 'D_zz': 18, 'E_xx': 19, 'E_xy': 20, 'E_xz': 21, 'E_yx': 22, 'E_yy': 23, 'E_yz': 24, 'E_zx': 25, 'E_zy': 26, 'E_zz': 27, 'density': 28, 'energy': 29, 'velocity_r': 30, 'velocity_theta': 31, 'velocity_phi': 32, 'momentum_x': 33, 'momentum_y': 34, 'momentum_z': 35, 'pressure_re': 36, 'pressure_im': 37, 'mask': 38, 'magnetic_field_x': 39, 'magnetic_field_y': 40, 'magnetic_field_z': 41, 'A': 42, 'B': 43, 'height': 44, 'internal_energy': 45, 'temperature': 46, 'electron_fraction': 47, 'entropy': 48, 'magnetic_field_log_r': 49, 'magnetic_field_theta': 50, 'magnetic_field_phi': 51, 'velocity_log_r': 52, 'buoyancy': 53, 'tracer': 54, 'log10_density': 55, 'log10_temperature': 56, 'c_zz': 57, 'C_xx': 58, 'C_xy': 59, 'C_xz': 60, 'C_yx': 61, 'C_yy': 62, 'C_yz': 63, 'C_zx': 64, 'C_zy': 65, 'C_zz': 66}" ] }, - "execution_count": 15, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -865,10 +1955,1063 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "b840e56d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mnt/home/mmccabe/venvs/clean_walrus/lib/python3.11/site-packages/walrus/models/temporal_blocks/axial_time_attention.py:38: FutureWarning: `nn.init.kaiming_uniform` is now deprecated in favor of `nn.init.kaiming_uniform_`.\n", + " init.kaiming_uniform(\n" + ] + }, + { + "data": { + "text/plain": [ + "IsotropicModel(\n", + " (patch_jitterer): FixedPatchJittererBoundaryPad()\n", + " (embed): ModuleDict(\n", + " (2): SpaceBagAdaptiveDVstrideEncoder(\n", + " (proj1): Conv3d(68, 352, kernel_size=(8, 8, 8), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): Conv3d(352, 1408, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm2): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " )\n", + " (3): SpaceBagAdaptiveDVstrideEncoder(\n", + " (proj1): Conv3d(68, 352, kernel_size=(8, 8, 8), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): Conv3d(352, 1408, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm2): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " )\n", + " )\n", + " (blocks): ModuleList(\n", + " (0): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): Identity()\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): Identity()\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (1): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.001)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.001)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (2): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.003)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.003)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (3): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.004)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.004)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (4): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.005)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.005)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (5): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.006)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.006)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (6): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.008)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.008)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (7): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.009)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.009)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (8): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.010)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.010)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (9): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.012)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.012)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (10): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.013)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.013)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (11): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.014)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.014)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (12): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.015)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.015)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (13): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.017)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.017)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (14): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.018)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.018)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (15): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.019)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.019)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (16): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.021)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.021)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (17): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.022)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.022)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (18): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.023)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.023)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (19): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.024)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.024)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (20): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.026)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.026)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (21): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.027)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.027)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (22): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.028)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.028)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (23): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.029)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.029)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (24): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.031)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.031)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (25): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.032)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.032)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (26): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.033)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.033)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (27): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.035)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.035)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (28): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.036)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.036)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (29): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.037)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.037)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (30): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.038)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.038)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (31): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.040)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.040)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (32): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.041)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.041)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (33): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.042)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.042)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (34): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.044)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.044)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (35): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.045)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.045)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (36): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.046)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.046)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (37): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.047)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.047)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (38): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.049)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.049)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " (39): SpaceTimeSplitBlock(\n", + " (space_mixing): FullAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (fused_ff_qkv): Linear(in_features=1408, out_features=9856, bias=True)\n", + " (activation): SwiGLU()\n", + " (ff_out): Linear(in_features=2816, out_features=1408, bias=True)\n", + " (attn_out): Linear(in_features=1408, out_features=1408, bias=False)\n", + " (q_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (k_norm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rotary_emb): RotaryEmbedding()\n", + " (drop_path): DropPath(drop_prob=0.050)\n", + " )\n", + " (time_mixing): AxialTimeAttention(\n", + " (norm1): RMSGroupNorm(16, 1408, eps=1e-06, affine=True)\n", + " (input_head): Conv3d(1408, 4224, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (output_head): Conv3d(1408, 1408, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n", + " (qnorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (knorm): LayerNorm((88,), eps=1e-05, elementwise_affine=True)\n", + " (rel_pos_bias): RelativePositionBias(\n", + " (relative_attention_bias): Embedding(32, 16)\n", + " )\n", + " (drop_path): DropPath(drop_prob=0.050)\n", + " )\n", + " (channel_mixing): Identity()\n", + " )\n", + " )\n", + " (debed): ModuleDict(\n", + " (2): AdaptiveDVstrideDecoder(\n", + " (proj1): ConvTranspose3d(1408, 352, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): ConvTranspose3d(352, 68, kernel_size=(8, 8, 8), stride=(1, 1, 1))\n", + " )\n", + " (3): AdaptiveDVstrideDecoder(\n", + " (proj1): ConvTranspose3d(1408, 352, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)\n", + " (norm1): RMSGroupNorm(16, 352, eps=1e-06, affine=True)\n", + " (act): SiLU()\n", + " (proj2): ConvTranspose3d(352, 68, kernel_size=(8, 8, 8), stride=(1, 1, 1))\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "new_field_to_index_map = copy.deepcopy(field_to_index_map)\n", "new_field_to_index_map[\"blubber\"] = max(field_to_index_map.values()) + 1 # New index for \"blubber\"\n", @@ -911,7 +3054,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "9f014caf", "metadata": {}, "outputs": [], @@ -951,7 +3094,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "d4a8413b", "metadata": {}, "outputs": [], @@ -1000,9 +3143,9 @@ ], "metadata": { "kernelspec": { - "display_name": "mamba_well", + "display_name": "clean_walrus", "language": "python", - "name": "python3" + "name": "clean_walrus" }, "language_info": { "codemirror_mode": { @@ -1014,7 +3157,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.11" } }, "nbformat": 4, From d2e9785a85ecbef5dd6469f961511ec04f99869e Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Mon, 24 Nov 2025 20:30:38 -0500 Subject: [PATCH 3/5] Remove imports --- walrus/__init__.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/walrus/__init__.py b/walrus/__init__.py index c8cdefb..e69de29 100755 --- a/walrus/__init__.py +++ b/walrus/__init__.py @@ -1,9 +0,0 @@ -# from . import data, models, optim, trainer, utils - -# __all__ = [ - # "data", - # "models", - # "optim", - # "trainer", - # "utils", -# ] From f52964d33195c1ff19b5593fc8ef346166d9c681 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Mon, 24 Nov 2025 20:47:52 -0500 Subject: [PATCH 4/5] Remove amp test since this was never stable --- tests/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 9056ff3..f615115 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -12,7 +12,7 @@ # Set the different options to test conf_options = { "trainer.prediction_type": ["delta", "full"], - "trainer.enable_amp": ["True", "False"], + "trainer.enable_amp": ["True"], "model.causal_in_time": ["True", "False"], } From 9f164c143efd34f702b78701252dc0c99e522425 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Mon, 24 Nov 2025 20:51:57 -0500 Subject: [PATCH 5/5] Remove amp test correctly by making it false --- tests/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index f615115..f3ac107 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -12,7 +12,7 @@ # Set the different options to test conf_options = { "trainer.prediction_type": ["delta", "full"], - "trainer.enable_amp": ["True"], + "trainer.enable_amp": ["False"], "model.causal_in_time": ["True", "False"], }