Skip to content

Commit 48c9499

Browse files
authored
Extract error handler to a separate entity so it can be reused (#61)
1 parent ab4c608 commit 48c9499

File tree

6 files changed

+97
-129
lines changed

6 files changed

+97
-129
lines changed

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,25 @@ import cats.effect.Async
44
import cats.implicits.*
55
import io.grpc.*
66
import io.grpc.MethodDescriptor.MethodType
7-
import org.http4s.dsl.Http4sDsl
8-
import org.http4s.{Header, MessageFailure, Response}
7+
import org.http4s.Status.Ok
8+
import org.http4s.{Header, Response}
99
import org.ivovk.connect_rpc_scala.Mappings.*
10-
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry}
10+
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, MethodRegistry}
1111
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
1212
import org.ivovk.connect_rpc_scala.http.RequestEntity
1313
import org.ivovk.connect_rpc_scala.http.codec.{Compressor, EncodeOptions, MessageCodec}
1414
import org.slf4j.{Logger, LoggerFactory}
1515
import scalapb.GeneratedMessage
1616

1717
import scala.concurrent.duration.*
18-
import scala.jdk.CollectionConverters.*
1918
import scala.util.chaining.*
2019

21-
object ConnectHandler {
22-
23-
extension [F[_]](response: Response[F]) {
24-
def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] =
25-
codec.encode(entity, options).applyTo(response)
26-
}
27-
28-
}
29-
3020
class ConnectHandler[F[_] : Async](
3121
channel: Channel,
32-
httpDsl: Http4sDsl[F],
22+
errorHandler: ErrorHandler[F],
3323
treatTrailersAsHeaders: Boolean,
3424
) {
3525

36-
import ConnectHandler.*
37-
import httpDsl.*
38-
3926
private val logger: Logger = LoggerFactory.getLogger(getClass)
4027

4128
def handle(
@@ -54,45 +41,7 @@ class ConnectHandler[F[_] : Async](
5441
io.grpc.Status.UNIMPLEMENTED.withDescription(s"Unsupported method type: $unsupported")
5542
))
5643

57-
f.handleError { e =>
58-
val grpcStatus = e match {
59-
case e: StatusException =>
60-
e.getStatus.getDescription match {
61-
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
62-
case _ => e.getStatus
63-
}
64-
case e: StatusRuntimeException => e.getStatus
65-
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
66-
case _ => io.grpc.Status.INTERNAL
67-
}
68-
69-
val (message, metadata) = e match {
70-
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers)
71-
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers)
72-
case e => (Option(e.getMessage), new Metadata())
73-
}
74-
75-
val httpStatus = grpcStatus.toHttpStatus
76-
val connectCode = grpcStatus.toConnectCode
77-
78-
// Should be called before converting metadata to headers
79-
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
80-
.fold(Seq.empty)(_.asScala.toSeq)
81-
82-
val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)
83-
84-
if (logger.isTraceEnabled) {
85-
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
86-
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
87-
logger.trace(s"<<< Error processing request", e)
88-
}
89-
90-
Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error(
91-
code = connectCode,
92-
message = message,
93-
details = details
94-
))
95-
}
44+
f.handleError(errorHandler.handle)
9645
}
9746

