Skip to content

Commit ad8c5c8

Browse files
author
Ganyu Teng
committed
Update code to pass security scan
1 parent d63a0ec commit ad8c5c8

File tree

4 files changed

+13
-23
lines changed

4 files changed

+13
-23
lines changed

anollm/anollm.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from anollm.anollm_utils import _array_to_dataframe
2323
from anollm.anollm_dataset import AnoLLMDataset, AnoLLMDataCollator
2424

25+
from safetensors.torch import save_model, load_model
2526

2627
class AnoLLM:
2728
"""AnoLLM Class
@@ -265,7 +266,7 @@ def decision_function(
265266
shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
266267

267268
if feature_wise:
268-
score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).cpu().numpy() # batch * (ori_seq_len -1)
269+
score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).cpu().to(torch.float32).numpy() # batch * (ori_seq_len -1)
269270

270271
for i in range(len(encoded_batch)):
271272
for j in range(n_col):
@@ -274,7 +275,7 @@ def decision_function(
274275
col_idx = col_indices_batch[i][j]
275276
anomaly_scores[start_idx+i, col_idx, perm_idx] = score_batch[i, start_pos:end_pos].sum()
276277
elif len(self.textual_columns) > 0:
277-
score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).cpu().numpy() # batch * (ori_seq_len -1)
278+
score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).cpu().to(torch.float32).numpy() # batch * (ori_seq_len -1)
278279
for i in range(len(encoded_batch)):
279280
score_single = 0
280281
for j in range(n_col):
@@ -287,7 +288,7 @@ def decision_function(
287288
score_single += score_batch[i, start_pos:end_pos].sum()
288289
anomaly_scores[start_idx+i, perm_idx] = score_single
289290
else:
290-
score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) # remove normalization
291+
score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).to(torch.float32).sum(1) # remove normalization
291292
anomaly_scores[start_idx:end_idx, perm_idx] = score_batch.cpu().numpy()
292293
start_idx = end_idx
293294

@@ -309,24 +310,14 @@ def save_state_dict(self, path: str):
309310
else:
310311
os.mkdir(directory)
311312

312-
state_dict = self.model.state_dict()
313-
new_state_dict = OrderedDict()
314-
for k, v in state_dict.items():
315-
name = k[7:] # remove `module.`
316-
new_state_dict[name] = v
317-
# Save the model with the modified state dict
318-
torch.save(new_state_dict, path)
313+
model_to_save = self.model.module
314+
save_model(model_to_save, path)
319315

320316
def load_from_state_dict(self, path: str):
321317
"""Load AnoLLM model from state_dict
322318
323319
Args:
324320
path: path where AnoLLM model is saved
325321
"""
326-
327-
if self.efficient_finetuning == 'lora':
328-
self.model.to('cpu')
329-
state_dict = torch.load(path, map_location=torch.device('cpu'))
330-
self.model.load_state_dict(state_dict)
331-
else:
332-
self.model.load_state_dict(torch.load(path))
322+
load_model(self.model, path)
323+

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ pandas==2.2.2
77
scikit_learn==1.6.1
88
scipy==1.13.1
99
tqdm==4.66.4
10-
transformers==4.42.3
1110
ucimlrepo==0.0.7
1211
peft==0.11.1
1312
datasets==2.20.0
1413
wandb==0.17.4
15-
tf-keras==2.16.0
14+
tf-keras==2.16.0
15+
transformers==4.48.2

scripts/exp2-odds/run_anollm.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ for model in 'smol' 'smol-360'; do
1818
--batch_size $batch_size --model $model --binning standard --wandb
1919
CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\
2020
--batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard
21-
exit
2221
wandb offline
2322
for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do
2423
CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\

src/data_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def flatten(l):
459459
for npz_file in os.listdir(dataset_root):
460460
if npz_file.startswith(str(n) + '_'):
461461
print(dataset_name, npz_file)
462-
data = np.load(dataset_root / npz_file, allow_pickle=True)
462+
data = np.load(dataset_root / npz_file, allow_pickle=False)
463463
break
464464
else:
465465
ValueError('{} is not found.'.format(dataset_name))
@@ -482,11 +482,11 @@ def load_adbench_data(dataset):
482482
Utils().download_datasets(repo='jihulab')
483483

484484
if dataset == 'cardio':
485-
return np.load(dataset_root / '6_cardio.npz', allow_pickle=True)
485+
return np.load(dataset_root / '6_cardio.npz', allow_pickle=False)
486486

487487
for npz_file in os.listdir(dataset_root):
488488
if dataset in npz_file.lower():
489-
return np.load(dataset_root / npz_file, allow_pickle=True)
489+
return np.load(dataset_root / npz_file, allow_pickle=False)
490490
else:
491491
ValueError('{} is not found.'.format(dataset))
492492

0 commit comments

Comments
 (0)