Skip to content

Conversation

@rka97
Copy link
Contributor

@rka97 rka97 commented Jan 12, 2026

  1. Change the number of workers to be larger.
  2. Add caching so that we don't have to rescan the file system every time we run the PyTorch workload (made a big difference while debugging!).
  3. Use torch built-in attention for the vision transformer. The custom module was a lot slower than JAX's built-in attention on the A100 (but crucially not on the V100, where they were the same speed!).
  4. Compile the ViT workload as well, without compilation it's a lot slower than JAX.
  5. Add a test that measures the speed difference between the jax and pytorch workloads. Currently this depends on conda envs, but we can modify it to use docker instead.

@rka97 rka97 requested a review from priyakasimbeg January 12, 2026 04:26
@rka97 rka97 self-assigned this Jan 12, 2026
@rka97 rka97 requested a review from a team as a code owner January 12, 2026 04:26
@github-actions
Copy link

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@priyakasimbeg priyakasimbeg merged commit 056281e into a100 Jan 12, 2026
43 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Jan 12, 2026
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants