3-C. Machine Learning Models

3.3 Graph Neural Networks

Completed and Compiled Code (3.3.1–3.3.3): Click Here

A Graph Neural Network (GNN) is built for graph-structured data. Nodes carry features, edges carry relationships, and learning proceeds by passing messages along edges, updating node states, and then reading out a graph-level vector for a downstream task.

Structure. Layers alternate between (i) message passing—neighbors send transformed signals along edges—and (ii) state updates—each node revises its embedding with a nonlinear function. Stacking (k) layers exposes a node to its (k)-hop neighborhood. A final readout/pooling (e.g., sum/mean) aggregates node embeddings into a fixed-length molecular representation.

Functioning. In molecular graphs, atoms = nodes and bonds = edges. Messages combine atom features with bond features, propagate to destination atoms, and are transformed (e.g., by MLPs or GRUs). After several layers, each atom’s embedding reflects its local chemical context (hybridization, aromaticity, ring membership, nearby heteroatoms).

Learning. Training minimizes a task loss (e.g., BCEWithLogits for classification) with gradient descent. Gradients flow through the graph topology, tuning how atoms attend to their neighbors and how the pooled representation supports the final prediction.

Roadmap.

3.3.1 looks at OGB-MOLHIV, the dataset and official split, and builds loaders with basic EDA.

3.3.2 implements an edge-aware MPNN (NNConv + GRU) and reads the curves/ROC.

3.3.3 compares GCN and MPNN on the same split and discusses the outcomes.


3.3.1 From Descriptors to Molecular Graphs: OGB-MOLHIV at a Glance

Why start here. Descriptor-only QSAR can flatten connectivity; a molecular GNN keeps the who-is-bonded-to-whom information. The OGB-MOLHIV benchmark (Hu et al., 2020) provides graph data (atoms, bonds) with an anti-HIV activity label and an official train/valid/test split—ideal for a clean, reproducible pipeline.

Key Functions and Concepts (Data Layer)
PygGraphPropPredDataset
OGB–PyG dataset wrapper
Auto-download + official split
DataLoader (PyG)
Mini-batches of graphs
Collates Data objects into DataBatch
DataBatch fields
x, edge_index, edge_attr, y
Atoms (9-dim), bonds (3-dim), labels

3.3.1-A Load the dataset and the official split

What this cell does. Imports minimal packages, downloads/loads ogbg-molhiv, and fetches the official indices for train/valid/test.

# Minimal imports
import numpy as np, matplotlib.pyplot as plt, torch
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.loader import DataLoader

# Load + split
dataset = PygGraphPropPredDataset(name="ogbg-molhiv")
split = dataset.get_idx_split()
train_set, valid_set, test_set = dataset[split["train"]], dataset[split["valid"]], dataset[split["test"]]

print(dataset)
print(f"Graphs: total={len(dataset)} | train/valid/test = {len(train_set)}/{len(valid_set)}/{len(test_set)}")

Results PygGraphPropPredDataset(41127) Graphs: total=41127 | train/valid/test = 32901/4113/4113

This confirms the canonical OGB split we use throughout.

3.3.1-B Quick EDA (label skew, nodes/edges per graph)

What this cell does. Builds three histograms: label distribution, atoms per molecule, and bonds per molecule—so we know the class balance and a reasonable message-passing depth.

labels = dataset.data.y.view(-1).cpu().numpy()
num_nodes = [g.num_nodes for g in dataset]
num_edges = [g.num_edges for g in dataset]

fig, axs = plt.subplots(1, 3, figsize=(12, 3.6), dpi=140)
axs[0].hist(labels, bins=[-0.5,0.5,1.5], rwidth=0.8, edgecolor="black"); axs[0].set_xticks([0,1])
axs[0].set_title("Label distribution"); axs[0].grid(alpha=0.3)
axs[1].hist(num_nodes, bins=40, edgecolor="black"); axs[1].set_title("Nodes per graph"); axs[1].grid(alpha=0.3)
axs[2].hist(num_edges, bins=40, edgecolor="black"); axs[2].set_title("Edges per graph"); axs[2].grid(alpha=0.3)
plt.tight_layout(); plt.show()

How to read it. Labels are heavily imbalanced (positives ≪ negatives), so ROC-AUC is a better main metric than accuracy, and class-weighted BCE is advisable. Graph sizes are mostly in the tens of atoms/bonds, so 2–4 message-passing layers are a reasonable starting point.

molhiv_eda

3.3.1-C DataLoaders + a batch sanity check

What this cell does. Creates loaders and prints one training batch to confirm feature shapes (we will reuse the loaders later).

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=256, shuffle=False)
test_loader  = DataLoader(test_set,  batch_size=256, shuffle=False)

batch = next(iter(train_loader))
print(batch)
print("num_graphs:", batch.num_graphs,
      "| node_feat_dim:", batch.x.size(-1),
      "| edge_feat_dim:", batch.edge_attr.size(-1),
      "| y shape:", tuple(batch.y.view(-1).shape))

