Skip to content

Commit 926e043

Browse files
authored
Rename more batch to system (#233)
1 parent 16bf8f8 commit 926e043

File tree

14 files changed

+119
-119
lines changed

14 files changed

+119
-119
lines changed

examples/scripts/7_Others/7.3_Batched_neighbor_list.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,26 @@
2525
# Fix: Ensure pbc has the correct shape [n_systems, 3]
2626
pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool)
2727

28-
mapping, mapping_batch, shifts_idx = torch_nl_linked_cell(
28+
mapping, mapping_system, shifts_idx = torch_nl_linked_cell(
2929
cutoff, pos, cell, pbc_tensor, system_idx, self_interaction
3030
)
31-
cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_batch)
31+
cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_system)
3232
dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
3333

3434
print(mapping.shape)
35-
print(mapping_batch.shape)
35+
print(mapping_system.shape)
3636
print(shifts_idx.shape)
3737
print(cell_shifts.shape)
3838
print(dds.shape)
3939

40-
mapping_n2, mapping_batch_n2, shifts_idx_n2 = torch_nl_n2(
40+
mapping_n2, mapping_system_n2, shifts_idx_n2 = torch_nl_n2(
4141
cutoff, pos, cell, pbc_tensor, system_idx, self_interaction
4242
)
43-
cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_batch_n2)
43+
cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_system_n2)
4444
dds_n2 = transforms.compute_distances_with_cell_shifts(pos, mapping_n2, cell_shifts_n2)
4545

4646
print(mapping_n2.shape)
47-
print(mapping_batch_n2.shape)
47+
print(mapping_system_n2.shape)
4848
print(shifts_idx_n2.shape)
4949
print(cell_shifts_n2.shape)
5050
print(dds_n2.shape)

examples/tutorials/low_level_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
"""
108108
`SimState` objects can be passed directly to the model and it will compute
109109
the properties of the systems in the batch. The properties will be returned
110-
either batchwise, like the energy, or atomwise, like the forces.
110+
either systemwise, like the energy, or atomwise, like the forces.
111111
112112
Note that the energy here refers to the potential energy of the system.
113113
"""
@@ -116,9 +116,9 @@
116116
model_outputs = model(state)
117117
print(f"Model outputs: {', '.join(list(model_outputs))}")
118118

119-
print(f"Energy is a batchwise property with shape: {model_outputs['energy'].shape}")
119+
print(f"Energy is a systemwise property with shape: {model_outputs['energy'].shape}")
120120
print(f"Forces are an atomwise property with shape: {model_outputs['forces'].shape}")
121-
print(f"Stress is a batchwise property with shape: {model_outputs['stress'].shape}")
121+
print(f"Stress is a systemwise property with shape: {model_outputs['stress'].shape}")
122122

123123

124124
# %% [markdown]

examples/tutorials/state_tutorial.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@
7171

7272
# %% [markdown]
7373
"""
74-
SimState attributes fall into three categories: atomwise, batchwise, and global.
74+
SimState attributes fall into three categories: atomwise, systemwise, and global.
7575
7676
* Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`,
77-
`masses`, `atomic_numbers`, and `batch`. Names are plural.
78-
* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for
77+
`masses`, `atomic_numbers`, and `system_idx`. Names are plural.
78+
* Systemwise attributes are tensors with shape (n_systems, ...), this is just `cell` for
7979
the base SimState. Names are singular.
8080
* Global attributes have any other shape or type, just `pbc` here. Names are singular.
8181
@@ -112,7 +112,7 @@
112112
f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems"
113113
)
114114

115-
# we can see how the shapes of batchwise, atomwise, and global properties change
115+
# we can see how the shapes of atomwise, systemwise, and global properties change
116116
print(f"Positions shape: {multi_state.positions.shape}")
117117
print(f"Cell shape: {multi_state.cell.shape}")
118118
print(f"PBC: {multi_state.pbc}")
@@ -142,7 +142,7 @@
142142
143143
SimState supports many convenience operations for manipulating batched states. Slicing
144144
is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state
145-
containing only the first three batches. The other operations are available through the
145+
containing only the first three systems. The other operations are available through the
146146
`pop`, `split`, `clone`, and `to` methods.
147147
"""
148148

