Skip to content

Commit 9f1a31c

Browse files
committed
fix: correctly return number of bytes read from chunked streams
1 parent 2542cd9 commit 9f1a31c

File tree

9 files changed

+65
-29
lines changed

9 files changed

+65
-29
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"id": "7251f5e7-0e9a-4ea4-b4c7-30dad31f4622",
3+
"type": "bugfix",
4+
"description": "⚠️ **IMPORTANT**: Correctly return number of bytes read from chunked streams",
5+
"issues": [
6+
"awslabs/smithy-kotlin#1285"
7+
],
8+
"requiresMinorVersionBump": true
9+
}

runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsChunkedByteReadChannel.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ public class AwsChunkedByteReadChannel(
4848
override suspend fun read(sink: SdkBuffer, limit: Long): Long {
4949
require(limit >= 0L) { "Invalid limit ($limit) must be >= 0L" }
5050
if (!chunkReader.ensureValidChunk()) return -1L
51-
return chunkReader.chunk.read(sink, limit)
51+
chunkReader.chunk.read(sink, limit)
52+
return chunkReader.readCountAndReset()
5253
}
5354
}
5455

runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsChunkedSource.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ public class AwsChunkedSource(
4747
chunkReader.ensureValidChunk()
4848
}
4949
if (!isChunkValid) return -1L
50-
return chunkReader.chunk.read(sink, limit)
50+
chunkReader.chunk.read(sink, limit)
51+
return chunkReader.readCountAndReset()
5152
}
5253

