Skip to content

Commit 289450b

Browse files
authored
Merge pull request #76 from EducationalTestingService/feature/choose-svd-method
Add multiple SVD methods and address other issues
2 parents 38827d5 + 2bfd6ed commit 289450b

11 files changed

+403
-62
lines changed

factor_analyzer/factor_analyzer.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sklearn.utils import check_array
3232
from sklearn.utils.validation import check_is_fitted
3333

34+
POSSIBLE_SVDS = ['randomized', 'lapack']
3435

3536
POSSIBLE_IMPUTATIONS = ['mean', 'median', 'drop']
3637

@@ -176,10 +177,19 @@ class FactorAnalyzer(BaseEstimator, TransformerMixin):
176177
If missing values are present in the data, either use
177178
list-wise deletion ('drop') or impute the column median
178179
('median') or column mean ('mean').
180+
Defaults to 'median'
179181
use_corr_matrix : bool, optional
180182
Set to true if the `data` is the correlation
181183
matrix.
182184
Defaults to False.
185+
svd_method : {‘lapack’, ‘randomized’}
186+
The SVD method to use when ``method='principal'``.
187+
If 'lapack', use standard SVD from ``scipy.linalg``.
188+
If 'randomized', use faster ``randomized_svd``
189+
function from scikit-learn. The latter should only
190+
be used if the number of columns is greater than or
191+
equal to the number of rows in in the dataset.
192+
Defaults to 'randomized'
183193
rotation_kwargs, optional
184194
Additional key word arguments
185195
are passed to the rotation method.
@@ -249,32 +259,18 @@ def __init__(self,
249259
is_corr_matrix=False,
250260
bounds=(0.005, 1),
251261
impute='median',
262+
svd_method='randomized',
252263
rotation_kwargs=None):
253264

254-
rotation = rotation.lower() if isinstance(rotation, str) else rotation
255-
if rotation not in POSSIBLE_ROTATIONS + [None]:
256-
raise ValueError(f"The rotation must be one of the following: {POSSIBLE_ROTATIONS + [None]}")
257-
258-
method = method.lower()
259-
if method not in POSSIBLE_METHODS:
260-
raise ValueError(f"The method must be one of the following: {POSSIBLE_METHODS + [None]}")
261-
262-
impute = impute.lower()
263-
if impute not in POSSIBLE_IMPUTATIONS:
264-
raise ValueError(f"The imputation must be one of the following: {POSSIBLE_IMPUTATIONS + [None]}")
265-
266-
if method == 'principal' and is_corr_matrix:
267-
raise ValueError('The principal method is only implemented using '
268-
'the full data set, not the correlation matrix.')
269-
270265
self.n_factors = n_factors
271266
self.rotation = rotation
272267
self.method = method
273268
self.use_smc = use_smc
274269
self.bounds = bounds
275270
self.impute = impute
276271
self.is_corr_matrix = is_corr_matrix
277-
self.rotation_kwargs = {} if rotation_kwargs is None else rotation_kwargs
272+
self.svd_method = svd_method
273+
self.rotation_kwargs = rotation_kwargs
278274

