|
1 |
| -from typing import get_args |
2 |
| - |
3 |
| -from classy_blocks.grading.autograding.params import ( |
4 |
| - ChopParams, |
5 |
| - FixedCountParams, |
6 |
| - InflationGraderParams, |
7 |
| - SimpleGraderParams, |
8 |
| - SmoothGraderParams, |
9 |
| -) |
| 1 | +from typing import Tuple, get_args |
| 2 | + |
| 3 | +from classy_blocks.grading.autograding.params.base import ChopParams |
| 4 | +from classy_blocks.grading.autograding.params.fixed import FixedCountGraderParams |
| 5 | +from classy_blocks.grading.autograding.params.inflation import InflationGraderParams |
| 6 | +from classy_blocks.grading.autograding.params.simple import SimpleGraderParams |
| 7 | +from classy_blocks.grading.autograding.params.smooth import SmoothGraderParams |
10 | 8 | from classy_blocks.grading.autograding.probe import Probe
|
11 | 9 | from classy_blocks.grading.autograding.row import Row
|
12 | 10 | from classy_blocks.grading.chop import Chop
|
@@ -38,28 +36,33 @@ def __init__(self, mesh: Mesh, params: ChopParams):
|
38 | 36 | self.mesh.assemble()
|
39 | 37 | self.probe = Probe(self.mesh)
|
40 | 38 |
|
| 39 | + def check_at_wall(self, row: Row) -> Tuple[bool, bool]: |
| 40 | + """Returns True if any block on given row has a wall patch |
| 41 | + (at start and/or end, respectively).""" |
| 42 | + start = False |
| 43 | + end = False |
| 44 | + |
| 45 | + # Check if there are blocks at the wall; |
| 46 | + for entry in row.entries: |
| 47 | + for wire in entry.wires: |
| 48 | + # TODO: cache WireInfo |
| 49 | + info = self.probe.get_wire_info(wire, entry.block) |
| 50 | + if info.starts_at_wall: |
| 51 | + start = True |
| 52 | + if info.ends_at_wall: |
| 53 | + end = True |
| 54 | + |
| 55 | + return start, end |
| 56 | + |
41 | 57 | def set_counts(self, row: Row, take: ChopTakeType) -> None:
|
42 | 58 | if row.count > 0:
|
43 | 59 | # stuff, pre-defined by the user
|
44 | 60 | return
|
45 | 61 |
|
46 |
| - # at_wall: List[Entry] = [] |
47 |
| - |
48 |
| - # Check if there are blocks at the wall; |
49 |
| - # for entry in row.entries: |
50 |
| - # for wire in entry.wires: |
51 |
| - # # TODO: cache WireInfo |
52 |
| - # info = self.probe.get_wire_info(wire, entry.block) |
53 |
| - # if info.starts_at_wall or info.ends_at_wall: |
54 |
| - # at_wall.append(entry) |
55 |
| - |
56 | 62 | length = row.get_length(take)
|
| 63 | + start_at_wall, end_at_wall = self.check_at_wall(row) |
57 | 64 |
|
58 |
| - # if len(at_wall) > 0: |
59 |
| - # # find out whether one or two sides are to be counted |
60 |
| - # pass |
61 |
| - |
62 |
| - row.count = self.params.get_count(length) |
| 65 | + row.count = self.params.get_count(length, start_at_wall, end_at_wall) |
63 | 66 |
|
64 | 67 | def grade_squeezed(self, row: Row) -> None:
|
65 | 68 | for entry in row.entries:
|
@@ -118,7 +121,7 @@ class FixedCountGrader(GraderBase):
|
118 | 121 | useful during mesh building and some tutorial cases"""
|
119 | 122 |
|
120 | 123 | def __init__(self, mesh: Mesh, count: int = 8):
|
121 |
| - super().__init__(mesh, FixedCountParams(count)) |
| 124 | + super().__init__(mesh, FixedCountGraderParams(count)) |
122 | 125 |
|
123 | 126 |
|
124 | 127 | class SimpleGrader(GraderBase):
|
|
0 commit comments