Results DataBatch(edge_index=[2, 3640], edge_attr=[3640, 3], x=[1711, 9], y=[64, 1], …) num_graphs: 64 | node_feat_dim: 9 | edge_feat_dim: 3 | y shape: (64,)

This confirms the exact input dimensions we connect to the models in §3.3.2 and §3.3.3.


3.3.2 Message Passing as Chemical Reasoning (NNConv–GRU MPNN)

The Message Passing Neural Network (MPNN) (Gilmer et al., 2017) is a family of GNNs designed to learn directly from molecular graphs. Each atom updates its representation by aggregating messages from its neighboring atoms through bonds. In our implementation, bond features determine how messages are weighted, and a Gated Recurrent Unit (GRU) stabilizes multi-step updates.

Goal. Build and train an edge-aware MPNN (using NNConv + GRU) for molecular property prediction on ogbg-molhiv, record training dynamics, and evaluate the ROC-AUC on the held-out test set.

Key Functions and Concepts (MPNN)
NNConv
Edge-conditioned convolution
Uses bond features to parameterize message filters
EdgeNet
MLP for edge feature transformation
Maps edge attributes → filter matrices
GRU
Gated recurrent unit
Controls information flow across layers
global_add_pool
Readout layer
Aggregates atomic embeddings → molecule representation

3.3.2-A Reproducibility, device, and class weight

This first cell ensures consistent runs and builds loaders for the model. Because the MOLHIV dataset is imbalanced, we compute a positive-class weight for use in BCEWithLogitsLoss.

import random, numpy as np, torch
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.loader import DataLoader

# Fix randomness for reproducibility
def set_seed(s=1):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(1)

# Select device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset and official splits
dataset = PygGraphPropPredDataset(name="ogbg-molhiv")
split = dataset.get_idx_split()
train_loader = DataLoader(dataset[split["train"]], batch_size=64, shuffle=True)
valid_loader = DataLoader(dataset[split["valid"]], batch_size=256, shuffle=False)
test_loader  = DataLoader(dataset[split["test"]],  batch_size=256, shuffle=False)

# Compute class weight (for imbalance)
train_y = dataset[split["train"]].y.view(-1).cpu().numpy()
pos_weight = torch.tensor([(train_y==0).sum()/(train_y==1).sum()], dtype=torch.float, device=device)

D_x, D_e = dataset.num_node_features, dataset.num_edge_features

Explanation. This setup guarantees that experiments are repeatable. pos_weight > 1 tells the loss function to assign higher penalty to misclassified positive samples.

3.3.2-B Define the model: EdgeNet + NNConv + GRU

This section builds the network architecture. Each message-passing layer applies an edge-specific linear transformation followed by a GRU update and dropout for regularization.

import torch.nn as nn, torch.nn.functional as F
from torch_geometric.nn import NNConv, global_add_pool

class EdgeNet(nn.Module):
    """Small MLP that transforms edge features into NNConv filters."""
    def __init__(self, edge_in, hidden):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(edge_in, hidden*hidden),
            nn.ReLU(),
            nn.Linear(hidden*hidden, hidden*hidden)
        )
        self.hidden = hidden

    def forward(self, e):  # e: [E, edge_in]
        return self.net(e)  # Output: [E, hidden*hidden]


class MPNN(nn.Module):
    """Message Passing Neural Network with edge-conditioned messages."""
    def __init__(self, node_in, edge_in, hidden=128, layers=3, dropout=0.2):
        super().__init__()
        self.embed = nn.Linear(node_in, hidden)
        self.edge_net = EdgeNet(edge_in, hidden)
        self.convs = nn.ModuleList([
            NNConv(hidden, hidden, self.edge_net, aggr='add') for _ in range(layers)
        ])
        self.gru = nn.GRU(hidden, hidden)
        self.dropout = dropout
        self.readout = nn.Linear(hidden, 1)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x.float(), data.edge_index, data.edge_attr, data.batch
        # Safety guard: ensure bond features are floats
        if edge_attr is None:
            edge_attr = torch.zeros((edge_index.size(1), D_e), dtype=torch.float, device=x.device)
        else:
            edge_attr = edge_attr.float()

        # 1) Embed atom features
        x = self.embed(x)
        h = x.unsqueeze(0)

        # 2) Perform message passing and gated update
        for conv in self.convs:
            m = F.relu(conv(x, edge_index, edge_attr))
            m = F.dropout(m, p=self.dropout, training=self.training).unsqueeze(0)
            out, h = self.gru(m, h)
            x = out.squeeze(0)

        # 3) Aggregate to graph-level representation
        g = global_add_pool(x, batch)

        # 4) Linear readout → logit
        return self.readout(g).view(-1)

Explanation.

  • EdgeNet maps 3-dimensional bond descriptors to layer-specific filters.
  • NNConv uses those filters to compute neighbor messages.
  • GRU decides how much new information to incorporate versus retain from previous states.
  • global_add_pool sums over all atoms to yield a molecule-level vector.

3.3.2-C Training, validation (AUC), and testing (ROC)

The training loop minimizes BCEWithLogitsLoss, monitors AUC on the validation split, and restores the best weights before evaluating on the test set.

from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt, numpy as np

