Skip to content

Commit 9d9c973

Browse files
authored
Proper support of the trailing headers; redact sensitive headers in the log output (#20)
1 parent a656c0c commit 9d9c973

File tree

7 files changed

+80
-55
lines changed

7 files changed

+80
-55
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ Run the following command to run Connect-RPC conformance tests:
8282
docker build . --output "out" --progress=plain
8383
```
8484

85-
Execution results are output in STDOUT.
86-
Diagnostic data from the server itself is output in the `out/out.log` file.
85+
Execution results are output to STDOUT.
86+
Diagnostic data from the server itself is written to the log file `out/out.log`.
8787

8888
### Conformance tests status
8989

@@ -92,7 +92,6 @@ Current status: 6/79 tests pass
9292
Known issues:
9393

9494
* fs2-grpc server implementation doesn't support setting response headers
95-
* Trailers that are set in http4s aren’t being sent to the client
9695
* `google.protobuf.Any` serialization doesn't follow Connect-RPC spec
9796

9897
## Future improvements

conformance/src/main/scala/org/ivovk/connect_rpc_scala/conformance/ConformanceServiceImpl.scala

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import java.util.concurrent.TimeUnit
1212
import scala.concurrent.duration.Duration
1313
import scala.jdk.CollectionConverters.*
1414

15+
case class UnaryHandlerResponse(payload: ConformancePayload, trailers: Metadata)
1516

1617
class ConformanceServiceImpl[F[_] : Async] extends ConformanceServiceFs2GrpcTrailers[F, Metadata] {
1718

@@ -24,38 +25,38 @@ class ConformanceServiceImpl[F[_] : Async] extends ConformanceServiceFs2GrpcTrai
2425
ctx: Metadata
2526
): F[(UnaryResponse, Metadata)] = {
2627
for
27-
payload <- handleUnaryRequest(
28+
res <- handleUnaryRequest(
2829
request.getResponseDefinition,
2930
Seq(request.toProtoAny),
3031
ctx
3132
)
32-
yield (UnaryResponse(payload.some), new Metadata())
33+
yield (
34+
UnaryResponse(res.payload.some),
35+
res.trailers
36+
)
3337
}
3438

3539
override def idempotentUnary(
3640
request: IdempotentUnaryRequest,
3741
ctx: Metadata,
3842
): F[(IdempotentUnaryResponse, Metadata)] = {
3943
for
40-
payload <- handleUnaryRequest(
44+
res <- handleUnaryRequest(
4145
request.getResponseDefinition,
4246
Seq(request.toProtoAny),
4347
ctx
4448
)
45-
yield (IdempotentUnaryResponse(payload.some), new Metadata())
49+
yield (
50+
IdempotentUnaryResponse(res.payload.some),
51+
res.trailers
52+
)
4653
}
4754

4855
private def handleUnaryRequest(
4956
responseDefinition: UnaryResponseDefinition,
5057
requests: Seq[com.google.protobuf.any.Any],
5158
ctx: Metadata,
52-
): F[ConformancePayload] = {
53-
//val trailers = new Metadata()
54-
//responseDefinition.responseTrailers.foreach { h =>
55-
// val key = Metadata.Key.of(h.name, Metadata.ASCII_STRING_MARSHALLER)
56-
// h.value.foreach(v => trailers.put(key, v))
57-
//}
58-
59+
): F[UnaryHandlerResponse] = {
5960
val requestInfo = ConformancePayload.RequestInfo(
6061
requestHeaders = mkConformanceHeaders(ctx),
6162
timeoutMs = extractTimeout(ctx),
@@ -77,10 +78,15 @@ class ConformanceServiceImpl[F[_] : Async] extends ConformanceServiceFs2GrpcTrai
7778
throw new StatusRuntimeException(status)
7879
}
7980

81+
val trailers = mkMetadata(responseDefinition.responseTrailers)
82+
8083
Async[F].sleep(Duration(responseDefinition.responseDelayMs, TimeUnit.MILLISECONDS)) *>
81-
Async[F].pure(ConformancePayload(
82-
responseData.getOrElse(ByteString.EMPTY),
83-
requestInfo.some
84+
Async[F].pure(UnaryHandlerResponse(
85+
payload = ConformancePayload(
86+
responseData.getOrElse(ByteString.EMPTY),
87+
requestInfo.some
88+
),
89+
trailers = trailers
8490
))
8591
}
8692

@@ -93,6 +99,16 @@ class ConformanceServiceImpl[F[_] : Async] extends ConformanceServiceFs2GrpcTrai
9399
}.toSeq
94100
}
95101

102+
private def mkMetadata(headers: Seq[Header]): Metadata = {
103+
val metadata = new Metadata()
104+
headers.foreach { h =>
105+
h.value.foreach { v =>
106+
metadata.put(keyof(h.name), v)
107+
}
108+
}
109+
metadata
110+
}
111+
96112
private def extractTimeout(metadata: Metadata): Option[Long] = {
97113
Option(metadata.get(keyof("grpc-timeout")))
98114
.map(v => v.substring(0, v.length - 1).toLong / 1000)

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRpcHttpRoutes.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import io.grpc.stub.MetadataUtils
1212
import org.http4s.*
1313
import org.http4s.dsl.Http4sDsl
1414
import org.ivovk.connect_rpc_scala.http.*
15-
import org.ivovk.connect_rpc_scala.http.Headers.*
15+
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
1616
import org.ivovk.connect_rpc_scala.http.MessageCodec.given
1717
import org.ivovk.connect_rpc_scala.http.QueryParams.*
1818
import org.slf4j.{Logger, LoggerFactory}
@@ -77,15 +77,15 @@ object ConnectRpcHttpRoutes {
7777

7878
_ <- EitherT.cond[F](
7979
// Support GET-requests for all methods until https://github.com/scalapb/ScalaPB/pull/1774 is merged
80-
httpMethod == Method.POST || (httpMethod == Method.GET && method.methodDescriptor.isSafe) || true,
80+
httpMethod == Method.POST || (httpMethod == Method.GET && method.descriptor.isSafe) || true,
8181
(),
8282
Forbidden(connectrpc.Error(
8383
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
8484
message = s"Only POST-requests are allowed for method: $grpcMethod".some
8585
))
8686
).leftSemiflatMap(identity)
8787

88-
response <- method.methodDescriptor.getType match
88+
response <- method.descriptor.getType match
8989
case MethodType.UNARY =>
9090
EitherT.right(handleUnary(dsl, method, entity, ipChannel))
9191
case unsupported =>
@@ -115,7 +115,7 @@ object ConnectRpcHttpRoutes {
115115

116116
private def handleUnary[F[_] : Async](
117117
dsl: Http4sDsl[F],
118-
entry: RegistryEntry,
118+
method: MethodRegistry.Entry,
119119
req: RequestEntity[F],
120120
channel: Channel
121121
)(using codec: MessageCodec[F]): F[Response[F]] = {
@@ -130,15 +130,15 @@ object ConnectRpcHttpRoutes {
130130
}
131131
}
132132

133-
given GeneratedMessageCompanion[GeneratedMessage] = entry.requestMessageCompanion
133+
given GeneratedMessageCompanion[GeneratedMessage] = method.requestMessageCompanion
134134

135135
req.as[GeneratedMessage]
136136
.flatMap { message =>
137137
val responseHeaderMetadata = new AtomicReference[Metadata]()
138138
val responseTrailerMetadata = new AtomicReference[Metadata]()
139139

140140
if (logger.isTraceEnabled) {
141-
logger.trace(s">>> Method: ${entry.methodDescriptor.getFullMethodName}, Entity: $message")
141+
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}, Entity: $message")
142142
}
143143

144144
Async[F].fromFuture(Async[F].delay {
@@ -148,26 +148,25 @@ object ConnectRpcHttpRoutes {
148148
MetadataUtils.newAttachHeadersInterceptor(req.headers.toMetadata),
149149
MetadataUtils.newCaptureMetadataInterceptor(responseHeaderMetadata, responseTrailerMetadata),
150150
),
151-
entry.methodDescriptor,
151+
method.descriptor,
152152
CallOptions.DEFAULT
153153
.pipe(
154-
req.headers.get[`Connect-Timeout-Ms`].fold[Endo[CallOptions]](identity) { header =>
155-
_.withDeadlineAfter(header.value, MILLISECONDS)
154+
req.timeout.fold[Endo[CallOptions]](identity) { timeout =>
155+
_.withDeadlineAfter(timeout, MILLISECONDS)
156156
}
157157
),
158158
message
159159
)
160160
}).map { response =>
161-
val headers = responseHeaderMetadata.get().toHeaders
162-
val trailers = responseTrailerMetadata.get().toHeaders
161+
val headers = org.http4s.Headers.empty ++
162+
responseHeaderMetadata.get().toHeaders ++
163+
responseTrailerMetadata.get().toTrailingHeaders
163164

164165
if (logger.isTraceEnabled) {
165-
logger.trace(s"<<< Headers: $headers, Trailers: $trailers")
166+
logger.trace(s"<<< Headers: ${headers.redactSensitive}")
166167
}
167168

168-
Response(Ok, headers = headers)
169-
.withEntity(response)
170-
.withTrailerHeaders(Async[F].pure(trailers))
169+
Response(Ok, headers = headers).withEntity(response)
171170
}
172171
}
173172
.recover { case e =>

core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import org.http4s.{Header, Headers}
55
import org.typelevel.ci.CIString
66
import scalapb.GeneratedMessage
77

8-
import scala.jdk.CollectionConverters.*
9-
108
object Mappings extends HeaderMappings, StatusCodeMappings, AnyMappings
119

1210
trait HeaderMappings {
@@ -25,17 +23,26 @@ trait HeaderMappings {
2523
}
2624

2725
extension (metadata: Metadata) {
28-
def toHeaders: Headers = {
29-
val headers = metadata.keys()
30-
.asScala.toList
31-
.flatMap { key =>
32-
metadata.getAll(asciiKey(key)).asScala.map { value =>
33-
Header.Raw(CIString(key), value)
34-
}
26+
private def headers(prefix: String = ""): Headers = {
27+
val keys = metadata.keys()
28+
if (keys.isEmpty) return Headers.empty
29+
30+
val b = List.newBuilder[Header.Raw]
31+
32+
keys.forEach { key =>
33+
val name = CIString(prefix + key)
34+
35+
metadata.getAll(asciiKey(key)).forEach { value =>
36+
b += Header.Raw(name, value)
3537
}
38+
}
3639

37-
Headers(headers)
40+
new Headers(b.result())
3841
}
42+
43+
def toHeaders: Headers = headers()
44+
45+
def toTrailingHeaders: Headers = headers("trailer-")
3946
}
4047

4148
}

core/src/main/scala/org/ivovk/connect_rpc_scala/MethodRegistry.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
55

66
import scala.jdk.CollectionConverters.*
77

8-
case class RegistryEntry(
9-
requestMessageCompanion: GeneratedMessageCompanion[GeneratedMessage],
10-
methodDescriptor: MethodDescriptor[GeneratedMessage, GeneratedMessage],
11-
)
12-
138
object MethodRegistry {
149

10+
case class Entry(
11+
requestMessageCompanion: GeneratedMessageCompanion[GeneratedMessage],
12+
descriptor: MethodDescriptor[GeneratedMessage, GeneratedMessage],
13+
)
14+
1515
def apply(services: Seq[ServerServiceDefinition]): MethodRegistry = {
1616
val entries = services
1717
.flatMap(_.getMethods.asScala)
@@ -30,12 +30,12 @@ object MethodRegistry {
3030
val requestCompanion = companionField.get(requestMarshaller)
3131
.asInstanceOf[GeneratedMessageCompanion[GeneratedMessage]]
3232

33-
val entry = RegistryEntry(
33+
val methodEntry = Entry(
3434
requestMessageCompanion = requestCompanion,
35-
methodDescriptor = methodDescriptor,
35+
descriptor = methodDescriptor,
3636
)
3737

38-
methodDescriptor.getFullMethodName -> entry
38+
methodDescriptor.getFullMethodName -> methodEntry
3939
}
4040
.toMap
4141

@@ -44,9 +44,9 @@ object MethodRegistry {
4444

4545
}
4646

47-
class MethodRegistry private(entries: Map[String, RegistryEntry]) {
47+
class MethodRegistry private(entries: Map[String, MethodRegistry.Entry]) {
4848

49-
def get(fullMethodName: String): Option[RegistryEntry] =
49+
def get(fullMethodName: String): Option[MethodRegistry.Entry] =
5050
entries.get(fullMethodName)
5151

5252
}

core/src/main/scala/org/ivovk/connect_rpc_scala/http/MessageCodec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class JsonMessageCodec[F[_] : Sync : Compression](printer: Printer) extends Mess
5656
val f = string
5757
.flatMap { str =>
5858
if (logger.isTraceEnabled) {
59-
logger.trace(s">>> Headers: ${entity.headers}")
59+
logger.trace(s">>> Headers: ${entity.headers.redactSensitive}")
6060
logger.trace(s">>> JSON: $str")
6161
}
6262

@@ -98,7 +98,7 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {
9898

9999
EitherT.right(f.map { message =>
100100
if (logger.isTraceEnabled) {
101-
logger.trace(s">>> Headers: ${entity.headers}")
101+
logger.trace(s">>> Headers: ${entity.headers.redactSensitive}")
102102
logger.trace(s">>> Proto: ${message.toProtoString}")
103103
}
104104

core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package org.ivovk.connect_rpc_scala.http
22

33
import cats.MonadThrow
44
import fs2.Stream
5-
import org.http4s.{Charset, Headers, Media}
65
import org.http4s.headers.`Content-Type`
6+
import org.http4s.{Charset, Headers, Media}
7+
import org.ivovk.connect_rpc_scala.http.Headers.`Connect-Timeout-Ms`
78
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
89

910
object RequestEntity {
@@ -26,6 +27,9 @@ case class RequestEntity[F[_]](
2627
def charset: Charset =
2728
contentType.flatMap(_.charset).getOrElse(Charset.`UTF-8`)
2829

30+
def timeout: Option[Long] =
31+
headers.get[`Connect-Timeout-Ms`].map(_.value)
32+
2933
def as[A <: Message](using M: MonadThrow[F], codec: MessageCodec[F], cmp: Companion[A]): F[A] =
3034
M.rethrow(codec.decode(this).value)
3135

0 commit comments

Comments
 (0)