Source code for PlanetAlign.data.dataset

import copy
import os
from pathlib import Path
from typing import Union
import torch
from torch_geometric.data import Data


[docs] class BaseData: r"""A dataset object storing multiple graphs and ground-truth anchor links for alignment. Parameters ---------- graphs: list[torch_geometric.data.Data] List of PyTorch Geometric Data objects representing the graphs. anchor_links: torch.Tensor Tensor containing the anchor links between the graphs. name: str Name of the dataset. train_ratio: float, optional Ratio of anchor links to be used for training. Default is 0.2. dtype: torch.dtype, optional Data type of the tensors, either torch.float32 or torch.float64. Default is torch.float32. seed: int, optional Random seed for shuffling the anchor links. Default is 0. """ def __init__(self, graphs: list[Data], anchor_links: torch.Tensor, name: str, train_ratio: float = 0.2, dtype: torch.dtype = torch.float32, seed: int = 0): self.pyg_graphs = graphs self._anchor_links = anchor_links self.name = name self.train_ratio = train_ratio self.seed = seed self.dtype = dtype self._validate() self.train_data, self.test_data = self.train_test_split() def train_test_split(self): num_anchor = self._anchor_links.shape[0] rng_state = torch.get_rng_state() torch.manual_seed(self.seed) perm = torch.randperm(num_anchor) torch.set_rng_state(rng_state) train_size = int(num_anchor * self.train_ratio) return self._anchor_links[perm[:train_size]], self._anchor_links[perm[train_size:]] def _check_integrity(self, root): raise NotImplementedError def _validate(self): assert hasattr(self, 'pyg_graphs') and hasattr(self, '_anchor_links'), 'Dataset has not been loaded yet, wrong place for validation' assert type(self.pyg_graphs) in [list, tuple], 'Graphs must be stored in a list or tuple' assert len(self.pyg_graphs) > 1, 'At least two graphs are required for alignment' assert all([isinstance(g, Data) for g in self.pyg_graphs]), 'Each graph must be a PyG Data object' assert isinstance(self._anchor_links, torch.Tensor), 'Anchor links must be stored in a PyTorch tensor' assert self._anchor_links.dim() == 2, 'Anchor links must be a 2D tensor' assert len(self.pyg_graphs) == self._anchor_links.shape[1], 'Number of PyG graphs and anchor links dimension do not match' for i, g in enumerate(self.pyg_graphs): assert g.edge_index.max() < g.num_nodes, f'Node index must be less than the number of nodes in graph {i}' assert self._anchor_links[:, i].max() < g.num_nodes, f'Anchor link must be less than the number of nodes in graph {i}' if g.x is not None: assert g.x.shape[0] == g.num_nodes, f'Number of nodes and node attributes must match in graph {i}' if g.edge_attr is not None: assert g.edge_attr.shape[0] == g.edge_index.shape[1], f'Number of edges and edge attributes must match in graph {i}' def __str__(self): network_count = len(self.pyg_graphs) # Number of networks (Assuming stored in a list) # Header output = ( f"Dataset: {self.name}\n" f"{'=' * 60}\n" f"{'Graphs':<20}" + "".join([f"{self.pyg_graphs[i].name:>15}" for i in range(network_count)]) + f"\n{'-' * 60}\n" ) # Number of Nodes per graph output += ( f"{'# Nodes':<20}" + "".join([f"{self.pyg_graphs[i].num_nodes:>15,}" for i in range(network_count)]) + "\n" ) # Number of Edges per graph output += ( f"{'# Edges':<20}" + "".join([f"{self.pyg_graphs[i].num_edges:>15,}" for i in range(network_count)]) + "\n" ) # Node Attributes Dimension per graph output += ( f"{'# Node Attributes':<20}" + "".join([ f"{self.pyg_graphs[i].num_node_features:>15,}" for i in range(network_count) ]) + "\n" ) # Edge Attributes Dimension per graph output += ( f"{'# Edge Attributes':<20}" + "".join([ f"{self.pyg_graphs[i].num_edge_features:>15,}" for i in range(network_count) ]) + "\n" ) # Anchor Links (Train/Test) output += ( f"{'=' * 60}\n" f"{'Anchor Links':<20}{'Train':>15}{'Test':>15}\n" f"{'-' * 60}\n" f"{f'Count (ratio: {self.train_ratio})':<20}{self.train_data.shape[0]:>15,}{self.test_data.shape[0]:>15,}\n" f"{'=' * 60}" ) return output def clone(self): return copy.deepcopy(self)
[docs] class Dataset: r"""A dataset object storing multiple graphs and ground-truth anchor links for alignment. Parameters ---------- root: str or pathlib.Path Root directory where the dataset is stored. name: str Name of the dataset. train_ratio: float, optional Ratio of anchor links to be used for training. Default is 0.2. dtype: torch.dtype, optional Data type of the tensors, either torch.float32 or torch.float64. Default is torch.float32. seed: int, optional Random seed for shuffling the anchor links. Default is 0. """ def __init__(self, root: Union[str, Path], name: str, train_ratio: float = 0.2, dtype: torch.dtype = torch.float32, seed: int = 0): assert 0 <= train_ratio < 1, 'Training ratio must be in [0, 1)' assert dtype in [torch.float32, torch.float64], 'Data type must be either torch.float32 or torch.float64' self.root = Path(root) self.name = name self.train_ratio = train_ratio self.seed = seed self.dtype = dtype print(f'Loading dataset "{self.name} ({self.__class__.__name__})"...', end=' ') self.pyg_graphs, self._anchor_links = self.load_dataset() print('Done') self._validate() self.train_data, self.test_data = self.train_test_split() def load_dataset(self) -> tuple[list[Data], torch.Tensor]: data_dict = torch.load(Path.joinpath(self.root, f'{self.name}.pt'), weights_only=True) pyg_graphs = list() for gid, gname in enumerate(data_dict['graphs']): num_nodes = data_dict['number_of_nodes'][gid] x = data_dict['node_attributes'][gid].to(self.dtype) if 'node_attributes' in data_dict else None edge_attr = data_dict['edge_attributes'][gid].to(self.dtype) if 'edge_attributes' in data_dict else None edge_index = data_dict['edges'][gid] pyg_graph = Data(name=gname, num_nodes=num_nodes, x=x, edge_index=edge_index, edge_attr=edge_attr) pyg_graphs.append(pyg_graph) anchor_links = data_dict['anchor_links'] return pyg_graphs, anchor_links def train_test_split(self): num_anchor = self._anchor_links.shape[0] rng_state = torch.get_rng_state() torch.manual_seed(self.seed) perm = torch.randperm(num_anchor) torch.set_rng_state(rng_state) train_size = int(num_anchor * self.train_ratio) return self._anchor_links[perm[:train_size]], self._anchor_links[perm[train_size:]] def _check_integrity(self, root): raise NotImplementedError def _validate(self): assert hasattr(self, 'pyg_graphs') and hasattr(self, '_anchor_links'), 'Dataset has not been loaded yet, wrong place for validation' assert type(self.pyg_graphs) in [list, tuple], 'Graphs must be stored in a list or tuple' assert len(self.pyg_graphs) > 1, 'At least two graphs are required for alignment' assert all([isinstance(g, Data) for g in self.pyg_graphs]), 'Each graph must be a PyG Data object' assert isinstance(self._anchor_links, torch.Tensor), 'Anchor links must be stored in a PyTorch tensor' assert self._anchor_links.dim() == 2, 'Anchor links must be a 2D tensor' assert len(self.pyg_graphs) == self._anchor_links.shape[1], 'Number of PyG graphs and anchor links dimension do not match' for i, g in enumerate(self.pyg_graphs): assert g.edge_index.max() < g.num_nodes, f'Node index must be less than the number of nodes in graph {i}' assert self._anchor_links[:, i].max() < g.num_nodes, f'Anchor link must be less than the number of nodes in graph {i}' if g.x is not None: assert g.x.shape[0] == g.num_nodes, f'Number of nodes and node attributes must match in graph {i}' if g.edge_attr is not None: assert g.edge_attr.shape[0] == g.edge_index.shape[1], f'Number of edges and edge attributes must match in graph {i}' def __str__(self): network_count = len(self.pyg_graphs) # Number of networks (Assuming stored in a list) # Header output = ( f"Dataset: {self.name}\n" f"{'=' * 60}\n" f"{'Graphs':<20}" + "".join([f"{self.pyg_graphs[i].name:>15}" for i in range(network_count)]) + f"\n{'-' * 60}\n" ) # Number of Nodes per graph output += ( f"{'# Nodes':<20}" + "".join([f"{self.pyg_graphs[i].num_nodes:>15,}" for i in range(network_count)]) + "\n" ) # Number of Edges per graph output += ( f"{'# Edges':<20}" + "".join([f"{self.pyg_graphs[i].num_edges:>15,}" for i in range(network_count)]) + "\n" ) # Node Attributes Dimension per graph output += ( f"{'# Node Attributes':<20}" + "".join([ f"{self.pyg_graphs[i].num_node_features:>15,}" for i in range(network_count) ]) + "\n" ) # Edge Attributes Dimension per graph output += ( f"{'# Edge Attributes':<20}" + "".join([ f"{self.pyg_graphs[i].num_edge_features:>15,}" for i in range(network_count) ]) + "\n" ) # Anchor Links (Train/Test) output += ( f"{'=' * 60}\n" f"{'Anchor Links':<20}{'Train':>15}{'Test':>15}\n" f"{'-' * 60}\n" f"{f'Count (ratio: {self.train_ratio})':<20}{self.train_data.shape[0]:>15,}{self.test_data.shape[0]:>15,}\n" f"{'=' * 60}" ) return output def clone(self): return copy.deepcopy(self)
if __name__ == '__main__': for f in os.listdir('../../datasets/pyg'): if f.endswith('.pt'): dataset_name = f[:-3] dataset = Dataset(root='../../datasets/pyg', name=dataset_name, train_ratio=0.2, dtype=torch.float64) print(dataset)