@@ -39,30 +39,23 @@ def encode_unsigned_variant_encoding(number: int) -> bytes:
39
39
return number .to_bytes (1 , "little" )
40
40
41
41
42
- def encode_tensor_info ( dtype : str , shape : Tuple [int , ...], offset : Tuple [int , int ]) -> List [ bytes ] :
43
- """Encodes the struct TensorInfo into byte buffer"""
42
+ def encode_header ( id : str , dtype : str , shape : Tuple [int , ...], offset : Tuple [int , int ]) -> bytes :
43
+ """Encodes the struct TensorInfo into byte buffer with string ID prefix. """
44
44
if dtype not in _DTYPE :
45
45
raise ValueError (f"Unsupported dtype: { dtype } " )
46
46
47
- # flatten out the tensor info
48
- layout = chain ([_DTYPE [dtype ], len (shape )], shape , offset )
49
- return b"" .join (list (map (encode_unsigned_variant_encoding , layout )))
47
+ encoded_id = encode_unsigned_variant_encoding (len (id )) + id .encode ("utf-8" )
50
48
51
-
52
- def encode_hash_map (index_map : Dict [str , int ]) -> List [bytes ]:
53
- """Encodes a dictionary of string keys and integer values."""
54
- length = encode_unsigned_variant_encoding (len (index_map ))
55
-
56
- hash_map_layout = chain .from_iterable (
57
- (
58
- encode_unsigned_variant_encoding (len (k )),
59
- k .encode ("utf-8" ),
60
- encode_unsigned_variant_encoding (v ),
61
- )
62
- for k , v in index_map .items ()
49
+ # Compose numeric fields
50
+ numeric_layout = chain (
51
+ [_DTYPE [dtype ], len (shape )],
52
+ shape ,
53
+ offset
63
54
)
64
55
65
- return b"" .join (chain ([length ], hash_map_layout ))
56
+ encoded_tensor_info = b"" .join (encode_unsigned_variant_encoding (x ) for x in numeric_layout )
57
+
58
+ return encoded_id + encoded_tensor_info
66
59
67
60
68
61
def test_empty_file ():
@@ -74,7 +67,7 @@ def test_empty_file():
74
67
# header size + metadata + empty tensors
75
68
MAX_FILE_SIZE = 8 + header_size
76
69
assert header_size == 8 , "expected packed buffer shoudl be unsinged interger 8."
77
- assert buffer [8 :] == b"\x00 \x00 \x00 " , "expected empty metadata fields."
70
+ assert buffer [8 :] == b"\x00 \x00 " , "expected empty metadata fields."
78
71
assert MAX_FILE_SIZE == len (buffer ), "These should be equal"
79
72
80
73
@@ -87,35 +80,27 @@ def test_man_cmp():
87
80
88
81
# Create tensor info buffer
89
82
tensor_info_buffer = b"" .join (
90
- encode_tensor_info (
83
+ encode_header (
84
+ f"weight_{ i } " ,
91
85
"F32" ,
92
86
shape ,
93
87
(i * tensor_chunk_length , i * tensor_chunk_length + tensor_chunk_length ),
94
88
)
95
89
for i in range (size )
96
90
)
97
- layout_tensor_info = length + tensor_info_buffer
98
-
99
- expected = []
100
- for (start , end , step ) in [(0 , size , 1 ), (size - 1 , - 1 , - 1 )]:
101
- # Create hash map layout
102
- hash_map_layout = encode_hash_map ({f"weight_{ i } " : i for i in range (start , end , step )})
103
-
104
- # Construct full layout
105
- layout = b"\0 " + layout_tensor_info + hash_map_layout
106
- layout += b" " * (((8 - len (layout )) % 8 ) % 8 )
107
- n = len (layout )
108
- n_header = n .to_bytes (8 , "little" )
109
-
110
- # layout together
111
- buffer = n_header + layout + b"\0 " * (tensor_chunk_length * 2 )
112
- expected .append (buffer )
91
+ layout = length + tensor_info_buffer
92
+ layout = b"\0 " + layout
93
+ layout += b" " * (((8 - len (layout )) % 8 ) % 8 )
94
+ n = len (layout )
95
+ n_header = n .to_bytes (8 , "little" )
96
+
97
+ expected = n_header + layout + (b"\0 " * tensor_chunk_length * size )
113
98
114
99
tensor_dict = {"weight_0" : torch .zeros (shape ), "weight_1" : torch .zeros (shape )}
115
100
116
101
buffer = save (tensor_dict )
117
102
# we need to check both since there is no order in the hashmap
118
- assert buffer in expected , f"got { buffer } , and expected { expected } "
103
+ assert buffer == expected , f"got { buffer } , and expected { expected } "
119
104
120
105
121
106
def test_missmatch_length_of_metadata_large ():
@@ -127,28 +112,22 @@ def test_missmatch_length_of_metadata_large():
127
112
128
113
# Create tensor info buffer
129
114
tensor_info_buffer = b"" .join (
130
- encode_tensor_info (
115
+ encode_header (
116
+ f"weight_{ i } " ,
131
117
"F32" ,
132
118
shape ,
133
119
(i * tensor_chunk_length , i * tensor_chunk_length + tensor_chunk_length ),
134
120
)
135
121
for i in range (size )
136
122
)
137
- layout_tensor_info = length + tensor_info_buffer
138
-
139
- expected = [0 ] * 2
140
-
141
- # Create hash map layout
142
- hash_map_layout = encode_hash_map ({f"weight_{ i } " : i for i in range (0 , 2 , 1 )})
143
-
144
- # Construct full layout
145
- layout = b"\0 " + layout_tensor_info + hash_map_layout
123
+ layout = length + tensor_info_buffer
124
+ layout = b"\0 " + layout
146
125
layout += b" " * (((8 - len (layout )) % 8 ) % 8 )
147
126
n = len (layout )
148
127
n_header = n .to_bytes (8 , "little" )
149
-
128
+
150
129
# layout together
151
- buffer = n_header + layout + b"\0 " * (tensor_chunk_length * 2 )
130
+ buffer = n_header + layout + b"\0 " * (tensor_chunk_length * size )
152
131
153
132
with pytest .raises (BintensorError ):
154
133
# this is not a valid since the metadata
@@ -165,70 +144,25 @@ def test_missmatch_length_of_metadata_small():
165
144
166
145
# Create tensor info buffer
167
146
tensor_info_buffer = b"" .join (
168
- encode_tensor_info (
147
+ encode_header (
148
+ f"weight_{ i } " ,
169
149
"F32" ,
170
150
shape ,
171
151
(i * tensor_chunk_length , i * tensor_chunk_length + tensor_chunk_length ),
172
152
)
173
153
for i in range (size )
174
154
)
175
- layout_tensor_info = length + tensor_info_buffer
176
-
177
- # Create hash map layout
178
- hash_map_layout = encode_hash_map ({f"weight_{ i } " : i for i in range (0 , 2 , 1 )})
179
-
180
- # Construct full layout
181
- layout = b"\0 " + layout_tensor_info + hash_map_layout
155
+ layout = length + tensor_info_buffer
156
+ layout = b"\0 " + layout
182
157
layout += b" " * (((8 - len (layout )) % 8 ) % 8 )
183
158
n = len (layout )
184
159
n_header = n .to_bytes (8 , "little" )
185
160
186
161
# layout together
187
- buffer = n_header + layout + b"\0 " * (tensor_chunk_length * 2 )
162
+ buffer = n_header + layout + b"\0 " * (tensor_chunk_length * size )
188
163
189
164
with pytest .raises (BintensorError ):
190
165
# this is not a valid since the metadata
191
166
# size doe not match as it too big
192
167
_ = load (buffer )
193
168
194
-
195
- def test_missmatch_length_of_metadata ():
196
- size = 2
197
- shape = (2 , 2 )
198
- tensor_chunk_length = shape [0 ] * shape [1 ] * 4 # Size of a tensor buffer
199
-
200
- # convert usize or unsigned long long into variant encoding
201
- length = encode_unsigned_variant_encoding (size * 1000 )
202
-
203
- # Create tensor info byte buffer
204
- tensor_info_buffer = b"" .join (
205
- encode_tensor_info (
206
- "F32" ,
207
- shape ,
208
- (i * tensor_chunk_length , i * tensor_chunk_length + tensor_chunk_length ),
209
- )
210
- for i in range (size )
211
- )
212
- layout_tensor_info = length + tensor_info_buffer
213
-
214
- # Create hash map layout
215
- hash_map_layout = encode_hash_map ({f"weight_{ i } " : i for i in range (0 , 2 , 1 )})
216
-
217
- # Construct full layout
218
- # metadata empty + tensor_info + hash_map_index_map
219
- layout = b"\0 " + layout_tensor_info + hash_map_layout
220
-
221
- # empty padding
222
- layout += b" " * (((8 - len (layout )) % 8 ) % 8 )
223
- n = len (layout )
224
-
225
- # size of full header (metadata + tensors info + index map)
226
- n_header = n .to_bytes (8 , "little" )
227
-
228
- # layout together into buffer
229
- buffer = n_header + layout + b"\0 " * (tensor_chunk_length * 2 )
230
-
231
- with pytest .raises (BintensorError ):
232
- # this is not a valid since the metadata
233
- # size doe not match as it too big
234
- _ = load (buffer )
0 commit comments