Skip to content

Commit d7a4db6

Browse files
authored
Merge pull request #50 from has2k1/single-dispatch
Use singledispatch to generate ptypes and remove save_ptype parameter
2 parents 07741fe + 19cf095 commit d7a4db6

17 files changed

+212
-160
lines changed

examples/coffeeratings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
lr_fit = LinearRegression().fit(X_train, y_train)
1818

1919
# create vetiver model
20-
v = vetiver.VetiverModel(lr_fit, save_ptype = True, ptype_data=X_train, model_name = "v")
20+
v = vetiver.VetiverModel(lr_fit, ptype_data=X_train, model_name = "v")
2121

2222
# version model via pin
2323
from pins import board_folder

vetiver/handlers/_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
except ImportError:
88
torch_exists = False
99

10-
def create_translator(model, ptype_data, save_ptype):
10+
def create_translator(model, ptype_data):
1111
"""check for model type to handle prediction
1212
1313
Parameters
@@ -22,10 +22,10 @@ def create_translator(model, ptype_data, save_ptype):
2222
"""
2323
if torch_exists:
2424
if isinstance(model, torch.nn.Module):
25-
return pytorch_vt.TorchHandler(model, ptype_data, save_ptype)
25+
return pytorch_vt.TorchHandler(model, ptype_data)
2626

2727
if isinstance(model, sklearn.base.BaseEstimator):
28-
return sklearn_vt.SKLearnHandler(model, ptype_data, save_ptype)
28+
return sklearn_vt.SKLearnHandler(model, ptype_data)
2929

3030
else:
3131
raise NotImplementedError

vetiver/handlers/pytorch_vt.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@ class TorchHandler:
1717
model : nn.Module
1818
a trained torch model
1919
"""
20-
def __init__(self, model, ptype_data, save_ptype):
20+
def __init__(self, model, ptype_data):
2121
self.model = model
2222
self.ptype_data = ptype_data
23-
self.save_ptype = save_ptype
2423

2524
def create_description(self):
2625
"""Create description for torch model
@@ -48,14 +47,13 @@ def ptype(self):
4847
----------
4948
ptype_data : pd.DataFrame, np.ndarray, or None
5049
Training data to create ptype
51-
save_ptype : bool
5250
5351
Returns
5452
-------
5553
ptype : pd.DataFrame or None
5654
Zero-row DataFrame for storing data types
5755
"""
58-
ptype = vetiver_create_ptype(self.ptype_data, self.save_ptype)
56+
ptype = vetiver_create_ptype(self.ptype_data)
5957

6058
return ptype
6159

@@ -90,15 +88,10 @@ def handler_predict(self, input_data, check_ptype):
9088
prediction = self.model(torch.from_numpy(input_data))
9189

9290
# do not check ptype
93-
else:
94-
batch = True
95-
if not isinstance(input_data, list):
96-
batch = False
97-
input_data = input_data.split(",") # user delimiter ?
98-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
99-
if not batch:
100-
input_data = input_data.reshape(1, -1)
101-
prediction = self.model(torch.from_numpy(input_data))
91+
else:
92+
input_data = torch.tensor(input_data)
93+
prediction = self.model(input_data)
94+
10295
else:
10396
raise ImportError("Cannot import `torch`.")
10497

vetiver/handlers/sklearn_vt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ class SKLearnHandler:
1212
a trained sklearn model
1313
"""
1414

15-
def __init__(self, model, ptype_data, save_ptype):
15+
def __init__(self, model, ptype_data):
1616
self.model = model
1717
self.ptype_data = ptype_data
18-
self.save_ptype = save_ptype
1918

2019
def create_description(self):
2120
"""Create description for sklearn model
@@ -42,14 +41,13 @@ def ptype(self):
4241
----------
4342
ptype_data : pd.DataFrame, np.ndarray, or None
4443
Training data to create ptype
45-
save_ptype : bool
4644
4745
Returns
4846
-------
4947
ptype : pd.DataFrame or None
5048
Zero-row DataFrame for storing data types
5149
"""
52-
ptype = vetiver_create_ptype(self.ptype_data, self.save_ptype)
50+
ptype = vetiver_create_ptype(self.ptype_data)
5351
return ptype
5452

