Skip to content

Commit d0c91cd

Browse files
committed
Fairchem v2 patch (#238)
1 parent b04ca55 commit d0c91cd

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torch_sim/models/fairchem.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
4747
from torch_sim.typing import StateDict
4848

4949

50-
class FairChemModel(torch.nn.Module, ModelInterface):
50+
class FairChemModel(ModelInterface):
5151
"""FairChem model wrapper for computing atomistic properties.
5252
5353
Wraps FairChem models to compute energies, forces, and stresses. Can be
@@ -171,13 +171,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict:
171171
if state.device != self._device:
172172
state = state.to(self._device)
173173

174-
if state.batch is None:
175-
state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int)
174+
if state.system_idx is None:
175+
state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int)
176176

177177
# Convert SimState to AtomicData objects for efficient batch processing
178178
from ase import Atoms
179179

180-
natoms = torch.bincount(state.batch)
180+
natoms = torch.bincount(state.system_idx)
181181
atomic_data_list = []
182182

183183
for i, (n, c) in enumerate(

0 commit comments

Comments
 (0)