Skip to content

Commit 376ffc6

Browse files
committed
Add option to restrict the particles the force applies to
1 parent 715d412 commit 376ffc6

File tree

10 files changed

+93
-16
lines changed

10 files changed

+93
-16
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ torch.onnx.export(model=PeriodicForce(),
108108
dynamic_axes={"positions":[0], "forces":[0]})
109109
```
110110

111+
## Applying to a Subset of Particles
112+
113+
In some cases one wants to model part of a system with a machine learning potential and the rest with a
114+
conventional force field. You can restrict which particles the `OnnxForce` acts on by calling `setParticleIndices()`.
115+
For example, the following applies it only to the first 50 particles in the system.
116+
117+
```python
118+
particles = list(range(50))
119+
force.setParticleIndices(particles)
120+
```
121+
122+
The `positions` tensor passed to the model will contain only the positions of the specified particles.
123+
Likewise, the `forces` tensor returned by the model should contain only the forces on those particles.
124+
111125
## Global Parameters
112126

113127
An `OnnxForce` can define global parameters that the model depends on. The model should have an additional

openmmapi/include/OnnxForce.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ class OPENMM_EXPORT_ONNX OnnxForce : public OpenMM::Force {
101101
* Set the execution provider to be used for computing the model.
102102
*/
103103
void setExecutionProvider(ExecutionProvider provider);
104+
/**
105+
* Get the indices of the particles this force is applied to. If this is empty, the
106+
* force is applied to all particles in the system.
107+
*/
108+
const std::vector<int>& getParticleIndices() const;
109+
/**
110+
* Set the indices of the particles this force is applied to. If this is empty, the
111+
* force is applied to all particles in the system.
112+
*/
113+
void setParticleIndices(const std::vector<int>& indices);
104114
/**
105115
* Get whether this force uses periodic boundary conditions.
106116
*/
@@ -169,6 +179,7 @@ class OPENMM_EXPORT_ONNX OnnxForce : public OpenMM::Force {
169179
class GlobalParameterInfo;
170180
void initProperties(const std::map<std::string, std::string>& properties);
171181
std::vector<uint8_t> model;
182+
std::vector<int> particleIndices;
172183
ExecutionProvider provider;
173184
bool periodic;
174185
std::vector<GlobalParameterInfo> globalParameters;

openmmapi/include/internal/OnnxForceImpl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class OPENMM_EXPORT_ONNX OnnxForceImpl : public OpenMM::CustomCPPForceImpl {
6060
Ort::Session session;
6161
std::vector<Ort::Value> inputTensors, outputTensors;
6262
std::vector<const char*> inputNames;
63+
std::vector<int> particleIndices;
6364
std::vector<float> positionVec, paramVec;
6465
float boxVectors[9];
6566
};

openmmapi/src/OnnxForce.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ void OnnxForce::setExecutionProvider(OnnxForce::ExecutionProvider provider) {
7474
this->provider = provider;
7575
}
7676

77+
const vector<int>& OnnxForce::getParticleIndices() const {
78+
return particleIndices;
79+
}
80+
81+
void OnnxForce::setParticleIndices(const vector<int>& indices) {
82+
particleIndices = indices;
83+
}
84+
7785
bool OnnxForce::usesPeriodicBoundaryConditions() const {
7886
return periodic;
7987
}

openmmapi/src/OnnxForceImpl.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ OnnxForceImpl::~OnnxForceImpl() {
4747
void OnnxForceImpl::initialize(ContextImpl& context) {
4848
CustomCPPForceImpl::initialize(context);
4949

50+
// Record which particles the force is applied to.
51+
52+
particleIndices = owner.getParticleIndices();
53+
if (particleIndices.size() == 0) {
54+
int numParticles = context.getSystem().getNumParticles();
55+
for (int i = 0; i < numParticles; i++)
56+
particleIndices.push_back(i);
57+
}
58+
5059
// Select the execution provider and set options.
5160

5261
OnnxForce::ExecutionProvider provider = owner.getExecutionProvider();
@@ -97,11 +106,10 @@ void OnnxForceImpl::initialize(ContextImpl& context) {
97106

98107
const vector<uint8_t>& model = owner.getModel();
99108
session = Session(env, model.data(), model.size(), options);
100-
int numParticles = context.getSystem().getNumParticles();
101-
positionVec.resize(3*numParticles);
109+
positionVec.resize(3*particleIndices.size());
102110
paramVec.resize(owner.getNumGlobalParameters());
103111
auto memoryInfo = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
104-
int64_t positionsShape[] = {numParticles, 3};
112+
int64_t positionsShape[] = {static_cast<int64_t>(particleIndices.size()), 3};
105113
int64_t boxShape[] = {3, 3};
106114
int64_t paramShape[] = {1};
107115
int numInputs = 1+owner.getNumGlobalParameters();
@@ -129,11 +137,12 @@ map<string, double> OnnxForceImpl::getDefaultParameters() {
129137
double OnnxForceImpl::computeForce(ContextImpl& context, const vector<Vec3>& positions, vector<Vec3>& forces) {
130138
// Pass the current state to ONNX Runtime.
131139

132-
int numParticles = context.getSystem().getNumParticles();
140+
int numParticles = particleIndices.size();
133141
for (int i = 0; i < numParticles; i++) {
134-
positionVec[3*i] = (float) positions[i][0];
135-
positionVec[3*i+1] = (float) positions[i][1];
136-
positionVec[3*i+2] = (float) positions[i][2];
142+
int index = particleIndices[i];
143+
positionVec[3*i] = (float) positions[index][0];
144+
positionVec[3*i+1] = (float) positions[index][1];
145+
positionVec[3*i+2] = (float) positions[index][2];
137146
}
138147
if (owner.usesPeriodicBoundaryConditions()) {
139148
Vec3 box[3];
@@ -152,6 +161,6 @@ double OnnxForceImpl::computeForce(ContextImpl& context, const vector<Vec3>& pos
152161
const float* energy = outputTensors[0].GetTensorData<float>();
153162
const float* forceData = outputTensors[1].GetTensorData<float>();
154163
for (int i = 0; i < numParticles; i++)
155-
forces[i] = Vec3(forceData[3*i], forceData[3*i+1], forceData[3*i+2]);
164+
forces[particleIndices[i]] = Vec3(forceData[3*i], forceData[3*i+1], forceData[3*i+2]);
156165
return *energy;
157166
}

python/onnxplugin.i

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
namespace std {
3636
%template(vectorbyte) vector<unsigned char>;
37+
%template(vectorint) vector<int>;
3738
%template(property_map) map<std::string, std::string>;
3839
};
3940

@@ -53,6 +54,8 @@ public:
5354
const std::vector<uint8_t>& getModel() const;
5455
ExecutionProvider getExecutionProvider() const;
5556
void setExecutionProvider(ExecutionProvider provider);
57+
const std::vector<int>& getParticleIndices() const;
58+
void setParticleIndices(const std::vector<int>& indices);
5659
bool usesPeriodicBoundaryConditions() const;
5760
void setUsesPeriodicBoundaryConditions(bool periodic);
5861
int getNumGlobalParameters() const;

python/tests/TestOnnxForce.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def testConstructors():
1515

1616
@pytest.mark.parametrize('use_cv_force', [True, False])
1717
@pytest.mark.parametrize('platform', [mm.Platform.getPlatform(i).getName() for i in range(mm.Platform.getNumPlatforms())])
18-
def testForce(use_cv_force, platform):
18+
@pytest.mark.parametrize('particles', [[], [5,3,0]])
19+
def testForce(use_cv_force, platform, particles):
1920

2021
# Create a random cloud of particles.
2122
numParticles = 10
@@ -26,6 +27,7 @@ def testForce(use_cv_force, platform):
2627

2728
# Create a force
2829
force = openmmonnx.OnnxForce('../../tests/central.onnx', {'UseGraphs': 'false'})
30+
force.setParticleIndices(particles)
2931
assert force.getProperties()['UseGraphs'] == 'false'
3032
if use_cv_force:
3133
# Wrap OnnxForce into CustomCVForce
@@ -45,9 +47,14 @@ def testForce(use_cv_force, platform):
4547
state = context.getState(getEnergy=True, getForces=True)
4648

4749
# See if the energy and forces are correct. The network defines a potential of the form E(r) = |r|^2
50+
if len(particles) > 0:
51+
positions = positions[particles]
4852
expectedEnergy = np.sum(positions*positions)
4953
assert np.allclose(expectedEnergy, state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole))
50-
assert np.allclose(-2*positions, state.getForces(asNumpy=True))
54+
forces = state.getForces(asNumpy=True)
55+
if len(particles) > 0:
56+
forces = forces[particles]
57+
assert np.allclose(-2*positions, forces)
5158

5259
def testProperties():
5360
""" Test that the properties are correctly set and retrieved """

serialization/src/OnnxForceProxy.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ void OnnxForceProxy::serialize(const void* object, SerializationNode& node) cons
7070
node.setStringProperty("model", hexEncode(force.getModel()));
7171
node.setIntProperty("forceGroup", force.getForceGroup());
7272
node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions());
73+
const vector<int>& indices = force.getParticleIndices();
74+
auto& indicesNode = node.createChildNode("ParticleIndices");
75+
for (int i = 0; i < indices.size(); i++)
76+
indicesNode.createChildNode("Particle").setIntProperty("index", indices[i]);
7377
SerializationNode& globalParams = node.createChildNode("GlobalParameters");
7478
for (int i = 0; i < force.getNumGlobalParameters(); i++)
7579
globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i));
@@ -86,6 +90,12 @@ void* OnnxForceProxy::deserialize(const SerializationNode& node) const {
8690
force->setForceGroup(node.getIntProperty("forceGroup"));
8791
force->setUsesPeriodicBoundaryConditions(node.getBoolProperty("usesPeriodic"));
8892
for (const SerializationNode& child : node.getChildren()) {
93+
if (child.getName() == "ParticleIndices") {
94+
vector<int> indices;
95+
for (auto& particle : child.getChildren())
96+
indices.push_back(particle.getIntProperty("index"));
97+
force->setParticleIndices(indices);
98+
}
8999
if (child.getName() == "GlobalParameters")
90100
for (auto& parameter : child.getChildren())
91101
force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default"));

serialization/tests/TestSerializeOnnxForce.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void testSerialization() {
4949
force.addGlobalParameter("y", 2.221);
5050
force.setUsesPeriodicBoundaryConditions(true);
5151
force.setProperty("UseGraphs", "true");
52+
force.setParticleIndices({0, 2, 4});
5253

5354
// Serialize and then deserialize it.
5455

@@ -61,6 +62,7 @@ void testSerialization() {
6162
OnnxForce& force2 = *copy;
6263
ASSERT_EQUAL_CONTAINERS(force.getModel(), force2.getModel());
6364
ASSERT_EQUAL(force.getForceGroup(), force2.getForceGroup());
65+
ASSERT_EQUAL_CONTAINERS(force.getParticleIndices(), force2.getParticleIndices());
6466
ASSERT_EQUAL(force.getNumGlobalParameters(), force2.getNumGlobalParameters());
6567
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
6668
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));

tests/TestOnnxForce.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "openmm/System.h"
3737
#include "openmm/VerletIntegrator.h"
3838
#include "sfmt/SFMT.h"
39+
#include <algorithm>
3940
#include <cmath>
4041
#include <iostream>
4142
#include <vector>
@@ -44,7 +45,7 @@ using namespace OnnxPlugin;
4445
using namespace OpenMM;
4546
using namespace std;
4647

47-
void testForce(Platform& platform) {
48+
void testForce(Platform& platform, vector<int> particleIndices) {
4849
// Create a random cloud of particles.
4950

5051
const int numParticles = 10;
@@ -57,6 +58,10 @@ void testForce(Platform& platform) {
5758
positions[i] = Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*10;
5859
}
5960
OnnxForce* force = new OnnxForce("tests/central.onnx");
61+
force->setParticleIndices(particleIndices);
62+
if (particleIndices.size() == 0)
63+
for (int i = 0; i < numParticles; i++)
64+
particleIndices.push_back(i);
6065
system.addForce(force);
6166

6267
// Compute the forces and energy.
@@ -70,10 +75,16 @@ void testForce(Platform& platform) {
7075

7176
double expectedEnergy = 0;
7277
for (int i = 0; i < numParticles; i++) {
73-
Vec3 pos = positions[i];
74-
double r = sqrt(pos.dot(pos));
75-
expectedEnergy += r*r;
76-
ASSERT_EQUAL_VEC(pos*(-2.0), state.getForces()[i], 1e-5);
78+
if (find(particleIndices.begin(), particleIndices.end(), i) != particleIndices.end()) {
79+
Vec3 pos = positions[i];
80+
double r = sqrt(pos.dot(pos));
81+
expectedEnergy += r*r;
82+
ASSERT_EQUAL_VEC(pos*(-2.0), state.getForces()[i], 1e-5);
83+
}
84+
else {
85+
Vec3 zero;
86+
ASSERT_EQUAL_VEC(zero, state.getForces()[i], 1e-5);
87+
}
7788
}
7889
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
7990
}
@@ -164,7 +175,8 @@ void testGlobal(Platform& platform) {
164175
}
165176

166177
void testPlatform(Platform& platform) {
167-
testForce(platform);
178+
testForce(platform, {});
179+
testForce(platform, {0, 1, 2, 9, 5});
168180
testPeriodicForce(platform);
169181
testGlobal(platform);
170182
}

0 commit comments

Comments
 (0)