9847
private def handleUnary(

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import cats.data.OptionT
55
import cats.effect.{Async, Resource}
66
import cats.implicits.*
77
import io.grpc.{ManagedChannelBuilder, ServerBuilder, ServerServiceDefinition}
8+
import org.http4s.Status.*
89
import org.http4s.dsl.Http4sDsl
910
import org.http4s.{HttpApp, HttpRoutes, MediaType, Method, Response, Uri}
1011
import org.ivovk.connect_rpc_scala.grpc.*
@@ -127,9 +128,13 @@ final class ConnectRouteBuilder[F[_] : Async] private(
127128
waitForShutdown,
128129
)
129130
yield
131+
val errorHandler = new ConnectErrorHandler[F](
132+
treatTrailersAsHeaders,
133+
)
134+
130135
val connectHandler = new ConnectHandler(
131136
channel,
132-
httpDsl,
137+
errorHandler,
133138
treatTrailersAsHeaders,
134139
)
135140

@@ -140,7 +145,7 @@ final class ConnectRouteBuilder[F[_] : Async] private(
140145
// until https://github.com/scalapb/ScalaPB/pull/1774 is merged
141146
.filter(_.descriptor.isSafe || true)
142147
.semiflatMap { methodEntry =>
143-
withCodec(httpDsl, codecRegistry, mediaType.some) { codec =>
148+
withCodec(codecRegistry, mediaType.some) { codec =>
144149
val entity = RequestEntity[F](message, req.headers)
145150

146151
connectHandler.handle(entity, methodEntry)(using codec)
@@ -149,7 +154,7 @@ final class ConnectRouteBuilder[F[_] : Async] private(
149154
case req@Method.POST -> `pathPrefix` / service / method =>
150155
OptionT.fromOption[F](methodRegistry.get(service, method))
151156
.semiflatMap { methodEntry =>
152-
withCodec(httpDsl, codecRegistry, req.contentType.map(_.mediaType)) { codec =>
157+
withCodec(codecRegistry, req.contentType.map(_.mediaType)) { codec =>
153158
val entity = RequestEntity[F](req.body, req.headers)
154159

155160
connectHandler.handle(entity, methodEntry)(using codec)
@@ -163,10 +168,10 @@ final class ConnectRouteBuilder[F[_] : Async] private(
163168
methodRegistry.all,
164169
pathPrefix,
165170
)
166-
val transcodingHandler = new TranscodingHandler(
171+
172+
val transcodingHandler = new TranscodingHandler(
167173
channel,
168-
httpDsl,
169-
treatTrailersAsHeaders,
174+
errorHandler,
170175
)
171176

172177
val transcodingRoutes = HttpRoutes[F] { req =>
@@ -197,19 +202,16 @@ final class ConnectRouteBuilder[F[_] : Async] private(
197202
buildRoutes.map(_.orNotFound)
198203

199204
private def withCodec(
200-
dsl: Http4sDsl[F],
201205
registry: MessageCodecRegistry[F],
202206
mediaType: Option[MediaType]
203207
)(r: MessageCodec[F] => F[Response[F]]): F[Response[F]] = {
204-
import dsl.*
205-
206208
mediaType.flatMap(registry.byMediaType) match {
207209
case Some(codec) => r(codec)
208210
case None =>
209211
val message = s"Unsupported media-type ${mediaType.show}. " +
210212
s"Supported media types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}"
211213

212-
UnsupportedMediaType(message)
214+
Response(UnsupportedMediaType).withEntity(message).pure[F]
213215
}
214216
}
215217

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package org.ivovk.connect_rpc_scala
2+
3+
import io.grpc.{Metadata, StatusException, StatusRuntimeException}
4+
import org.http4s.{MessageFailure, Response}
5+
import org.ivovk.connect_rpc_scala.Mappings.*
6+
import org.ivovk.connect_rpc_scala.grpc.GrpcHeaders
7+
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec
8+
import org.slf4j.LoggerFactory
9+
10+
import scala.jdk.CollectionConverters.*
11+
12+
trait ErrorHandler[F[_]] {
13+
def handle(e: Throwable)(using MessageCodec[F]): Response[F]
14+
}
15+
16+
class ConnectErrorHandler[F[_]](
17+
treatTrailersAsHeaders: Boolean = false,
18+
) extends ErrorHandler[F] {
19+
20+
private val logger = LoggerFactory.getLogger(getClass)
21+
22+
override def handle(e: Throwable)(using MessageCodec[F]): Response[F] = {
23+
val grpcStatus = e match {
24+
case e: StatusException =>
25+
e.getStatus.getDescription match {
26+
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
27+
case _ => e.getStatus
28+
}
29+
case e: StatusRuntimeException => e.getStatus
30+
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
31+
case _ => io.grpc.Status.INTERNAL
32+
}
33+
34+
val (message, metadata) = e match {
35+
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers)
36+
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers)
37+
case e => (Option(e.getMessage), new Metadata())
38+
}
39+
40+
val httpStatus = grpcStatus.toHttpStatus
41+
val connectCode = grpcStatus.toConnectCode
42+
43+
// Should be called before converting metadata to headers
44+
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
45+
.fold(Seq.empty)(_.asScala.toSeq)
46+
47+
val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)
48+
49+
if (logger.isTraceEnabled) {
50+
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
51+
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
52+
logger.trace(s"<<< Error processing request", e)
53+
}
54+
55+
Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error(
56+
code = connectCode,
57+
message = message,
58+
details = details
59+
))
60+
}
61+
}

core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package org.ivovk.connect_rpc_scala
22

33
import io.grpc.{Metadata, Status}
4-
import org.http4s.{Header, Headers}
4+
import org.http4s.{Header, Headers, Response}
55
import org.ivovk.connect_rpc_scala.grpc.GrpcHeaders.asciiKey
6+
import org.ivovk.connect_rpc_scala.http.codec.{EncodeOptions, MessageCodec}
67
import org.typelevel.ci.CIString
8+
import scalapb.GeneratedMessage
79

8-
object Mappings extends HeaderMappings, StatusCodeMappings
10+
object Mappings extends HeaderMappings, StatusCodeMappings, ResponseCodeExtensions
911

1012
trait HeaderMappings {
1113

@@ -47,6 +49,13 @@ trait HeaderMappings {
4749

4850
}
4951

52+
trait ResponseCodeExtensions {
53+
extension [F[_]](response: Response[F]) {
54+
def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] =
55+
codec.encode(entity, options).applyTo(response)
56+
}
57+
}
58+
5059
trait StatusCodeMappings {
5160

5261
private val httpStatusCodesByGrpcStatusCode: Array[org.http4s.Status] = {

core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package org.ivovk.connect_rpc_scala
33
import cats.effect.Async
44
import cats.implicits.*
55
import io.grpc.*
6-
import org.http4s.dsl.Http4sDsl
7-
import org.http4s.{Header, Headers, MessageFailure, Response}
6+
import org.http4s.Status.Ok
7+
import org.http4s.{Header, Headers, Response}
88
import org.ivovk.connect_rpc_scala.Mappings.*
9-
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry}
9+
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, MethodRegistry}
1010
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
1111
import org.ivovk.connect_rpc_scala.http.RequestEntity
1212
import org.ivovk.connect_rpc_scala.http.RequestEntity.*
@@ -15,27 +15,13 @@ import org.slf4j.{Logger, LoggerFactory}
1515
import scalapb.GeneratedMessage
1616

1717
import scala.concurrent.duration.*
18-
import scala.jdk.CollectionConverters.*
1918
import scala.util.chaining.*
2019

21-
object TranscodingHandler {
22-
23-
extension [F[_]](response: Response[F]) {
24-
def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] =
25-
codec.encode(entity, options).applyTo(response)
26-
}
27-
28-
}
29-
3020
class TranscodingHandler[F[_] : Async](
3121
channel: Channel,
32-
httpDsl: Http4sDsl[F],
33-
treatTrailersAsHeaders: Boolean,
22+
errorHandler: ErrorHandler[F],
3423
) {
3524

36-
import TranscodingHandler.*
37-
import httpDsl.*
38-
3925
private val logger: Logger = LoggerFactory.getLogger(getClass)
4026

4127
def handleUnary(
@@ -73,54 +59,15 @@ class TranscodingHandler[F[_] : Async](
7359
message
7460
)
7561
.map { response =>
76-
val headers = response.headers.toHeaders() ++
77-
response.trailers.toHeaders(trailing = !treatTrailersAsHeaders)
62+
val headers = response.headers.toHeaders() ++ response.trailers.toHeaders()
7863

7964
if (logger.isTraceEnabled) {
8065
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
8166
}
8267

8368
Response(Ok, headers = headers).withMessage(response.value)
8469
}
85-
.handleError { e =>
86-
val grpcStatus = e match {
87-
case e: StatusException =>
88-
e.getStatus.getDescription match {
89-
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
90-
case _ => e.getStatus
91-
}
92-
case e: StatusRuntimeException => e.getStatus
93-
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
94-
case _ => io.grpc.Status.INTERNAL
95-
}
96-
97-
val (message, metadata) = e match {
98-
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers)
99-
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers)
100-
case e => (Option(e.getMessage), new Metadata())
101-
}
102-
103-
val httpStatus = grpcStatus.toHttpStatus
104-
val connectCode = grpcStatus.toConnectCode
105-
106-
// Should be called before converting metadata to headers
107-
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
108-
.fold(Seq.empty)(_.asScala.toSeq)
109-
110-
val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)
111-
112-
if (logger.isTraceEnabled) {
113-
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
114-
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
115-
logger.trace(s"<<< Error processing request", e)
116-
}
117-
118-
Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error(
119-
code = connectCode,
120-
message = message,
121-
details = details
122-
))
123-
}
70+
.handleError(errorHandler.handle)
12471
}
12572

12673
}

core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/MethodRegistry.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ object MethodRegistry {
5252
private def extractHttpRule(methodDescriptor: MethodDescriptor[_, _]): Option[HttpRule] = {
5353
methodDescriptor.getSchemaDescriptor match
5454
case sd: ConcreteProtoMethodDescriptorSupplier =>
55-
val fields = sd.getMethodDescriptor.getOptions.getUnknownFields
55+
val fields = sd.getMethodDescriptor.getOptions.getUnknownFields
5656

5757
if fields.hasField(HttpFieldNumber) then
5858
Some(HttpRule.parseFrom(fields.getField(HttpFieldNumber).getLengthDelimitedList.get(0).toByteArray))
5959
else
6060
None
61-
case _ =>
61+
case _ =>
6262
None
6363
}
6464

0 commit comments

Comments
 (0)