gents.evaluation.model_based package
Module contents
- gents.evaluation.model_based.context_fid(ori_data: ndarray, gen_data: ndarray, device: str = 'cpu', ts2vec_path: str | None = None, train_data: ndarray | None = None)
Calculate context-FID.
Context-FID is a FID-like metric for evaluating how realistic the generated time series is (compared to the true time series). It requires to train a representative learning time series model (TS2Vec) on every time series dataset. Then calculate FID using the trained representative learning model.
- Parameters:
ori_data (np.ndarray) – Time series test dataset.
gen_data (np.ndarray) – Generated time series.
device (str, optional) – Computing device. Defaults to “cpu”.
ts2vec_path (str, optional) – The saving path of ts2vec model. If ts2vec_path is given but train_data is not given, then we will try load ts2vec model from ts2vec_path. If both ts2vec_path and train_data are given, we will train a ts2vec model and save it to ts2vec_path.
train_data (np.ndarray, optional) – Time series training dataset. Used for training TS2Vec model. If train_data is given but ts2vec_path is not given, we will train a ts2vec model but don’t save it.
- gents.evaluation.model_based.discriminative_score(ori_data: ndarray, generated_data: ndarray, device: str)
Discriminative score.
Discriminative score is used for evaluating whether the generated time series can be tell from the real time series through a post-hoc trained GRU network.
The generated data will be labeled as fake while the real one will be labeld as true. Then, a GRU classifier will be trained on 80% of the whole dataset to distinguish them.
The classification accuracy will be reported on the last 20% data.
- Parameters:
ori_data (np.ndarray) – Real time series data.
generated_data (np.ndarray) – Generated time series data.
device (str) – Computing device.
- gents.evaluation.model_based.predictive_score(ori_data: ndarray, generated_data: ndarray, device: str)
Predictive score.
Predictive score is used for evaluating the usefulness of the generated time series on forecasting.
The generated time series will be used for training a GRU forecasting model. Then, the trained model will be tested on the real data.
The test MAE will be reported.
- Parameters:
ori_data (np.ndarray) – Real time series data.
generated_data (np.ndarray) – Generated time series data.
device (str) – Computing device.