gents.model package

Subpackages

Submodules

gents.model.base module

class gents.model.base.BaseModel(seq_len: int, seq_dim: int, condition: str, **kwargs)

Bases: ABC, LightningModule

Base class for time series generative models in PyTorch Lightning.

Parameters:
  • seq_len (int) – Target sequence length

  • seq_dim (int) – Target sequence dimension, for univariate time series, set as 1

  • condition (str) – Possible condition type, choose from [None, ‘predict’, ‘impute’, ‘class’, ‘super_resolution’]. None standards for unconditional generation.

  • **kwargs – Additional arguments for the model

ALLOW_CONDITION = Ellipsis
sample(n_sample: int = 1, condition: Tensor | int = None, **kwargs) Tensor

Sample time series from trained model in evaluation mode.

Parameters:
  • n_sample (int, optional) – The number of samples. Defaults to 1.

  • condition (torch.Tensor | int, optional) – Condition tensor in shape (batch_size, seq_len or obs_len, seq_dim). If the model is class label, then could be int. Defaults to None.

  • kwargs – Additional arguments for the sampling process. E.g. data_mask, t, etc.

Returns:

sampled time series of shape (n_sample, seq_len, seq_dim) for unconditional generation, or (batch_size, seq_len, seq_dim, n_sample) for conditional generation.

Return type:

torch.Tensor