Skip to content

Commit 27f25ef

Browse files
authored
Allow extra inputs to model (#5)
* Allow extra inputs to model * Serialization of inputs * Python interface for inputs * Documentation for inputs * Fixed compilation error on Python 3.12 and earlier
1 parent 96231bf commit 27f25ef

File tree

12 files changed

+513
-1
lines changed

12 files changed

+513
-1
lines changed

README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,54 @@ on the `Context`.
159159
context.setParameter("k", 5.0)
160160
```
161161

162+
## Extra Inputs
163+
164+
You also can specify extra inputs that should be passed to the model. Unlike global parameters,
165+
which are always scalars, extra inputs can be tensors of any size and shape. On the other hand,
166+
their values are fixed at Context creation time, and can only be changed by reinitializing the
167+
Context.
168+
169+
This example is similar to the one above, but `k` is now a vector containing a different force
170+
constant for every particle.
171+
172+
```python
173+
class ForceWithInput(torch.nn.Module):
174+
def forward(self, positions, k):
175+
positions.grad = None
176+
r2 = torch.sum(positions*positions, dim=1)
177+
energy = torch.sum(k*r2)
178+
energy.backward()
179+
forces = -positions.grad
180+
return energy, forces
181+
182+
torch.onnx.export(model=ForceWithInput(),
183+
args=(torch.ones(1, 3, requires_grad=True), torch.ones(1)),
184+
f="ForceWithInput.onnx",
185+
input_names=["positions", "k"],
186+
output_names=["energy", "forces"],
187+
dynamic_axes={"positions":[0], "forces":[0], "k":[0]})
188+
```
189+
190+
Notice that we included `k[0]` in `dynamic_axes` when exporting the model. This allows its length
191+
to be variable, so we can use the model for systems with any number of particles.
192+
193+
Here is how we create the OnnxForce.
194+
195+
```python
196+
import openmmonnx
197+
force = OnnxForce("ForceWithInput.onnx")
198+
force.addInput(openmmonnx.FloatInput("k", k, [len(k)]))
199+
```
200+
201+
The three arguments to the FloatInput constructor are the name of the input (matching the name we
202+
specified in `input_names` when exporting the model), a list or array containing the values, and
203+
the shape of the tensor. In this case the tensor has one dimension, so the shape argument contains
204+
only a single value. Higher dimensional tensors are also allowed. In that case, the second
205+
argument should contain the values in flattened order.
206+
207+
In addition to FloatInput, which specifies a tensor of 32 bit floating point values, there is also
208+
an IntegerInput class, which specifies a tensor of 32 bit integer values.
209+
162210
## Execution Providers
163211

164212
ONNX Runtime supports a variety of backends that can be used to compute the neural network. They

openmmapi/include/OnnxForce.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ namespace OnnxPlugin {
4646

4747
class OPENMM_EXPORT_ONNX OnnxForce : public OpenMM::Force {
4848
public:
49+
class Input;
50+
class IntegerInput;
51+
class FloatInput;
4952
/**
5053
* This is an enumeration of ONNX execution providers.
5154
*/
@@ -89,6 +92,7 @@ class OPENMM_EXPORT_ONNX OnnxForce : public OpenMM::Force {
8992
* @param properties optional map of properties
9093
*/
9194
OnnxForce(const std::vector<uint8_t>& model, const std::map<std::string, std::string>& properties={});
95+
~OnnxForce();
9296
/**
9397
* Get the binary representation of the model in ONNX format.
9498
*/
@@ -161,6 +165,32 @@ class OPENMM_EXPORT_ONNX OnnxForce : public OpenMM::Force {
161165
* @param defaultValue the default value of the parameter
162166
*/
163167
void setGlobalParameterDefaultValue(int index, double defaultValue);
168+
/**
169+
* Get the number of extra tensors to pass to the model.
170+
*/
171+
int getNumInputs() const;
172+
/**
173+
* Add an extra tensor that should be passed to the model. The Input object should have
174+
* been created on the heap with the "new" operator. The OnnxForce takes over ownership
175+
* of it, and deletes it when the OnnxForce itself is deleted.
176+
*
177+
* @param input the tensor to pass in. This should be the appropriate subclass for
178+
* the type of value.
179+
* @return the index of the input that was added
180+
*/
181+
int addInput(Input* input);
182+
/**
183+
* Get a const reference to an extra input that is passed to the model.
184+
*
185+
* @param index the index of the input to return
186+
*/
187+
const Input& getInput(int index) const;
188+
/**
189+
* Get a writable reference to an extra input that is passed to the model.
190+
*
191+
* @param index the index of the input to return
192+
*/
193+
Input& getInput(int index);
164194
/**
165195
* Set the value of a property.
166196
*
@@ -183,9 +213,118 @@ class OPENMM_EXPORT_ONNX OnnxForce : public OpenMM::Force {
183213
ExecutionProvider provider;
184214
bool periodic;
185215
std::vector<GlobalParameterInfo> globalParameters;
216+
std::vector<Input*> inputs;
186217
std::map<std::string, std::string> properties;
187218
};
188219

220+
/**
221+
* An Input defines a tensor that should be passed to the model. This is an abstract class. Subclasses
222+
* define particular types of tensors.
223+
*/
224+
class OnnxForce::Input {
225+
public:
226+
virtual ~Input() {
227+
}
228+
/**
229+
* Get the name of the input.
230+
*/
231+
const std::string& getName() const {
232+
return name;
233+
}
234+
/**
235+
* Get the shape of the tensor.
236+
*/
237+
const std::vector<int>& getShape() const {
238+
return shape;
239+
}
240+
/**
241+
* Set the shape of the tensor.
242+
*/
243+
void setShape(const std::vector<int>& shape) {
244+
this->shape = shape;
245+
}
246+
protected:
247+
Input() {
248+
}
249+
Input(const std::string& name, const std::vector<int>& shape) : name(name), shape(shape) {
250+
}
251+
private:
252+
std::string name;
253+
std::vector<int> shape;
254+
};
255+
256+
/**
257+
* A tensor containing integer values that should be passed to the model.
258+
*/
259+
class OnnxForce::IntegerInput : public Input {
260+
public:
261+
/**
262+
* Create an IntegerInput.
263+
*
264+
* @param name the name of the input
265+
* @param values the values contained by the tensor, in flattened order
266+
* @param shape the shape of the tensor
267+
*/
268+
IntegerInput(const std::string& name, const std::vector<int>& values, const std::vector<int>& shape) : Input(name, shape), values(values) {
269+
}
270+
/**
271+
* Get a const reference to the values contained in the tensor, in flattened order.
272+
*/
273+
const std::vector<int>& getValues() const {
274+
return values;
275+
}
276+
/**
277+
* Get a writable reference to the values contained in the tensor, in flattened order.
278+
*/
279+
std::vector<int>& getValues() {
280+
return values;
281+
}
282+
/**
283+
* Set the values contained in the tensor, in flattened order.
284+
*/
285+
void setValues(const std::vector<int>& values) {
286+
this->values = values;
287+
}
288+
private:
289+
std::vector<int> values;
290+
};
291+
292+
/**
293+
* A tensor containing float values that should be passed to the model.
294+
*/
295+
class OnnxForce::FloatInput : public Input {
296+
public:
297+
/**
298+
* Create an FloatInput.
299+
*
300+
* @param name the name of the input
301+
* @param values the values contained by the tensor, in flattened order
302+
* @param shape the shape of the tensor
303+
*/
304+
FloatInput(const std::string& name, const std::vector<float>& values, const std::vector<int>& shape) : Input(name, shape), values(values) {
305+
}
306+
/**
307+
* Get a const reference to the values contained in the tensor, in flattened order.
308+
*/
309+
const std::vector<float>& getValues() const {
310+
return values;
311+
}
312+
/**
313+
* Get a writable reference to the values contained in the tensor, in flattened order.
314+
*/
315+
std::vector<float>& getValues() {
316+
return values;
317+
}
318+
/**
319+
* Set the values contained in the tensor, in flattened order.
320+
*/
321+
void setValues(const std::vector<float>& values) {
322+
this->values = values;
323+
}
324+
private:
325+
std::vector<float> values;
326+
};
327+
189328
/**
190329
* This is an internal class used to record information about a global parameter.
191330
* @private

openmmapi/include/internal/OnnxForceImpl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,15 @@ class OPENMM_EXPORT_ONNX OnnxForceImpl : public OpenMM::CustomCPPForceImpl {
5656
double computeForce(OpenMM::ContextImpl& context, const std::vector<OpenMM::Vec3>& positions, std::vector<OpenMM::Vec3>& forces);
5757
private:
5858
const OnnxForce& owner;
59+
void validateInput(const std::string& name, const std::vector<int>& shape, int size);
5960
Ort::Env env;
6061
Ort::Session session;
6162
std::vector<Ort::Value> inputTensors, outputTensors;
6263
std::vector<const char*> inputNames;
6364
std::vector<int> particleIndices;
6465
std::vector<float> positionVec, paramVec;
66+
std::vector<OnnxForce::IntegerInput> integerInputs;
67+
std::vector<OnnxForce::FloatInput> floatInputs;
6568
float boxVectors[9];
6669
};
6770

openmmapi/src/OnnxForce.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ OnnxForce::OnnxForce(const std::vector<uint8_t>& model, const map<string, string
5252
initProperties(properties);
5353
}
5454

55+
OnnxForce::~OnnxForce() {
56+
for (Input* input : inputs)
57+
delete input;
58+
}
59+
5560
void OnnxForce::initProperties(const std::map<std::string, std::string>& properties) {
5661
const std::map<std::string, std::string> defaultProperties = {{"UseGraphs", "false"}, {"DeviceIndex", "0"}};
5762
this->properties = defaultProperties;
@@ -123,6 +128,25 @@ void OnnxForce::setGlobalParameterDefaultValue(int index, double defaultValue) {
123128
globalParameters[index].defaultValue = defaultValue;
124129
}
125130

131+
int OnnxForce::getNumInputs() const {
132+
return inputs.size();
133+
}
134+
135+
int OnnxForce::addInput(OnnxForce::Input* input) {
136+
inputs.push_back(input);
137+
return inputs.size() - 1;
138+
}
139+
140+
const OnnxForce::Input& OnnxForce::getInput(int index) const {
141+
ASSERT_VALID_INDEX(index, inputs);
142+
return *inputs[index];
143+
}
144+
145+
OnnxForce::Input& OnnxForce::getInput(int index) {
146+
ASSERT_VALID_INDEX(index, inputs);
147+
return *inputs[index];
148+
}
149+
126150
void OnnxForce::setProperty(const string& name, const string& value) {
127151
if (properties.find(name) == properties.end())
128152
throw OpenMMException("OnnxForce: Unknown property '" + name + "'");

openmmapi/src/OnnxForceImpl.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "internal/OnnxForceImpl.h"
3333
#include "openmm/OpenMMException.h"
3434
#include "openmm/internal/ContextImpl.h"
35+
#include <sstream>
3536

3637
using namespace OnnxPlugin;
3738
using namespace OpenMM;
@@ -125,6 +126,46 @@ void OnnxForceImpl::initialize(ContextImpl& context) {
125126
inputTensors.emplace_back(Value::CreateTensor<float>(memoryInfo, &paramVec[i], 1, paramShape, 1));
126127
inputNames.push_back(owner.getGlobalParameterName(i).c_str());
127128
}
129+
130+
// Process extra inputs.
131+
132+
for (int i = 0; i < owner.getNumInputs(); i++) {
133+
const OnnxForce::IntegerInput* integerInput = dynamic_cast<const OnnxForce::IntegerInput*>(&owner.getInput(i));
134+
if (integerInput != nullptr) {
135+
validateInput(integerInput->getName(), integerInput->getShape(), integerInput->getValues().size());
136+
integerInputs.push_back(OnnxForce::IntegerInput(integerInput->getName(), integerInput->getValues(), integerInput->getShape()));
137+
}
138+
const OnnxForce::FloatInput* floatInput = dynamic_cast<const OnnxForce::FloatInput*>(&owner.getInput(i));
139+
if (floatInput != nullptr) {
140+
validateInput(floatInput->getName(), floatInput->getShape(), floatInput->getValues().size());
141+
floatInputs.push_back(OnnxForce::FloatInput(floatInput->getName(), floatInput->getValues(), floatInput->getShape()));
142+
}
143+
}
144+
for (OnnxForce::IntegerInput& input : integerInputs) {
145+
vector<int64_t> shape;
146+
for (int i : input.getShape())
147+
shape.push_back(i);
148+
inputTensors.emplace_back(Value::CreateTensor<int>(memoryInfo, input.getValues().data(), input.getValues().size(), shape.data(), shape.size()));
149+
inputNames.push_back(input.getName().c_str());
150+
}
151+
for (OnnxForce::FloatInput& input : floatInputs) {
152+
vector<int64_t> shape;
153+
for (int i : input.getShape())
154+
shape.push_back(i);
155+
inputTensors.emplace_back(Value::CreateTensor<float>(memoryInfo, input.getValues().data(), input.getValues().size(), shape.data(), shape.size()));
156+
inputNames.push_back(input.getName().c_str());
157+
}
158+
}
159+
160+
void OnnxForceImpl::validateInput(const string& name, const vector<int>& shape, int size) {
161+
int expected = 1;
162+
for (int i : shape)
163+
expected *= i;
164+
if (expected != size) {
165+
stringstream message;
166+
message<<"Incorrect length for input '"<<name<<"'. Expected "<<expected<<" elements, found "<<size<<".";
167+
throw OpenMMException(message.str());
168+
}
128169
}
129170

130171
map<string, double> OnnxForceImpl::getDefaultParameters() {

0 commit comments

Comments
 (0)