Skip to content

Commit 8e8aa56

Browse files
Added debug metrics
1 parent 14aa08f commit 8e8aa56

File tree

6 files changed

+73
-67
lines changed

6 files changed

+73
-67
lines changed

pytimeloop/fastfusion/mapper/mapper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytimeloop.fastfusion.mapper.logging import make_queue_and_listener
2222
from pytimeloop.fastfusion.mapper.per_einsum_mapper import get_top_loop_jobs, mapper_place_fusion_level
2323
from pytimeloop.fastfusion.sim import Tiling, Loop, TensorStorage
24-
from pytimeloop.fastfusion.pareto import LOGSTRING, MAPPING, STATS, DICT_COLUMNS
24+
from pytimeloop.fastfusion.pareto import LOGSTRING, MAPPING, STATS, DICT_COLUMNS, TENSORS
2525
from pytimeloop.fastfusion.mapper.process_results import Metrics
2626

2727
from pytimeloop.timeloopfe.v4 import Ert
@@ -153,8 +153,12 @@ def _convert_stats(from_einsum: int, to_einsum: int, stats, rank_renaming, tenso
153153
stats = deepcopy(stats)
154154
for s in stats:
155155
for d in DICT_COLUMNS:
156-
s[d][to_einsum] = s[d].pop(from_einsum)
157-
s[MAPPING][to_einsum] = s[MAPPING][to_einsum].rename(rank_renaming, tensor_renaming)
156+
if d in s:
157+
s[d][to_einsum] = s[d].pop(from_einsum)
158+
if MAPPING in s:
159+
s[MAPPING][to_einsum] = s[MAPPING][to_einsum].rename(rank_renaming, tensor_renaming)
160+
if TENSORS in s:
161+
s[TENSORS][to_einsum] = [t.rename(rank_renaming, tensor_renaming) for t in s[TENSORS][to_einsum]]
158162
return stats
159163

160164

pytimeloop/fastfusion/mapper/mapper_snowcat.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,12 @@ def _convert_stats(from_einsum: int, to_einsum: int, stats, rank_renaming, tenso
138138
stats = deepcopy(stats)
139139
for s in stats:
140140
for d in DICT_COLUMNS:
141-
s[d][to_einsum] = s[d].pop(from_einsum)
142-
s[MAPPING][to_einsum] = s[MAPPING][to_einsum].rename(rank_renaming, tensor_renaming)
143-
s[TENSORS][to_einsum] = [t.rename(rank_renaming, tensor_renaming) for t in s[TENSORS][to_einsum]]
141+
if d in s:
142+
s[d][to_einsum] = s[d].pop(from_einsum)
143+
if MAPPING in s:
144+
s[MAPPING][to_einsum] = s[MAPPING][to_einsum].rename(rank_renaming, tensor_renaming)
145+
if TENSORS in s:
146+
s[TENSORS][to_einsum] = [t.rename(rank_renaming, tensor_renaming) for t in s[TENSORS][to_einsum]]
144147
return stats
145148

146149

pytimeloop/fastfusion/mapper/process_results.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Metrics(Flag):
3030
# OCCUPANCY = auto()
3131
OFF_CHIP_ACCESSES = auto()
3232
OP_INTENSITY = auto()
33+
DEBUG = auto()
3334

3435
@classmethod
3536
def all_metrics(cls):
@@ -99,8 +100,9 @@ def record_storage(node):
99100
if storage.tensor_id not in found_tensors:
100101
found_tensors.add(storage.tensor_id)
101102
backing_storages.append(storage)
102-
103-
logstring.append(f"Strg({node['dspace']} in {node['target']})")
103+
104+
if Metrics.DEBUG in metrics:
105+
logstring.append(f"Strg({node['dspace']} in {node['target']})")
104106

105107
def record_loop(node):
106108
nonlocal cur_idx
@@ -118,7 +120,8 @@ def record_loop(node):
118120
)
119121
ranks_remaining[node["rank"]] = tile_shape
120122
full_tiling.append(loop)
121-
logstring.append(f"{node['type'][0].upper()}{node['rank']} size {tile_shape}")
123+
if Metrics.DEBUG in metrics:
124+
logstring.append(f"{node['type'][0].upper()}{node['rank']} size {tile_shape}")
122125

123126
logstring = []
124127
full_tiling = []
@@ -157,16 +160,17 @@ def record_loop(node):
157160
if Metrics.ENERGY in metrics:
158161
results["Energy"] = energy
159162

160-
offchip_ac = 0
161-
for (level, tensor, einsum), count in accesses.items():
162-
if level == 0:
163-
offchip_ac += count
164-
logstring.append(f"Ac_{level}_{tensor}={count:.2e}")
165-
166163
if Metrics.OFF_CHIP_ACCESSES in metrics:
164+
offchip_ac = 0
165+
for (level, tensor, einsum), count in accesses.items():
166+
if level == 0:
167+
offchip_ac += count
167168
results["Offchip Accesses"] = offchip_ac
168-
169-
logstring.append(f"{result.fanout}")
169+
if Metrics.DEBUG in metrics:
170+
logstring.append(f"Ac_{level}_{tensor}={count:.2e}")
171+
172+
if Metrics.DEBUG in metrics:
173+
logstring.append(f"{result.fanout}")
170174

171175
# Only record non-backing reservations. We'll reserve backing storage later
172176
# when we free the tensors & we know all operations for which the tensor must
@@ -179,26 +183,25 @@ def record_loop(node):
179183
results[key] += r.tile_size
180184
# logstring.append(f"{r}")
181185

182-
if Metrics.LATENCY in metrics:
186+
if Metrics.LATENCY in metrics and Metrics.DEBUG in metrics:
183187
logstring.append(f"L={results['Latency']:.2e}")
184188

185-
if Metrics.ENERGY in metrics:
189+
if Metrics.ENERGY in metrics and Metrics.DEBUG in metrics:
186190
logstring.append(f"E={results['Energy']:.2e}")
187191

188192
if Metrics.OP_INTENSITY in metrics:
189193
results["Op_Intensity"] = result.op_intensity[1]
190194

191-
logstring.append(f"Results: {results}")
192-
results[LOGSTRING] = {einsum_id: str(logstring)}
193-
results[MAPPING] = {einsum_id: tiling_full}
194-
results[TENSORS] = {einsum_id: backing_storages}
195-
results[STATS] = {
196-
einsum_id: {k: v for k, v in results.items() if k not in RESERVED_COLUMNS}
197-
}
198-
results[IN_PROGRESS_STATS] = {einsum_id: {}}
199-
results[MAPPING_HASH] = {einsum_id: hash((einsum_id, tiling_compatibility))}
200-
results[TAGS] = {einsum_id: tiling_compatibility.tags}
195+
if metrics.DEBUG in metrics:
196+
logstring.append(f"Results: {results}")
197+
results[LOGSTRING] = {einsum_id: str(logstring)}
198+
results[STATS] = {einsum_id: {k: v for k, v in results.items() if k not in RESERVED_COLUMNS}}
199+
results[TAGS] = {einsum_id: tiling_compatibility.tags}
200+
results[MAPPING_HASH] = {einsum_id: hash((einsum_id, tiling_compatibility))}
201+
results[IN_PROGRESS_STATS] = {einsum_id: {k: v for k, v in results.items() if k not in RESERVED_COLUMNS}}
202+
results[TENSORS] = {einsum_id: backing_storages}
201203

204+
results[MAPPING] = {einsum_id: tiling_full}
202205
key = (tiling_compatibility, fzs(results.keys()))
203206

204207
is_pareto = True

pytimeloop/fastfusion/pareto.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,12 @@ def draw_looptree(row: pd.DataFrame, live_tensors: set[int]):
218218

219219
def check_correctness(data: pd.DataFrame, live_tensors: set[int]):
220220
from pytimeloop.fastfusion.plot.looptree import tilings2looptree
221+
from pytimeloop.fastfusion.sim import TensorStorage
221222

222223
def fail(index):
223224
draw_looptree(data.iloc[index], live_tensors)
224-
all_tensors = set(t for tn in r[TENSORS].values() for t in tn)
225+
all_tensors = set(t for tn in r[MAPPING].values() for t in tn.tensors)
226+
all_tensors = TensorStorage.get_backing_stores(all_tensors)
225227
for t in sorted(all_tensors):
226228
print(f"{t.__repr__()},")
227229

@@ -230,8 +232,6 @@ def fail(index):
230232
looptree = tilings2looptree(
231233
r[MAPPING],
232234
r[STATS],
233-
r[TENSORS],
234-
r[IN_PROGRESS_STATS],
235235
skip_backing_tensors_in_right_branch=live_tensors,
236236
still_live_tensors=live_tensors,
237237
)
@@ -248,8 +248,6 @@ def fail(index):
248248
looptree = tilings2looptree(
249249
r[MAPPING],
250250
r[STATS],
251-
r[TENSORS],
252-
r[IN_PROGRESS_STATS],
253251
skip_backing_tensors_in_right_branch=live_tensors,
254252
still_live_tensors=live_tensors,
255253
)
@@ -337,6 +335,8 @@ def merge_cross(
337335
df = makepareto(df)
338336

339337
for k in DICT_COLUMNS:
338+
if k not in left.columns:
339+
continue
340340
c0, c1 = k + MERGE_SUFFIXES[0], k + MERGE_SUFFIXES[1]
341341
df[k] = (
342342
df.apply(lambda row: {**row[c0], **row[c1]}, axis=1) if len(df) > 0 else []
@@ -345,10 +345,9 @@ def merge_cross(
345345

346346
cols = [c for c in df.columns if c not in DICT_COLUMNS]
347347

348-
if True:
348+
if IN_PROGRESS_STATS in df.columns:
349349
first_row = df.iloc[0]
350-
tensors = first_row[TENSORS]
351-
einsums = list(tensors.keys())
350+
einsums = list(first_row[IN_PROGRESS_STATS].keys())
352351
last = einsums[-1]
353352
for i, r in df[cols].iterrows():
354353
df.at[i, IN_PROGRESS_STATS][last] = r.to_dict()
@@ -422,11 +421,12 @@ def alloc(self, resource_name: str, size: int, above_loop_index: int):
422421
def add_tensor(self, tensor):
423422
if len(self.data) == 0:
424423
return
425-
last_einsum = list(self.data.iloc[0][TENSORS].keys())[-1]
426-
if tensor in self.data[TENSORS].iloc[0][last_einsum]:
427-
return
428-
for t in self.data[TENSORS]:
429-
t[last_einsum].append(tensor)
424+
if TENSORS in self.data:
425+
last_einsum = list(self.data.iloc[0][TENSORS].keys())[-1]
426+
if tensor in self.data[TENSORS].iloc[0][last_einsum]:
427+
return
428+
for t in self.data[TENSORS]:
429+
t[last_einsum].append(tensor)
430430

431431
def copy(self) -> "Pareto":
432432
return Pareto(self.data.copy())

pytimeloop/fastfusion/plot/interactive.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88
import pandas as pd
99
import plotly.express as px
1010

11-
from pytimeloop.fastfusion.sim import Loop, Tiling
11+
from pytimeloop.fastfusion.sim import Loop, TensorStorage, Tiling
1212
from pytimeloop.fastfusion.util import expfmt
1313
from pytimeloop.fastfusion.plot.looptree import tilings2svg
1414
from pytimeloop.fastfusion.pareto import MAPPING, STATS, TENSORS, IN_PROGRESS_STATS, MAPPING_HASH
1515

1616
import pandas as pd
1717

1818
def mapping2svg(mapping: pd.Series):
19-
return SVG(tilings2svg(mapping[MAPPING], mapping[STATS], mapping[TENSORS], mapping[IN_PROGRESS_STATS]))
19+
return SVG(tilings2svg(mapping[MAPPING], mapping.get(STATS, None)))
2020

2121
def mapping2svg(mapping: pd.Series):
22-
return SVG(tilings2svg(mapping[MAPPING], mapping[STATS], mapping[TENSORS], mapping[IN_PROGRESS_STATS]))
22+
return SVG(tilings2svg(mapping[MAPPING], mapping.get(STATS, None)))
2323

2424
def diplay_mappings_on_fig(fig: plotly.graph_objs.FigureWidget, data: dict[str, pd.DataFrame]):
2525
# fig = go.FigureWidget(fig)
@@ -33,13 +33,10 @@ def display_mapping(trace, points, selector):
3333
d = data[trace.name]
3434
index = points.point_inds[0]
3535
display(mapping2svg(d.iloc[index]))
36-
all_tensors = set(
37-
t for tn in d.iloc[index][TENSORS].values() for t in tn
38-
)
39-
for t in sorted(all_tensors):
36+
backing_tensors = set(t for tn in d.iloc[index][MAPPING].values() for t in tn.tensors)
37+
backing_tensors = TensorStorage.get_backing_stores(backing_tensors)
38+
for t in sorted(backing_tensors):
4039
print(f"{t.__repr__()},")
41-
for k, v in d.iloc[index][MAPPING_HASH].items():
42-
print(f"{k}: {v},")
4340
for t in sorted(list(d.iloc[index][MAPPING].values())[-1].tensors):
4441
print(f"{t.__repr__()},")
4542
for v in d.iloc[index][MAPPING].values():
@@ -53,13 +50,10 @@ def display_mapping_2(trace, points, selector):
5350
d = data[trace.name]
5451
index = points.point_inds[0]
5552
display(mapping2svg(d.iloc[index]))
56-
all_tensors = set(
57-
t for tn in d.iloc[index][TENSORS].values() for t in tn
58-
)
59-
for t in sorted(all_tensors):
53+
backing_tensors = set(t for tn in d.iloc[index][MAPPING].values() for t in tn.tensors)
54+
backing_tensors = TensorStorage.get_backing_stores(backing_tensors)
55+
for t in sorted(backing_tensors):
6056
print(f"{t.__repr__()},")
61-
for k, v in d.iloc[index][MAPPING_HASH].items():
62-
print(f"{k}: {v},")
6357
for t in sorted(list(d.iloc[index][MAPPING].values())[-1].tensors):
6458
print(f"{t.__repr__()},")
6559
for v in d.iloc[index][MAPPING].values():

pytimeloop/fastfusion/plot/looptree.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,23 @@ def get_all_storages(self, _entry: bool = True) -> list[TensorStorage]:
8080
def get_backing_stores(self):
8181
return TensorStorage.get_backing_stores(self.get_all_storages())
8282

83-
def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], backing_tensors: dict[str, list[TensorStorage]], partial_stats: dict[str, Any], skip_backing_tensors_in_right_branch: Iterable[str] = (), still_live_tensors: set[str] = (), shared_loop_index: int = -1):
83+
def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], skip_backing_tensors_in_right_branch: Iterable[str] = (), still_live_tensors: set[str] = (), shared_loop_index: int = -1):
8484
prev_tilings = []
8585
root = Node()
8686
einsum_ids = list(mappings.keys())
87+
8788
assert set(einsum_ids) == set(stats.keys())
88-
assert set(einsum_ids) == set(backing_tensors.keys())
89+
90+
8991

