Skip to content

Commit 8f3ea96

Browse files
Bugfixes
1 parent 909e13b commit 8f3ea96

File tree

10 files changed

+56
-83
lines changed

10 files changed

+56
-83
lines changed

pytimeloop/fastfusion/filter_mappings.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from pytimeloop.fastfusion.sim import Tag, TensorStorage, Tiling
2+
from pytimeloop.fastfusion.sim import TensorStorage, Tiling
33

44
def get_ffmt_tag_mha(
55
einsum_id: str,
@@ -19,6 +19,8 @@ def get_ffmt_tag_mha(
1919
"AV": ["QK_QK_to_AV", "AV_AV_to_Z"], # NOTE: AV IS MISSING
2020
"Z": ["AV_AV_to_Z", "Z_Z_to_n"],
2121
}
22+
if einsum_id not in einsum_id_to_input_output:
23+
return ("FFMT_VALID",)
2224
a, b = einsum_id_to_input_output[einsum_id]
2325

2426
tags = []
@@ -103,7 +105,7 @@ def get_tileflow_tag_mha(
103105
):
104106
# Valid iff it's an even mapping
105107
storage2level = defaultdict(set)
106-
for ts in tiling.tensors:
108+
for ts in backing_storages:
107109
storage2level[ts.backer_id].add(ts.above_loop_index)
108110
if all(len(s) == 1 for s in storage2level.values()):
109111
return ("TILEFLOW_VALID",)

pytimeloop/fastfusion/mapper/per_einsum_subspaces/subspaces/storage.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,30 @@ def make_storage(
6969
else:
7070
min_i = last_storage_idx + 1
7171

72-
no_irrelevant_loop = True
72+
any_irrelevant_loop = False
7373
for i in range(min_i, len(mapping)):
7474
node = mapping[i]
7575
if node["type"] == "temporal":
7676
is_relevant = node["rank"] in relevant_ranks
7777

78-
no_irrelevant_loop = no_irrelevant_loop and not is_relevant
78+
any_irrelevant_loop = any_irrelevant_loop or not is_relevant
7979

80-
auto_lower = has_storage and i > apply_lrp_after_loop_idx
80+
auto_lower = i > apply_lrp_after_loop_idx
81+
auto_lower = False
8182

8283
if not auto_lower or (last_is_relevant and not is_relevant):
8384
tensor_choices.append(i)
8485

8586
last_is_relevant = is_relevant
8687

87-
if not is_relevant and tensor_must_be_fully_reused:
88+
if tensor_must_be_fully_reused and any_irrelevant_loop:
8889
break
90+
91+
auto_lower = len(mapping) > apply_lrp_after_loop_idx
92+
auto_lower = False
8993

9094
# Lowest possible storage node
91-
if last_is_relevant and (not tensor_must_be_fully_reused
92-
or no_irrelevant_loop):
95+
if (not auto_lower or last_is_relevant) and not (tensor_must_be_fully_reused and any_irrelevant_loop):
9396
tensor_choices.append(len(mapping))
9497

9598
if tensor_id in can_retain_tensors:

pytimeloop/fastfusion/mapper/process_results.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,12 @@ def record_loop(node):
142142
tiling=tiling_full,
143143
rank_name_to_shared_name=rank_name_to_shared_name,
144144
)
145-
# print(tiling_full)
146145

147146
tiling_compatibility = Tiling(
148147
loops=tuple(full_tiling[:n_fused_loops]),
149148
tensors=frozenset(backing_storages),
150149
tags=fzs().union(*([set()] + [set(t(**tagger_args)) for t in tag_with]))
151150
)
152-
153-
if "FFMT_VALID" in tiling_compatibility.tags:
154-
print(tiling_compatibility)
155-
# assert max(t.above_loop_index for t in backing_storages) == len(tiling_compatibility.loops), (
156-
# f"\n\ttiling_compatibility: {tiling_compatibility} "
157-
# f"\n\tbacking_storages: {backing_storages} "
158-
# f"\n\ttiling_full: {tiling_full}"
159-
# )
160151

161152
results = {}
162153

pytimeloop/fastfusion/mapper/simexplore.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def fuse_sims(
156156
for a in left[k]:
157157
print(f"\tNo match for {k} |||| {a.tiling_str()}")
158158

159+
if not combined:
160+
print(f'No valid combinations found.')
161+
159162
print_time("Bucket merging")
160163

161164
f = parallel if DELAY_MERGE else lambda x: x

pytimeloop/fastfusion/pareto.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,17 +198,29 @@ def _free_to_loop_index(
198198
def paretofy_by(data: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
199199
return data[paretoset(data[columns])].reset_index(drop=True)
200200

201+
def draw_looptree(row: pd.DataFrame, live_tensors: set[int]):
202+
from pytimeloop.fastfusion.plot.looptree import tilings2looptree
203+
204+
looptree = tilings2looptree(
205+
row[MAPPING],
206+
row[STATS],
207+
row[TENSORS],
208+
row[IN_PROGRESS_STATS],
209+
skip_backing_tensors_in_right_branch=live_tensors,
210+
still_live_tensors=live_tensors,
211+
)
212+
import pydot
201213

214+
graph = pydot.Dot(graph_type="digraph", ranksep="0.2", nodesep="0.2")
215+
looptree.to_pydot(graph)
216+
with open(f"test.png", "wb") as f:
217+
f.write(graph.create_png())
218+
202219
def check_correctness(data: pd.DataFrame, live_tensors: set[int]):
203220
from pytimeloop.fastfusion.plot.looptree import tilings2looptree
204221

205-
def fail(looptree):
206-
import pydot
207-
208-
graph = pydot.Dot(graph_type="digraph", ranksep="0.2", nodesep="0.2")
209-
looptree.to_pydot(graph)
210-
with open(f"test.png", "wb") as f:
211-
f.write(graph.create_png())
222+
def fail(index):
223+
draw_looptree(data.iloc[index], live_tensors)
212224
all_tensors = set(t for tn in r[TENSORS].values() for t in tn)
213225
for t in sorted(all_tensors):
214226
print(f"{t.__repr__()},")

pytimeloop/fastfusion/plot/interactive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def plotly_show(
9090
fig.add_scatter(x=v[x], y=v[y], name=k, line={"shape": 'hv'})
9191
else:
9292
data.sort_values(by=[x, y], inplace=True)
93-
fig.add_scatter(x=data[x], y=data[y], line={"shape": 'hv'})
93+
fig.add_scatter(x=data[x], y=data[y], name="", line={"shape": 'hv'})
9494
data = {"" : data}
9595
if title is not None:
9696
fig.update_layout(title=title)

pytimeloop/fastfusion/plot/looptree.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_backing_stores(self):
8181
return TensorStorage.get_backing_stores(self.get_all_storages())
8282

8383
def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], backing_tensors: dict[str, list[TensorStorage]], partial_stats: dict[str, Any], skip_backing_tensors_in_right_branch: Iterable[str] = (), still_live_tensors: set[str] = (), shared_loop_index: int = -1):
84-
prev_tiling = None
84+
prev_tilings = []
8585
root = Node()
8686
einsum_ids = list(mappings.keys())
8787
assert set(einsum_ids) == set(stats.keys())
@@ -105,9 +105,10 @@ def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], backing
105105
for i, einsum_id in enumerate(einsum_ids):
106106
skip_backing_tensors = () #if i < len(einsum_ids) - 1 else skip_backing_tensors_in_right_branch
107107
tiling = mappings[einsum_id]
108-
index = (
109-
prev_tiling.shared_loop_index(tiling.tensor_names) if prev_tiling else -1
110-
)
108+
if not prev_tilings:
109+
index = -1
110+
else:
111+
index = max(p.shared_loop_index(tiling.tensor_names) for p in prev_tilings)
111112
n = root.access_level(index)
112113
loops = tiling.loops[index:] if index != -1 else tiling.loops
113114
for l in loops:
@@ -128,13 +129,13 @@ def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], backing
128129
for tensor in storages:
129130
n = root.access_level(tensor.above_loop_index)
130131
# TODO if tensor not in n.this_level or tensor not in backers:
131-
if tensor not in n.this_level:
132+
if tensor not in n.this_level or tensor not in backers:
132133
n.this_level.append(tensor)
133134

134135
root.add_stats(stats[einsum_id])
135136
# for k, v in partial_stats[einsum_id].items():
136137
# last_level.append(f"_PARTIAL {k}: {expfmt(v)}")
137-
prev_tiling = tiling
138+
prev_tilings.append(tiling)
138139

139140
n = root
140141
skip_backing_tensors_in_right_branch= set(skip_backing_tensors_in_right_branch)

pytimeloop/fastfusion/sim.py

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -132,38 +132,11 @@ def __eq__(self, value):
132132
return False
133133
return True
134134

135-
class Tag:
136-
def __eq__(self, value: "Tag") -> bool:
137-
# Match base Tag
138-
if type(value) == Tag:
139-
return True
140-
# If other is a non-base Tag or anything else, let
141-
# the other object handle the comparison
142-
return value.__eq__(self)
143-
144-
def __hash__(self) -> int:
145-
return hash("Tag")
146-
147-
def merge_next(self, n: "Tag") -> "Tag":
148-
return n
149-
150-
@staticmethod
151-
def get_wildcard():
152-
return Wildcard()
153-
154-
class Wildcard(Tag):
155-
def __eq__(self, value: "Tag") -> bool:
156-
return True
157-
158-
def __hash__(self) -> int:
159-
return hash("Wildcard")
160-
161-
162135
@dataclass(frozen=True)
163136
class Tiling:
164137
loops: tuple[Loop, ...]
165138
tensors: fzs[TensorStorage]
166-
tags: fzs[Tag] = fzs()
139+
tags: fzs[Any] = fzs()
167140

168141
@cached_property
169142
def tensor_names(self) -> set[str]:
@@ -209,9 +182,7 @@ def __str__(self):
209182
return self.__repr__()
210183

211184
def __repr__(self):
212-
if type(self.tags) == Tag:
213-
return f"Tiling({self.loops}, {self.tensors})"
214-
return f"Tiling({self.loops}, {self.tensors}, {self.tags})"
185+
return f"Tiling({self.loops.__repr__()}, {self.tensors.__repr__()}, {self.tags.__repr__()})"
215186

216187
def merge_next(self, n: "Tiling", live_tensors: set[str]) -> "Tiling":
217188
tensors = fzs(t for t in (n.tensors | self.tensors) if t.tensor_id in live_tensors)
@@ -248,7 +219,7 @@ def matches_permutation(self, permutation: list[str]) -> bool:
248219
def has_tensor(self, *tensors: TensorStorage) -> bool:
249220
return all(any(t == tensor for t in self.tensors) for tensor in tensors)
250221

251-
def set_tags(self, *new_tags: Tag) -> "Tiling":
222+
def set_tags(self, *new_tags: Any) -> "Tiling":
252223
return Tiling(self.loops, self.tensors, fzs(new_tags))
253224

254225
class SIM:
@@ -410,20 +381,12 @@ def group_right(
410381
) -> dict[tuple[Tiling, ...], list["SIM"]]:
411382
return SIM._group(sims, live_tensors)
412383

413-
def set_tags(self, *tags: Tag) -> "SIM":
384+
def set_tags(self, *tags: Any) -> "SIM":
414385
self.tiling = self.tiling.set_tags(*tags)
415386

416-
@attribute
417-
def tags(self) -> fzs[Tag]:
387+
@property
388+
def tags(self) -> fzs[Any]:
418389
return self.tiling.tags
419-
420-
@attribute
421-
def tensors(self) -> fzs[TensorStorage]:
422-
return self.tiling.tensors
423-
424-
@attribute
425-
def loops(self) -> tuple[Loop, ...]:
426-
return self.tiling.loops
427390

428391

429392
import unittest

pytimeloop/fastfusion/util.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99

1010
class fzs(frozenset):
1111
def __repr__(self):
12-
try:
13-
return f"fzs({', '.join(x.__repr__() for x in sorted(self))}"
14-
except:
15-
return f"fzs({', '.join(x.__repr() for x in self)})"
12+
return f"fzs({', '.join(sorted(x.__repr__() for x in self))})"
1613

1714
def __str__(self):
1815
return self.__repr__()

pytimeloop/looptree/mapping_utilities.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def get_last_storage_node(mapping, tensor):
6262

6363

6464
def get_last_fused_loop_idx(mapping, intermediate_tensors):
65-
last_i = 0
65+
intermedaites_remaining = set(intermediate_tensors)
6666
for i, node in enumerate(mapping):
67-
if node['type'] == 'storage' \
68-
and any(t in node['dspace'] for t in intermediate_tensors):
69-
last_i = i
70-
return last_i
67+
if node['type'] == 'storage':
68+
intermedaites_remaining -= set(node['dspace'])
69+
if not intermedaites_remaining:
70+
return i
71+
return float('inf')

0 commit comments

Comments
 (0)