Source code for libam.graph.graph_pair

from __future__ import annotations
from dataclasses import dataclass

from functools import cached_property

import networkx as nx
import numpy as np
import scipy.sparse as sps

from .._logging import logger
from libam.graph._generate import apply_noise, permute_graph
from libam.graph._similarities_preprocess import create_L, create_S


[docs] @dataclass class GraphPair: """A matched source/target graph pair with lazily computed data representations. Algorithms access whatever format they need (adjacency matrix, node-pair candidates, edge-pair adjacency) via properties, representations are computed once on first access and cached. Node features (src_features, tar_features) are optional arrays of shape [n, d]. When provided, restricted algorithms can use these structural features when they are provide. The caller is responsible for ensuring feature rows remain consistent with node indices after any permutation or re-labeling. """ src: nx.Graph tar: nx.Graph ground_truth: tuple[np.ndarray, np.ndarray] | None = None src_features: np.ndarray | None = None # shape [n_src, d] tar_features: np.ndarray | None = None # shape [n_tar, d]
[docs] @classmethod def from_graph( cls, G: nx.Graph, features: np.ndarray | None = None, ) -> GraphPair: """ Create a graph pair from a single base graph with an identity ground truth Both src and tar start as copies of G. Use .permute() and .add_noise() to apply transformations afterward. :param G: A single graph of type NetworkX.Graph to derive graph pair from :param features: Features accompanying the graphs nodes. :return: """ n = G.number_of_nodes() identity: tuple[np.ndarray, np.ndarray] = (np.arange(n), np.arange(n)) src_feat = features.copy() if features is not None else None tar_feat = features.copy() if features is not None else None return cls(G.copy(), G.copy(), identity, src_feat, tar_feat)
[docs] @classmethod def from_graphs( cls, src: nx.Graph, tar: nx.Graph, ground_truth: tuple[np.ndarray, np.ndarray] | None = None, src_features: np.ndarray | None = None, tar_features: np.ndarray | None = None, ) -> GraphPair: """Create a pair from two already-constructed graphs.""" return cls(src, tar, ground_truth, src_features, tar_features)
[docs] def permute(self) -> GraphPair: """Derive a synthetic target by randomly permuting src node labels. Intended use: synthetic benchmarks created via GraphPair.from_graph(G). Replaces self.tar with a permuted copy of self.src and sets self.ground_truth to the corresponding node mapping. Warning: Logs a warning if src and tar are not the same graph, since this operation discards tar. For pairs built from two distinct real-world graphs, use shuffle_labels() instead. Returns self to allow chaining. """ if not nx.utils.graphs_equal(self.src, self.tar): logger.warning( "permute() replaces self.tar with a permuted copy of self.src, " "but src and tar appear to be different graphs. " "Did you mean shuffle_labels()?" ) src_edges = np.array(self.src.edges()) n = self.src.number_of_nodes() tar_edges, ground_truth = permute_graph(src_edges, n) self.tar = nx.Graph(tar_edges.tolist()) self.tar.add_nodes_from(range(n)) self.ground_truth = ground_truth if self.tar_features is not None: src_to_tar, tar_to_src = ground_truth self.tar_features = self.tar_features[tar_to_src] self._invalidate_cache() return self
[docs] def shuffle_labels(self) -> GraphPair: """Randomly relabel nodes in both src and tar to remove label-based shortcuts. Intended use: pairs built from two distinct real-world graphs via GraphPair.from_graphs(src, tar). Applies independent random permutations to each graph's node labels so algorithms cannot exploit coincidentally matching graph patterns. Any existing ground truth is updated to remain valid. Warning: Logs a warning if src and tar are the same graph, since for synthetic benchmarks permute() is the appropriate operation. Returns self for chaining. """ if nx.utils.graphs_equal(self.src, self.tar): logger.warning( "shuffle_labels() permutes both graphs independently, " "but src and tar appear to be the same graph. " "Did you mean permute()?" ) n_src = self.src.number_of_nodes() n_tar = self.tar.number_of_nodes() p_src = np.random.permutation(n_src) p_tar = np.random.permutation(n_tar) self.src = nx.relabel_nodes(self.src, dict(enumerate(p_src.tolist()))) self.tar = nx.relabel_nodes(self.tar, dict(enumerate(p_tar.tolist()))) if self.ground_truth is not None: src_to_tar, tar_to_src = self.ground_truth new_src_to_tar = np.empty_like(src_to_tar) new_src_to_tar[p_src] = p_tar[src_to_tar] new_tar_to_src = np.empty_like(tar_to_src) new_tar_to_src[p_tar] = p_src[tar_to_src] self.ground_truth = (new_src_to_tar, new_tar_to_src) if self.src_features is not None: new_src_features = np.empty_like(self.src_features) new_src_features[p_src] = self.src_features self.src_features = new_src_features if self.tar_features is not None: new_tar_features = np.empty_like(self.tar_features) new_tar_features[p_tar] = self.tar_features self.tar_features = new_tar_features self._invalidate_cache() return self
[docs] def add_noise( self, source_noise: float = 0.0, target_noise: float = 0.0, refill: bool = False, ) -> GraphPair: """Remove (and optionally refill) edges from src and/or tar. Args: source_noise: Fraction of source edges to remove. target_noise: Fraction of target edges to remove. refill: If True, removed edges are replaced with random ones. Returns self for chaining. """ src_edges = np.array(self.src.edges()) tar_edges = np.array(self.tar.edges()) n = self.src.number_of_nodes() n_edges = src_edges.shape[0] src_edges, tar_edges = apply_noise(src_edges, tar_edges, n, n_edges, source_noise, target_noise, refill) self.src = nx.Graph(src_edges.tolist()) self.tar = nx.Graph(tar_edges.tolist()) self.src.add_nodes_from(range(n)) self.tar.add_nodes_from(range(n)) self._invalidate_cache() return self
def _invalidate_cache(self) -> None: """Discard all cached representations so they are recomputed on next access.""" for attr in ("src_adjacency", "tar_adjacency", "L", "S", "_L_coo"): self.__dict__.pop(attr, None) # --- Core representations --- @cached_property def src_adjacency(self) -> np.ndarray: """n x n adjacency matrix for src, rows/cols ordered 0..n-1.""" n = self.src.number_of_nodes() return nx.to_numpy_array(self.src, nodelist=range(n)) @cached_property def tar_adjacency(self) -> np.ndarray: """n x n adjacency matrix for tar, rows/cols ordered 0..n-1.""" n = self.tar.number_of_nodes() return nx.to_numpy_array(self.tar, nodelist=range(n)) # --- Preprocessing @cached_property def L(self) -> sps.csr_matrix: """Candidate match scores: L[i, j] = how likely src node i aligns to tar node j (degree-based).""" return create_L(self.src_adjacency, self.tar_adjacency) @cached_property def S(self) -> sps.csr_matrix: """Edge-pair adjacency over L's non-zeros: S[e1, e2] = 1 if edges e1 and e2 share a matched endpoint.""" return create_S( sps.csr_matrix(self.src_adjacency), sps.csr_matrix(self.tar_adjacency), self.L, ) @cached_property def _L_coo(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: li, lj, w = sps.find(self.L) return li, lj, w @property def li(self) -> np.ndarray: """Src node indices of L's non-zero entries.""" return self._L_coo[0] @property def lj(self) -> np.ndarray: """Tar node indices of L's non-zero entries.""" return self._L_coo[1] @property def w(self) -> np.ndarray: """Weights of L's non-zero entries.""" return self._L_coo[2]