Time Series Generation with Class Label Guidance

In this tutorial, we will go through how to generate time series with class labels.

Problem setting

Given a time series class label (e.g. patient status), \(\mathbf{c} \in \mathbb{R}^{C}\), we are interested in the conditional distribution, i.e. \(p(\mathbf{x}_{\text{target}} \mid \mathbf{c})\). From this distribution, we can sample possible conditional time series \(\hat{\mathbf{x}}_{\text{target}} \in \mathbb{R}^{T \times D}\) that are aligned with the given class label.

Implementation

1. import modules

Print all models that suport class label guidance

[1]:
import torch
from gents.dataset import Spiral2D
from gents.model import TimeVQVAE
from lightning import Trainer
from gents.evaluation import tsne_visual
from gents.evaluation import context_fid

import gents.model

class_model = []
for name in gents.model.__all__:
    model_cls = getattr(gents.model, name)
    if 'class' in model_cls.ALLOW_CONDITION:
        class_model.append(name)
class_model
/home/wcx/anaconda3/envs/gents/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%
Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency.
Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.
[1]:
['VanillaVAE', 'TimeVQVAE', 'VanillaGAN', 'RCGAN', 'VanillaDDPM', 'VanillaMAF']

2. setup datamodule and model

Here, Spiral2D data has two class labels, i.e. clock-wise and counter clock-wise. We set \(T=32\) for illustration. Note that condition='class' and class_num=2 is also required for datamodule and model for setup.

[2]:
dm = Spiral2D(
    seq_len=32,
    batch_size=64,
    num_samples=3000,
    data_dir="../data",
    condition="class",
)
model = TimeVQVAE(
    seq_len=dm.seq_len,
    seq_dim=dm.seq_dim,
    condition="class",
    class_num=2,
)

3. setup training

Utilizing lightning/pytorch-lightning, one can easily set:

  • GPU devices

  • Training epochs/steps

  • Callbacks

  • etc..

[3]:
trainer = Trainer(max_steps=3000, devices=[0], enable_progress_bar=False)
trainer.fit(model, dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3080 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type                     | Params | Mode
-------------------------------------------------------------------
0 | encoder_l     | VQVAEEncoder             | 20.8 K | train
1 | encoder_h     | VQVAEEncoder             | 20.0 K | train
2 | vq_model_l    | VectorQuantize           | 2.2 K  | train
3 | vq_model_h    | VectorQuantize           | 2.2 K  | train
4 | decoder_l     | VQVAEDecoder             | 3.8 K  | train
5 | decoder_h     | VQVAEDecoder             | 2.9 K  | train
6 | transformer_l | BidirectionalTransformer | 656 K  | train
7 | transformer_h | BidirectionalTransformer | 786 K  | train
-------------------------------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.978     Total estimated model params size (MB)
242       Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_steps=3000` reached.

4. Evaluation

Qualitative evaluation: TSNE visualization

[4]:
# testing
dm.setup("test")
real_data = torch.cat([batch["seq"] for batch in dm.test_dataloader()])
cond_data = torch.cat([batch["c"] for batch in dm.test_dataloader()])

# generate samples with the same number as real data
# and same condition
gen_data = model.sample(n_sample=len(real_data), condition=cond_data)  # [N, 64, 2]
tsne_visual(real_data, gen_data, cond_data)

../_images/tutorials_class_label_guidance_8_0.png

Quantitative evaluation: context-fid

[6]:
context_fid(
    real_data.numpy(), gen_data.numpy(), device="cuda:0", train_data=dm.train_ds.data.numpy(),
)
train a new ts2vec model
[6]:
np.float64(0.8106987850150742)