Skip to content

Optimize GET-request entity parsing, merge GET- and POST-request hand… #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conformance-suite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
features:
versions: [ HTTP_VERSION_1 ]
protocols: [ PROTOCOL_CONNECT ]
codecs: [ CODEC_JSON ]
codecs: [ CODEC_JSON ] #, CODEC_PROTO ]
stream_types: [ STREAM_TYPE_UNARY ]
supports_tls: false
supports_trailers: false
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package org.ivovk.connect_rpc_scala

import cats.Endo
import cats.data.EitherT
import cats.effect.Async
import cats.effect.kernel.Resource
import cats.implicits.*
import fs2.compression.Compression
import fs2.{Chunk, Stream}
import io.grpc.*
import io.grpc.MethodDescriptor.MethodType
import io.grpc.stub.MetadataUtils
import org.http4s.*
import org.http4s.dsl.Http4sDsl
import org.http4s.headers.`Content-Type`
import org.ivovk.connect_rpc_scala.http.*
import org.ivovk.connect_rpc_scala.http.Headers.*
import org.ivovk.connect_rpc_scala.http.MessageCodec.given
Expand Down Expand Up @@ -55,78 +54,69 @@ object ConnectRpcHttpRoutes {
for
ipChannel <- InProcessChannelBridge.create(services, configuration.waitForShutdown)
yield
def handle(
httpMethod: Method,
contentType: Option[MediaType],
entity: RequestEntity[F],
grpcMethod: String,
): F[Response[F]] = {
val eitherT = for
given MessageCodec[F] <- EitherT.fromOptionM(
contentType.flatMap(codecRegistry.byContentType).pure[F],
UnsupportedMediaType(s"Unsupported content-type ${contentType.show}. " +
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
)

method <- EitherT.fromOptionM(
methodRegistry.get(grpcMethod).pure[F],
NotFound(connectrpc.Error(
code = io.grpc.Status.NOT_FOUND.toConnectCode,
message = s"Method not found: $grpcMethod".some
))
)

_ <- EitherT.cond[F](
// Support GET-requests for all methods until https://github.com/scalapb/ScalaPB/pull/1774 is merged
httpMethod == Method.POST || (httpMethod == Method.GET && method.methodDescriptor.isSafe) || true,
(),
Forbidden(connectrpc.Error(
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
message = s"Only POST-requests are allowed for method: $grpcMethod".some
))
).leftSemiflatMap(identity)

response <- method.methodDescriptor.getType match
case MethodType.UNARY =>
EitherT.right(handleUnary(dsl, method, entity, ipChannel))
case unsupported =>
EitherT.left(NotImplemented(connectrpc.Error(
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
message = s"Unsupported method type: $unsupported".some
)))
yield response

eitherT.merge
}

HttpRoutes.of[F] {
case req@Method.GET -> Root / serviceName / methodName :? EncodingQP(contentType) +& MessageQP(message) =>
val grpcMethod = grpcMethodName(serviceName, methodName)
val entity = RequestEntity[F](message, req.headers)

codecRegistry.byContentType(contentType) match {
case Some(codec) =>
given MessageCodec[F] = codec

val media = Media[F](Stream.chunk(Chunk.array(message.getBytes)), req.headers)

methodRegistry.get(grpcMethod) match {
// Support GET-requests for all methods until https://github.com/scalapb/ScalaPB/pull/1774 is merged
case Some(entry) if entry.methodDescriptor.isSafe || true =>
entry.methodDescriptor.getType match
case MethodType.UNARY =>
handleUnary(dsl, entry, media, ipChannel)
case unsupported =>
NotImplemented(connectrpc.Error(
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
message = s"Unsupported method type: $unsupported".some
))
case Some(_) =>
Forbidden(connectrpc.Error(
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
message = s"Method supports calling using POST: $grpcMethod".some
))
case None =>
NotFound(connectrpc.Error(
code = io.grpc.Status.NOT_FOUND.toConnectCode,
message = s"Method not found: $grpcMethod".some
))
}
case None =>
UnsupportedMediaType(s"Unsupported content-type ${contentType.show}. " +
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
}
handle(Method.GET, contentType.some, entity, grpcMethod)
case req@Method.POST -> Root / serviceName / methodName =>
val grpcMethod = grpcMethodName(serviceName, methodName)
val contentType = req.headers.get[`Content-Type`].map(_.mediaType)

contentType.flatMap(codecRegistry.byContentType) match {
case Some(codec) =>
given MessageCodec[F] = codec

methodRegistry.get(grpcMethod) match {
case Some(entry) =>
entry.methodDescriptor.getType match
case MethodType.UNARY =>
handleUnary(dsl, entry, req, ipChannel)
case unsupported =>
NotImplemented(connectrpc.Error(
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
message = s"Unsupported method type: $unsupported".some
))
case None =>
NotFound(connectrpc.Error(
code = io.grpc.Status.NOT_FOUND.toConnectCode,
message = s"Method not found: $grpcMethod".some
))
}
case None =>
UnsupportedMediaType(s"Unsupported content-type ${contentType.map(_.show).orNull}. " +
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
}
val contentType = req.contentType.map(_.mediaType)
val entity = RequestEntity[F](req)

handle(Method.POST, contentType, entity, grpcMethod)
}
}


private def handleUnary[F[_] : Async](
dsl: Http4sDsl[F],
entry: RegistryEntry,
req: Media[F],
req: RequestEntity[F],
channel: Channel
)(using codec: MessageCodec[F]): F[Response[F]] = {
import dsl.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@ import cats.Applicative
import cats.data.EitherT
import cats.effect.{Async, Sync}
import cats.implicits.*
import fs2.Stream
import fs2.compression.Compression
import fs2.io.{readOutputStream, toInputStreamResource}
import fs2.text.decodeWithCharset
import org.http4s.headers.{`Content-Encoding`, `Content-Type`}
import org.http4s.{Charset, ContentCoding, DecodeResult, Entity, EntityDecoder, EntityEncoder, Media, MediaRange, MediaType}
import org.http4s.{ContentCoding, DecodeResult, Entity, EntityDecoder, EntityEncoder, Headers, MediaRange, MediaType}
import org.ivovk.connect_rpc_scala.ConnectRpcHttpRoutes.getClass
import org.slf4j.{Logger, LoggerFactory}
import scalapb.json4s.{JsonFormat, Printer}
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}

import java.net.URLDecoder
import java.util.Base64

object MessageCodec {
given [F[_] : Applicative, A <: Message](using codec: MessageCodec[F], cmp: Companion[A]): EntityDecoder[F, A] =
EntityDecoder.decodeBy(MediaRange.`*/*`)(codec.decode)
EntityDecoder.decodeBy(MediaRange.`*/*`)(m => codec.decode(RequestEntity(m)))

given [F[_], A <: Message](using codec: MessageCodec[F]): EntityEncoder[F, A] =
EntityEncoder.encodeBy(`Content-Type`(codec.mediaType))(codec.encode)
Expand All @@ -26,27 +30,33 @@ trait MessageCodec[F[_]] {

val mediaType: MediaType

def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A]
def decode[A <: Message](m: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A]

def encode[A <: Message](message: A): Entity[F]

}

class JsonMessageCodec[F[_] : Sync : Compression](jsonPrinter: Printer) extends MessageCodec[F] {
class JsonMessageCodec[F[_] : Sync : Compression](printer: Printer) extends MessageCodec[F] {

private val logger: Logger = LoggerFactory.getLogger(getClass)

override val mediaType: MediaType = MediaTypes.`application/json`

override def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
val charset = m.charset.getOrElse(Charset.`UTF-8`).nioCharset
override def decode[A <: Message](entity: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
val charset = entity.charset.nioCharset
val string = entity.message match {
case str: String =>
Sync[F].delay(URLDecoder.decode(str, charset))
case stream: Stream[F, Byte] =>
decompressed(entity.headers, stream)
.through(decodeWithCharset(charset))
.compile.string
}

val f = decompressed(m)
.through(decodeWithCharset(charset))
.compile.string
val f = string
.flatMap { str =>
if (logger.isTraceEnabled) {
logger.trace(s">>> Headers: ${m.headers}")
logger.trace(s">>> Headers: ${entity.headers}")
logger.trace(s">>> JSON: $str")
}

Expand All @@ -57,7 +67,7 @@ class JsonMessageCodec[F[_] : Sync : Compression](jsonPrinter: Printer) extends
}

override def encode[A <: Message](message: A): Entity[F] = {
val string = jsonPrinter.print(message)
val string = printer.print(message)

if (logger.isTraceEnabled) {
logger.trace(s"<<< JSON: $string")
Expand All @@ -72,23 +82,28 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {

private val logger: Logger = LoggerFactory.getLogger(getClass)

override val mediaType: MediaType = MediaTypes.`application/proto`
private val base64dec = Base64.getUrlDecoder

override def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
val f = toInputStreamResource(decompressed(m)).use { is =>
Async[F].delay {
val message = cmp.parseFrom(is)
override val mediaType: MediaType = MediaTypes.`application/proto`

if (logger.isTraceEnabled) {
logger.trace(s">>> Headers: ${m.headers}")
logger.trace(s">>> Proto: ${message.toProtoString}")
}
override def decode[A <: Message](entity: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
val f = entity.message match {
case str: String =>
Async[F].delay(base64dec.decode(str.getBytes(entity.charset.nioCharset)))
.flatMap(arr => Async[F].delay(cmp.parseFrom(arr)))
case stream: Stream[F, Byte] =>
toInputStreamResource(decompressed(entity.headers, stream))
.use(is => Async[F].delay(cmp.parseFrom(is)))
}

message
EitherT.right(f.map { message =>
if (logger.isTraceEnabled) {
logger.trace(s">>> Headers: ${entity.headers}")
logger.trace(s">>> Proto: ${message.toProtoString}")
}
}

EitherT.right(f)
message
})
}

override def encode[A <: Message](message: A): Entity[F] = {
Expand All @@ -104,10 +119,10 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {

}

def decompressed[F[_] : Compression](m: Media[F]): fs2.Stream[F, Byte] = {
val encoding = m.headers.get[`Content-Encoding`].map(_.contentCoding)
def decompressed[F[_] : Compression](headers: Headers, body: Stream[F, Byte]): Stream[F, Byte] = {
val encoding = headers.get[`Content-Encoding`].map(_.contentCoding)

m.body.through(encoding match {
body.through(encoding match {
case Some(ContentCoding.gzip) =>
Compression[F].gunzip().andThen(_.flatMap(_.content))
case _ =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package org.ivovk.connect_rpc_scala.http

import org.http4s.dsl.impl.QueryParamDecoderMatcher
import org.http4s.{Charset, MediaType, ParseFailure, QueryParamDecoder}

import java.net.URLDecoder
import org.http4s.{MediaType, ParseFailure, QueryParamDecoder}

object QueryParams {

Expand All @@ -15,9 +13,6 @@ object QueryParams {

object EncodingQP extends QueryParamDecoderMatcher[MediaType]("encoding")(encodingQPDecoder)

private val messageQPDecoder = QueryParamDecoder[String]
.map(URLDecoder.decode(_, Charset.`UTF-8`.nioCharset))

object MessageQP extends QueryParamDecoderMatcher[String]("message")(messageQPDecoder)
object MessageQP extends QueryParamDecoderMatcher[String]("message")

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.ivovk.connect_rpc_scala.http

import cats.MonadThrow
import fs2.Stream
import org.http4s.{Charset, Headers, Media}
import org.http4s.headers.`Content-Type`
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}

object RequestEntity {
def apply[F[_]](m: Media[F]): RequestEntity[F] =
RequestEntity(m.body, m.headers)
}

/**
* Encoded message and headers with the knowledge how this message can be decoded.
* Similar to [[org.http4s.Media]], but extends the message with `String` type representing message that is
* passed in a query parameter.
*/
case class RequestEntity[F[_]](
message: String | Stream[F, Byte],
headers: Headers,
) {
def contentType: Option[`Content-Type`] =
headers.get[`Content-Type`]

def charset: Charset =
contentType.flatMap(_.charset).getOrElse(Charset.`UTF-8`)

def as[A <: Message](using M: MonadThrow[F], codec: MessageCodec[F], cmp: Companion[A]): F[A] =
M.rethrow(codec.decode(this).value)

def fold[A](string: String => A)(stream: Stream[F, Byte] => A): A =
message match {
case s: String =>
string(s)
case b: Stream[F, Byte] =>
stream(b)
}
}