Skip to content

Commit c592970

Browse files
committed
Optimize GET-request entity parsing, merge GET- and POST-request handlers
1 parent 582dd16 commit c592970

File tree

5 files changed

+135
-96
lines changed

5 files changed

+135
-96
lines changed

conformance-suite.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
features:
33
versions: [ HTTP_VERSION_1 ]
44
protocols: [ PROTOCOL_CONNECT ]
5-
codecs: [ CODEC_JSON ]
5+
codecs: [ CODEC_JSON ] #, CODEC_PROTO ]
66
stream_types: [ STREAM_TYPE_UNARY ]
77
supports_tls: false
88
supports_trailers: false

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRpcHttpRoutes.scala

Lines changed: 52 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
package org.ivovk.connect_rpc_scala
22

33
import cats.Endo
4+
import cats.data.EitherT
45
import cats.effect.Async
56
import cats.effect.kernel.Resource
67
import cats.implicits.*
78
import fs2.compression.Compression
8-
import fs2.{Chunk, Stream}
99
import io.grpc.*
1010
import io.grpc.MethodDescriptor.MethodType
1111
import io.grpc.stub.MetadataUtils
1212
import org.http4s.*
1313
import org.http4s.dsl.Http4sDsl
14-
import org.http4s.headers.`Content-Type`
1514
import org.ivovk.connect_rpc_scala.http.*
1615
import org.ivovk.connect_rpc_scala.http.Headers.*
1716
import org.ivovk.connect_rpc_scala.http.MessageCodec.given
@@ -55,78 +54,69 @@ object ConnectRpcHttpRoutes {
5554
for
5655
ipChannel <- InProcessChannelBridge.create(services, configuration.waitForShutdown)
5756
yield
57+
def handle(
58+
httpMethod: Method,
59+
contentType: Option[MediaType],
60+
entity: RequestEntity[F],
61+
grpcMethod: String,
62+
): F[Response[F]] = {
63+
val eitherT = for
64+
given MessageCodec[F] <- EitherT.fromOptionM(
65+
contentType.flatMap(codecRegistry.byContentType).pure[F],
66+
UnsupportedMediaType(s"Unsupported content-type ${contentType.show}. " +
67+
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
68+
)
69+
70+
method <- EitherT.fromOptionM(
71+
methodRegistry.get(grpcMethod).pure[F],
72+
NotFound(connectrpc.Error(
73+
code = io.grpc.Status.NOT_FOUND.toConnectCode,
74+
message = s"Method not found: $grpcMethod".some
75+
))
76+
)
77+
78+
_ <- EitherT.cond[F](
79+
// Support GET-requests for all methods until https://github.com/scalapb/ScalaPB/pull/1774 is merged
80+
httpMethod == Method.POST || (httpMethod == Method.GET && method.methodDescriptor.isSafe) || true,
81+
(),
82+
Forbidden(connectrpc.Error(
83+
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
84+
message = s"Only POST-requests are allowed for method: $grpcMethod".some
85+
))
86+
).leftSemiflatMap(identity)
87+
88+
response <- method.methodDescriptor.getType match
89+
case MethodType.UNARY =>
90+
EitherT.right(handleUnary(dsl, method, entity, ipChannel))
91+
case unsupported =>
92+
EitherT.left(NotImplemented(connectrpc.Error(
93+
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
94+
message = s"Unsupported method type: $unsupported".some
95+
)))
96+
yield response
97+
98+
eitherT.merge
99+
}
100+
58101
HttpRoutes.of[F] {
59102
case req@Method.GET -> Root / serviceName / methodName :? EncodingQP(contentType) +& MessageQP(message) =>
60103
val grpcMethod = grpcMethodName(serviceName, methodName)
104+
val entity = RequestEntity[F](message, req.headers)
61105

62-
codecRegistry.byContentType(contentType) match {
63-
case Some(codec) =>
64-
given MessageCodec[F] = codec
65-
66-
val media = Media[F](Stream.chunk(Chunk.array(message.getBytes)), req.headers)
67-
68-
methodRegistry.get(grpcMethod) match {
69-
// Support GET-requests for all methods until https://github.com/scalapb/ScalaPB/pull/1774 is merged
70-
case Some(entry) if entry.methodDescriptor.isSafe || true =>
71-
entry.methodDescriptor.getType match
72-
case MethodType.UNARY =>
73-
handleUnary(dsl, entry, media, ipChannel)
74-
case unsupported =>
75-
NotImplemented(connectrpc.Error(
76-
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
77-
message = s"Unsupported method type: $unsupported".some
78-
))
79-
case Some(_) =>
80-
Forbidden(connectrpc.Error(
81-
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
82-
message = s"Method supports calling using POST: $grpcMethod".some
83-
))
84-
case None =>
85-
NotFound(connectrpc.Error(
86-
code = io.grpc.Status.NOT_FOUND.toConnectCode,
87-
message = s"Method not found: $grpcMethod".some
88-
))
89-
}
90-
case None =>
91-
UnsupportedMediaType(s"Unsupported content-type ${contentType.show}. " +
92-
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
93-
}
106+
handle(Method.GET, contentType.some, entity, grpcMethod)
94107
case req@Method.POST -> Root / serviceName / methodName =>
95108
val grpcMethod = grpcMethodName(serviceName, methodName)
96-
val contentType = req.headers.get[`Content-Type`].map(_.mediaType)
97-
98-
contentType.flatMap(codecRegistry.byContentType) match {
99-
case Some(codec) =>
100-
given MessageCodec[F] = codec
101-
102-
methodRegistry.get(grpcMethod) match {
103-
case Some(entry) =>
104-
entry.methodDescriptor.getType match
105-
case MethodType.UNARY =>
106-
handleUnary(dsl, entry, req, ipChannel)
107-
case unsupported =>
108-
NotImplemented(connectrpc.Error(
109-
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
110-
message = s"Unsupported method type: $unsupported".some
111-
))
112-
case None =>
113-
NotFound(connectrpc.Error(
114-
code = io.grpc.Status.NOT_FOUND.toConnectCode,
115-
message = s"Method not found: $grpcMethod".some
116-
))
117-
}
118-
case None =>
119-
UnsupportedMediaType(s"Unsupported content-type ${contentType.map(_.show).orNull}. " +
120-
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
121-
}
109+
val contentType = req.contentType.map(_.mediaType)
110+
val entity = RequestEntity[F](req)
111+
112+
handle(Method.POST, contentType, entity, grpcMethod)
122113
}
123114
}
124115

125-
126116
private def handleUnary[F[_] : Async](
127117
dsl: Http4sDsl[F],
128118
entry: RegistryEntry,
129-
req: Media[F],
119+
req: RequestEntity[F],
130120
channel: Channel
131121
)(using codec: MessageCodec[F]): F[Response[F]] = {
132122
import dsl.*

core/src/main/scala/org/ivovk/connect_rpc_scala/http/MessageCodec.scala

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@ import cats.Applicative
44
import cats.data.EitherT
55
import cats.effect.{Async, Sync}
66
import cats.implicits.*
7+
import fs2.Stream
78
import fs2.compression.Compression
89
import fs2.io.{readOutputStream, toInputStreamResource}
910
import fs2.text.decodeWithCharset
1011
import org.http4s.headers.{`Content-Encoding`, `Content-Type`}
11-
import org.http4s.{Charset, ContentCoding, DecodeResult, Entity, EntityDecoder, EntityEncoder, Media, MediaRange, MediaType}
12+
import org.http4s.{ContentCoding, DecodeResult, Entity, EntityDecoder, EntityEncoder, Headers, MediaRange, MediaType}
1213
import org.ivovk.connect_rpc_scala.ConnectRpcHttpRoutes.getClass
1314
import org.slf4j.{Logger, LoggerFactory}
1415
import scalapb.json4s.{JsonFormat, Printer}
1516
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
1617

18+
import java.net.URLDecoder
19+
import java.util.Base64
20+
1721
object MessageCodec {
1822
given [F[_] : Applicative, A <: Message](using codec: MessageCodec[F], cmp: Companion[A]): EntityDecoder[F, A] =
19-
EntityDecoder.decodeBy(MediaRange.`*/*`)(codec.decode)
23+
EntityDecoder.decodeBy(MediaRange.`*/*`)(m => codec.decode(RequestEntity(m)))
2024

2125
given [F[_], A <: Message](using codec: MessageCodec[F]): EntityEncoder[F, A] =
2226
EntityEncoder.encodeBy(`Content-Type`(codec.mediaType))(codec.encode)
@@ -26,27 +30,33 @@ trait MessageCodec[F[_]] {
2630

2731
val mediaType: MediaType
2832

29-
def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A]
33+
def decode[A <: Message](m: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A]
3034

3135
def encode[A <: Message](message: A): Entity[F]
3236

3337
}
3438

35-
class JsonMessageCodec[F[_] : Sync : Compression](jsonPrinter: Printer) extends MessageCodec[F] {
39+
class JsonMessageCodec[F[_] : Sync : Compression](printer: Printer) extends MessageCodec[F] {
3640

3741
private val logger: Logger = LoggerFactory.getLogger(getClass)
3842

3943
override val mediaType: MediaType = MediaTypes.`application/json`
4044

41-
override def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
42-
val charset = m.charset.getOrElse(Charset.`UTF-8`).nioCharset
45+
override def decode[A <: Message](entity: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
46+
val charset = entity.charset.nioCharset
47+
val string = entity.message match {
48+
case str: String =>
49+
Sync[F].delay(URLDecoder.decode(str, charset))
50+
case stream: Stream[F, Byte] =>
51+
decompressed(entity.headers, stream)
52+
.through(decodeWithCharset(charset))
53+
.compile.string
54+
}
4355

44-
val f = decompressed(m)
45-
.through(decodeWithCharset(charset))
46-
.compile.string
56+
val f = string
4757
.flatMap { str =>
4858
if (logger.isTraceEnabled) {
49-
logger.trace(s">>> Headers: ${m.headers}")
59+
logger.trace(s">>> Headers: ${entity.headers}")
5060
logger.trace(s">>> JSON: $str")
5161
}
5262

@@ -57,7 +67,7 @@ class JsonMessageCodec[F[_] : Sync : Compression](jsonPrinter: Printer) extends
5767
}
5868

5969
override def encode[A <: Message](message: A): Entity[F] = {
60-
val string = jsonPrinter.print(message)
70+
val string = printer.print(message)
6171

6272
if (logger.isTraceEnabled) {
6373
logger.trace(s"<<< JSON: $string")
@@ -72,23 +82,28 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {
7282

7383
private val logger: Logger = LoggerFactory.getLogger(getClass)
7484

75-
override val mediaType: MediaType = MediaTypes.`application/proto`
85+
private val base64dec = Base64.getUrlDecoder
7686

77-
override def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
78-
val f = toInputStreamResource(decompressed(m)).use { is =>
79-
Async[F].delay {
80-
val message = cmp.parseFrom(is)
87+
override val mediaType: MediaType = MediaTypes.`application/proto`
8188

82-
if (logger.isTraceEnabled) {
83-
logger.trace(s">>> Headers: ${m.headers}")
84-
logger.trace(s">>> Proto: ${message.toProtoString}")
85-
}
89+
override def decode[A <: Message](entity: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
90+
val f = entity.message match {
91+
case str: String =>
92+
Async[F].delay(base64dec.decode(str.getBytes(entity.charset.nioCharset)))
93+
.flatMap(arr => Async[F].delay(cmp.parseFrom(arr)))
94+
case stream: Stream[F, Byte] =>
95+
toInputStreamResource(decompressed(entity.headers, stream))
96+
.use(is => Async[F].delay(cmp.parseFrom(is)))
97+
}
8698

87-
message
99+
EitherT.right(f.map { message =>
100+
if (logger.isTraceEnabled) {
101+
logger.trace(s">>> Headers: ${entity.headers}")
102+
logger.trace(s">>> Proto: ${message.toProtoString}")
88103
}
89-
}
90104

91-
EitherT.right(f)
105+
message
106+
})
92107
}
93108

94109
override def encode[A <: Message](message: A): Entity[F] = {
@@ -104,10 +119,10 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {
104119

105120
}
106121

107-
def decompressed[F[_] : Compression](m: Media[F]): fs2.Stream[F, Byte] = {
108-
val encoding = m.headers.get[`Content-Encoding`].map(_.contentCoding)
122+
def decompressed[F[_] : Compression](headers: Headers, body: Stream[F, Byte]): Stream[F, Byte] = {
123+
val encoding = headers.get[`Content-Encoding`].map(_.contentCoding)
109124

110-
m.body.through(encoding match {
125+
body.through(encoding match {
111126
case Some(ContentCoding.gzip) =>
112127
Compression[F].gunzip().andThen(_.flatMap(_.content))
113128
case _ =>
Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package org.ivovk.connect_rpc_scala.http
22

33
import org.http4s.dsl.impl.QueryParamDecoderMatcher
4-
import org.http4s.{Charset, MediaType, ParseFailure, QueryParamDecoder}
5-
6-
import java.net.URLDecoder
4+
import org.http4s.{MediaType, ParseFailure, QueryParamDecoder}
75

86
object QueryParams {
97

@@ -15,9 +13,6 @@ object QueryParams {
1513

1614
object EncodingQP extends QueryParamDecoderMatcher[MediaType]("encoding")(encodingQPDecoder)
1715

18-
private val messageQPDecoder = QueryParamDecoder[String]
19-
.map(URLDecoder.decode(_, Charset.`UTF-8`.nioCharset))
20-
21-
object MessageQP extends QueryParamDecoderMatcher[String]("message")(messageQPDecoder)
16+
object MessageQP extends QueryParamDecoderMatcher[String]("message")
2217

2318
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package org.ivovk.connect_rpc_scala.http
2+
3+
import cats.MonadThrow
4+
import fs2.Stream
5+
import org.http4s.{Charset, Headers, Media}
6+
import org.http4s.headers.`Content-Type`
7+
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}
8+
9+
object RequestEntity {
10+
def apply[F[_]](m: Media[F]): RequestEntity[F] =
11+
RequestEntity(m.body, m.headers)
12+
}
13+
14+
/**
15+
* Encoded message and headers with the knowledge how this message can be decoded.
16+
* Similar to [[org.http4s.Media]], but extends the message with `String` type representing message that is
17+
* passed in a query parameter.
18+
*/
19+
case class RequestEntity[F[_]](
20+
message: String | Stream[F, Byte],
21+
headers: Headers,
22+
) {
23+
def contentType: Option[`Content-Type`] =
24+
headers.get[`Content-Type`]
25+
26+
def charset: Charset =
27+
contentType.flatMap(_.charset).getOrElse(Charset.`UTF-8`)
28+
29+
def as[A <: Message](using M: MonadThrow[F], codec: MessageCodec[F], cmp: Companion[A]): F[A] =
30+
M.rethrow(codec.decode(this).value)
31+
32+
def fold[A](string: String => A)(stream: Stream[F, Byte] => A): A =
33+
message match {
34+
case s: String =>
35+
string(s)
36+
case b: Stream[F, Byte] =>
37+
stream(b)
38+
}
39+
}

0 commit comments

Comments
 (0)