Solving a real problem using our model: The Discrete Flashing Ratchet (DFR)#
The following is adapted from [BCS+24] and shows a step by step example for solving problems using our trained model from Hugging Face. This notebook is a detailed example of solving a physics problem using our model using the following approach:
Understanding the problem
Loading and understanding the data
Using our model to infer the transition rates
Using the inferred transition rates to solve the original problem
In statistical physics, the ratchet effect refers to the rectification of thermal fluctuations into directed motion to produce work, and goes all the way back to Feynman [FLSH65].
Here we consider a simple example thereof, in which a Brownian particle, immersed in a thermal bath at unit temperature, moves on a one-dimensional lattice. The particle is subject to a linear, periodic and asymmetric potential of maximum height \(2V\) that is switched on and off at a constant rate \(r\). The potential has three possible values when is switched on, which correspond to three of the states of the system. The particle jumps among them with rate \(f_{ij}^{\tiny{\text{ on}}}\).
When the potential is switched off, the particle jumps freely with rate \(f_{ij}^{\tiny{\text{ off}}}\).
We can therefore think of the system as a six-state system, as illustrated here:
Similar to [RoldanP10], we now define the transition rates as
Given these specifics, we consider the parameter set \((V, r, B) = (1, 1, 1)\)
together with the dataset simulated by [SSanchez23],
So given our data, we want to recover the theoretical transition rates, as a q-matrix.[1]
import numpy as np
np.set_printoptions(linewidth=100)
rates_on=np.array([[np.exp(-0.5*(j-i))-1*(i==j) for j in range(3)] for i in range(3)])
rates_off=np.ones(shape=(3,3))
np.fill_diagonal(rates_off, 0)
id3=np.eye(3)
q_matrix=np.zeros(shape=(6,6))
q_matrix[:3,:3]=rates_on
q_matrix[3:,3:]=rates_off
q_matrix[3:,:3]=id3
q_matrix[:3,3:]=id3
diagonal=-np.sum(q_matrix,axis=1)
np.fill_diagonal(q_matrix, diagonal)
q_matrix
array([[-1.9744101 , 0.60653066, 0.36787944, 1. , 0. , 0. ],
[ 1.64872127, -3.25525193, 0.60653066, 0. , 1. , 0. ],
[ 2.71828183, 1.64872127, -5.3670031 , 0. , 0. , 1. ],
[ 1. , 0. , 0. , -3. , 1. , 1. ],
[ 0. , 1. , 0. , 1. , -3. , 1. ],
[ 0. , 0. , 1. , 1. , 1. , -3. ]])
Loading and exploring the data#
We start by loading the simulated data from Hugging Face:
# Loading the Discrete Flashing Ratchet (DFR) dataset from Huggingface
from datasets import load_dataset
import torch
data = load_dataset("FIM4Science/mjp", download_mode="force_redownload", trust_remote_code=True, name="DFR_V=1")
data.set_format("torch")
Repo card metadata block was not found. Setting CardData to empty.
The model will later use the ‘observation_grid’, ‘observation_values’, ‘seq_lengths’ features to estimate the transition rates.
Here the observation grid is constant over all paths, consisting of 100 evenly spaced points on \([0,1]\). The sequence lengths of the dataset is therefore just 100 for all paths.
data["train"]["observation_grid"][0,0,:,0]
tensor([0.0000, 0.0101, 0.0202, 0.0303, 0.0404, 0.0505, 0.0606, 0.0707, 0.0808,
0.0909, 0.1010, 0.1111, 0.1212, 0.1313, 0.1414, 0.1515, 0.1616, 0.1717,
0.1818, 0.1919, 0.2020, 0.2121, 0.2222, 0.2323, 0.2424, 0.2525, 0.2626,
0.2727, 0.2828, 0.2929, 0.3030, 0.3131, 0.3232, 0.3333, 0.3434, 0.3535,
0.3636, 0.3737, 0.3838, 0.3939, 0.4040, 0.4141, 0.4242, 0.4343, 0.4444,
0.4545, 0.4646, 0.4747, 0.4848, 0.4949, 0.5051, 0.5152, 0.5253, 0.5354,
0.5455, 0.5556, 0.5657, 0.5758, 0.5859, 0.5960, 0.6061, 0.6162, 0.6263,
0.6364, 0.6465, 0.6566, 0.6667, 0.6768, 0.6869, 0.6970, 0.7071, 0.7172,
0.7273, 0.7374, 0.7475, 0.7576, 0.7677, 0.7778, 0.7879, 0.7980, 0.8081,
0.8182, 0.8283, 0.8384, 0.8485, 0.8586, 0.8687, 0.8788, 0.8889, 0.8990,
0.9091, 0.9192, 0.9293, 0.9394, 0.9495, 0.9596, 0.9697, 0.9798, 0.9899,
1.0000])
The observation values contain the state of the processes for all paths and all time points.
Warning
In practice these labels will rarely be directly observed values, since those will, in most cases, have to be computed as part of the preprocessing.
We can visualize the first 3 paths to get a feeling for the processes:
import matplotlib.pyplot as plt
ts=data["train"]["observation_grid"][0,0,:,0]
three_paths=data["train"]["observation_values"][0,:3,:,0]
for i in range(3):
plt.plot(ts,three_paths[i],"x--")
plt.show()

