File tree Expand file tree Collapse file tree 4 files changed +11
-6
lines changed Expand file tree Collapse file tree 4 files changed +11
-6
lines changed Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ def glb_storage(mapping, unfused_tensors):
63
63
explore_uneven = True ,
64
64
add_split_at_tensors = set (),
65
65
must_have_terminal_storage = True ,
66
- apply_lrp_after_loop_idx = last_fused_loop_idx + 1 )
66
+ apply_lrp_after_loop_idx = last_fused_loop_idx )
67
67
68
68
def tile_shape_optimization (mapping ):
69
69
for partial_mapping in infer_smallest_tile_shape (mapping ,
Original file line number Diff line number Diff line change @@ -78,7 +78,7 @@ def make_storage(
78
78
any_irrelevant_loop = any_irrelevant_loop or not is_relevant
79
79
80
80
auto_lower = i > apply_lrp_after_loop_idx
81
- auto_lower = False
81
+ # auto_lower = False
82
82
83
83
if not auto_lower or (last_is_relevant and not is_relevant ):
84
84
tensor_choices .append (i )
@@ -89,7 +89,7 @@ def make_storage(
89
89
break
90
90
91
91
auto_lower = len (mapping ) > apply_lrp_after_loop_idx
92
- auto_lower = False
92
+ # auto_lower = False
93
93
94
94
# Lowest possible storage node
95
95
if (not auto_lower or last_is_relevant ) and not (tensor_must_be_fully_reused and any_irrelevant_loop ):
Original file line number Diff line number Diff line change @@ -154,7 +154,7 @@ def fuse_sims(
154
154
print (s )
155
155
elif DO_PRINT :
156
156
for a in left [k ]:
157
- print (f"\t No match for { k } |||| { a . tiling_str () } " )
157
+ print (f"\t No match for { a . tiling } " )
158
158
159
159
if not combined :
160
160
print (f'No valid combinations found.' )
@@ -173,7 +173,7 @@ def fuse_sims(
173
173
for k in right :
174
174
if k not in left :
175
175
for b in right [k ]:
176
- print (f"\t REVERSE: No match for { k } |||| { b . tiling_str () } " )
176
+ print (f"\t REVERSE: No match for { b . tiling } " )
177
177
178
178
left = combined
179
179
print (f"Number of buckets: { len (left )} " )
Original file line number Diff line number Diff line change @@ -137,6 +137,11 @@ class Tiling:
137
137
loops : tuple [Loop , ...]
138
138
tensors : fzs [TensorStorage ]
139
139
tags : fzs [Any ] = fzs ()
140
+
141
+ def __post_init__ (self ):
142
+ assert isinstance (self .tensors , frozenset )
143
+ assert isinstance (self .loops , tuple )
144
+ assert isinstance (self .tags , frozenset )
140
145
141
146
@cached_property
142
147
def tensor_names (self ) -> set [str ]:
@@ -172,7 +177,7 @@ def clear_dead_tensors(
172
177
if keep_loops
173
178
else self .loops [: self .shared_loop_index (live_tensors ) + 1 ]
174
179
)
175
- tensors = tuple (t for t in self .tensors if t .tensor_id in live_tensors )
180
+ tensors = frozenset (t for t in self .tensors if t .tensor_id in live_tensors )
176
181
return Tiling (loops , tensors , self .tags )
177
182
178
183
def __lt__ (self , other ):
You can’t perform that action at this time.
0 commit comments