Skip to content

Commit f5aa908

Browse files
author
Michael Gilbert
committed
[looptree] Refactor latency model
1 parent 6b8b367 commit f5aa908

File tree

17 files changed

+429
-103
lines changed

17 files changed

+429
-103
lines changed

pytimeloop/fastfusion/fastmodel/fastmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sympy
66

77
from pytimeloop.looptree.equivalent_ranks import EquivalentGroups
8-
from pytimeloop.looptree.des import IslReuseAnalysisOutput
8+
from pytimeloop.looptree.reuse.isl.des import IslReuseAnalysisOutput
99

1010

1111
def compile_mapping(mapping,

pytimeloop/fastfusion/mapper/per_einsum_subspaces/subspaces/tile_shape/tile_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .shape_subspace import ShapeSubspace
66

7-
from pytimeloop.looptree.des import IslReuseAnalysisOutput
7+
from pytimeloop.looptree.reuse.isl.des import IslReuseAnalysisOutput
88

99

1010
def explore_tile_shape(

pytimeloop/looptree/latency/latency.py

Lines changed: 2 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
from collections import defaultdict
2-
31
from pytimeloop.isl.singular import get_value_from_singular_qpolynomial
4-
from pytimeloop.looptree.accesses import (
5-
reads_and_writes_from_fill_by_parent,
6-
reads_and_writes_from_fill_by_peer
7-
)
82
from pytimeloop.looptree.latency.processors import LATENCY_PROCESSORS
9-
from pytimeloop.looptree.des import IslReuseAnalysisOutput
3+
from pytimeloop.looptree.reuse.isl.des import IslReuseAnalysisOutput
4+
from pytimeloop.looptree.latency.memory.isl import memory_latency
105

116
from bindings.looptree import SpatialTag
127

@@ -34,94 +29,6 @@ def compute_latency(mapping, temporal_steps, workload):
3429
).to_python()
3530

3631

37-
def memory_latency(looptree_results: IslReuseAnalysisOutput,
38-
arch,
39-
mapping,
40-
workload,
41-
bindings):
42-
reads, writes = reads_and_writes_from_fill_by_parent(
43-
looptree_results.fills,
44-
looptree_results.reads_to_parent,
45-
mapping,
46-
workload,
47-
per_unit=True
48-
)
49-
50-
peer_reads, peer_writes = reads_and_writes_from_fill_by_peer(
51-
looptree_results.reads_to_peer,
52-
mapping,
53-
workload,
54-
per_unit=True
55-
)
56-
57-
component_to_read_writes = defaultdict(lambda: [None, None])
58-
for level, component in bindings.items():
59-
read_count = sum(reads[key] for key in reads if key[0] == level)
60-
read_count += sum(peer_reads[key]
61-
for key in peer_reads if key[0] == level)
62-
write_count = sum(writes[key] for key in writes if key[0] == level)
63-
write_count += sum(peer_writes[key]
64-
for key in peer_writes if key[0] == level)
65-
if component not in component_to_read_writes:
66-
component_to_read_writes[component][0] = read_count
67-
component_to_read_writes[component][1] = write_count
68-
else:
69-
component_to_read_writes[component][0] += read_count
70-
component_to_read_writes[component][1] += write_count
71-
72-
component_latency = {}
73-
bandwidths = get_bandwidth(arch)
74-
for component, (reads, writes) in component_to_read_writes.items():
75-
read_bw, write_bw, shared_bw = bandwidths[component]
76-
77-
# For numerical stability
78-
read_bw += 1e-8
79-
write_bw += 1e-8
80-
shared_bw += 1e-8
81-
82-
# All shared bw for writing
83-
write_latency = writes / (write_bw + shared_bw)
84-
read_latency = reads / read_bw
85-
if write_latency >= read_latency:
86-
component_latency[component] = write_latency
87-
continue
88-
# All shared bw for reading
89-
write_latency = writes / write_bw
90-
read_latency = reads / (read_bw + shared_bw)
91-
if read_latency >= write_latency:
92-
component_latency[component] = read_latency
93-
continue
94-
# Shared bw shared for reading and writing
95-
component_latency[component] = (
96-
(reads + writes)
97-
/
98-
(read_bw + write_bw + shared_bw)
99-
)
100-
return component_latency
101-
102-
103-
def get_bandwidth(arch):
104-
component_bandwidths = {}
105-
for node in arch['nodes']:
106-
attributes = node.attributes
107-
n_rd_ports = attributes.get('n_rd_ports', 0)
108-
n_wr_ports = attributes.get('n_wr_ports', 0)
109-
n_rdwr_ports = attributes.get('n_rdwr_ports', 0)
110-
if n_rd_ports + n_wr_ports + n_rdwr_ports < 1:
111-
n_rdwr_ports = 1
112-
113-
width = attributes['width']
114-
datawidth = attributes['datawidth']
115-
width_in_words = width/datawidth
116-
117-
component_bandwidths[node['name']] = [
118-
n_rd_ports*width_in_words,
119-
n_wr_ports*width_in_words,
120-
n_rdwr_ports*width_in_words
121-
]
122-
return component_bandwidths
123-
124-
12532
def _compute_latency(mapping, top_idx: int, temporal_steps, workload):
12633
einsum_name_to_id = workload.einsum_name_to_id()
12734

pytimeloop/looptree/latency/memory/__init__.py

Whitespace-only changes.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import overload
2+
3+
from pytimeloop.looptree.reuse.isl.des import IslReuseAnalysisOutput
4+
from pytimeloop.looptree.latency.memory import isl
5+
6+
7+
ANALYSIS_TYPE_TO_ANALYZER = {
8+
IslReuseAnalysisOutput: isl.memory_latency,
9+
SummarizedAnalysisOutput: summarized.memory_latency
10+
}
11+
12+
13+
@overload
14+
def calculate_memory_latency(reuse_analysis: IslReuseAnalysisOutput,
15+
architecture,
16+
mapping,
17+
workload,
18+
bindings):
19+
pass
20+
@overload
21+
def calculate_memory_latency(reuse_analysis: SummarizedAnalysisOutput,
22+
architecture,
23+
mapping,
24+
workload,
25+
bindings):
26+
pass
27+
def calculate_memory_latency(reuse_analysis,
28+
architecture,
29+
mapping,
30+
workload,
31+
bindings):
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from collections import defaultdict
2+
3+
from pytimeloop.looptree.accesses import (
4+
reads_and_writes_from_fill_by_parent,
5+
reads_and_writes_from_fill_by_peer
6+
)
7+
from pytimeloop.looptree.reuse.isl.des import IslReuseAnalysisOutput
8+
9+
10+
def memory_latency(looptree_results: IslReuseAnalysisOutput,
11+
arch,
12+
mapping,
13+
workload,
14+
bindings):
15+
reads, writes = reads_and_writes_from_fill_by_parent(
16+
looptree_results.fills,
17+
looptree_results.reads_to_parent,
18+
mapping,
19+
workload,
20+
per_unit=True
21+
)
22+
23+
peer_reads, peer_writes = reads_and_writes_from_fill_by_peer(
24+
looptree_results.reads_to_peer,
25+
mapping,
26+
workload,
27+
per_unit=True
28+
)
29+
30+
component_to_read_writes = defaultdict(lambda: [None, None])
31+
for level, component in bindings.items():
32+
read_count = sum(reads[key] for key in reads if key[0] == level)
33+
read_count += sum(peer_reads[key]
34+
for key in peer_reads if key[0] == level)
35+
write_count = sum(writes[key] for key in writes if key[0] == level)
36+
write_count += sum(peer_writes[key]
37+
for key in peer_writes if key[0] == level)
38+
if component not in component_to_read_writes:
39+
component_to_read_writes[component][0] = read_count
40+
component_to_read_writes[component][1] = write_count
41+
else:
42+
component_to_read_writes[component][0] += read_count
43+
component_to_read_writes[component][1] += write_count
44+
45+
component_latency = {}
46+
bandwidths = get_bandwidth(arch)
47+
for component, (reads, writes) in component_to_read_writes.items():
48+
read_bw, write_bw, shared_bw = bandwidths[component]
49+
50+
# For numerical stability
51+
read_bw += 1e-8
52+
write_bw += 1e-8
53+
shared_bw += 1e-8
54+
55+
# All shared bw for writing
56+
write_latency = writes / (write_bw + shared_bw)
57+
read_latency = reads / read_bw
58+
if write_latency >= read_latency:
59+
component_latency[component] = write_latency
60+
continue
61+
# All shared bw for reading
62+
write_latency = writes / write_bw
63+
read_latency = reads / (read_bw + shared_bw)
64+
if read_latency >= write_latency:
65+
component_latency[component] = read_latency
66+
continue
67+
# Shared bw shared for reading and writing
68+
component_latency[component] = (
69+
(reads + writes)
70+
/
71+
(read_bw + write_bw + shared_bw)
72+
)
73+
return component_latency
74+
75+
76+
def get_bandwidth(arch):
77+
component_bandwidths = {}
78+
for node in arch['nodes']:
79+
attributes = node.attributes
80+
n_rd_ports = attributes.get('n_rd_ports', 0)
81+
n_wr_ports = attributes.get('n_wr_ports', 0)
82+
n_rdwr_ports = attributes.get('n_rdwr_ports', 0)
83+
if n_rd_ports + n_wr_ports + n_rdwr_ports < 1:
84+
n_rdwr_ports = 1
85+
86+
width = attributes['width']
87+
datawidth = attributes['datawidth']
88+
width_in_words = width/datawidth
89+
90+
component_bandwidths[node['name']] = [
91+
n_rd_ports*width_in_words,
92+
n_wr_ports*width_in_words,
93+
n_rdwr_ports*width_in_words
94+
]
95+
return component_bandwidths
96+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def memory_latency(looptree_results,
2+
arch,
3+
mapping,
4+
workload,
5+
bindings):
6+
raise NotImplementedError()
File renamed without changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import sympy
2+
3+
4+
def lambdify(d, tile_shapes):
5+
if isinstance(next(iter(d.values())), tuple):
6+
return {
7+
k: (v[0], sympy.lambdify(tile_shapes, v[1]))
8+
for k, v in d.items()
9+
}
10+
else:
11+
return {
12+
k: sympy.lambdify(tile_shapes, v)
13+
for k, v in d.items()
14+
}

0 commit comments

Comments
 (0)