55import orjson as json
66from gevent import Greenlet , Timeout
77from locust import User , task
8+ from locust .env import Environment
89from orjson import JSONDecodeError
910from websocket import WebSocket , WebSocketConnectionClosedException , create_connection
11+
1012from chainbench .util .jsonrpc import RpcCall
1113
1214
1315class WSSubscription :
14- def __init__ (
15- self ,
16- subscribe_method : str ,
17- subscribe_params : dict | list ,
18- unsubscribe_method : str
19- ):
16+ def __init__ (self , subscribe_method : str , subscribe_params : dict | list , unsubscribe_method : str ):
2017 self .subscribe_rpc_call : RpcCall = RpcCall (subscribe_method , subscribe_params )
2118 self .unsubscribe_method : str = unsubscribe_method
2219 self .subscribed : bool = False
@@ -38,7 +35,7 @@ def subscription_id(self):
3835
3936
4037class WSRequest :
41- def __init__ (self , rpc_call : RpcCall , start_time : int , subscription_index : int = None ):
38+ def __init__ (self , rpc_call : RpcCall , start_time : int , subscription_index : int | None = None ):
4239 self .rpc_call = rpc_call
4340 self .start_time = start_time
4441 self .subscription_index = subscription_index
@@ -52,7 +49,7 @@ class WssJrpcUser(User):
5249 subscriptions : list [WSSubscription ] = []
5350 subscription_ids_to_index : dict [str | int , int ] = {}
5451
55- def __init__ (self , environment ):
52+ def __init__ (self , environment : Environment ):
5653 super ().__init__ (environment )
5754 self ._ws : WebSocket | None = None
5855 self ._ws_greenlet : Greenlet | None = None
@@ -108,7 +105,7 @@ def get_notification_name(self, parsed_response: dict):
108105 # Override this method to return the name of the notification if this is not correct
109106 return parsed_response ["method" ]
110107
111- def on_message (self , message ):
108+ def on_message (self , message : str | bytes ):
112109 try :
113110 parsed_json : dict = json .loads (message )
114111 if "error" in parsed_json :
@@ -181,15 +178,24 @@ def receive_loop(self):
181178 self .logger .error ("Connection closed by server, trying to reconnect..." )
182179 self .on_start ()
183180
184- def send (self , rpc_call : RpcCall = None , method : str = None , params : dict | list = None , subscription_index : int = None ):
181+ def send (
182+ self ,
183+ rpc_call : RpcCall | None = None ,
184+ method : str | None = None ,
185+ params : dict | list | None = None ,
186+ subscription_index : int | None = None ,
187+ ):
188+ def _get_args ():
189+ if rpc_call :
190+ return rpc_call
191+ elif method :
192+ return RpcCall (method , params )
193+ else :
194+ raise ValueError ("Either rpc_call or method must be provided" )
195+
196+ rpc_call = _get_args ()
185197 self .logger .debug (f"Sending: { rpc_call or method } " )
186- rpc = {
187- (None , None ): None ,
188- (None , method ): RpcCall (method , params ),
189- (rpc_call , None ): rpc_call ,
190- }
191198
192- rpc_call = rpc [(rpc_call , method )]
193199 if rpc_call is None :
194200 raise ValueError ("Either rpc_call or method must be provided" )
195201
@@ -202,4 +208,5 @@ def send(self, rpc_call: RpcCall = None, method: str = None, params: dict | list
202208 )
203209 json_body = json .dumps (rpc_call .request_body ())
204210 self .logger .debug (f"WSReq: { json_body } " )
205- self ._ws .send (json_body )
211+ if self ._ws :
212+ self ._ws .send (json_body )
0 commit comments