9092
# If a tensor appears in non-back-to-back Einsums, then we need to store it for
9193
# all Einsums in between
92-
tensors_lifetimes = {e: [] for e in backing_tensors}
93-
all_tensors = set().union(*[set(t) for t in backing_tensors.values()])
94+
tensors_lifetimes = {e: [] for e in einsum_ids}
95+
all_tensors = set().union(*[set(t.tensors) for t in mappings.values()])
9496
backers = TensorStorage.get_backing_stores(all_tensors)
9597
for t in all_tensors:
96-
first_appearance = min(i for i, ts in enumerate(backing_tensors.values()) if t in ts)
97-
last_appearance = max(i for i, ts in enumerate(backing_tensors.values()) if t in ts)
98+
first_appearance = min(i for i, ts in enumerate(mappings.values()) if t in ts.tensors)
99+
last_appearance = max(i for i, ts in enumerate(mappings.values()) if t in ts.tensors)
98100
if t.tensor_id in still_live_tensors:
99101
last_appearance = len(einsum_ids) - 1
100102
for i, l in enumerate(tensors_lifetimes.values()):
@@ -151,13 +153,13 @@ def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], backing
151153

152154
return root
153155

154-
def tilings2svg(mappings: dict[str, Tiling], stats: dict[str, Any], tensors: dict[str, list[TensorStorage]], partial_stats: dict[str, Any]):
155-
root = tilings2looptree(mappings, stats, tensors, partial_stats)
156+
def tilings2svg(mappings: dict[str, Tiling], stats: dict[str, Any], ):
157+
root = tilings2looptree(mappings, stats)
156158
graph = pydot.Dot(graph_type="digraph", ranksep="0.2", nodesep="0.2")
157159
root.to_pydot(graph)
158160
return graph.create_svg()
159161

160-
def tilings2yaml(mappings: dict[str, Tiling], stats: dict[str, Any], tensors: dict[str, list[TensorStorage]], partial_stats: dict[str, Any]):
161-
root = tilings2looptree(mappings, stats, tensors, partial_stats)
162+
def tilings2yaml(mappings: dict[str, Tiling], stats: dict[str, Any]):
163+
root = tilings2looptree(mappings, stats)
162164
return root.to_yaml()
163165

0 commit comments

Comments
 (0)