Skip to content

Commit a85ddd5

Browse files
committed
feat(core): support for Flax (JAX) and MLX backends
1 parent 67e3f99 commit a85ddd5

File tree

5 files changed

+727
-0
lines changed

5 files changed

+727
-0
lines changed

binding/python/py/bintensors/flax.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import os
2+
import hashlib
3+
from _hashlib import HASH
4+
from typing import Dict, Optional, Union, Callable, Tuple
5+
6+
import numpy as np
7+
8+
import jax.numpy as jnp
9+
from jax import Array
10+
from bintensors import numpy, safe_open
11+
12+
13+
def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes:
14+
"""
15+
Saves a dictionary of tensors into raw bytes in bintensors format.
16+
17+
Args:
18+
tensors (`Dict[str, Array]`):
19+
The incoming tensors. Tensors need to be contiguous and dense.
20+
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
21+
Optional text only metadata you might want to save in your header.
22+
For instance it can be useful to specify more about the underlying
23+
tensors. This is purely informative and does not affect tensor loading.
24+
25+
Returns:
26+
`bytes`: The raw bytes representing the format
27+
28+
Example:
29+
30+
```python
31+
from bintensors.flax import save
32+
from jax import numpy as jnp
33+
34+
tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
35+
byte_data = save(tensors)
36+
```
37+
"""
38+
np_tensors = _jnp2np(tensors)
39+
return numpy.save(np_tensors, metadata=metadata)
40+
41+
42+
def save_file(
43+
tensors: Dict[str, Array],
44+
filename: Union[str, os.PathLike],
45+
metadata: Optional[Dict[str, str]] = None,
46+
) -> None:
47+
"""
48+
Saves a dictionary of tensors into raw bytes in bintensors format.
49+
50+
Args:
51+
tensors (`Dict[str, Array]`):
52+
The incoming tensors. Tensors need to be contiguous and dense.
53+
filename (`str`, or `os.PathLike`)):
54+
The filename we're saving into.
55+
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
56+
Optional text only metadata you might want to save in your header.
57+
For instance it can be useful to specify more about the underlying
58+
tensors. This is purely informative and does not affect tensor loading.
59+
60+
Returns:
61+
`None`
62+
63+
Example:
64+
65+
```python
66+
from bintensors.flax import save_file
67+
from jax import numpy as jnp
68+
69+
tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
70+
save_file(tensors, "model.bintensors")
71+
```
72+
"""
73+
np_tensors = _jnp2np(tensors)
74+
return numpy.save_file(np_tensors, filename, metadata=metadata)
75+
76+
77+
def load(data: bytes) -> Dict[str, Array]:
78+
"""
79+
Loads a bintensors file into flax format from pure bytes.
80+
81+
Args:
82+
data (`bytes`):
83+
The content of a bintensors file
84+
85+
Returns:
86+
`Dict[str, Array]`: dictionary that contains name as key, value as `Array` on cpu
87+
88+
Example:
89+
90+
```python
91+
from bintensors.flax import load
92+
93+
file_path = "./my_folder/bert.bintensors"
94+
with open(file_path, "rb") as f:
95+
data = f.read()
96+
97+
loaded = load(data)
98+
```
99+
"""
100+
flat = numpy.load(data)
101+
return _np2jnp(flat)
102+
103+
104+
def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
105+
"""
106+
Loads a bintensors file into flax format.
107+
108+
Args:
109+
filename (`str`, or `os.PathLike`)):
110+
The name of the file which contains the tensors
111+
112+
Returns:
113+
`Dict[str, Array]`: dictionary that contains name as key, value as `Array`
114+
115+
Example:
116+
117+
```python
118+
from bintensors.flax import load_file
119+
120+
file_path = "./my_folder/bert.bt"
121+
loaded = load_file(file_path)
122+
```
123+
"""
124+
result = {}
125+
with safe_open(filename, framework="flax") as f:
126+
for k in f.offset_keys():
127+
result[k] = f.get_tensor(k)
128+
return result
129+
130+
def save_with_checksum(
131+
tensor_dict: Dict[str, Array],
132+
metadata: Optional[Dict[str, str]] = None,
133+
hasher: Callable[[bytes], HASH] = hashlib.sha1,
134+
) -> Tuple[bytes, bytes]:
135+
"""
136+
Saves a dictionary of tensors into raw bytes in bintensors format.
137+
138+
Args:
139+
tensors (`Dict[str, np.ndarray]`):
140+
The incoming tensors. Tensors need to be contiguous and dense.
141+
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
142+
Optional text only metadata you might want to save in your header.
143+
For instance it can be useful to specify more about the underlying
144+
tensors. This is purely informative and does not affect tensor loading.
145+
hasher (`Callable[[bytes], HASH]`):
146+
A hash is an object used to calculate a checksum of a string of information.
147+
148+
149+
Returns:
150+
`bytes`: The raw bytes representing the format
151+
152+
Example:
153+
154+
```python
155+
from bintensors.flax import save_with_checksum
156+
import jax.numpy as jnp
157+
158+
tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
159+
checksum, byte_data = save_with_checksum(tensors)
160+
```
161+
"""
162+
np_tensors = _jnp2np(tensor_dict)
163+
return numpy.save_with_checksum(np_tensors, metadata, hasher)
164+
165+
166+
167+
def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]:
168+
"""
169+
Preform conversion from numpy storage backend to jax storage backend.
170+
171+
Args:
172+
tensors (`Dict[str, np.ndarray]`):
173+
The incoming tensors. Tensors need to be contiguous and dense.
174+
175+
Returns:
176+
`Dict[str, Array]`: dictionary that contains name as key, value as `Array`
177+
"""
178+
for k, v in numpy_dict.items():
179+
numpy_dict[k] = jnp.array(v)
180+
return numpy_dict
181+
182+
183+
def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.ndarray]:
184+
"""
185+
Preform conversion from jax storage backend to numpy storage backend.
186+
187+
Args:
188+
tensors (`Dict[str, Array]`):
189+
The incoming tensors. Tensors need to be contiguous and dense.
190+
Returns:
191+
`Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray`
192+
"""
193+
for k, v in jnp_dict.items():
194+
jnp_dict[k] = np.asarray(v)
195+
return jnp_dict

