Skip to content

Commit 6b8b367

Browse files
author
Michael Gilbert
committed
[looptree] Refactor: rename LooptreeOutput to IslReuseAnalysisOutput
1 parent 0e6b274 commit 6b8b367

File tree

8 files changed

+51
-34
lines changed

8 files changed

+51
-34
lines changed

pytimeloop/fastfusion/fastmodel/fastmodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sympy
66

77
from pytimeloop.looptree.equivalent_ranks import EquivalentGroups
8-
from pytimeloop.looptree.des import LooptreeOutput
8+
from pytimeloop.looptree.des import IslReuseAnalysisOutput
99

1010

1111
def compile_mapping(mapping,
@@ -58,7 +58,7 @@ def compile_mapping(mapping,
5858

5959
tile_shapes = []
6060

61-
output = LooptreeOutput()
61+
output = IslReuseAnalysisOutput()
6262

6363
latency = 1
6464
potential_tensor_access_multiplier = defaultdict(lambda: 1)

pytimeloop/fastfusion/mapper/per_einsum_subspaces/subspaces/tile_shape/tile_shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .shape_subspace import ShapeSubspace
66

7-
from pytimeloop.looptree.des import LooptreeOutput
7+
from pytimeloop.looptree.des import IslReuseAnalysisOutput
88

99

1010
def explore_tile_shape(
@@ -56,7 +56,7 @@ def explore_tile_shape(
5656
if only_count:
5757
continue
5858

59-
result = LooptreeOutput()
59+
result = IslReuseAnalysisOutput()
6060
result.ops = call_with_arg(compiled_result.ops, shape)
6161
result.temporal_steps = call_with_arg(compiled_result.temporal_steps, shape)
6262
result.fanout = call_with_arg(compiled_result.fanout, shape)

pytimeloop/looptree/des.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,28 @@
1+
from dataclasses import dataclass, field
2+
13
import islpy as isl
24

35
import bindings
46

57

6-
class LooptreeOutput:
7-
def __init__(self):
8-
self.ops = {}
9-
self.fills = {}
10-
self.occupancy = {}
11-
self.op_occupancy = {}
12-
self.reads_to_peer = {}
13-
self.reads_to_parent = {}
14-
self.temporal_steps = {}
15-
self.fanout = {}
16-
self.op_intensity = {}
17-
18-
def __repr__(self):
19-
return (
20-
f'LooptreeOutput(' +
21-
f'ops={self.ops}, ' +
22-
f'occupancy={self.occupancy}, ' +
23-
f'reads_to_parent={self.reads_to_parent})'
24-
)
8+
@dataclass
9+
class IslReuseAnalysisOutput:
10+
ops: dict = field(default_factory=dict)
11+
fills: dict = field(default_factory=dict)
12+
occupancy: dict = field(default_factory=dict)
13+
op_occupancy: dict = field(default_factory=dict)
14+
reads_to_peer: dict = field(default_factory=dict)
15+
reads_to_parent: dict = field(default_factory=dict)
16+
temporal_steps: dict = field(default_factory=dict)
17+
fanout: dict = field(default_factory=dict)
18+
op_intensity: dict = field(default_factory=dict)
2519

2620

2721
def deserialize_looptree_output(
2822
looptree_output: bindings.looptree.LooptreeResult,
2923
isl_ctx: isl.Context
30-
) -> LooptreeOutput:
31-
output = LooptreeOutput()
24+
) -> IslReuseAnalysisOutput:
25+
output = IslReuseAnalysisOutput()
3226

3327
output.ops = {
3428
k: (dims, isl.PwQPolynomial.read_from_str(isl_ctx, v))

pytimeloop/looptree/latency/latency.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
reads_and_writes_from_fill_by_peer
77
)
88
from pytimeloop.looptree.latency.processors import LATENCY_PROCESSORS
9-
from pytimeloop.looptree.des import LooptreeOutput
9+
from pytimeloop.looptree.des import IslReuseAnalysisOutput
1010

1111
from bindings.looptree import SpatialTag
1212

1313

14-
def get_latency(looptree_results: LooptreeOutput,
14+
def get_latency(looptree_results: IslReuseAnalysisOutput,
1515
mapping,
1616
workload,
1717
arch,
@@ -34,7 +34,7 @@ def compute_latency(mapping, temporal_steps, workload):
3434
).to_python()
3535

3636

37-
def memory_latency(looptree_results: LooptreeOutput,
37+
def memory_latency(looptree_results: IslReuseAnalysisOutput,
3838
arch,
3939
mapping,
4040
workload,

pytimeloop/looptree/visualization/occupancy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import matplotlib.pyplot as plt
22

3-
from pytimeloop.looptree.des import LooptreeOutput
3+
from pytimeloop.looptree.des import IslReuseAnalysisOutput
44

55

6-
def plot_occupancy_graph(output: LooptreeOutput, workload):
6+
def plot_occupancy_graph(output: IslReuseAnalysisOutput, workload):
77
einsum_rank_to_shape = {
88
einsum: {
99
rank: workload.get_rank_shape(rank)

tests/looptree/test_model.py renamed to tests/looptree/test_cpp_reuse_analysis.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
import islpy as isl
66

7-
from bindings.config import Config
87
from bindings.looptree import *
98
from pytimeloop.looptree.des import deserialize_looptree_output
109
from .make_model_app import make_model_app
11-
from tests.util import TEST_TMP_DIR, gather_yaml_configs
10+
from tests.util import TEST_TMP_DIR
1211

1312
class LooptreeModelAppTest(unittest.TestCase):
1413
def test_model_with_two_level_mm(self):
@@ -52,7 +51,7 @@ def seconds(ds):
5251
)
5352

5453

55-
class TestLooptreeOutputDeserializer(unittest.TestCase):
54+
class TestIslReuseAnalysisOutputDeserializer(unittest.TestCase):
5655
def test_deserializer_with_two_level_mm(self):
5756
self.check_deserializer(
5857
Path(__file__).parent.parent / 'test_configs',
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
mapping:
2+
type: fused
3+
nodes:
4+
- type: storage
5+
target: 0
6+
dspace: [Filter1, Fmap1, Fmap2]
7+
- type: storage
8+
target: 1
9+
dspace: [Filter1]
10+
- type: temporal
11+
rank: P1
12+
- type: storage
13+
target: 2
14+
dspace: [Fmap2]
15+
- type: temporal
16+
rank: C1
17+
- type: spatial
18+
rank: M1
19+
- type: storage
20+
target: 3
21+
dspace: [Fmap1]
22+
- type: compute
23+
einsum: Fc1
24+
target: 4

tests/test_configs/three_level.arch.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ architecture:
88
- !Component
99
name: MainMemory
1010
class: DRAM
11-
attributes: {width: 256, block_size: 32, word_bits: 8, datawidth: 8}
11+
attributes: {depth: 1000000000, width: 256, block_size: 32, word_bits: 8, datawidth: 8}
1212
required_actions: ['read', 'write']
1313
- !Component
1414
name: GlobalBuffer

0 commit comments

Comments
 (0)