Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Enhancements
(PR #5038)
* Moved distopia checking function to common import location in
MDAnalysisTest.util (PR #5038)
* Enables parallelization for `analysis.polymer.PersistenceLength` (Issue #4671, PR #5074)


Changes
* Removed undocumented and unused attribute
Expand Down
45 changes: 32 additions & 13 deletions package/MDAnalysis/analysis/polymer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .. import NoDataError
from ..core.groups import requires, AtomGroup
from ..lib.distances import calc_bonds
from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -236,8 +236,17 @@ class PersistenceLength(AnalysisBase):
Former ``results`` are now stored as ``results.bond_autocorrelation``.
:attr:`lb`, :attr:`lp`, :attr:`fit` are now stored in a
:class:`MDAnalysis.analysis.base.Results` instance.
.. versionchanged:: 2.10.0
introduced :meth:`get_supported_backends` allowing for parallel
execution on ``multiprocessing`` and ``dask`` backends.
"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return ("serial", "multiprocessing", "dask")

def __init__(self, atomgroups, **kwargs):
super(PersistenceLength, self).__init__(
atomgroups[0].universe.trajectory, **kwargs
Expand All @@ -249,15 +258,18 @@ def __init__(self, atomgroups, **kwargs):
chainlength = len(atomgroups[0])
if not all(l == chainlength for l in lens):
raise ValueError("Not all AtomGroups were the same size")
self.chainlength = chainlength

self._results = np.zeros(chainlength - 1, dtype=np.float32)
def _prepare(self):
self.results.raw_bond_autocorr = np.zeros(
self.chainlength - 1, dtype=np.float32
)

def _single_frame(self):
# could optimise this by writing a "self dot array"
# we're only using the upper triangle of np.inner
# function would accept a bunch of coordinates and spit out the
# decorrel for that
n = len(self._atomgroups[0])

for chain in self._atomgroups:
# Vector from each atom to next
Expand All @@ -266,8 +278,17 @@ def _single_frame(self):
vecs /= np.sqrt((vecs * vecs).sum(axis=1))[:, None]

inner_pr = np.inner(vecs, vecs)
for i in range(n - 1):
self._results[: (n - 1) - i] += inner_pr[i, i:]
for i in range(self.chainlength - 1):
self.results.raw_bond_autocorr[
: (self.chainlength - 1) - i
] += inner_pr[i, i:]

def _get_aggregator(self):
return ResultsGroup(
lookup={
"raw_bond_autocorr": ResultsGroup.ndarray_sum,
}
)

@property
def lb(self):
Expand Down Expand Up @@ -300,14 +321,12 @@ def fit(self):
return self.results.fit

def _conclude(self):
n = len(self._atomgroups[0])

norm = np.linspace(n - 1, 1, n - 1)
norm *= len(self._atomgroups) * self.n_frames

self.results.bond_autocorrelation = self._results / norm
norm = np.linspace(self.chainlength - 1, 1, self.chainlength - 1)
norm *= len(self._atomgroups) * self._trajectory.n_frames
self.results.bond_autocorrelation = (
self.results.raw_bond_autocorr / norm
)
self._calc_bond_length()

self._perform_fit()

def _calc_bond_length(self):
Expand Down Expand Up @@ -350,7 +369,7 @@ def plot(self, ax=None):
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()
_, ax = plt.subplots()
ax.plot(
self.results.x,
self.results.bond_autocorrelation,
Expand Down
9 changes: 9 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from MDAnalysis.analysis.contacts import Contacts
from MDAnalysis.analysis.density import DensityAnalysis
from MDAnalysis.analysis.lineardensity import LinearDensity
from MDAnalysis.analysis.polymer import PersistenceLength
from MDAnalysis.lib.util import is_installed


Expand Down Expand Up @@ -185,3 +186,11 @@ def client_DensityAnalysis(request):
@pytest.fixture(scope="module", params=params_for_cls(LinearDensity))
def client_LinearDensity(request):
return request.param


# MDAnalysis.analysis.polymer


@pytest.fixture(scope="module", params=params_for_cls(PersistenceLength))
def client_PersistenceLength(request):
return request.param
30 changes: 17 additions & 13 deletions testsuite/MDAnalysisTests/analysis/test_persistencelength.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,33 @@
from MDAnalysisTests.datafiles import Plength, TRZ_psf, TRZ


def test_class_is_parallelizable():
assert polymer.PersistenceLength._analysis_algorithm_is_parallelizable


def test_supported_backends():
assert polymer.PersistenceLength.get_supported_backends() == (
"serial",
"multiprocessing",
"dask",
)


class TestPersistenceLength(object):
@staticmethod
@pytest.fixture()
def u():
return mda.Universe(Plength)

@staticmethod
@pytest.fixture()
def p(u):
def p(self, u):
ags = [r.atoms.select_atoms("name C* N*") for r in u.residues]

p = polymer.PersistenceLength(ags)
return p

@staticmethod
@pytest.fixture()
def p_run(p):
return p.run()
def p_run(self, p, client_PersistenceLength):
return p.run(**client_PersistenceLength)

def test_ag_ValueError(self, u):
ags = [u.atoms[:10], u.atoms[10:110]]
Expand Down Expand Up @@ -81,15 +90,11 @@ def test_raise_NoDataError(self, p):
def test_plot_ax_return(self, p_run):
"""Ensure that a matplotlib axis object is
returned when plot() is called."""
actual = p_run.plot()
expected = matplotlib.axes.Axes
assert isinstance(actual, expected)
assert isinstance(p_run.plot(), matplotlib.axes.Axes)

def test_plot_with_ax(self, p_run):
fig, ax = plt.subplots()

ax2 = p_run.plot(ax=ax)

assert ax2 is ax

def test_current_axes(self, p_run):
Expand All @@ -98,8 +103,7 @@ def test_current_axes(self, p_run):
assert ax2 is not ax

@pytest.mark.parametrize("attr", ("lb", "lp", "fit"))
def test(self, p, attr):
p_run = p.run(step=3)
def test(self, p_run, attr):
wmsg = f"The `{attr}` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
getattr(p_run, attr) is p_run.results[attr]
Expand Down
Loading