Skip to content

Commit 2641599

Browse files
Speed improvements
1 parent 830a109 commit 2641599

File tree

10 files changed

+216
-160
lines changed

10 files changed

+216
-160
lines changed

pytimeloop/fastfusion/filter_mappings.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ def get_ffmt_tag_mha(
7373
to_try += [(other_ranks[4:], (3, 4))]
7474
tags.append("FFMT_LAST")
7575
else: # Middle Einsum in a chain
76-
if einsum_id == "AV":
77-
a, b = b, a
78-
other_ranks[-2], other_ranks[-1] = other_ranks[-1], other_ranks[-2]
76+
a, b = b, a
77+
other_ranks[-2], other_ranks[-1] = other_ranks[-1], other_ranks[-2]
7978
to_try += [(other_ranks[:4], (3, 4))]
8079
tags.append("FFMT_MIDDLE")
8180

pytimeloop/fastfusion/mapper/per_einsum_mapper_snowcat.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66

77
from combinatorics.dependent_product import dependent_product
88
from combinatorics.splitter import split_dependent_product
9+
import pandas as pd
910

1011
from pytimeloop.fastfusion.fastmodel import compile_mapping
1112
from pytimeloop.fastfusion.mapper.constraints import *
1213
from pytimeloop.fastfusion.mapper.per_einsum_mapper import explore_tile_shape, process_result, get_hardware_levels
1314
from pytimeloop.fastfusion.mapper.per_einsum_subspaces.snowcat import make_subspaces
1415
from pytimeloop.fastfusion.mapper.per_einsum_subspaces.snowcat_ffmt import make_ffmt_subspaces
16+
from pytimeloop.fastfusion.pareto import Pareto, makepareto
17+
from pytimeloop.fastfusion.sim import SIM
1518
from pytimeloop.fastfusion.util import parallel
1619
from pytimeloop.looptree.equivalent_ranks import EquivalentGroups
1720
from pytimeloop.looptree.mapping_utilities import get_intermediate_tensors
@@ -159,7 +162,7 @@ def per_worker_exploration(*args):
159162
output_tensors=output_tensors,
160163
tag_with=tag_with,
161164
)
162-
return einsum_id, result
165+
return einsum_id, {k: makepareto(pd.DataFrame(v).fillna(0)) for k, v in result.items()}
163166

164167

165168
# # for pm in partial_mappings:
@@ -211,13 +214,21 @@ def per_einsum_mapper_snowcat(
211214
dataflow_constraint=dataflow_constraint,
212215
metrics=metrics,
213216
tag_with=tag_with,
214-
)
217+
)
215218
)
216219
data = {einsum_id: defaultdict(list) for einsum_id in einsums_to_explore}
217220

218-
for einsum_id, result in parallel(jobs, return_as="generator_unordered", pbar="Generating Single-Einsum Mappings"):
221+
for einsum_id, result in parallel(jobs, pbar="Generating Single-Einsum Mappings"):
219222
d = data[einsum_id]
220223
for k, v in result.items():
221-
d[k[0]] += v
224+
d[k[0]].append(v)
222225

223-
return data
226+
def makesim(einsum_id, tiling, data):
227+
return einsum_id, SIM(tiling, Pareto(pd.concat(data).fillna(0), skip_pareto=len(data) == 1))
228+
229+
data2 = defaultdict(list)
230+
jobs = [delayed(makesim)(einsum_id, tiling, data) for einsum_id, tilings in data.items() for tiling, data in tilings.items()]
231+
for einsum_id, sim in parallel(jobs, pbar="Generating SIMs"):
232+
data2[einsum_id].append(sim)
233+
234+
return data2

pytimeloop/fastfusion/mapper/per_einsum_subspaces/snowcat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def fused_temporal_fors(mapping, unfused_tensors):
4242

