Skip to content

Commit 8a70278

Browse files
fix reputation system
1 parent 70ad813 commit 8a70278

File tree

8 files changed

+104
-141
lines changed

8 files changed

+104
-141
lines changed

nebula/addons/reputation/reputation.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -97,35 +97,38 @@ def __init__(self, engine: "Engine", config: "Config"):
9797
self._addr = engine.addr
9898
self._log_dir = engine.log_dir
9999
self._idx = engine.idx
100-
self.connection_metrics = []
100+
self.connection_metrics = {}
101101

102102
neighbors: str = self._config.participant["network_args"]["neighbors"]
103-
self.connection_metrics = {}
104103
for nei in neighbors.split():
105104
self.connection_metrics[f"{nei}"] = Metrics()
106105

107-
self._with_reputation = self._config.participant["defense_args"]["with_reputation"]
108-
self._reputation_metrics = self._config.participant["defense_args"]["reputation_metrics"]
109-
self._initial_reputation = float(self._config.participant["defense_args"]["initial_reputation"])
110-
self._weighting_factor = self._config.participant["defense_args"]["weighting_factor"]
111-
self._weight_model_arrival_latency = float(
112-
self._config.participant["defense_args"]["weight_model_arrival_latency"]
113-
)
114-
self._weight_model_similarity = float(self._config.participant["defense_args"]["weight_model_similarity"])
115-
self._weight_num_messages = float(self._config.participant["defense_args"]["weight_num_messages"])
116-
self._weight_fraction_params_changed = float(
117-
self._config.participant["defense_args"]["weight_fraction_params_changed"]
118-
)
106+
reputation_config = self._config.participant["defense_args"]["reputation"]
107+
self._enabled = reputation_config["enabled"]
108+
self._metrics = reputation_config["metrics"]
109+
self._initial_reputation = float(reputation_config["initial_reputation"])
110+
self._weighting_factor = reputation_config["weighting_factor"]
111+
112+
# Extract weights from metrics if using static weighting
113+
if self._weighting_factor == "static":
114+
self._weight_model_arrival_latency = float(self._metrics.get("modelArrivalLatency", {}).get("weight", 0.25))
115+
self._weight_model_similarity = float(self._metrics.get("modelSimilarity", {}).get("weight", 0.25))
116+
self._weight_num_messages = float(self._metrics.get("numMessages", {}).get("weight", 0.25))
117+
self._weight_fraction_params_changed = float(self._metrics.get("fractionParametersChanged", {}).get("weight", 0.25))
118+
else:
119+
self._weight_model_arrival_latency = 0.25
120+
self._weight_model_similarity = 0.25
121+
self._weight_num_messages = 0.25
122+
self._weight_fraction_params_changed = 0.25
119123

120-
msg = f"Reputation system: {self._with_reputation}"
121-
msg += f"\nReputation metrics: {self._reputation_metrics}"
124+
msg = f"Reputation system: {self._enabled}"
125+
msg += f"\nReputation metrics: {self._metrics}"
122126
msg += f"\nInitial reputation: {self._initial_reputation}"
123127
msg += f"\nWeighting factor: {self._weighting_factor}"
124-
if self._weighting_factor == "static":
125-
msg += f"\nWeight model arrival latency: {self._weight_model_arrival_latency}"
126-
msg += f"\nWeight model similarity: {self._weight_model_similarity}"
127-
msg += f"\nWeight number of messages: {self._weight_num_messages}"
128-
msg += f"\nWeight fraction of parameters changed: {self._weight_fraction_params_changed}"
128+
msg += f"\nWeight model arrival latency: {self._weight_model_arrival_latency}"
129+
msg += f"\nWeight model similarity: {self._weight_model_similarity}"
130+
msg += f"\nWeight number of messages: {self._weight_num_messages}"
131+
msg += f"\nWeight fraction of parameters changed: {self._weight_fraction_params_changed}"
129132
print_msg_box(msg=msg, indent=2, title="Defense information")
130133

