Skip to content

Commit 8253552

Browse files
committed
rest adaper fix
1 parent 1d1db9c commit 8253552

File tree

1 file changed

+119
-56
lines changed

1 file changed

+119
-56
lines changed

simulation_bridge/src/protocol_adapters/rest/rest_adapter.py

Lines changed: 119 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from quart import Quart, request, Response
2-
from hypercorn.config import Config
2+
from hypercorn.config import Config as HyperConfig
33
from hypercorn.asyncio import serve
44
import asyncio
55
import yaml
@@ -17,22 +17,36 @@ class RESTAdapter(ProtocolAdapter):
1717
"""REST protocol adapter implementation using Quart and Hypercorn."""
1818

1919
def _get_config(self) -> Dict[str, Any]:
20+
"""Get REST configuration from config manager."""
2021
return self.config_manager.get_rest_config()
2122

2223
def __init__(self, config_manager: ConfigManager):
24+
"""Initialize REST adapter with configuration.
25+
26+
Args:
27+
config_manager: Configuration manager instance
28+
"""
2329
super().__init__(config_manager)
2430
self.app = Quart(__name__)
2531
self._setup_routes()
26-
self._active_streams = {} # client_id -> asyncio.Queue
32+
self.server = None
33+
self._active_streams = {} # Store active streams by client_id
34+
# Main event loop
35+
self._loop: Optional[asyncio.AbstractEventLoop] = None
2736
self._running = False
28-
self._server_task: Optional[asyncio.Task] = None
2937
logger.debug("REST - Adapter initialized with config: host=%s, port=%s",
3038
self.config['host'], self.config['port'])
3139

3240
def _setup_routes(self) -> None:
41+
"""Set up the streaming endpoint."""
3342
self.app.post(self.config['endpoint'])(self._handle_streaming_message)
3443

3544
async def _handle_streaming_message(self) -> Response:
45+
"""Handle incoming messages with streaming response.
46+
47+
Returns:
48+
Response: Streaming response with simulation results
49+
"""
3650
content_type = request.headers.get('content-type', '')
3751
body = await request.get_data()
3852

@@ -58,6 +72,7 @@ async def _handle_streaming_message(self) -> Response:
5872
producer = simulation.get('client_id', 'unknown')
5973
consumer = simulation.get('simulator', 'unknown')
6074

75+
# Add bridge metadata
6176
message['bridge_meta'] = {
6277
'protocol': 'rest',
6378
'producer': producer,
@@ -67,14 +82,15 @@ async def _handle_streaming_message(self) -> Response:
6782
logger.debug(
6883
"REST - Processing message from producer: %s, simulator: %s",
6984
producer, consumer)
70-
85+
# Use SignalManager to send the signal
7186
signal('message_received_input_rest').send(
7287
message=message,
7388
producer=producer,
7489
consumer=consumer,
7590
protocol='rest'
7691
)
7792

93+
# Create a queue for this client's messages
7894
queue = asyncio.Queue()
7995
self._active_streams[producer] = queue
8096

@@ -85,12 +101,23 @@ async def _handle_streaming_message(self) -> Response:
85101
)
86102

87103
def _parse_message(self, body: bytes, content_type: str) -> Dict[str, Any]:
104+
"""Parse message body based on content type.
105+
106+
Args:
107+
body: Raw message body
108+
content_type: Content type header
109+
110+
Returns:
111+
Dict[str, Any]: Parsed message
112+
"""
88113
if 'yaml' in content_type:
89114
logger.debug("REST - Attempting to parse message as YAML")
90115
return yaml.safe_load(body)
91116
elif 'json' in content_type:
92117
logger.debug("REST - Attempting to parse message as JSON")
93118
return json.loads(body)
119+
120+
# Fallback: try YAML, then JSON, then raw text
94121
try:
95122
logger.debug(
96123
"REST - Attempting to parse message as YAML (fallback)")
@@ -108,95 +135,131 @@ def _parse_message(self, body: bytes, content_type: str) -> Dict[str, Any]:
108135
}
109136

110137
async def _generate_response(
111-
self, producer: str, queue: asyncio.Queue
112-
) -> AsyncGenerator[str, None]:
138+
self, producer: str, queue: asyncio.Queue) -> AsyncGenerator[str, None]:
139+
"""Generate streaming response.
140+
141+
Args:
142+
producer: Client ID
143+
queue: Message queue for this client
144+
145+
Yields:
146+
str: JSON-encoded messages
147+
"""
113148
try:
149+
# Send initial acknowledgment
114150
yield json.dumps({"status": "processing"}) + "\n"
151+
# Keep the connection open and wait for results
115152
while True:
116153
try:
117154
result = await asyncio.wait_for(queue.get(), timeout=600)
118155
yield json.dumps(result) + "\n"
119156
except asyncio.TimeoutError:
120-
yield json.dumps({
121-
"status": "timeout",
122-
"error": "No response received within timeout"
123-
}) + "\n"
157+
yield json.dumps({"status": "timeout", "error": "No response received within timeout"}) + "\n"
124158
break
125159
except Exception as e:
126160
logger.error("REST - Error in stream: %s", e)
127161
yield json.dumps({"status": "error", "error": str(e)}) + "\n"
128162
break
129163
finally:
130-
self._active_streams.pop(producer, None)
164+
# Clean up when the stream ends
165+
if producer in self._active_streams:
166+
del self._active_streams[producer]
131167

