|
| 1 | +import pytest |
| 2 | +import struct |
| 3 | + |
| 4 | +import torch |
| 5 | +from typing import List, Dict, Tuple |
| 6 | +from itertools import chain |
| 7 | + |
| 8 | +from bintensors import BintensorError |
| 9 | +from bintensors.torch import save, load |
| 10 | + |
| 11 | +_DTYPE = { |
| 12 | + "BOL": 0, |
| 13 | + "U8": 1, |
| 14 | + "I8": 2, |
| 15 | + "F8_E5M2": 3, |
| 16 | + "F8_E4M3": 4, |
| 17 | + "I16": 5, |
| 18 | + "U16": 6, |
| 19 | + "F16": 7, |
| 20 | + "BF16": 8, |
| 21 | + "I32": 9, |
| 22 | + "U32": 10, |
| 23 | + "F32": 11, |
| 24 | + "F64": 12, |
| 25 | + "I64": 13, |
| 26 | + "F64": 14, |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +def encode_unsigned_variant_encoding(number: int) -> bytes: |
| 31 | + """Encodes an unsigned integer into a variable-length format.""" |
| 32 | + if number > 0xFFFFFFFF: |
| 33 | + return b"\xfd" + number.to_bytes(8, "little") |
| 34 | + elif number > 0xFFFF: |
| 35 | + return b"\xfc" + number.to_bytes(4, "little") |
| 36 | + elif number > 0xFA: |
| 37 | + return b"\xfb" + number.to_bytes(2, "little") |
| 38 | + else: |
| 39 | + return number.to_bytes(1, "little") |
| 40 | + |
| 41 | + |
| 42 | +def encode_tensor_info(dtype: str, shape: Tuple[int, ...], offset: Tuple[int, int]) -> List[bytes]: |
| 43 | + """Encodes the struct TensorInfo into byte buffer""" |
| 44 | + if dtype not in _DTYPE: |
| 45 | + raise ValueError(f"Unsupported dtype: {dtype}") |
| 46 | + |
| 47 | + # flatten out the tensor info |
| 48 | + layout = chain([_DTYPE[dtype], len(shape)], shape, offset) |
| 49 | + return b"".join(list(map(encode_unsigned_variant_encoding, layout))) |
| 50 | + |
| 51 | + |
| 52 | +def encode_hash_map(index_map: Dict[str, int]) -> List[bytes]: |
| 53 | + """Encodes a dictionary of string keys and integer values.""" |
| 54 | + length = encode_unsigned_variant_encoding(len(index_map)) |
| 55 | + |
| 56 | + hash_map_layout = chain.from_iterable( |
| 57 | + ( |
| 58 | + encode_unsigned_variant_encoding(len(k)), |
| 59 | + k.encode("utf-8"), |
| 60 | + encode_unsigned_variant_encoding(v), |
| 61 | + ) |
| 62 | + for k, v in index_map.items() |
| 63 | + ) |
| 64 | + |
| 65 | + return b"".join(chain([length], hash_map_layout)) |
| 66 | + |
| 67 | + |
| 68 | +def test_empty_file(): |
| 69 | + "bintensors allows empty dictonary" |
| 70 | + tensor_dict = {} |
| 71 | + buffer = save(tensor_dict) |
| 72 | + # decouple first 8 bytes part of the buffer unsinged long long |
| 73 | + header_size = struct.unpack("<Q", buffer[0:8])[0] |
| 74 | + # header size + metadata + empty tensors |
| 75 | + MAX_FILE_SIZE = 8 + header_size |
| 76 | + assert header_size == 8, "expected packed buffer shoudl be unsinged interger 8." |
| 77 | + assert buffer[8:] == b"\x00\x00\x00 ", "expected empty metadata fields." |
| 78 | + assert MAX_FILE_SIZE == len(buffer), "These should be equal" |
| 79 | + |
| 80 | + |
| 81 | +def test_man_cmp(): |
| 82 | + size = 2 |
| 83 | + shape = (2, 2) |
| 84 | + tensor_chunk_length = shape[0] * shape[1] * 4 # Size of a tensor buffer |
| 85 | + |
| 86 | + length = encode_unsigned_variant_encoding(size) |
| 87 | + |
| 88 | + # Create tensor info buffer |
| 89 | + tensor_info_buffer = b"".join( |
| 90 | + encode_tensor_info( |
| 91 | + "F32", |
| 92 | + shape, |
| 93 | + (i * tensor_chunk_length, i * tensor_chunk_length + tensor_chunk_length), |
| 94 | + ) |
| 95 | + for i in range(size) |
| 96 | + ) |
| 97 | + layout_tensor_info = length + tensor_info_buffer |
| 98 | + |
| 99 | + expected = [] |
| 100 | + for (start, end, step) in [(0, size, 1), (size - 1, -1, -1)]: |
| 101 | + # Create hash map layout |
| 102 | + hash_map_layout = encode_hash_map({f"weight_{i}": i for i in range(start, end, step)}) |
| 103 | + |
| 104 | + # Construct full layout |
| 105 | + layout = b"\0" + layout_tensor_info + hash_map_layout |
| 106 | + layout += b" " * (((8 - len(layout)) % 8) % 8) |
| 107 | + n = len(layout) |
| 108 | + n_header = n.to_bytes(8, "little") |
| 109 | + |
| 110 | + # layout together |
| 111 | + buffer = n_header + layout + b"\0" * (tensor_chunk_length * 2) |
| 112 | + expected.append(buffer) |
| 113 | + |
| 114 | + tensor_dict = {"weight_0": torch.zeros(shape), "weight_1": torch.zeros(shape)} |
| 115 | + |
| 116 | + buffer = save(tensor_dict) |
| 117 | + # we need to check both since there is no order in the hashmap |
| 118 | + assert buffer in expected, f"got {buffer}, and expected {expected}" |
| 119 | + |
| 120 | + |
| 121 | +def test_missmatch_length_of_metadata_large(): |
| 122 | + size = 2 |
| 123 | + shape = (2, 2) |
| 124 | + tensor_chunk_length = shape[0] * shape[1] * 4 # Size of a tensor buffer |
| 125 | + |
| 126 | + length = encode_unsigned_variant_encoding(size * 1000) |
| 127 | + |
| 128 | + # Create tensor info buffer |
| 129 | + tensor_info_buffer = b"".join( |
| 130 | + encode_tensor_info( |
| 131 | + "F32", |
| 132 | + shape, |
| 133 | + (i * tensor_chunk_length, i * tensor_chunk_length + tensor_chunk_length), |
| 134 | + ) |
| 135 | + for i in range(size) |
| 136 | + ) |
| 137 | + layout_tensor_info = length + tensor_info_buffer |
| 138 | + |
| 139 | + expected = [0] * 2 |
| 140 | + |
| 141 | + # Create hash map layout |
| 142 | + hash_map_layout = encode_hash_map({f"weight_{i}": i for i in range(0, 2, 1)}) |
| 143 | + |
| 144 | + # Construct full layout |
| 145 | + layout = b"\0" + layout_tensor_info + hash_map_layout |
| 146 | + layout += b" " * (((8 - len(layout)) % 8) % 8) |
| 147 | + n = len(layout) |
| 148 | + n_header = n.to_bytes(8, "little") |
| 149 | + |
| 150 | + # layout together |
| 151 | + buffer = n_header + layout + b"\0" * (tensor_chunk_length * 2) |
| 152 | + |
| 153 | + with pytest.raises(BintensorError): |
| 154 | + # this is not a valid since the metadata |
| 155 | + # size doe not match as it too big |
| 156 | + _ = load(buffer) |
| 157 | + |
| 158 | + |
| 159 | +def test_missmatch_length_of_metadata_small(): |
| 160 | + size = 2 |
| 161 | + shape = (2, 2) |
| 162 | + tensor_chunk_length = shape[0] * shape[1] * 4 # Size of a tensor buffer |
| 163 | + |
| 164 | + length = encode_unsigned_variant_encoding(size - 1) |
| 165 | + |
| 166 | + # Create tensor info buffer |
| 167 | + tensor_info_buffer = b"".join( |
| 168 | + encode_tensor_info( |
| 169 | + "F32", |
| 170 | + shape, |
| 171 | + (i * tensor_chunk_length, i * tensor_chunk_length + tensor_chunk_length), |
| 172 | + ) |
| 173 | + for i in range(size) |
| 174 | + ) |
| 175 | + layout_tensor_info = length + tensor_info_buffer |
| 176 | + |
| 177 | + # Create hash map layout |
| 178 | + hash_map_layout = encode_hash_map({f"weight_{i}": i for i in range(0, 2, 1)}) |
| 179 | + |
| 180 | + # Construct full layout |
| 181 | + layout = b"\0" + layout_tensor_info + hash_map_layout |
| 182 | + layout += b" " * (((8 - len(layout)) % 8) % 8) |
| 183 | + n = len(layout) |
| 184 | + n_header = n.to_bytes(8, "little") |
| 185 | + |
| 186 | + # layout together |
| 187 | + buffer = n_header + layout + b"\0" * (tensor_chunk_length * 2) |
| 188 | + |
| 189 | + with pytest.raises(BintensorError): |
| 190 | + # this is not a valid since the metadata |
| 191 | + # size doe not match as it too big |
| 192 | + _ = load(buffer) |
| 193 | + |
| 194 | + |
| 195 | +def test_missmatch_length_of_metadata(): |
| 196 | + size = 2 |
| 197 | + shape = (2, 2) |
| 198 | + tensor_chunk_length = shape[0] * shape[1] * 4 # Size of a tensor buffer |
| 199 | + |
| 200 | + # convert usize or unsigned long long into variant encoding |
| 201 | + length = encode_unsigned_variant_encoding(size * 1000) |
| 202 | + |
| 203 | + # Create tensor info byte buffer |
| 204 | + tensor_info_buffer = b"".join( |
| 205 | + encode_tensor_info( |
| 206 | + "F32", |
| 207 | + shape, |
| 208 | + (i * tensor_chunk_length, i * tensor_chunk_length + tensor_chunk_length), |
| 209 | + ) |
| 210 | + for i in range(size) |
| 211 | + ) |
| 212 | + layout_tensor_info = length + tensor_info_buffer |
| 213 | + |
| 214 | + # Create hash map layout |
| 215 | + hash_map_layout = encode_hash_map({f"weight_{i}": i for i in range(0, 2, 1)}) |
| 216 | + |
| 217 | + # Construct full layout |
| 218 | + # metadata empty + tensor_info + hash_map_index_map |
| 219 | + layout = b"\0" + layout_tensor_info + hash_map_layout |
| 220 | + |
| 221 | + # empty padding |
| 222 | + layout += b" " * (((8 - len(layout)) % 8) % 8) |
| 223 | + n = len(layout) |
| 224 | + |
| 225 | + # size of full header (metadata + tensors info + index map) |
| 226 | + n_header = n.to_bytes(8, "little") |
| 227 | + |
| 228 | + # layout together into buffer |
| 229 | + buffer = n_header + layout + b"\0" * (tensor_chunk_length * 2) |
| 230 | + |
| 231 | + with pytest.raises(BintensorError): |
| 232 | + # this is not a valid since the metadata |
| 233 | + # size doe not match as it too big |
| 234 | + _ = load(buffer) |
0 commit comments