gents.model package
Subpackages
Submodules
gents.model.base module
- class gents.model.base.BaseModel(seq_len: int, seq_dim: int, condition: str, **kwargs)
Bases:
ABC,LightningModuleBase class for time series generative models in PyTorch Lightning.
- Parameters:
seq_len (int) – Target sequence length
seq_dim (int) – Target sequence dimension, for univariate time series, set as 1
condition (str) – Possible condition type, choose from [None, ‘predict’, ‘impute’, ‘class’, ‘super_resolution’]. None standards for unconditional generation.
**kwargs – Additional arguments for the model
- ALLOW_CONDITION = Ellipsis
- sample(n_sample: int = 1, condition: Tensor | int = None, **kwargs) Tensor
Sample time series from trained model in evaluation mode.
- Parameters:
n_sample (int, optional) – The number of samples. Defaults to 1.
condition (torch.Tensor | int, optional) – Condition tensor in shape (batch_size, seq_len or obs_len, seq_dim). If the model is class label, then could be int. Defaults to None.
kwargs – Additional arguments for the sampling process. E.g. data_mask, t, etc.
- Returns:
sampled time series of shape (n_sample, seq_len, seq_dim) for unconditional generation, or (batch_size, seq_len, seq_dim, n_sample) for conditional generation.
- Return type:
torch.Tensor