5354
override fun close() {

runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/internal/AwsChunkedReader.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ internal class AwsChunkedReader(
5757
*/
5858
internal var hasLastChunkBeenSent: Boolean = false
5959

60+
private var read = 0L
61+
62+
/**
63+
* Gets the most recent read count and then resets the counter. This is meant to be checked after every caller
64+
* invocation of AwsChunkedReader.chunk.read(...). Note that because of buffering, this value may return _more_ than
65+
* the total number of bytes written to the sink.
66+
*/
67+
fun readCountAndReset(): Long = read.also { read = 0L }
68+
6069
/**
6170
* Ensures that the internal [chunk] is valid for reading. If it's not valid, try to load the next chunk. Note that
6271
* this function will suspend until the whole chunk has been loaded.
@@ -117,6 +126,7 @@ internal class AwsChunkedReader(
117126
while (remaining > 0L) {
118127
val rc = read(sink, remaining)
119128
if (rc == -1L) break
129+
read += rc
120130
remaining -= rc
121131
}
122132

runtime/auth/aws-signing-tests/common/src/aws/smithy/kotlin/runtime/auth/awssigning/tests/AwsChunkedByteReadChannelTestBase.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ abstract class AwsChunkedByteReadChannelTestBase : AwsChunkedTestBase(AwsChunked
4747
val sink = SdkBuffer()
4848

4949
val bytesRead = awsChunked.readAll(sink)
50-
writeJob.join()
50+
// writeJob.join()
51+
assertEquals(dataLengthBytes.toLong(), bytesRead)
52+
assertEquals(totalBytesExpected.toLong(), sink.size)
5153

5254
val bytesAsString = sink.readUtf8()
5355

54-
assertEquals(totalBytesExpected.toLong(), bytesRead)
5556
assertTrue(awsChunked.isClosedForRead)
5657

5758
val chunkSignatures = getChunkSignatures(bytesAsString)

runtime/auth/aws-signing-tests/common/src/aws/smithy/kotlin/runtime/auth/awssigning/tests/AwsChunkedTestBase.kt

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ interface AwsChunkedTestReader {
2222
// This may modify the chunked reader state and cause loss of data!
2323
fun isClosedForRead(): Boolean
2424
suspend fun read(sink: SdkBuffer, limit: Long): Long
25+
suspend fun readAll(sink: SdkBuffer): Long
2526
}
2627

2728
fun interface AwsChunkedReaderFactory {
@@ -32,6 +33,7 @@ fun interface AwsChunkedReaderFactory {
3233
object : AwsChunkedTestReader {
3334
override fun isClosedForRead(): Boolean = chunked.isClosedForRead
3435
override suspend fun read(sink: SdkBuffer, limit: Long): Long = chunked.read(sink, limit)
36+
override suspend fun readAll(sink: SdkBuffer): Long = chunked.readAll(sink)
3537
}
3638
}
3739
}
@@ -198,6 +200,7 @@ abstract class AwsChunkedTestBase(
198200
// need to make 2 successive calls because there are two chunks -- read will only fetch the first one due to limit
199201
var bytesRead = awsChunked.read(sink, readLimit.toLong())
200202
bytesRead += awsChunked.read(sink, readLimit - bytesRead)
203+
assertEquals(readLimit.toLong(), sink.size)
201204

202205
val bytesAsString = sink.readUtf8()
203206

@@ -213,7 +216,7 @@ abstract class AwsChunkedTestBase(
213216
assertEquals(CHUNK_SIZE_BYTES, chunkSizes[0])
214217
assertEquals(0, chunkSizes[1])
215218

216-
assertEquals(readLimit, bytesRead.toInt())
219+
assertEquals(dataLengthBytes, bytesRead.toInt())
217220
assertTrue(awsChunked.isClosedForRead())
218221
}
219222

@@ -257,7 +260,8 @@ abstract class AwsChunkedTestBase(
257260

258261
val sink = SdkBuffer()
259262
val bytesRead = awsChunked.read(sink, readLimit.toLong())
260-
assertEquals(readLimit.toLong(), bytesRead)
263+
assertEquals(CHUNK_SIZE_BYTES.toLong(), bytesRead)
264+
assertEquals(readLimit.toLong(), sink.size)
261265

262266
val bytesAsString = sink.readUtf8()
263267
val chunkSignatures = getChunkSignatures(bytesAsString)
@@ -289,7 +293,8 @@ abstract class AwsChunkedTestBase(
289293
bytesRead += awsChunked.read(sink, readLimit.toLong())
290294
}
291295
bytesRead += awsChunked.read(sink, readLimit.toLong())
292-
assertEquals(totalBytesExpected.toLong(), bytesRead)
296+
assertEquals(dataLengthBytes.toLong(), bytesRead)
297+
assertEquals(totalBytesExpected.toLong(), sink.size)
293298

294299
val bytesAsString = sink.readUtf8()
295300

@@ -345,9 +350,11 @@ abstract class AwsChunkedTestBase(
345350
}
346351
}
347352

353+
assertEquals(dataLengthBytes.toLong(), bytesRead)
354+
assertEquals(totalBytesExpected.toLong(), sink.size)
355+
348356
val bytesAsString = sink.readUtf8()
349357

350-
assertEquals(totalBytesExpected.toLong(), bytesRead)
351358
assertTrue(awsChunked.isClosedForRead())
352359

353360
val chunkSignatures = getChunkSignatures(bytesAsString)
@@ -402,13 +409,9 @@ abstract class AwsChunkedTestBase(
402409
val totalBytesExpected = encodedChunkLength(CHUNK_SIZE_BYTES) + encodedChunkLength(0) + trailingHeadersLength + "\r\n".length
403410
val sink = SdkBuffer()
404411

405-
var bytesRead = 0L
406-
407-
while (bytesRead < totalBytesExpected.toLong()) {
408-
bytesRead += awsChunked.read(sink, Long.MAX_VALUE)
409-
}
410-
411-
assertEquals(totalBytesExpected.toLong(), bytesRead)
412+
val bytesRead = awsChunked.readAll(sink)
413+
assertEquals(dataLengthBytes.toLong(), bytesRead)
414+
assertEquals(totalBytesExpected.toLong(), sink.size)
412415
assertTrue(awsChunked.isClosedForRead())
413416

414417
val bytesAsString = sink.readUtf8()
@@ -445,13 +448,9 @@ abstract class AwsChunkedTestBase(
445448
val totalBytesExpected = encodedUnsignedChunkLength(CHUNK_SIZE_BYTES) + encodedUnsignedChunkLength(0) + "\r\n".length
446449
val sink = SdkBuffer()
447450

448-
var bytesRead = 0L
449-
450-
while (bytesRead < totalBytesExpected.toLong()) {
451-
bytesRead += awsChunked.read(sink, Long.MAX_VALUE)
452-
}
453-
454-
assertEquals(totalBytesExpected.toLong(), bytesRead)
451+
val bytesRead = awsChunked.readAll(sink)
452+
assertEquals(dataLengthBytes.toLong(), bytesRead)
453+
assertEquals(totalBytesExpected.toLong(), sink.size)
455454
assertTrue(awsChunked.isClosedForRead())
456455

457456
val bytesAsString = sink.readUtf8()
@@ -482,13 +481,9 @@ abstract class AwsChunkedTestBase(
482481
val totalBytesExpected = encodedUnsignedChunkLength(CHUNK_SIZE_BYTES) + encodedUnsignedChunkLength(0) + trailingHeadersLength + "\r\n".length
483482
val sink = SdkBuffer()
484483

485-
var bytesRead = 0L
486-
487-
while (bytesRead < totalBytesExpected.toLong()) {
488-
bytesRead += awsChunked.read(sink, Long.MAX_VALUE)
489-
}
490-
491-
assertEquals(totalBytesExpected.toLong(), bytesRead)
484+
val bytesRead = awsChunked.readAll(sink)
485+
assertEquals(dataLengthBytes.toLong(), bytesRead)
486+
assertEquals(totalBytesExpected.toLong(), sink.size)
492487
assertTrue(awsChunked.isClosedForRead())
493488

494489
val bytesAsString = sink.readUtf8()

runtime/auth/aws-signing-tests/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/tests/AwsChunkedSourceTestBase.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package aws.smithy.kotlin.runtime.auth.awssigning.tests
77
import aws.smithy.kotlin.runtime.auth.awssigning.AwsChunkedSource
88
import aws.smithy.kotlin.runtime.io.SdkBuffer
9+
import aws.smithy.kotlin.runtime.io.readAll
910
import aws.smithy.kotlin.runtime.io.source
1011

1112
val AwsChunkedReaderFactory.Companion.Source: AwsChunkedReaderFactory
@@ -19,6 +20,7 @@ val AwsChunkedReaderFactory.Companion.Source: AwsChunkedReaderFactory
1920
return rc == -1L
2021
}
2122
override suspend fun read(sink: SdkBuffer, limit: Long): Long = chunked.read(sink, limit)
23+
override suspend fun readAll(sink: SdkBuffer): Long = chunked.readAll(sink)
2224
}
2325
}
2426

runtime/runtime-core/api/runtime-core.api

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ public final class aws/smithy/kotlin/runtime/io/SdkSourceJVMKt {
10751075
}
10761076

10771077
public final class aws/smithy/kotlin/runtime/io/SdkSourceKt {
1078+
public static final fun readAll (Laws/smithy/kotlin/runtime/io/SdkSource;Laws/smithy/kotlin/runtime/io/SdkSink;)J
10781079
public static final fun readFully (Laws/smithy/kotlin/runtime/io/SdkSource;Laws/smithy/kotlin/runtime/io/SdkBuffer;J)V
10791080
}
10801081

runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/io/SdkSource.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,19 @@ public fun SdkSource.readFully(sink: SdkBuffer, byteCount: Long) {
7676
totalBytesRead += rc
7777
}
7878
}
79+
80+
/**
81+
* Read all bytes from this source into [sink]. Returns the total number of bytes written.
82+
*/
83+
public fun SdkSource.readAll(sink: SdkSink): Long {
84+
val bufferedSink = sink.buffer()
85+
var totalWritten = 0L
86+
while (true) {
87+
val rc = read(bufferedSink.buffer, DEFAULT_BYTE_CHANNEL_MAX_BUFFER_SIZE.toLong())
88+
if (rc == -1L) break
89+
totalWritten += rc
90+
bufferedSink.emit()
91+
}
92+
bufferedSink.emit()
93+
return totalWritten
94+
}

0 commit comments

Comments
 (0)