Skip to content

Add point2index functionality for tensor_mesh #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion discretize/tensor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from discretize.base import BaseRectangularMesh, BaseTensorMesh
from discretize.operators import DiffOperators, InnerProducts
from discretize.mixins import InterfaceMixins, TensorMeshIO
from discretize.utils import mkvc
from discretize.utils import mkvc, as_array_n_by_dim
from discretize.utils.code_utils import deprecate_property

from .tensor_cell import TensorCell
Expand Down Expand Up @@ -756,6 +756,33 @@ def cell_boundary_indices(self):
indzu = self.gridCC[:, 2] == max(self.gridCC[:, 2])
return indxd, indxu, indyd, indyu, indzd, indzu

def point2index(self, locs): # NOQA D102
# Documentation inherited from discretize.base.BaseMesh

locs = as_array_n_by_dim(locs, self.dim)
# in each dimension do a sorted search within the nodes
# arrays to find the containing cell in that dimension
cell_bounds = [
self.nodes_x,
]
if self.dim > 1:
cell_bounds.append(self.nodes_y)
if self.dim == 3:
cell_bounds.append(self.nodes_z)

# subtract 1 here because given the nodes [0, 1], the point 0.5 would be inserted
# at index 1 to maintain the sorted list, but that corresponds to cell 0.
# clipping here ensures that anything outside the mesh will return the nearest cell.
multi_inds = tuple(
np.clip(np.searchsorted(n, p) - 1, 0, len(n) - 2)
for n, p in zip(cell_bounds, locs.T)
)
# and of course, we are fortran ordered in a tensor mesh.
if self.dim == 1:
return multi_inds[0]
else:
return np.ravel_multi_index(multi_inds, self.shape_cells, order="F")

def _repr_attributes(self):
"""Represent attributes of the mesh."""
attrs = {}
Expand Down
66 changes: 66 additions & 0 deletions tests/base/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
import numpy.testing as npt
import unittest
import discretize
from scipy.sparse.linalg import spsolve
Expand Down Expand Up @@ -321,5 +322,70 @@ def test_orderBackward(self):
self.orderTest()


@pytest.fixture(params=[1, 2, 3], ids=["dims-1", "dims-2", "dims-3"])
def random_tensor_mesh(request):
dim = request.param
rng = np.random.default_rng(440122)
shape = rng.integers(5, 10, dim)
cell_widths = [rng.uniform(3.0, 872634.321, n) for n in shape]
origin = rng.uniform(-101.031, 33.2, dim)

return discretize.TensorMesh(cell_widths, origin)


def test_tensor_point2index_inside_points(random_tensor_mesh):
mesh = random_tensor_mesh
dim = mesh.dim
m_origin = mesh.origin
m_extent = np.atleast_1d(np.max(mesh.nodes, axis=0))

nd = 15
points = np.stack(np.meshgrid(*np.linspace(m_origin, m_extent, nd).T), axis=-1)
points = points.reshape((-1, dim))

npt.assert_array_equal(mesh.is_inside(points), True)

cell_inds = mesh.point2index(points)
for icell, p in zip(cell_inds, points):
cell = mesh[icell]
c_origin, c_extent = cell.bounds.reshape((dim, 2)).T
dim_test = (p >= c_origin) & (p <= c_extent)
npt.assert_equal(dim_test, True)


def test_tensor_point2index_outside_points(random_tensor_mesh):
mesh = random_tensor_mesh
dim = mesh.dim
m_origin = mesh.origin
m_extent = np.atleast_1d(np.max(mesh.nodes, axis=0))
m_width = m_extent - m_origin

nd = 15
points = np.stack(
np.meshgrid(*np.linspace(m_origin - m_width * 2, m_extent + m_width * 2, nd).T),
axis=-1,
)
points = points.reshape((-1, dim))
outside_points = points[~mesh.is_inside(points)]

npt.assert_array_equal(mesh.is_inside(outside_points), False)

# manually check each point that is outside
cell_inds = mesh.point2index(outside_points)
for icell, p in zip(cell_inds, outside_points):
cell = mesh[icell]
c_origin, c_extent = cell.bounds.reshape((dim, 2)).T
dim_test = np.zeros(dim, bool)
for i in range(dim):
p_d = p[i]
if p_d < m_origin[i]:
dim_test[i] = p_d < c_origin[i]
elif p_d > m_extent[i]:
dim_test[i] = p_d > c_extent[i]
else:
dim_test[i] = p_d >= c_origin[i] and p_d <= c_extent[i]
npt.assert_equal(dim_test, True)


if __name__ == "__main__":
unittest.main()
Loading