Skip to content

Commit 1fbc2f0

Browse files
authored
[transcoding] Work with a tree of routes instead of a list (#58)
1 parent 73a632c commit 1fbc2f0

File tree

7 files changed

+159
-74
lines changed

7 files changed

+159
-74
lines changed

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ConnectHandler[F[_] : Async](
108108
}
109109
}
110110

111-
req.as[GeneratedMessage](method.requestMessageCompanion)
111+
req.as[GeneratedMessage](using method.requestMessageCompanion)
112112
.flatMap { message =>
113113
if (logger.isTraceEnabled) {
114114
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}")

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ import org.ivovk.connect_rpc_scala.grpc.*
1111
import org.ivovk.connect_rpc_scala.http.*
1212
import org.ivovk.connect_rpc_scala.http.QueryParams.*
1313
import org.ivovk.connect_rpc_scala.http.codec.*
14-
import scalapb.GeneratedMessage
14+
import org.ivovk.connect_rpc_scala.syntax.all.*
15+
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
1516

1617
import java.util.concurrent.Executor
1718
import scala.concurrent.ExecutionContext
@@ -170,17 +171,21 @@ final class ConnectRouteBuilder[F[_] : Async] private(
170171

171172
val transcodingRoutes = HttpRoutes[F] { req =>
172173
OptionT.fromOption[F](transcodingUrlMatcher.matchesRequest(req))
173-
.semiflatMap { case MatchedRequest(method, json) =>
174+
.semiflatMap { case MatchedRequest(method, pathJson, queryJson) =>
174175
given MessageCodec[F] = jsonCodec
175-
given EncodeOptions = EncodeOptions(None)
176176

177-
RequestEntity[F](req.body, req.headers)
178-
.as[GeneratedMessage](method.requestMessageCompanion)
179-
.flatMap { entity =>
180-
val entity2 = jsonCodec.parser.fromJson[GeneratedMessage](json)(method.requestMessageCompanion)
181-
val finalEntity = method.requestMessageCompanion.parseFrom(entity.toByteArray ++ entity2.toByteArray)
177+
given Companion[Message] = method.requestMessageCompanion
182178

183-
transcodingHandler.handleUnary(finalEntity, req.headers, method)
179+
RequestEntity[F](req.body, req.headers).as[Message]
180+
.flatMap { bodyMessage =>
181+
val pathMessage = jsonCodec.parser.fromJson[Message](pathJson)
182+
val queryMessage = jsonCodec.parser.fromJson[Message](queryJson)
183+
184+
transcodingHandler.handleUnary(
185+
bodyMessage.concat(pathMessage, queryMessage),
186+
req.headers,
187+
method
188+
)
184189
}
185190
}
186191
}

core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala

Lines changed: 119 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,123 @@ import com.google.api.HttpRule
55
import org.http4s.{Method, Request, Uri}
66
import org.ivovk.connect_rpc_scala
77
import org.ivovk.connect_rpc_scala.grpc.MethodRegistry
8+
import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.*
89
import org.json4s.JsonAST.{JField, JObject}
910
import org.json4s.{JString, JValue}
1011

11-
import scala.util.boundary
12-
import scala.util.boundary.break
12+
import scala.jdk.CollectionConverters.*
1313

14-
case class MatchedRequest(method: MethodRegistry.Entry, json: JValue)
14+
case class MatchedRequest(
15+
method: MethodRegistry.Entry,
16+
pathJson: JValue,
17+
queryJson: JValue,
18+
)
1519

1620
object TranscodingUrlMatcher {
1721
case class Entry(
1822
method: MethodRegistry.Entry,
19-
httpMethodMatcher: Method => Boolean,
23+
httpMethod: Option[Method],
2024
pattern: Uri.Path,
2125
)
2226

27+
sealed trait RouteTree
28+
29+
case class RootNode(
30+
children: Vector[RouteTree],
31+
) extends RouteTree
32+
33+
case class Node(
34+
isVariable: Boolean,
35+
segment: String,
36+
children: Vector[RouteTree],
37+
) extends RouteTree
38+
39+
case class Leaf(
40+
httpMethod: Option[Method],
41+
method: MethodRegistry.Entry,
42+
) extends RouteTree
43+
44+
private def mkTree(entries: Seq[Entry]): Vector[RouteTree] = {
45+
entries.groupByOrd(_.pattern.segments.headOption)
46+
.flatMap { (maybeSegment, entries) =>
47+
maybeSegment match {
48+
case None =>
49+
entries.map { entry =>
50+
Leaf(entry.httpMethod, entry.method)
51+
}
52+
case Some(head) =>
53+
val variableDef = this.isVariable(head)
54+
val segment =
55+
if variableDef then
56+
head.encoded.substring(1, head.encoded.length - 1)
57+
else head.encoded
58+
59+
List(
60+
Node(
61+
variableDef,
62+
segment,
63+
mkTree(entries.map(e => e.copy(pattern = e.pattern.splitAt(1)._2)).toVector),
64+
)
65+
)
66+
}
67+
}
68+
.toVector
69+
}
70+
71+
extension [A](it: Iterable[A]) {
72+
// Preserves ordering of elements
73+
def groupByOrd[B](f: A => B): Map[B, Vector[A]] = {
74+
val result = collection.mutable.LinkedHashMap.empty[B, Vector[A]]
75+
76+
it.foreach { elem =>
77+
val key = f(elem)
78+
val vec = result.getOrElse(key, Vector.empty)
79+
result.update(key, vec :+ elem)
80+
}
81+
82+
result.toMap
83+
}
84+
85+
// Returns the first element that is Some
86+
def colFirst[B](f: A => Option[B]): Option[B] = {
87+
val iter = it.iterator
88+
while (iter.hasNext) {
89+
val x = f(iter.next())
90+
if x.isDefined then return x
91+
}
92+
None
93+
}
94+
}
95+
96+
private def isVariable(segment: Uri.Path.Segment): Boolean = {
97+
val enc = segment.encoded
98+
val length = enc.length
99+
100+
length > 2 && enc(0) == '{' && enc(length - 1) == '}'
101+
}
102+
23103
def create[F[_]](
24104
methods: Seq[MethodRegistry.Entry],
25105
pathPrefix: Uri.Path,
26106
): TranscodingUrlMatcher[F] = {
27107
val entries = methods.flatMap { method =>
28-
method.httpRule match {
29-
case Some(httpRule) =>
30-
val (httpMethod, pattern) = extractMethodAndPattern(httpRule)
108+
method.httpRule.fold(List.empty[Entry]) { httpRule =>
109+
val additionalBindings = httpRule.getAdditionalBindingsList.asScala.toList
31110

32-
val httpMethodMatcher: Method => Boolean = m => httpMethod.forall(_ == m)
111+
(httpRule :: additionalBindings).map { rule =>
112+
val (httpMethod, pattern) = extractMethodAndPattern(rule)
33113

34114
Entry(
35115
method,
36-
httpMethodMatcher,
37-
pathPrefix.dropEndsWithSlash.concat(pattern.toRelative)
38-
).some
39-
case None => none
116+
httpMethod,
117+
pathPrefix.concat(pattern),
118+
)
119+
}
40120
}
41121
}
42122

43123
new TranscodingUrlMatcher(
44-
entries,
124+
RootNode(mkTree(entries)),
45125
)
46126
}
47127

@@ -62,56 +142,40 @@ object TranscodingUrlMatcher {
62142
}
63143

64144
class TranscodingUrlMatcher[F[_]](
65-
entries: Seq[TranscodingUrlMatcher.Entry],
145+
tree: TranscodingUrlMatcher.RootNode,
66146
) {
67147

68-
import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.*
148+
import TranscodingUrlMatcher.*
69149

70-
def matchesRequest(req: Request[F]): Option[MatchedRequest] = boundary {
71-
entries.foreach { entry =>
72-
if (entry.httpMethodMatcher(req.method)) {
73-
matchExtract(entry.pattern, req.uri.path) match {
74-
case Some(pathParams) =>
75-
val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse("")))
150+
def matchesRequest(req: Request[F]): Option[MatchedRequest] = {
151+
def doMatch(node: RouteTree, path: List[Uri.Path.Segment], pathVars: List[JField]): Option[MatchedRequest] = {
152+
node match {
153+
case Node(isVariable, patternSegment, children) if path.nonEmpty =>
154+
val pathSegment = path.head
155+
val pathTail = path.tail
76156

77-
val merged = mergeFields(groupFields(pathParams), groupFields(queryParams))
157+
if isVariable then
158+
val newPatchVars = (patternSegment -> JString(pathSegment.encoded)) :: pathVars
78159

79-
break(Some(MatchedRequest(entry.method, JObject(merged))))
80-
case None => // continue
81-
}
160+
children.colFirst(doMatch(_, pathTail, newPatchVars))
161+
else if pathSegment.encoded == patternSegment then
162+
children.colFirst(doMatch(_, pathTail, pathVars))
163+
else none
164+
case Leaf(httpMethod, method) if path.isEmpty && httpMethod.forall(_ == req.method) =>
165+
val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse("")))
166+
167+
MatchedRequest(
168+
method,
169+
JObject(groupFields(pathVars)),
170+
JObject(groupFields(queryParams))
171+
).some
172+
case RootNode(children) =>
173+
children.colFirst(doMatch(_, path, pathVars))
174+
case _ => none
82175
}
83176
}
84177

85-
none
178+
doMatch(tree, req.uri.path.segments.toList, List.empty)
86179
}
87180

88-
/**
89-
* Matches path segments with pattern segments and extracts variables from the path.
90-
* Returns None if the path does not match the pattern.
91-
*/
92-
private def matchExtract(pattern: Uri.Path, path: Uri.Path): Option[List[JField]] = boundary {
93-
if path.segments.length != pattern.segments.length then boundary.break(none)
94-
95-
path.segments.indices
96-
.foldLeft(List.empty[JField]) { (state, idx) =>
97-
val pathSegment = path.segments(idx)
98-
val patternSegment = pattern.segments(idx)
99-
100-
if isVariable(patternSegment) then
101-
val varName = patternSegment.encoded.substring(1, patternSegment.encoded.length - 1)
102-
103-
(varName -> JString(pathSegment.encoded)) :: state
104-
else if pathSegment != patternSegment then
105-
boundary.break(none)
106-
else state
107-
}
108-
.some
109-
}
110-
111-
private def isVariable(segment: Uri.Path.Segment): Boolean = {
112-
val enc = segment.encoded
113-
val length = enc.length
114-
115-
length > 2 && enc(0) == '{' && enc(length - 1) == '}'
116-
}
117181
}

core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ case class RequestEntity[F[_]](
3838

3939
def timeout: Option[Long] = headers.timeout
4040

41-
def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] =
42-
M.rethrow(codec.decode(this)(using cmp).value)
43-
41+
def as[A <: Message: Companion](using M: MonadThrow[F], codec: MessageCodec[F]): F[A] =
42+
M.rethrow(codec.decode(this)(using summon[Companion[A]]).value)
4443
}

core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/MessageCodec.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ case class EncodeOptions(
1010
encoding: Option[ContentCoding]
1111
)
1212

13+
object EncodeOptions {
14+
given EncodeOptions = EncodeOptions(None)
15+
}
16+
1317
trait MessageCodec[F[_]] {
1418

1519
val mediaType: MediaType

core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package org.ivovk.connect_rpc_scala.syntax
22

3+
import com.google.protobuf.ByteString
34
import io.grpc.{StatusException, StatusRuntimeException}
45
import org.ivovk.connect_rpc_scala.grpc.GrpcHeaders
5-
import scalapb.GeneratedMessage
6+
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
67

78
object all extends ExceptionSyntax, ProtoMappingsSyntax
89

@@ -33,6 +34,20 @@ trait ExceptionSyntax {
3334
trait ProtoMappingsSyntax {
3435

3536
extension [T <: GeneratedMessage](t: T) {
37+
def concat(other: T, more: T*): T = {
38+
val cmp = t.companion.asInstanceOf[GeneratedMessageCompanion[T]]
39+
val empty = cmp.defaultInstance
40+
41+
val els = (t :: other :: more.toList).filter(_ != empty)
42+
43+
els match
44+
case Nil => empty
45+
case el :: Nil => el
46+
case _ =>
47+
val is = els.foldLeft(ByteString.empty)(_ concat _.toByteString).newCodedInput()
48+
cmp.parseFrom(is)
49+
}
50+
3651
def toProtoAny: com.google.protobuf.any.Any = {
3752
com.google.protobuf.any.Any(
3853
typeUrl = "type.googleapis.com/" + t.companion.scalaDescriptor.fullName,

core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,39 +40,37 @@ class TranscodingUrlMatcherTest extends AnyFunSuiteLike {
4040

4141
assert(result.isDefined)
4242
assert(result.get.method.name == MethodName("CountriesService", "ListCountries"))
43-
assert(result.get.json == JObject())
4443
}
4544

4645
test("matches request with POST method") {
4746
val result = matcher.matchesRequest(Request[IO](Method.POST, uri"/api/countries"))
4847

4948
assert(result.isDefined)
5049
assert(result.get.method.name == MethodName("CountriesService", "CreateCountry"))
51-
assert(result.get.json == JObject())
5250
}
5351

5452
test("extracts query parameters") {
5553
val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/list?limit=10&offset=5"))
5654

5755
assert(result.isDefined)
5856
assert(result.get.method.name == MethodName("CountriesService", "ListCountries"))
59-
assert(result.get.json == JObject("limit" -> JString("10"), "offset" -> JString("5")))
57+
assert(result.get.queryJson == JObject("limit" -> JString("10"), "offset" -> JString("5")))
6058
}
6159

6260
test("matches request with path parameter and extracts it") {
6361
val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/Uganda"))
6462

6563
assert(result.isDefined)
6664
assert(result.get.method.name == MethodName("CountriesService", "GetCountry"))
67-
assert(result.get.json == JObject("country_id" -> JString("Uganda")))
65+
assert(result.get.pathJson == JObject("country_id" -> JString("Uganda")))
6866
}
6967

7068
test("extracts repeating query parameters") {
7169
val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/list?limit=10&limit=20"))
7270

7371
assert(result.isDefined)
7472
assert(result.get.method.name == MethodName("CountriesService", "ListCountries"))
75-
assert(result.get.json == JObject("limit" -> JArray(JString("10") :: JString("20") :: Nil)))
73+
assert(result.get.queryJson == JObject("limit" -> JArray(JString("10") :: JString("20") :: Nil)))
7674
}
7775

7876
}

0 commit comments

Comments
 (0)