Skip to content

Commit 14aa08f

Browse files
Parallelization improvements
1 parent e39f20a commit 14aa08f

File tree

6 files changed

+334
-206
lines changed

6 files changed

+334
-206
lines changed
Lines changed: 188 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22
from collections import defaultdict
3+
import itertools
34

45
from joblib import delayed
56

@@ -19,159 +20,204 @@
1920
from bindings.looptree import LooptreeWorkload, LooptreeWorkloadDependencyAnalyzer
2021

2122

22-
def per_einsum_mapper_snowcat(
23+
def _per_einsum_mapper_snowcat(
2324
config,
2425
spec,
2526
explore_glb_uneven,
26-
einsums_to_explore,
27+
einsum_id,
2728
energy_dict,
2829
ffmt=False,
2930
ffmt_refetch_weights=True,
3031
dataflow_constraint=None,
3132
metrics=Metrics.all_metrics(),
3233
tag_with: tuple[callable] = (),
3334
):
34-
data = {}
35-
for einsum_id in einsums_to_explore:
36-
workload = LooptreeWorkload.parse_cfg(config.root["problem"])
35+
workload = LooptreeWorkload.parse_cfg(config.root["problem"])
36+
analyzer = LooptreeWorkloadDependencyAnalyzer(workload)
37+
equivalent_groups = EquivalentGroups.from_workload(workload, analyzer)
38+
39+
einsum_id_to_name = workload.einsum_id_to_name()
40+
rank_name_to_id = workload.dimension_name_to_id()
41+
tensor_name_to_id = workload.data_space_name_to_id()
42+
43+
tensors = workload.tensors_read_by_einsum(einsum_id) \
44+
| workload.tensors_written_by_einsum(einsum_id)
45+
intermediate_tensors = tensors & get_intermediate_tensors(workload)
46+
all_ranks = workload.einsum_ospace_dimensions(einsum_id)
47+
48+
bindings, max_fanout, max_capacity = get_hardware_levels(spec.architecture)
49+
50+
all_ranks = workload.einsum_ospace_dimensions(einsum_id)
51+
52+
tensor_to_relevant_ranks = {
53+
tensor: analyzer.einsum_dims_relevant_to_tensor(einsum_id, tensor)
54+
for tensor in tensors
55+
}
56+
57+
einsum_shape = {
58+
rank_id: workload.get_rank_shape(rank_id)[1] + 1 for rank_id in all_ranks
59+
}
60+
61+
62+
tensor_to_relevant_ranks = {
63+
tensor: analyzer.einsum_dims_relevant_to_tensor(einsum_id, tensor)
64+
for tensor in tensors
65+
}
66+
67+
if not ffmt:
68+
subspaces = make_subspaces(tensors,
69+
intermediate_tensors,
70+
tensor_to_relevant_ranks,
71+
einsum_id,
72+
workload,
73+
dataflow_constraint[einsum_id])
74+
else:
75+
subspaces = make_ffmt_subspaces(tensors,
76+
intermediate_tensors,
77+
tensor_to_relevant_ranks,
78+
einsum_id,
79+
workload,
80+
refetch_weights=ffmt_refetch_weights)
81+
82+
n_jobs=32
83+
parallelized_spaces, task_spaces = \
84+
split_dependent_product(n_split_min=n_jobs, spaces=subspaces)
85+
86+
partial_mappings = list(dependent_product(parallelized_spaces))
87+
partial_mappings = [x if isinstance(x, tuple) else (x,) for x in partial_mappings]
88+
rank_id_to_name = {v: k for k, v in rank_name_to_id.items()}
89+
tensor_id_to_name = {v: k for k, v in tensor_name_to_id.items()}
90+
input_tensors = set(tensor_id_to_name[t] for t in workload.tensors_read_by_einsum(einsum_id))
91+
output_tensors = set(tensor_id_to_name[t] for t in workload.tensors_written_by_einsum(einsum_id))
92+
rank_name_to_shared_name = {
93+
rank_id_to_name[k]: rank_id_to_name[v] for k, v in equivalent_groups.rank_to_group_id.items()
94+
}
95+
96+
# successful_partial_mappings = []
97+
# for p in partial_mappings:
98+
# partial_mapping = p[0]
99+
# found_storages = set()
100+
# fail = False
101+
# for i, p in enumerate(partial_mapping):
102+
# if p["type"] == "storage":
103+
# for t in set(p["dspace"]) - found_storages:
104+
# for p2 in partial_mapping[:i]:
105+
# if p2["type"] in ["temporal", "spatial"] and p2["rank"] not in tensor_to_relevant_ranks[t]:
106+
# fail = True
107+
# break
108+
# found_storages |= set(p["dspace"])
109+
# if len(found_storages) < len(tensors) or i == 0:
110+
# continue
111+
# prev = partial_mapping[i - 1]
112+
# for t in ["spatial"]: # "temporal", TEMPORAL DOESN"T WORK. WEIRD INTERACTIONS WITH LOOP RELEVANCE PRINCIPLEz
113+
# if not fail:
114+
# successful_partial_mappings.append(p)
115+
# partial_mappings = successful_partial_mappings
116+
117+
def per_worker_exploration(*args):
37118
analyzer = LooptreeWorkloadDependencyAnalyzer(workload)
38-
equivalent_groups = EquivalentGroups.from_workload(workload, analyzer)
39-
40-
einsum_id_to_name = workload.einsum_id_to_name()
41-
rank_name_to_id = workload.dimension_name_to_id()
42-
tensor_name_to_id = workload.data_space_name_to_id()
43-
44-
tensors = workload.tensors_read_by_einsum(einsum_id) \
45-
| workload.tensors_written_by_einsum(einsum_id)
46-
intermediate_tensors = tensors & get_intermediate_tensors(workload)
47-
all_ranks = workload.einsum_ospace_dimensions(einsum_id)
48-
49-
bindings, max_fanout, max_capacity = get_hardware_levels(spec.architecture)
50-
51-
all_ranks = workload.einsum_ospace_dimensions(einsum_id)
52-
53-
tensor_to_relevant_ranks = {
54-
tensor: analyzer.einsum_dims_relevant_to_tensor(einsum_id, tensor)
55-
for tensor in tensors
56-
}
57-
58-
einsum_shape = {
59-
rank_id: workload.get_rank_shape(rank_id)[1] + 1 for rank_id in all_ranks
60-
}
61-
62-
63-
tensor_to_relevant_ranks = {
64-
tensor: analyzer.einsum_dims_relevant_to_tensor(einsum_id, tensor)
65-
for tensor in tensors
66-
}
67-
68-
if not ffmt:
69-
subspaces = make_subspaces(tensors,
70-
intermediate_tensors,
71-
tensor_to_relevant_ranks,
72-
einsum_id,
73-
workload,
74-
dataflow_constraint[einsum_id])
75-
else:
76-
subspaces = make_ffmt_subspaces(tensors,
77-
intermediate_tensors,
78-
tensor_to_relevant_ranks,
79-
einsum_id,
80-
workload,
81-
refetch_weights=ffmt_refetch_weights)
82-
83-
n_jobs=32
84-
parallelized_spaces, task_spaces = \
85-
split_dependent_product(n_split_min=n_jobs, spaces=subspaces)
86-
87-
partial_mappings = list(dependent_product(parallelized_spaces))
88-
partial_mappings = [x if isinstance(x, tuple) else (x,) for x in partial_mappings]
89-
rank_id_to_name = {v: k for k, v in rank_name_to_id.items()}
90-
tensor_id_to_name = {v: k for k, v in tensor_name_to_id.items()}
91-
input_tensors = set(tensor_id_to_name[t] for t in workload.tensors_read_by_einsum(einsum_id))
92-
output_tensors = set(tensor_id_to_name[t] for t in workload.tensors_written_by_einsum(einsum_id))
93-
rank_name_to_shared_name = {
94-
rank_id_to_name[k]: rank_id_to_name[v] for k, v in equivalent_groups.rank_to_group_id.items()
95-
}
96-
97-
# successful_partial_mappings = []
98-
# for p in partial_mappings:
99-
# partial_mapping = p[0]
100-
# found_storages = set()
101-
# fail = False
102-
# for i, p in enumerate(partial_mapping):
103-
# if p["type"] == "storage":
104-
# for t in set(p["dspace"]) - found_storages:
105-
# for p2 in partial_mapping[:i]:
106-
# if p2["type"] in ["temporal", "spatial"] and p2["rank"] not in tensor_to_relevant_ranks[t]:
107-
# fail = True
108-
# break
109-
# found_storages |= set(p["dspace"])
110-
# if len(found_storages) < len(tensors) or i == 0:
111-
# continue
112-
# prev = partial_mapping[i - 1]
113-
# for t in ["spatial"]: # "temporal", TEMPORAL DOESN"T WORK. WEIRD INTERACTIONS WITH LOOP RELEVANCE PRINCIPLE
114-
# if p["type"] == t and prev["type"] == t and p["rank"] < prev["rank"]:
115-
# fail = True
116-
# if not fail:
117-
# successful_partial_mappings.append(p)
118-
# partial_mappings = successful_partial_mappings
119-
120-
def per_worker_exploration(*args):
121-
analyzer = LooptreeWorkloadDependencyAnalyzer(workload)
122-
local_task_spaces = deepcopy(task_spaces)
123-
local_task_spaces[0] = lambda : task_spaces[0](*args)
124-
result = defaultdict(list)
125-
for partial_mapping in dependent_product(local_task_spaces):
126-
_, compiled_results = compile_mapping(
127-
partial_mapping, workload, analyzer
128-
)
129-
tile_shape_explorer = explore_tile_shape(
119+
local_task_spaces = deepcopy(task_spaces)
120+
local_task_spaces[0] = lambda : task_spaces[0](*args)
121+
result = defaultdict(list)
122+
for partial_mapping in dependent_product(local_task_spaces):
123+
_, compiled_results = compile_mapping(
124+
partial_mapping, workload, analyzer
125+
)
126+
tile_shape_explorer = explore_tile_shape(
127+
partial_mapping,
128+
einsum_shape,
129+
compiled_results,
130+
max_capacity,
131+
max_fanout,
132+
tensors=tensors,
133+
)
134+
# HACKY: Pop out the subspace object as the first in the iterator
135+
shape_subspace = next(tile_shape_explorer)
136+
137+
count = 0
138+
for shape, res in tile_shape_explorer:
139+
count += 1
140+
is_pareto, results, fulltiling = process_result(
141+
res,
142+
shape,
143+
result,
144+
einsum_id,
145+
intermediate_tensors,
130146
partial_mapping,
131-
einsum_shape,
132-
compiled_results,
133-
max_capacity,
134-
max_fanout,
135-
tensors=tensors,
147+
bindings,
148+
workload,
149+
energy_dict,
150+
equivalent_groups,
151+
explore_fusion_uneven=explore_glb_uneven,
152+
einsum_shape=einsum_shape,
153+
metrics=metrics,
154+
einsum_id_to_name=einsum_id_to_name,
155+
rank_id_to_name=rank_id_to_name,
156+
tensor_id_to_name=tensor_id_to_name,
157+
rank_name_to_shared_name=rank_name_to_shared_name,
158+
input_tensors=input_tensors,
159+
output_tensors=output_tensors,
160+
tag_with=tag_with,
136161
)
137-
# HACKY: Pop out the subspace object as the first in the iterator
138-
shape_subspace = next(tile_shape_explorer)
139-
140-
count = 0
141-
for shape, res in tile_shape_explorer:
142-
count += 1
143-
is_pareto, results, fulltiling = process_result(
144-
res,
145-
shape,
146-
result,
147-
einsum_id,
148-
intermediate_tensors,
149-
partial_mapping,
150-
bindings,
151-
workload,
152-
energy_dict,
153-
equivalent_groups,
154-
explore_fusion_uneven=explore_glb_uneven,
155-
einsum_shape=einsum_shape,
156-
metrics=metrics,
157-
einsum_id_to_name=einsum_id_to_name,
158-
rank_id_to_name=rank_id_to_name,
159-
tensor_id_to_name=tensor_id_to_name,
160-
rank_name_to_shared_name=rank_name_to_shared_name,
161-
input_tensors=input_tensors,
162-
output_tensors=output_tensors,
163-
tag_with=tag_with,
164-
)
165-
return result
166-
167-
# for pm in partial_mappings:
168-
# per_worker_exploration(*pm)
169-
results = parallel(delayed(per_worker_exploration)(*pm) for pm in partial_mappings)
170-
171-
data[einsum_id] = defaultdict(list)
172-
for res in results:
173-
for k, v in res.items():
174-
data[einsum_id][k[0]] += v
175-
176-
return data
162+
return einsum_id, result
163+
164+
165+
# # for pm in partial_mappings:
166+
# # per_worker_exploration(*pm)
167+
# data[einsum_id] = defaultdict(list)
168+
# for res in parallel(
169+
# [delayed(per_worker_exploration)(*pm) for pm in partial_mappings],
170+
# return_as_generator=True,
171+
# pbar=f"Generating data for Einsum {einsum_id}. {i+1}/{len(einsums_to_explore)}",
172+
# ):
173+
# for k, v in res.items():
174+
# data[einsum_id][k[0]] += v
177175

176+
return [delayed(per_worker_exploration)(*pm) for pm in partial_mappings]
177+
178+
def per_einsum_mapper_snowcat(
179+
config,
180+
spec,
181+
explore_glb_uneven,
182+
einsums_to_explore,
183+
energy_dict,
184+
ffmt=False,
185+
ffmt_refetch_weights=True,
186+
dataflow_constraint=None,
187+
metrics=Metrics.all_metrics(),
188+
tag_with: tuple[callable] = (),
189+
):
190+
# return _per_einsum_mapper_snowcat(
191+
# config,
192+
# spec,
193+
# explore_glb_uneven,
194+
# einsums_to_explore,
195+
# energy_dict,
196+
# ffmt=ffmt,
197+
# ffmt_refetch_weights=ffmt_refetch_weights,
198+
# dataflow_constraint=dataflow_constraint,
199+
# metrics=metrics,
200+
# tag_with=tag_with,
201+
# )
202+
203+
jobs = list(j for einsum_id in einsums_to_explore for j in _per_einsum_mapper_snowcat(
204+
config,
205+
spec,
206+
explore_glb_uneven,
207+
einsum_id,
208+
energy_dict,
209+
ffmt=ffmt,
210+
ffmt_refetch_weights=ffmt_refetch_weights,
211+
dataflow_constraint=dataflow_constraint,
212+
metrics=metrics,
213+
tag_with=tag_with,
214+
)
215+
)
216+
data = {einsum_id: defaultdict(list) for einsum_id in einsums_to_explore}
217+
218+
for einsum_id, result in parallel(jobs, return_as="generator_unordered", pbar="Generating data for Einsums"):
219+
d = data[einsum_id]
220+
for k, v in result.items():
221+
d[k[0]] += v
222+
223+
return data

0 commit comments

Comments
 (0)