def train_epoch(model, loader, opt, crit, clip=2.0):
    """Single epoch of training."""
    model.train(); total, n = 0.0, 0
    for data in loader:
        data = data.to(device)
        y = data.y.view(-1).float()
        opt.zero_grad()
        logits = model(data)
        loss = crit(logits, y)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        opt.step()
        total += loss.item() * y.numel(); n += y.numel()
    return total / n


@torch.no_grad()
def eval_auc(model, loader):
    """Compute ROC-AUC on a loader."""
    model.eval(); y_true, y_prob = [], []
    for data in loader:
        data = data.to(device)
        prob = torch.sigmoid(model(data)).cpu().numpy()
        y_true.append(data.y.view(-1).cpu().numpy())
        y_prob.append(prob)
    y_true = np.concatenate(y_true); y_prob = np.concatenate(y_prob)
    auc = roc_auc_score(y_true, y_prob)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    return auc, (fpr, tpr)


# Initialize model, optimizer, and loss
model = MPNN(D_x, D_e, hidden=128, layers=3, dropout=0.2).to(device)
opt   = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
crit  = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

best_val, best_state = -1.0, None
train_curve, val_curve = [], []

# Train for 20 epochs
for ep in range(1, 21):
    tr = train_epoch(model, train_loader, opt, crit)
    va, _ = eval_auc(model, valid_loader)
    train_curve.append(tr); val_curve.append(va)
    if va > best_val:
        best_val = va
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}
    print(f"Epoch {ep:02d} | train {tr:.4f} | valid AUC {va:.4f}")

# Restore best checkpoint and test
model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
test_auc, (fpr, tpr) = eval_auc(model, test_loader)
print(f"[MPNN] Best valid AUC = {best_val:.4f} | Test AUC = {test_auc:.4f}")

# Plot results
fig, axs = plt.subplots(1, 2, figsize=(11, 4), dpi=150)
axs[0].plot(train_curve, label="Train loss")
axs[0].plot(val_curve, label="Valid AUC")
axs[0].set_title("Training progress")
axs[0].set_xlabel("Epoch")
axs[0].grid(alpha=0.3)
axs[0].legend()
axs[1].plot(fpr, tpr, lw=2, label=f"MPNN (AUC={test_auc:.3f})")
axs[1].plot([0,1],[0,1],"--",color="gray")
axs[1].set_xlabel("False Positive Rate")
axs[1].set_ylabel("True Positive Rate")
axs[1].set_title("ROC on test")
axs[1].grid(alpha=0.3)
axs[1].legend()
plt.tight_layout(); plt.show()

Training Log (excerpt) Epoch 17 … valid AUC 0.7089 → best Test AUC = 0.7268

Interpretation.

  • Training loss decreases smoothly to ≈ 0.18.
  • Validation AUC fluctuates between 0.35 – 0.70 due to class imbalance and small batch effects, peaking around 0.71.
  • The best test AUC ≈ 0.73 shows that the model captures useful substructural information even in a compact configuration.
  • The oscillations are typical for imbalanced molecular classification tasks; early stopping at the validation peak prevents overfitting.

mpnn_curves_roc

References (for §3.3 & §3.3.1 §3.3.2)

  • Fey, M., & Lenssen, J. E. (2019). PyTorch Geometric.
  • Hu, W. et al. (2020). Open Graph Benchmark.
  • Wu, Z. et al. (2018). MoleculeNet.

3.3.3 Same Split, Two Architectures: GCN vs Edge-aware MPNN

We now place two graph architectures side by side on the same OGB-MOLHIV split and with the same evaluation metric (ROC-AUC):

  • GCN (Kipf & Welling, 2017): neighborhood aggregation without an edge MLP; all bonds contribute uniformly.
  • MPNN (NNConv + GRU): messages are edge-conditioned by bond features; a GRU stabilizes multi-step updates.

The point is to contrast how the two formulations behave under an identical training recipe, not to assume one is inherently superior.

Key Functions and Concepts (Comparison)
GCNConv
Graph convolution (no bond MLP)
Uniform treatment of bonds during aggregation
NNConv
Edge-conditioned convolution
Bond features → per-edge filters
GRU
Gated update across layers
Stabilizes multi-step message passing
ROC-AUC
Ranking under class imbalance
Threshold-free comparison

Note. This subsection reuses the loaders and pos_weight from §3.3.2 (3.3.2-A). If you run cells independently, execute §3.3.1 and §3.3.2-A first.

3.3.3-A Define the two architectures succinctly

What this cell does. Implements a compact GCN baseline and the NNConv + GRU MPNN; both end with global_add_pool and a linear head that outputs a logit per molecule.

import torch.nn as nn, torch.nn.functional as F
from torch_geometric.nn import GCNConv, NNConv, global_add_pool

class GCN_Baseline(nn.Module):
    def __init__(self, in_dim, hidden=128, layers=3, dropout=0.2):
        super().__init__()
        self.embed = nn.Linear(in_dim, hidden)
        self.convs = nn.ModuleList([GCNConv(hidden, hidden) for _ in range(layers)])
        self.drop  = dropout
        self.out   = nn.Linear(hidden, 1)
    def forward(self, data):
        x, edge_index, batch = data.x.float(), data.edge_index, data.batch
        x = self.embed(x)
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.drop, training=self.training)
        g = global_add_pool(x, batch)
        return self.out(g).view(-1)

