Skip to content

Commit 5e05568

Browse files
committed
lint
1 parent 74f9a44 commit 5e05568

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
types:
1111
- python
1212
args:
13-
- "--max-line-length=90"
13+
- "--max-line-length=100"
1414
- id: trailing-whitespace
1515
- id: end-of-file-fixer
1616
- id: check-yaml

vetiver/handlers/sklearn.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def handler_predict(self, input_data, check_prototype: bool, **kw):
3030
Test data
3131
check_prototype: bool
3232
prediction_type: str
33-
Type of prediction to make. One of "predict", "predict_proba", or "predict_log_proba".
34-
Default is "predict".
33+
Type of prediction to make. One of "predict", "predict_proba",
34+
or "predict_log_proba". Default is "predict".
3535
3636
Returns
3737
-------
@@ -40,7 +40,14 @@ def handler_predict(self, input_data, check_prototype: bool, **kw):
4040
"""
4141
prediction_type = kw.get("prediction_type", "predict")
4242
if prediction_type not in ["predict", "predict_proba", "predict_log_proba"]:
43-
raise ValueError('prediction_type must be "predict", "predict_proba", or "predict_log_proba"')
44-
45-
input_data = [input_data] if check_prototype and not isinstance(input_data, pd.DataFrame) else input_data
43+
raise ValueError(
44+
'prediction_type must be "predict", "predict_proba", \
45+
or "predict_log_proba"'
46+
)
47+
48+
input_data = (
49+
[input_data]
50+
if check_prototype and not isinstance(input_data, pd.DataFrame)
51+
else input_data
52+
)
4653
return getattr(self.model, prediction_type)(input_data).tolist()

vetiver/handlers/statsmodels.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def handler_predict(self, input_data, check_prototype, **kw):
4242
"""
4343
if not sm_exists:
4444
raise ImportError("Cannot import `statsmodels`")
45-
46-
input_data = input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data]
45+
46+
input_data = (
47+
input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data]
48+
)
4749
return self.model.predict(input_data).tolist()

0 commit comments

Comments
 (0)