@@ -182,19 +182,19 @@
182182
# %% [markdown]
183183
"""
184184
185-
You can extract specific batches from a batched state using Python's slicing syntax.
185+
You can extract specific systems from a batched state using Python's slicing syntax.
186186
This is extremely useful for analyzing specific systems or for implementing complex
187187
workflows where different systems need separate processing:
188188
189189
The slicing interface follows Python's standard indexing conventions, making it
190190
intuitive to use. Behind the scenes, TorchSim is creating a new SimState with only the
191-
selected batches, maintaining all the necessary properties and relationships.
191+
selected systems, maintaining all the necessary properties and relationships.
192192
193193
Note the difference between these operations:
194-
- `split()` returns all batches as separate states but doesn't modify the original
195-
- `pop()` removes specified batches from the original state and returns them as
194+
- `split()` returns all systems as separate states but doesn't modify the original
195+
- `pop()` removes specified systems from the original state and returns them as
196196
separate states
197-
- `__getitem__` (slicing) creates a new state with specified batches without modifying
197+
- `__getitem__` (slicing) creates a new state with specified systems without modifying
198198
the original
199199
200200
This flexibility allows you to structure your simulation workflows in the most
@@ -203,7 +203,7 @@
203203
### Splitting and Popping Batches
204204
205205
SimState provides methods to split a batched state into separate states or to remove
206-
specific batches:
206+
specific systems:
207207
"""
208208

209209
# %% [markdown]

tests/test_autobatching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_binning_auto_batcher(
149149
# Get batches until None is returned
150150
batches = list(batcher)
151151

152-
# Check we got the expected number of batches
152+
# Check we got the expected number of systems
153153
assert len(batches) == len(batcher.batched_states)
154154

155155
# Test restore_original_order

tests/test_integrators.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,66 +20,66 @@ def test_calculate_momenta_basic(device: torch.device):
2020
seed = 42
2121
dtype = torch.float64
2222

23-
# Create test inputs for 3 batches with 2 atoms each
23+
# Create test inputs for 3 systems with 2 atoms each
2424
n_atoms = 8
2525
positions = torch.randn(n_atoms, 3, dtype=dtype, device=device)
2626
masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5
27-
batch = torch.tensor(
27+
system_idx = torch.tensor(
2828
[0, 0, 1, 1, 2, 2, 3, 3], device=device
29-
) # 3 batches with 2 atoms each
29+
) # 3 systems with 2 atoms each
3030
kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device)
3131

3232
# Run the function
33-
momenta = calculate_momenta(positions, masses, batch, kT, seed=seed)
33+
momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed)
3434

3535
# Basic checks
3636
assert momenta.shape == positions.shape
3737
assert momenta.dtype == dtype
3838
assert momenta.device == device
3939

40-
# Check that each batch has zero center of mass momentum
40+
# Check that each system has zero center of mass momentum
4141
for b in range(4):
42-
batch_mask = batch == b
43-
batch_momenta = momenta[batch_mask]
44-
com_momentum = torch.mean(batch_momenta, dim=0)
42+
system_mask = system_idx == b
43+
system_momenta = momenta[system_mask]
44+
com_momentum = torch.mean(system_momenta, dim=0)
4545
assert torch.allclose(
4646
com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10
4747
)
4848

4949

5050
def test_calculate_momenta_single_atoms(device: torch.device):
51-
"""Test that calculate_momenta preserves momentum for batches with single atoms."""
51+
"""Test that calculate_momenta preserves momentum for systems with single atoms."""
5252
seed = 42
5353
dtype = torch.float64
5454

