Skip to content

Commit 73a632c

Browse files
authored
Initial work on the GRPC Transcoding (#54)
1 parent 3a0e68f commit 73a632c

File tree

12 files changed

+474
-23
lines changed

12 files changed

+474
-23
lines changed

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.ivovk.connect_rpc_scala.grpc.*
1111
import org.ivovk.connect_rpc_scala.http.*
1212
import org.ivovk.connect_rpc_scala.http.QueryParams.*
1313
import org.ivovk.connect_rpc_scala.http.codec.*
14+
import scalapb.GeneratedMessage
1415

1516
import java.util.concurrent.Executor
1617
import scala.concurrent.ExecutionContext
@@ -108,8 +109,9 @@ final class ConnectRouteBuilder[F[_] : Async] private(
108109
val httpDsl = Http4sDsl[F]
109110
import httpDsl.*
110111

112+
val jsonCodec = customJsonCodec.getOrElse(JsonMessageCodecBuilder[F]().build)
111113
val codecRegistry = MessageCodecRegistry[F](
112-
customJsonCodec.getOrElse(JsonMessageCodecBuilder[F]().build),
114+
jsonCodec,
113115
ProtoMessageCodec[F](),
114116
)
115117

@@ -124,13 +126,13 @@ final class ConnectRouteBuilder[F[_] : Async] private(
124126
waitForShutdown,
125127
)
126128
yield
127-
val handler = new ConnectHandler(
129+
val connectHandler = new ConnectHandler(
128130
channel,
129131
httpDsl,
130132
treatTrailersAsHeaders,
131133
)
132134

133-
HttpRoutes[F] {
135+
val connectRoutes = HttpRoutes[F] {
134136
case req@Method.GET -> `pathPrefix` / service / method :? EncodingQP(mediaType) +& MessageQP(message) =>
135137
OptionT.fromOption[F](methodRegistry.get(service, method))
136138
// Temporary support GET-requests for all methods,
@@ -140,7 +142,7 @@ final class ConnectRouteBuilder[F[_] : Async] private(
140142
withCodec(httpDsl, codecRegistry, mediaType.some) { codec =>
141143
val entity = RequestEntity[F](message, req.headers)
142144

143-
handler.handle(entity, methodEntry)(using codec)
145+
connectHandler.handle(entity, methodEntry)(using codec)
144146
}
145147
}
146148
case req@Method.POST -> `pathPrefix` / service / method =>
@@ -149,12 +151,41 @@ final class ConnectRouteBuilder[F[_] : Async] private(
149151
withCodec(httpDsl, codecRegistry, req.contentType.map(_.mediaType)) { codec =>
150152
val entity = RequestEntity[F](req.body, req.headers)
151153

152-
handler.handle(entity, methodEntry)(using codec)
154+
connectHandler.handle(entity, methodEntry)(using codec)
153155
}
154156
}
155157
case _ =>
156158
OptionT.none
157159
}
160+
161+
val transcodingUrlMatcher = TranscodingUrlMatcher.create[F](
162+
methodRegistry.all,
163+
pathPrefix,
164+
)
165+
val transcodingHandler = new TranscodingHandler(
166+
channel,
167+
httpDsl,
168+
treatTrailersAsHeaders,
169+
)
170+
171+
val transcodingRoutes = HttpRoutes[F] { req =>
172+
OptionT.fromOption[F](transcodingUrlMatcher.matchesRequest(req))
173+
.semiflatMap { case MatchedRequest(method, json) =>
174+
given MessageCodec[F] = jsonCodec
175+
given EncodeOptions = EncodeOptions(None)
176+
177+
RequestEntity[F](req.body, req.headers)
178+
.as[GeneratedMessage](method.requestMessageCompanion)
179+
.flatMap { entity =>
180+
val entity2 = jsonCodec.parser.fromJson[GeneratedMessage](json)(method.requestMessageCompanion)
181+
val finalEntity = method.requestMessageCompanion.parseFrom(entity.toByteArray ++ entity2.toByteArray)
182+
183+
transcodingHandler.handleUnary(finalEntity, req.headers, method)
184+
}
185+
}
186+
}
187+
188+
connectRoutes <+> transcodingRoutes
158189
}
159190

160191
def build: Resource[F, HttpApp[F]] =
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package org.ivovk.connect_rpc_scala
2+
3+
import cats.effect.Async
4+
import cats.implicits.*
5+
import io.grpc.*
6+
import org.http4s.dsl.Http4sDsl
7+
import org.http4s.{Header, Headers, MessageFailure, Response}
8+
import org.ivovk.connect_rpc_scala.Mappings.*
9+
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry}
10+
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
11+
import org.ivovk.connect_rpc_scala.http.RequestEntity
12+
import org.ivovk.connect_rpc_scala.http.RequestEntity.*
13+
import org.ivovk.connect_rpc_scala.http.codec.{EncodeOptions, MessageCodec}
14+
import org.slf4j.{Logger, LoggerFactory}
15+
import scalapb.GeneratedMessage
16+
17+
import scala.concurrent.duration.*
18+
import scala.jdk.CollectionConverters.*
19+
import scala.util.chaining.*
20+
21+
object TranscodingHandler {
22+
23+
extension [F[_]](response: Response[F]) {
24+
def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] =
25+
codec.encode(entity, options).applyTo(response)
26+
}
27+
28+
}
29+
30+
class TranscodingHandler[F[_] : Async](
31+
channel: Channel,
32+
httpDsl: Http4sDsl[F],
33+
treatTrailersAsHeaders: Boolean,
34+
) {
35+
36+
import TranscodingHandler.*
37+
import httpDsl.*
38+
39+
private val logger: Logger = LoggerFactory.getLogger(getClass)
40+
41+
def handleUnary(
42+
message: GeneratedMessage,
43+
headers: Headers,
44+
method: MethodRegistry.Entry,
45+
)(using MessageCodec[F], EncodeOptions): F[Response[F]] = {
46+
if (logger.isTraceEnabled) {
47+
// Used in conformance tests
48+
headers.get[`X-Test-Case-Name`] match {
49+
case Some(header) =>
50+
logger.trace(s">>> Test Case: ${header.value}")
51+
case None => // ignore
52+
}
53+
}
54+
55+
if (logger.isTraceEnabled) {
56+
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}")
57+
}
58+
59+
val callOptions = CallOptions.DEFAULT
60+
.pipe(
61+
headers.timeout match {
62+
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
63+
case None => identity
64+
}
65+
)
66+
67+
ClientCalls
68+
.asyncUnaryCall(
69+
channel,
70+
method.descriptor,
71+
callOptions,
72+
headers.toMetadata,
73+
message
74+
)
75+
.map { response =>
76+
val headers = response.headers.toHeaders() ++
77+
response.trailers.toHeaders(trailing = !treatTrailersAsHeaders)
78+
79+
if (logger.isTraceEnabled) {
80+
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
81+
}
82+
83+
Response(Ok, headers = headers).withMessage(response.value)
84+
}
85+
.handleError { e =>
86+
val grpcStatus = e match {
87+
case e: StatusException =>
88+
e.getStatus.getDescription match {
89+
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
90+
case _ => e.getStatus
91+
}
92+
case e: StatusRuntimeException => e.getStatus
93+
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
94+
case _ => io.grpc.Status.INTERNAL
95+
}
96+
97+
val (message, metadata) = e match {
98+
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers)
99+
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers)
100+
case e => (Option(e.getMessage), new Metadata())
101+
}
102+
103+
val httpStatus = grpcStatus.toHttpStatus
104+
val connectCode = grpcStatus.toConnectCode
105+
106+
// Should be called before converting metadata to headers
107+
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
108+
.fold(Seq.empty)(_.asScala.toSeq)
109+
110+
val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)
111+
112+
if (logger.isTraceEnabled) {
113+
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
114+
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
115+
logger.trace(s"<<< Error processing request", e)
116+
}
117+
118+
Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error(
119+
code = connectCode,
120+
message = message,
121+
details = details
122+
))
123+
}
124+
}
125+
126+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package org.ivovk.connect_rpc_scala
2+
3+
import cats.implicits.*
4+
import com.google.api.HttpRule
5+
import org.http4s.{Method, Request, Uri}
6+
import org.ivovk.connect_rpc_scala
7+
import org.ivovk.connect_rpc_scala.grpc.MethodRegistry
8+
import org.json4s.JsonAST.{JField, JObject}
9+
import org.json4s.{JString, JValue}
10+
11+
import scala.util.boundary
12+
import scala.util.boundary.break
13+
14+
case class MatchedRequest(method: MethodRegistry.Entry, json: JValue)
15+
16+
object TranscodingUrlMatcher {
17+
case class Entry(
18+
method: MethodRegistry.Entry,
19+
httpMethodMatcher: Method => Boolean,
20+
pattern: Uri.Path,
21+
)
22+
23+
def create[F[_]](
24+
methods: Seq[MethodRegistry.Entry],
25+
pathPrefix: Uri.Path,
26+
): TranscodingUrlMatcher[F] = {
27+
val entries = methods.flatMap { method =>
28+
method.httpRule match {
29+
case Some(httpRule) =>
30+
val (httpMethod, pattern) = extractMethodAndPattern(httpRule)
31+
32+
val httpMethodMatcher: Method => Boolean = m => httpMethod.forall(_ == m)
33+
34+
Entry(
35+
method,
36+
httpMethodMatcher,
37+
pathPrefix.dropEndsWithSlash.concat(pattern.toRelative)
38+
).some
39+
case None => none
40+
}
41+
}
42+
43+
new TranscodingUrlMatcher(
44+
entries,
45+
)
46+
}
47+
48+
private def extractMethodAndPattern(rule: HttpRule): (Option[Method], Uri.Path) = {
49+
val (method, str) = rule.getPatternCase match
50+
case HttpRule.PatternCase.GET => (Method.GET.some, rule.getGet)
51+
case HttpRule.PatternCase.PUT => (Method.PUT.some, rule.getPut)
52+
case HttpRule.PatternCase.POST => (Method.POST.some, rule.getPost)
53+
case HttpRule.PatternCase.DELETE => (Method.DELETE.some, rule.getDelete)
54+
case HttpRule.PatternCase.PATCH => (Method.PATCH.some, rule.getPatch)
55+
case HttpRule.PatternCase.CUSTOM => (none, rule.getCustom.getPath)
56+
case other => throw new RuntimeException(s"Unsupported pattern case $other (Rule: $rule)")
57+
58+
val path = Uri.Path.unsafeFromString(str).dropEndsWithSlash
59+
60+
(method, path)
61+
}
62+
}
63+
64+
class TranscodingUrlMatcher[F[_]](
65+
entries: Seq[TranscodingUrlMatcher.Entry],
66+
) {
67+
68+
import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.*
69+
70+
def matchesRequest(req: Request[F]): Option[MatchedRequest] = boundary {
71+
entries.foreach { entry =>
72+
if (entry.httpMethodMatcher(req.method)) {
73+
matchExtract(entry.pattern, req.uri.path) match {
74+
case Some(pathParams) =>
75+
val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse("")))
76+
77+
val merged = mergeFields(groupFields(pathParams), groupFields(queryParams))
78+
79+
break(Some(MatchedRequest(entry.method, JObject(merged))))
80+
case None => // continue
81+
}
82+
}
83+
}
84+
85+
none
86+
}
87+
88+
/**
89+
* Matches path segments with pattern segments and extracts variables from the path.
90+
* Returns None if the path does not match the pattern.
91+
*/
92+
private def matchExtract(pattern: Uri.Path, path: Uri.Path): Option[List[JField]] = boundary {
93+
if path.segments.length != pattern.segments.length then boundary.break(none)
94+
95+
path.segments.indices
96+
.foldLeft(List.empty[JField]) { (state, idx) =>
97+
val pathSegment = path.segments(idx)
98+
val patternSegment = pattern.segments(idx)
99+
100+
if isVariable(patternSegment) then
101+
val varName = patternSegment.encoded.substring(1, patternSegment.encoded.length - 1)
102+
103+
(varName -> JString(pathSegment.encoded)) :: state
104+
else if pathSegment != patternSegment then
105+
boundary.break(none)
106+
else state
107+
}
108+
.some
109+
}
110+
111+
private def isVariable(segment: Uri.Path.Segment): Boolean = {
112+
val enc = segment.encoded
113+
val length = enc.length
114+
115+
length > 2 && enc(0) == '{' && enc(length - 1) == '}'
116+
}
117+
}

