How to use the FIM-MJP Model#
Input to the model#
The model takes as an input dictionary containing at least three items and one additional argument. The input dictionary should contain the following items:
The observation grid with size
[num_paths, grid_size]which are the locations in time when a observation was recorded. The key in the dictionary isobservation_gridand the data type isfloat.The observation values with size
[num_paths, grid_size]are the actually observed values (state) of the process. The key in the dictionary isobservation_valuesand the data type isint.The sequence length with size
[num_paths]which is the length of the observed sequence. The key in the dictionary isseq_lengthand the data type isint.The dimension of the process which is an
integerbetween 2 and 6. The maximum number of states that are supported by our model is 6. The argument name isn_states.
Optionally, the dictionary can contain the following items:
The time normalization factor with size
[num_paths]which is the factor by which the time is normalized. The key in the dictionary istime_normalization_factorsand the data type isfloat. In case this item is not provided, the model will normalize the time by the maximum time in the observation grid.Items for calculating the loss:
intensity matrix with size
[num_paths, n_states, n_states]which is the intensity matrix of the process. The key in the dictionary isintensity_matricesand the data type isfloat.initial distribution with size
[num_paths, n_states]which is the initial distribution of the process. The key in the dictionary isinitial_distributionsand the data type isint.adjacency matrix with size
[num_paths, n_states, n_states]which is the adjacency matrix of the process. The key in the dictionary isadjacency_matricesand the data type isint.
Output of the model#
The model returns a dictionary containing the following items:
The intensity matrix with size
[num_paths, n_states, n_states]which is the intensity matrix of the process. The key in the dictionary isintensity_matricesand the data type isfloat.The initial distribution with size
[num_paths, n_states]which is the initial distribution of the process. The key in the dictionary isinitial_distributionsand the data type isint.The adjacency matrix with size
[num_paths, n_states, n_states]which is the adjacency matrix of the process. The key in the dictionary isadjacency_matricesand the data type isint.The losses which is the loss of the model. The key in the dictionary is
lossand the data type isfloat.
Loading the data and our model#
from datasets import load_dataset
import torch
from collections import defaultdict
from fim.trainers.utils import get_accel_type
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
device = get_accel_type()
Dataset#
We also provide a synthetic dataset for evaluating the model
# Loading the Discrete Flashing Ratchet (DFR) dataset from Huggingface
data = load_dataset("FIM4Science/mjp", download_mode="force_redownload", name="default")
data.set_format("torch")
Repo card metadata block was not found. Setting CardData to empty.
Pretrained model#
# Loading the FIMMJP model from Huggingface
from fim.models.mjp import FIMMJP
fimmjp = FIMMJP.from_pretrained("FIM4Science/fim-mjp", trust_remote_code=True)
fimmjp = fimmjp.to(device)
fimmjp.eval()
FIMMJP(
(gaussian_nll): GaussianNLLLoss()
(init_cross_entropy): CrossEntropyLoss()
(pos_encodings): SineTimeEncoding(
(linear_embedding): Linear(in_features=1, out_features=1, bias=True)
(periodic_embedding): Sequential(
(0): Linear(in_features=1, out_features=249, bias=True)
(1): SinActivation()
)
)
(ts_encoder): TransformerEncoder(
(layers): ModuleList(
(0-3): 4 x TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=1024, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=1024, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
)
)
(path_attention): MultiHeadLearnableQueryAttention(
(W_k): Linear(in_features=256, out_features=256, bias=False)
(W_v): Linear(in_features=256, out_features=256, bias=False)
(W_o): Linear(in_features=256, out_features=256, bias=False)
)
(intensity_matrix_decoder): MLP(
(layers): Sequential(
(linear_0): Linear(in_features=257, out_features=128, bias=True)
(activation_0): SELU()
(dropout_0): Dropout(p=0.1, inplace=False)
(linear_1): Linear(in_features=128, out_features=128, bias=True)
(activation_1): SELU()
(dropout_1): Dropout(p=0.1, inplace=False)
(output): Linear(in_features=128, out_features=60, bias=True)
)
)
(initial_distribution_decoder): MLP(
(layers): Sequential(
(linear_0): Linear(in_features=257, out_features=128, bias=True)
(activation_0): SELU()
(dropout_0): Dropout(p=0.1, inplace=False)
(linear_1): Linear(in_features=128, out_features=128, bias=True)
(activation_1): SELU()
(dropout_1): Dropout(p=0.1, inplace=False)
(output): Linear(in_features=128, out_features=6, bias=True)
)
)
)
# copy data to device
batch = {k: v.to(device)[0] for k, v in data["train"][:1].items()}
# Prepare a batch
n_paths_eval = [1, 30, 100, 300, 500, 1000, 5000]
total_n_paths = batch["observation_grid"].shape[1]
statistics = total_n_paths // 300
Evaluate the model#
result = defaultdict(list)
with torch.no_grad():
for n_paths in n_paths_eval:
for _ in range(statistics):
paths_idx = torch.randperm(total_n_paths)[:n_paths]
mini_batch = batch.copy()
mini_batch["observation_grid"] = batch["observation_grid"][:, paths_idx]
mini_batch["observation_values"] = batch["observation_values"][:, paths_idx]
mini_batch["seq_lengths"] = batch["seq_lengths"][:, paths_idx]
output = fimmjp(mini_batch, n_states=6)
result[n_paths].append(output["losses"]["rmse_loss"].item())
means = {n_paths: torch.tensor(losses).mean().item() for n_paths, losses in result.items()}
stds = {n_paths: torch.tensor(losses).std().item() for n_paths, losses in result.items()}
df_result = pd.DataFrame(
{
"# Paths during Evaluation": list(means.keys()),
"RMSE": [f"{mean:.3f} ± {std:.3f}" for mean, std in zip(means.values(), stds.values())],
}
)
df_result
| # Paths during Evaluation | RMSE | |
|---|---|---|
| 0 | 1 | 0.654 ± 0.061 |
| 1 | 30 | 0.320 ± 0.063 |
| 2 | 100 | 0.196 ± 0.039 |
| 3 | 300 | 0.169 ± 0.021 |
| 4 | 500 | 0.162 ± 0.016 |
| 5 | 1000 | 0.227 ± 0.009 |
| 6 | 5000 | 0.724 ± 0.001 |