Skip to content

Commit b303c71

Browse files
committed
Rewrite InflationGrader: get_count()
1 parent 559ad0b commit b303c71

File tree

9 files changed

+487
-234
lines changed

9 files changed

+487
-234
lines changed

src/classy_blocks/grading/autograding/grader.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from typing import get_args
2-
3-
from classy_blocks.grading.autograding.params import (
4-
ChopParams,
5-
FixedCountParams,
6-
InflationGraderParams,
7-
SimpleGraderParams,
8-
SmoothGraderParams,
9-
)
1+
from typing import Tuple, get_args
2+
3+
from classy_blocks.grading.autograding.params.base import ChopParams
4+
from classy_blocks.grading.autograding.params.fixed import FixedCountGraderParams
5+
from classy_blocks.grading.autograding.params.inflation import InflationGraderParams
6+
from classy_blocks.grading.autograding.params.simple import SimpleGraderParams
7+
from classy_blocks.grading.autograding.params.smooth import SmoothGraderParams
108
from classy_blocks.grading.autograding.probe import Probe
119
from classy_blocks.grading.autograding.row import Row
1210
from classy_blocks.grading.chop import Chop
@@ -38,28 +36,33 @@ def __init__(self, mesh: Mesh, params: ChopParams):
3836
self.mesh.assemble()
3937
self.probe = Probe(self.mesh)
4038

39+
def check_at_wall(self, row: Row) -> Tuple[bool, bool]:
40+
"""Returns True if any block on given row has a wall patch
41+
(at start and/or end, respectively)."""
42+
start = False
43+
end = False
44+
45+
# Check if there are blocks at the wall;
46+
for entry in row.entries:
47+
for wire in entry.wires:
48+
# TODO: cache WireInfo
49+
info = self.probe.get_wire_info(wire, entry.block)
50+
if info.starts_at_wall:
51+
start = True
52+
if info.ends_at_wall:
53+
end = True
54+
55+
return start, end
56+
4157
def set_counts(self, row: Row, take: ChopTakeType) -> None:
4258
if row.count > 0:
4359
# stuff, pre-defined by the user
4460
return
4561

46-
# at_wall: List[Entry] = []
47-
48-
# Check if there are blocks at the wall;
49-
# for entry in row.entries:
50-
# for wire in entry.wires:
51-
# # TODO: cache WireInfo
52-
# info = self.probe.get_wire_info(wire, entry.block)
53-
# if info.starts_at_wall or info.ends_at_wall:
54-
# at_wall.append(entry)
55-
5662
length = row.get_length(take)
63+
start_at_wall, end_at_wall = self.check_at_wall(row)
5764

58-
# if len(at_wall) > 0:
59-
# # find out whether one or two sides are to be counted
60-
# pass
61-
62-
row.count = self.params.get_count(length)
65+
row.count = self.params.get_count(length, start_at_wall, end_at_wall)
6366

6467
def grade_squeezed(self, row: Row) -> None:
6568
for entry in row.entries:
@@ -118,7 +121,7 @@ class FixedCountGrader(GraderBase):
118121
useful during mesh building and some tutorial cases"""
119122

120123
def __init__(self, mesh: Mesh, count: int = 8):
121-
super().__init__(mesh, FixedCountParams(count))
124+
super().__init__(mesh, FixedCountGraderParams(count))
122125

123126

124127
class SimpleGrader(GraderBase):

src/classy_blocks/grading/autograding/params.py

Lines changed: 0 additions & 209 deletions
This file was deleted.

src/classy_blocks/grading/autograding/params/__init__.py

Whitespace-only changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import abc
2+
from typing import List, Optional
3+
4+
from classy_blocks.grading.autograding.probe import WireInfo
5+
from classy_blocks.grading.chop import Chop
6+
7+
CellSizeType = Optional[float]
8+
9+
10+
def sum_length(start_size: float, count: int, c2c_expansion: float) -> float:
11+
"""Returns absolute length of the chop"""
12+
length = 0.0
13+
size = start_size
14+
15+
for _ in range(count):
16+
length += size
17+
size *= c2c_expansion
18+
19+
return length
20+
21+
22+
class ChopParams(abc.ABC):
23+
@abc.abstractmethod
24+
def get_count(self, length: float, start_at_wall: bool, end_at_wall: bool) -> int:
25+
"""Calculates count based on given length and position"""
26+
27+
@abc.abstractmethod
28+
def is_squeezed(self, count: int, info: WireInfo) -> bool:
29+
"""Returns True if cells have to be 'squished' together (thinner than prescribed in params)"""
30+
31+
@abc.abstractmethod
32+
def get_chops(self, count: int, info: WireInfo) -> List[Chop]:
33+
"""Fixes cell count but modifies chops so that proper cell sizing will be obeyed"""
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import dataclasses
2+
from typing import List
3+
4+
from classy_blocks.grading.autograding.params.base import ChopParams
5+
from classy_blocks.grading.chop import Chop
6+
7+
8+
@dataclasses.dataclass
9+
class FixedCountGraderParams(ChopParams):
10+
count: int = 8
11+
12+
def get_count(self, _length, _start_at_wall, _end_at_wall):
13+
return self.count
14+
15+
def is_squeezed(self, _count, _info) -> bool:
16+
return True # grade everything in first pass
17+
18+
def get_chops(self, count, _info) -> List[Chop]:
19+
return [Chop(count=count)]

0 commit comments

Comments
 (0)