Skip to content

Commit 9462a3e

Browse files
authored
Better interface for setting headers during response encoding (#53)
1 parent de571fd commit 9462a3e

File tree

6 files changed

+106
-75
lines changed

6 files changed

+106
-75
lines changed

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import org.ivovk.connect_rpc_scala.Mappings.*
1010
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry}
1111
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
1212
import org.ivovk.connect_rpc_scala.http.RequestEntity
13-
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given
1413
import org.ivovk.connect_rpc_scala.http.codec.{Compressor, EncodeOptions, MessageCodec}
1514
import org.slf4j.{Logger, LoggerFactory}
1615
import scalapb.GeneratedMessage
@@ -19,12 +18,22 @@ import scala.concurrent.duration.*
1918
import scala.jdk.CollectionConverters.*
2019
import scala.util.chaining.*
2120

21+
object ConnectHandler {
22+
23+
extension [F[_]](response: Response[F]) {
24+
def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] =
25+
codec.encode(entity, options).applyTo(response)
26+
}
27+
28+
}
29+
2230
class ConnectHandler[F[_] : Async](
2331
channel: Channel,
2432
httpDsl: Http4sDsl[F],
2533
treatTrailersAsHeaders: Boolean,
2634
) {
2735

36+
import ConnectHandler.*
2837
import httpDsl.*
2938

3039
private val logger: Logger = LoggerFactory.getLogger(getClass)
@@ -37,14 +46,53 @@ class ConnectHandler[F[_] : Async](
3746
encoding = req.encoding.filter(Compressor.supportedEncodings.contains)
3847
)
3948

40-
method.descriptor.getType match
49+
val f = method.descriptor.getType match
4150
case MethodType.UNARY =>
4251
handleUnary(req, method)
4352
case unsupported =>
44-
NotImplemented(connectrpc.Error(
45-
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
46-
message = s"Unsupported method type: $unsupported".some
53+
Async[F].raiseError(new StatusException(
54+
io.grpc.Status.UNIMPLEMENTED.withDescription(s"Unsupported method type: $unsupported")
4755
))
56+
57+
f.handleError { e =>
58+
val grpcStatus = e match {
59+
case e: StatusException =>
60+
e.getStatus.getDescription match {
61+
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
62+
case _ => e.getStatus
63+
}
64+
case e: StatusRuntimeException => e.getStatus
65+
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
66+
case _ => io.grpc.Status.INTERNAL
67+
}
68+
69+
val (message, metadata) = e match {
70+
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers)
71+
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers)
72+
case e => (Option(e.getMessage), new Metadata())
73+
}
74+
75+
val httpStatus = grpcStatus.toHttpStatus
76+
val connectCode = grpcStatus.toConnectCode
77+
78+
// Should be called before converting metadata to headers
79+
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
80+
.fold(Seq.empty)(_.asScala.toSeq)
81+
82+
val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)
83+
84+
if (logger.isTraceEnabled) {
85+
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
86+
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
87+
logger.trace(s"<<< Error processing request", e)
88+
}
89+
90+
Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error(
91+
code = connectCode,
92+
message = message,
93+
details = details
94+
))
95+
}
4896
}
4997

