1
1
from copy import deepcopy
2
2
from collections import defaultdict
3
+ import itertools
3
4
4
5
from joblib import delayed
5
6
19
20
from bindings .looptree import LooptreeWorkload , LooptreeWorkloadDependencyAnalyzer
20
21
21
22
22
- def per_einsum_mapper_snowcat (
23
+ def _per_einsum_mapper_snowcat (
23
24
config ,
24
25
spec ,
25
26
explore_glb_uneven ,
26
- einsums_to_explore ,
27
+ einsum_id ,
27
28
energy_dict ,
28
29
ffmt = False ,
29
30
ffmt_refetch_weights = True ,
30
31
dataflow_constraint = None ,
31
32
metrics = Metrics .all_metrics (),
32
33
tag_with : tuple [callable ] = (),
33
34
):
34
- data = {}
35
- for einsum_id in einsums_to_explore :
36
- workload = LooptreeWorkload .parse_cfg (config .root ["problem" ])
35
+ workload = LooptreeWorkload .parse_cfg (config .root ["problem" ])
36
+ analyzer = LooptreeWorkloadDependencyAnalyzer (workload )
37
+ equivalent_groups = EquivalentGroups .from_workload (workload , analyzer )
38
+
39
+ einsum_id_to_name = workload .einsum_id_to_name ()
40
+ rank_name_to_id = workload .dimension_name_to_id ()
41
+ tensor_name_to_id = workload .data_space_name_to_id ()
42
+
43
+ tensors = workload .tensors_read_by_einsum (einsum_id ) \
44
+ | workload .tensors_written_by_einsum (einsum_id )
45
+ intermediate_tensors = tensors & get_intermediate_tensors (workload )
46
+ all_ranks = workload .einsum_ospace_dimensions (einsum_id )
47
+
48
+ bindings , max_fanout , max_capacity = get_hardware_levels (spec .architecture )
49
+
50
+ all_ranks = workload .einsum_ospace_dimensions (einsum_id )
51
+
52
+ tensor_to_relevant_ranks = {
53
+ tensor : analyzer .einsum_dims_relevant_to_tensor (einsum_id , tensor )
54
+ for tensor in tensors
55
+ }
56
+
57
+ einsum_shape = {
58
+ rank_id : workload .get_rank_shape (rank_id )[1 ] + 1 for rank_id in all_ranks
59
+ }
60
+
61
+
62
+ tensor_to_relevant_ranks = {
63
+ tensor : analyzer .einsum_dims_relevant_to_tensor (einsum_id , tensor )
64
+ for tensor in tensors
65
+ }
66
+
67
+ if not ffmt :
68
+ subspaces = make_subspaces (tensors ,
69
+ intermediate_tensors ,
70
+ tensor_to_relevant_ranks ,
71
+ einsum_id ,
72
+ workload ,
73
+ dataflow_constraint [einsum_id ])
74
+ else :
75
+ subspaces = make_ffmt_subspaces (tensors ,
76
+ intermediate_tensors ,
77
+ tensor_to_relevant_ranks ,
78
+ einsum_id ,
79
+ workload ,
80
+ refetch_weights = ffmt_refetch_weights )
81
+
82
+ n_jobs = 32
83
+ parallelized_spaces , task_spaces = \
84
+ split_dependent_product (n_split_min = n_jobs , spaces = subspaces )
85
+
86
+ partial_mappings = list (dependent_product (parallelized_spaces ))
87
+ partial_mappings = [x if isinstance (x , tuple ) else (x ,) for x in partial_mappings ]
88
+ rank_id_to_name = {v : k for k , v in rank_name_to_id .items ()}
89
+ tensor_id_to_name = {v : k for k , v in tensor_name_to_id .items ()}
90
+ input_tensors = set (tensor_id_to_name [t ] for t in workload .tensors_read_by_einsum (einsum_id ))
91
+ output_tensors = set (tensor_id_to_name [t ] for t in workload .tensors_written_by_einsum (einsum_id ))
92
+ rank_name_to_shared_name = {
93
+ rank_id_to_name [k ]: rank_id_to_name [v ] for k , v in equivalent_groups .rank_to_group_id .items ()
94
+ }
95
+
96
+ # successful_partial_mappings = []
97
+ # for p in partial_mappings:
98
+ # partial_mapping = p[0]
99
+ # found_storages = set()
100
+ # fail = False
101
+ # for i, p in enumerate(partial_mapping):
102
+ # if p["type"] == "storage":
103
+ # for t in set(p["dspace"]) - found_storages:
104
+ # for p2 in partial_mapping[:i]:
105
+ # if p2["type"] in ["temporal", "spatial"] and p2["rank"] not in tensor_to_relevant_ranks[t]:
106
+ # fail = True
107
+ # break
108
+ # found_storages |= set(p["dspace"])
109
+ # if len(found_storages) < len(tensors) or i == 0:
110
+ # continue
111
+ # prev = partial_mapping[i - 1]
112
+ # for t in ["spatial"]: # "temporal", TEMPORAL DOESN"T WORK. WEIRD INTERACTIONS WITH LOOP RELEVANCE PRINCIPLEz
113
+ # if not fail:
114
+ # successful_partial_mappings.append(p)
115
+ # partial_mappings = successful_partial_mappings
116
+
117
+ def per_worker_exploration (* args ):
37
118
analyzer = LooptreeWorkloadDependencyAnalyzer (workload )
38
- equivalent_groups = EquivalentGroups .from_workload (workload , analyzer )
39
-
40
- einsum_id_to_name = workload .einsum_id_to_name ()
41
- rank_name_to_id = workload .dimension_name_to_id ()
42
- tensor_name_to_id = workload .data_space_name_to_id ()
43
-
44
- tensors = workload .tensors_read_by_einsum (einsum_id ) \
45
- | workload .tensors_written_by_einsum (einsum_id )
46
- intermediate_tensors = tensors & get_intermediate_tensors (workload )
47
- all_ranks = workload .einsum_ospace_dimensions (einsum_id )
48
-
49
- bindings , max_fanout , max_capacity = get_hardware_levels (spec .architecture )
50
-
51
- all_ranks = workload .einsum_ospace_dimensions (einsum_id )
52
-
53
- tensor_to_relevant_ranks = {
54
- tensor : analyzer .einsum_dims_relevant_to_tensor (einsum_id , tensor )
55
- for tensor in tensors
56
- }
57
-
58
- einsum_shape = {
59
- rank_id : workload .get_rank_shape (rank_id )[1 ] + 1 for rank_id in all_ranks
60
- }
61
-
62
-
63
- tensor_to_relevant_ranks = {
64
- tensor : analyzer .einsum_dims_relevant_to_tensor (einsum_id , tensor )
65
- for tensor in tensors
66
- }
67
-
68
- if not ffmt :
69
- subspaces = make_subspaces (tensors ,
70
- intermediate_tensors ,
71
- tensor_to_relevant_ranks ,
72
- einsum_id ,
73
- workload ,
74
- dataflow_constraint [einsum_id ])
75
- else :
76
- subspaces = make_ffmt_subspaces (tensors ,
77
- intermediate_tensors ,
78
- tensor_to_relevant_ranks ,
79
- einsum_id ,
80
- workload ,
81
- refetch_weights = ffmt_refetch_weights )
82
-
83
- n_jobs = 32
84
- parallelized_spaces , task_spaces = \
85
- split_dependent_product (n_split_min = n_jobs , spaces = subspaces )
86
-
87
- partial_mappings = list (dependent_product (parallelized_spaces ))
88
- partial_mappings = [x if isinstance (x , tuple ) else (x ,) for x in partial_mappings ]
89
- rank_id_to_name = {v : k for k , v in rank_name_to_id .items ()}
90
- tensor_id_to_name = {v : k for k , v in tensor_name_to_id .items ()}
91
- input_tensors = set (tensor_id_to_name [t ] for t in workload .tensors_read_by_einsum (einsum_id ))
92
- output_tensors = set (tensor_id_to_name [t ] for t in workload .tensors_written_by_einsum (einsum_id ))
93
- rank_name_to_shared_name = {
94
- rank_id_to_name [k ]: rank_id_to_name [v ] for k , v in equivalent_groups .rank_to_group_id .items ()
95
- }
96
-
97
- # successful_partial_mappings = []
98
- # for p in partial_mappings:
99
- # partial_mapping = p[0]
100
- # found_storages = set()
101
- # fail = False
102
- # for i, p in enumerate(partial_mapping):
103
- # if p["type"] == "storage":
104
- # for t in set(p["dspace"]) - found_storages:
105
- # for p2 in partial_mapping[:i]:
106
- # if p2["type"] in ["temporal", "spatial"] and p2["rank"] not in tensor_to_relevant_ranks[t]:
107
- # fail = True
108
- # break
109
- # found_storages |= set(p["dspace"])
110
- # if len(found_storages) < len(tensors) or i == 0:
111
- # continue
112
- # prev = partial_mapping[i - 1]
113
- # for t in ["spatial"]: # "temporal", TEMPORAL DOESN"T WORK. WEIRD INTERACTIONS WITH LOOP RELEVANCE PRINCIPLE
114
- # if p["type"] == t and prev["type"] == t and p["rank"] < prev["rank"]:
115
- # fail = True
116
- # if not fail:
117
- # successful_partial_mappings.append(p)
118
- # partial_mappings = successful_partial_mappings
119
-
120
- def per_worker_exploration (* args ):
121
- analyzer = LooptreeWorkloadDependencyAnalyzer (workload )
122
- local_task_spaces = deepcopy (task_spaces )
123
- local_task_spaces [0 ] = lambda : task_spaces [0 ](* args )
124
- result = defaultdict (list )
125
- for partial_mapping in dependent_product (local_task_spaces ):
126
- _ , compiled_results = compile_mapping (
127
- partial_mapping , workload , analyzer
128
- )
129
- tile_shape_explorer = explore_tile_shape (
119
+ local_task_spaces = deepcopy (task_spaces )
120
+ local_task_spaces [0 ] = lambda : task_spaces [0 ](* args )
121
+ result = defaultdict (list )
122
+ for partial_mapping in dependent_product (local_task_spaces ):
123
+ _ , compiled_results = compile_mapping (
124
+ partial_mapping , workload , analyzer
125
+ )
126
+ tile_shape_explorer = explore_tile_shape (
127
+ partial_mapping ,
128
+ einsum_shape ,
129
+ compiled_results ,
130
+ max_capacity ,
131
+ max_fanout ,
132
+ tensors = tensors ,
133
+ )
134
+ # HACKY: Pop out the subspace object as the first in the iterator
135
+ shape_subspace = next (tile_shape_explorer )
136
+
137
+ count = 0
138
+ for shape , res in tile_shape_explorer :
139
+ count += 1
140
+ is_pareto , results , fulltiling = process_result (
141
+ res ,
142
+ shape ,
143
+ result ,
144
+ einsum_id ,
145
+ intermediate_tensors ,
130
146
partial_mapping ,
131
- einsum_shape ,
132
- compiled_results ,
133
- max_capacity ,
134
- max_fanout ,
135
- tensors = tensors ,
147
+ bindings ,
148
+ workload ,
149
+ energy_dict ,
150
+ equivalent_groups ,
151
+ explore_fusion_uneven = explore_glb_uneven ,
152
+ einsum_shape = einsum_shape ,
153
+ metrics = metrics ,
154
+ einsum_id_to_name = einsum_id_to_name ,
155
+ rank_id_to_name = rank_id_to_name ,
156
+ tensor_id_to_name = tensor_id_to_name ,
157
+ rank_name_to_shared_name = rank_name_to_shared_name ,
158
+ input_tensors = input_tensors ,
159
+ output_tensors = output_tensors ,
160
+ tag_with = tag_with ,
136
161
)
137
- # HACKY: Pop out the subspace object as the first in the iterator
138
- shape_subspace = next (tile_shape_explorer )
139
-
140
- count = 0
141
- for shape , res in tile_shape_explorer :
142
- count += 1
143
- is_pareto , results , fulltiling = process_result (
144
- res ,
145
- shape ,
146
- result ,
147
- einsum_id ,
148
- intermediate_tensors ,
149
- partial_mapping ,
150
- bindings ,
151
- workload ,
152
- energy_dict ,
153
- equivalent_groups ,
154
- explore_fusion_uneven = explore_glb_uneven ,
155
- einsum_shape = einsum_shape ,
156
- metrics = metrics ,
157
- einsum_id_to_name = einsum_id_to_name ,
158
- rank_id_to_name = rank_id_to_name ,
159
- tensor_id_to_name = tensor_id_to_name ,
160
- rank_name_to_shared_name = rank_name_to_shared_name ,
161
- input_tensors = input_tensors ,
162
- output_tensors = output_tensors ,
163
- tag_with = tag_with ,
164
- )
165
- return result
166
-
167
- # for pm in partial_mappings:
168
- # per_worker_exploration(*pm)
169
- results = parallel (delayed (per_worker_exploration )(* pm ) for pm in partial_mappings )
170
-
171
- data [einsum_id ] = defaultdict (list )
172
- for res in results :
173
- for k , v in res .items ():
174
- data [einsum_id ][k [0 ]] += v
175
-
176
- return data
162
+ return einsum_id , result
163
+
164
+
165
+ # # for pm in partial_mappings:
166
+ # # per_worker_exploration(*pm)
167
+ # data[einsum_id] = defaultdict(list)
168
+ # for res in parallel(
169
+ # [delayed(per_worker_exploration)(*pm) for pm in partial_mappings],
170
+ # return_as_generator=True,
171
+ # pbar=f"Generating data for Einsum {einsum_id}. {i+1}/{len(einsums_to_explore)}",
172
+ # ):
173
+ # for k, v in res.items():
174
+ # data[einsum_id][k[0]] += v
177
175
176
+ return [delayed (per_worker_exploration )(* pm ) for pm in partial_mappings ]
177
+
178
+ def per_einsum_mapper_snowcat (
179
+ config ,
180
+ spec ,
181
+ explore_glb_uneven ,
182
+ einsums_to_explore ,
183
+ energy_dict ,
184
+ ffmt = False ,
185
+ ffmt_refetch_weights = True ,
186
+ dataflow_constraint = None ,
187
+ metrics = Metrics .all_metrics (),
188
+ tag_with : tuple [callable ] = (),
189
+ ):
190
+ # return _per_einsum_mapper_snowcat(
191
+ # config,
192
+ # spec,
193
+ # explore_glb_uneven,
194
+ # einsums_to_explore,
195
+ # energy_dict,
196
+ # ffmt=ffmt,
197
+ # ffmt_refetch_weights=ffmt_refetch_weights,
198
+ # dataflow_constraint=dataflow_constraint,
199
+ # metrics=metrics,
200
+ # tag_with=tag_with,
201
+ # )
202
+
203
+ jobs = list (j for einsum_id in einsums_to_explore for j in _per_einsum_mapper_snowcat (
204
+ config ,
205
+ spec ,
206
+ explore_glb_uneven ,
207
+ einsum_id ,
208
+ energy_dict ,
209
+ ffmt = ffmt ,
210
+ ffmt_refetch_weights = ffmt_refetch_weights ,
211
+ dataflow_constraint = dataflow_constraint ,
212
+ metrics = metrics ,
213
+ tag_with = tag_with ,
214
+ )
215
+ )
216
+ data = {einsum_id : defaultdict (list ) for einsum_id in einsums_to_explore }
217
+
218
+ for einsum_id , result in parallel (jobs , return_as = "generator_unordered" , pbar = "Generating data for Einsums" ):
219
+ d = data [einsum_id ]
220
+ for k , v in result .items ():
221
+ d [k [0 ]] += v
222
+
223
+ return data
0 commit comments