class EdgeNet(nn.Module):
    def __init__(self, edge_in, hidden):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(edge_in, hidden*hidden), nn.ReLU(),
            nn.Linear(hidden*hidden, hidden*hidden)
        ); self.hidden = hidden
    def forward(self, e): return self.net(e)

class MPNN_NNConv(nn.Module):
    def __init__(self, node_in, edge_in, hidden=128, layers=3, dropout=0.2):
        super().__init__()
        self.embed   = nn.Linear(node_in, hidden)
        self.edge_net= EdgeNet(edge_in, hidden)
        self.convs   = nn.ModuleList([NNConv(hidden, hidden, self.edge_net, aggr='add') for _ in range(layers)])
        self.gru     = nn.GRU(hidden, hidden)
        self.drop    = dropout
        self.out     = nn.Linear(hidden, 1)
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x.float(), data.edge_index, data.edge_attr, data.batch
        edge_attr = edge_attr.float() if edge_attr is not None else torch.zeros((edge_index.size(1), D_e), device=x.device)
        x = self.embed(x); h = x.unsqueeze(0)
        for conv in self.convs:
            m = F.relu(conv(x, edge_index, edge_attr))
            m = F.dropout(m, p=self.drop, training=self.training).unsqueeze(0)
            out, h = self.gru(m, h); x = out.squeeze(0)
        g = global_add_pool(x, batch)
        return self.out(g).view(-1)

3.3.3-B Shared training/evaluation routine

What this cell does. Trains either model with the same loop and hyperparameters, selects the best validation AUC checkpoint, and reports test AUC and ROC.

from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np, torch

def train_and_select(model, tag, train_loader, valid_loader, test_loader, pos_weight=None, epochs=15, lr=1e-3):
    device = next(model.parameters()).device
    opt  = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    crit = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    best_val, best = -1.0, None

    for ep in range(1, epochs+1):
        # Train one epoch
        model.train(); tot, n = 0.0, 0
        for data in train_loader:
            data = data.to(device); y = data.y.view(-1).float()
            opt.zero_grad(); loss = crit(model(data), y); loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()
            tot += loss.item() * y.numel(); n += y.numel()

        # Validate
        model.eval(); y_true, y_prob = [], []
        with torch.no_grad():
            for data in valid_loader:
                data = data.to(device)
                p = torch.sigmoid(model(data)).cpu().numpy()
                y_true.append(data.y.view(-1).cpu().numpy()); y_prob.append(p)
        y_true = np.concatenate(y_true); y_prob = np.concatenate(y_prob)
        val_auc = roc_auc_score(y_true, y_prob)
        print(f"[{tag}] epoch {ep:02d} | train {tot/n:.4f} | valid AUC {val_auc:.4f}")

        if val_auc > best_val:
            best_val = val_auc
            best = {k: v.cpu() for k, v in model.state_dict().items()}

    # Test with best checkpoint
    model.load_state_dict({k: v.to(device) for k, v in best.items()})
    model.eval(); y_true, y_prob = [], []
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            p = torch.sigmoid(model(data)).cpu().numpy()
            y_true.append(data.y.view(-1).cpu().numpy()); y_prob.append(p)
    y_true = np.concatenate(y_true); y_prob = np.concatenate(y_prob)
    test_auc = roc_auc_score(y_true, y_prob)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    print(f"[{tag}] BEST valid AUC {best_val:.4f} | TEST AUC {test_auc:.4f}")
    return test_auc, (fpr, tpr)

3.3.3-C Run both models and visualize

What this cell does. Instantiates GCN and MPNN, runs the shared routine for 15 epochs, and produces ROC curves plus a bar chart of the two AUCs.

import matplotlib.pyplot as plt

gcn  = GCN_Baseline(D_x).to(device)
mpnn = MPNN_NNConv(D_x, D_e).to(device)

gcn_auc, (gcn_fpr, gcn_tpr)   = train_and_select(gcn,  "GCN",  train_loader, valid_loader, test_loader, pos_weight, epochs=15)
mpn_auc, (mpn_fpr, mpn_tpr)   = train_and_select(mpnn, "MPNN", train_loader, valid_loader, test_loader, pos_weight, epochs=15)

# Plot ROC and AUC bars
fig, ax = plt.subplots(1, 2, figsize=(11, 4), dpi=150)
ax[0].plot(gcn_fpr, gcn_tpr, lw=2, label=f"GCN (AUC={gcn_auc:.3f})")
ax[0].plot(mpn_fpr, mpn_tpr, lw=2, label=f"MPNN/NNConv (AUC={mpn_auc:.3f})")
ax[0].plot([0,1],[0,1],"--",color="gray"); ax[0].set_xlabel("FPR"); ax[0].set_ylabel("TPR")
ax[0].set_title("ROC on test"); ax[0].grid(alpha=0.3); ax[0].legend()