5098
private def handleUnary(
@@ -90,46 +138,7 @@ class ConnectHandler[F[_] : Async](
90138
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
91139
}
92140

93-
Response(Ok, headers = headers).withEntity(response.value)
94-
}
95-
.recover { case e =>
96-
val grpcStatus = e match {
97-
case e: StatusException =>
98-
e.getStatus.getDescription match {
99-
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
100-
case _ => e.getStatus
101-
}
102-
case e: StatusRuntimeException => e.getStatus
103-
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
104-
case _ => io.grpc.Status.INTERNAL
105-
}
106-
107-
val (message, metadata) = e match {
108-
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers)
109-
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers)
110-
case e => (Option(e.getMessage), new Metadata())
111-
}
112-
113-
val httpStatus = grpcStatus.toHttpStatus
114-
val connectCode = grpcStatus.toConnectCode
115-
116-
// Should be called before converting metadata to headers
117-
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
118-
.fold(Seq.empty)(_.asScala.toSeq)
119-
120-
val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)
121-
122-
if (logger.isTraceEnabled) {
123-
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
124-
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
125-
logger.trace(s"<<< Error processing request", e)
126-
}
127-
128-
Response[F](httpStatus, headers = headers).withEntity(connectrpc.Error(
129-
code = connectCode,
130-
message = message,
131-
details = details
132-
))
141+
Response(Ok, headers = headers).withMessage(response.value)
133142
}
134143
}
135144

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.ivovk.connect_rpc_scala.http
2+
3+
import org.http4s.headers.`Content-Length`
4+
import org.http4s.{EntityBody, Headers, Response}
5+
6+
import scala.util.chaining.*
7+
8+
case class ResponseEntity[F[_]](
9+
headers: Headers,
10+
body: EntityBody[F],
11+
length: Option[Long] = None
12+
) {
13+
14+
def applyTo(response: Response[F]): Response[F] = {
15+
val headers = (response.headers ++ this.headers)
16+
.pipe(
17+
length match
18+
case Some(length) => _.withContentLength(`Content-Length`(length))
19+
case None => identity
20+
)
21+
22+
response.copy(
23+
headers = headers,
24+
body = body,
25+
)
26+
}
27+
28+
}

core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/Compressor.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package org.ivovk.connect_rpc_scala.http.codec
33
import cats.effect.Sync
44
import fs2.Stream
55
import fs2.compression.Compression
6-
import org.http4s.{ContentCoding, Entity}
6+
import io.grpc.{Status, StatusException}
7+
import org.http4s.ContentCoding
8+
import org.http4s.headers.`Content-Encoding`
9+
import org.ivovk.connect_rpc_scala.http.ResponseEntity
710

811
object Compressor {
912
val supportedEncodings: Set[ContentCoding] = Set(ContentCoding.gzip)
@@ -18,19 +21,20 @@ class Compressor[F[_] : Sync] {
1821
case Some(ContentCoding.gzip) =>
1922
Compression[F].gunzip().andThen(_.flatMap(_.content))
2023
case Some(other) =>
21-
throw new IllegalArgumentException(s"Unsupported encoding: $other")
24+
throw new StatusException(Status.INVALID_ARGUMENT.withDescription(s"Unsupported encoding: $other"))
2225
case None =>
2326
identity
2427
})
2528

26-
def compressed(encoding: Option[ContentCoding], entity: Entity[F]): Entity[F] =
29+
def compressed(encoding: Option[ContentCoding], entity: ResponseEntity[F]): ResponseEntity[F] =
2730
encoding match {
2831
case Some(ContentCoding.gzip) =>
29-
Entity(
32+
ResponseEntity(
33+
headers = entity.headers.put(`Content-Encoding`(ContentCoding.gzip)),
3034
body = entity.body.through(Compression[F].gzip()),
3135
)
3236
case Some(other) =>
33-
throw new IllegalArgumentException(s"Unsupported encoding: $other")
37+
throw new StatusException(Status.INVALID_ARGUMENT.withDescription(s"Unsupported encoding: $other"))
3438
case None =>
3539
entity
3640
}

core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/JsonMessageCodec.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import cats.effect.Sync
44
import cats.implicits.*
55
import fs2.text.decodeWithCharset
66
import fs2.{Chunk, Stream}
7-
import org.http4s.{DecodeResult, Entity, InvalidMessageBodyFailure, MediaType}
8-
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity}
7+
import org.http4s.headers.`Content-Type`
8+
import org.http4s.{DecodeResult, Headers, InvalidMessageBodyFailure, MediaType}
9+
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity, ResponseEntity}
910
import org.slf4j.LoggerFactory
1011
import scalapb.json4s.{Parser, Printer}
1112
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
@@ -46,7 +47,7 @@ class JsonMessageCodec[F[_] : Sync](
4647
.leftMap(e => InvalidMessageBodyFailure(e.getMessage, e.some))
4748
}
4849

