1
+ # -*- coding: utf-8 -*-
2
+
3
+
4
+ from . import PCA
5
+ from ..Selection .MolSys import Trajectory
6
+ import numpy as np
7
+ import matplotlib .pyplot as plt
8
+ from matplotlib .colors import ListedColormap
9
+ from copy import copy
10
+
11
+ class CombinedPCA (PCA ):
12
+ def __init__ (self ,* pcas ):
13
+ super ().__init__ (pcas [0 ].select )
14
+ # self.__dict__=copy(pcas[0].__dict__)
15
+ self ._pos = np .concatenate ([pca0 .pos for pca0 in pcas ],axis = 0 )
16
+ self .pcas = [copy (pca ) for pca in pcas ]
17
+
18
+ self ._traj = Traj (self )
19
+
20
+ for pca ,Range in zip (self .pcas ,self .ranges ):
21
+ pca ._PC = self .PC
22
+ pca ._pcamp = self .PCamp [:,Range [0 ]:Range [1 ]]
23
+ pca ._lambda = self .Lambda
24
+
25
+ self .traj [0 ]
26
+
27
+
28
+ @property
29
+ def traj (self ):
30
+ return self ._traj
31
+
32
+ @property
33
+ def n_trajs (self ):
34
+ return len (self .trajs )
35
+
36
+ @property
37
+ def trajs (self ):
38
+ return [pca .traj for pca in self .pcas ]
39
+
40
+ @property
41
+ def lengths (self ):
42
+ return np .array ([len (traj ) for traj in self .trajs ],dtype = int )
43
+
44
+ @property
45
+ def ranges (self ):
46
+ return [(self .lengths [:n ].sum (),self .lengths [:n + 1 ].sum ()) for n in range (self .n_trajs )]
47
+
48
+ def hist_by_traj (self ,nmax :int = 3 ,cmap = 'Reds' ,cmap0 = 'jet' ,** kwargs ):
49
+ fig ,ax = plt .subplots (nmax ,self .n_trajs )
50
+ if nmax == 1 :ax = np .array ([ax ])
51
+
52
+ cm = plt .get_cmap (cmap )
53
+ colors = np .array ([cm (k ) for k in range (256 )])
54
+ colors [:,- 1 ]= np .linspace (0 ,1 ,257 )[1 :]** .25
55
+ cm = ListedColormap (colors )
56
+
57
+ for q ,ax0 in enumerate (ax .T ):
58
+ for ax00 ,n0 ,n1 in zip (ax0 ,range (nmax ),range (1 ,nmax + 1 )):
59
+ self .Hist .plot (n0 ,n1 ,ax = ax00 ,cmap = cmap0 ,** kwargs )
60
+ index = np .zeros (len (self .traj ),dtype = bool )
61
+ index [self .ranges [q ][0 ]:self .ranges [q ][1 ]]= True
62
+ self .Hist .plot (n0 ,n1 ,ax = ax00 ,cmap = cm ,index = index )
63
+
64
+ fig .tight_layout ()
65
+
66
+ return fig
67
+
68
+ def hist2struct (self ,nmax :int = 3 ,from_traj :bool = True ,select_str :str = 'protein' ,
69
+ ref_struct :bool = False ,cmap = 'Reds' ,cmap0 = 'jet' ,cmap_ch = 'gist_rainbow' ,n_colors = 10 ,** kwargs ):
70
+ fig ,ax = plt .subplots (nmax ,self .n_trajs )
71
+
72
+ cm = plt .get_cmap (cmap )
73
+ colors = np .array ([cm (k ) for k in range (255 )])
74
+ colors [:,- 1 ]= np .linspace (0 ,1 ,255 )** .25
75
+ cm = ListedColormap (colors )
76
+
77
+ for q ,ax0 in enumerate (ax .T ):
78
+ xlims = []
79
+ ylims = []
80
+ for ax00 ,n0 ,n1 in zip (ax0 ,range (nmax ),range (1 ,nmax + 1 )):
81
+ self .Hist .plot (n0 ,n1 ,ax = ax00 ,cmap = cmap0 ,** kwargs )
82
+ ax00 .set_xlim (ax00 .get_xlim ())
83
+ ax00 .set_ylim (ax00 .get_ylim ())
84
+ xlims .append (ax00 .get_xlim ())
85
+ ylims .append (ax00 .get_ylim ())
86
+ self .pcas [q ].Hist .hist2struct (nmax = nmax ,from_traj = from_traj ,select_str = select_str ,
87
+ ref_struct = ref_struct ,ax = ax0 .tolist (),
88
+ cmap = cm ,cmap_ch = cmap_ch ,n_colors = n_colors ,** kwargs )
89
+ for ax00 ,xlim ,ylim in zip (ax0 ,xlims ,ylims ):
90
+ ax00 .set_xlim (xlim )
91
+ ax00 .set_ylim (ylim )
92
+
93
+ return fig
94
+
95
+ def hist_t_depend (self ,nmax :int = 3 ,cmap = 'jet' ,cmap0 = 'Greys' ,step :int = 20 ,** kwargs ):
96
+ fig ,ax = plt .subplots (nmax ,self .n_trajs )
97
+ if nmax == 1 :ax = np .array ([ax ])
98
+
99
+ if isinstance (cmap ,str ):cmap = plt .get_cmap (cmap )
100
+
101
+
102
+ for q ,ax0 in enumerate (ax .T ):
103
+ for ax00 ,n0 ,n1 in zip (ax0 ,range (nmax ),range (1 ,nmax + 1 )):
104
+ self .Hist .plot (n0 ,n1 ,ax = ax00 ,cmap = cmap0 ,** kwargs )
105
+ index = np .zeros (len (self .traj ),dtype = bool )
106
+ index [self .ranges [q ][0 ]:self .ranges [q ][1 ]:step ]= True
107
+ cmap = cmap .resampled (index .sum ())
108
+ c = cmap (np .arange (index .sum ()))
109
+ ax00 .scatter (self .PCamp [n0 ][index ],self .PCamp [n1 ][index ],s = 1 ,c = c )
110
+
111
+ fig .tight_layout ()
112
+
113
+ return fig
114
+
115
+
116
+
117
+
118
+
119
+ @property
120
+ def filenames (self ):
121
+ return [pca .uni .filename for pca in self .pcas ]
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ class Traj ():
130
+ def __init__ (self ,combinePCA ):
131
+ self .cPCA = combinePCA
132
+ self ._index = 0
133
+ self ._traj_index = 0
134
+ self .mda_traj = self .trajs [0 ].mda_traj
135
+
136
+ @property
137
+ def trajs (self ):
138
+ return self .cPCA .trajs
139
+
140
+ @property
141
+ def pcas (self ):
142
+ return self .cPCA .pcas
143
+
144
+ @property
145
+ def lengths (self ):
146
+ return self .cPCA .lengths
147
+
148
+ def __getitem__ (self ,i ):
149
+ # q=np.argmax(i<np.cumsum(np.concatenate([[0],self.cPCA.lengths])))
150
+ q = np .argmax (i < np .cumsum (self .cPCA .lengths ))
151
+ pca = self .pcas [q ]
152
+ self .cPCA ._uni = pca .uni
153
+ self .cPCA ._atoms = pca .atoms
154
+ self .mda_traj = pca .traj .mda_traj
155
+ self ._traj_index = q
156
+ self ._index = i
157
+ return pca .traj [(i % len (self ))- self .lengths [:q ].sum ()]
158
+
159
+ @property
160
+ def traj_index (self ):
161
+ return self ._traj_index
162
+
163
+ @property
164
+ def index (self ):
165
+ return self ._index
166
+
167
+
168
+ def __len__ (self ):
169
+ return self .tf
170
+
171
+ @property
172
+ def t0 (self ):
173
+ return 0
174
+
175
+ @t0 .setter
176
+ def t0 (self ,t0 ):
177
+ pass
178
+
179
+ @property
180
+ def tf (self ):
181
+ return np .sum ([len (traj ) for traj in self .cPCA .trajs ])
182
+
183
+ @tf .setter
184
+ def tf (self ,tf ):
185
+ pass
186
+
187
+ @property
188
+ def step (self ):
189
+ return self .cPCA .trajs [0 ].step
190
+
191
+ @step .setter
192
+ def step (self ,step ):
193
+ pass
194
+
195
+ @property
196
+ def frame (self ):
197
+ return (self .mda_traj .frame - self .trajs [self ._traj_index ].t0 )// self .step
0 commit comments