4343
def glb_storage(mapping, unfused_tensors):
4444
glb_fused_tensors = intermediate_tensors - unfused_tensors
45+
last_fused_loop_idx = get_last_fused_loop_idx(mapping, intermediate_tensors)
46+
# last_fused_loop_idx = None
4547
for partial_mapping in make_storage(mapping,
4648
level=1,
4749
must_retain_tensors=intermediate_tensors,
@@ -51,9 +53,8 @@ def glb_storage(mapping, unfused_tensors):
5153
explore_uneven=True,
5254
add_split_at_tensors=glb_fused_tensors,
5355
must_have_terminal_storage=False,
54-
apply_lrp_after_loop_idx=None):
55-
last_fused_loop_idx = get_last_fused_loop_idx(partial_mapping,
56-
intermediate_tensors)
56+
apply_lrp_after_loop_idx=last_fused_loop_idx):
57+
last_fused_loop_idx = get_last_fused_loop_idx(partial_mapping, intermediate_tensors)
5758
yield from make_storage(partial_mapping,
5859
level=1,
5960
must_retain_tensors=tensors - intermediate_tensors,

pytimeloop/fastfusion/mapper/per_einsum_subspaces/subspaces/storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def make_storage(
4343

4444
retained_tensors = must_retain_tensors | set(also_retained_tensors)
4545
mapping.add_storage(level, retained_tensors)
46-
if any(t in add_split_at_tensors for t in retained_tensors):
47-
mapping.add_sequential()
46+
# if any(t in add_split_at_tensors for t in retained_tensors):
47+
# mapping.add_sequential()
4848

4949
if return_retained_tensors:
5050
yield mapping, retained_tensors
@@ -115,8 +115,8 @@ def make_storage(
115115
for idx, tensors_at_idx in sorted(idx_to_tensors.items(),
116116
key=lambda pair: pair[0],
117117
reverse=True):
118-
if any(t in add_split_at_tensors for t in tensors_at_idx):
119-
mapping.add_sequential(idx)
118+
# if any(t in add_split_at_tensors for t in tensors_at_idx):
119+
# mapping.add_sequential(idx)
120120
mapping.add_storage(level, tensors_at_idx, idx)
121121
# Check for any irrelevant loops above the backing storage for a tensor
122122
for t in tensors_at_idx:

pytimeloop/fastfusion/mapper/process_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def record_storage(node):
9292
tensor_id_to_name[dspace],
9393
len(full_tiling),
9494
node["target"],
95-
result.occupancy[(node["target"], dspace)],
95+
int(result.occupancy[(node["target"], dspace)]),
9696
)
9797
all_storages.append(storage)
9898
if storage.tensor_id in intermediates_to_find:
@@ -178,7 +178,7 @@ def record_loop(node):
178178
for r in all_storages:
179179
r: TensorStorage
180180
if r not in backing_storages:
181-
key = nameloop2col(r.backer_id, r.above_loop_index)
181+
key = nameloop2col(r.backer_id, min(r.above_loop_index, n_fused_loops))
182182
results.setdefault(key, 0)
183183
results[key] += r.tile_size
184184
# logstring.append(f"{r}")

pytimeloop/fastfusion/mapper/simexplore.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
import pandas as pd
88
from joblib import delayed
9-
from tqdm import tqdm
109

1110
from pytimeloop.fastfusion.sim import SIM
12-
from pytimeloop.fastfusion.pareto import Pareto, check_correctness
11+
from pytimeloop.fastfusion.pareto import Pareto
1312
from pytimeloop.fastfusion.util import parallel, debugger_active
1413

1514

@@ -28,6 +27,8 @@ def mapping2sims(einsum_to_result: Mapping):
2827
for einsum_id, compat_dict in einsum_to_result.items():
2928
r[einsum_id] = [paretofy(k, v) for k, v in compat_dict.items()]
3029
return list(r.values())
30+
def paretofy(k, v):
31+
return SIM(k, Pareto(pd.DataFrame(v).fillna(0)))
3132

3233

3334
prev_time = 0
@@ -84,19 +85,20 @@ def consolidate(
8485

8586

8687
def fuse_sims(
87-
sims: list[SIM],
88+
sims: dict[str, list[SIM]],
8889
resource2capacity: dict = None,
8990
return_nmappings_nbuckets: bool = False,
90-
pre_filter: bool = True
9191
):
9292
nmappings = []
9393
nbuckets = []
94-
resource2capacity = resource2capacity or {}
95-
sims = [s for s in sims]
9694

97-
for i, s in enumerate(sims):
98-
print(f'SIM {i} tensors: {s[0].tensor_names}')
99-
95+
sims = list(sims.items())
96+
97+
for einsum_id, s in sims:
98+
print(f'SIM {einsum_id} tensors: {s[0].tensor_names}')
99+
100+
# TODO: Lookahead by one SIM. If we're going to create a tiling that has loops
101+
# that are not in the ranks of the next SIM, we should drop that tiling.
100102
# if pre_filter:
101103
# for i in range(len(sims) - 1):
102104
# left, right = sims[i], sims[i + 1]
@@ -109,29 +111,30 @@ def fuse_sims(
109111
# print(f'Filtered {len(left)} -> {len(left2)} SIMs from Einsum {i}')
110112
# print(f'Filtered {len(right)} -> {len(right2)} SIMs from Einsum {i + 1}')
111113

112-
left = sims.pop(0)
113-
114114
init_print_time()
115-
116-
if not sims:
117-
sims = copy.deepcopy(sims)
115+
if len(sims) == 1:
116+
left = copy.deepcopy(sims[0][1])
117+
sims = []
118118
left = consolidate(
119119
x=left,
120120
left=True,
121121
live_tensors=set(),
122122
resource2capacity=resource2capacity,
123123
shared_tensors=set(),
124124
)
125-
126-
# TODO: Lookahead by one SIM. If we're going to create a tiling that has loops
127-
# that are not in the ranks of the next SIM, we should drop that tiling.
128125

126+
n_iterations = 0
127+
total_iterations = len(sims)
128+
left_einsum, left = sims.pop(0)
129129
while sims:
130+
n_iterations += 1
130131
nbuckets.append(len(left))
131132
nmappings.append(sum(len(s.mapping.data) for s in left))
132133

133-
right = sims.pop(0)
134-
live_tensors = set.union(set(), *[s[0].tensor_names for s in sims if s])
134+
right_einsum, right = sims.pop(0)
135+
print(f'\nEinsum {right_einsum} ({n_iterations}/{total_iterations})')
136+
137+
live_tensors = set.union(set(), *[s[0].tensor_names for _, s in sims if s])
135138
shared_tensors = set(left[0].tensor_names) & set(right[0].tensor_names)
136139

137140
right_tensors = right[0].tensor_names
@@ -144,23 +147,26 @@ def fuse_sims(
144147
shared_tensors=shared_tensors,
145148
)
146149

147-
left = SIM.combine_combineable(left, live_tensors | right_tensors)
148-
right = SIM.combine_combineable(right, live_tensors | left_tensors)
149-
150-
print_time("Combining")
151-
152150
left = sorted(left, key=lambda x: len(x.mapping.data), reverse=True)
153151
right = sorted(right, key=lambda x: len(x.mapping.data), reverse=True)
152+
lr = parallel(
153+
[delayed(lambda l: l.left_consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in left] +
154+
[delayed(lambda l: l.consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in right],
155+
pbar=f"Consolidating {left_einsum} <--> {right_einsum}",
156+
)
157+
left, right = lr[:len(left)], lr[len(left):]
158+
print_time(f"Consolidating")
154159

155-
left = parallel([delayed(lambda l: l.left_consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in left], pbar="Left consolidate")
156-
right = parallel([delayed(lambda l: l.consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in right], pbar="Right consolidate")
160+
left = SIM.combine_combineable(left, live_tensors | right_tensors)
161+
right = SIM.combine_combineable(right, live_tensors | left_tensors)
162+
print_time(f"Combining")
157163

158-
print_time("Consolidating")
164+
# left = parallel([delayed(lambda l: l.left_consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in left], pbar="Left consolidate")
165+
# right = parallel([delayed(lambda l: l.consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in right], pbar="Right consolidate")
159166

160167
# Group left and right into buckets
161168
right = SIM.group_right(right, left_tensors)
162169
left = SIM.group_left(left, right_tensors)
163-
164170
print_time("Grouping")
165171

166172
for v in list(left.values()) + list(right.values()):
@@ -196,28 +202,27 @@ def fuse_sims(
196202

197203
if DELAY_MERGE:
198204
combined = sorted(combined, key=lambda x: x._predicted_mappings, reverse=True)
199-
for c, mapping in zip(combined, parallel([c.mapping for c in combined], pbar='Merging mappings')):
205+
for c, mapping in zip(combined, parallel([c.mapping for c in combined], pbar=f'Merging mappings {left_einsum} <--> {right_einsum}')):
200206
c.mapping = mapping
201207

202208
print_time("Mapping merging")
203209

204-
print(
205-
f"\tCombining {sum(len(s) for s in left)}({len(left)}) x {sum(len(s) for s in right)}({len(right)}) -> {len(combined)}"
206-
)
210+
print(f"\tCombining {sum(len(s) for s in left)}({len(left)}) x {sum(len(s) for s in right)}({len(right)}) -> {len(combined)}")
207211
# if DO_PRINT:
208212
# for k in right:
209213
# if k not in left:
210214
# for b in right[k]:
211215
# print(f"\tREVERSE: No match for {b.tiling}")
212216

213217
left = combined
214-
print(f"Number of buckets: {len(left)}")
218+
left_einsum = right_einsum
219+
print(f"\tNumber of buckets for Einsum {left_einsum}: {len(left)}")
215220
n_mappings = sum(len(s.mapping.data) for s in left)
216-
print(f"Number of mappings: {n_mappings}")
217-
print(f"Mappings per bucket: {n_mappings / len(left)}")
221+
print(f"\tNumber of mappings for Einsum {left_einsum}: {n_mappings}")
222+
print(f"\tMappings per bucket for Einsum {left_einsum}: {n_mappings / len(left)}")
218223

219224
for s in left:
220-
s.left_consolidate(set(), resource2capacity)
225+
s.left_consolidate(None, resource2capacity)
221226
s_final = SIM.combine_combineable(left, set())[0]
222227
data = s_final.mapping.data
223228
# check_correctness(data, set())
@@ -227,7 +232,3 @@ def fuse_sims(
227232
if return_nmappings_nbuckets:
228233
return data, nmappings, nbuckets
229234
return data
230-
231-
232-
def paretofy(k, v):
233-
return SIM(k, Pareto(pd.DataFrame(v).fillna(0)))

0 commit comments

Comments
 (0)