Skip to content

Commit bd90f26

Browse files
authored
feat: traced() (#96)
* feat: traced() fixes: #38 * chore: Improve tests --------- Co-authored-by: danthorpe <danthorpe@users.noreply.github.com>
1 parent b5dd469 commit bd90f26

File tree

6 files changed

+308
-0
lines changed

6 files changed

+308
-0
lines changed

Sources/Helpers/Data+Crypto.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import CryptoKit
2+
import Foundation
3+
4+
extension Data {
5+
6+
static func secureRandomData(length: UInt) -> Data? {
7+
let count = Int(length)
8+
var bytes = [Int8](repeating: 0, count: count)
9+
guard errSecSuccess == SecRandomCopyBytes(kSecRandomDefault, count, &bytes) else {
10+
return nil
11+
}
12+
return Data(bytes: bytes, count: count)
13+
}
14+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import ConcurrencyExtras
2+
import Foundation
3+
4+
package enum UniqueIdentifier: Hashable {
5+
package enum Format: Hashable {
6+
case base64, hex
7+
}
8+
case secureBytes(length: UInt = 10, format: Format)
9+
}
10+
11+
extension UniqueIdentifier {
12+
13+
func generate() -> String {
14+
switch self {
15+
case let .secureBytes(length, _):
16+
var data = Data()
17+
repeat {
18+
data = .secureRandomData(length: length) ?? Data()
19+
} while data.isEmpty
20+
return format(data: data)
21+
}
22+
}
23+
24+
func format(data: Data) -> String {
25+
switch self {
26+
case .secureBytes(_, .base64):
27+
return data.base64EncodedString(options: [])
28+
case .secureBytes(_, .hex):
29+
return data.map { String(format: "%02hhx", $0) }.joined()
30+
}
31+
}
32+
}
33+
34+
// MARK: - Generator
35+
36+
extension UniqueIdentifier {
37+
package struct Generator: Sendable {
38+
private let generate: @Sendable () -> String
39+
40+
package init(_ id: UniqueIdentifier) {
41+
self.init { id.generate() }
42+
}
43+
44+
package init(generate: @escaping @Sendable () -> String) {
45+
self.generate = generate
46+
}
47+
48+
@discardableResult
49+
package func callAsFunction() -> String {
50+
generate()
51+
}
52+
}
53+
}
54+
55+
extension UniqueIdentifier.Generator {
56+
package static func constant(_ id: UniqueIdentifier) -> Self {
57+
let generation = id.generate()
58+
return Self { generation }
59+
}
60+
61+
package static func incrementing(_ id: UniqueIdentifier) -> Self {
62+
let sequence = LockIsolated<Int>(0)
63+
return Self {
64+
let number = sequence.withValue {
65+
$0 += 1
66+
return $0
67+
}
68+
let data = withUnsafeBytes(of: number.bigEndian) { Data($0) }
69+
return id.format(data: data)
70+
}
71+
}
72+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import Dependencies
2+
import DependenciesMacros
3+
import Foundation
4+
import HTTPTypes
5+
import Helpers
6+
7+
extension NetworkingComponent {
8+
9+
/// Generates a HTTP Trace Parent header for each request.
10+
///
11+
/// - See-Also: [Trace-Context](https://www.w3.org/TR/trace-context/)
12+
public func traced() -> some NetworkingComponent {
13+
modified(Traced())
14+
}
15+
}
16+
17+
private struct Traced: NetworkingModifier {
18+
@Dependency(\.traceParentGenerator) var generate
19+
func resolve(upstream: some NetworkingComponent, request: HTTPRequestData) -> HTTPRequestData {
20+
guard nil == request.traceParent else {
21+
return request
22+
}
23+
var copy = request
24+
copy.traceParent = generate()
25+
return copy
26+
}
27+
}
28+
29+
extension HTTPField.Name {
30+
public static let traceparent = HTTPField.Name("traceparent")!
31+
}
32+
33+
extension HTTPRequestData {
34+
package fileprivate(set) var traceParent: TraceParent? {
35+
get { self[option: TraceParent.self] }
36+
set {
37+
self[option: TraceParent.self] = newValue
38+
self.headerFields[.traceparent] = newValue?.description
39+
}
40+
}
41+
42+
public var traceId: String? {
43+
traceParent?.traceId
44+
}
45+
46+
public var parentId: String? {
47+
traceParent?.parentId
48+
}
49+
}
50+
51+
public struct TraceParent: Sendable, Hashable, HTTPRequestDataOption {
52+
public static var defaultOption: Self?
53+
54+
// Current version of the spec only supports 01 flag
55+
// Future versions of the spec will require support for bit-field mask
56+
public let traceId: String
57+
public let parentId: String
58+
59+
public var description: String {
60+
"00-\(traceId)-\(parentId)-01"
61+
}
62+
63+
public init(traceId: String, parentId: String) {
64+
self.traceId = traceId
65+
self.parentId = parentId
66+
}
67+
}
68+
69+
// MARK: - Generator
70+
71+
@DependencyClient
72+
public struct TraceParentGenerator: Sendable {
73+
public var generate: @Sendable () -> TraceParent = {
74+
TraceParent(traceId: "dummy-trace-id", parentId: "dummy-parent-id")
75+
}
76+
77+
package func callAsFunction() -> TraceParent {
78+
generate()
79+
}
80+
}
81+
82+
extension TraceParentGenerator: DependencyKey {
83+
public static let liveValue = {
84+
let traceId = UniqueIdentifier.Generator(.secureBytes(length: 16, format: .hex))
85+
let parentId = UniqueIdentifier.Generator(.secureBytes(length: 8, format: .hex))
86+
return TraceParentGenerator {
87+
TraceParent(
88+
traceId: traceId(),
89+
parentId: parentId()
90+
)
91+
}
92+
}()
93+
}
94+
95+
extension DependencyValues {
96+
public var traceParentGenerator: TraceParentGenerator {
97+
get { self[TraceParentGenerator.self] }
98+
set { self[TraceParentGenerator.self] = newValue }
99+
}
100+
}
101+
102+
extension TraceParentGenerator {
103+
public static var incrementing: TraceParentGenerator {
104+
let traceId = UniqueIdentifier.Generator.incrementing(.secureBytes(length: 16, format: .hex))
105+
let parentId = UniqueIdentifier.Generator.incrementing(.secureBytes(length: 8, format: .hex))
106+
return TraceParentGenerator {
107+
TraceParent(
108+
traceId: traceId(),
109+
parentId: parentId()
110+
)
111+
}
112+
}
113+
}

Sources/TestSupport/Mocked.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ import Networking
22

33
extension NetworkingComponent {
44

5+
/// Mock all requests with a stub
6+
public func mocked(
7+
all stub: StubbedResponseStream
8+
) -> some NetworkingComponent {
9+
mocked(stub) { _ in true }
10+
}
11+
512
/// Mock a given request with a stub
613
public func mocked(
714
_ request: HTTPRequestData,

Sources/TestSupport/NetworkingTestCase.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ open class NetworkingTestCase: XCTestCase {
2323
) {
2424
withDependencies {
2525
$0.shortID = shortIdGenerator ?? .incrementing
26+
$0.traceParentGenerator = .incrementing
2627
$0.continuousClock = continuousClock ?? TestClock()
2728
updateValuesForOperation(&$0)
2829
} operation: {
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import Foundation
2+
import TestSupport
3+
import XCTest
4+
5+
@testable import Networking
6+
7+
final class TracedTests: NetworkingTestCase {
8+
override func invokeTest() {
9+
withTestDependencies {
10+
super.invokeTest()
11+
}
12+
}
13+
14+
func test__request_includes_trace() async throws {
15+
let reporter = TestReporter()
16+
17+
let network = TerminalNetworkingComponent()
18+
.mocked(all: .ok())
19+
.reported(by: reporter)
20+
.traced()
21+
22+
try await withThrowingTaskGroup(of: HTTPResponseData.self) { group in
23+
for _ in 0 ..< 10 {
24+
group.addTask {
25+
try await network.data(HTTPRequestData())
26+
}
27+
}
28+
29+
var responses: [HTTPResponseData] = []
30+
for try await response in group {
31+
responses.append(response)
32+
}
33+
}
34+
35+
let sentRequests = await reporter.requests
36+
37+
XCTAssertEqual(
38+
sentRequests.map(\.headerFields[.traceparent]),
39+
[
40+
"00-0000000000000001-0000000000000001-01",
41+
"00-0000000000000002-0000000000000002-01",
42+
"00-0000000000000003-0000000000000003-01",
43+
"00-0000000000000004-0000000000000004-01",
44+
"00-0000000000000005-0000000000000005-01",
45+
"00-0000000000000006-0000000000000006-01",
46+
"00-0000000000000007-0000000000000007-01",
47+
"00-0000000000000008-0000000000000008-01",
48+
"00-0000000000000009-0000000000000009-01",
49+
"00-000000000000000a-000000000000000a-01",
50+
]
51+
)
52+
53+
XCTAssertEqual(sentRequests.last?.traceId, "000000000000000a")
54+
XCTAssertEqual(sentRequests.first?.parentId, "0000000000000001")
55+
}
56+
57+
func test__traced_requests_do_not_get_another_trace() async throws {
58+
let reporter = TestReporter()
59+
60+
let network = TerminalNetworkingComponent()
61+
.mocked(all: .ok())
62+
.reported(by: reporter)
63+
.traced()
64+
65+
// Make an initial request
66+
let response = try await network.data(HTTPRequestData())
67+
68+
// Resend the request
69+
try await network.data(response.request)
70+
71+
let sentRequests = await reporter.requests
72+
73+
XCTAssertEqual(
74+
sentRequests.map(\.headerFields[.traceparent]),
75+
[
76+
"00-0000000000000001-0000000000000001-01",
77+
"00-0000000000000001-0000000000000001-01",
78+
]
79+
)
80+
}
81+
82+
func test__live_trace_generator() async {
83+
let generate = TraceParentGenerator.liveValue
84+
85+
let traces = await withTaskGroup(of: TraceParent.self) { group in
86+
for _ in 0 ..< 10 {
87+
group.addTask {
88+
generate()
89+
}
90+
}
91+
92+
var traces: [TraceParent] = []
93+
for await trace in group {
94+
traces.append(trace)
95+
}
96+
return traces
97+
}
98+
99+
XCTAssertEqual(Set(traces).count, 10)
100+
}
101+
}

0 commit comments

Comments
 (0)