ax[1].bar(["GCN","MPNN"], [gcn_auc, mpn_auc], edgecolor="black", alpha=0.9)
for x, v in zip(["GCN","MPNN"], [gcn_auc, mpn_auc]):
    ax[1].text(x, v+0.01, f"{v:.3f}", ha="center", va="bottom")
ax[1].set_ylim(0, 1.05); ax[1].set_ylabel("AUC"); ax[1].set_title("AUC comparison")
ax[1].grid(axis="y", alpha=0.3)
plt.tight_layout(); plt.show()

Recorded results GCN [GCN] … BEST valid AUC 0.7212 | TEST AUC 0.6640 MPNN (quick comparison recipe) [MPNN] … BEST valid AUC 0.5505 | TEST AUC 0.5894

Figure gcn_mpn

How to read these outcomes (with the logs)

  • Under the shared 15-epoch recipe, the GCN run landed at test AUC ≈ 0.664, while the quick MPNN comparison run landed at ≈ 0.589.
  • In §3.3.2, a separate MPNN run—same split but with early-stopping at the sharp validation peak and 20 epochs—reached test AUC ≈ 0.727.
  • Together, these runs show how training choices (epoch budget, early-stopping point, class weighting, LR/scheduler) influence the two architectures differently. The comparison here is useful to see relative behavior under the same budget; §3.3.2 shows what the edge-aware model can achieve with a slightly longer/steadier training loop.

Optional dials to try (one at a time). Epochs 30–40 for MPNN; LR (5\times 10^{-4}) for MPNN; hidden 256 (layers=3); ReduceLROnPlateau on validation AUC; keep weight_decay=1e-5, dropout=0.2.

References (for §3.3.3) Kipf & Welling, 2017 (GCN); Gilmer et al., 2017 (MPNN); Fey & Lenssen, 2019 (PyG); Hu et al., 2020 (OGB).


3.3.4 Challenges and Interpretability in GNNs

Completed and Compiled Code: Click Here

What We're Exploring: Fundamental Challenges in Graph Neural Networks

Why Study GNN Challenges?

  • Over-smoothing: Why deeper isn't always better - node features become indistinguishable
  • Interpretability: Understanding what the model learns - crucial for drug discovery
  • Real Impact: These challenges affect whether GNNs can be trusted in production

What you'll learn: The fundamental limitations of GNNs and current solutions to overcome them

Challenge What Happens Why It Matters Solutions
Over-smoothing Node features converge
All atoms look the same
Limits network depth
Can't capture long-range interactions
Residual connections
Skip connections, normalization
Interpretability Black box predictions
Don't know why it predicts
No trust in predictions
Can't guide drug design
Attention visualization
Substructure explanations

While GNNs have shown remarkable success in molecular property prediction, they face several fundamental challenges that limit their practical deployment. In this section, we’ll explore two critical issues: the over-smoothing phenomenon that limits network depth, and the interpretability challenge that makes it difficult to understand model predictions.

The Power of Depth vs. The Curse of Over-smoothing

In Graph Neural Networks (GNNs), adding more message-passing layers allows nodes (atoms) to gather information from increasingly distant parts of a graph (molecule). At first glance, it seems deeper networks should always perform better—after all, more layers mean more context. But in practice, there’s a major trade-off known as over-smoothing.

Understanding Over-smoothing

Concept Simple Explanation Molecular Context
Message Passing Atoms share info with neighbors Like atoms "talking" through bonds
Receptive Field How far information travels k layers = k-hop neighborhood
Over-smoothing All nodes become similar Can't distinguish different atoms
Critical Depth ~3-5 layers typically Beyond this, performance drops

What to Demonstrate

Before we jump into the code, here’s what it’s trying to show:

We want to measure how similar node embeddings become as we increase the number of GCN layers. If all node vectors become nearly identical after several layers, that means the model is losing resolution—different atoms can’t be distinguished anymore. This is called over-smoothing.

Key Functions and Concepts
GCNConv
Graph convolution layer
Aggregates neighbor features
F.relu()
Non-linear activation
Adds expressiveness
F.normalize()
L2 normalization
For cosine similarity
torch.mm()
Matrix multiplication
Computes similarity matrix

Functions and Concepts Used

  • GCNConv (from torch_geometric.nn): This is a standard Graph Convolutional Network (GCN) layer. It performs message passing by aggregating neighbor features and updating node embeddings. It normalizes messages by node degrees to prevent high-degree nodes from dominating.

  • F.relu(): Applies a non-linear ReLU activation function after each GCN layer. This introduces non-linearity to the model, allowing it to learn more complex patterns.

  • F.normalize(..., p=2, dim=1): This normalizes node embeddings to unit length (L2 norm), which is required for cosine similarity calculation.

  • torch.mm(): Matrix multiplication is used here to compute the full cosine similarity matrix between normalized node embeddings.

  • Cosine similarity: Measures how aligned two vectors are (value close to 1 means very similar). By averaging all pairwise cosine similarities, we can track whether the node representations are collapsing into the same vector.

Graph Construction

