2
2
from pathlib import Path
3
3
from openpyxl import Workbook
4
4
from sklearn .metrics import precision_score , recall_score , f1_score , accuracy_score , cohen_kappa_score
5
- from .translations import TRANSLATIONS
5
+ from .translations import TRANSLATIONS , get_class_names
6
6
7
7
def compute_metrics (df , language = 'cs' ):
8
8
"""Compute metrics from confusion matrix data"""
9
9
df = df .copy ()
10
10
df .columns = df .columns .astype (str )
11
11
12
- # Filter for C_1 and C_2 rows
13
- df_cm = df [df ['ClassValue' ].isin (['C_1' , 'C_2' ])]
12
+ # Find C_* columns (handle cases like "C_1 - nezasazena uroda")
13
+ c_columns = [col for col in df .columns if col .startswith ('C_' ) and '_' in col ]
14
+ if not c_columns :
15
+ raise ValueError ("No C_* columns found in the CSV file" )
16
+
17
+ # Sort columns to ensure C_1, C_2, C_3, etc. order
18
+ c_columns .sort (key = lambda x : int (x .split ('_' )[1 ].split ()[0 ]) if x .split ('_' )[1 ].split ()[0 ].isdigit () else 0 )
19
+
20
+ # Filter for rows that have ClassValue matching the C_* pattern
21
+ # Look for ClassValue entries that start with C_ and contain a number
22
+ class_values = []
23
+ for col in c_columns :
24
+ class_num = col .split ('_' )[1 ].split ()[0 ] # Extract number from C_1, C_2, etc.
25
+ if class_num .isdigit ():
26
+ class_values .append (f'C_{ class_num } ' )
27
+
28
+ if not class_values :
29
+ raise ValueError ("No valid class values found in ClassValue column" )
30
+
31
+ df_cm = df [df ['ClassValue' ].astype (str ).str .startswith (tuple (class_values ))]
32
+
33
+ if df_cm .empty :
34
+ raise ValueError ("No rows found with matching ClassValue entries" )
14
35
15
- # Handle decimal commas
16
- df_cm [ 'C_1' ] = df_cm [ 'C_1' ]. astype ( str ). str . replace ( ',' , '.' ). astype ( float ). astype ( int )
17
- df_cm ['C_2' ] = df_cm ['C_2' ].astype (str ).str .replace (',' , '.' ).astype (float ).astype (int )
36
+ # Handle decimal commas and convert to numeric
37
+ for col in c_columns :
38
+ df_cm [col ] = df_cm [col ].astype (str ).str .replace (',' , '.' ).astype (float ).astype (int )
18
39
19
- cm = df_cm [['C_1' , 'C_2' ]].to_numpy ()
40
+ # Create confusion matrix from the C_* columns
41
+ cm = df_cm [c_columns ].to_numpy ()
42
+
43
+ if cm .size == 0 or cm .shape [0 ] == 0 :
44
+ raise ValueError ("Confusion matrix is empty" )
20
45
21
- y_true = [0 ] * int (cm [0 , :].sum ()) + [1 ] * int (cm [1 , :].sum ())
22
- y_pred = [0 ] * int (cm [0 , 0 ]) + [1 ] * int (cm [0 , 1 ]) + [0 ] * int (cm [1 , 0 ]) + [1 ] * int (cm [1 , 1 ])
46
+ # Create y_true and y_pred arrays
47
+ y_true = []
48
+ y_pred = []
49
+
50
+ for i , row in enumerate (cm ):
51
+ # Add true labels (class i repeated by the sum of that row)
52
+ row_sum = int (row .sum ())
53
+ y_true .extend ([i ] * row_sum )
54
+
55
+ # Add predicted labels
56
+ for j , count in enumerate (row ):
57
+ y_pred .extend ([j ] * int (count ))
58
+
59
+ if not y_true or not y_pred :
60
+ raise ValueError ("No valid predictions found in the data" )
23
61
62
+ # Calculate metrics
24
63
precision = precision_score (y_true , y_pred , average = None , zero_division = 0 )
25
64
recall = recall_score (y_true , y_pred , average = None , zero_division = 0 )
26
65
f1 = f1_score (y_true , y_pred , average = None , zero_division = 0 )
@@ -31,12 +70,33 @@ def compute_metrics(df, language='cs'):
31
70
avg_recall = round (recall_score (y_true , y_pred , average = 'macro' , zero_division = 0 ), 3 )
32
71
avg_f1 = round (f1_score (y_true , y_pred , average = 'macro' , zero_division = 0 ), 3 )
33
72
34
- class_names = TRANSLATIONS [language ]['class_names' ]
35
- return [
36
- [class_names [0 ], round (precision [0 ], 3 ), round (recall [0 ], 3 ), round (f1 [0 ], 3 ), accuracy , kappa ],
37
- [class_names [1 ], round (precision [1 ], 3 ), round (recall [1 ], 3 ), round (f1 [1 ], 3 ), accuracy , kappa ],
38
- [class_names [2 ], avg_precision , avg_recall , avg_f1 , accuracy , kappa ]
39
- ]
73
+ # Generate class names based on the number of classes found
74
+ class_names = get_class_names (len (c_columns ), language )
75
+
76
+ # Create results for each class
77
+ results = []
78
+ for i in range (len (c_columns )):
79
+ if i < len (precision ):
80
+ results .append ([
81
+ class_names [i ] if i < len (class_names ) else f"Class { i + 1 } " ,
82
+ round (precision [i ], 3 ),
83
+ round (recall [i ], 3 ),
84
+ round (f1 [i ], 3 ),
85
+ accuracy ,
86
+ kappa
87
+ ])
88
+
89
+ # Add average row
90
+ results .append ([
91
+ class_names [- 1 ] if len (class_names ) > len (c_columns ) else "Average" ,
92
+ avg_precision ,
93
+ avg_recall ,
94
+ avg_f1 ,
95
+ accuracy ,
96
+ kappa
97
+ ])
98
+
99
+ return results
40
100
41
101
def export_to_excel (input_path , output_path , language = 'cs' ):
42
102
"""Export metrics to Excel file"""
0 commit comments