Source code for PlanetAlign.datasets.airport

from typing import Union, Optional
from pathlib import Path
import os
import torch

from PlanetAlign.data import Dataset
from .utils import download_file_from_google_drive


[docs] class Airport(Dataset): """A pair of networks synthesized from the `American air-traffic network <https://arxiv.org/pdf/1704.03165>`_. Nodes represent airports and an edge exists between two aiports if there are commercial flights between them. The level of activity in each airport is used as node attributes. The two networks are noisy permutations of the original network generated by randomly inserting 10% edges (Airport1) and deleting 15% edges (Airport2) from the original network, respectively. There are in total 1,190 common nodes across two networks. .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - Graph - #nodes - #edges - #node attrs - #edge attrs * - Airport1 - 1,190 - 14,958 - 4 - 0 * - Airport2 - 1,190 - 11,560 - 4 - 0 """ def __init__(self, root: Union[str, Path], download: Optional[bool] = False, train_ratio: Optional[float] = 0.2, dtype: torch.dtype = torch.float32, seed: Optional[int] = 0): if download: download_file_from_google_drive( remote_file_id='1XaeOcM9H_VbUplKBvUt3FD0WPeAsSXfq', save_filename='airport.pt', root=root) if not self._check_integrity(root): raise RuntimeError('Airport dataset not found or corrupted. You can use download=True to download it') super(Airport, self).__init__(root=root, name='airport', train_ratio=train_ratio, dtype=dtype, seed=seed) def _check_integrity(self, root): return os.path.exists(os.path.join(root, 'airport.pt'))