We use a 6-node ring structure as a simple molecular graph. Each node starts with a unique identity (using identity matrix torch.eye(6) as input features), and all nodes are connected in a cycle:

Graph Construction Process
Step 1:
Create node features
Identity matrix (6×6)
Step 2:
Define ring topology
Each node → 2 neighbors
Step 3:
Make bidirectional
12 directed edges total
Result:
PyG Data object
Ready for GNN
import torch
from torch_geometric.data import Data

# Each node has a unique 6D feature vector (identity matrix)
x = torch.eye(6)

# Define edges for a 6-node cycle (each edge is bidirectional)
edge_index = torch.tensor([
    [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5, 0, 5, 0, 1, 2, 3, 4]
], dtype=torch.long)

# Create PyTorch Geometric graph object
data = Data(x=x, edge_index=edge_index)

Over-smoothing Analysis

Now we apply the same GCN layer multiple times to simulate a deeper GNN. After each layer, we re-compute the node embeddings and compare them using cosine similarity:

Over-smoothing Measurement Process
Apply GCN layers:
Stack 1-10 layers
Same layer repeated
Compute similarity:
Cosine between nodes
Average all pairs
Track convergence:
Plot vs depth
Watch similarity → 1
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv

def measure_smoothing(num_layers, data):
    """
    Apply num_layers GCNConv layers and measure
    how similar node embeddings become.
    """
    x = data.x
    for _ in range(num_layers):
        conv = GCNConv(x.size(1), x.size(1))
        x = F.relu(conv(x, data.edge_index))

    # Normalize embeddings for cosine similarity
    x_norm = F.normalize(x, p=2, dim=1)
    
    # Cosine similarity matrix
    similarity_matrix = torch.mm(x_norm, x_norm.t())
    
    # Exclude diagonal (self-similarity) when averaging
    n = x.size(0)
    mask = ~torch.eye(n, dtype=torch.bool)
    avg_similarity = similarity_matrix[mask].mean().item()
    
    return avg_similarity

# Run for different GNN depths
depths = [1, 3, 5, 10]
sims = []
for depth in depths:
    sim = measure_smoothing(depth, data)
    sims.append(sim)
    print(f"Depth {depth}: Average similarity = {sim:.3f}")

# Plot the smoothing effect
plt.plot(depths, sims, marker='o')
plt.xlabel("Number of GCN Layers")
plt.ylabel("Average Cosine Similarity")
plt.title("Over-smoothing Effect in GNNs")
plt.grid(True)
plt.show()

Output

Depth 1: Average similarity = 0.406
Depth 3: Average similarity = 0.995
Depth 5: Average similarity = 0.993
Depth 10: Average similarity = 1.000

Over-smoothing in GNNs

Interpretation of Results

Depth Similarity What It Means Practical Impact
1 layer 0.406 Nodes still distinct Can identify different atoms
3 layers 0.995 Nearly identical Losing atomic identity
5 layers 0.993 Effectively same No useful information
10 layers 1.000 Complete collapse Model is useless

As shown above, as the number of message-passing layers increases, node representations converge. Initially distinct feature vectors (left) become nearly indistinguishable after several layers (right), resulting in the loss of structural information. This phenomenon is known as over-smoothing and is a critical limitation of deep GNNs.

Interpretation

As we can see, even at just 3 layers, the node embeddings become nearly identical. By 10 layers, the model has effectively lost all ability to distinguish individual atoms. This is the core issue of over-smoothing—deep GNNs can blur out meaningful structural differences.

Solutions to Over-smoothing

Technique How It Works Implementation Effectiveness
Residual Connections Skip connections preserve original features x = x + GCN(x) Very effective
Feature Concatenation Combine features from multiple layers concat(x₁, x₂, ...) Good for shallow nets
Batch Normalization Normalize features per layer BatchNorm after GCN Moderate help
Jumping Knowledge Aggregate all layer outputs JK networks State-of-the-art

To mitigate this problem, modern GNNs use techniques like:

  • Residual connections (skip connections that reintroduce raw input)
  • Feature concatenation from earlier layers
  • Batch normalization or graph normalization
  • Jumping knowledge networks to combine representations from multiple layers

When working with molecular graphs, you should choose the depth of your GNN carefully. It should be deep enough to capture important substructures, but not so deep that you lose atomic-level details.

Interpretability in Molecular GNNs

Beyond the technical challenge of over-smoothing, GNNs face a critical issue of interpretability. When a model predicts that a molecule might be toxic or have specific properties, chemists need to understand which structural features drive that prediction. This “black box” nature of neural networks is particularly problematic in chemistry, where understanding structure-activity relationships is fundamental to rational drug design.

Why Interpretability Matters in Chemistry

Stakeholder Need Example Impact
Medicinal Chemists Understand SAR
Structure-Activity Relationships
Which groups increase potency? Guide drug optimization
Regulatory Bodies Safety justification
Why is it safe?
Explain toxicity predictions FDA approval
Researchers Scientific insight
New mechanisms
Discover new pharmacophores Advance knowledge
Industry Risk assessment
Confidence in predictions
Why invest in this molecule? Resource allocation

