Source code for libam.datasets._dataset

from pathlib import Path
from typing import Any, Callable

import networkx as nx

import libam
from libam.datasets._registry import fetch


[docs] class Dataset: def __init__(self, filename: str, loader: Callable[[Path], Any], parser: Callable[..., libam.GraphPair], members: list[str] | None = None, ) -> None: """ Loads a dataset file into a networkx graph object. :param filename: Registry key of the file to load, or the folder prefix (ending in ``/``) for multi-file datasets. :param loader: Loader function turning the resolved path into graph data. :param parser: Parser function turning the loaded data into a GraphPair. :param members: For folder datasets, the file names inside ``filename`` that must be fetched; the loader receives the folder path. :returns: simple graph object, or tuple of graph objects and their ground truth, or graph objects, features, and ground truth """ self._filename = filename self._loader = loader self._parser = parser self._members = members def __repr__(self) -> str: return self._filename def _resolve(self) -> Path: """Fetch the dataset (downloading on first use) and return its local path.""" if self._members is None: return Path(fetch(self._filename)) paths = [fetch(f"{self._filename}{member}") for member in self._members] return Path(paths[0]).parent def graph(self) -> nx.Graph: return self._loader(self._resolve()) def graphpair(self) -> libam.GraphPair: return self._parser(self._loader(self._resolve()))