Skip to content

Conversation

stephen-huan
Copy link
Contributor

Checklist

  • I've formatted the new code by running hatch run dev:format before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Fix some failing tests with jax 0.5.1. It seems from 0.5.0 -> 0.5.1 the numerics got slightly worse (especially for inverse).

For the get_batch test, as I understand since it's sampling with replacement it's possible to get unlucky and have two rows coincide. This is most likely when the batch size is large (the tests that fail are when the batch size is the same as the dataset size). We'd expect n_dim * batch_size / n_data elements to collide (each row has a 1 / n_data chance of colliding, each collision makes n_dim elements the same, and there are batch_size rows).

Related to #490.

@stephen-huan stephen-huan changed the title Fix tests Fix tests with jax 0.5.1 Mar 3, 2025
@stephen-huan
Copy link
Contributor Author

CI failures seem unrelated.

@thomaspinder
Copy link
Collaborator

Yes, unclear what is causing the CI failures right now. I am unable to reproduce locally.

@thomaspinder thomaspinder merged commit af27f0a into JaxGaussianProcesses:main Mar 4, 2025
12 checks passed
@stephen-huan stephen-huan deleted the fix-tests branch March 5, 2025 00:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants