gents.evaluation.visualization package
Module contents
- gents.evaluation.visualization.imputation_visual(real_data: Tensor, gen_data: Tensor, cond_data: Tensor, data_mask: BoolTensor, max_viz_n_channel=3, save_root=None)
Visualize time series imputation results, including 95% predict interval.
- Parameters:
real_data (torch.Tensor) – Ground truth time series, in shape of [B, seq_len, C].
gen_data (torch.Tensor) – Predicted time series scenarios, in shape of [B, seq_len, C, N], N is the number of scenarios.
cond_data (torch.Tensor) – Observed time series, in shape of [B, seq_len, C], missing values should be set as NaN.
data_mask (torch.BoolTensor) – Ground truth time series data mask, in shape of [B, seq_len, C].
max_viz_n_channel (int, optional) – The maximum number of channels to be visualized. Defaults to 3.
save_root (str, optional) – Save root path. The post fix should be .png/.pdf/etc. If None, don’t save figure. Defaults to None.
- gents.evaluation.visualization.predict_visual(real_data: Tensor, gen_data: Tensor, data_mask: BoolTensor, max_viz_n_channel: int = 3, save_root: str | None = None)
Visualize time series prediction results, including 95% predict interval.
- Parameters:
real_data (torch.Tensor) – Ground truth time series, in shape of [B, obs_len + seq_len, C].
gen_data (torch.Tensor) – Predicted time series scenarios, in shape of [B, seq_len, C, N], N is the number of scenarios.
data_mask (torch.BoolTensor) – Ground truth time series data mask, in shape of [B, obs_len + seq_len, C].
max_viz_n_channel (int, optional) – The maximum number of channels to be visualized. Defaults to 3.
save_root (str, optional) – Save root path. The post fix should be .png/.pdf/etc. If None, don’t save figure. Defaults to None.
- gents.evaluation.visualization.tsne_visual(real_data, generated_data, class_label_data=None, save_root=None, min_viz_samples=1000)
TSNE visualization of generated time series and real time series.
- Parameters:
real_data (ArrayLike) – Real time series data, in shape of [B, T, C].
generated_data (ArrayLike) – Generated time series data, in shape of [B, T, C].
class_label_data (ArrayLike, optional) – Time series labels. If not None, in shape of [B, ] Defaults to None.
save_root (str, optional) – Save root path. The post fix should be .png/.pdf/etc. If None, don’t save figure. Defaults to None.
min_viz_samples (int, optional) – The number of data samples put into visualization. Defaults to 1000.