@@ -47,18 +47,23 @@ type activeTransfer struct {
47
47
}
48
48
49
49
type activeRequest struct {
50
- id string
51
50
deadline int64
52
51
result chan <- AsyncQueryResult
53
52
}
54
53
54
+ type expectedTransfer struct {
55
+ deadline int64
56
+ maxSize int64
57
+ }
58
+
55
59
type RLDP struct {
56
60
adnl ADNL
57
61
useV2 bool
58
62
59
63
activateRecoverySender chan bool
60
64
activeRequests map [string ]* activeRequest
61
65
activeTransfers map [string ]* activeTransfer
66
+ expectedTransfers map [string ]* expectedTransfer
62
67
63
68
recvStreams map [string ]* decoderStream
64
69
@@ -94,6 +99,7 @@ type decoderStream struct {
94
99
}
95
100
96
101
var DefaultSymbolSize uint32 = 768
102
+ var MaxUnexpectedTransferSize int64 = 1 << 16 // 64 KB
97
103
98
104
const _MTU = 1 << 37
99
105
@@ -103,6 +109,7 @@ func NewClient(a ADNL) *RLDP {
103
109
activeRequests : map [string ]* activeRequest {},
104
110
activeTransfers : map [string ]* activeTransfer {},
105
111
recvStreams : map [string ]* decoderStream {},
112
+ expectedTransfers : map [string ]* expectedTransfer {},
106
113
activateRecoverySender : make (chan bool , 1 ),
107
114
}
108
115
@@ -168,12 +175,22 @@ func (r *RLDP) handleMessage(msg *adnl.MessageCustom) error {
168
175
id := string (m .TransferID )
169
176
r .mx .RLock ()
170
177
stream := r .recvStreams [id ]
178
+ expected := r .expectedTransfers [id ]
171
179
r .mx .RUnlock ()
172
180
173
181
if stream == nil {
174
- // TODO: limit unexpected transfer size to 1024 bytes
175
182
if m .TotalSize > _MTU || m .TotalSize <= 0 {
176
- return fmt .Errorf ("bad rldp packet total size" )
183
+ return fmt .Errorf ("bad rldp packet total size %d" , m .TotalSize )
184
+ }
185
+
186
+ // unexpected transfers limited to this size, for protection
187
+ var maxTransferSize = MaxUnexpectedTransferSize
188
+ if expected != nil {
189
+ maxTransferSize = expected .maxSize
190
+ }
191
+
192
+ if m .TotalSize > maxTransferSize {
193
+ return fmt .Errorf ("too big transfer size %d, max allowed %d" , m .TotalSize , maxTransferSize )
177
194
}
178
195
179
196
stream = & decoderStream {
@@ -189,6 +206,7 @@ func (r *RLDP) handleMessage(msg *adnl.MessageCustom) error {
189
206
} else {
190
207
r .recvStreams [id ] = stream
191
208
}
209
+ delete (r .expectedTransfers , id )
192
210
r .mx .Unlock ()
193
211
}
194
212
@@ -405,7 +423,8 @@ func (r *RLDP) recoverySender() {
405
423
packets := make ([]tl.Serializable , 0 , 1024 )
406
424
transfersToProcess := make ([]* activeTransfer , 0 , 128 )
407
425
timedOut := make ([]* activeTransfer , 0 , 32 )
408
- timedOutReq := make ([]* activeRequest , 0 , 32 )
426
+ timedOutReq := make ([]string , 0 , 32 )
427
+ timedOutExp := make ([]string , 0 , 32 )
409
428
closerCtx := r .adnl .GetCloserCtx ()
410
429
ticker := time .NewTicker (1 * time .Millisecond )
411
430
defer ticker .Stop ()
@@ -419,6 +438,7 @@ func (r *RLDP) recoverySender() {
419
438
transfersToProcess = transfersToProcess [:0 ]
420
439
timedOut = timedOut [:0 ]
421
440
timedOutReq = timedOutReq [:0 ]
441
+ timedOutExp = timedOutExp [:0 ]
422
442
423
443
ms := time .Now ().UnixNano () / int64 (time .Millisecond )
424
444
@@ -436,13 +456,19 @@ func (r *RLDP) recoverySender() {
436
456
}
437
457
}
438
458
439
- for _ , req := range r .activeRequests {
459
+ for id , req := range r .activeRequests {
440
460
if req .deadline < ms {
441
- timedOutReq = append (timedOutReq , req )
461
+ timedOutReq = append (timedOutReq , id )
442
462
}
443
463
}
444
464
445
- if len (r .activeRequests )+ len (r .activeTransfers ) == 0 {
465
+ for id , req := range r .expectedTransfers {
466
+ if req .deadline < ms {
467
+ timedOutExp = append (timedOutExp , id )
468
+ }
469
+ }
470
+
471
+ if len (r .activeRequests )+ len (r .activeTransfers )+ len (r .expectedTransfers ) == 0 {
446
472
// stop ticks to not consume resources
447
473
ticker .Stop ()
448
474
}
@@ -481,13 +507,16 @@ func (r *RLDP) recoverySender() {
481
507
}
482
508
}
483
509
484
- if len (timedOut ) > 0 || len (timedOutReq ) > 0 {
510
+ if len (timedOut ) > 0 || len (timedOutReq ) > 0 || len ( timedOutExp ) > 0 {
485
511
r .mx .Lock ()
486
512
for _ , transfer := range timedOut {
487
513
delete (r .activeTransfers , string (transfer .id ))
488
514
}
489
515
for _ , req := range timedOutReq {
490
- delete (r .activeRequests , req .id )
516
+ delete (r .activeRequests , req )
517
+ }
518
+ for _ , req := range timedOutExp {
519
+ delete (r .expectedTransfers , req )
491
520
}
492
521
r .mx .Unlock ()
493
522
}
@@ -580,17 +609,9 @@ func (r *RLDP) DoQuery(ctx context.Context, maxAnswerSize int64, query, result t
580
609
581
610
select {
582
611
case resp := <- res :
583
- r .mx .Lock ()
584
- delete (r .activeRequests , string (qid ))
585
- r .mx .Unlock ()
586
-
587
612
reflect .ValueOf (result ).Elem ().Set (reflect .ValueOf (resp .Result ))
588
613
return nil
589
614
case <- ctx .Done ():
590
- r .mx .Lock ()
591
- delete (r .activeRequests , string (qid ))
592
- r .mx .Unlock ()
593
-
594
615
return fmt .Errorf ("response deadline exceeded, err: %w" , ctx .Err ())
595
616
}
596
617
}
@@ -627,6 +648,7 @@ func (r *RLDP) DoQueryAsync(ctx context.Context, maxAnswerSize int64, id []byte,
627
648
if err != nil {
628
649
return err
629
650
}
651
+ reverseId := reverseTransferId (transferId )
630
652
631
653
out := timeout .UnixNano () / int64 (time .Millisecond )
632
654
@@ -635,6 +657,10 @@ func (r *RLDP) DoQueryAsync(ctx context.Context, maxAnswerSize int64, id []byte,
635
657
deadline : out ,
636
658
result : result ,
637
659
}
660
+ r .expectedTransfers [string (reverseId )] = & expectedTransfer {
661
+ deadline : out ,
662
+ maxSize : maxAnswerSize ,
663
+ }
638
664
r .mx .Unlock ()
639
665
640
666
if err = r .sendMessageParts (ctx , transferId , data , (time .Duration (q .Timeout )- time .Duration (time .Now ().Unix ()))* time .Second ); err != nil {
@@ -659,17 +685,14 @@ func (r *RLDP) SendAnswer(ctx context.Context, maxAnswerSize int64, queryId, toT
659
685
return fmt .Errorf ("too big answer for that client, client wants no more than %d bytes" , maxAnswerSize )
660
686
}
661
687
662
- transferId := make ( []byte , 32 )
688
+ var transferId []byte
663
689
664
690
if toTransferId != nil {
665
691
// if we have transfer to respond, invert it and use id
666
- copy (transferId , toTransferId )
667
- for i := range transferId {
668
- transferId [i ] ^= 0xFF
669
- }
692
+ transferId = reverseTransferId (toTransferId )
670
693
} else {
671
- _ , err = rand . Read ( transferId )
672
- if err != nil {
694
+ transferId = make ([] byte , 32 )
695
+ if _ , err = rand . Read ( transferId ); err != nil {
673
696
return err
674
697
}
675
698
}
@@ -680,3 +703,12 @@ func (r *RLDP) SendAnswer(ctx context.Context, maxAnswerSize int64, queryId, toT
680
703
681
704
return nil
682
705
}
706
+
707
+ func reverseTransferId (id []byte ) []byte {
708
+ rev := make ([]byte , 32 )
709
+ copy (rev , id )
710
+ for i := range rev {
711
+ rev [i ] ^= 0xFF
712
+ }
713
+ return rev
714
+ }
0 commit comments