How to use the FIM-MJP Model#

Input to the model#

The model takes as an input dictionary containing at least three items and one additional argument. The input dictionary should contain the following items:

  1. The observation grid with size [num_paths, grid_size] which are the locations in time when a observation was recorded. The key in the dictionary is observation_grid and the data type is float.

  2. The observation values with size [num_paths, grid_size] are the actually observed values (state) of the process. The key in the dictionary is observation_values and the data type is int.

  3. The sequence length with size [num_paths] which is the length of the observed sequence. The key in the dictionary is seq_length and the data type is int.

  4. The dimension of the process which is an integer between 2 and 6. The maximum number of states that are supported by our model is 6. The argument name is n_states.

Optionally, the dictionary can contain the following items:

  • The time normalization factor with size [num_paths] which is the factor by which the time is normalized. The key in the dictionary is time_normalization_factors and the data type is float. In case this item is not provided, the model will normalize the time by the maximum time in the observation grid.

  • Items for calculating the loss:

    • intensity matrix with size [num_paths, n_states, n_states] which is the intensity matrix of the process. The key in the dictionary is intensity_matrices and the data type is float.

    • initial distribution with size [num_paths, n_states] which is the initial distribution of the process. The key in the dictionary is initial_distributions and the data type is int.

    • adjacency matrix with size [num_paths, n_states, n_states] which is the adjacency matrix of the process. The key in the dictionary is adjacency_matrices and the data type is int.

Output of the model#

The model returns a dictionary containing the following items:

  • The intensity matrix with size [num_paths, n_states, n_states] which is the intensity matrix of the process. The key in the dictionary is intensity_matrices and the data type is float.

  • The initial distribution with size [num_paths, n_states] which is the initial distribution of the process. The key in the dictionary is initial_distributions and the data type is int.

  • The adjacency matrix with size [num_paths, n_states, n_states] which is the adjacency matrix of the process. The key in the dictionary is adjacency_matrices and the data type is int.

  • The losses which is the loss of the model. The key in the dictionary is loss and the data type is float.

Loading the data and our model#

from datasets import load_dataset
import torch
from collections import defaultdict
from fim.trainers.utils import get_accel_type
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

device = get_accel_type()

Dataset#

We also provide a synthetic dataset for evaluating the model

# Loading the Discrete Flashing Ratchet (DFR) dataset from Huggingface
data = load_dataset("FIM4Science/mjp", download_mode="force_redownload", name="default")
data.set_format("torch")
Repo card metadata block was not found. Setting CardData to empty.

Pretrained model#

# Loading the FIMMJP model from Huggingface
from fim.models.mjp import FIMMJP
fimmjp = FIMMJP.from_pretrained("FIM4Science/fim-mjp", trust_remote_code=True)
fimmjp = fimmjp.to(device)
fimmjp.eval()
FIMMJP(
  (gaussian_nll): GaussianNLLLoss()
  (init_cross_entropy): CrossEntropyLoss()
  (pos_encodings): SineTimeEncoding(
    (linear_embedding): Linear(in_features=1, out_features=1, bias=True)
    (periodic_embedding): Sequential(
      (0): Linear(in_features=1, out_features=249, bias=True)
      (1): SinActivation()
    )
  )
  (ts_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (path_attention): MultiHeadLearnableQueryAttention(
    (W_k): Linear(in_features=256, out_features=256, bias=False)
    (W_v): Linear(in_features=256, out_features=256, bias=False)
    (W_o): Linear(in_features=256, out_features=256, bias=False)
  )
  (intensity_matrix_decoder): MLP(
    (layers): Sequential(
      (linear_0): Linear(in_features=257, out_features=128, bias=True)
      (activation_0): SELU()
      (dropout_0): Dropout(p=0.1, inplace=False)
      (linear_1): Linear(in_features=128, out_features=128, bias=True)
      (activation_1): SELU()
      (dropout_1): Dropout(p=0.1, inplace=False)
      (output): Linear(in_features=128, out_features=60, bias=True)
    )
  )
  (initial_distribution_decoder): MLP(
    (layers): Sequential(
      (linear_0): Linear(in_features=257, out_features=128, bias=True)
      (activation_0): SELU()
      (dropout_0): Dropout(p=0.1, inplace=False)
      (linear_1): Linear(in_features=128, out_features=128, bias=True)
      (activation_1): SELU()
      (dropout_1): Dropout(p=0.1, inplace=False)
      (output): Linear(in_features=128, out_features=6, bias=True)
    )
  )
)
# copy data to device
batch = {k: v.to(device)[0] for k, v in data["train"][:1].items()}
# Prepare a batch
n_paths_eval = [1, 30, 100, 300, 500, 1000, 5000]
total_n_paths = batch["observation_grid"].shape[1]
statistics = total_n_paths // 300 

Evaluate the model#

result = defaultdict(list)
with torch.no_grad():
    for n_paths in n_paths_eval:
        for _ in range(statistics):
            paths_idx = torch.randperm(total_n_paths)[:n_paths]
            mini_batch = batch.copy()
            mini_batch["observation_grid"] = batch["observation_grid"][:, paths_idx]
            mini_batch["observation_values"] = batch["observation_values"][:, paths_idx]
            mini_batch["seq_lengths"] = batch["seq_lengths"][:, paths_idx]
            output = fimmjp(mini_batch, n_states=6)
            result[n_paths].append(output["losses"]["rmse_loss"].item())
means = {n_paths: torch.tensor(losses).mean().item() for n_paths, losses in result.items()}
stds = {n_paths: torch.tensor(losses).std().item() for n_paths, losses in result.items()}

df_result = pd.DataFrame(
    {
        "# Paths during Evaluation": list(means.keys()),
        "RMSE": [f"{mean:.3f} ± {std:.3f}" for mean, std in zip(means.values(), stds.values())],
    }
)

df_result
# Paths during Evaluation RMSE
0 1 0.654 ± 0.061
1 30 0.320 ± 0.063
2 100 0.196 ± 0.039
3 300 0.169 ± 0.021
4 500 0.162 ± 0.016
5 1000 0.227 ± 0.009
6 5000 0.724 ± 0.001