Inferring transition rates#
To infer the \(q\)-matrix we first load our trained model:
from transformers import AutoModel
from fim.trainers.utils import get_accel_type
device = get_accel_type()
fimmjp = AutoModel.from_pretrained("FIM4Science/fim-mjp", trust_remote_code=True)
fimmjp = fimmjp.to(device)
fimmjp.eval()
FIMMJP(
(gaussian_nll): GaussianNLLLoss()
(init_cross_entropy): CrossEntropyLoss()
(pos_encodings): DeltaTimeEncoding()
(ts_encoder): RNNEncoder(
(rnn): LSTM(8, 256, batch_first=True, bidirectional=True)
)
(path_attention): MultiHeadLearnableQueryAttention(
(W_k): Linear(in_features=512, out_features=128, bias=False)
(W_v): Linear(in_features=512, out_features=128, bias=False)
)
(intensity_matrix_decoder): MLP(
(layers): Sequential(
(linear_0): Linear(in_features=2049, out_features=128, bias=True)
(activation_0): SELU()
(linear_1): Linear(in_features=128, out_features=128, bias=True)
(activation_1): SELU()
(output): Linear(in_features=128, out_features=60, bias=True)
)
)
(initial_distribution_decoder): MLP(
(layers): Sequential(
(linear_0): Linear(in_features=2049, out_features=128, bias=True)
(activation_0): SELU()
(linear_1): Linear(in_features=128, out_features=128, bias=True)
(activation_1): SELU()
(output): Linear(in_features=128, out_features=6, bias=True)
)
)
)
As noted in [BCS+24][2] it suffices to look at a small context window of 300 paths with 50 observation values each. We therefore infer the transition rates batchwise to demonstrate how little data might be needed.
# copy data to device
batch = {k: v.to(device) for k, v in data["train"][:1].items() if not k in ["intensity_matrices","adjacency_matrices","initial_distributions"]} # data without any information, we seek to find
n_paths = 300
total_n_paths = batch["observation_grid"].shape[1]
statistics = 50
outputs=[]
with torch.no_grad():
for _ in range(statistics):
paths_idx = torch.randperm(total_n_paths)[:n_paths]
mini_batch = batch.copy()
mini_batch["observation_grid"] = batch["observation_grid"][:, paths_idx]
mini_batch["observation_values"] = batch["observation_values"][:, paths_idx]
mini_batch["seq_lengths"] = batch["seq_lengths"][:, paths_idx]
outputs.append(fimmjp(mini_batch, n_states=6)["intensity_matrices"])# We are currently not interested in variances or initial conditions
mean_rates=torch.mean(torch.stack(outputs,dim=0),dim=0)[0].numpy()
This results in the following inferred transition rates:
mean_rates=torch.mean(torch.stack(outputs,dim=0),dim=0)[0].numpy()
print(mean_rates)
plt.matshow(mean_rates)
plt.colorbar()
plt.show()
[[-1.9315803 0.55702436 0.35664162 0.9428668 0.03558088 0.03946645]
[ 1.6503025 -3.3356545 0.6199239 0.06110187 0.94765836 0.0566683 ]
[ 2.5707161 1.601183 -5.3578224 0.17986627 0.15248567 0.85357094]
[ 0.90139735 0.05238242 0.06600926 -3.0131218 1.031255 0.9620779 ]
[ 0.04738419 0.9359199 0.06654382 0.9602639 -2.99806 0.9879483 ]
[ 0.07564461 0.05701077 0.9734346 0.9929121 0.9550975 -3.0540998 ]]

