Temporal missing pattern

Temporal missing pattern#

Similar to the last section, we will now look at the \(\texttt{FIM}\) model. We will again generate some data, write some preprocessing function and delve deeper into the arguments of this model! Remember that the \(\texttt{FIM}\) tries to impute a whole range of values over some timeframe for which local simplicity can no longer be assumed.

from fim.models.imputation import FIMImputationWindowed
from datasets import load_dataset
from tutorial_helper import prepare_data
import torch
model = FIMImputationWindowed.from_pretrained("FIM4Science/fim-windowed-imputation")
/home/manuel/Documents/github/FIM/.venv/lib/python3.12/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  _C._set_float32_matmul_precision(precision)

Imputing a temporal missing pattern#

We again start by generating some data, similar to the last two sections:

data = load_dataset("FIM4Science/roessler-example", download_mode="force_redownload", name="default")["train"] 
data.set_format("torch")
ts=data["t"][:].reshape(1,1,4096,1)
observed_x=data["x"][:].reshape(1,1,4096,1)*(1+torch.normal(0,0.05,size=(1,4096,1)))
observed_v=data["x_prime"][:].reshape(1,1,4096,1)
observed_a=data["x_double_prime"][:].reshape(1,1,4096,1)

Sadly the function needed to prepare our data for this model is a bit more lengthy. New users can either use this version, which assumes 3 windows (Start window, Imputation window, End Window), the version from tutorial_help.py which also supports 5 windows or write their own preprocessing function. This function then should follow the structure of the output of this function, i.e. a dictionary with the same keys and the described dimensions! The output should therefore have the following keys with the indicated shapes, where wc is the window count, B the batch size and D the dimension:

Key

Shape

location_times

[B, wlen_locs, 1]

observation_times

[B, wc, wlen, D]

observation_values

[B, wc, wlen, D]

observation_mask

[B, wc, wlen, D]

linitial_conditions

[B, D]

rinitial_conditions

[B, D]

padding_mask_locations

[B, wlen_locs]

With all that hard work, we can now impute the missing values as before. Following the paper we impute the center 20% of the dataset:

imp_start=675
imp_end=700
batch= prepare_data(ts,observed_x,imp_start=imp_start,imp_end=imp_end)
with torch.no_grad():
    prediction_x=model(batch)["imputation_window"]["learnt"]
    prediction_velocity=model(batch)["imputation_window"]["drift"]

Similar to the \(\texttt{FIM-}\ell\) model, \(\texttt{FIM-}\) returns more then just the learned values, here is a complete list of the structure of the output:

{
    "imputation_window": {
        "learnt": learnt_imp_solution,
        "target": batch.get("target_sample_path", None),
        "locations": locations,
        "drift": learnt_imp_drift,
        "drift_certainty": learnt_imp_certainty,
        "padding_mask_locations": batch.get("padding_mask_locations", None),
    },
    "observations": {
        "values": obs_values,
        "mask": obs_mask,
        "times": obs_times,
        "denoised_values": obs_values_processed.view(B, wc, wlen, D),
        "interpolation": interpolation_solution,
        "drift": interpolation_drift,
        "drift_certainty": interpolation_certainty,
    },
}
import matplotlib.pyplot as plt 
plt.plot(ts.flatten()[600:imp_start],observed_x.flatten()[600:imp_start],label="True velocity (Context)", c="red", alpha=0.5)
plt.plot(ts.flatten()[imp_end:800],observed_x.flatten()[imp_end:800], c="red", alpha=0.5)
plt.plot(ts.flatten()[imp_start:imp_end],observed_x.flatten()[imp_start:imp_end],linestyle="dashed",label="True velocity", c="steelblue")
plt.plot(ts.flatten()[imp_start:imp_end],prediction_x.flatten(),label="Estimated velocity", c="black")
plt.xlabel("Time [s]")
plt.ylabel("x")
plt.legend()
plt.show()
_images/9a91fe434f5675e90a196e2024a186532e8162e1dc48388f521b0b44ffdcc273.png