131134
@property
@@ -205,30 +208,22 @@ def save_data(
205208
logging.exception("Error saving data")
206209

207210
async def setup(self):
208-
"""
209-
Setup the reputation system by subscribing to various events.
210-
211-
This function enables the reputation system and subscribes to events based on active metrics.
212-
"""
213-
if self._with_reputation:
214-
logging.info("Reputation system enabled")
211+
"""Set up the reputation system by subscribing to relevant events."""
212+
if self._enabled:
215213
await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self.on_round_start)
216214
await EventManager.get_instance().subscribe_node_event(AggregationEvent, self.calculate_reputation)
217-
if self._reputation_metrics.get("model_similarity", False):
215+
if self._metrics.get("modelSimilarity", {}).get("enabled", False):
218216
await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.recollect_similarity)
219-
if self._reputation_metrics.get("fraction_parameters_changed", False):
217+
if self._metrics.get("fractionParametersChanged", {}).get("enabled", False):
220218
await EventManager.get_instance().subscribe_node_event(
221219
UpdateReceivedEvent, self.recollect_fraction_of_parameters_changed
222220
)
223-
if self._reputation_metrics.get("num_messages", False):
221+
if self._metrics.get("numMessages", {}).get("enabled", False):
224222
await EventManager.get_instance().subscribe(("model", "update"), self.recollect_number_message)
225223
await EventManager.get_instance().subscribe(("model", "initialization"), self.recollect_number_message)
226-
await EventManager.get_instance().subscribe(("control", "alive"), self.recollect_number_message)
227-
await EventManager.get_instance().subscribe(
228-
("federation", "federation_models_included"), self.recollect_number_message
229-
)
224+
await EventManager.get_instance().subscribe(("model", "aggregation"), self.recollect_number_message)
230225
await EventManager.get_instance().subscribe(("reputation", "share"), self.recollect_number_message)
231-
if self._reputation_metrics.get("model_arrival_latency", False):
226+
if self._metrics.get("modelArrivalLatency", {}).get("enabled", False):
232227
await EventManager.get_instance().subscribe_node_event(
233228
UpdateReceivedEvent, self.recollect_model_arrival_latency
234229
)
@@ -250,7 +245,7 @@ def init_reputation(
250245
logging.error("init_reputation | No federation nodes provided")
251246
return
252247

253-
if self._with_reputation:
248+
if self._enabled:
254249
neighbors = self.is_valid_ip(federation_nodes)
255250

256251
if not neighbors:
@@ -362,7 +357,7 @@ async def _calculate_dynamic_reputation(self, addr, neighbors):
362357
average_weights = {}
363358

364359
for metric_name in self.history_data.keys():
365-
if self._reputation_metrics.get(metric_name, False):
360+
if self._metrics.get(metric_name, False):
366361
valid_entries = [
367362
entry
368363
for entry in self.history_data[metric_name]
@@ -378,7 +373,7 @@ async def _calculate_dynamic_reputation(self, addr, neighbors):
378373
for nei in neighbors:
379374
metric_values = {}
380375
for metric_name in self.history_data.keys():
381-
if self._reputation_metrics.get(metric_name, False):
376+
if self._metrics.get(metric_name, False):
382377
for entry in self.history_data.get(metric_name, []):
383378
if (
384379
entry["round"] == self._engine.get_round()
@@ -1316,9 +1311,9 @@ async def calculate_reputation(self, ae: AggregationEvent):
13161311
ae (AggregationEvent): The event containing aggregated updates.
13171312
"""
13181313
(updates, _, _) = await ae.get_event_data()
1319-
if self._with_reputation:
1314+
if self._enabled:
13201315
logging.info(f"Calculating reputation at round {self._engine.get_round()}")
1321-
logging.info(f"Active metrics: {self._reputation_metrics}")
1316+
logging.info(f"Active metrics: {self._metrics}")
13221317
logging.info(f"rejected nodes at round {self._engine.get_round()}: {self.rejected_nodes}")
13231318

13241319
neighbors = set(await self._engine._cm.get_addrs_current_connections(only_direct=True))
@@ -1335,7 +1330,7 @@ async def calculate_reputation(self, ae: AggregationEvent):
13351330
self._idx,
13361331
self._addr,
13371332
nei,
1338-
metrics_active=self._reputation_metrics,
1333+
metrics_active=self._metrics,
13391334
)
13401335

13411336
if self._weighting_factor == "dynamic":
@@ -1348,7 +1343,7 @@ async def calculate_reputation(self, ae: AggregationEvent):
13481343
self._engine.get_round(),
13491344
self._addr,
13501345
nei,
1351-
self._reputation_metrics,
1346+
self._metrics,
13521347
)
13531348

13541349
if self._weighting_factor == "static" and self._engine.get_round() >= 5:
@@ -1368,7 +1363,7 @@ async def calculate_reputation(self, ae: AggregationEvent):
13681363
if self._weighting_factor == "dynamic" and self._engine.get_round() >= 5:
13691364
await self._calculate_dynamic_reputation(self._addr, neighbors)
13701365

1371-
if self._engine.get_round() < 5 and self._with_reputation:
1366+
if self._engine.get_round() < 5 and self._enabled:
13721367
federation = self._engine.config.participant["network_args"]["neighbors"].split()
13731368
self.init_reputation(
13741369
self._addr,
@@ -1571,7 +1566,7 @@ async def recollect_similarity(self, ure: UpdateReceivedEvent):
15711566
ure (UpdateReceivedEvent): The event data containing model update information.
15721567
"""
15731568
(decoded_model, weight, nei, round_num, local) = await ure.get_event_data()
1574-
if self._with_reputation and self._reputation_metrics.get("model_similarity"):
1569+
if self._enabled and self._metrics.get("model_similarity"):
15751570
if self._engine.config.participant["adaptive_args"]["model_similarity"]:
15761571
if nei != self._addr:
15771572
logging.info("🤖 handle_model_message | Checking model similarity")

nebula/config/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __setup_logging(self, log_filename):
9090
log_console_format = f"{CYAN}%(asctime)s - {self.participant['device_args']['name']} - [%(filename)s:%(lineno)d]{RESET}\n%(message)s"
9191

9292
console_handler = logging.StreamHandler()
93-
console_handler.setLevel(logging.INFO if self.participant["device_args"]["logging"] else logging.CRITICAL)
93+
console_handler.setLevel(logging.CRITICAL)
9494
console_handler.setFormatter(Formatter(log_console_format))
9595

9696
file_handler = FileHandler(f"{log_filename}.log", mode="w", encoding="utf-8")
@@ -126,7 +126,7 @@ def __set_training_logging(self):
126126
level = logging.DEBUG if self.participant["device_args"]["logging"] else logging.CRITICAL
127127

128128
console_handler = logging.StreamHandler()
129-
console_handler.setLevel(logging.INFO if self.participant["device_args"]["logging"] else logging.CRITICAL)
129+
console_handler.setLevel(logging.CRITICAL)
130130
console_handler.setFormatter(Formatter(log_console_format))
131131

132132
file_handler = FileHandler(f"{training_log_filename}.log", mode="w", encoding="utf-8")

nebula/controller/controller.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from fastapi.concurrency import asynccontextmanager
1515

1616
from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished
17-
from nebula.controller.scenarios import Scenario, ScenarioManagement
1817
from nebula.utils import DockerUtils
1918

2019

@@ -278,6 +277,8 @@ async def run_scenario(
278277

279278
import subprocess
280279

280+
from nebula.controller.scenarios import ScenarioManagement
281+
281282
# Manager for the actual scenario
282283
scenarioManagement = ScenarioManagement(scenario_data, user)
283284

@@ -314,6 +315,8 @@ async def stop_scenario(
314315
username: str = Body(..., embed=True),
315316
all: bool = Body(False, embed=True),
316317
):
318+
from nebula.controller.scenarios import ScenarioManagement
319+
317320
ScenarioManagement.stop_participants(scenario_name)
318321
DockerUtils.remove_containers_by_prefix(f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_{username}-participant")
319322
DockerUtils.remove_docker_network(
@@ -348,6 +351,7 @@ async def remove_scenario(
348351
dict: A message indicating successful removal.
349352
"""
350353
from nebula.controller.database import remove_scenario_by_name
354+
from nebula.controller.scenarios import ScenarioManagement
351355

352356
try:
353357
remove_scenario_by_name(scenario_name)
@@ -415,6 +419,7 @@ async def update_scenario(
415419
dict: A message confirming the update.
416420
"""
417421
from nebula.controller.database import scenario_update_record
422+
from nebula.controller.scenarios import Scenario
418423

419424
try:
420425
scenario = Scenario.from_dict(scenario)

nebula/controller/database.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ async def initialize_databases(databases_dir):
142142
network_gateway TEXT,
143143
epochs TEXT,
144144
attack_params TEXT,
145-
with_reputation TEXT,
145+
reputation TEXT,
146146
random_geo TEXT,
147147
latitude TEXT,
148148
longitude TEXT,
@@ -190,7 +190,7 @@ async def initialize_databases(databases_dir):
190190
"network_gateway": "TEXT",
191191
"epochs": "TEXT",
192192
"attack_params": "TEXT",
193-
"with_reputation": "TEXT",
193+
"reputation": "TEXT",
194194
"random_geo": "TEXT",
195195
"latitude": "TEXT",
196196
"longitude": "TEXT",
@@ -576,7 +576,7 @@ def scenario_update_record(name, start_time, end_time, scenario, status, role, u
576576
network_gateway,
577577
epochs,
578578
attack_params,
579-
with_reputation,
579+
reputation,
580580
random_geo,
581581
latitude,
582582
longitude,
@@ -626,7 +626,7 @@ def scenario_update_record(name, start_time, end_time, scenario, status, role, u
626626
scenario.network_gateway,
627627
scenario.epochs,
628628
json.dumps(scenario.attack_params),
629-
scenario.with_reputation,
629+
json.dumps(scenario.reputation),
630630
scenario.random_geo,
631631
scenario.latitude,
632632
scenario.longitude,
@@ -673,7 +673,7 @@ def scenario_update_record(name, start_time, end_time, scenario, status, role, u
673673
network_gateway = ?,
674674
epochs = ?,
675675
attack_params = ?,
676-
with_reputation = ?,
676+
reputation = ?,
677677
random_geo = ?,
678678
latitude = ?,
679679
longitude = ?,
@@ -719,11 +719,8 @@ def scenario_update_record(name, start_time, end_time, scenario, status, role, u
719719
scenario.network_subnet,
720720
scenario.network_gateway,
721721
scenario.epochs,
722-
scenario.poisoned_node_percent,
723-
scenario.poisoned_sample_percent,
724-
scenario.poisoned_noise_percent,
725722
json.dumps(scenario.attack_params),
726-
scenario.with_reputation,
723+
json.dumps(scenario.reputation),
727724
scenario.random_geo,
728725
scenario.latitude,
729726
scenario.longitude,

0 commit comments

Comments
 (0)