22
22
from anollm .anollm_utils import _array_to_dataframe
23
23
from anollm .anollm_dataset import AnoLLMDataset , AnoLLMDataCollator
24
24
25
+ from safetensors .torch import save_model , load_model
25
26
26
27
class AnoLLM :
27
28
"""AnoLLM Class
@@ -265,7 +266,7 @@ def decision_function(
265
266
shift_attention_mask_batch = attn_mask [..., 1 :].contiguous ()
266
267
267
268
if feature_wise :
268
- score_batch = (loss_fct (shift_logits .transpose (1 , 2 ), shift_labels ) * shift_attention_mask_batch ).cpu ().numpy () # batch * (ori_seq_len -1)
269
+ score_batch = (loss_fct (shift_logits .transpose (1 , 2 ), shift_labels ) * shift_attention_mask_batch ).cpu ().to ( torch . float32 ). numpy () # batch * (ori_seq_len -1)
269
270
270
271
for i in range (len (encoded_batch )):
271
272
for j in range (n_col ):
@@ -274,7 +275,7 @@ def decision_function(
274
275
col_idx = col_indices_batch [i ][j ]
275
276
anomaly_scores [start_idx + i , col_idx , perm_idx ] = score_batch [i , start_pos :end_pos ].sum ()
276
277
elif len (self .textual_columns ) > 0 :
277
- score_batch = (loss_fct (shift_logits .transpose (1 , 2 ), shift_labels ) * shift_attention_mask_batch ).cpu ().numpy () # batch * (ori_seq_len -1)
278
+ score_batch = (loss_fct (shift_logits .transpose (1 , 2 ), shift_labels ) * shift_attention_mask_batch ).cpu ().to ( torch . float32 ). numpy () # batch * (ori_seq_len -1)
278
279
for i in range (len (encoded_batch )):
279
280
score_single = 0
280
281
for j in range (n_col ):
@@ -287,7 +288,7 @@ def decision_function(
287
288
score_single += score_batch [i , start_pos :end_pos ].sum ()
288
289
anomaly_scores [start_idx + i , perm_idx ] = score_single
289
290
else :
290
- score_batch = (loss_fct (shift_logits .transpose (1 , 2 ), shift_labels ) * shift_attention_mask_batch ).sum (1 ) # remove normalization
291
+ score_batch = (loss_fct (shift_logits .transpose (1 , 2 ), shift_labels ) * shift_attention_mask_batch ).to ( torch . float32 ). sum (1 ) # remove normalization
291
292
anomaly_scores [start_idx :end_idx , perm_idx ] = score_batch .cpu ().numpy ()
292
293
start_idx = end_idx
293
294
@@ -309,24 +310,14 @@ def save_state_dict(self, path: str):
309
310
else :
310
311
os .mkdir (directory )
311
312
312
- state_dict = self .model .state_dict ()
313
- new_state_dict = OrderedDict ()
314
- for k , v in state_dict .items ():
315
- name = k [7 :] # remove `module.`
316
- new_state_dict [name ] = v
317
- # Save the model with the modified state dict
318
- torch .save (new_state_dict , path )
313
+ model_to_save = self .model .module
314
+ save_model (model_to_save , path )
319
315
320
316
def load_from_state_dict (self , path : str ):
321
317
"""Load AnoLLM model from state_dict
322
318
323
319
Args:
324
320
path: path where AnoLLM model is saved
325
321
"""
326
-
327
- if self .efficient_finetuning == 'lora' :
328
- self .model .to ('cpu' )
329
- state_dict = torch .load (path , map_location = torch .device ('cpu' ))
330
- self .model .load_state_dict (state_dict )
331
- else :
332
- self .model .load_state_dict (torch .load (path ))
322
+ load_model (self .model , path )
323
+
0 commit comments