Back to Projects

Encoding Atomic Environments with PyTorch

PyTorchPackage DevelopmentCI/CDData VisualizationPython

TorchMole: Atomic Environment Encoding

Overview

TorchMole is a PyTorch-based package I developed for encoding atomic environments in molecular and materials science applications. It provides efficient graph neural network implementations for learning representations of atomic structures.

Key Features

Architecture

Python: TorchMole Model Architecture
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data

class AtomicEncoder(nn.Module): """ Graph neural network for encoding atomic environments. Uses message-passing to aggregate neighbor information. """ def init(self, hidden_dim=128, num_layers=3, cutoff=5.0): super().init() self.cutoff = cutoff

    # Embedding layers
    self.atom_embedding = nn.Embedding(100, hidden_dim)  # Atomic numbers
    self.distance_expansion = GaussianBasis(0, cutoff, num_gaussians=50)
    
    # Message passing layers
    self.interactions = nn.ModuleList([
        InteractionBlock(hidden_dim, hidden_dim)
        for _ in range(num_layers)
    ])
    
    # Output network
    self.output_net = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim),
        nn.SiLU(),
        nn.Linear(hidden_dim, 1)
    )

def forward(self, atomic_numbers, positions, batch):
    # Compute neighbor list
    edge_index, edge_attr = self.compute_edges(positions, batch)
    
    # Initial embeddings
    x = self.atom_embedding(atomic_numbers)
    
    # Message passing
    for interaction in self.interactions:
        x = interaction(x, edge_index, edge_attr)
    
    # Aggregate to per-structure prediction
    out = self.output_net(x)
    return scatter_add(out, batch, dim=0)

class InteractionBlock(MessagePassing): """Continuous-filter convolution for atomic interactions."""

def __init__(self, in_dim, out_dim):
    super().__init__(aggr='add')
    self.mlp = nn.Sequential(
        nn.Linear(in_dim + 50, out_dim),
        nn.SiLU(),
        nn.Linear(out_dim, out_dim)
    )
    self.update_net = nn.Sequential(
        nn.Linear(out_dim + in_dim, out_dim),
        nn.SiLU()
    )

def forward(self, x, edge_index, edge_attr):
    return self.propagate(edge_index, x=x, edge_attr=edge_attr)

def message(self, x_j, edge_attr):
    return self.mlp(torch.cat([x_j, edge_attr], dim=-1))

def update(self, aggr_out, x):
    return self.update_net(torch.cat([x, aggr_out], dim=-1))</code></pre>

Package Development

The package follows modern Python best practices:

YAML: GitHub Actions CI/CD
name: CI

on: [push, pull_request]

jobs: test: runs-on: ubuntu-latest strategy: matrix: python-version: [3.9, 3.10, 3.11]

steps:
- uses: actions/checkout@v3

- name: Set up Python
  uses: actions/setup-python@v4
  with:
    python-version: ${{ matrix.python-version }}

- name: Install dependencies
  run: |
    pip install -e &quot;.[dev]&quot;

- name: Run tests
  run: pytest tests/ --cov=torchmole

- name: Upload coverage
  uses: codecov/codecov-action@v3</code></pre>

Results

TorchMole achieves competitive results on standard molecular property prediction benchmarks while providing a clean, extensible API for research applications.