1
1
from quart import Quart , request , Response
2
- from hypercorn .config import Config
2
+ from hypercorn .config import Config as HyperConfig
3
3
from hypercorn .asyncio import serve
4
4
import asyncio
5
5
import yaml
@@ -17,22 +17,36 @@ class RESTAdapter(ProtocolAdapter):
17
17
"""REST protocol adapter implementation using Quart and Hypercorn."""
18
18
19
19
def _get_config (self ) -> Dict [str , Any ]:
20
+ """Get REST configuration from config manager."""
20
21
return self .config_manager .get_rest_config ()
21
22
22
23
def __init__ (self , config_manager : ConfigManager ):
24
+ """Initialize REST adapter with configuration.
25
+
26
+ Args:
27
+ config_manager: Configuration manager instance
28
+ """
23
29
super ().__init__ (config_manager )
24
30
self .app = Quart (__name__ )
25
31
self ._setup_routes ()
26
- self ._active_streams = {} # client_id -> asyncio.Queue
32
+ self .server = None
33
+ self ._active_streams = {} # Store active streams by client_id
34
+ # Main event loop
35
+ self ._loop : Optional [asyncio .AbstractEventLoop ] = None
27
36
self ._running = False
28
- self ._server_task : Optional [asyncio .Task ] = None
29
37
logger .debug ("REST - Adapter initialized with config: host=%s, port=%s" ,
30
38
self .config ['host' ], self .config ['port' ])
31
39
32
40
def _setup_routes (self ) -> None :
41
+ """Set up the streaming endpoint."""
33
42
self .app .post (self .config ['endpoint' ])(self ._handle_streaming_message )
34
43
35
44
async def _handle_streaming_message (self ) -> Response :
45
+ """Handle incoming messages with streaming response.
46
+
47
+ Returns:
48
+ Response: Streaming response with simulation results
49
+ """
36
50
content_type = request .headers .get ('content-type' , '' )
37
51
body = await request .get_data ()
38
52
@@ -58,6 +72,7 @@ async def _handle_streaming_message(self) -> Response:
58
72
producer = simulation .get ('client_id' , 'unknown' )
59
73
consumer = simulation .get ('simulator' , 'unknown' )
60
74
75
+ # Add bridge metadata
61
76
message ['bridge_meta' ] = {
62
77
'protocol' : 'rest' ,
63
78
'producer' : producer ,
@@ -67,14 +82,15 @@ async def _handle_streaming_message(self) -> Response:
67
82
logger .debug (
68
83
"REST - Processing message from producer: %s, simulator: %s" ,
69
84
producer , consumer )
70
-
85
+ # Use SignalManager to send the signal
71
86
signal ('message_received_input_rest' ).send (
72
87
message = message ,
73
88
producer = producer ,
74
89
consumer = consumer ,
75
90
protocol = 'rest'
76
91
)
77
92
93
+ # Create a queue for this client's messages
78
94
queue = asyncio .Queue ()
79
95
self ._active_streams [producer ] = queue
80
96
@@ -85,12 +101,23 @@ async def _handle_streaming_message(self) -> Response:
85
101
)
86
102
87
103
def _parse_message (self , body : bytes , content_type : str ) -> Dict [str , Any ]:
104
+ """Parse message body based on content type.
105
+
106
+ Args:
107
+ body: Raw message body
108
+ content_type: Content type header
109
+
110
+ Returns:
111
+ Dict[str, Any]: Parsed message
112
+ """
88
113
if 'yaml' in content_type :
89
114
logger .debug ("REST - Attempting to parse message as YAML" )
90
115
return yaml .safe_load (body )
91
116
elif 'json' in content_type :
92
117
logger .debug ("REST - Attempting to parse message as JSON" )
93
118
return json .loads (body )
119
+
120
+ # Fallback: try YAML, then JSON, then raw text
94
121
try :
95
122
logger .debug (
96
123
"REST - Attempting to parse message as YAML (fallback)" )
@@ -108,95 +135,131 @@ def _parse_message(self, body: bytes, content_type: str) -> Dict[str, Any]:
108
135
}
109
136
110
137
async def _generate_response (
111
- self , producer : str , queue : asyncio .Queue
112
- ) -> AsyncGenerator [str , None ]:
138
+ self , producer : str , queue : asyncio .Queue ) -> AsyncGenerator [str , None ]:
139
+ """Generate streaming response.
140
+
141
+ Args:
142
+ producer: Client ID
143
+ queue: Message queue for this client
144
+
145
+ Yields:
146
+ str: JSON-encoded messages
147
+ """
113
148
try :
149
+ # Send initial acknowledgment
114
150
yield json .dumps ({"status" : "processing" }) + "\n "
151
+ # Keep the connection open and wait for results
115
152
while True :
116
153
try :
117
154
result = await asyncio .wait_for (queue .get (), timeout = 600 )
118
155
yield json .dumps (result ) + "\n "
119
156
except asyncio .TimeoutError :
120
- yield json .dumps ({
121
- "status" : "timeout" ,
122
- "error" : "No response received within timeout"
123
- }) + "\n "
157
+ yield json .dumps ({"status" : "timeout" , "error" : "No response received within timeout" }) + "\n "
124
158
break
125
159
except Exception as e :
126
160
logger .error ("REST - Error in stream: %s" , e )
127
161
yield json .dumps ({"status" : "error" , "error" : str (e )}) + "\n "
128
162
break
129
163
finally :
130
- self ._active_streams .pop (producer , None )
164
+ # Clean up when the stream ends
165
+ if producer in self ._active_streams :
166
+ del self ._active_streams [producer ]
131
167
132
168
async def send_result (self , producer : str , result : Dict [str , Any ]) -> None :
169
+ """Send a result message to a specific client.
170
+
171
+ Args:
172
+ producer: Client ID
173
+ result: Result message to send
174
+ """
133
175
if producer in self ._active_streams :
134
176
await self ._active_streams [producer ].put (result )
135
177
else :
136
178
logger .warning (
137
- "REST - No active stream found for producer: %s" ,
138
- producer )
139
-
140
- async def _start_server (self ):
141
- config = Config ()
142
- config .bind = [f"{ self .config ['host' ]} :{ self .config ['port' ]} " ]
143
- config .certfile = self .config .get ('certfile' )
144
- config .keyfile = self .config .get ('keyfile' )
145
- config .worker_class = "asyncio"
179
+ "REST - No active stream found for producer: %s" , producer )
180
+
181
+ async def _start_server (self ) -> None :
182
+ """Start the Hypercorn server."""
183
+ self ._loop = asyncio .get_running_loop () # Save main event loop
184
+
185
+ config = HyperConfig ()
186
+ config .errorlog = logger # Use the main logger for error logs
187
+ config .accesslog = logger # Use the main logger for access logs
188
+ config .bind = ["%s:%s" % (self .config ['host' ], self .config ['port' ])]
146
189
config .use_reloader = False
190
+ config .worker_class = "asyncio"
191
+ config .alpn_protocols = ["h2" , "http/1.1" ]
147
192
148
- logger .debug ("REST - Starting Hypercorn server" )
193
+ if self .config .get ('certfile' ) and self .config .get ('keyfile' ):
194
+ config .certfile = self .config ['certfile' ]
195
+ config .keyfile = self .config ['keyfile' ]
149
196
await serve (self .app , config )
150
197
151
- def start (self ):
152
- """Start the REST server in the current asyncio event loop."""
153
- if self ._running :
154
- logger .warning ("REST - Adapter already running" )
155
- return
156
- self ._running = True
198
+ def start (self ) -> None :
199
+ """Start the REST server."""
200
+ logger .debug (
201
+ "REST - Starting adapter on %s:%s" ,
202
+ self .config ['host' ], self .config ['port' ])
157
203
try :
158
- loop = asyncio .get_event_loop ()
159
- except RuntimeError :
160
- loop = asyncio .new_event_loop ()
161
- asyncio .set_event_loop (loop )
162
-
163
- self ._server_task = loop .create_task (self ._start_server ())
164
- logger .debug ("REST - Server started as asyncio task" )
165
-
166
- def stop (self ):
167
- """Stop the REST adapter."""
168
- logger .debug ("REST - Stopping adapter" )
169
- self ._running = False
170
- if self ._server_task :
171
- self ._server_task .cancel ()
204
+ asyncio .run (self ._start_server ())
205
+ self ._running = True
206
+ except Exception as e :
207
+ logger .error ("REST - Error starting server: %s" , e )
208
+ raise
172
209
173
210
def send_result_sync (self , producer : str , result : Dict [str , Any ]) -> None :
211
+ """Synchronous wrapper for sending result messages.
212
+
213
+ Args:
214
+ producer: Client ID
215
+ result: Result message to send
216
+ """
174
217
if producer not in self ._active_streams :
175
218
logger .warning (
176
- "REST - No active stream found for producer: %s. Available streams: %s" ,
219
+ "REST - No active stream found for producer: %s. "
220
+ "Available streams: %s" ,
177
221
producer , list (self ._active_streams .keys ())
178
222
)
179
223
return
180
224
181
- try :
182
- loop = asyncio .get_event_loop ()
183
- if loop .is_running ():
184
- future = asyncio .run_coroutine_threadsafe (
185
- self .send_result (producer , result ),
186
- loop
187
- )
225
+ if self ._loop and self ._loop .is_running ():
226
+ # Use run_coroutine_threadsafe to execute coroutine in main loop
227
+ future = asyncio .run_coroutine_threadsafe (
228
+ self .send_result (producer , result ),
229
+ self ._loop
230
+ )
231
+ try :
232
+ # Optional: wait for result with short timeout
188
233
future .result (timeout = 5 )
189
- else :
190
- logger .error (
191
- "REST - Event loop not running; cannot send result." )
192
- except Exception as e :
193
- logger .error ("REST - Error sending result: %s" , e )
234
+ except Exception as e :
235
+ logger .error ("REST - Error sending result: %s" , e )
236
+ else :
237
+ logger .error ("REST - Event loop not running; cannot send result." )
238
+
239
+ def stop (self ) -> None :
240
+ """Stop the REST server."""
241
+ logger .debug ("REST - Stopping adapter" )
242
+ self ._running = False
243
+ if self .server :
244
+ self .server .close ()
194
245
195
246
def _handle_message (self , message : Dict [str , Any ]) -> None :
196
- # REST handled by Quart routes
247
+ """Handle incoming messages (required by ProtocolAdapter).
248
+
249
+ Args:
250
+ message: Message to handle
251
+ """
252
+ # For REST, this is handled by the Quart endpoint
197
253
pass
198
254
199
- def publish_result_message_rest (self , sender , ** kwargs ):
255
+ def publish_result_message_rest (self , sender , ** kwargs ): # pylint: disable=unused-argument
256
+ """
257
+ Publish result message via REST adapter.
258
+
259
+ Args:
260
+ message: Message payload to send
261
+ destination: REST endpoint destination
262
+ """
200
263
try :
201
264
message = kwargs .get ('message' , {})
202
265
destination = message .get ('destinations' , [])[0 ]
0 commit comments