Recent advances in GNN interpretability for molecular applications have taken several promising directions:

Attention-Based Methods:

Attention-Based Interpretability
Method:
Graph Attention Networks
GATs
How it works:
Learn importance weights
α_ij for each edge
Visualization:
Highlight important bonds
Thicker = more important
Reference:
Veličković et al., 2017
ICLR

Graph Attention Networks (GATs) provide built-in interpretability through their attention mechanisms, allowing researchers to visualize which atoms or bonds the model considers most important for a given prediction [1,2]. This approach naturally aligns with chemical intuition about reactive sites and functional groups.

Substructure-Based Explanations:

Substructure Mask Explanation (SME)
Innovation:
Fragment-based
Not just atoms
Alignment:
Chemical intuition
Functional groups
Application:
Toxicophore detection
Find toxic substructures
Reference:
Nature Comms, 2023
14, 2585

The Substructure Mask Explanation (SME) method represents a significant advance by providing interpretations based on chemically meaningful molecular fragments rather than individual atoms or edges [3]. This approach uses established molecular segmentation methods to ensure explanations align with chemists’ understanding, making it particularly valuable for identifying pharmacophores and toxicophores.

Integration of Chemical Knowledge:

Pharmacophore-Integrated GNNs
Concept:
Hierarchical modeling
Multi-level structure
Benefit 1:
Better performance
Domain knowledge helps
Benefit 2:
Natural interpretability
Pharmacophore-level
Reference:
J Cheminformatics, 2022
14, 49

Recent work has shown that incorporating pharmacophore information hierarchically into GNN architectures not only improves prediction performance but also enhances interpretability by explicitly modeling chemically meaningful substructures [4]. This bridges the gap between data-driven learning and domain expertise.

Gradient-Based Attribution:

SHAP for Molecular GNNs
Method:
SHapley values
Game theory based
Advantage:
Rigorous foundation
Additive features
Output:
Feature importance
Per atom/bond
Reference:
Lundberg & Lee, 2017
NeurIPS

Methods like SHAP (SHapley Additive exPlanations) have been successfully applied to molecular property prediction, providing feature importance scores that help identify which molecular characteristics most influence predictions [5,6]. These approaches are particularly useful for understanding global model behavior across different molecular classes.

Comparative Studies:

GNNs vs Traditional Methods
Aspect GNNs Descriptor-based Recommendation
Performance Often superior
Complex patterns
Good baseline
Well-understood
Task-dependent
Interpretability Challenging
Requires extra work
Built-in
Known features
Hybrid approach
Reference Jiang et al., 2021, J Cheminformatics

Recent comparative studies have shown that while GNNs excel at learning complex patterns, traditional descriptor-based models often provide better interpretability through established chemical features, suggesting a potential hybrid approach combining both paradigms [6].

The Future: Interpretable-by-Design

The field is moving toward interpretable-by-design architectures rather than post-hoc explanation methods. As noted by researchers, some medicinal chemists value interpretability over raw accuracy if a small sacrifice in performance can significantly enhance understanding of the model's reasoning [3]. This reflects a broader trend in molecular AI toward building systems that augment rather than replace human chemical intuition.

Design Principle Implementation Example
Chemical hierarchy Multi-scale architectures Atom → Group → Molecule
Explicit substructures Pharmacophore encoding H-bond donors as nodes
Modular predictions Separate property modules Solubility + Toxicity branches

Summary

Key Takeaways: Challenges and Solutions

Challenge Impact Current Solutions Future Directions
Over-smoothing Limits depth to 3-5 layers
Can't capture long-range
• Residual connections
• Jumping knowledge
• Normalization
Novel architectures
Beyond message passing
Interpretability Low trust & adoption
Can't guide design
• Attention visualization
• SHAP values
• Substructure masking
Interpretable-by-design
Chemical hierarchy

The Path Forward:

  • Balance accuracy with interpretability - Sometimes 90% accuracy with clear explanations beats 95% black box
  • Incorporate domain knowledge - Chemical principles should guide architecture design
  • Develop hybrid approaches - Combine GNN power with traditional descriptor interpretability
  • Focus on augmenting chemists - Tools should enhance, not replace, human expertise

The challenges facing molecular GNNs—over-smoothing and interpretability—are significant but surmountable. Over-smoothing limits the depth of networks we can effectively use, constraining the model’s ability to capture long-range molecular interactions. Meanwhile, the interpretability challenge affects trust and adoption in real-world applications where understanding model decisions is crucial.

Current solutions include architectural innovations like residual connections to combat over-smoothing, and various interpretability methods ranging from attention visualization to substructure-based explanations. The key insight is that effective molecular AI systems must balance predictive power with chemical interpretability, ensuring that models not only make accurate predictions but also provide insights that align with and enhance human understanding of chemistry.

As the field progresses, the focus is shifting from purely accuracy-driven models to systems that provide transparent, chemically meaningful explanations for their predictions. This evolution is essential for GNNs to fulfill their promise as tools for accelerating molecular discovery and understanding.

