Skip to content

Commit 2301712

Browse files
committed
Move size checks to DataDecoder
This further decouples the code. There are probably future improvements for reducing redundancy however.
1 parent 5eec7e6 commit 2301712

File tree

2 files changed

+117
-46
lines changed

2 files changed

+117
-46
lines changed

internal/decoder/decoder.go

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -156,39 +156,66 @@ func (d *DataDecoder) decodeFromTypeToDeserializer(
156156
return newOffset, err
157157
case _Slice:
158158
return d.decodeSliceToDeserializer(size, offset, dser, depth)
159-
}
160-
161-
// For the remaining types, size is the byte size
162-
if offset+size > uint(len(d.buffer)) {
163-
return 0, mmdberrors.NewOffsetError()
164-
}
165-
switch dtype {
166159
case _Bytes:
167-
v, offset := d.decodeBytes(size, offset)
160+
v, offset, err := d.decodeBytes(size, offset)
161+
if err != nil {
162+
return 0, err
163+
}
168164
return offset, dser.Bytes(v)
169165
case _Float32:
170-
v, offset := d.decodeFloat32(size, offset)
166+
v, offset, err := d.decodeFloat32(size, offset)
167+
if err != nil {
168+
return 0, err
169+
}
171170
return offset, dser.Float32(v)
172171
case _Float64:
173-
v, offset := d.decodeFloat64(size, offset)
172+
v, offset, err := d.decodeFloat64(size, offset)
173+
if err != nil {
174+
return 0, err
175+
}
176+
174177
return offset, dser.Float64(v)
175178
case _Int32:
176-
v, offset := d.decodeInt(size, offset)
179+
v, offset, err := d.decodeInt(size, offset)
180+
if err != nil {
181+
return 0, err
182+
}
183+
177184
return offset, dser.Int32(int32(v))
178185
case _String:
179-
v, offset := d.decodeString(size, offset)
186+
v, offset, err := d.decodeString(size, offset)
187+
if err != nil {
188+
return 0, err
189+
}
190+
180191
return offset, dser.String(v)
181192
case _Uint16:
182-
v, offset := d.decodeUint(size, offset)
193+
v, offset, err := d.decodeUint(size, offset)
194+
if err != nil {
195+
return 0, err
196+
}
197+
183198
return offset, dser.Uint16(uint16(v))
184199
case _Uint32:
185-
v, offset := d.decodeUint(size, offset)
200+
v, offset, err := d.decodeUint(size, offset)
201+
if err != nil {
202+
return 0, err
203+
}
204+
186205
return offset, dser.Uint32(uint32(v))
187206
case _Uint64:
188-
v, offset := d.decodeUint(size, offset)
207+
v, offset, err := d.decodeUint(size, offset)
208+
if err != nil {
209+
return 0, err
210+
}
211+
189212
return offset, dser.Uint64(v)
190213
case _Uint128:
191-
v, offset := d.decodeUint128(size, offset)
214+
v, offset, err := d.decodeUint128(size, offset)
215+
if err != nil {
216+
return 0, err
217+
}
218+
192219
return offset, dser.Uint128(v)
193220
default:
194221
return 0, mmdberrors.NewInvalidDatabaseError("unknown type: %d", dtype)
@@ -199,32 +226,48 @@ func decodeBool(size, offset uint) (bool, uint) {
199226
return size != 0, offset
200227
}
201228

202-
func (d *DataDecoder) decodeBytes(size, offset uint) ([]byte, uint) {
229+
func (d *DataDecoder) decodeBytes(size, offset uint) ([]byte, uint, error) {
230+
if offset+size > uint(len(d.buffer)) {
231+
return nil, 0, mmdberrors.NewOffsetError()
232+
}
233+
203234
newOffset := offset + size
204235
bytes := make([]byte, size)
205236
copy(bytes, d.buffer[offset:newOffset])
206-
return bytes, newOffset
237+
return bytes, newOffset, nil
207238
}
208239

209-
func (d *DataDecoder) decodeFloat64(size, offset uint) (float64, uint) {
240+
func (d *DataDecoder) decodeFloat64(size, offset uint) (float64, uint, error) {
241+
if offset+size > uint(len(d.buffer)) {
242+
return 0, 0, mmdberrors.NewOffsetError()
243+
}
244+
210245
newOffset := offset + size
211246
bits := binary.BigEndian.Uint64(d.buffer[offset:newOffset])
212-
return math.Float64frombits(bits), newOffset
247+
return math.Float64frombits(bits), newOffset, nil
213248
}
214249

215-
func (d *DataDecoder) decodeFloat32(size, offset uint) (float32, uint) {
250+
func (d *DataDecoder) decodeFloat32(size, offset uint) (float32, uint, error) {
251+
if offset+size > uint(len(d.buffer)) {
252+
return 0, 0, mmdberrors.NewOffsetError()
253+
}
254+
216255
newOffset := offset + size
217256
bits := binary.BigEndian.Uint32(d.buffer[offset:newOffset])
218-
return math.Float32frombits(bits), newOffset
257+
return math.Float32frombits(bits), newOffset, nil
219258
}
220259

221-
func (d *DataDecoder) decodeInt(size, offset uint) (int, uint) {
260+
func (d *DataDecoder) decodeInt(size, offset uint) (int, uint, error) {
261+
if offset+size > uint(len(d.buffer)) {
262+
return 0, 0, mmdberrors.NewOffsetError()
263+
}
264+
222265
newOffset := offset + size
223266
var val int32
224267
for _, b := range d.buffer[offset:newOffset] {
225268
val = (val << 8) | int32(b)
226269
}
227-
return int(val), newOffset
270+
return int(val), newOffset, nil
228271
}
229272

230273
func (d *DataDecoder) decodeMapToDeserializer(
@@ -314,28 +357,40 @@ func (d *DataDecoder) decodeSliceToDeserializer(
314357
return offset, nil
315358
}
316359

317-
func (d *DataDecoder) decodeString(size, offset uint) (string, uint) {
360+
func (d *DataDecoder) decodeString(size, offset uint) (string, uint, error) {
361+
if offset+size > uint(len(d.buffer)) {
362+
return "", 0, mmdberrors.NewOffsetError()
363+
}
364+
318365
newOffset := offset + size
319-
return string(d.buffer[offset:newOffset]), newOffset
366+
return string(d.buffer[offset:newOffset]), newOffset, nil
320367
}
321368

322-
func (d *DataDecoder) decodeUint(size, offset uint) (uint64, uint) {
369+
func (d *DataDecoder) decodeUint(size, offset uint) (uint64, uint, error) {
370+
if offset+size > uint(len(d.buffer)) {
371+
return 0, 0, mmdberrors.NewOffsetError()
372+
}
373+
323374
newOffset := offset + size
324375
bytes := d.buffer[offset:newOffset]
325376

326377
var val uint64
327378
for _, b := range bytes {
328379
val = (val << 8) | uint64(b)
329380
}
330-
return val, newOffset
381+
return val, newOffset, nil
331382
}
332383

333-
func (d *DataDecoder) decodeUint128(size, offset uint) (*big.Int, uint) {
384+
func (d *DataDecoder) decodeUint128(size, offset uint) (*big.Int, uint, error) {
385+
if offset+size > uint(len(d.buffer)) {
386+
return nil, 0, mmdberrors.NewOffsetError()
387+
}
388+
334389
newOffset := offset + size
335390
val := new(big.Int)
336391
val.SetBytes(d.buffer[offset:newOffset])
337392

338-
return val, newOffset
393+
return val, newOffset, nil
339394
}
340395

341396
func uintFromBytes(prefix uint, uintBytes []byte) uint {

internal/decoder/reflection.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,6 @@ func (d *Decoder) decodeFromType(
166166
return d.unmarshalPointer(size, offset, result, depth)
167167
case _Slice:
168168
return d.unmarshalSlice(size, offset, result, depth)
169-
}
170-
171-
// For the remaining types, size is the byte size
172-
if offset+size > uint(len(d.buffer)) {
173-
return 0, mmdberrors.NewOffsetError()
174-
}
175-
switch dtype {
176169
case _Bytes:
177170
return d.unmarshalBytes(size, offset, result)
178171
case _Float32:
@@ -181,14 +174,14 @@ func (d *Decoder) decodeFromType(
181174
return d.unmarshalFloat64(size, offset, result)
182175
case _Int32:
183176
return d.unmarshalInt32(size, offset, result)
184-
case _String:
185-
return d.unmarshalString(size, offset, result)
186177
case _Uint16:
187178
return d.unmarshalUint(size, offset, result, 16)
188179
case _Uint32:
189180
return d.unmarshalUint(size, offset, result, 32)
190181
case _Uint64:
191182
return d.unmarshalUint(size, offset, result, 64)
183+
case _String:
184+
return d.unmarshalString(size, offset, result)
192185
case _Uint128:
193186
return d.unmarshalUint128(size, offset, result)
194187
default:
@@ -250,7 +243,10 @@ func indirect(result reflect.Value) reflect.Value {
250243
var sliceType = reflect.TypeOf([]byte{})
251244

252245
func (d *Decoder) unmarshalBytes(size, offset uint, result reflect.Value) (uint, error) {
253-
value, newOffset := d.decodeBytes(size, offset)
246+
value, newOffset, err := d.decodeBytes(size, offset)
247+
if err != nil {
248+
return 0, err
249+
}
254250

255251
switch result.Kind() {
256252
case reflect.Slice:
@@ -274,7 +270,10 @@ func (d *Decoder) unmarshalFloat32(size, offset uint, result reflect.Value) (uin
274270
size,
275271
)
276272
}
277-
value, newOffset := d.decodeFloat32(size, offset)
273+
value, newOffset, err := d.decodeFloat32(size, offset)
274+
if err != nil {
275+
return 0, err
276+
}
278277

279278
switch result.Kind() {
280279
case reflect.Float32, reflect.Float64:
@@ -296,7 +295,10 @@ func (d *Decoder) unmarshalFloat64(size, offset uint, result reflect.Value) (uin
296295
size,
297296
)
298297
}
299-
value, newOffset := d.decodeFloat64(size, offset)
298+
value, newOffset, err := d.decodeFloat64(size, offset)
299+
if err != nil {
300+
return 0, err
301+
}
300302

301303
switch result.Kind() {
302304
case reflect.Float32, reflect.Float64:
@@ -321,7 +323,11 @@ func (d *Decoder) unmarshalInt32(size, offset uint, result reflect.Value) (uint,
321323
size,
322324
)
323325
}
324-
value, newOffset := d.decodeInt(size, offset)
326+
327+
value, newOffset, err := d.decodeInt(size, offset)
328+
if err != nil {
329+
return 0, err
330+
}
325331

326332
switch result.Kind() {
327333
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -410,7 +416,10 @@ func (d *Decoder) unmarshalSlice(
410416
}
411417

412418
func (d *Decoder) unmarshalString(size, offset uint, result reflect.Value) (uint, error) {
413-
value, newOffset := d.decodeString(size, offset)
419+
value, newOffset, err := d.decodeString(size, offset)
420+
if err != nil {
421+
return 0, err
422+
}
414423

415424
switch result.Kind() {
416425
case reflect.String:
@@ -438,7 +447,10 @@ func (d *Decoder) unmarshalUint(
438447
)
439448
}
440449

441-
value, newOffset := d.decodeUint(size, offset)
450+
value, newOffset, err := d.decodeUint(size, offset)
451+
if err != nil {
452+
return 0, err
453+
}
442454

443455
switch result.Kind() {
444456
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -475,7 +487,11 @@ func (d *Decoder) unmarshalUint128(size, offset uint, result reflect.Value) (uin
475487
size,
476488
)
477489
}
478-
value, newOffset := d.decodeUint128(size, offset)
490+
491+
value, newOffset, err := d.decodeUint128(size, offset)
492+
if err != nil {
493+
return 0, err
494+
}
479495

480496
switch result.Kind() {
481497
case reflect.Struct:

0 commit comments

Comments
 (0)