@@ -10,7 +10,6 @@ import org.ivovk.connect_rpc_scala.Mappings.*
10
10
import org .ivovk .connect_rpc_scala .grpc .{ClientCalls , GrpcHeaders , MethodRegistry }
11
11
import org .ivovk .connect_rpc_scala .http .Headers .`X-Test-Case-Name`
12
12
import org .ivovk .connect_rpc_scala .http .RequestEntity
13
- import org .ivovk .connect_rpc_scala .http .codec .MessageCodec .given
14
13
import org .ivovk .connect_rpc_scala .http .codec .{Compressor , EncodeOptions , MessageCodec }
15
14
import org .slf4j .{Logger , LoggerFactory }
16
15
import scalapb .GeneratedMessage
@@ -19,12 +18,22 @@ import scala.concurrent.duration.*
19
18
import scala .jdk .CollectionConverters .*
20
19
import scala .util .chaining .*
21
20
21
+ object ConnectHandler {
22
+
23
+ extension [F [_]](response : Response [F ]) {
24
+ def withMessage (entity : GeneratedMessage )(using codec : MessageCodec [F ], options : EncodeOptions ): Response [F ] =
25
+ codec.encode(entity, options).applyTo(response)
26
+ }
27
+
28
+ }
29
+
22
30
class ConnectHandler [F [_] : Async ](
23
31
channel : Channel ,
24
32
httpDsl : Http4sDsl [F ],
25
33
treatTrailersAsHeaders : Boolean ,
26
34
) {
27
35
36
+ import ConnectHandler .*
28
37
import httpDsl .*
29
38
30
39
private val logger : Logger = LoggerFactory .getLogger(getClass)
@@ -37,14 +46,53 @@ class ConnectHandler[F[_] : Async](
37
46
encoding = req.encoding.filter(Compressor .supportedEncodings.contains)
38
47
)
39
48
40
- method.descriptor.getType match
49
+ val f = method.descriptor.getType match
41
50
case MethodType .UNARY =>
42
51
handleUnary(req, method)
43
52
case unsupported =>
44
- NotImplemented (connectrpc.Error (
45
- code = io.grpc.Status .UNIMPLEMENTED .toConnectCode,
46
- message = s " Unsupported method type: $unsupported" .some
53
+ Async [F ].raiseError(new StatusException (
54
+ io.grpc.Status .UNIMPLEMENTED .withDescription(s " Unsupported method type: $unsupported" )
47
55
))
56
+
57
+ f.handleError { e =>
58
+ val grpcStatus = e match {
59
+ case e : StatusException =>
60
+ e.getStatus.getDescription match {
61
+ case " an implementation is missing" => io.grpc.Status .UNIMPLEMENTED
62
+ case _ => e.getStatus
63
+ }
64
+ case e : StatusRuntimeException => e.getStatus
65
+ case _ : MessageFailure => io.grpc.Status .INVALID_ARGUMENT
66
+ case _ => io.grpc.Status .INTERNAL
67
+ }
68
+
69
+ val (message, metadata) = e match {
70
+ case e : StatusRuntimeException => (Option (e.getStatus.getDescription), e.getTrailers)
71
+ case e : StatusException => (Option (e.getStatus.getDescription), e.getTrailers)
72
+ case e => (Option (e.getMessage), new Metadata ())
73
+ }
74
+
75
+ val httpStatus = grpcStatus.toHttpStatus
76
+ val connectCode = grpcStatus.toConnectCode
77
+
78
+ // Should be called before converting metadata to headers
79
+ val details = Option (metadata.removeAll(GrpcHeaders .ErrorDetailsKey ))
80
+ .fold(Seq .empty)(_.asScala.toSeq)
81
+
82
+ val headers = metadata.toHeaders(trailing = ! treatTrailersAsHeaders)
83
+
84
+ if (logger.isTraceEnabled) {
85
+ logger.trace(s " <<< Http Status: $httpStatus, Connect Error Code: $connectCode" )
86
+ logger.trace(s " <<< Headers: ${headers.redactSensitive()}" )
87
+ logger.trace(s " <<< Error processing request " , e)
88
+ }
89
+
90
+ Response [F ](httpStatus, headers = headers).withMessage(connectrpc.Error (
91
+ code = connectCode,
92
+ message = message,
93
+ details = details
94
+ ))
95
+ }
48
96
}
49
97
50
98
private def handleUnary (
@@ -90,46 +138,7 @@ class ConnectHandler[F[_] : Async](
90
138
logger.trace(s " <<< Headers: ${headers.redactSensitive()}" )
91
139
}
92
140
93
- Response (Ok , headers = headers).withEntity(response.value)
94
- }
95
- .recover { case e =>
96
- val grpcStatus = e match {
97
- case e : StatusException =>
98
- e.getStatus.getDescription match {
99
- case " an implementation is missing" => io.grpc.Status .UNIMPLEMENTED
100
- case _ => e.getStatus
101
- }
102
- case e : StatusRuntimeException => e.getStatus
103
- case _ : MessageFailure => io.grpc.Status .INVALID_ARGUMENT
104
- case _ => io.grpc.Status .INTERNAL
105
- }
106
-
107
- val (message, metadata) = e match {
108
- case e : StatusRuntimeException => (Option (e.getStatus.getDescription), e.getTrailers)
109
- case e : StatusException => (Option (e.getStatus.getDescription), e.getTrailers)
110
- case e => (Option (e.getMessage), new Metadata ())
111
- }
112
-
113
- val httpStatus = grpcStatus.toHttpStatus
114
- val connectCode = grpcStatus.toConnectCode
115
-
116
- // Should be called before converting metadata to headers
117
- val details = Option (metadata.removeAll(GrpcHeaders .ErrorDetailsKey ))
118
- .fold(Seq .empty)(_.asScala.toSeq)
119
-
120
- val headers = metadata.toHeaders(trailing = ! treatTrailersAsHeaders)
121
-
122
- if (logger.isTraceEnabled) {
123
- logger.trace(s " <<< Http Status: $httpStatus, Connect Error Code: $connectCode" )
124
- logger.trace(s " <<< Headers: ${headers.redactSensitive()}" )
125
- logger.trace(s " <<< Error processing request " , e)
126
- }
127
-
128
- Response [F ](httpStatus, headers = headers).withEntity(connectrpc.Error (
129
- code = connectCode,
130
- message = message,
131
- details = details
132
- ))
141
+ Response (Ok , headers = headers).withMessage(response.value)
133
142
}
134
143
}
135
144
0 commit comments