References (3.3.4)
  • Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., & Bengio, Y. (2017). Graph Attention Networks. International Conference on Learning Representations.
  • Yuan, H., Yu, H., Gui, S., & Ji, S. (2022). Explainability in graph neural networks: A taxonomic survey. IEEE Transactions on Pattern Analysis and Machine Intelligence.
  • Chemistry-intuitive explanation of graph neural networks for molecular property prediction with substructure masking. (2023). Nature Communications, 14, 2585.
  • Integrating concept of pharmacophore with graph neural networks for chemical property prediction and interpretation. (2022). Journal of Cheminformatics, 14, 52.
  • Lundberg, S. M., & Lee, S. I. (2017). A unified approach to interpreting model predictions. Advances in Neural Information Processing Systems, 30, 4765-4774.
  • Jiang, D., Wu, Z., Hsieh, C. Y., Chen, G., Liao, B., Wang, Z., … & Hou, T. (2021). Could graph neural networks learn better molecular representation for drug discovery? A comparison study of descriptor-based and graph-based models. Journal of Cheminformatics, 13(1), 1-23.

Section 3.3 – Quiz Questions

1) Factual Questions

Question 1

What is the primary advantage of using Graph Neural Networks (GNNs) over traditional neural networks for molecular property prediction?

A. GNNs require less computational resources
B. GNNs can directly process the graph structure of molecules
C. GNNs always achieve higher accuracy than other methods
D. GNNs work only with small molecules

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: GNNs can directly process molecules as graphs where atoms are nodes and bonds are edges, preserving the structural information that is crucial for determining molecular properties.

Question 2

In the message passing mechanism of GNNs, what happens during the aggregation step?

A. Node features are updated using a neural network
B. Messages from neighboring nodes are combined
C. Edge features are initialized
D. The final molecular prediction is made

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: During aggregation, all incoming messages from neighboring nodes are combined (typically by summing or averaging) to form a single aggregated message for each node.

Question 3

Which of the following molecular representations is most suitable as input for a Graph Neural Network?

A. SMILES string directly as text
B. 2D image of the molecular structure
C. Graph with nodes as atoms and edges as bonds
D. List of molecular descriptors only

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: GNNs are designed to work with graph-structured data where nodes represent atoms and edges represent chemical bonds, allowing the model to learn from the molecular connectivity.

Question 4

What is the “over-smoothing” problem in Graph Neural Networks?

A. The model becomes too complex to train
B. Node representations become increasingly similar in deeper networks
C. The model cannot handle large molecules
D. Training takes too much time

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: Over-smoothing occurs when deep GNNs make node representations increasingly similar across layers, losing the ability to distinguish between different atoms and their local environments.

2) Conceptual Questions

Question 5

You want to build a GNN to predict molecular solubility (a continuous value). Which combination of pooling and output layers would be most appropriate?

A.

# Mean pooling + regression output
x = global_mean_pool(x, batch)
output = nn.Linear(hidden_dim, 1)(x)

B.

# Max pooling + classification output  
x = global_max_pool(x, batch)
output = nn.Sequential(nn.Linear(hidden_dim, 2), nn.Softmax())(x)

C.

# No pooling + multiple outputs
output = nn.Linear(hidden_dim, num_atoms)(x)

D.

# Sum pooling + sigmoid output
x = global_add_pool(x, batch) 
output = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())(x)
▶ Click to show answer Correct Answer: A
▶ Click to show explanation Explanation: For continuous property prediction (regression), we need to pool node features to get a molecular-level representation, then use a linear layer to output a single continuous value. Mean pooling is commonly used and effective for this purpose.
▶ Click to see code: Complete GNN architecture for solubility prediction

# Complete GNN for solubility prediction
class SolubilityGNN(nn.Module):
    def __init__(self, node_features, hidden_dim=64):
        super(SolubilityGNN, self).__init__()
        self.conv1 = GCNConv(node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.predictor = nn.Linear(hidden_dim, 1)
    
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)  # Pool to molecular level
        return self.predictor(x)        # Single continuous output

Question 6

A chemist notices that their GNN model performs well on training molecules but poorly on a new set of structurally different compounds. What is the most likely cause and solution?

A. The model is too simple; add more layers
B. The model suffers from distribution shift; collect more diverse training data
C. The learning rate is too high; reduce it
D. The model has too many parameters; reduce model size

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: This scenario describes distribution shift, where the model was trained on one chemical space but tested on a different one. The solution is to include more diverse molecular structures in the training data to improve generalization.
▶ Click to see code: Data augmentation for chemical space diversity

# Data augmentation to improve generalization
def augment_chemical_space(original_smiles_list):
    """Expand training data with structural diversity"""
    augmented_data = []
    
    for smiles in original_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        
        # Add original
        augmented_data.append(smiles)
        
        # Add different SMILES representations
        for _ in range(3):
            random_smiles = Chem.MolToSmiles(mol, doRandom=True)
            augmented_data.append(random_smiles)
    
    return augmented_data

# Use diverse training data from multiple chemical databases
diverse_training_data = combine_datasets([
    'drug_molecules.csv',
    'natural_products.csv', 
    'synthetic_compounds.csv'
])

results matching ""

    No results matching ""