49-
override def encode[A <: Message](message: A, options: EncodeOptions): Entity[F] = {
50+
override def encode[A <: Message](message: A, options: EncodeOptions): ResponseEntity[F] = {
5051
val string = printer.print(message)
5152

5253
if (logger.isTraceEnabled) {
@@ -55,7 +56,8 @@ class JsonMessageCodec[F[_] : Sync](
5556

5657
val bytes = string.getBytes()
5758

58-
val entity = Entity(
59+
val entity = ResponseEntity[F](
60+
headers = Headers(`Content-Type`(mediaType)),
5961
body = Stream.chunk(Chunk.array(bytes)),
6062
length = Some(bytes.length.toLong),
6163
)
Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package org.ivovk.connect_rpc_scala.http.codec
22

3-
import org.http4s.headers.{`Content-Encoding`, `Content-Type`}
4-
import org.http4s.{ContentCoding, DecodeResult, Entity, EntityEncoder, Header, Headers, MediaType}
5-
import org.ivovk.connect_rpc_scala.http.RequestEntity
3+
import org.http4s.{ContentCoding, DecodeResult, MediaType}
4+
import org.ivovk.connect_rpc_scala.http.{RequestEntity, ResponseEntity}
65
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
76

87
import scala.util.chaining.*
@@ -11,25 +10,12 @@ case class EncodeOptions(
1110
encoding: Option[ContentCoding]
1211
)
1312

14-
object MessageCodec {
15-
given [F[_], A <: Message](using codec: MessageCodec[F], options: EncodeOptions): EntityEncoder[F, A] = {
16-
val headers = Headers(`Content-Type`(codec.mediaType))
17-
.pipe(
18-
options.encoding match
19-
case Some(encoding) => _.put(`Content-Encoding`(encoding))
20-
case None => identity
21-
)
22-
23-
EntityEncoder.encodeBy(headers)(codec.encode(_, options))
24-
}
25-
}
26-
2713
trait MessageCodec[F[_]] {
2814

2915
val mediaType: MediaType
3016

3117
def decode[A <: Message](m: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A]
3218

33-
def encode[A <: Message](message: A, options: EncodeOptions): Entity[F]
19+
def encode[A <: Message](message: A, options: EncodeOptions): ResponseEntity[F]
3420

3521
}

core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/ProtoMessageCodec.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import cats.implicits.*
55
import com.google.protobuf.CodedOutputStream
66
import fs2.Stream
77
import fs2.io.{readOutputStream, toInputStreamResource}
8-
import org.http4s.{DecodeResult, Entity, InvalidMessageBodyFailure, MediaType}
9-
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity}
8+
import org.http4s.headers.`Content-Type`
9+
import org.http4s.{DecodeResult, Headers, InvalidMessageBodyFailure, MediaType}
10+
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity, ResponseEntity}
1011
import org.slf4j.LoggerFactory
1112
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
1213

@@ -44,15 +45,16 @@ class ProtoMessageCodec[F[_] : Async] extends MessageCodec[F] {
4445
.leftMap(e => InvalidMessageBodyFailure(e.getMessage, e.some))
4546
}
4647

47-
override def encode[A <: Message](message: A, options: EncodeOptions): Entity[F] = {
48+
override def encode[A <: Message](message: A, options: EncodeOptions): ResponseEntity[F] = {
4849
if (logger.isTraceEnabled) {
4950
logger.trace(s"<<< Proto: ${message.toProtoString}")
5051
}
5152

5253
val dataLength = message.serializedSize
5354
val chunkSize = CodedOutputStream.DEFAULT_BUFFER_SIZE min dataLength
5455

55-
val entity = Entity(
56+
val entity = ResponseEntity(
57+
headers = Headers(`Content-Type`(mediaType)),
5658
body = readOutputStream(chunkSize)(os => Async[F].delay(message.writeTo(os))),
5759
length = Some(dataLength.toLong),
5860
)

0 commit comments

Comments
 (0)