and the following variances:
variances=torch.var(torch.stack(outputs,dim=0),dim=0)[0].numpy()
print(variances)
plt.matshow(variances)
plt.colorbar()
plt.show()
[[8.7043829e-03 3.9266837e-03 2.4207723e-03 5.3723077e-03 2.1619326e-05 3.9558490e-05]
[2.3196101e-02 4.5218397e-02 8.4911650e-03 7.3515883e-05 9.9087860e-03 5.6897985e-05]
[4.9169671e-02 4.8791450e-02 1.5976055e-01 6.1710132e-04 4.6605940e-04 2.3276092e-02]
[5.5284980e-03 4.0480354e-05 6.8849971e-05 2.0662460e-02 8.7698111e-03 5.0109020e-03]
[3.3746906e-05 5.5495598e-03 6.3126106e-05 7.6852222e-03 2.7397871e-02 1.1959122e-02]
[1.1993035e-04 6.5733715e-05 1.3381124e-02 1.3250175e-02 1.1286872e-02 4.0514071e-02]]

Using the previously computed theoretical transition rates we can look at the pointwise error:
error=mean_rates-q_matrix
print(error)
plt.matshow(error)
plt.colorbar()
plt.show()
[[ 0.0428298 -0.0495063 -0.01123782 -0.0571332 0.03558088 0.03946645]
[ 0.00158126 -0.08040257 0.01339323 0.06110187 -0.05234164 0.0566683 ]
[-0.14756569 -0.04753821 0.00918068 0.17986627 0.15248567 -0.14642906]
[-0.09860265 0.05238242 0.06600926 -0.01312184 0.03125501 -0.03792208]
[ 0.04738419 -0.06408012 0.06654382 -0.03973609 0.00194001 -0.0120517 ]
[ 0.07564461 0.05701077 -0.02656537 -0.00708789 -0.0449025 -0.0540998 ]]

Inferring the parameters#
We will now use the transition rates to recover the parameters used to generate the dataset. Recall the formula we previously used to calculate the theoretical transition rates:
The \(q\)-matrix (transition rate matrix) is therefore given by
since the transition rates between \((i,j,\text{on})\) and \((i,j,\text{off})\) (and vice versa) are given by the parameter \(r\).
This system is evidently overdetermined. For simplicity we assume normally distributed independent errors, which results in the following MLEs for each of the parameters:
We can easily recover \(b\) using the diagonal:
lower_matrix=mean_rates[3:,3:]
np.fill_diagonal(lower_matrix, 0)
B_hat=lower_matrix.sum()/6
B_hat
np.float32(0.9815925)
Similarly \(r\) describes the transition rates from a state \((i,j,\text{on})\) to \((i,j,\text{off})\) and vice versa:
r_hat=sum([mean_rates[i,i+3]+mean_rates[i+3,i] for i in range(3)])/6
r_hat
np.float32(0.92580795)
And \(V\) can be recovered using \((q_{i,j})_{i,j=1}^3\):
V_hat=(-2*(np.log(mean_rates[0,1])+np.log(mean_rates[1,2]))+2*(np.log(mean_rates[1,0])+np.log(mean_rates[2,1]))-np.log(mean_rates[0,2])+np.log(mean_rates[2,0]))/6
V_hat
np.float32(1.0075369)
Which results in the following inferred parameters:
(V_hat,r_hat,B_hat)
(np.float32(1.0075369), np.float32(0.92580795), np.float32(0.9815925))
Bibliography#
David Berghaus, Kostadin Cvejoski, Patrick Seifner, Cesar Ojeda, and Ramses J Sanchez. Foundation inference models for markov jump processes. In The Thirty-eighth Annual Conference on Neural Information Processing Systems. 2024. URL: https://openreview.net/forum?id=f4v7cmm5sC.
Richard P Feynman, Robert B Leighton, Matthew Sands, and Everett M Hafner. The feynman lectures on physics; vol. i. American Journal of Physics, 33(9):750–752, 1965.
É. Roldán and J. M. R. Parrondo. Estimating dissipation from single stationary trajectories. Physical review letters, 105 15:150607, 2010.
Patrick Seifner and Ramsés J Sánchez. Neural markov jump processes. In International Conference on Machine Learning, 30523–30552. PMLR, 2023.