Skip to content

Commit e39f20a

Browse files
Tile matching bugfix
1 parent 8f3ea96 commit e39f20a

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed

pytimeloop/fastfusion/mapper/per_einsum_subspaces/snowcat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def glb_storage(mapping, unfused_tensors):
6363
explore_uneven=True,
6464
add_split_at_tensors=set(),
6565
must_have_terminal_storage=True,
66-
apply_lrp_after_loop_idx=last_fused_loop_idx+1)
66+
apply_lrp_after_loop_idx=last_fused_loop_idx)
6767

6868
def tile_shape_optimization(mapping):
6969
for partial_mapping in infer_smallest_tile_shape(mapping,

pytimeloop/fastfusion/mapper/per_einsum_subspaces/subspaces/storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def make_storage(
7878
any_irrelevant_loop = any_irrelevant_loop or not is_relevant
7979

8080
auto_lower = i > apply_lrp_after_loop_idx
81-
auto_lower = False
81+
# auto_lower = False
8282

8383
if not auto_lower or (last_is_relevant and not is_relevant):
8484
tensor_choices.append(i)
@@ -89,7 +89,7 @@ def make_storage(
8989
break
9090

9191
auto_lower = len(mapping) > apply_lrp_after_loop_idx
92-
auto_lower = False
92+
# auto_lower = False
9393

9494
# Lowest possible storage node
9595
if (not auto_lower or last_is_relevant) and not (tensor_must_be_fully_reused and any_irrelevant_loop):

pytimeloop/fastfusion/mapper/simexplore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def fuse_sims(
154154
print(s)
155155
elif DO_PRINT:
156156
for a in left[k]:
157-
print(f"\tNo match for {k} |||| {a.tiling_str()}")
157+
print(f"\tNo match for {a.tiling}")
158158

159159
if not combined:
160160
print(f'No valid combinations found.')
@@ -173,7 +173,7 @@ def fuse_sims(
173173
for k in right:
174174
if k not in left:
175175
for b in right[k]:
176-
print(f"\tREVERSE: No match for {k} |||| {b.tiling_str()}")
176+
print(f"\tREVERSE: No match for {b.tiling}")
177177

178178
left = combined
179179
print(f"Number of buckets: {len(left)}")

pytimeloop/fastfusion/sim.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ class Tiling:
137137
loops: tuple[Loop, ...]
138138
tensors: fzs[TensorStorage]
139139
tags: fzs[Any] = fzs()
140+
141+
def __post_init__(self):
142+
assert isinstance(self.tensors, frozenset)
143+
assert isinstance(self.loops, tuple)
144+
assert isinstance(self.tags, frozenset)
140145

141146
@cached_property
142147
def tensor_names(self) -> set[str]:
@@ -172,7 +177,7 @@ def clear_dead_tensors(
172177
if keep_loops
173178
else self.loops[: self.shared_loop_index(live_tensors) + 1]
174179
)
175-
tensors = tuple(t for t in self.tensors if t.tensor_id in live_tensors)
180+
tensors = frozenset(t for t in self.tensors if t.tensor_id in live_tensors)
176181
return Tiling(loops, tensors, self.tags)
177182

178183
def __lt__(self, other):

0 commit comments

Comments
 (0)