Skip to content

Commit 13f8905

Browse files
new scaling test
1 parent 3d6f8fe commit 13f8905

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tests/test_scaling.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from src.preprocessing.scaling import discretize_power, normalize_power
6+
7+
# --------------------------------------------------------------------------- #
8+
# normalize_power #
9+
# --------------------------------------------------------------------------- #
10+
11+
12+
def test_normalize_power_basic():
13+
df = pd.DataFrame({"power": [0.0, 2.0, 4.0]})
14+
out = normalize_power(df.copy())
15+
16+
expected = pd.Series([0.0, 0.5, 1.0], name="power")
17+
pd.testing.assert_series_equal(out["power"], expected, check_dtype=False)
18+
19+
20+
def test_normalize_power_constant_values():
21+
df = pd.DataFrame({"power": [3.3, 3.3, 3.3]})
22+
out = normalize_power(df.copy(), eps=1e-12)
23+
24+
assert out["power"].eq(0.0).all()
25+
26+
27+
def test_normalize_power_empty_df_raises():
28+
empty = pd.DataFrame(columns=["power"])
29+
with pytest.raises(ValueError):
30+
normalize_power(empty)
31+
32+
33+
# --------------------------------------------------------------------------- #
34+
# discretize_power #
35+
# --------------------------------------------------------------------------- #
36+
37+
38+
@pytest.mark.parametrize(
39+
"value,expected_state",
40+
[
41+
(0.00, 0),
42+
(0.02, 1),
43+
(0.05, 2),
44+
(0.20, 4),
45+
(0.83, 9),
46+
],
47+
)
48+
def test_discretize_power_states(value, expected_state):
49+
df = pd.DataFrame({"power": [value]})
50+
out = discretize_power(df.copy())
51+
52+
assert out.loc[0, "state"] == expected_state
53+
54+
55+
def test_discretize_power_preserves_power_column():
56+
values = np.linspace(0, 1, 6)
57+
df = pd.DataFrame({"power": values})
58+
out = discretize_power(df.copy())
59+
60+
pd.testing.assert_series_equal(out["power"], pd.Series(values, name="power"))

0 commit comments

Comments
 (0)