Skip to content

Commit 6f0db92

Browse files
committed
feat(core): introduce model abstraction layer with BaseModelWrapper
- Added BaseModelWrapper and LinearModelWrapper to support model-agnostic diagnostics - Centralized model fitting using get_model_wrapper(), fit once for all checks - Updated linearity check to use model wrapper and fallback gracefully - Introduced --model-type CLI flag for model selection (currently supports 'linear') - Enabled filtering of assumption checks by applicable model types via @register_assumption - Printed model metadata in console report output (e.g. 'Model Type: Linear Regression') - Reordered function parameters for clarity and consistency across checks - Removed redundant classmethod import and updated docstrings for clarity
1 parent 986332e commit 6f0db92

File tree

8 files changed

+115
-22
lines changed

8 files changed

+115
-22
lines changed

app/core/dispatcher.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# app/core/dispatcher.py
2-
from typing import Dict
2+
from typing import Dict, Tuple
33

44
import pandas as pd
55

@@ -9,6 +9,8 @@
99
from app.core import normality # noqa: F401
1010
from app.core.registry import ASSUMPTION_CHECKS
1111
from app.core.types import AssumptionResult
12+
from app.models.base_model_wrapper import BaseModelWrapper
13+
from app.models.utils import get_model_wrapper
1214

1315
__all__ = ["check_assumption", "run_all_checks"]
1416

