@@ -97,35 +97,38 @@ def __init__(self, engine: "Engine", config: "Config"):
97
97
self ._addr = engine .addr
98
98
self ._log_dir = engine .log_dir
99
99
self ._idx = engine .idx
100
- self .connection_metrics = []
100
+ self .connection_metrics = {}
101
101
102
102
neighbors : str = self ._config .participant ["network_args" ]["neighbors" ]
103
- self .connection_metrics = {}
104
103
for nei in neighbors .split ():
105
104
self .connection_metrics [f"{ nei } " ] = Metrics ()
106
105
107
- self ._with_reputation = self ._config .participant ["defense_args" ]["with_reputation" ]
108
- self ._reputation_metrics = self ._config .participant ["defense_args" ]["reputation_metrics" ]
109
- self ._initial_reputation = float (self ._config .participant ["defense_args" ]["initial_reputation" ])
110
- self ._weighting_factor = self ._config .participant ["defense_args" ]["weighting_factor" ]
111
- self ._weight_model_arrival_latency = float (
112
- self ._config .participant ["defense_args" ]["weight_model_arrival_latency" ]
113
- )
114
- self ._weight_model_similarity = float (self ._config .participant ["defense_args" ]["weight_model_similarity" ])
115
- self ._weight_num_messages = float (self ._config .participant ["defense_args" ]["weight_num_messages" ])
116
- self ._weight_fraction_params_changed = float (
117
- self ._config .participant ["defense_args" ]["weight_fraction_params_changed" ]
118
- )
106
+ reputation_config = self ._config .participant ["defense_args" ]["reputation" ]
107
+ self ._enabled = reputation_config ["enabled" ]
108
+ self ._metrics = reputation_config ["metrics" ]
109
+ self ._initial_reputation = float (reputation_config ["initial_reputation" ])
110
+ self ._weighting_factor = reputation_config ["weighting_factor" ]
111
+
112
+ # Extract weights from metrics if using static weighting
113
+ if self ._weighting_factor == "static" :
114
+ self ._weight_model_arrival_latency = float (self ._metrics .get ("modelArrivalLatency" , {}).get ("weight" , 0.25 ))
115
+ self ._weight_model_similarity = float (self ._metrics .get ("modelSimilarity" , {}).get ("weight" , 0.25 ))
116
+ self ._weight_num_messages = float (self ._metrics .get ("numMessages" , {}).get ("weight" , 0.25 ))
117
+ self ._weight_fraction_params_changed = float (self ._metrics .get ("fractionParametersChanged" , {}).get ("weight" , 0.25 ))
118
+ else :
119
+ self ._weight_model_arrival_latency = 0.25
120
+ self ._weight_model_similarity = 0.25
121
+ self ._weight_num_messages = 0.25
122
+ self ._weight_fraction_params_changed = 0.25
119
123
120
- msg = f"Reputation system: { self ._with_reputation } "
121
- msg += f"\n Reputation metrics: { self ._reputation_metrics } "
124
+ msg = f"Reputation system: { self ._enabled } "
125
+ msg += f"\n Reputation metrics: { self ._metrics } "
122
126
msg += f"\n Initial reputation: { self ._initial_reputation } "
123
127
msg += f"\n Weighting factor: { self ._weighting_factor } "
124
- if self ._weighting_factor == "static" :
125
- msg += f"\n Weight model arrival latency: { self ._weight_model_arrival_latency } "
126
- msg += f"\n Weight model similarity: { self ._weight_model_similarity } "
127
- msg += f"\n Weight number of messages: { self ._weight_num_messages } "
128
- msg += f"\n Weight fraction of parameters changed: { self ._weight_fraction_params_changed } "
128
+ msg += f"\n Weight model arrival latency: { self ._weight_model_arrival_latency } "
129
+ msg += f"\n Weight model similarity: { self ._weight_model_similarity } "
130
+ msg += f"\n Weight number of messages: { self ._weight_num_messages } "
131
+ msg += f"\n Weight fraction of parameters changed: { self ._weight_fraction_params_changed } "
129
132
print_msg_box (msg = msg , indent = 2 , title = "Defense information" )
130
133
131
134
@property
@@ -205,30 +208,22 @@ def save_data(
205
208
logging .exception ("Error saving data" )
206
209
207
210
async def setup (self ):
208
- """
209
- Setup the reputation system by subscribing to various events.
210
-
211
- This function enables the reputation system and subscribes to events based on active metrics.
212
- """
213
- if self ._with_reputation :
214
- logging .info ("Reputation system enabled" )
211
+ """Set up the reputation system by subscribing to relevant events."""
212
+ if self ._enabled :
215
213
await EventManager .get_instance ().subscribe_node_event (RoundStartEvent , self .on_round_start )
216
214
await EventManager .get_instance ().subscribe_node_event (AggregationEvent , self .calculate_reputation )
217
- if self ._reputation_metrics .get ("model_similarity " , False ):
215
+ if self ._metrics .get ("modelSimilarity" , {}). get ( "enabled " , False ):
218
216
await EventManager .get_instance ().subscribe_node_event (UpdateReceivedEvent , self .recollect_similarity )
219
- if self ._reputation_metrics .get ("fraction_parameters_changed " , False ):
217
+ if self ._metrics .get ("fractionParametersChanged" , {}). get ( "enabled " , False ):
220
218
await EventManager .get_instance ().subscribe_node_event (
221
219
UpdateReceivedEvent , self .recollect_fraction_of_parameters_changed
222
220
)
223
- if self ._reputation_metrics .get ("num_messages " , False ):
221
+ if self ._metrics .get ("numMessages" , {}). get ( "enabled " , False ):
224
222
await EventManager .get_instance ().subscribe (("model" , "update" ), self .recollect_number_message )
225
223
await EventManager .get_instance ().subscribe (("model" , "initialization" ), self .recollect_number_message )
226
- await EventManager .get_instance ().subscribe (("control" , "alive" ), self .recollect_number_message )
227
- await EventManager .get_instance ().subscribe (
228
- ("federation" , "federation_models_included" ), self .recollect_number_message
229
- )
224
+ await EventManager .get_instance ().subscribe (("model" , "aggregation" ), self .recollect_number_message )
230
225
await EventManager .get_instance ().subscribe (("reputation" , "share" ), self .recollect_number_message )
231
- if self ._reputation_metrics .get ("model_arrival_latency " , False ):
226
+ if self ._metrics .get ("modelArrivalLatency" , {}). get ( "enabled " , False ):
232
227
await EventManager .get_instance ().subscribe_node_event (
233
228
UpdateReceivedEvent , self .recollect_model_arrival_latency
234
229
)
@@ -250,7 +245,7 @@ def init_reputation(
250
245
logging .error ("init_reputation | No federation nodes provided" )
251
246
return
252
247
253
- if self ._with_reputation :
248
+ if self ._enabled :
254
249
neighbors = self .is_valid_ip (federation_nodes )
255
250
256
251
if not neighbors :
@@ -362,7 +357,7 @@ async def _calculate_dynamic_reputation(self, addr, neighbors):
362
357
average_weights = {}
363
358
364
359
for metric_name in self .history_data .keys ():
365
- if self ._reputation_metrics .get (metric_name , False ):
360
+ if self ._metrics .get (metric_name , False ):
366
361
valid_entries = [
367
362
entry
368
363
for entry in self .history_data [metric_name ]
@@ -378,7 +373,7 @@ async def _calculate_dynamic_reputation(self, addr, neighbors):
378
373
for nei in neighbors :
379
374
metric_values = {}
380
375
for metric_name in self .history_data .keys ():
381
- if self ._reputation_metrics .get (metric_name , False ):
376
+ if self ._metrics .get (metric_name , False ):
382
377
for entry in self .history_data .get (metric_name , []):
383
378
if (
384
379
entry ["round" ] == self ._engine .get_round ()
@@ -1316,9 +1311,9 @@ async def calculate_reputation(self, ae: AggregationEvent):
1316
1311
ae (AggregationEvent): The event containing aggregated updates.
1317
1312
"""
1318
1313
(updates , _ , _ ) = await ae .get_event_data ()
1319
- if self ._with_reputation :
1314
+ if self ._enabled :
1320
1315
logging .info (f"Calculating reputation at round { self ._engine .get_round ()} " )
1321
- logging .info (f"Active metrics: { self ._reputation_metrics } " )
1316
+ logging .info (f"Active metrics: { self ._metrics } " )
1322
1317
logging .info (f"rejected nodes at round { self ._engine .get_round ()} : { self .rejected_nodes } " )
1323
1318
1324
1319
neighbors = set (await self ._engine ._cm .get_addrs_current_connections (only_direct = True ))
@@ -1335,7 +1330,7 @@ async def calculate_reputation(self, ae: AggregationEvent):
1335
1330
self ._idx ,
1336
1331
self ._addr ,
1337
1332
nei ,
1338
- metrics_active = self ._reputation_metrics ,
1333
+ metrics_active = self ._metrics ,
1339
1334
)
1340
1335
1341
1336
if self ._weighting_factor == "dynamic" :
@@ -1348,7 +1343,7 @@ async def calculate_reputation(self, ae: AggregationEvent):
1348
1343
self ._engine .get_round (),
1349
1344
self ._addr ,
1350
1345
nei ,
1351
- self ._reputation_metrics ,
1346
+ self ._metrics ,
1352
1347
)
1353
1348
1354
1349
if self ._weighting_factor == "static" and self ._engine .get_round () >= 5 :
@@ -1368,7 +1363,7 @@ async def calculate_reputation(self, ae: AggregationEvent):
1368
1363
if self ._weighting_factor == "dynamic" and self ._engine .get_round () >= 5 :
1369
1364
await self ._calculate_dynamic_reputation (self ._addr , neighbors )
1370
1365
1371
- if self ._engine .get_round () < 5 and self ._with_reputation :
1366
+ if self ._engine .get_round () < 5 and self ._enabled :
1372
1367
federation = self ._engine .config .participant ["network_args" ]["neighbors" ].split ()
1373
1368
self .init_reputation (
1374
1369
self ._addr ,
@@ -1571,7 +1566,7 @@ async def recollect_similarity(self, ure: UpdateReceivedEvent):
1571
1566
ure (UpdateReceivedEvent): The event data containing model update information.
1572
1567
"""
1573
1568
(decoded_model , weight , nei , round_num , local ) = await ure .get_event_data ()
1574
- if self ._with_reputation and self ._reputation_metrics .get ("model_similarity" ):
1569
+ if self ._enabled and self ._metrics .get ("model_similarity" ):
1575
1570
if self ._engine .config .participant ["adaptive_args" ]["model_similarity" ]:
1576
1571
if nei != self ._addr :
1577
1572
logging .info ("🤖 handle_model_message | Checking model similarity" )
0 commit comments