binding/python/py/bintensors/mlx.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import os
2+
import hashlib
3+
from _hashlib import HASH
4+
from typing import Dict, Optional, Union, Tuple, Callable
5+
6+
import numpy as np
7+
8+
import mlx.core as mx
9+
from bintensors import numpy, safe_open
10+
11+
__all__ = ["save", "save_file", "load", "load_file", "save_with_checksum"]
12+
13+
def save(tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None) -> bytes:
14+
"""
15+
Saves a dictionary of tensors into raw bytes in bintensors format.
16+
17+
Args:
18+
tensors (`Dict[str, mx.array]`):
19+
The incoming tensors. Tensors need to be contiguous and dense.
20+
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
21+
Optional text only metadata you might want to save in your header.
22+
For instance it can be useful to specify more about the underlying
23+
tensors. This is purely informative and does not affect tensor loading.
24+
25+
Returns:
26+
`bytes`: The raw bytes representing the format
27+
28+
Example:
29+
30+
```python
31+
from bintensors.mlx import save
32+
import mlx.core as mx
33+
34+
tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))}
35+
byte_data = save(tensors)
36+
```
37+
"""
38+
np_tensors = _mx2np(tensors)
39+
return numpy.save(np_tensors, metadata=metadata)
40+
41+
42+
def save_file(
43+
tensors: Dict[str, mx.array],
44+
filename: Union[str, os.PathLike],
45+
metadata: Optional[Dict[str, str]] = None,
46+
) -> None:
47+
"""
48+
Saves a dictionary of tensors into raw bytes in bintensors format.
49+
50+
Args:
51+
tensors (`Dict[str, mx.array]`):
52+
The incoming tensors. Tensors need to be contiguous and dense.
53+
filename (`str`, or `os.PathLike`)):
54+
The filename we're saving into.
55+
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
56+
Optional text only metadata you might want to save in your header.
57+
For instance it can be useful to specify more about the underlying
58+
tensors. This is purely informative and does not affect tensor loading.
59+
60+
Returns:
61+
`None`
62+
63+
Example:
64+
65+
```python
66+
from bintensors.mlx import save_file
67+
import mlx.core as mx
68+
69+
tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))}
70+
save_file(tensors, "model.bt")
71+
```
72+
"""
73+
np_tensors = _mx2np(tensors)
74+
return numpy.save_file(np_tensors, filename, metadata=metadata)
75+
76+
77+
def save_with_checksum(
78+
tensor_dict: Dict[str, np.ndarray],
79+
metadata: Optional[Dict[str, str]] = None,
80+
hasher: Callable[[bytes], HASH] = hashlib.sha1,
81+
) -> Tuple[bytes, bytes]:
82+
"""
83+
Saves a dictionary of tensors into raw bytes in bintensors format.
84+
85+
Args:
86+
tensors (`Dict[str, np.ndarray]`):
87+
The incoming tensors. Tensors need to be contiguous and dense.
88+
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
89+
Optional text only metadata you might want to save in your header.
90+
For instance it can be useful to specify more about the underlying
91+
tensors. This is purely informative and does not affect tensor loading.
92+
hasher (`Callable[[bytes], HASH]`):
93+
A hash is an object used to calculate a checksum of a string of information.
94+
95+
96+
Returns:
97+
`bytes`: The raw bytes representing the format
98+
99+
Example:
100+
101+
```python
102+
from bintensors.mlx import save_with_checksum
103+
import mlx.core as mx
104+
105+
tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))}
106+
checksum, byte_data = save_with_checksum(tensors)
107+
```
108+
"""
109+
np_tensors = _mx2np(tensor_dict)
110+
return numpy.save_with_checksum(np_tensors, metadata, hasher)
111+
112+
113+
def load(data: bytes) -> Dict[str, mx.array]:
114+
"""
115+
Loads a bintensors file into MLX format from pure bytes.
116+
117+
Args:
118+
data (`bytes`):
119+
The content of a bintensors file
120+
121+
Returns:
122+
`Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array`
123+
124+
Example:
125+
126+
```python
127+
from bintensors.mlx import load
128+
129+
file_path = "./my_folder/bert.bt"
130+
with open(file_path, "rb") as f:
131+
data = f.read()
132+
133+
loaded = load(data)
134+
```
135+
"""
136+
flat = numpy.load(data)
137+
return _np2mx(flat)
138+
139+
140+
def load_file(filename: Union[str, os.PathLike]) -> Dict[str, mx.array]:
141+
"""
142+
Loads a bintensors file into MLX format.
143+
144+
Args:
145+
filename (`str`, or `os.PathLike`)):
146+
The name of the file which contains the tensors
147+
148+
Returns:
149+
`Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array`
150+
151+
Example:
152+
153+
```python
154+
from bintensors.flax import load_file
155+
156+
file_path = "./my_folder/bert.bt"
157+
loaded = load_file(file_path)
158+
```
159+
"""
160+
result = {}
161+
with safe_open(filename, framework="mlx") as f:
162+
for k in f.offset_keys():
163+
result[k] = f.get_tensor(k)
164+
return result
165+
166+
167+
def _np2mx(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, mx.array]:
168+
"""
169+
Preform conversion from numpy storage backend to mlx storage backend.
170+
171+
Args:
172+
tensors (`Dict[str, np.ndarray]`):
173+
The incoming tensors. Tensors need to be contiguous and dense.
174+
175+
Returns:
176+
`Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array`
177+
"""
178+
for k, v in numpy_dict.items():
179+
numpy_dict[k] = mx.array(v)
180+
return numpy_dict
181+
182+
183+
def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.ndarray]:
184+
"""
185+
Preform conversion from mlx storage backend to numpy storage backend.
186+
187+
Args:
188+
tensors (`Dict[str, mx.array]`):
189+
The incoming tensors. Tensors need to be contiguous and dense.
190+
Returns:
191+
`Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray`
192+
"""
193+
new_dict = {}
194+
for k, v in mx_dict.items():
195+
new_dict[k] = np.asarray(v)
196+
return new_dict

0 commit comments

Comments
 (0)