55-
# Create test inputs with some batches having single atoms
55+
# Create test inputs with some systems having single atoms
5656
positions = torch.randn(5, 3, dtype=dtype, device=device)
5757
masses = torch.rand(5, dtype=dtype, device=device) + 0.5
58-
batch = torch.tensor(
58+
system_idx = torch.tensor(
5959
[0, 1, 1, 2, 3], device=device
60-
) # Batches 0, 2, and 3 have single atoms
60+
) # systems 0, 2, and 3 have single atoms
6161
kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device)
6262

6363
# Generate momenta and save the raw values before COM correction
6464
generator = torch.Generator(device=device).manual_seed(seed)
6565
raw_momenta = torch.randn(
6666
positions.shape, device=device, dtype=dtype, generator=generator
67-
) * torch.sqrt(masses * kT[batch]).unsqueeze(-1)
67+
) * torch.sqrt(masses * kT[system_idx]).unsqueeze(-1)
6868

6969
# Run the function
70-
momenta = calculate_momenta(positions, masses, batch, kT, seed=seed)
70+
momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed)
7171

72-
# Check that single-atom batches have unchanged momenta
73-
for b in [0, 2, 3]: # Single atom batches
74-
batch_mask = batch == b
72+
# Check that single-atom systems have unchanged momenta
73+
for b in [0, 2, 3]: # Single atom systems
74+
system_mask = system_idx == b
7575
# The momentum should be exactly the same as the raw value for single atoms
76-
assert torch.allclose(momenta[batch_mask], raw_momenta[batch_mask])
76+
assert torch.allclose(momenta[system_mask], raw_momenta[system_mask])
7777

78-
# Check that multi-atom batches have zero COM
79-
for b in [1]: # Multi-atom batches
80-
batch_mask = batch == b
81-
batch_momenta = momenta[batch_mask]
82-
com_momentum = torch.mean(batch_momenta, dim=0)
78+
# Check that multi-atom systems have zero COM
79+
for b in [1]: # Multi-atom systems
80+
system_mask = system_idx == b
81+
system_momenta = momenta[system_mask]
82+
com_momentum = torch.mean(system_momenta, dim=0)
8383
assert torch.allclose(
8484
com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10
8585
)
@@ -378,7 +378,7 @@ def test_compute_cell_force_atoms_per_system():
378378
Covers fix in https://github.com/Radical-AI/torch-sim/pull/153."""
379379
from torch_sim.integrators.npt import _compute_cell_force
380380

381-
# Setup minimal state with two batches having 8:1 atom ratio
381+
# Setup minimal state with two systems having 8:1 atom ratio
382382
s1, s2 = torch.zeros(8, dtype=torch.long), torch.ones(64, dtype=torch.long)
383383

