Skip to content

Conversation

theo-brown
Copy link
Contributor

@theo-brown theo-brown commented Mar 24, 2025

Checklist

Changes only to pyproject.toml, so these do not apply

  • I've formatted the new code by running hatch run dev:format before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

  • Added a pin to a specific patch of cola-ml to support jax>0.5.0, pending it being made part of a release.

Issue Number: #490

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening your first PR into GPJax!

If you have not heard from us in a while, please feel free to ping @gpjax/developers or anyone who has commented on the PR. Most of our reviewers are volunteers and sometimes things fall through the cracks.

You can also join us on Slack for real-time discussion.

For details on testing, writing docs, and our review process, please see the developer guide

We strive to be a welcoming and open project. Please follow our Code of Conduct.

@theo-brown
Copy link
Contributor Author

A few unit tests are failing with TypeCheckError. There are also AssertionErrors in test_transform.

@thomaspinder
Copy link
Collaborator

Can you try rebasing from #498 and see if this fixes the issues please?

@theo-brown
Copy link
Contributor Author

theo-brown commented Mar 25, 2025

Oops, my unpin of JAX might be an overreach (see discussions about optax vs jaxopt) - should I limit this PR to just pinning cola?

@thomaspinder
Copy link
Collaborator

Yes, let's restrict this just to cola right now. Once we've removed jaxopt as a dependency, then I think we're good to remove the JAX pin and release a new version :)

@theo-brown theo-brown changed the title Unpin JAX version; add patch to specific cola version Add patch to specific cola version to support jax>0.5.0 Mar 25, 2025
@theo-brown
Copy link
Contributor Author

Rebased and reduced to only pinning cola.

7 tests fail with AttributeError: 'jaxlib.xla_extension.DeviceList' object has no attribute 'device'. Did you mean: 'devices'?, which I presume is a JAX version compatibility problem. I'm not super worried about chasing this down exactly, because it feels like everything's a bit interdependent.

Combining the cola pin with unpinning jax, jaxlib, and flax (required to match unpinned jax version) gives 100% test success.

@Thomas-Christie
Copy link
Contributor

JFYI I noticed there was a new release of Cola yesterday, so we can pin to that now instead of the specific commit.

@thomaspinder
Copy link
Collaborator

Thanks for the effort here, but I'm closing as the issue is resolve in the latest release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants