Skip to content

Commit 4567661

Browse files
committed
Implement own np.take and remove 'reflected lists'
1 parent e7269d8 commit 4567661

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

src/classy_blocks/optimize/iteration.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import dataclasses
22
from typing import List
33

4-
from classy_blocks.util.constants import VBIG, VSMALL
4+
from classy_blocks.util.constants import TOL, VBIG, VSMALL
55
from classy_blocks.util.tools import report
66

77

@@ -129,4 +129,8 @@ def converged(self) -> bool:
129129
print("Tolerance reached, stopping optimization.")
130130
return True
131131

132+
if self.iterations[-1].final_quality < TOL:
133+
print("Nothing left to optimize.")
134+
return True
135+
132136
return False

src/classy_blocks/optimize/junction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dataclasses
22
from typing import List, Optional, Set
33

4+
import numpy as np
5+
46
from classy_blocks.base.exceptions import ClampExistsError
57
from classy_blocks.cbtyping import NPPointListType, NPPointType
68
from classy_blocks.optimize.cell import CellBase, HexCell
@@ -60,4 +62,4 @@ def quality(self) -> float:
6062
else:
6163
quality_function = get_quad_quality
6264

63-
return sum(quality_function(self.points, cell.indexes) for cell in self.cells)
65+
return sum(quality_function(self.points, np.array(cell.indexes, dtype=np.int32)) for cell in self.cells)

src/classy_blocks/optimize/quality.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from typing import List, Tuple
1+
from typing import Tuple
22

33
import numba # type:ignore
44
import numpy as np
5+
from nptyping import Int32, NDArray, Shape
56

67
from classy_blocks.cbtyping import NPPointListType, NPPointType, NPVectorType
78
from classy_blocks.util.constants import VSMALL
89

10+
NPIndexType = NDArray[Shape["*, 1"], Int32]
11+
912

1013
@numba.jit(nopython=True, cache=True)
1114
def scale_quality(base: float, exponent: float, factor: float, value: float) -> float:
@@ -27,6 +30,19 @@ def scale_aspect(ratio: float) -> float:
2730
return scale_quality(3, 2.5, 3, np.log10(ratio))
2831

2932

33+
@numba.jit(nopython=True, cache=True)
34+
def take(points: NPPointListType, indexes: NPIndexType):
35+
n_points = len(indexes)
36+
dim = points.shape[1]
37+
result = np.empty((n_points, dim), dtype=points.dtype)
38+
39+
for i in range(n_points):
40+
for j in range(dim):
41+
result[i, j] = points[indexes[i], j]
42+
43+
return result
44+
45+
3046
@numba.jit(nopython=True, cache=True)
3147
def get_center_point(points: NPPointListType) -> NPPointType:
3248
return np.sum(points, axis=0) / len(points)
@@ -103,8 +119,8 @@ def get_quad_angle_quality(quad_points: NPPointListType) -> float:
103119

104120

105121
@numba.jit(nopython=True, cache=True)
106-
def get_quad_quality(grid_points: NPPointListType, cell_indexes: List[int]) -> float:
107-
cell_points = np.take(grid_points, cell_indexes, axis=0)
122+
def get_quad_quality(grid_points: NPPointListType, cell_indexes: NPIndexType) -> float:
123+
cell_points = take(grid_points, cell_indexes)
108124
cell_center, cell_normal, cell_aspect = get_quad_normal(cell_points)
109125

110126
# non-ortho
@@ -119,8 +135,8 @@ def get_quad_quality(grid_points: NPPointListType, cell_indexes: List[int]) -> f
119135

120136

121137
@numba.jit(nopython=True, cache=True)
122-
def get_hex_quality(grid_points: NPPointListType, cell_indexes: List[int]) -> float:
123-
cell_points = np.take(grid_points, cell_indexes, axis=0)
138+
def get_hex_quality(grid_points: NPPointListType, cell_indexes: NPIndexType) -> float:
139+
cell_points = take(grid_points, cell_indexes)
124140
cell_center = get_center_point(cell_points)
125141

126142
side_indexes = np.array([[0, 1, 2, 3], [7, 6, 5, 4], [4, 0, 3, 7], [6, 2, 1, 5], [0, 4, 5, 1], [7, 3, 2, 6]])
@@ -136,7 +152,7 @@ def get_hex_quality(grid_points: NPPointListType, cell_indexes: List[int]) -> fl
136152
# For this kind of optimization it is quite sufficient to take
137153
# only the latter as it's not much different and we'll optimize
138154
# other cells too.
139-
side_points = np.take(cell_points, side, axis=0)
155+
side_points = take(cell_points, side)
140156
side_center, side_normal, side_aspect = get_quad_normal(side_points)
141157
center_vector = cell_center - side_center
142158

0 commit comments

Comments
 (0)