How to use the FIM-MJP Model#

Input to the model#

The model takes as an input dictionary containing at least three items and one additional argument. The input dictionary should contain the following items:

  1. The observation grid with size [num_paths, grid_size] which are the locations in time when a observation was recorded. The key in the dictionary is observation_grid and the data type is float.

  2. The observation values with size [num_paths, grid_size] are the actually observed values (state) of the process. The key in the dictionary is observation_values and the data type is int.

  3. The sequence length with size [num_paths] which is the length of the observed sequence. The key in the dictionary is seq_length and the data type is int.

  4. The dimension of the process which is an integer between 2 and 6. The maximum number of states that are supported by our model is 6. The argument name is n_states.

Optionally, the dictionary can contain the following items:

  • The time normalization factor with size [num_paths] which is the factor by which the time is normalized. The key in the dictionary is time_normalization_factors and the data type is float. In case this item is not provided, the model will normalize the time by the maximum time in the observation grid.

  • Items for calculating the loss:

    • intensity matrix with size [num_paths, n_states, n_states] which is the intensity matrix of the process. The key in the dictionary is intensity_matrices and the data type is float.

    • initial distribution with size [num_paths, n_states] which is the initial distribution of the process. The key in the dictionary is initial_distributions and the data type is int.

    • adjacency matrix with size [num_paths, n_states, n_states] which is the adjacency matrix of the process. The key in the dictionary is adjacency_matrices and the data type is int.

Output of the model#

The model returns a dictionary containing the following items:

  • The intensity matrix with size [num_paths, n_states, n_states] which is the intensity matrix of the process. The key in the dictionary is intensity_matrices and the data type is float.

  • The initial distribution with size [num_paths, n_states] which is the initial distribution of the process. The key in the dictionary is initial_distributions and the data type is int.

  • The adjacency matrix with size [num_paths, n_states, n_states] which is the adjacency matrix of the process. The key in the dictionary is adjacency_matrices and the data type is int.

  • The losses which is the loss of the model. The key in the dictionary is loss and the data type is float.

Loading the data and our model#

from datasets import load_dataset
from transformers import AutoModel
import torch
from collections import defaultdict
from fim.trainers.utils import get_accel_type
import pandas as pd

device = get_accel_type()

Dataset#

We also provide a synthetic dataset for evaluating the model

# Loading the Discrete Flashing Ratchet (DFR) dataset from Huggingface
data = load_dataset("FIM4Science/mjp", download_mode="force_redownload", trust_remote_code=True, name="DFR_V=1")
data.set_format("torch")
Repo card metadata block was not found. Setting CardData to empty.

Pretrained model#

# Loading the FIMMJP model from Huggingface
fimmjp = AutoModel.from_pretrained("FIM4Science/fim-mjp", trust_remote_code=True)

fimmjp = fimmjp.to(device)
fimmjp.eval()
FIMMJP(
  (gaussian_nll): GaussianNLLLoss()
  (init_cross_entropy): CrossEntropyLoss()
  (pos_encodings): DeltaTimeEncoding()
  (ts_encoder): RNNEncoder(
    (rnn): LSTM(8, 256, batch_first=True, bidirectional=True)
  )
  (path_attention): MultiHeadLearnableQueryAttention(
    (W_k): Linear(in_features=512, out_features=128, bias=False)
    (W_v): Linear(in_features=512, out_features=128, bias=False)
  )
  (intensity_matrix_decoder): MLP(
    (layers): Sequential(
      (linear_0): Linear(in_features=2049, out_features=128, bias=True)
      (activation_0): SELU()
      (linear_1): Linear(in_features=128, out_features=128, bias=True)
      (activation_1): SELU()
      (output): Linear(in_features=128, out_features=60, bias=True)
    )
  )
  (initial_distribution_decoder): MLP(
    (layers): Sequential(
      (linear_0): Linear(in_features=2049, out_features=128, bias=True)
      (activation_0): SELU()
      (linear_1): Linear(in_features=128, out_features=128, bias=True)
      (activation_1): SELU()
      (output): Linear(in_features=128, out_features=6, bias=True)
    )
  )
)
# copy data to device
batch = {k: v.to(device) for k, v in data["train"][:1].items()}
# Prepare a batch
n_paths_eval = [1, 30, 100, 300, 500, 1000, 5000]
total_n_paths = batch["observation_grid"].shape[1]
statistics = total_n_paths // 300 

Evaluate the model#

result = defaultdict(list)
with torch.no_grad():
    for n_paths in n_paths_eval:
        for _ in range(statistics):
            paths_idx = torch.randperm(total_n_paths)[:n_paths]
            mini_batch = batch.copy()
            mini_batch["observation_grid"] = batch["observation_grid"][:, paths_idx]
            mini_batch["observation_values"] = batch["observation_values"][:, paths_idx]
            mini_batch["seq_lengths"] = batch["seq_lengths"][:, paths_idx]
            output = fimmjp(mini_batch, n_states=6)
            result[n_paths].append(output["losses"]["rmse_loss"].item())
means = {n_paths: torch.tensor(losses).mean().item() for n_paths, losses in result.items()}
stds = {n_paths: torch.tensor(losses).std().item() for n_paths, losses in result.items()}

df_result = pd.DataFrame(
    {
        "# Paths during Evaluation": list(means.keys()),
        "RMSE": [f"{mean:.3f} ± {std:.3f}" for mean, std in zip(means.values(), stds.values())],
    }
)

df_result
# Paths during Evaluation RMSE
0 1 0.676 ± 0.054
1 30 0.299 ± 0.049
2 100 0.173 ± 0.029
3 300 0.119 ± 0.017
4 500 0.114 ± 0.014
5 1000 0.229 ± 0.019
6 5000 0.882 ± 0.001