Skip to content

Improve minimization reporting#46

Merged
jessegrabowski merged 2 commits intomainfrom
jax-tests
Dec 13, 2025
Merged

Improve minimization reporting#46
jessegrabowski merged 2 commits intomainfrom
jax-tests

Conversation

@jessegrabowski
Copy link
Owner

  • Fix bug where loss functions that return scalars would raise
  • Don't square the objective of scalar loss function
  • Add tests for jax functions

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves minimization reporting by fixing a bug where scalar loss functions would raise errors and adds comprehensive JAX integration tests. The changes remove the incorrect squaring of objective values and properly handle JAX arrays and other array-like return values.

Key Changes:

  • Fixed update_progressbar method to properly handle scalar and array return values from objective functions
  • Added comprehensive JAX integration tests covering minimize, root, and basinhopping functions
  • Added JAX as a conda environment dependency

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
better_optimize/wrapper.py Fixed the update_progressbar method to use np.asarray for converting values, properly distinguish between root-finding (uses norm) and minimization (requires scalar), and removed incorrect squaring of objective values
tests/test_jax_integration.py Added comprehensive tests for JAX array compatibility with minimize, root, and basinhopping functions
conda_envs/better_optimize.yml Added JAX as a dependency for testing JAX integration

After a thorough review of the code changes, I found no issues. The implementation is correct, follows best practices, and the tests are comprehensive. The changes successfully address all stated goals in the PR description.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@jessegrabowski jessegrabowski merged commit 487a1c9 into main Dec 13, 2025
7 checks passed
@jessegrabowski jessegrabowski deleted the jax-tests branch December 13, 2025 04:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments