Skip to content

Commit 43299fa

Browse files
committed
tests: updated test files for model wrapper functionality
1 parent 6f0db92 commit 43299fa

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

tests/test_linear_model_wrapper.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from app.models.linear_model_wrapper import LinearModelWrapper
5+
6+
7+
def test_linear_wrapper_fit_and_predict():
8+
"""
9+
Test that LinearModelWrapper can fit a model and return predictions.
10+
Ensures the fitted model exposes .predict() and matches expected length.
11+
"""
12+
X = pd.DataFrame({"x1": np.random.randn(100)})
13+
y = 3 * X["x1"] + np.random.randn(100)
14+
15+
model = LinearModelWrapper(X, y).fit()
16+
preds = model.predict()
17+
18+
assert len(preds) == len(y)
19+
assert hasattr(model, "model")
20+
21+
22+
def test_linear_wrapper_residuals_and_fitted():
23+
"""
24+
Verify that residuals + fitted values approximately equal the true target.
25+
Confirms internal math and data shape integrity.
26+
"""
27+
X = pd.DataFrame({"x1": np.random.randn(100)})
28+
y = 2 * X["x1"] + np.random.randn(100)
29+
30+
model = LinearModelWrapper(X, y).fit()
31+
residuals = model.residuals()
32+
fitted = model.fitted()
33+
34+
# residuals = y - y_pred
35+
np.testing.assert_allclose(y.values, residuals + fitted, rtol=1e-4)
36+
37+
38+
def test_linear_wrapper_summary():
39+
"""
40+
Confirm the summary() method returns expected keys and types.
41+
"""
42+
X = pd.DataFrame({"x1": np.random.randn(50)})
43+
y = X["x1"] + np.random.randn(50)
44+
45+
model = LinearModelWrapper(X, y).fit()
46+
summary = model.summary()
47+
48+
assert "model_type" in summary
49+
assert summary["model_type"].lower() == "linear regression"
50+
assert 0 <= summary["r_squared"] <= 1

tests/test_linearity.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# tests/test_linearity.py
2+
import numpy as np
3+
import pandas as pd
4+
25
from app.core import linearity
6+
from app.core.linearity import check_linearity
37
from app.data import simulated_data
8+
from app.models.linear_model_wrapper import LinearModelWrapper
49

510

611
def test_linearity_r_squared_threshold():
@@ -21,3 +26,15 @@ def test_linearity_plot_generation():
2126
result = linearity.check_linearity(df["x"], df["y"], return_plot=True)
2227
assert result.plot_base64 is not None
2328
assert result.plot_base64.startswith("iVBOR") # PNG header in base64
29+
30+
31+
def test_linearity_with_model_wrapper():
32+
"""
33+
Ensure check_linearity works when a pre-fit model_wrapper is provided.
34+
"""
35+
X = pd.DataFrame({"x1": np.random.randn(100)})
36+
y = 2 * X["x1"] + np.random.randn(100)
37+
wrapper = LinearModelWrapper(X, y).fit()
38+
39+
result = check_linearity(X, y, model_wrapper=wrapper)
40+
assert "r_squared" in result.details

tests/test_model_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from app.models.linear_model_wrapper import LinearModelWrapper
6+
from app.models.utils import get_model_wrapper
7+
8+
9+
def test_get_model_wrapper_linear():
10+
"""
11+
Verify get_model_wrapper returns a LinearModelWrapper for 'linear' input.
12+
"""
13+
X = pd.DataFrame({"x1": np.random.randn(30)})
14+
y = 2 * X["x1"] + np.random.randn(30)
15+
16+
wrapper = get_model_wrapper("linear", X, y)
17+
assert isinstance(wrapper, LinearModelWrapper)
18+
assert hasattr(wrapper, "predict")
19+
assert hasattr(wrapper, "residuals")
20+
21+
22+
def test_get_model_wrapper_invalid_type():
23+
"""
24+
Confirm get_model_wrapper raises ValueError for unknown model_type.
25+
"""
26+
X = pd.DataFrame({"x1": np.random.randn(30)})
27+
y = 2 * X["x1"] + np.random.randn(30)
28+
29+
with pytest.raises(ValueError, match="Unsupported model type"):
30+
get_model_wrapper("invalid_type", X, y)

0 commit comments

Comments
 (0)