Skip to content

Commit 9b28c51

Browse files
committed
[0.1.1] - 2025-07-07
### Added - Support for CSV files where the `ClassValue` column contains descriptive text after the class code (e.g., `C_1 - nezasazena uroda`). - Automatic detection and processing of any number of classes (`C_1`, `C_2`, `C_3`, ...). ### Fixed - Fixed a crash when loading CSV files where `ClassValue` did not exactly match `C_1`, `C_2`, etc. - Improved error messages for invalid or empty data in CSV files. ### Changed - Metrics calculation is now robust to various CSV formats and works for any number of classes.
1 parent 824015d commit 9b28c51

File tree

3 files changed

+101
-16
lines changed

3 files changed

+101
-16
lines changed

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
# Changelog
22

3+
## [0.1.1] - 2025-07-07
4+
5+
### Added
6+
- Support for CSV files where the `ClassValue` column contains descriptive text after the class code (e.g., `C_1 - nezasazena uroda`).
7+
- Automatic detection and processing of any number of classes (`C_1`, `C_2`, `C_3`, ...).
8+
9+
### Fixed
10+
- Fixed a crash when loading CSV files where `ClassValue` did not exactly match `C_1`, `C_2`, etc.
11+
- Improved error messages for invalid or empty data in CSV files.
12+
13+
### Changed
14+
- Metrics calculation is now robust to various CSV formats and works for any number of classes.
15+
16+
---
17+
318
## [0.1.0] - 2025-06-22
419

520
### Added

core/metrics.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,64 @@
22
from pathlib import Path
33
from openpyxl import Workbook
44
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, cohen_kappa_score
5-
from .translations import TRANSLATIONS
5+
from .translations import TRANSLATIONS, get_class_names
66

77
def compute_metrics(df, language='cs'):
88
"""Compute metrics from confusion matrix data"""
99
df = df.copy()
1010
df.columns = df.columns.astype(str)
1111

12-
# Filter for C_1 and C_2 rows
13-
df_cm = df[df['ClassValue'].isin(['C_1', 'C_2'])]
12+
# Find C_* columns (handle cases like "C_1 - nezasazena uroda")
13+
c_columns = [col for col in df.columns if col.startswith('C_') and '_' in col]
14+
if not c_columns:
15+
raise ValueError("No C_* columns found in the CSV file")
16+
17+
# Sort columns to ensure C_1, C_2, C_3, etc. order
18+
c_columns.sort(key=lambda x: int(x.split('_')[1].split()[0]) if x.split('_')[1].split()[0].isdigit() else 0)
19+
20+
# Filter for rows that have ClassValue matching the C_* pattern
21+
# Look for ClassValue entries that start with C_ and contain a number
22+
class_values = []
23+
for col in c_columns:
24+
class_num = col.split('_')[1].split()[0] # Extract number from C_1, C_2, etc.
25+
if class_num.isdigit():
26+
class_values.append(f'C_{class_num}')
27+
28+
if not class_values:
29+
raise ValueError("No valid class values found in ClassValue column")
30+
31+
df_cm = df[df['ClassValue'].astype(str).str.startswith(tuple(class_values))]
32+
33+
if df_cm.empty:
34+
raise ValueError("No rows found with matching ClassValue entries")
1435

15-
# Handle decimal commas
16-
df_cm['C_1'] = df_cm['C_1'].astype(str).str.replace(',', '.').astype(float).astype(int)
17-
df_cm['C_2'] = df_cm['C_2'].astype(str).str.replace(',', '.').astype(float).astype(int)
36+
# Handle decimal commas and convert to numeric
37+
for col in c_columns:
38+
df_cm[col] = df_cm[col].astype(str).str.replace(',', '.').astype(float).astype(int)
1839

19-
cm = df_cm[['C_1', 'C_2']].to_numpy()
40+
# Create confusion matrix from the C_* columns
41+
cm = df_cm[c_columns].to_numpy()
42+
43+
if cm.size == 0 or cm.shape[0] == 0:
44+
raise ValueError("Confusion matrix is empty")
2045

21-
y_true = [0] * int(cm[0, :].sum()) + [1] * int(cm[1, :].sum())
22-
y_pred = [0] * int(cm[0, 0]) + [1] * int(cm[0, 1]) + [0] * int(cm[1, 0]) + [1] * int(cm[1, 1])
46+
# Create y_true and y_pred arrays
47+
y_true = []
48+
y_pred = []
49+
50+
for i, row in enumerate(cm):
51+
# Add true labels (class i repeated by the sum of that row)
52+
row_sum = int(row.sum())
53+
y_true.extend([i] * row_sum)
54+
55+
# Add predicted labels
56+
for j, count in enumerate(row):
57+
y_pred.extend([j] * int(count))
58+
59+
if not y_true or not y_pred:
60+
raise ValueError("No valid predictions found in the data")
2361

62+
# Calculate metrics
2463
precision = precision_score(y_true, y_pred, average=None, zero_division=0)
2564
recall = recall_score(y_true, y_pred, average=None, zero_division=0)
2665
f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
@@ -31,12 +70,33 @@ def compute_metrics(df, language='cs'):
3170
avg_recall = round(recall_score(y_true, y_pred, average='macro', zero_division=0), 3)
3271
avg_f1 = round(f1_score(y_true, y_pred, average='macro', zero_division=0), 3)
3372

34-
class_names = TRANSLATIONS[language]['class_names']
35-
return [
36-
[class_names[0], round(precision[0], 3), round(recall[0], 3), round(f1[0], 3), accuracy, kappa],
37-
[class_names[1], round(precision[1], 3), round(recall[1], 3), round(f1[1], 3), accuracy, kappa],
38-
[class_names[2], avg_precision, avg_recall, avg_f1, accuracy, kappa]
39-
]
73+
# Generate class names based on the number of classes found
74+
class_names = get_class_names(len(c_columns), language)
75+
76+
# Create results for each class
77+
results = []
78+
for i in range(len(c_columns)):
79+
if i < len(precision):
80+
results.append([
81+
class_names[i] if i < len(class_names) else f"Class {i+1}",
82+
round(precision[i], 3),
83+
round(recall[i], 3),
84+
round(f1[i], 3),
85+
accuracy,
86+
kappa
87+
])
88+
89+
# Add average row
90+
results.append([
91+
class_names[-1] if len(class_names) > len(c_columns) else "Average",
92+
avg_precision,
93+
avg_recall,
94+
avg_f1,
95+
accuracy,
96+
kappa
97+
])
98+
99+
return results
40100

41101
def export_to_excel(input_path, output_path, language='cs'):
42102
"""Export metrics to Excel file"""

core/translations.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,14 @@
118118
'excel_metrics_sheet': 'Metrics',
119119
'excel_data_sheet': 'Data'
120120
}
121-
}
121+
}
122+
123+
def get_class_names(num_classes, language='cs'):
124+
"""Generate class names based on the number of classes found"""
125+
if language == 'cs':
126+
class_names = [f"C_{i+1}" for i in range(num_classes)]
127+
class_names.append("Průměr")
128+
else: # English
129+
class_names = [f"C_{i+1}" for i in range(num_classes)]
130+
class_names.append("Average")
131+
return class_names

0 commit comments

Comments
 (0)