6
6
7
7
import pandas as pd
8
8
from joblib import delayed
9
- from tqdm import tqdm
10
9
11
10
from pytimeloop .fastfusion .sim import SIM
12
- from pytimeloop .fastfusion .pareto import Pareto , check_correctness
11
+ from pytimeloop .fastfusion .pareto import Pareto
13
12
from pytimeloop .fastfusion .util import parallel , debugger_active
14
13
15
14
@@ -28,6 +27,8 @@ def mapping2sims(einsum_to_result: Mapping):
28
27
for einsum_id , compat_dict in einsum_to_result .items ():
29
28
r [einsum_id ] = [paretofy (k , v ) for k , v in compat_dict .items ()]
30
29
return list (r .values ())
30
+ def paretofy (k , v ):
31
+ return SIM (k , Pareto (pd .DataFrame (v ).fillna (0 )))
31
32
32
33
33
34
prev_time = 0
@@ -84,19 +85,20 @@ def consolidate(
84
85
85
86
86
87
def fuse_sims (
87
- sims : list [SIM ],
88
+ sims : dict [ str , list [SIM ] ],
88
89
resource2capacity : dict = None ,
89
90
return_nmappings_nbuckets : bool = False ,
90
- pre_filter : bool = True
91
91
):
92
92
nmappings = []
93
93
nbuckets = []
94
- resource2capacity = resource2capacity or {}
95
- sims = [s for s in sims ]
96
94
97
- for i , s in enumerate (sims ):
98
- print (f'SIM { i } tensors: { s [0 ].tensor_names } ' )
99
-
95
+ sims = list (sims .items ())
96
+
97
+ for einsum_id , s in sims :
98
+ print (f'SIM { einsum_id } tensors: { s [0 ].tensor_names } ' )
99
+
100
+ # TODO: Lookahead by one SIM. If we're going to create a tiling that has loops
101
+ # that are not in the ranks of the next SIM, we should drop that tiling.
100
102
# if pre_filter:
101
103
# for i in range(len(sims) - 1):
102
104
# left, right = sims[i], sims[i + 1]
@@ -109,29 +111,30 @@ def fuse_sims(
109
111
# print(f'Filtered {len(left)} -> {len(left2)} SIMs from Einsum {i}')
110
112
# print(f'Filtered {len(right)} -> {len(right2)} SIMs from Einsum {i + 1}')
111
113
112
- left = sims .pop (0 )
113
-
114
114
init_print_time ()
115
-
116
- if not sims :
117
- sims = copy . deepcopy ( sims )
115
+ if len ( sims ) == 1 :
116
+ left = copy . deepcopy ( sims [ 0 ][ 1 ])
117
+ sims = []
118
118
left = consolidate (
119
119
x = left ,
120
120
left = True ,
121
121
live_tensors = set (),
122
122
resource2capacity = resource2capacity ,
123
123
shared_tensors = set (),
124
124
)
125
-
126
- # TODO: Lookahead by one SIM. If we're going to create a tiling that has loops
127
- # that are not in the ranks of the next SIM, we should drop that tiling.
128
125
126
+ n_iterations = 0
127
+ total_iterations = len (sims )
128
+ left_einsum , left = sims .pop (0 )
129
129
while sims :
130
+ n_iterations += 1
130
131
nbuckets .append (len (left ))
131
132
nmappings .append (sum (len (s .mapping .data ) for s in left ))
132
133
133
- right = sims .pop (0 )
134
- live_tensors = set .union (set (), * [s [0 ].tensor_names for s in sims if s ])
134
+ right_einsum , right = sims .pop (0 )
135
+ print (f'\n Einsum { right_einsum } ({ n_iterations } /{ total_iterations } )' )
136
+
137
+ live_tensors = set .union (set (), * [s [0 ].tensor_names for _ , s in sims if s ])
135
138
shared_tensors = set (left [0 ].tensor_names ) & set (right [0 ].tensor_names )
136
139
137
140
right_tensors = right [0 ].tensor_names
@@ -144,23 +147,26 @@ def fuse_sims(
144
147
shared_tensors = shared_tensors ,
145
148
)
146
149
147
- left = SIM .combine_combineable (left , live_tensors | right_tensors )
148
- right = SIM .combine_combineable (right , live_tensors | left_tensors )
149
-
150
- print_time ("Combining" )
151
-
152
150
left = sorted (left , key = lambda x : len (x .mapping .data ), reverse = True )
153
151
right = sorted (right , key = lambda x : len (x .mapping .data ), reverse = True )
152
+ lr = parallel (
153
+ [delayed (lambda l : l .left_consolidate (live_tensors , resource2capacity , shared_tensors ))(l ) for l in left ] +
154
+ [delayed (lambda l : l .consolidate (live_tensors , resource2capacity , shared_tensors ))(l ) for l in right ],
155
+ pbar = f"Consolidating { left_einsum } <--> { right_einsum } " ,
156
+ )
157
+ left , right = lr [:len (left )], lr [len (left ):]
158
+ print_time (f"Consolidating" )
154
159
155
- left = parallel ([delayed (lambda l : l .left_consolidate (live_tensors , resource2capacity , shared_tensors ))(l ) for l in left ], pbar = "Left consolidate" )
156
- right = parallel ([delayed (lambda l : l .consolidate (live_tensors , resource2capacity , shared_tensors ))(l ) for l in right ], pbar = "Right consolidate" )
160
+ left = SIM .combine_combineable (left , live_tensors | right_tensors )
161
+ right = SIM .combine_combineable (right , live_tensors | left_tensors )
162
+ print_time (f"Combining" )
157
163
158
- print_time ("Consolidating" )
164
+ # left = parallel([delayed(lambda l: l.left_consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in left], pbar="Left consolidate")
165
+ # right = parallel([delayed(lambda l: l.consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in right], pbar="Right consolidate")
159
166
160
167
# Group left and right into buckets
161
168
right = SIM .group_right (right , left_tensors )
162
169
left = SIM .group_left (left , right_tensors )
163
-
164
170
print_time ("Grouping" )
165
171
166
172
for v in list (left .values ()) + list (right .values ()):
@@ -196,28 +202,27 @@ def fuse_sims(
196
202
197
203
if DELAY_MERGE :
198
204
combined = sorted (combined , key = lambda x : x ._predicted_mappings , reverse = True )
199
- for c , mapping in zip (combined , parallel ([c .mapping for c in combined ], pbar = 'Merging mappings' )):
205
+ for c , mapping in zip (combined , parallel ([c .mapping for c in combined ], pbar = f 'Merging mappings { left_einsum } <--> { right_einsum } ' )):
200
206
c .mapping = mapping
201
207
202
208
print_time ("Mapping merging" )
203
209
204
- print (
205
- f"\t Combining { sum (len (s ) for s in left )} ({ len (left )} ) x { sum (len (s ) for s in right )} ({ len (right )} ) -> { len (combined )} "
206
- )
210
+ print (f"\t Combining { sum (len (s ) for s in left )} ({ len (left )} ) x { sum (len (s ) for s in right )} ({ len (right )} ) -> { len (combined )} " )
207
211
# if DO_PRINT:
208
212
# for k in right:
209
213
# if k not in left:
210
214
# for b in right[k]:
211
215
# print(f"\tREVERSE: No match for {b.tiling}")
212
216
213
217
left = combined
214
- print (f"Number of buckets: { len (left )} " )
218
+ left_einsum = right_einsum
219
+ print (f"\t Number of buckets for Einsum { left_einsum } : { len (left )} " )
215
220
n_mappings = sum (len (s .mapping .data ) for s in left )
216
- print (f"Number of mappings: { n_mappings } " )
217
- print (f"Mappings per bucket: { n_mappings / len (left )} " )
221
+ print (f"\t Number of mappings for Einsum { left_einsum } : { n_mappings } " )
222
+ print (f"\t Mappings per bucket for Einsum { left_einsum } : { n_mappings / len (left )} " )
218
223
219
224
for s in left :
220
- s .left_consolidate (set () , resource2capacity )
225
+ s .left_consolidate (None , resource2capacity )
221
226
s_final = SIM .combine_combineable (left , set ())[0 ]
222
227
data = s_final .mapping .data
223
228
# check_correctness(data, set())
@@ -227,7 +232,3 @@ def fuse_sims(
227
232
if return_nmappings_nbuckets :
228
233
return data , nmappings , nbuckets
229
234
return data
230
-
231
-
232
- def paretofy (k , v ):
233
- return SIM (k , Pareto (pd .DataFrame (v ).fillna (0 )))
0 commit comments