5553
def handler_startup():

vetiver/pin_read_write.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool=True):
2929
type = "joblib",
3030
description = model.description,
3131
metadata = {"required_pkgs": model.metadata.get("required_pkgs"),
32-
"save_ptype": model.save_ptype,
3332
"ptype": None if model.ptype == None else model.ptype().json()},
3433
versioned=versioned
3534
)
@@ -79,7 +78,6 @@ def vetiver_pin_read(board, name: str, version: str = None) -> VetiverModel:
7978
url = meta.user.get("url"), # None all the time, besides Connect
8079
required_pkgs = meta.user.get("required_pkgs")
8180
),
82-
save_ptype=meta.user.get("save_ptype"),
8381
ptype_data = json.loads(meta.user.get("ptype")) if meta.user.get("ptype") else None,
8482
versioned = True
8583
)

vetiver/ptype.py

Lines changed: 145 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
from functools import singledispatch
2+
try:
3+
from types import NoneType
4+
except ImportError:
5+
# python < 3.10
6+
NoneType = type(None)
7+
18
import pandas as pd
29
import numpy as np
310
from pydantic import BaseModel, create_model
411

12+
513
class NoAvailablePTypeError(Exception):
614
"""
715
Throw an error if we cannot create
@@ -18,72 +26,173 @@ def __init__(
1826

1927
class InvalidPTypeError(Exception):
2028
"""
21-
Throw an error if `save_ptype` is not
22-
True, False, or data.frame
29+
Throw an error if ptype cannot be recognised
2330
"""
2431

2532
def __init__(
2633
self,
27-
message="The `ptype_data` argument must be a pandas.DataFrame, a pydantic BaseModel, np.ndarray, or `save_ptype` must be FALSE.",
34+
message="`ptype_data` must be a pd.DataFrame, a pydantic BaseModel or np.ndarray",
2835
):
2936
self.message = message
3037
super().__init__(self.message)
3138

3239

33-
def vetiver_create_ptype(ptype_data, save_ptype: bool):
40+
CREATE_PTYPE_TPL = """\
41+
Failed to create a data prototype (ptype) from data of \
42+
type {_data_type}. If your datatype is not one of \
43+
(pd.DataFrame, pydantic.BaseModel, np.ndarry, dict), \
44+
you should write a function to create the ptype. Here is \
45+
a template for such a function: \
46+
47+
from pydantic import create_model
48+
from vetiver.ptype import vetiver_create_ptype
49+
50+
@vetiver_create_ptype.register
51+
def _(data: {_data_type}):
52+
data_dict = ... # convert data to a dictionary
53+
ptype = create_model("ptype", **data_dict)
54+
return ptype
55+
56+
If your datatype is a common type, please consider submitting \
57+
a pull request.
58+
"""
59+
60+
@singledispatch
61+
def vetiver_create_ptype(data):
3462
"""Create zero row structure to save data types
63+
3564
Parameters
3665
----------
37-
ptype_data :
38-
Data that represents what
39-
save_ptype : bool
40-
Whether or not ptype should be created
66+
data : object
67+
An object with information (data) whose layout is to be determined.
4168
4269
Returns
4370
-------
44-
ptype
71+
ptype : pydantic.main.BaseModel
4572
Data prototype
4673
4774
"""
48-
ptype = None
75+
raise InvalidPTypeError(
76+
message=CREATE_PTYPE_TPL.format(_data_type=type(data))
77+
)
4978

50-
if save_ptype == False:
51-
pass
52-
elif save_ptype == True:
53-
try:
54-
if isinstance(ptype_data, np.ndarray):
55-
ptype = _array_to_ptype(ptype_data[1])
56-
elif isinstance(ptype_data, dict):
57-
ptype = _dict_to_ptype(ptype_data)
58-
elif isinstance(ptype_data.construct(), BaseModel):
59-
ptype = ptype_data
60-
except AttributeError: # cannot construct basemodel
61-
if isinstance(ptype_data, pd.DataFrame):
62-
ptype = _df_to_ptype(ptype_data.iloc[1, :])
63-
else:
64-
raise InvalidPTypeError
6579

80+
@vetiver_create_ptype.register
81+
def _(data: pd.DataFrame):
82+
"""
83+
Create ptype for a pandas dataframe
84+
85+
Parameters
86+
----------
87+
data : DataFrame
88+
Pandas dataframe
89+
90+
Examples
91+
--------
92+
>>> from pydantic import BaseModel
93+
>>> df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
94+
>>> prototype = vetiver_create_ptype(df)
95+
>>> issubclass(prototype, BaseModel)
96+
True
97+
>>> prototype()
98+
ptype(x=1, y=4)
99+
100+
The data prototype created for the dataframe is equivalent to:
101+
102+
>>> class another_prototype(BaseModel):
103+
... class Config:
104+
... title = 'ptype'
105+
... x: int = 1
106+
... y: int = 4
107+
108+
>>> another_prototype()
109+
another_prototype(x=1, y=4)
110+
>>> another_prototype() == prototype()
111+
True
112+
113+
Changing the title using `class Config` ensures that the
114+
also json/schemas match.
115+
116+
>>> another_prototype.schema() == prototype.schema()
117+
True
118+
"""
119+
dict_data = data.iloc[0, :].to_dict()
120+
ptype = create_model("ptype", **dict_data)
66121
return ptype
67122

68123

69-
def _df_to_ptype(train_data):
124+
@vetiver_create_ptype.register
125+
def _(data: np.ndarray):
126+
"""
127+
Create ptype for a numpy array
70128
71-
dict_data = train_data.to_dict()
72-
ptype = create_model("ptype", **dict_data)
129+
Parameters
130+
----------
131+
data : ndarray
132+
2-Dimensional numpy array
133+
134+
Examples
135+
--------
136+
>>> arr = np.array([[1, 4], [2, 5], [3, 6]])
137+
>>> prototype = vetiver_create_ptype(arr)
138+
>>> prototype()
139+
ptype(0=1, 1=4)
140+
141+
>>> arr2 = np.array([[1, 'a'], [2, 'b'], [3, 'c']], dtype=object)
142+
>>> prototype2 = vetiver_create_ptype(arr2)
143+
>>> prototype2()
144+
ptype(0=1, 1='a')
145+
"""
146+
def _item(value):
147+
# pydantic needs python objects. .item() converts a numpy
148+
# scalar type to a python equivalent, and if the ndarray
149+
# is dtype=object, it may have python objects
150+
try:
151+
return value.item()
152+
except AttributeError:
153+
return value
73154

155+
dict_data = dict(enumerate(data[0], 0))
156+
# pydantic requires strings as indicies
157+
dict_data = {f"{key}": _item(value) for key, value in dict_data.items()}
158+
ptype = create_model("ptype", **dict_data)
74159
return ptype
75160

76161

77-
def _array_to_ptype(train_data):
78-
dict_data = dict(enumerate(train_data, 0))
162+
@vetiver_create_ptype.register
163+
def _(data: dict):
164+
"""
165+
Create ptype for a dict
79166
80-
# pydantic requires strings as indicies
81-
dict_data = {str(key): value.item() for key, value in dict_data.items()}
82-
ptype = create_model("ptype", **dict_data)
167+
Parameters
168+
----------
169+
data : dict
170+
Dictionary
171+
"""
172+
return create_model("ptype", **data)
83173

84-
return ptype
85174

175+
@vetiver_create_ptype.register
176+
def _(data: BaseModel):
177+
"""
178+
Create ptype for a pydantic BaseModel object
86179
87-
def _dict_to_ptype(train_data):
180+
Parameters
181+
----------
182+
data : pydantic.BaseModel
183+
Pydantic BaseModel
184+
"""
185+
return data
88186

89-
return create_model("ptype",**train_data)
187+
188+
@vetiver_create_ptype.register
189+
def _(data: NoneType):
190+
"""
191+
Create ptype for None
192+
193+
Parameters
194+
----------
195+
data : None
196+
None
197+
"""
198+
return None

0 commit comments

Comments
 (0)