-
Notifications
You must be signed in to change notification settings - Fork 5
chore(deps): update dependency jaxlib to v0.7.1 #590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
renovate
wants to merge
1
commit into
main
Choose a base branch
from
renovate/jaxlib-0.x-lockfile
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6639dec
to
2d9a256
Compare
2d9a256
to
54d7fc8
Compare
54d7fc8
to
46c2431
Compare
46c2431
to
8b09534
Compare
8b09534
to
e6819e3
Compare
e6819e3
to
4edadca
Compare
4edadca
to
95eb103
Compare
95eb103
to
5a70223
Compare
5a70223
to
386120a
Compare
386120a
to
6a67abe
Compare
6a67abe
to
e6ab988
Compare
e6ab988
to
3693135
Compare
3693135
to
bd9e07f
Compare
bd9e07f
to
d5f13e8
Compare
d5f13e8
to
6016efb
Compare
6016efb
to
1ec8c3d
Compare
1ec8c3d
to
1e66439
Compare
1e66439
to
cf37bef
Compare
cf37bef
to
67cd53b
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR contains the following updates:
0.4.23
->0.7.1
Warning
Some dependencies could not be looked up. Check the Dependency Dashboard for more information.
Release Notes
jax-ml/jax (jaxlib)
v0.7.1
New features
offered free-threading builds on Linux.
Changes
jax.set_mesh
which acts as a global setter and a context manager.Removed
jax.sharding.use_mesh
in favor ofjax.set_mesh
.supported.
jax.lax.dot
now implements the general dot product via the optionaldimension_numbers
argument.Deprecations:
jax.lax.zeros_like_array
is deprecated. Please use{func}
jax.numpy.zeros_like
instead.jax.experimental.host_callback
now results ina
DeprecationWarning
, and will result in anImportError
starting in JAXv0.8.0. Its APIs have raised
NotImplementedError
since JAX version 0.4.35.jax.lax.dot
, passing theprecision
andpreferred_element_type
arguments by position is deprecated. Pass them by explicit keyword instead.
jax.interpreters.ad
,{mod}
jax.interpreters.batching
, and {mod}jax.interpreters.partial_eval
; theyare used rarely if ever outside JAX itself, and most are deprecated without any
public replacement.
v0.7.0
New features:
jax.P
which is an alias forjax.sharding.PartitionSpec
.jax.tree.reduce_associative
.jax.numpy.ndarray.at
indexing methods now support awrap_negative_indices
argument, which defaults to
True
to match the current behavior ({jax-issue}#29434
).Breaking changes:
migration guide
for more information.
implementing linearization via JVP and partial eval).
See migration guide
for more information.
jax.stages.OutInfo
has been replaced withjax.ShapeDtypeStruct
.jax.jit
now requiresfun
to be passed by position, and additionalarguments to be passed by keyword. Doing otherwise will result in an error
starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.
supported version until July 2026.
Layout
,.layout
,.input_layouts
and.output_layouts
have beenrenamed to
Format
,.format
,.input_formats
and.output_formats
DeviceLocalLayout
,.device_local_layout
have been renamed toLayout
and
.layout
jax.experimental.shard
module has been deleted and all the APIs have beenmoved to the
jax.sharding
endpoint. So usejax.sharding.reshard
,jax.sharding.auto_axes
andjax.sharding.explicit_axes
instead of theirexperimental endpoints.
lax.infeed
andlax.outfeed
were removed, after being deprecated inJAX 0.6. The
transfer_to_infeed
andtransfer_from_outfeed
methods werealso removed the
Device
objects.jax.extend.core.primitives.pjit_p
primitive has been renamed tojit_p
, and itsname
attribute has changed from"pjit"
to"jit"
.This affects the string representations of jaxprs. The same primitive is no
longer exported from the
jax.experimental.pjit
module.jax.extend.backend.add_clear_backends_callback
has been removed. Users should use
jax.extend.backend.register_backend_cache
instead.
out_sharding
arg added tox.at[y].set
andx.at[y].add
. Previousbehavior propagating operand sharding removed. Please use
x.at[y].set/add(z, out_sharding=jax.typeof(x).sharding)
to retain previousbehavior if scatter op requires collectives.
Deprecations:
jax.dlpack.SUPPORTED_DTYPES
is deprecated; please use the new{func}
jax.dlpack.is_supported_dtype
function.jax.scipy.special.sph_harm
has been deprecated following a similardeprecation in SciPy; use {func}
jax.scipy.special.sph_harm_y
instead.jax.interpreters.xla
, the previously deprecated symbolsabstractify
andpytype_aval_mappings
have been removed.jax.interpreters.xla.canonicalize_dtype
is deprecated. Forcanonicalizing dtypes, prefer {func}
jax.dtypes.canonicalize_dtype
.For checking whether an object is a valid jax input, prefer
{func}
jax.core.valid_jaxtype
.jax.core
, the previously deprecated symbolsAxisName
,ConcretizationTypeError
,axis_frame
,call_p
,closed_call_p
,get_type
,trace_state_clean
,typematch
, andtypecheck
have beenremoved.
jax.lib.xla_client
, the previously deprecated symbolsDeviceAssignment
,get_topology_for_devices
, andmlir_api_version
have been removed.
jax.extend.ffi
was removed after being deprecated in v0.5.0.Use {mod}
jax.ffi
instead.jax.lib.xla_bridge.get_compile_options
is deprecated, and replaced by{func}
jax.extend.backend.get_compile_options
.v0.6.2
New features:
jax.tree.broadcast
which implements a pytree prefix broadcasting helper.Changes
v0.6.1
New features:
jax.lax.axis_size
which returns the size of the mapped axisgiven its name.
Changes
re-enabled, having been accidentally disabled in a previous release.
these packages, see the JAX installation guide.
jax.sharding.PartitionSpec
no longer inherits from a tuple.jax.ShapeDtypeStruct
is immutable now. Please use.update
method toupdate your
ShapeDtypeStruct
instead of doing in-place updates.Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_p
is deprecated, and will beremoved in JAX v0.7.0.
v0.6.0
Breaking changes
jax.numpy.array
no longer acceptsNone
. This behavior wasdeprecated since November 2023 and is now removed.
config.jax_data_dependent_tracing_fallback
config option,which was added temporarily in v0.4.36 to allow users to opt out of the
new "stackless" tracing machinery.
config.jax_eager_pmap
config option.lower
andtrace
AOT APIs on the resultof
jax.jit
if there have been subsequent wrappers applied.Previously this worked, but silently ignored the wrappers.
The workaround is to apply
jax.jit
last among the wrappers,and similarly for
jax.pmap
.See {jax-issue}
#27873
.cuda12_pip
extra forjax
has been removed; usepip install jax[cuda12]
instead.
Changes
supported.
align with PEP 685. For instance, if you were previously using
pip install jax[cuda12_local]
to install JAX, run
pip install jax[cuda12-local]
instead.jax.jit
now requiresfun
to be passed by position, and additionalarguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
Deprecations
jax.tree_util.build_tree
is deprecated. Use {func}jax.tree.unflatten
instead.
and removed existing CPU/GPU handlers using XLA's custom call.
jax.lib.xla_extension
are now deprecated.jax.interpreters.mlir.hlo
andjax.interpreters.mlir.func_dialect
,which were accidental exports, have been removed. If needed, they are
available from
jax.extend.mlir
.jax.interpreters.mlir.custom_call
is deprecated. The APIs provided by{mod}
jax.ffi
should be used instead.jax.ffi.ffi_call
with inline arguments is nolonger supported. {func}
~jax.ffi.ffi_call
now unconditionally returns acallable.
jax.lib.xla_client
are deprecated:get_topology_for_devices
,heap_profile
,mlir_api_version
,Client
,CompileOptions
,DeviceAssignment
,Frame
,HloSharding
,OpSharding
,Traceback
.jax.util
are deprecated:HashableFunction
,as_hashable_function
,cache
,safe_map
,safe_zip
,split_dict
,split_list
,split_list_checked
,split_merge
,subvals
,toposort
,unzip2
,wrap_name
, andwraps
.jax.dlpack.to_dlpack
has been deprecated. You can usually pass a JAXArray
directly to thefrom_dlpack
function of another framework. If youneed the functionality of
to_dlpack
, use the__dlpack__
attribute of anarray.
jax.lax.infeed
,jax.lax.infeed_p
,jax.lax.outfeed
, andjax.lax.outfeed_p
are deprecated and will be removed in JAX v0.7.0.jax.lib.xla_client
:ArrayImpl
,FftType
,PaddingType
,PrimitiveType
,XlaBuilder
,dtype_to_etype
,ops
,register_custom_call_target
,shape_from_pyval
,Shape
,XlaComputation
.jax.lib.xla_extension
:ArrayImpl
,XlaRuntimeError
.jax
:jax.treedef_is_leaf
,jax.tree_flatten
,jax.tree_map
,jax.tree_leaves
,jax.tree_structure
,jax.tree_transpose
, andjax.tree_unflatten
. Replacements can be found in {mod}jax.tree
or{mod}
jax.tree_util
.jax.core
:AxisSize
,ClosedJaxpr
,EvalTrace
,InDBIdx
,InputType
,Jaxpr
,JaxprEqn
,Literal
,MapPrimitive
,OpaqueTraceState
,OutDBIdx
,Primitive
,Token
,TRACER_LEAK_DEBUGGER_WARNING
,Var
,concrete_aval
,dedup_referents
,escaped_tracer_error
,extend_axis_env_nd
,full_lower
,get_referent
,jaxpr_as_fun
,join_effects
,lattice_join
,leaked_tracer_error
,maybe_find_leaked_tracers
,raise_to_shaped
,raise_to_shaped_mappings
,reset_trace_state
,str_eqn_compact
,substitute_vars_in_output_ty
,typecompat
, andused_axis_names_jaxpr
. Mosthave no public replacement, though a few are available at {mod}
jax.extend.core
.vectorized
argument to {func}~jax.pure_callback
and{func}
~jax.ffi.ffi_call
. Use thevmap_method
parameter instead.v0.5.3
New Features
allow_negative_indices
option to {func}jax.lax.dynamic_slice
,{func}
jax.lax.dynamic_update_slice
and related functions. The default istrue, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size.
replace
option to {func}jax.random.categorical
to enable samplingwithout replacement.
v0.5.1
tracing cache did not include sharding information at all
(although subsequent jit caches did like lowering and compilation caches),
so two equivalent shardings of different types would not retrace,
but now they do. For example:
inp1.sharding is of type SingleDeviceSharding
inp2.sharding is of type NamedSharding
New Features
jax.experimental.custom_dce.custom_dce
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See {jax-issue}
#25956
for moredetails.
jax.lax
: {func}jax.lax.reduce_sum
,{func}
jax.lax.reduce_prod
, {func}jax.lax.reduce_max
, {func}jax.lax.reduce_min
,{func}
jax.lax.reduce_and
, {func}jax.lax.reduce_or
, and {func}jax.lax.reduce_xor
.jax.lax.linalg.qr
, and {func}jax.scipy.linalg.qr
, now supportcolumn-pivoting on CPU and GPU. See {jax-issue}
#20282
andjax.random.multinomial
.{jax-issue}
#25955
for more details.Changes
JAX_CPU_COLLECTIVES_IMPLEMENTATION
andJAX_NUM_CPU_DEVICES
now work asenv vars. Before they could only be specified via jax.config or flags.
JAX_CPU_COLLECTIVES_IMPLEMENTATION
now defaults to'gloo'
, meaningmulti-process CPU communication works out-of-the-box.
jax[tpu]
TPU extra no longer depends on thelibtpu-nightly
package.This package may safely be removed if it is present on your machine; JAX now
uses
libtpu
instead.Deprecations
linear_util.wrap_init
and the constructorcore.Jaxpr
now must take a non-emptycore.DebugInfo
kwarg. Fora limited time, a
DeprecationWarning
is printed ifjax.extend.linear_util.wrap_init
is used without debugging info.A downstream effect of this several other internal functions need debug
info. This change does not affect public APhttps://github.com/jax-ml/jax/issues/26480issues/26480 for more detail.
jax.numpy.ndim
, {func}jax.numpy.shape
, and {func}jax.numpy.size
,non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
Bug fixes
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
need to enable transparent hugepages in your VM image
(
sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'
).We hope to improve this further in future releases.
JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
eviction policy isn't enabled. This should improve performance when using
the cache with large-scale network storage.
v0.5.0
As of this release, JAX now uses
effort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.
Breaking changes
Enable
jax_threefry_partitionable
by default (seethe update note).
This release drops support for Mac x86 wheels. Mac ARM of course remains
supported. For a recent discussion, https://github.com/jax-ml/jax/discussions/22936ns/22936.
Two key factors motivated this decision:
would prefer to ship no release than a broken release.
developers at this point. So it is difficult for us to fix this kind of
problem even if we wanted to.
We are open to re-adding support for Mac x86 if the community is willing
to help support that platform: in particular, we would need the JAX test
suite to pass cleanly on Mac x86 before we could ship releases again.
Changes:
supported version until June 2025.
supported version until June 2025.
jax.numpy.einsum
now defaults tooptimize='auto'
rather thanoptimize='optimal'
. This avoids exponentially-scaling trace-time inthe case of many arguments ({jax-issue}
#25214
).jax.numpy.linalg.solve
no longer supports batched 1D argumentson the right hand side. To recover the previous behavior in these cases,
use
solve(a, b[..., None]).squeeze(-1)
.New Features
jax.numpy.fft.fftn
, {func}jax.numpy.fft.rfftn
,{func}
jax.numpy.fft.ifftn
, and {func}jax.numpy.fft.irfftn
now supporttransforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}
#25606
for more details.{func}
jax.ffi.register_ffi_type_id
function..as_text()
method now supports thedebug_info
optionto include debugging information, e.g., source location, in the output.
Deprecations
jax.interpreters.xla
,abstractify
andpytype_aval_mappings
are now deprecated, having been replaced by symbols of the same name
in {mod}
jax.core
.jax.scipy.special.lpmn
and {func}jax.scipy.special.lpmn_values
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
jax.extend.ffi
submodule was moved to {mod}jax.ffi
, and theprevious import path is deprecated.
Deletions
jax_enable_memories
flag has been deleted and the behavior of that flagis on by default.
jax.lib.xla_client
, the previously-deprecatedDevice
andXlaRuntimeError
symbols have been removed; instead usejax.Device
and
jax.errors.JaxRuntimeError
respectively.jax.experimental.array_api
module has been removed after beingdeprecated in JAX v0.4.32. Since that release, {mod}
jax.numpy
supportsthe array API directly.
v0.4.38
Breaking Changes
XlaExecutable.cost_analysis
now returns adict[str, float]
(instead of asingle-element
list[dict[str, float]]
).Changes:
jax.tree.flatten_with_path
andjax.tree.map_with_path
are addedas shortcuts of the corresponding
tree_util
functions.Deprecations
jax.core
namespace have been deprecated.Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}
jax.extend.core
; see the documentation for {mod}jax.extend
for information on the compatibility guarantees of these semi-public extensions.
jax.core
:check_eqn
,check_type
,check_valid_jaxtype
, andnon_negative_dim
.jax.lib.xla_bridge
:xla_client
anddefault_backend
.jax.lib.xla_client
:_xla
andbfloat16
.jax.numpy
:round_
.New Features
jax.export.export
can be used for device-polymorphic export withshardings constructed with {func}
jax.sharding.AbstractMesh
.See the jax.export documentation.
jax.lax.split
. This is a primitive version of{func}
jax.numpy.split
, added because it yields a more compacttranspose during automatic differentiation.
v0.4.36
Breaking Changes
This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels,
post_process_call
,new_base_main
,custom_bind
, and so on. The change should only affectusers that use JAX internals.
If you do use JAX internals then you may need to
update your code (see
jax-ml/jax@c36e1f7
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallback
flag as a workaround, and ifyou need help updating your code then please file a bug.
{func}
jax.experimental.jax2tf.convert
withnative_serialization=False
or with
enable_xla=False
have been deprecated since July 2024, withJAX version 0.4.31. Now we removed support for these use cases.
jax2tf
with native serialization will still be supported.
In
jax.interpreters.xla
, thexb
,xc
, andxe
symbols have been removedafter being deprecated in JAX v0.4.31. Instead use
xb = jax.lib.xla_bridge
,xc = jax.lib.xla_client
, andxe = jax.lib.xla_extension
.The deprecated module
jax.experimental.export
has been removed. It was replacedby {mod}
jax.export
in JAX v0.4.30. See the migration guidefor information on migrating to the new API.
The
initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
has been removed, after being deprecated in v0.4.27.
Calling
np.asarray
on typed PRNG keys (i.e. keys produced by {func}jax.random.key
)now raises an error. Previously, this returned a scalar object array.
The following deprecated methods and functions in {mod}
jax.export
havebeen removed:
jax.export.DisabledSafetyCheck.shape_assertions
: it had no effectalready.
jax.export.Exported.lowering_platforms
: useplatforms
.jax.export.Exported.mlir_module_serialization_version
:use
calling_convention_version
.jax.export.Exported.uses_shape_polymorphism
:use
uses_global_constants
.lowering_platforms
kwarg for {func}jax.export.export
: useplatforms
instead.The kwargs
symbolic_scope
andsymbolic_constraints
from{func}
jax.export.symbolic_args_specs
have been removed. They weredeprecated in June 2024. Use
scope
andconstraints
instead.Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a
TypeError
.Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run
python build/build.py --help
formore details. Brief overview of the new subcommand options:
build
: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
: Updates requirements_lock.txt files.{func}
jax.scipy.linalg.toeplitz
now does implicit batching on multi-dimensionalinputs. To recover the previous behavior, you can call {func}
jax.numpy.ravel
on the function inputs.
{func}
jax.scipy.special.gamma
and {func}jax.scipy.special.gammasgn
nowreturn NaN for negative integer inputs, to match the behavior of SciPy fhttps://github.com/scipy/scipy/pull/21827ll/21827.
jax.clear_backends
was removed after being deprecated in v0.4.26.We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the
disabled_checks
parameter. See more details in the documentation.
New Features
jax.jit
got a newcompiler_options: dict[str, Any]
argument, forpassing compilation options to XLA. For the moment it's undocumented and
may be in flux.
jax.tree_util.register_dataclass
now allows metadata fields to bedeclared inline via {func}
dataclasses.field
. See the function documentationfor examples.
jax.numpy.put_along_axis
.jax.lax.linalg.eig
and the relatedjax.numpy
functions({func}
jax.numpy.linalg.eig
and {func}jax.numpy.linalg.eigvals
) are nowsupported on GPU. See {jax-issue}
#24663
for more details.jax_exec_time_optimization_effort
andjax_memory_fitting_effort
, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.Bug fixes
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}
#24843
for more details.Deprecations
jax.lib.xla_extension.ArrayImpl
andjax.lib.xla_client.ArrayImpl
are deprecated;use
jax.Array
instead.jax.lib.xla_extension.XlaRuntimeError
is deprecated; usejax.errors.JaxRuntimeError
instead.
v0.4.35
Breaking Changes
jax.numpy.isscalar
now returns True for any array-like object withzero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.
jax.experimental.host_callback
has been deprecated since March 2024, withJAX version 0.4.26. Now we removed it.
See {jax-issue}
#20385
for a discussion of alternatives.Changes:
jax.lax.FftType
was introduced as a public name for the enum of FFToperations. The semi-public API
jax.lib.xla_client.FftType
has beendeprecated.
libtpu
package rather thanlibtpu-nightly
. For the next few releases JAX will pin an empty version oflibtpu-nightly
as well aslibtpu
to ease the transition; that dependencywill be removed in Q1 2025.
Deprecations:
jax.lib.xla_client.PaddingType
has been deprecated.No JAX APIs consume this type, so there is no replacement.
jax.pure_callback
and{func}
jax.extend.ffi.ffi_call
undervmap
has been deprecated and so hasthe
vectorized
parameter to those functions. Thevmap_method
parametershould be used instead for better defined behavior. See the discussion in
{jax-issue}
#23881
for more details.jax.lib.xla_client.register_custom_call_target
hasbeen deprecated. Use the JAX FFI instead.
jax.lib.xla_client.dtype_to_etype
,jax.lib.xla_client.ops
,jax.lib.xla_client.shape_from_pyval
,jax.lib.xla_client.PrimitiveType
,jax.lib.xla_client.Shape
,jax.lib.xla_client.XlaBuilder
, andjax.lib.xla_client.XlaComputation
have been deprecated. Use StableHLOinstead.
v0.4.34
New Functionality
supported.
jax.errors.JaxRuntimeError
has been added as a public alias for theformerly private
XlaRuntimeError
type.Breaking changes
jax_pmap_no_rank_reduction
flag is set toTrue
by default.instead).
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, withJAX version 0.4.26. Now we set the default value of the
--jax_host_callback_legacy
configuration value toTrue
, which means thatif your code uses
jax.experimental.host_callback
APIs, those API callswill be implemented in terms of the new
jax.experimental.io_callback
API.If this breaks your code, for a very limited time, you can set the
--jax_host_callback_legacy
toTrue
. Soon we will remove thatconfiguration option, so you should instead transition to using the
new JAX callback APIs. See {jax-issue}
#20385
for a discussion.Deprecations
jax.numpy.trim_zeros
, non-arraylike arguments or arraylikearguments with
ndim != 1
are now deprecated, and in the future will resultin an error.
jax.core.pp_*
have been removed, afterbeing deprecated in JAX v0.4.30.
jax.lib.xla_client.Device
is deprecated; usejax.Device
instead.jax.lib.xla_client.XlaRuntimeError
has been deprecated. Usejax.errors.JaxRuntimeError
instead.Deletion:
jax.xla_computation
is deleted. It's been 3 months since it's deprecationin 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality as
jax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
..out_info
property ofjax.stages.Lowered
to get theoutput information (like tree structure, shape and dtype).
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.jax.ShapeDtypeStruct
no longer accepts thenamed_shape
argument.The argument was only used by
xmap
which was removed in 0.4.31.jax.tree.map(f, None, non-None)
, which previously emitted aDeprecationWarning
, now raises an error in a future version of jax.None
is only a tree-prefix of itself. To preserve the current behavior, you can
ask
jax.tree.map
to treatNone
as a leaf value by writing:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.jax.sharding.XLACompatibleSharding
has been removed. Please usejax.sharding.Sharding
.Bug fixes
jax.numpy.cumsum
would produce incorrect outputsif a non-boolean input was provided and
dtype=bool
was specified.jax.numpy.ldexp
to get correct gradient.v0.4.33
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of
libtpu
.This release fixes an inaccurate result for F64 tanh on CPU (#23590).
v0.4.32
Compare Source
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.
New Functionality
jax.extend.ffi.ffi_call
and {func}jax.extend.ffi.ffi_lowering
to support the use of the new {ref}
ffi-tutorial
to interface with customC++ and CUDA code from JAX.
Changes
jax_enable_memories
flag is set toTrue
by default.jax.numpy
now supports v2023.12 of the Python Array API Standard.See {ref}
python-array-api
for more information.more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
jax.config.update('jax_cpu_enable_async_dispatch', False)
.jax.process_indices
function to replace thejax.host_ids()
function that was deprecated in JAX v0.2.13.numpy.fabs
,jax.numpy.fabs
has beenmodified to no longer support
complex dtypes
.jax.tree_util.register_dataclass
now checks thatdata_fields
and
meta_fields
includes all dataclass fields withinit=True
and only them, if
nodetype
is a dataclass.jax.numpy
functions now have full {class}~jax.numpy.ufunc
interfaces, including {obj}
~jax.numpy.add
, {obj}~jax.numpy.multiply
,{obj}
~jax.numpy.bitwise_and
, {obj}~jax.numpy.bitwise_or
,{obj}
~jax.numpy.bitwise_xor
, {obj}~jax.numpy.logical_and
,{obj}
~jax.numpy.logical_and
, and {obj}~jax.numpy.logical_and
.In future releases we plan to expand these to other ufuncs.
jax.lax.optimization_barrier
, which allows users to preventcompiler optimizations such as common-subexpression elimination and to
control scheduling.
Breaking changes
jax.extend.mlir.mhlo
) has been removed. Use thestablehlo
dialect instead.Deprecations
jax.numpy.clip
and {func}jax.numpy.hypot
areno longer allowed, after being deprecated since JAX v0.4.27.
jax.lib.xla_bridge.xla_client
: use {mod}jax.lib.xla_client
directly.jax.lib.xla_bridge.get_backend
: use {func}jax.extend.backend.get_backend
.jax.lib.xla_bridge.default_backend
: use {func}jax.extend.backend.default_backend
.jax.experimental.array_api
module is deprecated, and importing it is nolonger required to use the Array API.
jax.numpy
supports the array APIdirectly; see {ref}
python-array-api
for more information.jax.core.check_eqn
,jax.core.check_type
, andjax.core.check_valid_jaxtype
are now deprecated, and will be removed inthe future.
jax.numpy.round_
has been deprecated, following removal of the correspondingAPI in NumPy 2.0. Use {func}
jax.numpy.round
instead.jax.dlpack.from_dlpack
is deprecated.The argument to {func}
jax.dlpack.from_dlpack
should be an array fromanother framework that implements the
__dlpack__
protocol.v0.4.31
Compare Source
Deletion
shard_map
as the replacement.Changes
but we now declare this version constraint formally.
supported version until July 2025.
supported version until December 2024.
supported version until January 2025.
jax.numpy.ceil
, {func}jax.numpy.floor
and {func}jax.numpy.trunc
now return the outputof the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
libdevice.10.bc
is no longer bundled with CUDA wheels. It must beinstalled either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
jax.experimental.pallas.BlockSpec
now expectsblock_shape
tobe passed before
index_map
. The old argument order is deprecated andwill be removed in a future release.
with TPUs/CPUs. For example,
cuda(id=0)
will now beCudaDevice(id=0)
.device
property andto_device
method to {class}jax.Array
, aspart of JAX's Array API support.
Deprecations
polymorphic shapes. From {mod}
jax.core
: removedcanonicalize_shape
,dimension_as_value
,definitely_equal
, andsymbolic_equal_dim
.Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
jax.experimental.jax2tf.convert
withnative_serialization=False
or
enable_xla=False
is now deprecated and this support will be removed ina future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
jax.random.shuffle
has been removed;instead use
jax.random.permutation
withindependent=True
.v0.4.30
Compare Source
Changes
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
jax.experimental.mesh_utils
can now create an efficient mesh for TPU v5e.plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with
pip install jax
, no extras required.to exist in
jax.experimental.export
(which is being deprecated),and will now live in
jax.export
.See the documentation.
Deprecations
jax.core.pp_*
are deprecated, and will be removedin a future release.
TypeError
in a future JAXrelease. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
jax.experimental.export
is deprecated. Use {mod}jax.export
instead.See the migration guide.
x
andy
,x.astype(y)
will raise a warning. To silence it usex.astype(y.dtype)
.jax.xla_computation
is deprecated and will be removed in a future release.Please use the AOT APIs to get the same functionality as
jax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
..out_info
property ofjax.stages.Lowered
to get theoutput information (like tree structure, shape and dtype).
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.v0.4.29
Compare Source
Changes
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.
pip install jax[cuda12]
).jax.experimental.export
API. It is not possible anymore to usefrom jax.experimental.export import export
, and instead you should usefrom jax.experimental import export
.The removed functionality has been deprecated since 0.4.24.
is_leaf
argument to {func}jax.tree.all
& {func}jax.tree_util.tree_all
.Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please usejax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed asjax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
.The old names will be removed after 3 months.
jax.core
:non_negative_dim
,DimSize
,Shape
jax.lax
:tie_in
jax.nn
:normalize
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,XlaOp
.tol
argument of {func}jax.numpy.linalg.matrix_rank
is beingdeprecated and will soon be removed. Use
rtol
instead.rcond
argument of {func}jax.numpy.linalg.pinv
is beingdeprecated and will soon be removed. Use
rtol
instead.jax.config
submodule has been removed. To configure JAXuse
import jax
and then reference the config object viajax.config
.jax.random
APIs no longer accept batched keys, where previouslysome did unintentionally. Going forward, we recommend explicit use of
{func}
jax.vmap
in such cases.jax.scipy.special.beta
, thex
andy
parameters have beenrenamed to
a
andb
for consistency with otherbeta
APIs.New Functionality
jax.experimental.Exported.in_shardings_jax
to constructshardings that can be used with the JAX APIs from the HloShardings
that are stored in the
Exported
objects.v0.4.28
Compare Source
Bug fixes
make_jaxpr
that was breaking Equinox (#21116).Deprecations & removals
kind
argument to {func}jax.numpy.sort
and {func}jax.numpy.argsort
is now removed. Use
stable=True
orstable=False
instead.get_compute_capability
from thejax.experimental.pallas.gpu
module. Use the
compute_capability
attribute of a GPU device, returnedby {func}
jax.devices
or {func}jax.local_devices
, instead.newshape
argument to {func}jax.numpy.reshape
is being deprecatedand will soon be removed. Use
shape
instead.Changes
v0.4.27
Compare Source
New Functionality
jax.numpy.unstack
and {func}jax.numpy.cumulative_sum
,following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
jax_cpu_collectives_implementation
to select theimplementation of cross-process collective operations used by the CPU backend.
Choices available are
'none'
(default),'gloo'
and'mpi'
(requires jaxlib 0.4.26).If set to
'none'
, cross-process collective operations are disabled.Changes
jax.pure_callback
, {func}jax.experimental.io_callback
and {func}
jax.debug.callback
now use {class}jax.Array
insteadof {class}
np.ndarray
. You can recover the old behavior by transformingthe arguments via
jax.tree.map(np.asarray, args)
before passing themto the callback.
complex_arr.astype(bool)
now follows the same semantics as NumPy, returningFalse where
complex_arr
is equal to0 + 0j
, and True otherwise.core.Token
now is a non-trivial class which wraps ajax.Array
. It couldbe created and threaded in and out of computations to build up dependency.
The singleton object
core.token
has been removed, users now should createand use fresh
core.Token
objects instead.by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
. If the newdefault causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
Deprecations & Removals
lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA
environment variable no longer has any effect.jax.numpy.clip
has a new argument signature:a
,a_min
, anda_max
are deprecated in favor ofx
(positional only),min
, andmax
({jax-issue}20550
).device()
method of JAX arrays has been removed, after being deprecatedsince JAX v0.4.21. Use
arr.devices()
instead.initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
is deprecated; empty inputs to softmax are now supported without setting this.
jax.jit
, passing invalidstatic_argnums
orstatic_argnames
now leads to an error rather than a warning.
jax.numpy.hypot
function now issues a deprecation warning whenpassing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
jax.numpy.nonzero
, {func}jax.numpy.where
, andrelated functions now raise an error, following a similar change in NumPy.
jax_cpu_enable_gloo_collectives
is deprecated.Use
jax.config.update('jax_cpu_collectives_implementation', 'gloo')
instead.jax.Array.device_buffer
andjax.Array.device_buffers
methods havebeen removed after being deprecated in JAX v0.4.22. Instead use
{attr}
jax.Array.addressable_shards
and {meth}jax.Array.addressable_data
.condition
,x
, andy
parameters ofjax.numpy.where
are nowpositional-only, following deprecation of the keywords in JAX v0.4.21.
jax.lax.linalg
now must bespecified by keyword. Previously, this raised a DeprecationWarning.
jax.numpy
APIs,including {func}
~jax.numpy.apply_along_axis
,{func}
~jax.numpy.apply_over_axes
, {func}~jax.numpy.inner
,{func}
~jax.numpy.outer
, {func}~jax.numpy.cross
,{func}
~jax.numpy.kron
, and {func}~jax.numpy.lexsort
.Bug fixes
jax.numpy.astype
will now always return a copy whencopy=True
.Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to
copy=False
to preserve backwards compatibility.v0.4.26
Compare Source
New Functionality
jax.numpy.trapezoid
, following the addition of this function inNumPy 2.0.
Changes
jax.numpy.geomspace
now chooses the logarithmic spiralbranch consistent with that of NumPy 2.0.
lax.rng_bit_generator
, and in turn the'rbg'
and
'unsafe_rbg'
PRNG implementations, underjax.vmap
haschanged so that
mapping over keys results in random generation only from the first
key in the batch.
jax.random.key
for construction of PRNG key arraysrather than
jax.random.PRNGKey
.Deprecations & Removals
jax.tree_map
is deprecated; usejax.tree.map
instead, or for backwardcompatibility with older JAX versions, use {func}
jax.tree_util.tree_map
.jax.clear_backends
is deprecated as it does not necessarily do whatits name suggests and can lead to unexpected consequences, e.g., it will not
destroy existing backends and release corresponding owned resources. Use
{func}
jax.clear_caches
if you only want to clean up compilation caches.For backward compatibility or you really need to switch/reinitialize the
default backend, use {func}
jax.extend.backend.clear_backends
.jax.experimental.maps
module andjax.experimental.maps.xmap
aredeprecated. Use
jax.experimental.shard_map
orjax.vmap
with thespmd_axis_name
argument for expressing SPMD device-parallel computations.jax.experimental.host_callback
module is deprecated.Use instead the new JAX external callbacks.
Added
JAX_HOST_CALLBACK_LEGACY
flag to assist in the transition to thenew callbacks. See {jax-issue}
#20385
for a discussion.jax.numpy.array_equal
and {func}jax.numpy.array_equiv
that cannot be converted to a JAX array now results in an exception.
jax_parallel_functions_output_gda
has been removed.This flag was long deprecated and did nothing; its use was a no-op.
jax.interpreters.ad.config
andjax.interpreters.ad.source_info_util
have now been removed. Usejax.config
and
jax.extend.source_info_util
instead.has been supported since October 27th, 2023 and has become the default
since February 1, 2024.
See a description of the versions.
This change could break clients that set a specific
JAX serialization version lower than 9.
v0.4.25
Compare Source
New Features
Interface
import support (requires jaxlib 0.4.24).
x[True]
orx[False]
.jax.tree
module, with a more convenient interface for referencing functionsin {mod}
jax.tree_util
.jax.tree.transpose
(i.e. {func}jax.tree_util.tree_transpose
) now acceptsinner_treedef=None
, in which case the inner treedef will be automatically inferred.Changes
kernels. You can revert to the old behavior by setting the
JAX_TRITON_COMPILE_VIA_XLA
environment variable to"0"
.jax.interpreters.xla
that were removed in v0.4.24have been re-added in v0.4.25, including
backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
, andXLAOp
. These are still considered deprecated, andwill be removed again in the future when better replacements are available.
Refer to {jax-issue}
#19816
for discussion.Deprecations & Removals
jax.numpy.linalg.solve
now shows a deprecation warning for batched 1Dsolves with
b.ndim > 1
. In the future these will be treated as batched 2Dsolves.
of the size of the array. Previously a deprecation warning was raised in the case of
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
following a standard 3 months deprecation cycle (see {ref}
api-compatibility
).These include
jax.config.config
object anddefine_*_state
andDEFINE_*
methods of {data}jax.config
.jax.config
submodule viaimport jax.config
is deprecated.To configure JAX use
import jax
and then reference the config objectvia
jax.config
.v0.4.24
Compare Source
Changes
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to
rule
parameter ofmlir.register_lowering
then add yourprimitive to
jax._src.dispatch.prim_requires_devices_during_lowering
set.This is needed because custom_partitioning and JAX callbacks need physical
devices to create
Sharding
s during lowering.This is a temporary state until we can create
Sharding
s without physicaldevices.
jax.numpy.argsort
and {func}jax.numpy.sort
now support thestable
and
descending
arguments.{mod}
jax.experimental.jax2tf
and {mod}jax.experimental.export
):#19227
)This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
#19235
) we nowconsider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes
cannot interact, e.g., in arithmetic operations.
Scopes are introduced by {func}
jax.experimental.jax2tf.convert
,{func}
jax.experimental.export.symbolic_shape
, {func}jax.experimental.export.symbolic_args_specs
.The scope of a symbolic expression
e
can be read withe.scope
and passedinto the above functions to direct them to construct symbolic expressions in
a given scope.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
to be equal if the normalized form of their difference reduces to 0
({jax-issue}
#19231
; note that this may result in user-visibleConfiguration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR was generated by Mend Renovate. View the repository job log.