Skip to content

Optimized asyncUnaryCall implementation with cancellation support #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ import io.grpc.stub.MetadataUtils
import org.http4s.dsl.Http4sDsl
import org.http4s.{Header, MediaType, MessageFailure, Method, Response}
import org.ivovk.connect_rpc_scala.Mappings.*
import org.ivovk.connect_rpc_scala.grpc.{MethodName, MethodRegistry}
import org.ivovk.connect_rpc_scala.grpc.{GrpcClientCalls, MethodName, MethodRegistry}
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given
import org.ivovk.connect_rpc_scala.http.codec.{MessageCodec, MessageCodecRegistry}
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity}
import org.slf4j.{Logger, LoggerFactory}
import scalapb.grpc.ClientCalls
import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TextFormat}

import java.util.concurrent.atomic.AtomicReference
Expand Down Expand Up @@ -103,32 +102,35 @@ class ConnectHandler[F[_] : Async](
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}, Entity: $message")
}

Async[F].fromFuture(Async[F].delay {
ClientCalls.asyncUnaryCall[GeneratedMessage, GeneratedMessage](
val callOptions = CallOptions.DEFAULT
.pipe(
req.timeout match {
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
case None => identity
}
)

GrpcClientCalls
.asyncUnaryCall2[F, GeneratedMessage, GeneratedMessage](
ClientInterceptors.intercept(
channel,
MetadataUtils.newAttachHeadersInterceptor(req.headers.toMetadata),
MetadataUtils.newCaptureMetadataInterceptor(responseHeaderMetadata, responseTrailerMetadata),
),
method.descriptor,
CallOptions.DEFAULT.pipe(
req.timeout match {
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
case None => identity
}
),
callOptions,
message
)
}).map { response =>
val headers = responseHeaderMetadata.get.toHeaders() ++
responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders)
.map { response =>
val headers = responseHeaderMetadata.get.toHeaders() ++
responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders)

if (logger.isTraceEnabled) {
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
}
if (logger.isTraceEnabled) {
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
}

Response(Ok, headers = headers).withEntity(response)
}
Response(Ok, headers = headers).withEntity(response)
}
}
.recover { case e =>
val grpcStatus = e match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package org.ivovk.connect_rpc_scala.grpc

import cats.effect.Async
import com.google.common.util.concurrent.{FutureCallback, Futures, MoreExecutors}
import io.grpc.stub.{ClientCalls, StreamObserver}
import io.grpc.{CallOptions, Channel, MethodDescriptor}

object GrpcClientCalls {

/**
* Asynchronous unary call.
*
* Optimized version of the `scalapb.grpc.ClientCalls.asyncUnaryCall` that skips Scala's Future instantiation
* and supports cancellation.
*/
def asyncUnaryCall[F[_] : Async, Req, Resp](
channel: Channel,
method: MethodDescriptor[Req, Resp],
options: CallOptions,
request: Req,
): F[Resp] = {
Async[F].async[Resp] { cb =>
Async[F].delay {
val future = ClientCalls.futureUnaryCall(channel.newCall(method, options), request)

Futures.addCallback(
future,
new FutureCallback[Resp] {
def onSuccess(result: Resp): Unit = cb(Right(result))

def onFailure(t: Throwable): Unit = cb(Left(t))
},
MoreExecutors.directExecutor(),
)

Some(Async[F].delay(future.cancel(true)))
}
}
}

/**
* Implementation that should be faster than the [[asyncUnaryCall]].
*/
def asyncUnaryCall2[F[_] : Async, Req, Resp](
channel: Channel,
method: MethodDescriptor[Req, Resp],
options: CallOptions,
request: Req,
): F[Resp] = {
Async[F].async[Resp] { cb =>
Async[F].delay {
val call = channel.newCall(method, options)

ClientCalls.asyncUnaryCall(call, request, new CallbackObserver(cb))

Some(Async[F].delay(call.cancel("Cancelled", null)))
}
}
}

/**
* [[CallbackObserver]] either executes [[onNext]] -> [[onCompleted]] during the happy path or just [[onError]] in case of
* an error.
*/
private class CallbackObserver[F[_] : Async, Resp](cb: Either[Throwable, Resp] => Unit) extends StreamObserver[Resp] {
private var value: Option[Either[Throwable, Resp]] = None

override def onNext(value: Resp): Unit = {
if this.value.isDefined then
throw new IllegalStateException("Value already received")

this.value = Some(Right(value))
}

override def onError(t: Throwable): Unit = {
cb(Left(t))
}

override def onCompleted(): Unit = {
this.value match
case Some(v) => cb(v)
case None => cb(Left(new IllegalStateException("No value received or call to onCompleted after onError")))
}

}

}
Loading