Skip to content

Commit 29dab9a

Browse files
authored
Optimized asyncUnaryCall implementation with cancellation support (#45)
1 parent 74403b0 commit 29dab9a

File tree

2 files changed

+107
-18
lines changed

2 files changed

+107
-18
lines changed

core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@ import io.grpc.stub.MetadataUtils
99
import org.http4s.dsl.Http4sDsl
1010
import org.http4s.{Header, MediaType, MessageFailure, Method, Response}
1111
import org.ivovk.connect_rpc_scala.Mappings.*
12-
import org.ivovk.connect_rpc_scala.grpc.{MethodName, MethodRegistry}
12+
import org.ivovk.connect_rpc_scala.grpc.{GrpcClientCalls, MethodName, MethodRegistry}
1313
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
1414
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given
1515
import org.ivovk.connect_rpc_scala.http.codec.{MessageCodec, MessageCodecRegistry}
1616
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity}
1717
import org.slf4j.{Logger, LoggerFactory}
18-
import scalapb.grpc.ClientCalls
1918
import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TextFormat}
2019

2120
import java.util.concurrent.atomic.AtomicReference
@@ -103,32 +102,35 @@ class ConnectHandler[F[_] : Async](
103102
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}, Entity: $message")
104103
}
105104

106-
Async[F].fromFuture(Async[F].delay {
107-
ClientCalls.asyncUnaryCall[GeneratedMessage, GeneratedMessage](
105+
val callOptions = CallOptions.DEFAULT
106+
.pipe(
107+
req.timeout match {
108+
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
109+
case None => identity
110+
}
111+
)
112+
113+
GrpcClientCalls
114+
.asyncUnaryCall2[F, GeneratedMessage, GeneratedMessage](
108115
ClientInterceptors.intercept(
109116
channel,
110117
MetadataUtils.newAttachHeadersInterceptor(req.headers.toMetadata),
111118
MetadataUtils.newCaptureMetadataInterceptor(responseHeaderMetadata, responseTrailerMetadata),
112119
),
113120
method.descriptor,
114-
CallOptions.DEFAULT.pipe(
115-
req.timeout match {
116-
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
117-
case None => identity
118-
}
119-
),
121+
callOptions,
120122
message
121123
)
122-
}).map { response =>
123-
val headers = responseHeaderMetadata.get.toHeaders() ++
124-
responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders)
124+
.map { response =>
125+
val headers = responseHeaderMetadata.get.toHeaders() ++
126+
responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders)
125127

126-
if (logger.isTraceEnabled) {
127-
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
128-
}
128+
if (logger.isTraceEnabled) {
129+
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
130+
}
129131

130-
Response(Ok, headers = headers).withEntity(response)
131-
}
132+
Response(Ok, headers = headers).withEntity(response)
133+
}
132134
}
133135
.recover { case e =>
134136
val grpcStatus = e match {
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package org.ivovk.connect_rpc_scala.grpc
2+
3+
import cats.effect.Async
4+
import com.google.common.util.concurrent.{FutureCallback, Futures, MoreExecutors}
5+
import io.grpc.stub.{ClientCalls, StreamObserver}
6+
import io.grpc.{CallOptions, Channel, MethodDescriptor}
7+
8+
object GrpcClientCalls {
9+
10+
/**
11+
* Asynchronous unary call.
12+
*
13+
* Optimized version of the `scalapb.grpc.ClientCalls.asyncUnaryCall` that skips Scala's Future instantiation
14+
* and supports cancellation.
15+
*/
16+
def asyncUnaryCall[F[_] : Async, Req, Resp](
17+
channel: Channel,
18+
method: MethodDescriptor[Req, Resp],
19+
options: CallOptions,
20+
request: Req,
21+
): F[Resp] = {
22+
Async[F].async[Resp] { cb =>
23+
Async[F].delay {
24+
val future = ClientCalls.futureUnaryCall(channel.newCall(method, options), request)
25+
26+
Futures.addCallback(
27+
future,
28+
new FutureCallback[Resp] {
29+
def onSuccess(result: Resp): Unit = cb(Right(result))
30+
31+
def onFailure(t: Throwable): Unit = cb(Left(t))
32+
},
33+
MoreExecutors.directExecutor(),
34+
)
35+
36+
Some(Async[F].delay(future.cancel(true)))
37+
}
38+
}
39+
}
40+
41+
/**
42+
* Implementation that should be faster than the [[asyncUnaryCall]].
43+
*/
44+
def asyncUnaryCall2[F[_] : Async, Req, Resp](
45+
channel: Channel,
46+
method: MethodDescriptor[Req, Resp],
47+
options: CallOptions,
48+
request: Req,
49+
): F[Resp] = {
50+
Async[F].async[Resp] { cb =>
51+
Async[F].delay {
52+
val call = channel.newCall(method, options)
53+
54+
ClientCalls.asyncUnaryCall(call, request, new CallbackObserver(cb))
55+
56+
Some(Async[F].delay(call.cancel("Cancelled", null)))
57+
}
58+
}
59+
}
60+
61+
/**
62+
* [[CallbackObserver]] either executes [[onNext]] -> [[onCompleted]] during the happy path or just [[onError]] in case of
63+
* an error.
64+
*/
65+
private class CallbackObserver[F[_] : Async, Resp](cb: Either[Throwable, Resp] => Unit) extends StreamObserver[Resp] {
66+
private var value: Option[Either[Throwable, Resp]] = None
67+
68+
override def onNext(value: Resp): Unit = {
69+
if this.value.isDefined then
70+
throw new IllegalStateException("Value already received")
71+
72+
this.value = Some(Right(value))
73+
}
74+
75+
override def onError(t: Throwable): Unit = {
76+
cb(Left(t))
77+
}
78+
79+
override def onCompleted(): Unit = {
80+
this.value match
81+
case Some(v) => cb(v)
82+
case None => cb(Left(new IllegalStateException("No value received or call to onCompleted after onError")))
83+
}
84+
85+
}
86+
87+
}

0 commit comments

Comments
 (0)