from typing import Union, Optional
from pathlib import Path
import os
import torch
from PlanetAlign.data import Dataset
from .utils import download_file_from_google_drive
[docs]
class Airport(Dataset):
"""A pair of networks synthesized from the `American air-traffic network <https://arxiv.org/pdf/1704.03165>`_.
Nodes represent airports and an edge exists between two aiports if there are commercial flights between them. The
level of activity in each airport is used as node attributes. The two networks are noisy permutations of the
original network generated by randomly inserting 10% edges (Airport1) and deleting 15% edges (Airport2) from the
original network, respectively. There are in total 1,190 common nodes across two networks.
.. list-table::
:widths: 10 10 10 10 10
:header-rows: 1
* - Graph
- #nodes
- #edges
- #node attrs
- #edge attrs
* - Airport1
- 1,190
- 14,958
- 4
- 0
* - Airport2
- 1,190
- 11,560
- 4
- 0
"""
def __init__(self,
root: Union[str, Path],
download: Optional[bool] = False,
train_ratio: Optional[float] = 0.2,
dtype: torch.dtype = torch.float32,
seed: Optional[int] = 0):
if download:
download_file_from_google_drive(
remote_file_id='1XaeOcM9H_VbUplKBvUt3FD0WPeAsSXfq',
save_filename='airport.pt',
root=root)
if not self._check_integrity(root):
raise RuntimeError('Airport dataset not found or corrupted. You can use download=True to download it')
super(Airport, self).__init__(root=root, name='airport', train_ratio=train_ratio, dtype=dtype, seed=seed)
def _check_integrity(self, root):
return os.path.exists(os.path.join(root, 'airport.pt'))