core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/MethodRegistry.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package org.ivovk.connect_rpc_scala.grpc
22

3-
import com.google.api.AnnotationsProto
4-
import com.google.api.http.HttpRule
3+
import com.google.api.{AnnotationsProto, HttpRule}
54
import io.grpc.{MethodDescriptor, ServerMethodDefinition, ServerServiceDefinition}
65
import scalapb.grpc.ConcreteProtoMethodDescriptorSupplier
76
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
@@ -44,7 +43,6 @@ object MethodRegistry {
4443
descriptor = methodDescriptor,
4544
)
4645
}
47-
.groupMapReduce(_.name.service)(e => Map(e.name.method -> e))(_ ++ _)
4846

4947
new MethodRegistry(entries)
5048
}
@@ -63,9 +61,16 @@ object MethodRegistry {
6361

6462
}
6563

66-
class MethodRegistry private(entries: Map[Service, Map[Method, MethodRegistry.Entry]]) {
64+
class MethodRegistry private(entries: Seq[MethodRegistry.Entry]) {
65+
66+
private val serviceMethodEntries: Map[Service, Map[Method, MethodRegistry.Entry]] = entries
67+
.groupMapReduce(_.name.service)(e => Map(e.name.method -> e))(_ ++ _)
68+
69+
def all: Seq[MethodRegistry.Entry] = entries
70+
71+
def get(name: MethodName): Option[MethodRegistry.Entry] = get(name.service, name.method)
6772

6873
def get(service: Service, method: Method): Option[MethodRegistry.Entry] =
69-
entries.getOrElse(service, Map.empty).get(method)
74+
serviceMethodEntries.getOrElse(service, Map.empty).get(method)
7075

7176
}

core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ import org.ivovk.connect_rpc_scala.http.codec.MessageCodec
99
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
1010

1111

12+
object RequestEntity {
13+
extension (h: Headers) {
14+
def timeout: Option[Long] =
15+
h.get[`Connect-Timeout-Ms`].map(_.value)
16+
}
17+
}
18+
1219
/**
1320
* Encoded message and headers with the knowledge how this message can be decoded.
1421
* Similar to [[org.http4s.Media]], but extends the message with `String` type representing message that is
@@ -18,6 +25,7 @@ case class RequestEntity[F[_]](
1825
message: String | Stream[F, Byte],
1926
headers: Headers,
2027
) {
28+
import RequestEntity.*
2129

2230
private def contentType: Option[`Content-Type`] =
2331
headers.get[`Content-Type`]
@@ -28,8 +36,7 @@ case class RequestEntity[F[_]](
2836
def encoding: Option[ContentCoding] =
2937
headers.get[`Content-Encoding`].map(_.contentCoding)
3038

31-
def timeout: Option[Long] =
32-
headers.get[`Connect-Timeout-Ms`].map(_.value)
39+
def timeout: Option[Long] = headers.timeout
3340

3441
def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] =
3542
M.rethrow(codec.decode(this)(using cmp).value)

0 commit comments

Comments
 (0)