@@ -21,7 +23,7 @@ def check_assumption(
2123
2224
Args:
2325
name (str): assumption name
24-
X (pd.Series): Predictor (1D)
26+
X (pd.Series or pd.DataFrame): Predictor values (1D or multivariate)
2527
y (pd.Series): Response (1D)
2628
return_plot (bool, optional): Whether to return base64-encoded
2729
PNG of the plot. Defaults to False.
@@ -42,13 +44,13 @@ def check_assumption(
4244

4345

4446
def run_all_checks(
45-
X: pd.Series, y: pd.Series, return_plot: bool = False
46-
) -> Dict[str, AssumptionResult]:
47+
X: pd.Series, y: pd.Series, model_type=None, return_plot: bool = False
48+
) -> Tuple[Dict[str, AssumptionResult], BaseModelWrapper]:
4749
"""
4850
Run all registered assumption checks and return a dictionary of results.
4951
5052
Args:
51-
X (pd.Series): Predictor (1D)
53+
X (pd.Series or pd.DataFrame): Predictor values (1D or multivariate)
5254
y (pd.Series): Response (1D)
5355
return_plot (bool, optional): Whether to return base64-encoded
5456
PNG of the plot. Defaults to False.
@@ -62,6 +64,10 @@ def run_all_checks(
6264
if isinstance(X, pd.Series):
6365
X = X.to_frame()
6466

67+
model_wrapper = get_model_wrapper(model_type, X, y)
68+
6569
for name, func in ASSUMPTION_CHECKS.items():
66-
results[name] = func(X, y, return_plot)
67-
return results
70+
if model_type not in getattr(func, "_model_types", ["linear"]):
71+
continue
72+
results[name] = func(X, y, model_wrapper=model_wrapper, return_plot=return_plot)
73+
return results, model_wrapper

app/core/linearity.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import matplotlib.pyplot as plt
1111
import pandas as pd
12-
from sklearn.linear_model import LinearRegression
1312
from sklearn.metrics import r2_score
1413

1514
from app.config import LINEARITY_R2_THRESHOLD, R2_SEVERITY_THRESHOLDS
@@ -20,9 +19,9 @@
2019
__all__ = ["check_linearity"]
2120

2221

23-
@register_assumption("linearity")
22+
@register_assumption("linearity", model_types=["linear"])
2423
def check_linearity(
25-
X: pd.Series, y: pd.Series, return_plot: bool = False
24+
X: pd.Series, y: pd.Series, return_plot: bool = False, model_wrapper=None
2625
) -> AssumptionResult:
2726
"""
2827
Check linearity assumption using:
@@ -59,11 +58,15 @@ def check_linearity(
5958
)
6059
X = X.iloc[:, 0] # Convert to Series
6160

61+
# Guard for if model_wrapper is None
62+
if model_wrapper is None:
63+
from app.models.utils import get_model_wrapper
64+
65+
model_wrapper = get_model_wrapper("linear", X, y)
66+
6267
# Fit simple linear model to input data
63-
X_reshaped = X.values.reshape(-1, 1)
64-
model = LinearRegression().fit(X_reshaped, y)
65-
y_pred = model.predict(X_reshaped)
66-
residuals = y - y_pred
68+
residuals = model_wrapper.residuals()
69+
y_pred = model_wrapper.fitted()
6770

6871
# Coefficient of determination (R²) measures goodness of fit
6972
r2 = r2_score(y, y_pred)

app/core/registry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
AssumptionCheck = Callable[[pd.Series, pd.Series, bool], AssumptionResult]
1515

1616

17-
def register_assumption(name: str) -> Callable[[AssumptionCheck], AssumptionCheck]:
17+
def register_assumption(
18+
name: str, model_types: list = ["linear"]
19+
) -> Callable[[AssumptionCheck], AssumptionCheck]:
1820
"""
1921
Decorator to register an assumption check function under a given name.
2022
@@ -26,6 +28,8 @@ def register_assumption(name: str) -> Callable[[AssumptionCheck], AssumptionChec
2628
"""
2729

2830
def decorator(func: AssumptionCheck) -> AssumptionCheck:
31+
func._assumption_name = name
32+
func._model_types = model_types
2933
ASSUMPTION_CHECKS[name] = func
3034
return func
3135

app/models/base_model_wrapper.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class BaseModelWrapper(ABC):
5+
def __init__(self, X, y):
6+
self.X = X
7+
self.y = y
8+
9+
@abstractmethod
10+
def fit(self): ...
11+
12+
@abstractmethod
13+
def predict(self): ...
14+
15+
@abstractmethod
16+
def residuals(self): ...
17+
18+
@abstractmethod
19+
def fitted(self): ...
20+
21+
def summary(self):
22+
return {}

app/models/linear_model_wrapper.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import statsmodels.api as sm
2+
3+
from app.models.base_model_wrapper import BaseModelWrapper
4+
5+
6+
class LinearModelWrapper(BaseModelWrapper):
7+
def fit(self):
8+
self.model = sm.OLS(self.y, sm.add_constant(self.X)).fit()
9+
return self
10+
11+
def predict(self):
12+
return self.model.predict(sm.add_constant(self.X))
13+
14+
def residuals(self):
15+
return self.model.resid
16+
17+
def fitted(self):
18+
return self.model.fittedvalues
19+
20+
def summary(self):
21+
return {"model_type": "Linear Regression", "r_squared": self.model.rsquared}

app/models/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from app.models.base_model_wrapper import BaseModelWrapper
2+
from app.models.linear_model_wrapper import LinearModelWrapper
3+
4+
5+
def get_model_wrapper(model_type: str, X, y) -> BaseModelWrapper:
6+
if model_type == "linear":
7+
return LinearModelWrapper(X, y).fit()
8+
elif model_type == "PLACEHOLDER":
9+
...
10+
else:
11+
raise ValueError(f"Unsupported model type: {model_type}")

app/report.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
def generate_report(
1616
X,
1717
y,
18+
model_type=None,
1819
return_plot: bool = False,
1920
output_format: str = "console",
2021
verbose: bool = False,
@@ -23,7 +24,7 @@ def generate_report(
2324
Generate an assumption diagnostic report using the registered checks.
2425
2526
Args:
26-
X (pd.Series): Predictor values.
27+
X (pd.Series or pd.DataFrame): Predictor values (1D or multivariate)
2728
y (pd.Series): Response values.
2829
return_plot (bool, optional): Include base64-encoded plots in results.
2930
output_format (str): 'console', 'json', or 'markdown'.
@@ -32,10 +33,12 @@ def generate_report(
3233
Raises:
3334
ValueError: If the output_format is not recognized.
3435
"""
35-
results = run_all_checks(X, y, return_plot=return_plot)
36+
results, model_wrapper = run_all_checks(
37+
X, y, model_type=model_type, return_plot=return_plot
38+
)
3639

3740
if output_format == "console":
38-
print_console_report(results, verbose=verbose)
41+
print_console_report(results, model_wrapper=model_wrapper, verbose=verbose)
3942
elif output_format == "json":
4043
export_to_json(results)
4144
elif output_format == "markdown":
@@ -44,7 +47,7 @@ def generate_report(
4447
raise ValueError("Unsupported output format")
4548

4649

47-
def print_console_report(results, verbose: bool = False):
50+
def print_console_report(results, model_wrapper, verbose: bool = False):
4851
"""
4952
Print a structured Rich panel for each assumption result.
5053
@@ -54,6 +57,11 @@ def print_console_report(results, verbose: bool = False):
5457
"""
5558
console = Console()
5659
console.rule("[bold yellow]Assumption Check Report")
60+
61+
# Print mdoel metadata
62+
model_info = model_wrapper.summary().get("model_type", "Unknown")
63+
console.print(f"[bold cyan]Model Type:[/bold cyan] {model_info}")
64+
5765
for name, result in results.items():
5866

5967
# Determine pass/fail icon and panel title
@@ -198,13 +206,22 @@ def export_to_markdown(results, filename: str = None):
198206

199207

200208
if __name__ == "__main__":
201-
parser = argparse.ArgumentParser(description="Run regression assumption checks.")
209+
210+
parser = argparse.ArgumentParser(
211+
description="Run statistical assumption checks for supervised models."
212+
)
202213
parser.add_argument(
203214
"--data",
204215
choices=list_simulations().keys(),
205216
default="linear",
206217
help="Which simulated dataset to run assumption checks on.",
207218
)
219+
parser.add_argument(
220+
"--model-type",
221+
choices=["linear"],
222+
default="linear",
223+
help="Which model to fit for diagnostics.",
224+
)
208225
parser.add_argument(
209226
"--format",
210227
choices=["console", "json", "markdown"],
@@ -222,6 +239,10 @@ def export_to_markdown(results, filename: str = None):
222239

223240
args = parser.parse_args()
224241

242+
diagnostic_context = {
243+
"model_type": args.model_type,
244+
}
245+
225246
data_func = list_simulations()[args.data]
226247
df = data_func(seed=42)
227248

@@ -230,5 +251,10 @@ def export_to_markdown(results, filename: str = None):
230251
y = df["y"]
231252

232253
generate_report(
233-
X, y, return_plot=args.plot, output_format=args.format, verbose=args.verbose
254+
X,
255+
y,
256+
model_type=args.model_type,
257+
return_plot=args.plot,
258+
output_format=args.format,
259+
verbose=args.verbose,
234260
)

tests/test_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_dispatch_all_assumptions():
2020
Test dispatcher's run_all_checks().
2121
"""
2222
df = simulated_data.generate_linear_data(n_samples=300, seed=42)
23-
results = dispatcher.run_all_checks(df["x"], df["y"])
23+
results, _ = dispatcher.run_all_checks(df["x"], df["y"], model_type="linear")
2424
assert "linearity" in results
2525
assert "homoscedasticity" in results
2626
assert results["linearity"].passed

0 commit comments

Comments
 (0)