Skip to content

Commit aa7855b

Browse files
this works-- pip install -r requirements; pip install -e .; pytest
1 parent 905b53f commit aa7855b

File tree

5 files changed

+313
-71
lines changed

5 files changed

+313
-71
lines changed

condo/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .adapter_mmd import AdapterMMD
2+
from .adapter_wd import AdapterWD
3+
from .adapter_gaussian_ot import AdapterGaussianOT
4+
from .condo_adapter_kld import ConDoAdapterKLD
5+
from .condo_adapter_mmd import ConDoAdapterMMD
6+
from .condo_adapter_wd import ConDoAdapterWD
7+
from .product_prior import product_prior
8+
9+
__version__ = "0.8.0"

condo/utils.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
from copy import deepcopy
2+
3+
import miceforest as mf
4+
import numpy as np
5+
import sklearn.utils as skut
6+
import torch
7+
8+
9+
class AdapterDataset(torch.utils.data.Dataset):
10+
def __init__(
11+
self,
12+
S_list: np.ndarray,
13+
T_list: np.ndarray,
14+
):
15+
# Each list has len n_bootstraps * bootsize, with elts shape=(n_mice_impute, d)
16+
#assert S_list.shape == T_list.shape
17+
assert S_list.shape[0] == T_list.shape[0]
18+
self.S_list = torch.from_numpy(S_list)
19+
self.T_list = torch.from_numpy(T_list)
20+
21+
def __len__(self):
22+
return self.S_list.shape[0]
23+
24+
def __getitem__(self, idx):
25+
# Returns a pair of (n_mice_impute, d) matrices as a single "sample"
26+
# We will compute the MMD between these two matrices
27+
# And the loss for a batch will be the sum over a batch of "samples"
28+
return self.S_list[idx, :, :], self.T_list[idx, :, :]
29+
30+
def dtype(self):
31+
return self.S_list.dtype
32+
33+
34+
class AdapterDatasetConDo(torch.utils.data.Dataset):
35+
def __init__(
36+
self,
37+
Xs,
38+
Xt,
39+
Zs_,
40+
Zt_,
41+
Z_test_,
42+
W_test,
43+
n_mice_impute,
44+
n_mice_iters,
45+
n_samples,
46+
batch_size,
47+
):
48+
self.Xs = Xs
49+
self.Xt = Xt
50+
self.Zs_ = Zs_
51+
self.Zt_ = Zt_
52+
self.Z_test_ = Z_test_
53+
self.W_test = W_test
54+
self.n_mice_impute = n_mice_impute
55+
self.n_mice_iters = n_mice_iters
56+
self.n_samples = n_samples
57+
self.batch_size = batch_size
58+
self.mydtype = torch.from_numpy(Xs).dtype
59+
60+
def __len__(self):
61+
return self.n_samples
62+
63+
def __getitem__(self, idx):
64+
Xs = self.Xs
65+
Zs_ = self.Zs_
66+
Xt = self.Xt
67+
Zt_ = self.Zt_
68+
Z_test_ = self.Z_test_
69+
W_test = self.W_test
70+
batch_size = self.batch_size
71+
dtype = Xs.dtype
72+
rng = skut.check_random_state(idx)
73+
d = Xs.shape[1]
74+
75+
Z_testixs = rng.choice(Z_test_.shape[0], size=batch_size, p=W_test.ravel())
76+
bZ_test_ = Z_test_[Z_testixs, :]
77+
78+
S_dataset = np.concatenate([
79+
np.concatenate([Xs, Zs_], axis=1),
80+
np.concatenate([np.full((batch_size, d), np.nan), bZ_test_], axis=1),
81+
])
82+
S_imputer = mf.ImputationKernel(
83+
S_dataset,
84+
datasets=self.n_mice_impute,
85+
save_all_iterations=False,
86+
random_state=idx,
87+
)
88+
S_imputer.mice(self.n_mice_iters)
89+
S_complete = np.zeros((batch_size, self.n_mice_impute, d), dtype=dtype)
90+
for imp in range(self.n_mice_impute):
91+
S_complete[:, imp, :] = S_imputer.complete_data(dataset=imp)[Xs.shape[0]:, :d]
92+
93+
T_dataset = np.concatenate([
94+
np.concatenate([Xt, Zt_], axis=1),
95+
np.concatenate([np.full((batch_size, d), np.nan), bZ_test_], axis=1),
96+
])
97+
T_imputer = mf.ImputationKernel(
98+
T_dataset,
99+
datasets=self.n_mice_impute,
100+
save_all_iterations=False,
101+
random_state=idx+1234,
102+
)
103+
T_imputer.mice(self.n_mice_iters)
104+
T_complete = np.zeros((batch_size, self.n_mice_impute, d), dtype=dtype)
105+
for imp in range(self.n_mice_impute):
106+
T_complete[:, imp, :] = T_imputer.complete_data(dataset=imp)[Xt.shape[0]:, :d]
107+
108+
return torch.from_numpy(S_complete), torch.from_numpy(T_complete)
109+
110+
def dtype(self):
111+
return self.mydtype
112+
113+
114+
class EarlyStopping:
115+
def __init__(self, patience, model=None):
116+
self.patience = patience
117+
self.counter = 0
118+
self.early_stop = False
119+
self.loss_min = np.Inf
120+
self.state_dict = None
121+
if model is not None:
122+
self.state_dict = deepcopy(model.state_dict())
123+
124+
def __call__(self, loss, model, epoch):
125+
if loss < self.loss_min:
126+
self.loss_min = loss
127+
self.epoch_min = epoch
128+
self.state_dict = deepcopy(model.state_dict())
129+
self.counter = 0
130+
else:
131+
self.counter += 1
132+
if self.counter >= self.patience:
133+
self.early_stop = True
134+
135+
136+
class LinearAdapter(torch.nn.Module):
137+
def __init__(
138+
self,
139+
transform_type: str,
140+
in_features: int,
141+
out_features: int,
142+
device=None,
143+
dtype=None,
144+
) -> None:
145+
factory_kwargs = {"device": device, "dtype": dtype}
146+
super().__init__()
147+
self.transform_type = transform_type
148+
self.in_features = in_features
149+
self.out_features = out_features
150+
151+
if transform_type == "location-scale":
152+
assert in_features == out_features
153+
num_feats = in_features
154+
self.M = torch.nn.Parameter(torch.empty(num_feats, **factory_kwargs))
155+
self.b = torch.nn.Parameter(torch.empty(num_feats, **factory_kwargs))
156+
157+
elif transform_type == "affine":
158+
self.M = torch.nn.Parameter(
159+
torch.empty((out_features, in_features), **factory_kwargs)
160+
)
161+
self.b = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
162+
else:
163+
raise ValueError(f"invalid transform_type:{transform_type}")
164+
self.reset_parameters()
165+
166+
def reset_parameters(self) -> None:
167+
if self.transform_type == "location-scale":
168+
torch.nn.init.ones_(self.M)
169+
torch.nn.init.zeros_(self.b)
170+
elif self.transform_type == "affine":
171+
torch.nn.init.eye_(self.M)
172+
torch.nn.init.zeros_(self.b)
173+
174+
def forward(self, S: torch.Tensor) -> torch.Tensor:
175+
(batch_size, n_mice_impute, ds) = S.shape
176+
S_ = S.reshape(-1, ds)
177+
if self.transform_type == "location-scale":
178+
adaptedSsample = S_ * self.M.reshape(1, -1) + self.b.reshape(1, -1)
179+
elif self.transform_type == "affine":
180+
adaptedSsample = S_ @ self.M.T + self.b.reshape(1, -1)
181+
adaptedSsample = adaptedSsample.reshape(batch_size, n_mice_impute, -1)
182+
return adaptedSsample
183+
184+
def extra_repr(self) -> str:
185+
return "transform_type={}, in_features={}, out_features={}".format(
186+
self.transform_type,
187+
self.in_features,
188+
self.out_features,
189+
)
190+
191+
def get_M_b(self):
192+
best_M = self.M.detach().numpy()
193+
best_b = self.b.detach().numpy()
194+
return (best_M, best_b)
195+
196+
197+
"""
198+
class LinearAdapter(torch.nn.Module):
199+
def __init__(
200+
self,
201+
transform_type: str,
202+
num_feats: int,
203+
device=None,
204+
dtype=None,
205+
) -> None:
206+
factory_kwargs = {"device": device, "dtype": dtype}
207+
super().__init__()
208+
self.transform_type = transform_type
209+
self.num_feats = num_feats
210+
211+
if transform_type == "location-scale":
212+
self.M = torch.nn.Parameter(torch.empty(num_feats, **factory_kwargs))
213+
self.b = torch.nn.Parameter(torch.empty(num_feats, **factory_kwargs))
214+
215+
elif transform_type == "affine":
216+
self.M = torch.nn.Parameter(
217+
torch.empty((num_feats, num_feats), **factory_kwargs)
218+
)
219+
self.b = torch.nn.Parameter(torch.empty(num_feats, **factory_kwargs))
220+
else:
221+
raise ValueError(f"invalid transform_type:{transform_type}")
222+
self.reset_parameters()
223+
224+
def reset_parameters(self) -> None:
225+
if self.transform_type == "location-scale":
226+
torch.nn.init.zeros_(self.M)
227+
torch.nn.init.zeros_(self.b)
228+
elif self.transform_type == "affine":
229+
torch.nn.init.zeros_(self.M)
230+
torch.nn.init.zeros_(self.b)
231+
232+
def forward(self, S: torch.Tensor) -> torch.Tensor:
233+
if self.transform_type == "location-scale":
234+
adaptedSsample = S * self.M.reshape(1, -1) + self.b.reshape(1, -1) + S
235+
elif self.transform_type == "affine":
236+
adaptedSsample = S @ self.M.T + self.b.reshape(1, -1) + S
237+
return adaptedSsample
238+
239+
def extra_repr(self) -> str:
240+
return "transform_type={}, num_feats={}".format(
241+
self.transform_type,
242+
self.num_feats,
243+
)
244+
245+
def get_M_b(self):
246+
best_M = self.M.detach().numpy()
247+
best_b = self.b.detach().numpy()
248+
if best_M.ndim == 1:
249+
best_M = best_M + 1.
250+
else:
251+
best_M = best_M + np.eye(self.num_feats, dtype=best_M.dtype)
252+
return (best_M, best_b)
253+
"""
254+
255+
256+
class RBF(torch.nn.Module):
257+
"""https://github.com/yiftachbeer/mmd_loss_pytorch"""
258+
def __init__(self, n_kernels=1, mul_factor=2.0, bandwidth=None):
259+
super().__init__()
260+
# XXX n_kernels > 1 causes a segfault at torch.exp with torch==2.1.2 and numpy==1.26.3
261+
self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
262+
self.bandwidth = bandwidth
263+
264+
def get_bandwidth(self, L2_distances):
265+
if self.bandwidth is None:
266+
n_samples = L2_distances.shape[0]
267+
return L2_distances.data.sum() / (n_samples ** 2 - n_samples)
268+
269+
return self.bandwidth
270+
271+
def forward(self, X):
272+
L2_distances = torch.cdist(X, X) ** 2
273+
bws = (self.get_bandwidth(L2_distances.detach()) * self.bandwidth_multipliers)[:, None, None]
274+
beforeexp = -L2_distances[None, ...] / bws
275+
afterexp = torch.exp(beforeexp)
276+
return afterexp.sum(dim=0)
277+
278+
279+
class BatchMMDLoss(torch.nn.Module):
280+
"""https://github.com/yiftachbeer/mmd_loss_pytorch"""
281+
def __init__(self, kernel=RBF()):
282+
super().__init__()
283+
self.kernel = kernel
284+
285+
def forward(self, allX, allY):
286+
batch_size = allX.shape[0]
287+
mmd = torch.tensor(0.)
288+
289+
for i in range(batch_size):
290+
X = allX[i, :, :]
291+
Y = allY[i, :, :]
292+
K = self.kernel(torch.vstack([X, Y]))
293+
294+
X_size = X.shape[0]
295+
XX = K[:X_size, :X_size].mean()
296+
XY = K[:X_size, X_size:].mean()
297+
YY = K[X_size:, X_size:].mean()
298+
mmd = mmd + XX - 2 * XY + YY
299+
return mmd
300+

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
miceforest
1+
miceforest==5.7.0
22
numpy>=1.18.1
33
pandas>=1.0
44
pre-commit>=2.2.0

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def readme():
77
return readme_file.read()
88

99
configuration = {
10-
"name": "utrees",
10+
"name": "condo",
1111
"version": "0.8.0",
1212
"description": "Confounded domain adaptation",
1313
"long_description": readme(),
@@ -37,7 +37,9 @@ def readme():
3737
"maintainer_email": "mccarter.calvin@gmail.com",
3838
"packages": ["condo"],
3939
"install_requires": [
40+
"miceforest<6.0.0",
4041
"numpy",
42+
"pandas",
4143
"pytorch-minimize>=0.0.2",
4244
"scipy",
4345
"scikit-learn",

0 commit comments

Comments
 (0)