Source code for PlanetAlign.datasets.pems

from typing import Union, Optional
from pathlib import Path
import os
import torch

from PlanetAlign.data import Dataset
from .utils import download_file_from_google_drive


[docs] class PeMS08(Dataset): """A pair of traffic networks synthesized from the `Performance Measurement System (PeMS) Data Source <https://dot.ca.gov/programs/traffic-operations/mpr/pems-source>`_. Nodes represent sensors and edges indicate traffic flow correlation. Node attributes are averaged across all time interval. The two networks are noisy permutations of the original network generated by randomly inserting 10% edges (PeMS08-1) and deleting 15% edges (PeMS08-2) from the original network, respectively. There are in total 170 common nodes across two networks. .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Graph - #nodes - #edges - #node attrs - #edge attrs * - Airport1 - 170 - 301 - 3 - 0 * - Airport2 - 170 - 233 - 3 - 0 """ def __init__(self, root: Union[str, Path], download: Optional[bool] = False, train_ratio: Optional[float] = 0.2, dtype: torch.dtype = torch.float32, seed: Optional[int] = 0): if download: download_file_from_google_drive( remote_file_id='1mVEYcniOIueMfErkrpT1zA1XLptPnW5D', save_filename='PeMS08.pt', root=root) if not self._check_integrity(root): raise RuntimeError('PeMS08 dataset not found or corrupted. You can use download=True to download it') super(PeMS08, self).__init__(root=root, name='PeMS08', train_ratio=train_ratio, dtype=dtype, seed=seed) def _check_integrity(self, root): return os.path.exists(os.path.join(root, 'PeMS08.pt'))