279275
# default matrices to None
280276
self.mean_ = None
@@ -288,6 +284,34 @@ def __init__(self,
288284
self.rotation_matrix_ = None
289285
self.weights_ = None
290286

287+
def _arg_checker(self):
288+
"""
289+
Check the input parameters to make sure they're properly formattted.
290+
We need to do this to ensure that the FactorAnalyzer class can be properly
291+
cloned when used with grid search CV, for example.
292+
"""
293+
self.rotation = self.rotation.lower() if isinstance(self.rotation, str) else self.rotation
294+
if self.rotation not in POSSIBLE_ROTATIONS + [None]:
295+
raise ValueError(f"The rotation must be one of the following: {POSSIBLE_ROTATIONS + [None]}")
296+
297+
self.method = self.method.lower() if isinstance(self.method, str) else self.method
298+
if self.method not in POSSIBLE_METHODS:
299+
raise ValueError(f"The method must be one of the following: {POSSIBLE_METHODS}")
300+
301+
self.impute = self.impute.lower() if isinstance(self.impute, str) else self.impute
302+
if self.impute not in POSSIBLE_IMPUTATIONS:
303+
raise ValueError(f"The imputation must be one of the following: {POSSIBLE_IMPUTATIONS}")
304+
305+
self.svd_method = self.svd_method.lower() if isinstance(self.svd_method, str) else self.svd_method
306+
if self.svd_method not in POSSIBLE_SVDS:
307+
raise ValueError(f"The SVD method must be one of the following: {POSSIBLE_SVDS}")
308+
309+
if self.method == 'principal' and self.is_corr_matrix:
310+
raise ValueError('The principal method is only implemented using '
311+
'the full data set, not the correlation matrix.')
312+
313+
self.rotation_kwargs = {} if self.rotation_kwargs is None else self.rotation_kwargs
314+
291315
@staticmethod
292316
def _fit_uls_objective(psi, corr_mtx, n_factors):
293317
"""
@@ -472,8 +496,21 @@ def _fit_principal(self, X):
472496
X = X.copy()
473497
X = (X - X.mean(0)) / X.std(0)
474498

499+
# if the number of rows is less than the number of columns,
500+
# warn the user that the number of factors will be constrained
501+
nrows, ncols = X.shape
502+
if nrows < ncols and self.n_factors >= nrows:
503+
warnings.warn('The number of factors will be '
504+
'constrained to min(n_samples, n_features)'
505+
'={}.'.format(min(nrows, ncols)))
506+
475507
# perform the randomized singular value decomposition
476-
U, S, V = randomized_svd(X, self.n_factors)
508+
if self.svd_method == 'randomized':
509+
U, S, V = randomized_svd(X, self.n_factors)
510+
# otherwise, perform the full SVD
511+
else:
512+
U, S, V = np.linalg.svd(X, full_matrices=False)
513+
477514
corr_mtx = np.dot(X, V.T)
478515
loadings = np.array([[pearsonr(x, c)[0] for c in corr_mtx.T] for x in X.T])
479516
return loadings
@@ -577,6 +614,9 @@ def fit(self, X, y=None):
577614
[ 0.81533404, -0.12494695, 0.17639683]])
578615
"""
579616

617+
# check the input arguments
618+
self._arg_checker()
619+
580620
# check if the data is a data frame,
581621
# so we can convert it to an array
582622
if isinstance(X, pd.DataFrame):
@@ -650,11 +690,14 @@ def fit(self, X, y=None):
650690
phi = np.dot(np.dot(np.diag(signs), phi), np.diag(signs))
651691
structure = np.dot(loadings, phi) if self.rotation in OBLIQUE_ROTATIONS else None
652692

653-
# resort the factors according to their variance
654-
variance = self._get_factor_variance(loadings)[0]
655-
new_order = list(reversed(np.argsort(variance)))
656-
loadings = loadings[:, new_order].copy()
693+
# resort the factors according to their variance,
694+
# unless the method is principal
695+
if self.method != 'principal':
696+
variance = self._get_factor_variance(loadings)[0]
697+
new_order = list(reversed(np.argsort(variance)))
698+
loadings = loadings[:, new_order].copy()
657699

700+
# if the structure matrix exists, reorder
658701
if structure is not None:
659702
structure = structure[:, new_order].copy()
660703

factor_analyzer/test_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def calculate_py_output(test_name,
3838
factors,
3939
method,
4040
rotation,
41+
svd_method='randomized',
4142
use_corr_matrix=False,
4243
top_dir=None):
4344
"""
@@ -54,6 +55,9 @@ def calculate_py_output(test_name,
5455
The rotation method
5556
rotation : str
5657
The type of rotation
58+
svd_method : str, optional
59+
The SVD method to use
60+
Defaults to 'randomized'
5761
use_corr_matrix : bool, optional
5862
Whether to use the correlation matrix.
5963
Defaults to False.
@@ -81,7 +85,7 @@ def calculate_py_output(test_name,
8185
rotation = None if rotation == 'none' else rotation
8286
method = {'uls': 'minres'}.get(method, method)
8387

84-
fa = FactorAnalyzer(n_factors=factors, method=method,
88+
fa = FactorAnalyzer(n_factors=factors, method=method, svd_method=svd_method,
8589
rotation=rotation, is_corr_matrix=use_corr_matrix)
8690
fa.fit(X)
8791

@@ -228,6 +232,11 @@ def check_close(data1, data2, rel_tol=0.0, abs_tol=0.1,
228232
data1 = normalize(data1, absolute)
229233
data2 = normalize(data2, absolute)
230234

235+
print(data1)
236+
print()
237+
print(data2)
238+
print('------')
239+
231240
err_msg = 'r - py: {} != {}'
232241
assert data1.shape == data2.shape, err_msg.format(data1.shape, data2.shape)
233242

@@ -253,6 +262,7 @@ def check_scenario(test_name,
253262
check_scores=False,
254263
check_structure=False,
255264
use_corr_matrix=False,
265+
svd_method='randomized',
256266
data_dir=None,
257267
expected_dir=None,
258268
rel_tol=0,
@@ -321,8 +331,10 @@ def check_scenario(test_name,
321331
if check_structure:
322332
output_types.append('structure')
323333

324-
r_output = collect_r_output(test_name, factors, method, rotation, output_types, expected_dir)
325-
py_output = calculate_py_output(test_name, factors, method, rotation, use_corr_matrix, data_dir)
334+
r_output = collect_r_output(test_name, factors, method, rotation,
335+
output_types, expected_dir)
336+
py_output = calculate_py_output(test_name, factors, method, rotation, svd_method,
337+
use_corr_matrix, data_dir)
326338

327339
for output_type in output_types:
328340

tests/data/test15.csv

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
pers01,pers02,pers03,pers04,pers05,pers06,pers07,pers08,pers09,pers10,pers11,pers12,pers13,pers14,pers15,pers16,pers17,pers18,pers19,pers20,pers21,pers22,pers23,pers24,pers25,pers26,pers27,pers28,pers29,pers30,pers31,pers32,pers33,pers34,pers35,pers36,pers37,pers38,pers39,pers40,pers41,pers42,pers43,pers44
2+
5,4,5,1,4,3,3,1,2,3,2,4,5,4,4,3,5,1,3,4,2,4,2,3,2,5,4,5,4,2,5,4,5,3,1,3,4,4,4,3,3,3,5,4
3+
1,1,5,2,1,2,5,1,5,1,5,3,5,4,1,2,1,3,5,1,5,5,1,3,1,2,3,5,3,3,5,5,5,5,5,3,1,1,2,3,3,5,2,1
4+
4,1,5,3,3,4,5,3,1,4,2,1,5,4,3,2,5,1,5,4,4,4,3,2,4,1,3,4,5,4,3,3,5,5,4,2,3,1,5,4,2,3,5,3
5+
4,2,5,1,4,3,4,4,4,5,4,1,4,5,3,3,4,2,5,4,4,4,3,2,4,2,1,5,3,4,4,4,4,4,3,3,3,4,5,5,3,5,2,4
6+
2,3,5,1,2,4,5,2,3,3,4,2,5,3,3,3,5,1,4,2,5,5,1,3,2,2,4,4,4,3,4,4,5,4,4,2,5,5,4,3,2,4,3,2
7+
1,1,5,4,3,4,4,2,1,4,3,3,5,5,3,2,4,1,5,3,5,4,1,3,3,1,2,5,5,4,5,3,4,2,3,1,1,3,5,4,2,4,3,5
8+
3,2,5,1,2,1,1,2,5,4,4,1,5,1,1,4,5,5,1,3,2,5,1,4,2,1,1,5,1,2,1,5,4,5,5,5,2,5,1,1,4,5,1,1
9+
5,2,4,2,4,1,4,3,3,5,4,1,4,3,2,4,4,4,2,5,2,4,3,3,1,1,3,4,3,3,2,5,4,4,2,4,2,4,2,3,2,5,4,4
10+
5,1,4,3,2,1,4,4,2,3,4,1,4,4,2,5,5,5,4,4,2,5,2,4,2,4,4,4,3,5,2,5,4,2,4,5,1,2,4,3,2,5,4,2
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
,x
2+
pers01,2.147636456
3+
pers02,3.05860696
4+
pers03,2.634980867
5+
pers04,4.167044602
6+
pers05,3.239628161
7+
pers06,2.855673949
8+
pers07,2.846150806
9+
pers08,2.342990801
10+
pers09,3.030327258
11+
pers10,2.975790993
12+
pers11,2.395566544
13+
pers12,2.220826511
14+
pers13,2.038321064
15+
pers14,3.265232908
16+
pers15,2.425621063
17+
pers16,3.061295019
18+
pers17,2.017025954
19+
pers18,3.049338556
20+
pers19,3.122061642
21+
pers20,2.728356141
22+
pers21,2.594510111
23+
pers22,3.782658794
24+
pers23,3.070056277
25+
pers24,4.111231609
26+
pers25,3.064079366
27+
pers26,1.832456264
28+
pers27,2.162362221
29+
pers28,1.758259243
30+
pers29,3.225984183
31+
pers30,3.166512309
32+
pers31,2.790080337
33+
pers32,3.175142177
34+
pers33,2.391267257
35+
pers34,2.403458399
36+
pers35,2.514121013
37+
pers36,2.983086234
38+
pers37,2.746601113
39+
pers38,3.495223869
40+
pers39,2.864928829
41+
pers40,3.571532646
42+
pers41,2.654887486
43+
pers42,2.126933791
44+
pers43,2.482551325
45+
pers44,3.105868335
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
,x
2+
1,14.25162169
3+
2,9.294533097
4+
3,6.498852485
5+
4,4.620861291
6+
5,3.059319317
7+
6,2.625005382
8+
7,1.991825606
9+
8,1.657981133
10+
9,2.69E-15
11+
10,1.67E-15
12+
11,1.22E-15
13+
12,9.84E-16
14+
13,6.64E-16
15+
14,5.94E-16
16+
15,5.04E-16
17+
16,4.43E-16
18+
17,4.28E-16
19+
18,3.86E-16
20+
19,3.62E-16
21+
20,2.83E-16
22+
21,1.92E-16
23+
22,1.41E-16
24+
23,1.16E-16
25+
24,7.91E-17
26+
25,2.42E-17
27+
26,1.45E-17
28+
27,-1.52E-17
29+
28,-3.58E-17
30+
29,-9.31E-17
31+
30,-2.07E-16
32+
31,-2.23E-16
33+
32,-2.53E-16
34+
33,-3.39E-16
35+
34,-4.00E-16
36+
35,-4.38E-16
37+
36,-4.54E-16
38+
37,-5.18E-16
39+
38,-5.69E-16
40+
39,-7.36E-16
41+
40,-8.59E-16
42+
41,-9.67E-16
43+
42,-1.20E-15
44+
43,-1.50E-15
45+
44,-3.05E-15
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
,MR1,MR2,MR3,MR4,MR5,MR6,MR7,MR8,MR9
2+
pers01,-0.04828462,0.872685807,0.239695145,0.1783874,-0.070266058,0.233832136,-0.127745213,0.266230396,0.109949125
3+
pers02,0.065050861,0.087111805,0.917108625,0.09829389,0.045889312,0.259212544,0.240458465,-0.101551737,-0.438491274
4+
pers03,0.335344379,-0.692295104,0.352908838,-0.448246064,-0.105717303,-0.075752288,0.072948636,0.246102405,-0.026867838
5+
pers04,0.369038127,0.065051128,-0.659481751,0.216562195,0.122770915,-0.550263386,-0.238255766,-0.055995953,0.650878605
6+
pers05,0.460360115,0.673786618,0.357376302,-0.321981352,0.22082246,0.141451804,-0.140351837,-0.119249891,0.088835419
7+
pers06,0.87171974,-0.336822521,0.112582984,-0.145152604,-0.211905593,-0.080225718,0.198171197,-0.047949025,0.530776467
8+
pers07,0.576151802,-0.206133915,-0.510224335,0.311050617,-0.18298515,0.43465408,0.004627584,-0.214586587,0.581738459
9+
pers08,0.013465102,0.72669014,-0.534236413,-0.20871943,-0.19468342,0.119898964,0.269297945,0.134048523,0.395215471
10+
pers09,-0.748916302,-0.361166248,0.065624042,-0.340057899,0.067146842,0.42644553,0.043870457,0.021173528,-0.771983064
11+
pers10,0.137766271,0.729061174,0.055714597,-0.601408949,-0.108200226,-0.133706349,0.019605779,-0.233932375,0.119752545
12+
pers11,-0.675924401,-0.273483264,-0.433155063,-0.064071358,0.143265802,0.359476466,0.248424031,-0.255228802,-0.544707831
13+
pers12,0.355725628,-0.531383218,0.427897805,0.322997529,0.543285483,-0.056315934,-0.072021539,-0.012227634,-0.211014824
14+
pers13,0.191984172,-0.752696697,0.397452888,0.0537552,-0.217179515,-0.397121986,-0.166976097,0.05456656,0.03461962
15+
pers14,0.761628997,0.018888667,-0.389796375,0.02405038,0.439146501,0.204641161,0.08654672,0.15756543,0.476955155
16+
pers15,0.841854905,0.251761235,0.412010313,0.053732916,0.042505833,0.029407325,0.226082677,0.038366163,0.402422072
17+
pers16,-0.64430474,0.655136206,0.025197088,0.259591771,0.01078642,-0.024935079,0.293030345,0.032253453,-0.309892468
18+
pers17,0.133885448,0.599835612,0.405812074,0.039838088,-0.411211008,-0.402493253,0.338007923,0.103231512,0.273106526
19+
pers18,-0.899018777,0.306544161,-0.263593624,0.086024622,0.073959402,-0.061269493,-0.070891304,0.081636764,-0.419225537
20+
pers19,0.690934528,-0.329647629,-0.551624951,0.0182762,0.064019853,0.196192485,0.15420495,0.207240475,0.606635059
21+
pers20,0.155163821,0.957389706,0.108910582,-0.089609788,0.022891194,-0.02826736,-0.193441382,0.026364585,0.239325582
22+
pers21,0.485799705,-0.70406665,-0.364281338,-0.193071399,-0.097510953,0.090140204,0.1946385,-0.206866024,0.285201141
23+
pers22,-0.681026118,-0.43653348,-0.152606385,0.349435475,-0.227476649,-0.024178065,0.372173269,0.096979075,-0.32989024
24+
pers23,0.298251743,0.738460588,-0.104605406,-0.18269442,-0.149645341,0.429224993,-0.312246211,0.131441227,0.397777692
25+
pers24,-0.735050328,0.022838942,0.070148594,0.42301257,0.17584153,-0.452732468,0.197949991,-0.015741936,-0.467726762
26+
pers25,0.612945178,0.179319088,-0.143944907,-0.536182675,-0.246095779,-0.144123534,0.178983737,0.412990281,0.615785245
27+
pers26,0.075493078,0.164186546,0.361855321,0.60623896,0.363141969,0.212132397,0.266854716,0.469888221,-0.107909185
28+
pers27,0.222367812,0.025808466,0.118818476,0.937412093,-0.130937606,0.1757883,0.031728968,-0.089298973,0.271281581
29+
pers28,-0.066730222,-0.434588822,0.269651574,-0.478297293,0.639249191,-0.095051394,-0.032966158,0.294003909,-0.528846167
30+
pers29,0.943968963,-0.072840218,-0.105035164,0.233586597,-0.059723312,-0.116892943,-0.06291015,-0.129743188,0.741046671
31+
pers30,0.329973251,0.309685066,-0.828180961,0.108154302,-0.016580396,-0.057603016,0.255376313,0.169769485,0.679006138
32+
pers31,0.663368245,-0.546831933,0.099996746,0.073101125,0.436956147,0.219166869,0.07969368,-0.016075172,0.068306007
33+
pers32,-0.896103875,0.087076721,-0.032225109,0.233865369,0.167003644,0.322598699,-0.024188096,-0.033746387,-0.67512874
34+
pers33,0.368927175,-0.55505924,0.330315318,0.404087257,-0.358903058,0.316887491,-0.186347221,0.139473596,0.227565878
35+
pers34,-0.297896347,-0.356547533,0.108006974,-0.380935237,-0.55636427,0.376008538,-0.417555873,0.045584353,-0.143322194
36+
pers35,-0.466026995,-0.521274333,-0.486631795,-0.090454241,-0.437389454,-0.144597074,0.10906771,0.204904425,0.03733149
37+
pers36,-0.849143547,0.423510789,0.024744724,0.136594471,0.030216843,0.09148576,-0.026121108,0.265251737,-0.457345574
38+
pers37,0.341668592,0.007337606,0.66756131,-0.008476775,-0.46117769,0.331304008,0.328795419,-0.083337506,0.029768105
39+
pers38,-0.263669214,0.20346738,0.640196465,-0.321078152,0.009810988,-0.038680089,0.470802886,-0.391010377,-0.606810668
40+
pers39,0.882113868,0.17203835,-0.16887364,-0.023210193,-0.013321485,-0.040546264,0.34238953,0.210163886,0.712370485
41+
pers40,0.791007623,0.146938,-0.379995394,-0.239363727,0.157728703,0.334133052,0.109024279,0.051138425,0.546598278
42+
pers41,-0.56246979,-0.273713014,0.436724979,-0.447447065,0.177598914,0.035751419,-0.068604277,0.424553004,-0.688649095
43+
pers42,-0.742317116,0.071452222,-0.469083975,-0.218919485,0.263651376,0.217525709,0.182249823,-0.160778989,-0.53833395
44+
pers43,0.587977857,0.446104906,0.144926452,0.575703527,-0.127295674,0.010140071,-0.290820054,0.044169119,0.615268623
45+
pers44,0.697943291,0.446797496,0.096467909,-0.230890491,0.408446905,-0.09793622,-0.066987721,-0.264051038,0.260242219

0 commit comments

Comments
 (0)