132168
async def send_result(self, producer: str, result: Dict[str, Any]) -> None:
169+
"""Send a result message to a specific client.
170+
171+
Args:
172+
producer: Client ID
173+
result: Result message to send
174+
"""
133175
if producer in self._active_streams:
134176
await self._active_streams[producer].put(result)
135177
else:
136178
logger.warning(
137-
"REST - No active stream found for producer: %s",
138-
producer)
139-
140-
async def _start_server(self):
141-
config = Config()
142-
config.bind = [f"{self.config['host']}:{self.config['port']}"]
143-
config.certfile = self.config.get('certfile')
144-
config.keyfile = self.config.get('keyfile')
145-
config.worker_class = "asyncio"
179+
"REST - No active stream found for producer: %s", producer)
180+
181+
async def _start_server(self) -> None:
182+
"""Start the Hypercorn server."""
183+
self._loop = asyncio.get_running_loop() # Save main event loop
184+
185+
config = HyperConfig()
186+
config.errorlog = logger # Use the main logger for error logs
187+
config.accesslog = logger # Use the main logger for access logs
188+
config.bind = ["%s:%s" % (self.config['host'], self.config['port'])]
146189
config.use_reloader = False
190+
config.worker_class = "asyncio"
191+
config.alpn_protocols = ["h2", "http/1.1"]
147192

148-
logger.debug("REST - Starting Hypercorn server")
193+
if self.config.get('certfile') and self.config.get('keyfile'):
194+
config.certfile = self.config['certfile']
195+
config.keyfile = self.config['keyfile']
149196
await serve(self.app, config)
150197

151-
def start(self):
152-
"""Start the REST server in the current asyncio event loop."""
153-
if self._running:
154-
logger.warning("REST - Adapter already running")
155-
return
156-
self._running = True
198+
def start(self) -> None:
199+
"""Start the REST server."""
200+
logger.debug(
201+
"REST - Starting adapter on %s:%s",
202+
self.config['host'], self.config['port'])
157203
try:
158-
loop = asyncio.get_event_loop()
159-
except RuntimeError:
160-
loop = asyncio.new_event_loop()
161-
asyncio.set_event_loop(loop)
162-
163-
self._server_task = loop.create_task(self._start_server())
164-
logger.debug("REST - Server started as asyncio task")
165-
166-
def stop(self):
167-
"""Stop the REST adapter."""
168-
logger.debug("REST - Stopping adapter")
169-
self._running = False
170-
if self._server_task:
171-
self._server_task.cancel()
204+
asyncio.run(self._start_server())
205+
self._running = True
206+
except Exception as e:
207+
logger.error("REST - Error starting server: %s", e)
208+
raise
172209

173210
def send_result_sync(self, producer: str, result: Dict[str, Any]) -> None:
211+
"""Synchronous wrapper for sending result messages.
212+
213+
Args:
214+
producer: Client ID
215+
result: Result message to send
216+
"""
174217
if producer not in self._active_streams:
175218
logger.warning(
176-
"REST - No active stream found for producer: %s. Available streams: %s",
219+
"REST - No active stream found for producer: %s. "
220+
"Available streams: %s",
177221
producer, list(self._active_streams.keys())
178222
)
179223
return
180224

181-
try:
182-
loop = asyncio.get_event_loop()
183-
if loop.is_running():
184-
future = asyncio.run_coroutine_threadsafe(
185-
self.send_result(producer, result),
186-
loop
187-
)
225+
if self._loop and self._loop.is_running():
226+
# Use run_coroutine_threadsafe to execute coroutine in main loop
227+
future = asyncio.run_coroutine_threadsafe(
228+
self.send_result(producer, result),
229+
self._loop
230+
)
231+
try:
232+
# Optional: wait for result with short timeout
188233
future.result(timeout=5)
189-
else:
190-
logger.error(
191-
"REST - Event loop not running; cannot send result.")
192-
except Exception as e:
193-
logger.error("REST - Error sending result: %s", e)
234+
except Exception as e:
235+
logger.error("REST - Error sending result: %s", e)
236+
else:
237+
logger.error("REST - Event loop not running; cannot send result.")
238+
239+
def stop(self) -> None:
240+
"""Stop the REST server."""
241+
logger.debug("REST - Stopping adapter")
242+
self._running = False
243+
if self.server:
244+
self.server.close()
194245

195246
def _handle_message(self, message: Dict[str, Any]) -> None:
196-
# REST handled by Quart routes
247+
"""Handle incoming messages (required by ProtocolAdapter).
248+
249+
Args:
250+
message: Message to handle
251+
"""
252+
# For REST, this is handled by the Quart endpoint
197253
pass
198254

199-
def publish_result_message_rest(self, sender, **kwargs):
255+
def publish_result_message_rest(self, sender, **kwargs): # pylint: disable=unused-argument
256+
"""
257+
Publish result message via REST adapter.
258+
259+
Args:
260+
message: Message payload to send
261+
destination: REST endpoint destination
262+
"""
200263
try:
201264
message = kwargs.get('message', {})
202265
destination = message.get('destinations', [])[0]

0 commit comments

Comments
 (0)