Skip to content

Commit e893dfe

Browse files
author
Michael Gilbert
committed
[looptree] Skeleton Python-based workload
1 parent 0f742fe commit e893dfe

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

pytimeloop/looptree/workload.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from collections.abc import Iterable
2+
3+
4+
class Einsum:
5+
def __init__(self):
6+
self.rank_variables: list[str] = []
7+
8+
9+
class Tensor:
10+
def __init__(self):
11+
self.ranks: list[str] = []
12+
pass
13+
14+
15+
class LooptreeWorkload:
16+
def __init__(self):
17+
self.einsums: dict[str, Einsum] = {}
18+
self.tensors: dict[str, Tensor] = {}
19+
20+
self._tensors_read_by_einsum: dict[str, set[str]] = {}
21+
self._tensors_written_by_einsum: dict[str, set[str]] = {}
22+
23+
self._einsums_reading_tensor: dict[str, set[str]] = {}
24+
self._einsums_writing_tensor: dict[str, set[str]] = {}
25+
26+
def add_einsum(self, einsum_name: str):
27+
if einsum_name in self.einsums:
28+
raise KeyError(f'{einsum_name} already exists')
29+
self.einsums[einsum_name] = Einsum()
30+
self._tensors_written_by_einsum[einsum_name] = set()
31+
self._tensors_read_by_einsum[einsum_name] = set()
32+
33+
def set_einsum_rank_variables(self,
34+
einsum_name: str,
35+
rank_variables: Iterable[str]):
36+
try:
37+
self.einsums[einsum_name].rank_variables = list(rank_variables)
38+
except KeyError:
39+
raise KeyError(f'Einsum {einsum_name} not in workload')
40+
41+
def set_einsum_shape(self, einsum_name: str, shape):
42+
raise NotImplementedError()
43+
44+
def add_tensor(self, tensor_name: str):
45+
if tensor_name in self.einsums:
46+
raise KeyError(f'{tensor_name} already exists')
47+
self.tensors[tensor_name] = Tensor()
48+
self._einsums_reading_tensor[tensor_name] = set()
49+
self._einsums_writing_tensor[tensor_name] = set()
50+
51+
def set_tensor_ranks(self, tensor_name: str, ranks: Iterable[str]):
52+
try:
53+
self.tensors[tensor_name].ranks = list(ranks)
54+
except KeyError:
55+
raise KeyError(f'Tensor {tensor_name} not in workload')
56+
57+
def set_tensor_shape(self, tensor_name: str, shape):
58+
raise NotImplementedError()
59+
60+
def set_projection(self,
61+
einsum_name: str,
62+
tensor_name: str,
63+
projection,
64+
is_output: bool):
65+
if is_output:
66+
self._tensors_written_by_einsum[einsum_name].add(tensor_name)
67+
self._einsums_writing_tensor[tensor_name].add(einsum_name)
68+
else:
69+
self._tensors_read_by_einsum[einsum_name].add(tensor_name)
70+
self._einsums_reading_tensor[tensor_name].add(einsum_name)
71+
72+
raise NotImplementedError()
73+
74+
def get_einsum_with_name(self, einsum_name: str):
75+
return self.einsums[einsum_name]
76+
77+
def get_tensor_with_name(self, tensor_name: str):
78+
return self.tensors[tensor_name]

0 commit comments

Comments
 (0)