384384
state = NPTLangevinState(

tests/test_neighbors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,13 @@ def test_torch_nl_implementations(
342342
)
343343

344344
# Get the neighbor list from the implementation being tested
345-
mapping, mapping_batch, shifts_idx = nl_implementation(
345+
mapping, mapping_system, shifts_idx = nl_implementation(
346346
cutoff, pos, row_vector_cell, pbc, batch, self_interaction
347347
)
348348

349349
# Calculate distances
350350
cell_shifts = transforms.compute_cell_shifts(
351-
row_vector_cell, shifts_idx, mapping_batch
351+
row_vector_cell, shifts_idx, mapping_system
352352
)
353353
dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
354354
dds = np.sort(dds.numpy())
@@ -496,30 +496,30 @@ def test_strict_nl_edge_cases(
496496

497497
# Test with no cell shifts
498498
mapping = torch.tensor([[0], [1]], device=device, dtype=torch.long)
499-
batch_mapping = torch.tensor([0], device=device, dtype=torch.long)
499+
system_mapping = torch.tensor([0], device=device, dtype=torch.long)
500500
shifts_idx = torch.zeros((1, 3), device=device, dtype=torch.long)
501501

502502
new_mapping, new_batch, new_shifts = neighbors.strict_nl(
503503
cutoff=1.5,
504504
positions=pos,
505505
cell=cell,
506506
mapping=mapping,
507-
batch_mapping=batch_mapping,
507+
system_mapping=system_mapping,
508508
shifts_idx=shifts_idx,
509509
)
510510
assert len(new_mapping[0]) > 0 # Should find neighbors
511511

512512
# Test with different batch mappings
513513
mapping = torch.tensor([[0, 1], [1, 0]], device=device, dtype=torch.long)
514-
batch_mapping = torch.tensor([0, 1], device=device, dtype=torch.long)
514+
system_mapping = torch.tensor([0, 1], device=device, dtype=torch.long)
515515
shifts_idx = torch.zeros((2, 3), device=device, dtype=torch.long)
516516

517517
new_mapping, new_batch, new_shifts = neighbors.strict_nl(
518518
cutoff=1.5,
519519
positions=pos,
520520
cell=cell,
521521
mapping=mapping,
522-
batch_mapping=batch_mapping,
522+
system_mapping=system_mapping,
523523
shifts_idx=shifts_idx,
524524
)
525525
assert len(new_mapping[0]) > 0 # Should find neighbors
@@ -559,7 +559,7 @@ def test_neighbor_lists_time_and_memory(
559559
system_idx = torch.zeros(n_atoms, dtype=torch.long, device=device)
560560
# Fix pbc tensor shape
561561
pbc = torch.tensor([[True, True, True]], device=device)
562-
mapping, mapping_batch, shifts_idx = nl_fn(
562+
mapping, mapping_system, shifts_idx = nl_fn(
563563
cutoff, pos, cell, pbc, system_idx, self_interaction=False
564564
)
565565
else:

tests/test_transforms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,9 +1183,9 @@ def test_compute_cell_shifts_basic() -> None:
11831183
"""Test compute_cell_shifts function."""
11841184
cell = torch.eye(3).unsqueeze(0) * 2.0
11851185
shifts_idx = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
1186-
batch_mapping = torch.tensor([0, 0])
1186+
system_mapping = torch.tensor([0, 0])
11871187

1188-
cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, batch_mapping)
1188+
cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, system_mapping)
11891189

11901190
expected = torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]])
11911191
torch.testing.assert_close(cell_shifts, expected)
@@ -1272,16 +1272,16 @@ def test_build_linked_cell_neighborhood_basic() -> None:
12721272
cutoff = 1.5
12731273
n_atoms = torch.tensor([2, 2])
12741274

1275-
mapping, batch_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood(
1275+
mapping, system_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood(
12761276
positions, cell, pbc, cutoff, n_atoms, self_interaction=False
12771277
)
12781278

12791279
# Check that atoms in the same structure are neighbors
12801280
assert mapping.shape[1] >= 2 # At least 2 neighbor pairs
12811281

1282-
# Verify batch_mapping has correct length
1283-
assert batch_mapping.shape[0] == mapping.shape[1]
1282+
# Verify system_mapping has correct length
1283+
assert system_mapping.shape[0] == mapping.shape[1]
12841284

12851285
# Verify that there are neighbors from both batches
1286-
assert torch.any(batch_mapping == 0)
1287-
assert torch.any(batch_mapping == 1)
1286+
assert torch.any(system_mapping == 0)
1287+
assert torch.any(system_mapping == 1)

torch_sim/models/graphpes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra
6868
graphs = []
6969

7070
for i in range(state.n_systems):
71-
batch_mask = state.system_idx == i
72-
R = state.positions[batch_mask]
73-
Z = state.atomic_numbers[batch_mask]
71+
system_mask = state.system_idx == i
72+
R = state.positions[system_mask]
73+
Z = state.atomic_numbers[system_mask]
7474
cell = state.row_vector_cell[i]
7575
nl, shifts = vesin_nl_ts(
7676
R,

0 commit comments

Comments
 (0)