3. Machine Learning Models

Machine learning (ML) is a subfield of artificial intelligence (AI) focused on developing algorithms and statistical models that enable computers to learn from and make decisions based on data. Unlike traditional programming, where explicit instructions are given, machine learning systems identify patterns and insights from large datasets, improving their performance over time through experience.

ML encompasses various techniques, including supervised learning, where models are trained on labeled data to predict outcomes; unsupervised learning, which involves discovering hidden patterns or groupings within unlabeled data; and reinforcement learning, where models learn optimal actions through trial and error in dynamic environments. These methods are applied across diverse domains, from natural language processing and computer vision to recommendation systems and autonomous vehicles, revolutionizing how technology interacts with the world.

3.1 Decision Tree and Random Forest

3.1.1 Decision Trees

Decision Trees are intuitive and powerful models used in machine learning to make predictions and decisions. Think of it like playing a game of 20 questions, where each question helps you narrow down the possibilities. Decision trees function similarly; they break down a complex decision into a series of simpler questions based on the data.

Each question, referred to as a “decision,” relies on a specific characteristic or feature of the data. For instance, if you’re trying to determine whether a fruit is an apple or an orange, the initial question might be, “Is the fruit’s color red or orange?” Depending on the answer, you might follow up with another question—such as, “Is the fruit’s size small or large?” This questioning process continues until you narrow it down to a final answer (e.g., the fruit is either an apple or an orange).

In a decision tree, these questions are represented as nodes, and the possible answers lead to different branches. The final outcomes are represented at the end of each branch, known as leaf nodes. One of the key advantages of decision trees is their clarity and ease of understanding—much like a flowchart. However, they can also be prone to overfitting, especially when dealing with complex datasets that have many features. Overfitting occurs when a model performs exceptionally well on training data but fails to generalize to new or unseen data.

In summary, decision trees offer an intuitive approach to making predictions and decisions, but caution is required to prevent them from becoming overly complicated and tailored too closely to the training data.

3.1.2 Random Forest

Random Forests address the limitations of decision trees by utilizing an ensemble of multiple trees instead of relying on a single one. Imagine you’re gathering opinions about a game outcome from a group of people; rather than trusting just one person’s guess, you ask everyone and then take the most common answer. This is the essence of how a Random Forest operates.

In a Random Forest, numerous decision trees are constructed, each making its own predictions. However, a key difference is that each tree is built using a different subset of the data and considers different features of the data. This technique, known as bagging (Bootstrap Aggregating), allows each tree to provide a unique perspective, which collectively leads to a more reliable prediction.

When making a final prediction, the Random Forest aggregates the predictions from all the trees. For classification tasks, it employs majority voting to determine the final class label, while for regression tasks, it averages the results.

Random Forests typically outperform individual decision trees because they are less likely to overfit the data. By combining multiple trees, they achieve a balance between model complexity and predictive performance on unseen data.

Real-Life Analogy

Consider Andrew, who wants to decide on a destination for his year-long vacation. He starts by asking his close friends for suggestions. The first friend asks Andrew about his past travel preferences, using his answers to recommend a destination. This is akin to a decision tree approach—one friend following a rule-based decision process.

Next, Andrew consults more friends, each of whom poses different questions to gather recommendations. Finally, Andrew chooses the places suggested most frequently by his friends, mirroring the Random Forest algorithm’s method of aggregating multiple decision trees’ outputs.

3.1.3 Implementing Random Forest on the BBBP Dataset

This guide demonstrates how to implement a Random Forest algorithm in Python using the BBBP (Blood–Brain Barrier Permeability) dataset. The BBBP dataset is used in cheminformatics to predict whether a compound can cross the blood-brain barrier based on its chemical structure.

The dataset contains SMILES (Simplified Molecular Input Line Entry System) strings representing chemical compounds, and a target column that indicates whether the compound is permeable to the blood-brain barrier or not.

The goal is to predict whether a given chemical compound will cross the blood-brain barrier, based on its molecular structure. This guide walks you through downloading the dataset, processing it, and training a Random Forest model.

Step 1: Install RDKit (Required for SMILES to Fingerprint Conversion)

We need to use the RDKit library, which is essential for converting SMILES strings into molecular fingerprints, a numerical representation of the molecule.

# Install the RDKit package via conda-forge
!pip install -q condacolab
import condacolab
condacolab.install()

# Now install RDKit
!mamba install -c conda-forge rdkit -y

# Import RDKit and check if it's installed successfully
from rdkit import Chem
print("RDKit is successfully installed!")

Step 2: Download the BBBP Dataset from Kaggle

The BBBP dataset is hosted on Kaggle, a popular platform for datasets and machine learning competitions. To access the dataset, you need a Kaggle account and an API key for authentication. Here’s how you can set it up:

Step 2.1: Create a Kaggle Account
  1. Visit Kaggle and create an account if you don’t already have one.
  2. Once you’re logged in, go to your profile by clicking on your profile picture in the top right corner, and select My Account.
Step 2.2: Set Up the Kaggle API Key
  1. Scroll down to the section labeled API on your account page.
  2. Click on the button “Create New API Token”. This will download a file named kaggle.json to your computer.
  3. Keep this file safe! It contains your API key, which you’ll use to authenticate when downloading datasets.
Step 2.3: Upload the Kaggle API Key

Once you have the kaggle.json file, you need to upload it to your Python environment:

  1. If you’re using a notebook environment like Google Colab, use the code below to upload the file:
# Upload the kaggle.json file from google.colab import 
files uploaded = files.upload() 
# Move the file to the right directory for authentication 
!mkdir -p ~/.kaggle !mv kaggle.json ~/.kaggle/ !chmod 600 ~/.kaggle/kaggle.json
  1. If you’re using a local Jupyter Notebook: Place the kaggle.json file in a folder named .kaggle within your home directory:
    • On Windows: Place it in C:\Users<YourUsername>.kaggle.
    • On Mac/Linux: Place it in ~/.kaggle.
Step 2.4: Install the Required Libraries

To interact with Kaggle and download the dataset, you need the Kaggle API client. Install it with the following command:

!pip install kaggle
Step 2.5: Download the BBBP Dataset

Now that the API key is set up, you can download the dataset using the Kaggle API:

# Download the BBBP dataset using the Kaggle API 
!kaggle datasets download -d priyanagda/bbbp-smiles 
# Unzip the downloaded file 
!unzip bbbp-smiles.zip -d bbbp_dataset

This code will:

  1. Download the dataset into your environment.
  2. Extract the dataset files into a folder named bbbp_dataset.
Step 2.6: Verify the Download

After downloading, check the dataset files to confirm that everything is in place:

# List the files in the dataset folder 
import os 
dataset_path = "bbbp_dataset" 
files = os.listdir(dataset_path) 
print("Files in the dataset:", files)

By following these steps, you will have successfully downloaded and extracted the BBBP dataset, ready for further analysis and processing.

Step 3: Load the BBBP Dataset

After downloading the dataset, we’ll load the BBBP dataset into a pandas DataFrame. The dataset contains the SMILES strings and the target variable (p_np), which indicates whether the compound can cross the blood-brain barrier (binary classification: 1 for permeable, 0 for non-permeable).

import pandas as pd

# Load the BBBP dataset (adjust the filename if it's different)
data = pd.read_csv("bbbp.csv")  # Assuming the dataset is named bbbp.csv
print("Dataset Head:", data.head())

Step 4: Convert SMILES to Molecular Fingerprints

To use the SMILES strings for modeling, we need to convert them into molecular fingerprints. This process turns the chemical structures into a numerical format that can be fed into machine learning models. We’ll use RDKit to generate these fingerprints using the Morgan Fingerprint method.

from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

# Function to convert SMILES to molecular fingerprints
def featurize_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
    else:
        return None

# Apply featurization to the dataset
features = [featurize_molecule(smi) for smi in data['smiles']]  # Replace 'smiles' with the actual column name if different
features = [list(fp) if fp is not None else np.zeros(1024) for fp in features]  # Handle missing data by filling with zeros
X = np.array(features)
y = data['p_np']  # Target column (1 for permeable, 0 for non-permeable)

The diagram below provides a visual representation of what this code does:

Smiles Diagram

Figure: SMILES to Molecular Fingerprints Conversion Process

Step 5: Split Data into Training and Testing Sets

To evaluate the model, we need to split the data into training and testing sets. The train_test_split function from scikit-learn will handle this. We’ll use 80% of the data for training and 20% for testing.

from sklearn.model_selection import train_test_split

# Split data into train and test sets (80% training, 20% testing)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

The diagram below provides a visual representation of what this code does:

Train Test Split Diagram

Figure: Data Splitting Process for Training and Testing

Step 6: Train the Random Forest Model

We’ll use the RandomForestClassifier from scikit-learn to build the model. A Random Forest is an ensemble method that uses multiple decision trees to make predictions. The more trees (n_estimators) we use, the more robust the model will be, but the longer the model will take to run. For the most part, n_estimators is set to 100 in most versions of scikit-learn. However, for more complex datasets, higher values like 500 or 1000 may improve performance.

from sklearn.ensemble import RandomForestClassifier

# Train a Random Forest classifier
rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)

The diagram below provides a visual explanation of what is going on here:

Random Forest Decision Tree Diagram

Figure: Random Forest Algorithm Structure

Step 7: Evaluate the Model

After training the model, we’ll use the test data to evaluate its performance. We will print the accuracy and the classification report to assess the model’s precision, recall, and F1 score.

from sklearn.metrics import accuracy_score, classification_report

# Predictions on the test set
y_pred = rf_model.predict(X_test)

# Evaluate accuracy and performance
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("Classification Report:", classification_report(y_test, y_pred))

Model Performance and Parameters

  • Accuracy: The proportion of correctly predicted instances out of all instances.
  • Classification Report: Provides additional metrics like precision, recall, and F1 score.

In this case, we achieved an accuracy score of ~87%.

Key Hyperparameters:

  • n_estimators: The number of trees in the Random Forest. More trees generally lead to better performance but also require more computational resources.
  • test_size: The proportion of data used for testing. A larger test size gives a more reliable evaluation but reduces the amount of data used for training.
  • random_state: Ensures reproducibility by initializing the random number generator to a fixed seed.

Conclusion

This guide demonstrated how to implement a Random Forest model to predict the Blood–Brain Barrier Permeability (BBBP) using the BBBP dataset. By converting SMILES strings to molecular fingerprints and using a Random Forest classifier, we were able to achieve an accuracy score of around 87%.

Adjusting parameters like the number of trees (n_estimators) or the split ratio (test_size) can help improve the model’s performance. Feel free to experiment with these parameters and explore other machine learning models for this task!

3.1.4 Approaching Random Forest Problems

When tackling a classification or regression problem using the Random Forest algorithm, a systematic approach can enhance your chances of success. Here’s a step-by-step guide to effectively solve any Random Forest problem:

  1. Understand the Problem Domain: Begin by thoroughly understanding the problem you are addressing. Identify the nature of the data and the specific goal—whether it’s classification (e.g., predicting categories) or regression (e.g., predicting continuous values). Familiarize yourself with the dataset, including the features (independent variables) and the target variable (dependent variable).

  2. Data Collection and Preprocessing: Gather the relevant dataset and perform necessary preprocessing steps. This may include handling missing values, encoding categorical variables, normalizing or standardizing numerical features, and removing any outliers. Proper data cleaning ensures that the model learns from quality data.

  3. Exploratory Data Analysis (EDA): Conduct an exploratory data analysis to understand the underlying patterns, distributions, and relationships within the data. Visualizations, such as scatter plots, histograms, and correlation matrices, can provide insights that inform feature selection and model tuning.

  4. Feature Selection and Engineering: Identify the most relevant features for the model. This can be achieved through domain knowledge, statistical tests, or feature importance metrics from preliminary models. Consider creating new features through feature engineering to enhance model performance.

  5. Model Training and Parameter Tuning: Split the dataset into training and testing sets, typically using an 80-20 or 70-30 ratio. Train the Random Forest model using the training data, adjusting parameters such as the number of trees (n_estimators), the maximum depth of the trees (max_depth), and the minimum number of samples required to split an internal node (min_samples_split). Utilize techniques like grid search or random search to find the optimal hyperparameters.

  6. Model Evaluation: Once trained, evaluate the model’s performance on the test set using appropriate metrics. For classification problems, metrics such as accuracy, precision, recall, F1 score, and ROC-AUC are valuable. For regression tasks, consider metrics like mean absolute error (MAE), mean squared error (MSE), and R-squared.

  7. Interpretation and Insights: Analyze the model’s predictions and feature importance to derive actionable insights. Understanding which features contribute most to the model can guide decision-making and further improvements in the model or data collection.

  8. Iterate and Improve: Based on the evaluation results, revisit the previous steps to refine your model. This may involve further feature engineering, collecting more data, or experimenting with different algorithms alongside Random Forest to compare performance.

  9. Deployment: Once satisfied with the model’s performance, prepare it for deployment. Ensure the model can process incoming data and make predictions in a real-world setting, and consider implementing monitoring tools to track its performance over time.

By following this structured approach, practitioners can effectively leverage the Random Forest algorithm to solve a wide variety of problems, ensuring thorough analysis, accurate predictions, and actionable insights.

3.1.5 Strengths and Weaknesses of Random Forest

Strengths:

  • Robustness: Random Forests are less prone to overfitting compared to individual decision trees, making them more reliable for new data.

  • Versatility: They can handle both classification and regression tasks effectively.

  • Feature Importance: Random Forests provide insights into the significance of each feature in making predictions.

Weaknesses:

  • Complexity: The model can become complex, making it less interpretable than single decision trees.

  • Resource Intensive: Training a large number of trees can require significant computational resources and time.

  • Slower Predictions: While individual trees are quick to predict, aggregating predictions from multiple trees can slow down the prediction process.


Section 3.1 – Quiz Questions

1) Factual Questions

Question 1

What is the primary reason a Decision Tree might perform very well on training data but poorly on new, unseen data?

A. Underfitting
B. Data leakage
C. Overfitting
D. Regularization

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: Decision Trees can easily overfit the training data by creating very complex trees that capture noise instead of general patterns. This hurts their performance on unseen data.

Question 2

In a Decision Tree, what do the internal nodes represent?

A. Possible outcomes
B. Splitting based on a feature
C. Aggregation of multiple trees
D. Random subsets of data

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: Internal nodes represent decision points where the dataset is split based on the value of a specific feature (e.g., "Is the fruit color red or orange?").

Question 3

Which of the following best explains the Random Forest algorithm?

A. A single complex decision tree trained on all the data
B. Many decision trees trained on identical data to improve depth
C. Many decision trees trained on random subsets of the data and features
D. A clustering algorithm that separates data into groups

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: Random Forests use bagging to train multiple decision trees on different random subsets of the data and different random subsets of features, making the ensemble more robust.

Question 4

When training a Random Forest for a classification task, how is the final prediction made?

A. By taking the median of the outputs
B. By taking the average of probability outputs
C. By majority vote among trees’ predictions
D. By selecting the tree with the best accuracy

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: For classification problems, the Random Forest algorithm uses majority voting — the class most predicted by the individual trees becomes the final prediction.

2) Conceptual Questions

Question 5

You are given a dataset containing information about chemical compounds, with many categorical features (such as “molecular class” or “bond type”).
Would using a Random Forest model be appropriate for this dataset?

A. No, Random Forests cannot handle categorical data.
B. Yes, Random Forests can naturally handle datasets with categorical variables after encoding.
C. No, Random Forests only work on images.
D. Yes, but only if the dataset has no missing values.

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: Random Forests can handle categorical data after simple preprocessing, such as label encoding or one-hot encoding. They are robust to different feature types, including numerical and categorical.

Question 6

Suppose you have your molecule fingerprints stored in variables X and your labels (0 or 1 for BBBP) stored in y.
Which of the following correctly splits the data into 80% training and 20% testing sets?

A.

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.2, random_state=42)

B.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

C.

X_train, X_test = train_test_split(X, y, test_size=0.8, random_state=42)

D.

X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: In Random Forest modeling, we use train_test_split from sklearn.model_selection. test_size=0.2 reserves 20% of the data for testing, leaving 80% for training. The function returns train features, test features, train labels, and test labels — in that exact order: X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) A, C, and D are wrong because... (A) reverses train and test sizing. (C) mistakenly sets test_size=0.8 (which would leave only 20% for training — wrong). (D) messes up the return order (train features and labels must come first).
▶ Show Solution Code

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

3.2 Neural Network

A neural network is a computational model inspired by the neural structure of the human brain, designed to recognize patterns and learn from data. It consists of layers of interconnected nodes, or neurons, which process input data through weighted connections.

Structure: Neural networks typically include an input layer, one or more hidden layers, and an output layer. Each neuron in a layer is connected to neurons in the adjacent layers. The input layer receives data, the hidden layers transform this data through various operations, and the output layer produces the final prediction or classification.

Functioning: Data is fed into the network, where each neuron applies an activation function to its weighted sum of inputs. These activation functions introduce non-linearity, allowing the network to learn complex patterns. The output of the neurons is then passed to the next layer until the final prediction is made.

Learning Process: Neural networks learn through a process called training. During training, the network adjusts the weights of connections based on the error between its predictions and the actual values. This is achieved using algorithms like backpropagation and optimization techniques such as gradient descent, which iteratively updates the weights to minimize the prediction error.

3.2.1 Biological and Conceptual Foundations of Neural Networks

Neural networks are a class of machine learning models designed to learn patterns from data in order to make predictions or classifications. Their structure and behavior are loosely inspired by how the human brain processes information: through a large network of connected units that transmit signals to each other. Although artificial neural networks are mathematical rather than biological, this analogy provides a helpful starting point for understanding how they function.

The Neural Analogy

In a biological system, neurons receive input signals from other neurons, process those signals, and send output to downstream neurons. Similarly, an artificial neural network is composed of units called “neurons” or “nodes” that pass numerical values from one layer to the next. Each of these units receives inputs, processes them using a simple rule, and forwards the result.

This structure allows the network to build up an understanding of the input data through multiple layers of transformations. As information flows forward through the network—layer by layer—it becomes increasingly abstract. Early layers may focus on basic patterns in the input, while deeper layers detect more complex or chemically meaningful relationships.

Layers of a Neural Network

Neural networks are organized into three main types of layers:

  • Input Layer: This is where the network receives the data. In chemistry applications, this might include molecular fingerprints, structural descriptors, or other numerical representations of a molecule.
  • Hidden Layers: These are the internal layers where computations happen. The network adjusts its internal parameters to best relate the input to the desired output.
  • Output Layer: This layer produces the final prediction. For example, it might output a predicted solubility value, a toxicity label, or the probability that a molecule is biologically active.

The depth (number of layers) and width (number of neurons in each layer) of a network affect its capacity to learn complex relationships.

Why Chemists Use Neural Networks

Many molecular properties—such as solubility, lipophilicity, toxicity, and biological activity—are influenced by intricate, nonlinear combinations of atomic features and substructures. These relationships are often difficult to express with a simple equation or rule.

Neural networks are especially useful in chemistry because:

  • They can learn from large, complex datasets without needing detailed prior knowledge about how different features should be weighted.
  • They can model nonlinear relationships, such as interactions between molecular substructures, electronic effects, and steric hindrance.
  • They are flexible and can be applied to a wide range of tasks, from predicting reaction outcomes to screening drug candidates.

How Learning Happens

Unlike hardcoded rules, neural networks improve through a process of learning:

  1. Prediction: The network uses its current understanding to make a guess about the output (e.g., predicting a molecule’s solubility).
  2. Feedback: It compares its prediction to the known, correct value.
  3. Adjustment: It updates its internal parameters to make better predictions next time.

This process repeats over many examples, gradually improving the model’s accuracy. Over time, the network can generalize—making reliable predictions on molecules it has never seen before.

3.2.2 The Structure of a Neural Network

Completed and Compiled Code: Click Here

The structure of a neural network refers to how its components are organized and how information flows from the input to the output. Understanding this structure is essential for applying neural networks to chemical problems, where numerical data about molecules must be transformed into meaningful predictions—such as solubility, reactivity, toxicity, or classification into chemical groups.

Basic Building Blocks

A typical neural network consists of three types of layers:

  1. Input Layer

This is the first layer and represents the data you give the model. In chemistry, this might include:

  • Molecular fingerprints (e.g., Morgan or ECFP4)
  • Descriptor vectors (e.g., molecular weight, number of rotatable bonds)
  • Graph embeddings (in more advanced architectures)

Each input feature corresponds to one “neuron” in this layer. The network doesn’t modify the data here; it simply passes it forward.

  1. Hidden Layers

These are the core of the network. They are composed of interconnected neurons that process the input data through a series of transformations. Each neuron:

  • Multiplies each input by a weight (a learned importance factor)
  • Adds the results together, along with a bias term
  • Passes the result through an activation function to determine the output

Multiple hidden layers can extract increasingly abstract features. For example:

  • First hidden layer: detects basic structural motifs (e.g., aromatic rings)
  • Later hidden layers: model higher-order relationships (e.g., presence of specific pharmacophores)

The depth of a network (number of hidden layers) increases its capacity to model complex patterns, but also makes it more challenging to train.

  1. Output Layer

This layer generates the final prediction. The number of output neurons depends on the type of task:

  • One neuron for regression (e.g., predicting solubility)
  • One neuron with a sigmoid function for binary classification (e.g., active vs. inactive)
  • Multiple neurons with softmax for multi-class classification (e.g., toxicity categories)

Activation Functions

The activation function introduces non-linearity to the model. Without it, the network would behave like a linear regression model, unable to capture complex relationships. Common activation functions include:

  • ReLU (Rectified Linear Unit): Returns 0 for negative inputs and the input itself for positive values. Efficient and widely used.
  • Sigmoid: Squeezes inputs into the range (0,1), useful for probabilities.
  • Tanh: Similar to sigmoid but outputs values between -1 and 1, often used in earlier layers.

These functions allow neural networks to model subtle chemical relationships, such as how a substructure might enhance activity in one molecular context but reduce it in another.

Forward Pass: How Data Flows Through the Network

The process of making a prediction is known as the forward pass. Here’s what happens step-by-step:

  1. Each input feature (e.g., molecular weight = 300) is multiplied by a corresponding weight.
  2. The weighted inputs are summed and combined with a bias.
  3. The result is passed through the activation function.
  4. The output becomes the input to the next layer.

This process repeats until the final output is produced.

Building a Simple Neural Network for Molecular Property Prediction

Let’s build a minimal neural network that takes molecular descriptors as input and predicts a continuous chemical property, such as aqueous solubility. We’ll use TensorFlow and Keras.

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Example molecular descriptors for 5 hypothetical molecules:
# Features: [Molecular Weight, LogP, Number of Rotatable Bonds]
X = np.array([
    [180.1, 1.2, 3],
    [310.5, 3.1, 5],
    [150.3, 0.5, 2],
    [420.8, 4.2, 8],
    [275.0, 2.0, 4]
])

# Target values: Normalized aqueous solubility
y = np.array([0.82, 0.35, 0.90, 0.20, 0.55])

# Define a simple feedforward neural network
model = models.Sequential([
    layers.Input(shape=(3,)),              # 3 input features per molecule
    layers.Dense(8, activation='relu'),    # First hidden layer
    layers.Dense(4, activation='relu'),    # Second hidden layer
    layers.Dense(1)                        # Output layer (regression)
])

# Compile the model
model.compile(optimizer='adam', loss='mse')  # Mean Squared Error for regression

# Train the model
model.fit(X, y, epochs=100, verbose=0)

# Predict on new data
new_molecule = np.array([[300.0, 2.5, 6]])
predicted_solubility = model.predict(new_molecule)
print("Predicted Solubility:", predicted_solubility[0][0])

Results

Predicted Solubility: 13.366545

What This Code Does:

  • Inputs are numerical molecular descriptors (easy for chemists to relate to).
  • The model learns a pattern from these descriptors to predict solubility.
  • Layers are built exactly as explained: input → hidden (ReLU) → output.
  • The output is a single continuous number, suitable for regression tasks.

Practice Problem 3: Neural Network Warm-Up

Using the logic from the code above:

  1. Replace the input features with the following descriptors:
    • [350.2, 3.3, 5], [275.4, 1.8, 4], [125.7, 0.2, 1]
  2. Create a new NumPy array called X_new with those values.
  3. Use the trained model to predict the solubility of each new molecule.
  4. Print the outputs with a message like: “Predicted solubility for molecule 1: 0.67”
# Step 1: Create new molecular descriptors for prediction
X_new = np.array([
    [350.2, 3.3, 5],
    [275.4, 1.8, 4],
    [125.7, 0.2, 1]
])

# Step 2: Use the trained model to predict solubility
predictions = model.predict(X_new)

# Step 3: Print each result with a message
for i, prediction in enumerate(predictions):
    print(f"Predicted solubility for molecule {i + 1}: {prediction[0]:.2f}")

Discussion: What Did We Just Do?

In this practice problem, we used a trained neural network to predict the solubility of three new chemical compounds based on simple molecular descriptors. Each molecule was described using three features:

  1. Molecular weight
  2. LogP (a measure of lipophilicity)
  3. Number of rotatable bonds

The model, having already learned patterns from prior data during training, applied its internal weights and biases to compute a prediction for each molecule.

Predicted solubility for molecule 1: 0.38  
Predicted solubility for molecule 2: 0.55  
Predicted solubility for molecule 3: 0.91

These values reflect the model’s confidence in how soluble each molecule is, with higher numbers generally indicating better solubility. While we don’t yet know how the model arrived at these exact numbers (that comes in the next section), this exercise demonstrates a key advantage of neural networks:

  • Once trained, they can generalize to unseen data—making predictions for new molecules quickly and efficiently.

3.2.3 How Neural Networks Learn: Backpropagation and Loss Functions

Completed and Compiled Code: Click Here

In the previous section, we saw how a neural network can take molecular descriptors as input and generate predictions, such as aqueous solubility. However, this raises an important question: how does the network learn to make accurate predictions in the first place? The answer lies in two fundamental concepts: the loss function and backpropagation.

Loss Function: Measuring the Error

The loss function is a mathematical expression that quantifies how far off the model’s predictions are from the actual values. It acts as a feedback mechanism—telling the network how well or poorly it’s performing.

In regression tasks like solubility prediction, a common loss function is Mean Squared Error (MSE):

\[\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2\]

Where:

  • $\hat{y}_i$ is the predicted solubility
  • $y_i$ is the true solubility
  • $n$ is the number of samples

MSE penalizes larger errors more severely than smaller ones, which is especially useful in chemical property prediction where large prediction errors can have significant consequences.

Gradient Descent: Minimizing the Loss

Once the model calculates the loss, it needs to adjust its internal weights to reduce that loss. This optimization process is called gradient descent.

Gradient descent updates the model’s weights in the opposite direction of the gradient of the loss function:

\[w_{\text{new}} = w_{\text{old}} - \alpha \cdot \frac{\partial \text{Loss}}{\partial w}\]

Where:

  • $w$ is a weight in the network
  • $\alpha$ is the learning rate, a small scalar that determines the step size

This iterative update helps the model gradually “descend” toward a configuration that minimizes the prediction error.

Backpropagation: Updating the Network

Backpropagation is the algorithm that computes how to adjust the weights.

  1. It begins by computing the prediction and measuring the loss.
  2. Then, it calculates how much each neuron contributed to the final error by applying the chain rule from calculus.
  3. Finally, it adjusts all weights by propagating the error backward from the output layer to the input layer.

Over time, the network becomes better at associating input features with the correct output properties.

Intuition for Chemists

Think of a chemist optimizing a synthesis route. After a failed reaction, they adjust parameters (temperature, solvent, reactants) based on what went wrong. With enough trials and feedback, they achieve better yields.

A neural network does the same—after each “trial” (training pass), it adjusts its internal settings (weights) to improve its “yield” (prediction accuracy) the next time.

Visualizing Loss Reduction During Training

This code demonstrates how a simple neural network learns over time by minimizing error through backpropagation and gradient descent. It also visualizes the loss curve to help you understand how training progresses.

# 3.2.3 Example: Visualizing Loss Reduction During Training

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Simulated training data: [molecular_weight, logP, rotatable_bonds]
X_train = np.array([
    [350.2, 3.3, 5],
    [275.4, 1.8, 4],
    [125.7, 0.2, 1],
    [300.1, 2.5, 3],
    [180.3, 0.5, 2]
])

# Simulated solubility labels (normalized between 0 and 1)
y_train = np.array([0.42, 0.63, 0.91, 0.52, 0.86])

# Define a simple neural network
model = Sequential()
model.add(Dense(10, input_dim=3, activation='relu'))
model.add(Dense(1, activation='sigmoid'))  # Regression output

# Compile the model using MSE (Mean Squared Error) loss
model.compile(optimizer='adam', loss='mean_squared_error')

# Train the model and record loss values
history = model.fit(X_train, y_train, epochs=100, verbose=0)

# Plot the training loss over time
plt.plot(history.history['loss'])
plt.title('Loss Reduction During Training')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.grid(True)
plt.show()

This example demonstrates:

  • How the network calculates and minimizes the loss function (MSE)
  • How backpropagation adjusts weights over time
  • How loss consistently decreases with each epoch

Practice Problem: Observe the Learning Curve

Reinforce the concepts of backpropagation and gradient descent by modifying the model to exaggerate or dampen learning behavior.

  1. Change the optimizer from “adam” to “sgd” and observe how the loss reduction changes.
  2. Add validation_split=0.2 to model.fit() to visualize both training and validation loss.
  3. Plot both loss curves using matplotlib.
# Add validation and switch optimizer
model.compile(optimizer='sgd', loss='mean_squared_error')

history = model.fit(X_train, y_train, epochs=100, validation_split=0.2, verbose=0)

# Plot training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training vs Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

You should observe:

  1. Slower convergence when using SGD vs. Adam.
  2. Validation loss potentially diverging if overfitting begins.

3.2.4 Activation Functions

Completed and Compiled Code: Click Here

Activation functions are a key component of neural networks that allow them to model complex, non-linear relationships between inputs and outputs. Without activation functions, no matter how many layers we add, a neural network would essentially behave like a linear model. For chemists, this would mean failing to capture the non-linear relationships between molecular descriptors and properties such as solubility, reactivity, or binding affinity.

What Is an Activation Function?

An activation function is applied to the output of each neuron in a hidden layer. It determines whether that neuron should “fire” (i.e., pass information to the next layer) and to what degree.

Think of it like a valve in a chemical reaction pathway: the valve can allow the signal to pass completely, partially, or not at all—depending on the condition (input value). This gating mechanism allows neural networks to build more expressive models that can simulate highly non-linear chemical behavior.

Common Activation Functions (with Intuition)

Here are the most widely used activation functions and how you can interpret them in chemical modeling contexts:

1. ReLU (Rectified Linear Unit)

\[\text{ReLU}(x) = \max(0,x)\]

Behavior: Passes positive values as-is; blocks negative ones.
Analogy: A pH-dependent gate that opens only if the environment is basic (positive).
Use: Fast to compute; ideal for hidden layers in large models.

2. Sigmoid

\[\text{Sigmoid}(x) = \frac{1}{1 + e^{-x}}\]

Behavior: Maps input to a value between 0 and 1.
Analogy: Represents probability or confidence — useful when you want to interpret the output as “likelihood of solubility” or “chance of toxicity”.
Use: Often used in the output layer for binary classification.

3. Tanh (Hyperbolic Tangent)

\[\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]

Behavior: Outputs values between -1 and 1, centered around 0.
Analogy: Models systems with directionality — such as positive vs. negative binding affinity.
Use: Sometimes preferred over sigmoid in hidden layers.

Why Are They Important

Without activation functions, neural networks would be limited to computing weighted sums—essentially doing linear algebra. This would be like trying to model the melting point of a compound using only molecular weight: too simplistic for real-world chemistry.

Activation functions allow networks to “bend” input-output mappings, much like how a catalyst changes the energy profile of a chemical reaction.

Comparing ReLU and Sigmoid Activation Functions

This code visually compares how ReLU and Sigmoid behave across a range of inputs. Understanding the shapes of these activation functions helps chemists choose the right one for a neural network layer depending on the task (e.g., regression vs. classification).

# 3.2.4 Example: Comparing ReLU vs Sigmoid Activation Functions

import numpy as np
import matplotlib.pyplot as plt

# Define ReLU and Sigmoid activation functions
def relu(x):
    return np.maximum(0, x)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Input range
x = np.linspace(-10, 10, 500)

# Compute function outputs
relu_output = relu(x)
sigmoid_output = sigmoid(x)

# Plot the functions
plt.figure(figsize=(10, 6))
plt.plot(x, relu_output, label='ReLU', linewidth=2)
plt.plot(x, sigmoid_output, label='Sigmoid', linewidth=2)
plt.axhline(0, color='gray', linestyle='--', linewidth=0.5)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.5)
plt.title('Activation Function Comparison: ReLU vs Sigmoid')
plt.xlabel('Input (x)')
plt.ylabel('Activation Output')
plt.legend()
plt.grid(True)
plt.show()

This example demonstrates:

  • ReLU outputs 0 for any negative input and increases linearly for positive inputs. This makes it ideal for deep layers in large models where speed and sparsity are priorities.
  • Sigmoid smoothly maps all inputs to values between 0 and 1. This is useful for binary classification tasks, such as predicting whether a molecule is toxic or not.
  • Why this matters in chemistry: Choosing the right activation function can affect whether your neural network correctly learns properties like solubility, toxicity, or reactivity. For instance, sigmoid may be used in the output layer when predicting probabilities, while ReLU is preferred in hidden layers to retain training efficiency.

3.2.5 Training a Neural Network for Chemical Property Prediction

Completed and Compiled Code: Click Here

In the previous sections, we explored how neural networks are structured and how they learn. In this final section, we’ll put everything together by training a neural network on a small dataset of molecules to predict aqueous solubility — a property of significant importance in drug design and formulation.

Rather than using high-level abstractions, we’ll walk through the full training process: from preparing chemical data to building, training, evaluating, and interpreting a neural network model.

Chemical Context

Solubility determines how well a molecule dissolves in water, which affects its absorption and distribution in biological systems. Predicting this property accurately can save time and cost in early drug discovery. By using features like molecular weight, lipophilicity (LogP), and number of rotatable bonds, we can teach a neural network to approximate this property from molecular descriptors.

Step-by-Step Training Example

Goal: Predict normalized solubility values from 3 molecular descriptors:

  • Molecular weight
  • LogP
  • Number of rotatable bonds
# 3.2.5 Example: Training a Neural Network for Solubility Prediction

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

# Step 1: Simulated chemical data
X = np.array([
    [350.2, 3.3, 5],
    [275.4, 1.8, 4],
    [125.7, 0.2, 1],
    [300.1, 2.5, 3],
    [180.3, 0.5, 2],
    [410.0, 4.1, 6],
    [220.1, 1.2, 3],
    [140.0, 0.1, 1]
])
y = np.array([0.42, 0.63, 0.91, 0.52, 0.86, 0.34, 0.70, 0.95])  # Normalized solubility

# Step 2: Normalize features using MinMaxScaler
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# Step 3: Train-test split
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.25, random_state=42)

# Step 4: Build the neural network
model = Sequential()
model.add(Dense(16, input_dim=3, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))  # Output layer for regression (normalized range)

# Step 5: Compile and train
model.compile(optimizer='adam', loss='mean_squared_error')
history = model.fit(X_train, y_train, epochs=100, verbose=0)

# Step 6: Evaluate performance
loss = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss (MSE): {loss:.4f}")

# Step 7: Plot training loss
plt.plot(history.history['loss'])
plt.title("Training Loss Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.grid(True)
plt.show()

Interpreting the Results

  • The network gradually learns to predict solubility based on three molecular features.
  • The loss value shows the mean squared error on the test set—lower values mean better predictions.
  • The loss curve demonstrates whether the model is converging (flattening loss) or struggling (oscillating loss).

Summary

This section demonstrated how a basic neural network can be trained on molecular descriptors to predict solubility. While our dataset was small and artificial, the same principles apply to real-world cheminformatics datasets.

You now understand:

  • How to process input features from molecules
  • How to build and train a simple feedforward neural network
  • How to interpret loss, predictions, and model performance

This hands-on foundation prepares you to tackle more complex models like convolutional and graph neural networks in the next sections.


Section 3.2 – Quiz Questions

1) Factual Questions

Question 1

Which of the following best describes the role of the hidden layers in a neural network predicting chemical properties?

A. They store the molecular structure for visualization.
B. They transform input features into increasingly abstract representations.
C. They calculate the final solubility or toxicity score directly.
D. They normalize the input data before processing begins.

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: Hidden layers apply weights, biases, and activation functions to extract increasingly complex patterns (e.g., substructures, steric hindrance) from the input molecular data.

Question 2

Suppose you’re predicting aqueous solubility using a neural network. Which activation function in the hidden layers would be most suitable to introduce non-linearity efficiently, especially with large chemical datasets?

A. Softmax
B. Linear
C. ReLU
D. Sigmoid

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: ReLU is widely used in hidden layers for its computational efficiency and ability to handle vanishing gradient problems in large datasets.

Question 3

In the context of molecular property prediction, which of the following sets of input features is most appropriate for the input layer of a neural network?

A. IUPAC names and structural diagrams
B. Raw SMILES strings and melting points as text
C. Numerical descriptors like molecular weight, LogP, and rotatable bonds
D. Hand-drawn chemical structures and reaction mechanisms

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: Neural networks require numerical input. Molecular descriptors are quantifiable features that encode structural, electronic, and steric properties.

Question 4

Your neural network performs poorly on new molecular data but does very well on training data. Which of the following is most likely the cause?

A. The model lacks an output layer
B. The training set contains irrelevant descriptors
C. The network is overfitting due to too many parameters
D. The input layer uses too few neurons

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: Overfitting occurs when a model memorizes the training data but fails to generalize. This is common in deep networks with many parameters and not enough regularization or data diversity.

2) Conceptual Questions

Question 5

You are building a neural network to predict binary activity (active vs inactive) of molecules based on three features: [Molecular Weight, LogP, Rotatable Bonds].
Which code correctly defines the output layer for this classification task?

A. layers.Dense(1)
B. layers.Dense(1, activation=’sigmoid’)
C. layers.Dense(2, activation=’relu’)
D. layers.Dense(3, activation=’softmax’)

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: For binary classification, you need a single neuron with a sigmoid activation function to output a probability between 0 and 1.

Question 6

Why might a chemist prefer a neural network over a simple linear regression model for predicting molecular toxicity?

A. Neural networks can run faster than linear models.
B. Toxicity is not predictable using any mathematical model.
C. Neural networks can model nonlinear interactions between substructures.
D. Neural networks use fewer parameters and are easier to interpret.

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: Chemical toxicity often arises from complex, nonlinear interactions among molecular features—something neural networks can capture but linear regression cannot.

3.3 Graph Neural Network

Graph Neural Networks (GNNs) offer a new and powerful way to handle molecular machine learning. Traditional neural networks are good at working with fixed-size inputs, such as images or sequences. However, molecules are different in nature. They are best represented as graphs, where atoms are nodes and chemical bonds are edges. This kind of graph structure has always been central in chemistry, appearing in everything from simple Lewis structures to complex reaction pathways. GNNs make it possible for computers to work directly with this kind of data structure.

Unlike images or text, molecules do not follow a regular shape or order. This makes it hard for conventional neural networks to process them effectively. Convolutional neural networks (CNNs) are designed for image data, and recurrent neural networks (RNNs) are built for sequences, but neither is suited to the irregular and highly connected structure of molecules. As a result, older models often fail to capture how atoms are truly linked inside a molecule.

GNN Overview - Traditional Methods vs Graph Neural Networks Comparison between traditional neural network approaches and Graph Neural Networks for molecular machine learning. The flowchart illustrates why molecules require graph-based methods and how GNNs preserve structural information that traditional methods lose.

Before GNNs were introduced, chemists used what are known as molecular descriptors. These are numerical features based on molecular structure, such as how many functional groups a molecule has or how its atoms are arranged in space. These descriptors were used as input for machine learning models. However, they often lose important information about the exact way atoms are connected. This loss of detail limits how well the models can predict molecular behavior.

GNNs solve this problem by learning directly from the molecular graph. Instead of relying on handcrafted features, GNNs use the structure itself to learn what matters. Each atom gathers information from its neighbors in the graph, which helps the model understand the molecule as a whole. This approach leads to more accurate predictions and also makes the results easier to interpret.

In short, GNNs allow researchers to build models that reflect the true structure of molecules. They avoid the limitations of older methods by directly using the connections between atoms, offering a more natural and powerful way to predict molecular properties.

GNN Processing Pipeline Step-by-step visualization of how Graph Neural Networks process molecular graphs. The pipeline shows the flow from input molecular graph through message passing and aggregation to final property prediction.

3.3.1 What Are Graph Neural Networks?

Completed and Compiled Code: Click Here

Why Are Molecules Naturally Graphs?

Let’s start with the most fundamental question: Why do we say molecules are graphs?

Imagine a water molecule (H₂O). If you’ve taken chemistry, you know it looks like this:

H - O - H

This is already a graph! Let’s break it down:

  • Nodes (vertices): The atoms - one oxygen (O) and two hydrogens (H)
  • Edges (connections): The chemical bonds - two O-H bonds

Molecules as Graphs Concept Visualization showing how molecules naturally form graph structures. Water (H₂O) and ethanol (C₂H₆O) are shown in both chemical notation and graph representation, demonstrating that atoms are nodes and bonds are edges.

Now consider a slightly more complex molecule - ethanol (drinking alcohol):

    H H
    | |
H - C-C - O - H
    | |
    H H

Again, we have a graph:

  • 9 nodes: 2 carbons, 1 oxygen, 6 hydrogens
  • 8 edges: All the chemical bonds connecting these atoms

Here’s the key insight: A molecule’s properties depend heavily on how its atoms are connected. Water dissolves salt because its bent O-H-O structure creates a polar molecule. Diamond is hard while graphite is soft - both are pure carbon, but connected differently!

The Molecular Property Prediction Challenge

Before GNNs, how did computers predict molecular properties like solubility, toxicity, or drug effectiveness? Scientists would calculate numerical “descriptors” - features like:

  • Molecular weight (sum of all atom weights)
  • Number of oxygen atoms
  • Number of rotatable bonds
  • Surface area

But this approach has a fundamental flaw. Consider these two molecules:

Traditional Descriptors vs Graph Neural Networks Comparison between traditional molecular descriptors and Graph Neural Networks. Traditional methods lose connectivity information by converting molecules into numerical features, while GNNs preserve the full molecular structure through direct graph processing.

Molecule A:  H-O-C-C-C-C-O-H     (linear structure)
Molecule B:  H-O-C-C-O-H         (branched structure)
                 |
                 C-C

Traditional descriptors might count:

  • Both have 2 oxygens ✓
  • Both have similar molecular weights ✓
  • Both have OH groups ✓

Yet their properties could be vastly different! The traditional approach loses the connectivity information - it treats molecules as “bags of atoms” rather than structured entities.

Enter Graph Neural Networks

GNNs solve this problem elegantly. They process molecules as they truly are - graphs where:

Graph Component Chemistry Equivalent Example in Ethanol
Node Atom C, C, O, H, H, H, H, H, H
Edge Chemical bond C-C, C-O, C-H bonds
Node features Atomic properties Carbon has 4 bonds, Oxygen has 2
Edge features Bond properties Single bond, double bond
Graph Complete molecule The entire ethanol structure

How GNNs Learn from Molecular Graphs

The magic of GNNs lies in message passing - atoms “talk” to their neighbors through bonds. Let’s see how this works step by step:

Message Passing in Graph Neural Networks Step-by-step visualization of the message passing mechanism in GNNs. The figure shows how information propagates through the molecular graph over multiple iterations, allowing each atom to understand its role within the larger molecular context.

Step 0: Initial State Each atom starts knowing only about itself:

Carbon-1: "I'm carbon with 4 bonds"
Carbon-2: "I'm carbon with 4 bonds"  
Oxygen:   "I'm oxygen with 2 bonds"

Step 1: First Message Pass Atoms share information with neighbors:

Carbon-1: "I'm carbon connected to another carbon and 3 hydrogens"
Carbon-2: "I'm carbon between another carbon and an oxygen"
Oxygen:   "I'm oxygen connected to a carbon and a hydrogen"

Step 2: Second Message Pass Information spreads further:

Carbon-1: "I'm in an ethyl group (CH3CH2-)"
Carbon-2: "I'm the connection point to an OH group"
Oxygen:   "I'm part of an alcohol (-OH) group"

After enough message passing, each atom understands its role in the entire molecular structure!

Why Molecular Property Prediction Matters

Molecular property prediction is at the heart of modern drug discovery and materials science. Consider these real-world applications:

GNN Applications in Science and Industry Real-world applications of molecular property prediction using GNNs across six domains: drug discovery, environmental science, materials design, toxicity prediction, battery research, and agriculture.

  1. Drug Discovery: Will this molecule pass through the blood-brain barrier?
  2. Environmental Science: How long will this chemical persist in water?
  3. Materials Design: What’s the melting point of this new polymer?

Traditional experiments to measure these properties are expensive and time-consuming. If we can predict properties from structure alone, we can:

  • Screen millions of virtual compounds before synthesizing any
  • Identify promising drug candidates faster
  • Avoid creating harmful compounds

Representing Molecules as Graphs: A Step-by-Step Guide

Let’s implement a simple example to see how we represent molecules as graphs in code.

We’ll walk step-by-step through a basic molecular graph construction pipeline using RDKit, a popular cheminformatics toolkit in Python. You’ll learn how to load molecules, add hydrogens, inspect atoms and bonds, and prepare graph-based inputs for further learning.

Feature Extraction Pipeline Complete pipeline for converting molecular SMILES strings into graph representations suitable for GNN processing. The workflow shows six stages: from SMILES input through RDKit molecule creation, node/edge feature extraction, to final graph object construction.

1. Load a molecule and include hydrogen atoms

To start, we need to load a molecule using RDKit. RDKit provides a function Chem.MolFromSmiles() to create a molecule object from a SMILES string (a standard text representation of molecules). However, by default, hydrogen atoms are not included explicitly in the molecule. To use GNNs effectively, we want all atoms explicitly shown, so we also call Chem.AddHs() to add them in.

Let’s break down the functions we’ll use:

  • Chem.MolFromSmiles(smiles_str): Creates an rdkit.Chem.rdchem.Mol object from a SMILES string. This object represents the molecule internally as atoms and bonds.

  • mol.GetNumAtoms(): Returns the number of atoms currently present in the molecule object (by default, RDKit does not include H atoms unless you explicitly add them).

  • Chem.AddHs(mol): Returns a new molecule object with explicit hydrogen atoms added to the input mol.

Molecule Loading Process

▶ Click to see code: Basic molecule to graph conversion

from rdkit import Chem
import numpy as np

# Step 1: Create a molecule object from the SMILES string for water ("O" means one oxygen atom)
water = Chem.MolFromSmiles("O")

# Count how many atoms are present (will be 1 — only the oxygen)
print(f"Number of atoms: {water.GetNumAtoms()}")  # Output: 1

# Step 2: Add explicit hydrogen atoms
water = Chem.AddHs(water)

# Count again — now we should see 3 atoms (1 O + 2 H)
print(f"Number of atoms with H: {water.GetNumAtoms()}")  # Output: 3
  1. Initial Atom Count: Initially, the molecule object only includes the oxygen atom, as hydrogen atoms are not explicitly represented by default. Therefore, GetNumAtoms() returns 1.
  2. Adding Hydrogen Atoms: After calling Chem.AddHs(water), the molecule object is updated to include explicit hydrogen atoms. This is essential for a complete representation of the molecule.
  3. Final Atom Count: The final count of atoms is 3, which includes one oxygen atom and two hydrogen atoms. This accurately reflects the molecular structure of water (H₂O).

By explicitly adding hydrogen atoms, we ensure that the molecular graph representation is comprehensive and suitable for further processing in GNNs.

2. Access the bond structure (graph edges)

Once we have the molecule, we want to know which atoms are connected—this is the basis for constructing a graph. RDKit stores this as a list of Bond objects, which we can retrieve using mol.GetBonds().

Let’s break down the functions used here:

  • mol.GetBonds(): Returns a list of bond objects in the molecule. Each bond connects two atoms.

  • bond.GetBeginAtomIdx() and bond.GetEndAtomIdx(): These return the indices (integers) of the two atoms that are connected by the bond.

  • mol.GetAtomWithIdx(idx).GetSymbol(): This retrieves the chemical symbol (e.g. “H”, “O”) of the atom at a given index.

Bond Structure Extraction Extracting bond connectivity from RDKit molecule object. Each bond connects two atoms identified by their indices, forming the edges of our molecular graph.

▶ Click to see code: Extracting graph connectivity

# Print all bonds in the molecule in the form: Atom(index) -- Atom(index)
print("Water molecule connections:")
for bond in water.GetBonds():
    atom1_idx = bond.GetBeginAtomIdx()  # e.g., 0
    atom2_idx = bond.GetEndAtomIdx()    # e.g., 1
    atom1 = water.GetAtomWithIdx(atom1_idx).GetSymbol()  # e.g., "O"
    atom2 = water.GetAtomWithIdx(atom2_idx).GetSymbol()  # e.g., "H"
    print(f"  {atom1}({atom1_idx}) -- {atom2}({atom2_idx})")

# Output:
# Water molecule connections:
#   O(0) -- H(1)
#   O(0) -- H(2)
  1. Bond Retrieval: The mol.GetBonds() function returns a list of bond objects in the molecule. Each bond object represents a connection between two atoms.
  2. Atom Indices: For each bond, bond.GetBeginAtomIdx() and bond.GetEndAtomIdx() return the indices of the two atoms connected by the bond. These indices correspond to the positions of the atoms in the molecule object.
  3. Atom Symbols: The mol.GetAtomWithIdx(idx).GetSymbol() function retrieves the chemical symbol (e.g., “H” for hydrogen, “O” for oxygen) of the atom at a given index. This helps in identifying the types of atoms involved in each bond.
  4. Connectivity Representation: The output shows the connectivity of the water molecule as:
    • O(0) -- H(1)
    • O(0) -- H(2)

This indicates that the oxygen atom (index 0) is bonded to two hydrogen atoms (indices 1 and 2). This connectivity information is crucial for constructing the graph representation of the molecule, where atoms are nodes and bonds are edges.

3. Extract simple atom-level features

Each atom will become a node in our graph, and we often associate it with a feature vector. To keep things simple, we start with just the atomic number.

Here’s what each function does:

  • atom.GetAtomicNum(): Returns the atomic number (integer) for the element, e.g., 1 for hydrogen, 8 for oxygen.

  • mol.GetAtoms(): Returns a generator over all Atom objects in the molecule.

  • atom.GetSymbol(): Returns the chemical symbol (“H”, “O”, etc.), useful for printing/debugging.

Feature Vector Visualization Atom-to-feature mapping. The atomic number provides a simple yet effective initial representation for each node in the molecular graph.

▶ Click to see code: Atom feature extraction

# For each atom, we print its atomic number
def get_atom_features(atom):
    # Atomic number is a simple feature used in many models
    return [atom.GetAtomicNum()]

# Apply to all atoms in the molecule
for i, atom in enumerate(water.GetAtoms()):
    features = get_atom_features(atom)
    symbol = atom.GetSymbol()
    print(f"Atom {i} ({symbol}): features = {features}")

# Output
# Atom 0 (O): features = [8]
# Atom 1 (H): features = [1]
# Atom 2 (H): features = [1]
  1. Feature Extraction Function:
    • The get_atom_features(atom) function extracts the atomic number of each atom using atom.GetAtomicNum(). This is a simple yet powerful feature for distinguishing between different elements.
    • The atomic number is a unique identifier for each element: 1 for hydrogen (H) and 8 for oxygen (O).
  2. Iterating Over Atoms:
    • The mol.GetAtoms() function returns a generator that iterates over all Atom objects in the molecule.
    • For each atom, we retrieve its atomic number and store it as a feature vector (a list containing a single element).
  3. Output Explanation:
    • The output lists each atom in the molecule along with its atomic number:
      • Atom 0 (O): features = [8]: The oxygen atom (index 0) has an atomic number of 8.
      • Atom 1 (H): features = [1]: The first hydrogen atom (index 1) has an atomic number of 1.
      • Atom 2 (H): features = [1]: The second hydrogen atom (index 2) also has an atomic number of 1.
  4. Significance:
    • These atomic numbers serve as the initial node features for the molecular graph. In more advanced models, additional features (e.g., degree, hybridization, electronegativity) can be included to capture more complex chemical properties.
    • By representing each atom with its atomic number, we provide a basic yet meaningful input for graph neural networks to learn from the structural and chemical properties of the molecule.

In summary, this code demonstrates how to extract simple yet essential features from each atom in a molecule, laying the foundation for constructing informative node attributes in molecular graphs.

4. Build the undirected edge list

Now we extract the list of bonds as pairs of indices. Since GNNs typically use undirected graphs, we store each bond in both directions (i → j and j → i).

Functions involved:

  • bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() (as above)
  • We simply collect [i, j] and [j, i] into a list of edges.

Edge List Construction Building an undirected edge list from molecular bonds. Each bond generates two directed edges (i→j and j→i) to ensure bidirectional message passing in the GNN.The complete edge list for water molecule. Bidirectional edges enable information flow in both directions during GNN message passing.

▶ Click to see code: Edge extraction

def get_edge_list(mol):
    edges = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.append([i, j])
        edges.append([j, i])  # undirected graph: both directions
    return edges

# Run on water molecule
water_edges = get_edge_list(water)
print("Water edges:", water_edges)

# Output
# Water edges: [[0, 1], [1, 0], [0, 2], [2, 0]]
  1. Function Definition:
    • get_edge_list(mol): This function takes an RDKit molecule object (mol) as input and returns a list of edges representing the connectivity between atoms.
  2. Edge Extraction:
    • mol.GetBonds(): This method retrieves a list of bond objects from the molecule. Each bond object represents a connection between two atoms.
    • For each bond, bond.GetBeginAtomIdx() and bond.GetEndAtomIdx() are used to get the indices of the two atoms connected by the bond. These indices are integers that uniquely identify atoms within the molecule.
    • The edge list is constructed by appending both [i, j] and [j, i] to the edges list. This ensures that the graph is undirected, which is essential for GNNs. In an undirected graph, the relationship between nodes is bidirectional, meaning that if atom i is connected to atom j, then atom j is also connected to atom i.
  3. Output:
    • The function returns the complete edge list, which includes all pairs of connected atoms in both directions.
    • For the water molecule (H₂O), the output is:
      Water edges: [[0, 1], [1, 0], [0, 2], [2, 0]]
      
      • [0, 1] and [1, 0]: These pairs represent the bond between the oxygen atom (index 0) and the first hydrogen atom (index 1).
      • [0, 2] and [2, 0]: These pairs represent the bond between the oxygen atom (index 0) and the second hydrogen atom (index 2).

Each pair represents one connection (bond) between atoms. Including both directions ensures that during message passing, information can flow freely from each node to all its neighbors.

RDKit Molecular Structures Using RDKit, we can get the chemical structures and corresponding graph statistics for common molecules (water, ethanol, benzene, and aspirin). Each molecule is shown with its 2D structure alongside graph metrics including node count, edge count, and atom type distribution.

Summary: The Power of Molecular Graphs

Let’s recap what we’ve learned:

  1. Molecules are naturally graphs - atoms are nodes, bonds are edges
  2. Traditional methods lose structural information - they treat molecules as bags of features
  3. GNNs preserve molecular structure - they process the actual connectivity
  4. Message passing allows context learning - atoms learn from their chemical environment
  5. Property prediction becomes structure learning - the model learns which structural patterns lead to which properties

In the next section, we’ll dive deep into how message passing actually works, building our understanding step by step until we can implement a full molecular property predictor.

3.3.2 Message Passing and Graph Convolutions

Completed and Compiled Code: Click Here

At the core of a Graph Neural Network (GNN) is the idea of message passing. The goal is to simulate an important phenomenon in chemistry: how electronic effects propagate through molecular structures via chemical bonds. This is something that happens in real molecules, and GNNs try to mimic it through mathematical and computational means.

Let’s first look at a chemistry example. When a fluorine atom is added to a molecule, its high electronegativity doesn’t just affect the atom it is directly bonded to. It causes that carbon atom to become slightly positive, which in turn affects its other bonds, and so on. The effect ripples outward through the structure.

This is exactly the kind of structural propagation that message passing in GNNs is designed to model.

Chemical Effects Propagation Comparison of electronic effects propagation in real molecules (left) versus GNN simulation (right). The fluorine atom’s electronegativity creates a ripple effect through the carbon chain, which GNNs capture through iterative message passing.

The structure of message passing: what happens at each GNN layer?

Even though the idea sounds intuitive, we need a well-defined set of mathematical steps for the computer to execute. In a GNN, each layer usually follows three standard steps.

Message Passing Three Steps The three standard steps of message passing in GNNs: (1) Message Construction - neighbors create messages based on their features and edge properties, (2) Message Aggregation - all incoming messages are combined using sum, mean, or attention, (3) State Update - nodes combine their current state with aggregated messages to produce new representations.

Step 1: Message Construction

For every node $i$, we consider all its neighbors $j$ and create a message $m_{ij}$ to describe what information node $j$ wants to send to node $i$.

This message often includes:

  • Information about node $j$ itself
  • Information about the bond between $i$ and $j$ (e.g., single, double, aromatic)

Importantly, we don’t just pass raw features. Instead, we use learnable functions (like neural networks) to transform the input into something more meaningful for the task.

Step 2: Message Aggregation

Once node $i$ receives messages from all neighbors, it aggregates them into a single combined message $m_i$.

The simplest aggregation method is to sum all incoming messages:

\[m_i = \sum_{j \in N(i)} m_{ij}\]

Here, $N(i)$ is the set of all neighbors of node $i$. This step is like saying: “I listen to all my neighbors and combine what they told me.”

However, in real chemistry, not all neighbors are equally important:

  • A double bond may influence differently than a single bond
  • An oxygen atom might carry more weight than a hydrogen atom

degreenormalize

That’s why advanced GNNs often use weighted aggregation or attention mechanisms to adjust how each neighbor contributes.

Aggregation Functions Different aggregation functions in GNNs. Sum preserves total signal strength, Mean normalizes by node degree, Max captures the strongest signal, and Attention weights messages by learned importance scores.

Step 3: State Update

Finally, node $i$ uses two inputs:

  • Its current feature vector $h_i^{(t)}$
  • The aggregated message $m_i$

These are combined to produce an updated node representation for the next layer:

\[h_i^{(t+1)} = \text{Update}(h_i^{(t)}, m_i)\]

This update is usually implemented with a small neural network, such as a multilayer perceptron (MLP). It learns how to combine a node’s old information with the new input from its neighbors to produce something more useful.

In summary, at each GNN layer, every atom (node) listens to its neighbors and updates its understanding of the molecule. After several layers of message passing, each node’s embedding captures not just its local features, but also the broader context of the molecular structure.

Graph Convolutions: Making It Concrete

The term “graph convolution” comes from analogy with Convolutional Neural Networks (CNNs) in computer vision. In CNNs, filters slide over local neighborhoods of pixels. In GNNs, we also aggregate information from “neighbors”, but now neighbors are defined by molecular or structural connectivity, not spatial proximity.

In Graph Convolutional Networks (GCNs), message passing is defined by this formula for at each layer:

\[h_i^{(t+1)} = \sigma \left( \sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{\sqrt{d_i d_j}} W h_j^{(t)} \right)\]
  • $h_i^{(t)}$: Feature of node $i$ at layer $t$
  • $W$: Learnable weight matrix
  • $d_i$: Degree (number of neighbors) of node $i$
  • $\sigma$: Activation function (e.g. ReLU)

This formula averages and transforms neighbor features while normalizing based on node degrees.

GCN Formula Breakdown

According to the formula, GCN message passing follows 4 steps at each layer:

gcn4steps

To execute it in Python, we use PyTorch Geometric (PyG). In PyTorch Geometric (PyG), the most basic GNN implementation is GCNConv. Let’s go through each part of the code.

PyTorch Geometric Components

Component Purpose
torch.tensor(...) Creates dense tensors (like NumPy arrays) for node features or edge indices.
x Node feature matrix. Shape = [num_nodes, num_node_features]
edge_index Edge list in COO format: [2, num_edges]. First row: source nodes. Second row: target nodes.
torch_geometric.data.Data Creates a graph object holding x, edge_index, and optionally edge/node labels.
GCNConv(in_channels, out_channels) A GCN layer that does: message passing + aggregation + update.
conv(x, edge_index) Applies one layer of graph convolution and returns updated node features.

Each part of our code works according to this Flowchart:

flowchart

Algorithmic Idea

The flowchart above shows how raw atomic features are transformed by a single GCN (Graph Convolutional Network) layer into learned, structure-aware embeddings. Below, we unpack this transformation step-by-step, both conceptually and computationally.

1. Node Features Each atom is represented by a 3-dimensional one-hot feature vector that encodes its identity. For example:

  • [1, 0, 0] could represent carbon (C)
  • [0, 1, 0] oxygen (O)
  • [0, 0, 1] nitrogen (N)

encode

In our toy molecule with 4 atoms, the input node feature matrix has shape [4, 3]: 4 atoms, each with 3 features.

import torch

x = torch.tensor([
    [1, 0, 0],  # Atom 0: C
    [0, 1, 0],  # Atom 1: O
    [1, 1, 0],  # Atom 2: hybrid or multi-type
    [0, 0, 1]   # Atom 3: N
], dtype=torch.float)

This matrix x is the starting point: a basic, structure-free description of the molecule.

2. Edge Index To describe how atoms are bonded, we define the connectivity of the molecular graph using a list of directed edges. Each chemical bond is entered twice to support bidirectional message passing.

The resulting tensor edge_index has shape [2, E], where E is the total number of directed edges. In our case, E = 8.

edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3, 3, 0],  # Source atoms
    [1, 0, 2, 1, 3, 2, 0, 3]   # Target atoms
], dtype=torch.long)

Here, for example, the edge (0, 1) and (1, 0) encode a bond between atoms 0 and 1.

represent

3. Graph Data PyTorch Geometric uses a unified object to represent graph-structured data. The node features x and the connectivity edge_index are bundled into a Data object:

from torch_geometric.data import Data

data = Data(x=x, edge_index=edge_index)

At this point:

  • data.x has shape [4, 3]
  • data.edge_index has shape [2, 8]

This fully defines a small undirected graph of 4 atoms.

graphdata

4. GCN Layer We now define a graph convolutional layer that will transform each 3-dimensional atom vector into a 2-dimensional learned embedding.

This is not just a dimensionality reduction step. Rather, it’s a learned transformation:

  • It combines each atom’s own features with information from its bonded neighbors
  • It applies a shared learnable weight matrix $W \in \mathbb{R}^{3 \times 2}$
  • It normalizes contributions based on node degrees
  • It applies a non-linear activation function (e.g. ReLU)

3d2d

This transformation implements the formula:

\[h_i^{(t+1)} = \sigma \left( \sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{\sqrt{d_i d_j}} \, W h_j^{(t)} \right)\]

Where:

  • $h_j^{(t)}$ is the 3-dimensional feature vector of neighbor $j$
  • $W$ maps from 3 → 2 dimensions
  • $\sigma$ is a nonlinearity (e.g. ReLU)
  • The normalization factor $\frac{1}{\sqrt{d_i d_j}}$ accounts for node degrees

The layer is defined as:

from torch_geometric.nn import GCNConv

conv = GCNConv(in_channels=3, out_channels=2)

This means: input is a [4, 3] matrix, and output will be a [4, 2] matrix. Each atom now gets a 2-dimensional embedding.

Why go from 3 to 2 dimensions? We are not blindly compressing the input — rather, we are learning a more compact, expressive representation that fuses both identity and structure. The dimensionality is a design choice: you could use 2, 8, 128… depending on downstream task complexity. In this toy case, 2 is used for visualization and simplicity.

5. Forward Pass We now execute the forward pass of the GCN. Internally, the layer:

  • For each atom, gathers features from its neighbors and itself
  • Applies the learned transformation and aggregates
  • Outputs a new feature vector per atom
output = conv(data.x, data.edge_index)

After this operation:

  • Input shape: [4, 3]
  • Output shape: [4, 2]
  • Each row in the output is a learned embedding vector for an atom

forwardpasstable

6. Output Features

print("Updated Node Features After Message Passing:")
print(output)

Result

Updated Node Features After Message Passing:
tensor([[ 0.2851, -0.0017],
        [ 0.6568, -0.4519],
        [ 0.6180,  0.1266],
        [ 0.2807, -0.3559]], grad_fn=<AddBackward0>)

Interpretation

Shape: The output has shape [4, 2]:

  • 4 rows → 4 atoms
  • 2 columns → each atom’s learned embedding

What does each embedding mean?

Each row (e.g. [0.2851, -0.0017]) represents an atom’s updated feature — not just its raw type (C, O, N), but also how it is situated in the molecular graph:

  • Who its neighbors are
  • What types they are
  • How strongly it’s connected (degree)

This embedding is no longer one-hot or fixed — it is learned from data, and will improve with training.

Why is this useful?

You can now:

  • Feed these embeddings into another neural network to predict molecular properties
  • Use them to classify atom roles (e.g., is this a reaction site?)
  • Visualize molecule structure in 2D/3D via t-SNE or PCA

About grad_fn=<AddBackward0>

This line tells you that the output is part of the autograd computation graph in PyTorch. That means it supports backpropagation: gradients will flow back through the GCN layer during training to update $W$.

Variants of Graph Convolutions

Different GNN models define the message passing process differently:

Graph Convolutional Networks (GCNs) Use simple averaging with normalization. Very stable and interpretable. Good for small graphs with clean structure.

GraphSAGE Introduces neighbor sampling, which makes it scalable to large graphs. You can also choose the aggregation function (mean, max, LSTM, etc.).

Graph Attention Networks (GATs) Use attention to assign different weights to different neighbors. This is very helpful in chemistry, where some bonds are more important (e.g. polar bonds).

Message Passing Neural Networks (MPNNs) A general and expressive framework. Can use edge features, which is important in molecules (e.g. bond type, aromaticity). Many SOTA chemistry models (e.g., D-MPNN) are built on this.

GNN Variants Figure 3.3.10: Comparison of different GNN architectures. GCN uses simple normalized averaging, GraphSAGE samples neighbors for scalability, GAT employs attention mechanisms for weighted aggregation, and MPNN provides a general framework incorporating edge features.

Chemical Intuition Behind Message Passing

To understand how message passing in graph neural networks actually captures chemical effects, let’s walk through a concrete example: the molecule para-nitrophenol, which features two chemically distinct groups — a nitro group (NO₂) and a hydroxyl group (OH) — placed at opposite ends of a benzene ring.

Chemically speaking, this setup forms a classic “push-pull” system: the nitro group is strongly electron-withdrawing, while the hydroxyl group is electron-donating. This dynamic tension in electron distribution plays a key role in determining the molecule’s acidity, reactivity, and overall behavior. The power of message passing lies in its ability to gradually capture this electron flow, layer by layer.

GNN Layer Expansion Figure 3.3.11: Receptive field expansion in GNNs. Each layer increases a node’s awareness by one hop. Starting from self-awareness (Layer 0), nodes progressively integrate information from 1-hop neighbors, 2-hop neighbors, and eventually the entire molecular graph.

GNN Message Passing: Neighborhood Expansion per Layer

  • Layer 0: Self (Each atom only knows its own features)
    • e.g., O knows it’s oxygen; N knows it’s nitrogen
  • Layer 1: 1-hop neighbors (Directly bonded atoms)
    • O learns about the carbon it’s attached to
    • N in NO₂ learns about its adjacent O atoms
  • Layer 2: 2-hop neighbors (Neighbors of neighbors)
    • O learns about atoms bonded to its neighboring carbon
    • Benzene carbons begin to capture influence from NO₂ and OH
  • Layer 3: 3-hop neighborhood (Extended molecular context)
    • Carbons across the ring begin to “feel” opposing substituent effects
    • Push-pull interactions emerge in representation
  • Layer 4+: Global context (Full molecule representation)
    • Every atom integrates information from the entire molecule
    • Final features encode global electronic and structural effects

Summary

In this section, we explored message passing and graph convolutions, the fundamental mechanisms that enable Graph Neural Networks to learn from molecular structures. The key insight is that GNNs mimic how electronic effects propagate through chemical bonds in real molecules.

The message passing framework follows three standard steps at each layer:

  1. Message Construction: Nodes create messages for their neighbors using learnable functions
  2. Message Aggregation: Each node combines incoming messages (via sum, attention, etc.)
  3. State Update: Nodes update their representations by combining current features with aggregated messages

Through the lens of para-nitrophenol, we saw how this iterative process gradually expands each atom’s “awareness” from local to global context. Starting with only self-knowledge, atoms progressively integrate information from direct neighbors, then second-hop neighbors, until eventually capturing full molecular context including competing electronic effects.

Different GNN architectures (GCN, GraphSAGE, GAT, MPNN) offer various approaches to this process, each with distinct advantages for chemical applications. The choice depends on factors like molecular size, importance of edge features, and computational constraints.

The power of message passing lies in its ability to bridge structure and function. By allowing atoms to “communicate” through bonds, GNNs learn representations that encode not just molecular topology, but also the chemical behaviors that emerge from that structure — acidity, reactivity, electronic distribution, and more. This makes GNNs particularly well-suited for molecular property prediction and drug discovery tasks where understanding chemical context is crucial.

3.3.3 GNNs for Molecular Property Prediction

Completed and Compiled Code: Click Here

Experimental Pipeline Overview

What We’re Building: A Molecular Solubility Predictor

Why Predict Molecular Solubility?

  • Drug Discovery: 90% of drug candidates fail due to poor solubility – we need to predict this early!
  • Cost Savings: Lab testing costs $1000+/molecule; our model predicts in milliseconds
  • Real Impact: Better solubility = better drug absorption = more effective medicines

whypredict

What you’ll learn: How to turn molecules into graphs and use AI to predict their properties

Stage Input Process Output
1. Data Loading ESOL CSV file pandas parsing SMILES + log S values
2. Molecular Encoding SMILES strings
(text)
RDKit → graph conversion Node features + edge indices
3. Model Building Graph structures 3-layer GCN Molecular embeddings
4. Training Batched graphs Adam optimizer + MSE loss Trained parameters
5. Prediction New SMILES Forward pass Solubility (log S)

5stages

Key Metrics We’ll Track:

  • RMSE: Root Mean Squared Error (expect ~1.9 log S)
  • R²: Variance explained (expect ~0.22)
  • MAE: Mean Absolute Error (expect ~1.6 log S)

Key Terms Explained

Term Simple Explanation Why It Matters
SMILES Text code for molecules (like “H₂O” but more detailed) Computers can’t see molecules, so we use text
Node Features Properties of each atom (e.g. element type, charge) AI needs numbers to work with
Edge Indices Which atoms are connected by bonds Structure determines properties
GCN Graph Convolutional Network – AI for graph data Molecules are naturally graphs!

terms

Step 1: Understanding Molecular Solubility as a Graph Learning Problem

The Chemistry Behind Solubility

Why is solubility prediction hard? Solubility emerges from a delicate balance of intermolecular forces:

\[\mathrm{Solubility} \;\propto\; \frac{\text{Solute–Solvent interactions}}{\text{Solute–Solute interactions}}\]

What This Formula Really Means:

  • Numerator (top): How well the molecule “likes” water
  • Denominator (bottom): How much the molecule “likes” itself
  • Result: If molecules prefer water over themselves → high solubility!

Example: Sugar dissolves because it forms hydrogen bonds with water better than with other sugar molecules

factors

Key factors:

  • Hydrogen bonding: –OH, –NH groups increase water solubility
  • Hydrophobic effect: Long carbon chains decrease solubility
  • Molecular size: Larger molecules → harder to solvate
  • Aromaticity: π-systems are hydrophobic

Why Graph Neural Networks?

Traditional machine learning uses fixed-size molecular fingerprints, losing structural information. GNNs preserve the full molecular graph:

Molecule as Graph:

  • Nodes: Atoms with features (element, charge, aromaticity)
  • Edges: Chemical bonds (single, double, triple, aromatic)
  • Message Passing: Atoms “communicate” through bonds

intuition

Intuition:

  • Just like social networks!

    • People (atoms) have properties (age, interests)
    • Friendships (bonds) connect people
    • Information spreads through connections
    • Your friends influence your behavior!

Step 1.1: Package Imports and Setup

We need several specialized libraries to handle different aspects of the pipeline:

Step Library Purpose Why We Need It
1 PyTorch Deep Learning Core Handles neural networks and gradients
2 PyTorch Geometric Graph Operations Special tools for graph neural networks
3 RDKit Chemistry Processing Understands molecules and bonds
4 NumPy / Pandas Data Handling Efficient array and table operations
5 Scikit-learn Evaluation Metrics Measure how good our predictions are
Library Import Process
Step 1: Import PyTorch Core
torch, nn, F
Step 2: Import Graph Tools
GCNConv, DataLoader
Step 3: Import Chemistry
RDKit Chem module
Step 4: Import Utilities
numpy, pandas, sklearn
# Deep Learning Framework
import torch
import torch.nn as nn
import torch.nn.functional as F

# Graph Neural Network Operations
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader

# Chemistry Toolkit
from rdkit import Chem

# Data Processing & Visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import requests
import io
from sklearn.metrics import mean_squared_error, r2_score

library

Step 1.2: Feature Extraction

Each atom’s properties influence molecular behavior. We encode 5 key atomic features:

Feature Chemical Meaning Value Range Why It Affects Solubility
Atomic Number Element identity (C=6, N=7, O=8) 1-118 O, N form H-bonds with water
Degree Number of bonds 0-4 (typically) More bonds = less flexible
Formal Charge Electronic state -2 to +2 Charged = water-loving
Aromaticity In aromatic ring? 0 or 1 Aromatic = water-fearing
H Count Hydrogen bonding potential 0-4 More H = more H-bonds

Implementation: The get_atom_features function extracts these properties from RDKit atom objects:

def get_atom_features(atom):
    """
    Extract numerical features from RDKit atom object.
    These features capture the chemical environment of each atom.
    
    Think of this as creating an "ID card" for each atom with 5 key facts!
    """
    return [
        atom.GetAtomicNum(),        # What element? (C=6, N=7, O=8, etc.)
        atom.GetDegree(),           # How many bonds? (connectivity)
        atom.GetFormalCharge(),     # Is it charged? (+1, 0, -1, etc.)
        int(atom.GetIsAromatic()),  # In benzene-like ring? (0=no, 1=yes)
        atom.GetTotalNumHs()        # How many hydrogens attached?
    ]
Atom Feature Extraction Process
Input:
RDKit Atom Object
Contains all chemical info
→ Extract 5 Features →
1. Element type
2. Bond count
3. Charge state
4. Aromaticity
5. H-bond potential
Output:
List of 5 numbers
[8, 2, 0, 0, 0] for oxygen

Why These 5 Features?

  • Atomic Number: Oxygen atoms love water, carbon atoms don’t
  • Degree (Connectivity): Highly connected atoms are “buried” in the molecule
  • Formal Charge: Charged molecules dissolve like salt in water
  • Aromaticity: Benzene rings are oily, not watery
  • H Count: More hydrogens = more hydrogen bonding with water

why5features

Step 1.3: Bond Connectivity Extraction

Chemical bonds are bidirectional - electrons are shared between atoms. We need to represent this bidirectionality:

Bond Extraction Process
Step 1:
Loop through bonds
mol.GetBonds()
Step 2:
Get atom indices
i = start, j = end
Step 3:
Create bidirectional
[i→j] and [j→i]
Result:
Edge list
[[0,1], [1,0], ...]
def get_bond_connections(mol):
    """
    Convert molecular bonds to directed edges.
    Bidirectional to allow information flow in both directions.
    
    Why bidirectional? In real molecules, electrons flow both ways!
    """
    edges = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.extend([[i, j], [j, i]])  # Add both directions
    return edges

bind

Step 2: Loading and Exploring the ESOL Dataset

Step 2.1: Dataset Overview

Why ESOL? The ESOL (Estimated SOLubility) dataset is a gold standard because:

  • 1,128 molecules with experimental (not computed) solubility
  • Spans 13+ orders of magnitude ($10^{13}$-fold range!)
  • Diverse chemical space: alkanes, aromatics, heterocycles, etc.

Data Loading Implementation:

Data Loading Pipeline
Step 1:
Fetch from GitHub
requests.get(url)
Step 2:
Parse CSV
pd.read_csv()
Step 3:
Extract SMILES
data['smiles']
Step 4:
Extract log S
data['measured log solubility']
# Download ESOL dataset from DeepChem repository
url = "https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/delaney-processed.csv"
response = requests.get(url)
data = pd.read_csv(io.StringIO(response.text))

# Extract molecular data
smiles_list = data['smiles'].tolist()
solubility_values = data['measured log solubility in mols per litre'].tolist()

print(f"Dataset contains {len(smiles_list)} molecules")
print(f"Solubility range: {min(solubility_values):.2f} to {max(solubility_values):.2f} log S")

pipeline1

Result Interpretation:

Dataset contains 1128 molecules
Solubility range: -11.60 to 1.58 log S

What do these numbers mean?

  • log S = –11.60 → Solubility = $10^{-11.60}$ mol/L (extremely insoluble)
  • log S = 1.58 → Solubility = $10^{1.58}$ mol/L (very soluble)
  • Range: 13.18 log units = $10^{13.18}$ ≈ 15 trillion-fold difference!

In Real-World Context:

  • Very soluble (log S > 0): Like sugar in water – dissolves easily
  • Moderately soluble (–3 < log S < 0): Like alcohol – mixes well
  • Poorly soluble (log S < –3): Like oil – forms a separate layer
  • Practically insoluble (log S < –6): Like plastic – never dissolves

Step 2.2: Examining Example Molecules

Let’s examine some specific molecules to understand the dataset diversity:

Data Exploration Process
Purpose:
Understand dataset
Check molecular diversity
Method:
Print first 5 molecules
SMILES + solubility
Insight:
Complexity varies
Simple to complex structures
print("\nExample molecules from the dataset:")
print("-" * 60)
print(f"{'SMILES':<40} {'Solubility (log S)':<20}")
print("-" * 60)

for i in range(5):
    print(f"{smiles_list[i]:<40} {solubility_values[i]:<20.2f}")

5moles

Result Interpretation:

Example molecules from the dataset:
------------------------------------------------------------
SMILES                                   Solubility (log S)  
------------------------------------------------------------
OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O  -0.77               
Cc1occc1C(=O)Nc2ccccc2                   -3.30               
CC(C)=CCCC(C)=CC(=O)                     -2.06               
c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43       -7.87               
c1ccsc1                                  -1.33               

Chemical Interpretation:

  • Row 1: Complex sugar derivative with multiple –OH groups → relatively soluble (log S = –0.77)
  • Row 4: Large polycyclic aromatic hydrocarbon → very insoluble (log S = –7.87)
  • Row 5: Small thiophene heterocycle → moderate solubility (log S = –1.33)

The dataset covers a wide range of molecular complexity and functional groups.

Step 2.3: Visualizing Solubility Distribution

Understanding the distribution helps us assess potential modeling challenges:

Visualization Purpose
What:
Histogram of solubility
50 bins
Why:
See data distribution
Identify skewness
Insight:
Most molecules poorly soluble
Left-skewed distribution
plt.figure(figsize=(10, 6))
plt.hist(solubility_values, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Log Solubility (log S)', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.title('Distribution of Solubility Values in ESOL Dataset', fontsize=14)
plt.grid(True, alpha=0.3)
plt.show()

Distribution of Solubility Values in ESOL Dataset

Result Interpretation:

  • Central tendency: Distribution peaks around -2 to -3 log S
  • Skewness: Left-skewed distribution (more insoluble molecules)
  • Chemical reality: Most organic molecules have limited water solubility
  • Modeling implication: Model may perform better on common solubility ranges than extremes

Step 3: Converting Molecules to Graph Representations

Step 3.1: Testing Feature Extraction on Water

We validate our feature extraction by testing on water (H₂O), the simplest molecule:

Water Molecule Test Process
Input:
SMILES: "O"
Oxygen only
Step 1:
Parse SMILES
Create mol object
Step 2:
Add H atoms
O → H-O-H
Step 3:
Extract features
For O, H, H
# Parse water molecule
water_smiles = "O"
water = Chem.MolFromSmiles(water_smiles)
water = Chem.AddHs(water)  # Add explicit hydrogens

print("Water molecule (H2O) atom features:")
print("-" * 50)
for i, atom in enumerate(water.GetAtoms()):
    features = get_atom_features(atom)
    symbol = atom.GetSymbol()
    print(f"Atom {i} ({symbol}): {features}")

h20

Result Interpretation:

Water molecule (H2O) atom features:
--------------------------------------------------
Atom 0 (O): [8, 2, 0, 0, 0]
Atom 1 (H): [1, 1, 0, 0, 0]
Atom 2 (H): [1, 1, 0, 0, 0]

Feature Vector Breakdown:

Oxygen: [8, 2, 0, 0, 0]

  • 8 = Atomic number (element oxygen)
  • 2 = Degree (bonded to 2 hydrogen atoms)
  • 0 = No formal charge (neutral)
  • 0 = Not aromatic (water is not aromatic)
  • 0 = No implicit hydrogens (all are explicit)

Hydrogen: [1, 1, 0, 0, 0]

  • 1 = Atomic number (element hydrogen)
  • 1 = Degree (bonded to 1 oxygen atom)
  • Remaining features are all zero

Step 3.2: Testing Bond Extraction on Ethanol

Test on a more complex molecule to see bond connectivity:

Ethanol Bond Analysis
Molecule:
Ethanol (CCO)
CH₃-CH₂-OH
Expected:
9 atoms total
2C + 1O + 6H
Bonds:
8 undirected
16 directed edges
# Test on ethanol
ethanol_smiles = "CCO"
ethanol = Chem.MolFromSmiles(ethanol_smiles)
ethanol = Chem.AddHs(ethanol)

connections = get_bond_connections(ethanol)

print(f"Ethanol molecule (C2H6O):")
print(f"  Number of atoms: {ethanol.GetNumAtoms()}")
print(f"  Number of bonds: {len(connections)//2}")
print(f"  Number of directed edges: {len(connections)}")

eth

Result:

Ethanol molecule (C2H6O):
  Number of atoms: 9
  Number of bonds: 8
  Number of directed edges: 16

Step 3.3: Detailed Edge Analysis

Edge Directionality Explanation
Why Bidirectional?
• Electrons flow both ways
• Message passing needs it
• Reflects chemical reality
Example:
C-C bond becomes:
• Edge 0: C→C
• Edge 1: C←C
print("\nFirst few connections (atom index pairs):")
for i, (src, dst) in enumerate(connections[:6]):
    src_symbol = ethanol.GetAtomWithIdx(src).GetSymbol()
    dst_symbol = ethanol.GetAtomWithIdx(dst).GetSymbol()
    print(f"  Edge {i}: {src}({src_symbol}) → {dst}({dst_symbol})")

flow

Result:

First few connections (atom index pairs):
  Edge 0: 0(C) → 1(C)
  Edge 1: 1(C) → 0(C)
  Edge 2: 1(C) → 2(O)
  Edge 3: 2(O) → 1(C)
  Edge 4: 0(C) → 3(H)
  Edge 5: 3(H) → 0(C)

Chemical Interpretation: Each bond appears twice (e.g., C→C and C←C) to enable bidirectional message passing in the GNN. This reflects the quantum mechanical reality that electrons are shared between atoms.

edge

Step 3.4: Complete Molecule-to-Graph Conversion

Implementation Strategy:

Step Process Output
1 SMILES String → RDKit Molecule Molecule object
2 Add Hydrogens Complete molecule
3 Extract Features Node feature matrix
4 Extract Bonds Edge index
5 Create PyG Data Graph object
Main Conversion Function Structure
Function Inputs:
• SMILES string
• Solubility value (optional)
Function Outputs:
• PyG Data object
• None if parsing fails
Processing Steps:
Parse → Add H → Extract features → Extract bonds → Package
def molecule_to_graph(smiles, solubility=None):
    """
    Complete pipeline: SMILES → Molecular Graph
    
    Uses RDKit for chemistry, PyTorch Geometric for graph structure
    """
    # Parse SMILES
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Add hydrogens (affects H-bonding)
    mol = Chem.AddHs(mol)

Step 3.4.1: Feature Extraction Part

Feature Extraction Loop
For each atom:
Call get_atom_features()
Returns list of 5 numbers
Collect all:
atom_features list
[[f1], [f2], ...]
Convert:
To PyTorch tensor
Shape: [n_atoms, 5]
    # Extract atom features
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append(get_atom_features(atom))
    
    # Convert to tensor
    x = torch.tensor(atom_features, dtype=torch.float)

Step 3.4.2: Edge Construction Part

Edge Index Construction
Get bonds:
Call get_bond_connections
Returns [[i,j], ...]
Handle edge case:
Single atoms
Add self-loop [0,0]
Format for PyG:
Transpose to COO
Shape: [2, n_edges]
    # Extract bonds
    edge_list = get_bond_connections(mol)
    if len(edge_list) == 0:  # Single atom
        edge_list = [[0, 0]]  # Self-loop
    
    # Create edge index (transpose for PyG format)
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

Step 3.4.3: Data Object Creation

PyG Data Object Assembly
Required fields:
• x: node features
• edge_index: connectivity
Optional fields:
• y: target value
• Other properties
    # Create Data object
    data = Data(x=x, edge_index=edge_index)
    
    # Add label if provided
    if solubility is not None:
        data.y = torch.tensor([solubility], dtype=torch.float)
    
    return data

5 steps

Step 3.5: Testing the Conversion Pipeline

Test on molecules of varying complexity:

Test Strategy
Test 1:
Water (O)
3 atoms, 2 bonds
Test 2:
Ethanol (CCO)
9 atoms, 8 bonds
Test 3:
Benzene (c1ccccc1)
12 atoms, 12 bonds
test_molecules = [
    ("O", "Water"),
    ("CCO", "Ethanol"),
    ("c1ccccc1", "Benzene")
]

print("Testing molecule to graph conversion:")
print("-" * 60)

for smiles, name in test_molecules:
    graph = molecule_to_graph(smiles, solubility=0.0)
    if graph:
        print(f"{name} ({smiles}):")
        print(f"  Atoms: {graph.x.shape[0]}")
        print(f"  Features per atom: {graph.x.shape[1]}")
        print(f"  Bonds: {graph.edge_index.shape[1] // 2}")
        print(f"  Graph object: {graph}")
        print()

features

Result Interpretation:

Testing molecule to graph conversion:
------------------------------------------------------------
Water (O):
  Atoms: 3
  Features per atom: 5
  Bonds: 2
  Graph object: Data(x=[3, 5], edge_index=[2, 4], y=[1])

Ethanol (CCO):
  Atoms: 9
  Features per atom: 5
  Bonds: 8
  Graph object: Data(x=[9, 5], edge_index=[2, 16], y=[1])

Benzene (c1ccccc1):
  Atoms: 12
  Features per atom: 5
  Bonds: 12
  Graph object: Data(x=[12, 5], edge_index=[2, 24], y=[1])

How to interpret “Data(x=[12, 5], edge_index=[2, 24], y=[1])” ?

  • x: Node feature matrix [num_atoms, num_features]
  • edge_index: COO-format edges [2, num_edges]
  • y: Target property (solubility)

Example – Benzene:

  • 6 carbon atoms + 6 hydrogen atoms = 12 nodes
  • 6 C–C bonds + 6 C–H bonds = 12 undirected bonds
  • 12 undirected bonds × 2 directions = 24 directed edges

benzene

What is COO Format? COO (COOrdinate) format stores edges as pairs of node indices:

  • First row: source nodes [0, 1, 1, 2, …]
  • Second row: target nodes [1, 0, 2, 1, …]
  • Each column is one edge: (0→1), (1→0), (1→2), (2→1), …

Step 4: Building the Graph Neural Network Architecture

Step 4.1: Review the GNN Design Principles

Message Passing Framework: As we have seen before, each GCN layer performs the following operation:

\[h_i^{(l+1)} = \sigma\bigl(W^{(l)} \cdot \mathrm{AGG}(\{h_j^{(l)} : j \in N(i) \cup \{i\}\})\bigr)\]

pooling

Breaking Down This Formula

  • $h_i^{(l)}$ — What atom i knows at layer l
  • $N(i)$ — Atom i’s neighbors (bonded atoms)
  • $\mathrm{AGG}()$ — Combine information from neighbors (usually average)
  • $W^{(l)}$ — Learnable transformation (the “smart” part)
  • $\sigma$ — Activation function (adds non-linearity)

Step 4.2: Set up the Model Architecture

Design Choices Architecture Intuition
3 GCN layers Why 3 layers? Most chemical effects happen within 3 bonds.
Captures 3-hop neighborhoods (sufficient for most molecular patterns)
64 hidden dimensions Why 64 dimensions? Enough to capture complexity, not too much to overfit.
Balances expressiveness vs overfitting
Global mean pooling Why mean pooling? Average all atom features to get molecule feature.
Aggregates variable-sized molecules to fixed representation
Single output Final value is log solubility — a continuous target.
Predicts scalar solubility value

Step 4.3: Class Definition and Initialization

Model Architecture Components
Component 1:
GCN Layers
3 layers in ModuleList
Component 2:
Predictor
Linear layer (64→1)
Flow:
Atoms→GCN→Pool→Predict
Variable→Fixed size
class MolecularGNN(nn.Module):
    """
    Graph Neural Network for molecular property prediction
    
    Architecture Flow:
    Atoms (5 features) → GCN layers → Molecular embedding (64) → Solubility (1)
    """
    
    def __init__(self, num_features=5, hidden_dim=64, num_layers=3):
        super(MolecularGNN, self).__init__()
        
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Build GCN layers
        self.convs = nn.ModuleList()

Step 4.3.1: Layer Construction

Layer Construction Process
Layer 1:
Input → Hidden
5 features → 64 dims
Layers 2-3:
Hidden → Hidden
64 dims → 64 dims
Output Layer:
Hidden → Prediction
64 dims → 1 value
        # First layer: 5 → 64
        self.convs.append(GCNConv(num_features, hidden_dim))
        
        # Hidden layers: 64 → 64
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        # Output layer: 64 → 1
        self.predictor = nn.Linear(hidden_dim, 1)

layers

Step 4.3.2: Forward Pass Implementation

Principle: The forward pass implements message passing followed by pooling:

Forward Pass Flow
Step 1:
Message Passing
3 GCN layers
Step 2:
Activation
ReLU after each
Step 3:
Pooling
Atoms → Molecule
Step 4:
Prediction
64 → 1 value
    def forward(self, x, edge_index, batch):
        """
        Forward propagation through GNN
        
        Args:
            x: Node features [num_atoms_in_batch, 5]
            edge_index: Edge connectivity [2, num_edges_in_batch]
            batch: Maps atoms to molecules [num_atoms_in_batch]
        """
        # Message passing layers
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)  # Non-linearity
        
        # Aggregate atoms → molecules
        x = global_mean_pool(x, batch)
        
        # Predict property
        return self.predictor(x)

forwardpass

What Actually Happens in Forward Pass?

  1. Input: Each atom starts with 5 features
  2. Layer 1: Atoms exchange info with neighbors, transform to 64 features
  3. Layer 2: Exchange again, refine understanding
  4. Layer 3: Final exchange, atoms now “know” their 3-hop neighborhood
  5. Pooling: Average all atoms to get one molecule representation
  6. Prediction: Transform 64 features to 1 solubility value

Key Point: The network learns what information to exchange and how to transform it!

Step 4.4: Model Analysis

Let’s analyze the model architecture:

Model Analysis Process
What we're doing:
• Create model instance
• Count parameters
• Analyze complexity
Why it matters:
• Understand model size
• Compare to other DNNs
• Check if overfitting risk
# Create model instance
model = MolecularGNN(num_features=5, hidden_dim=64, num_layers=3)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model architecture: MolecularGNN")
print(f"  Input features: 5 (per atom)")
print(f"  Hidden dimension: 64")
print(f"  Number of GCN layers: 3")
print(f"  Total parameters: {total_params:,}")

Result:

Model architecture: MolecularGNN
  Input features: 5 (per atom)
  Hidden dimension: 64
  Number of GCN layers: 3
  Total parameters: 8,769

Step 4.5: Detailed Parameter Breakdown

Parameter Breakdown Analysis
Purpose:
See where parameters are
Weights vs biases
Insight:
Most params in hidden layers
64×64 = 4096 each
print("\nLayer-by-layer breakdown:")
for name, param in model.named_parameters():
    print(f"  {name}: {param.shape}")

Result:

Layer-by-layer breakdown:
  convs.0.bias: torch.Size([64])
  convs.0.lin.weight: torch.Size([64, 5])
  convs.1.bias: torch.Size([64])
  convs.1.lin.weight: torch.Size([64, 64])
  convs.2.bias: torch.Size([64])
  convs.2.lin.weight: torch.Size([64, 64])
  predictor.weight: torch.Size([1, 64])
  predictor.bias: torch.Size([1])

paramsize

Step 5: Preparing Training Data

Step 5.1: Dataset Conversion

Convert all molecules to graphs, handling potential failures:

Dataset Conversion Pipeline
Step 1:
Loop molecules
1000 total
Step 2:
Convert each
SMILES → Graph
Step 3:
Handle failures
Track bad SMILES
Result:
List of graphs
Ready for training
num_molecules = 1000  # Use subset for faster training
graphs = []
failed_molecules = []

print(f"Converting {num_molecules} molecules to graphs...")

for i in range(num_molecules):
    smiles = smiles_list[i]
    solubility = solubility_values[i]
    
    graph = molecule_to_graph(smiles, solubility)
    if graph is not None:
        graphs.append(graph)
    else:
        failed_molecules.append((i, smiles))

trainpre

Conversion Results:

Conversion Quality Check
Success Rate:
100% (1000/1000)
All SMILES valid
Meaning:
High quality dataset
Well-curated ESOL
print(f"Successfully converted: {len(graphs)} molecules")
print(f"Failed conversions: {len(failed_molecules)} molecules")

Result:

Converting 1000 molecules to graphs...
Successfully converted: 1000 molecules
Failed conversions: 0 molecules

100% success rate indicates high-quality SMILES strings in the ESOL dataset. RDKit successfully parsed all molecules.

Step 5.2: Train-Test Split

Always evaluate on unseen data to assess generalization. We use 80/20 split:

Train-Test Split Strategy
Training Set (80%):
800 molecules
Model learns from these
Test Set (20%):
200 molecules
Never seen during training
Purpose:
Prevent overfitting
Test generalization
# 80/20 split (standard in ML)
train_size = int(0.8 * len(graphs))
train_graphs = graphs[:train_size]
test_graphs = graphs[train_size:]

print(f"\nDataset split:")
print(f"  Training: {len(train_graphs)} molecules")
print(f"  Testing: {len(test_graphs)} molecules")

Result:

Dataset split:
  Training: 800 molecules
  Testing: 200 molecules

trainsplit

Step 5.3: Creating DataLoaders

DataLoader automatically batches variable-sized graphs:

DataLoader Configuration
Batch Size:
32 molecules/batch
Process together
Shuffle:
True for training
Random order helps
Result:
25 train batches
7 test batches
# Batch size 32 is typical for molecular property prediction
train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

print(f"Data loaders created:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Test batches: {len(test_loader)}")
print(f"  Batch size: 32")

Result:

Data loaders created:
  Training batches: 25
  Test batches: 7
  Batch size: 32

loader

Step 5.4: Batch Structure Analysis:

Understanding Batch Structure
Challenge:
Molecules have different sizes
• Water: 3 atoms
• Aspirin: 21 atoms
How to batch?
Solution:
Combine into one big graph
• All atoms together
• Batch tensor tracks ownership
• No edges between molecules
# Inspect batch structure
for batch in train_loader:
    print(f"\nExample batch:")
    print(f"  Total atoms in batch: {batch.x.shape[0]}")
    print(f"  Total molecules in batch: {batch.num_graphs}")
    print(f"  Batch tensor shape: {batch.batch.shape}")
    print(f"  Edge index shape: {batch.edge_index.shape}")
    break

Result:

Example batch:
  Total atoms in batch: 864
  Total molecules in batch: 32
  Batch tensor shape: torch.Size([864])
  Edge index shape: torch.Size([2, 1744])

combinegraph

Batching Mechanism Explained:

  • 32 molecules contain 864 atoms total (average ≈ 27 atoms/molecule)

  • Batch tensor:

    [0, 0, 0, ..., 1, 1, 1, ..., 31, 31, 31]
    

    Maps each atom to its parent molecule (0–31)

  • Edge index combines all molecular graphs into one large disconnected graph

  • Enables efficient parallel processing on GPU

After combination, PyTorch Geometric treats a batch of graphs as one big disconnected graph!

  • Molecule 1: atoms 0–20
  • Molecule 2: atoms 21–35
  • No edges between molecules
  • Process all at once = FAST!

Step 6: Training the Model

Step 6.1: Training Components Setup

Step 6.1.1: Adam Optimizer Combines momentum with adaptive learning rates

\[\begin{aligned} m_t &= \beta_1 m_{t-1} + (1 - \beta_1)\,g_t \\ v_t &= \beta_2 v_{t-1} + (1 - \beta_2)\,g_t^2 \\ \theta_t &= \theta_{t-1} - \alpha\,\frac{m_t}{\sqrt{v_t + \epsilon}} \end{aligned}\]

optimizer

  • $m_t$ (Momentum):
    This term computes an exponential moving average of the gradients. It combines the current gradient $g_t$ with the historical average $m_{t-1}$:

    \[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\]

    Think of it as “smoothing out” the direction of updates. Instead of reacting to each new gradient immediately, $m_t$ helps the model maintain a stable direction over time.
    Imagine pushing a ball downhill: the more consistent the slope, the faster it builds momentum. Similarly, $m_t$ gives you an accumulated sense of which direction consistently reduces the loss.

  • $v_t$ (Adaptive scaling):
    This term accumulates the squared gradients to estimate how large and volatile each parameter’s updates are:

    \[v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\]

    Unlike $m_t$, $v_t$ only tracks the magnitude of the gradients (not their direction). It tells you whether the updates for a certain parameter dimension have been large or small recently.
    If a direction has large or noisy gradients, $v_t$ becomes large, and the update will be smaller in that direction. Think of it as driving over bumpy terrain: the bumpier it is, the more cautiously you move forward.

  • $\theta_t$ (Final update):
    This is the final step that uses both $m_t$ and $v_t$ to update the model parameters:

    \[\theta_t = \theta_{t-1} - \alpha \cdot \frac{m_t}{\sqrt{v_t} + \epsilon}\]

    The numerator ($m_t$) tells you which direction to move, and the denominator ($\sqrt{v_t}$) controls how big the step should be in that direction.
    If $v_t$ is large (unstable or steep), the denominator increases, shrinking the step size. If $v_t$ is small (smooth), the update is larger.
    The learning rate $\alpha$ scales the entire step size globally (e.g., 0.001), while $\epsilon$ is a small number added for numerical stability.

    This adaptive combination makes Adam very effective in complex landscapes — it balances speed and caution, moving quickly in flat regions and slowly where gradients are volatile.

Step 6.1.2: MSE Loss For regression tasks

\[L = \frac{1}{n} \sum_{i=1}^{n}\bigl(y_{\text{pred},i} - y_{\text{true},i}\bigr)^2\]

mse

Why MSE for Regression?

  • Measures average squared error
  • Penalizes big mistakes more than small ones
  • Always positive (squaring removes sign)
  • Has nice mathematical properties for optimization

Step 6.2: Training Setup

Training Setup Components
Model:
MolecularGNN
8,769 parameters
Optimizer:
Adam (lr=0.001)
Adaptive learning
Loss:
MSE Loss
For regression
# Initialize components
model = MolecularGNN(num_features=5, hidden_dim=64, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Device handling
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print("Training setup:")
print(f"  Optimizer: Adam (lr=0.001)")
print(f"  Loss function: MSE")
print(f"  Device: {device}")

Result:

Training setup:
  Optimizer: Adam (lr=0.001)
  Loss function: MSE
  Device: cpu

1. How does the model actually learn?

  • Forward pass: Input → Model → Prediction
  • Loss calculation: How wrong was the prediction?
  • Backward pass: Compute gradients (derivatives) using chain rule
  • Update: Adjust weights to reduce loss

gradient

2. What is a gradient? The gradient tells us “which way to adjust each parameter to reduce error.” Think of it like hiking—the gradient points uphill, so we go the opposite way to reach the valley (minimum loss).

3. Why learning rate = 0.001?

  • Too large (0.1): Might overshoot the minimum
  • Too small (0.00001): Training takes forever
  • 0.001: Good default for Adam optimizer

learningrate

Step 6.3: Training Function Implementation

Each training epoch processes all batches once, following the 5 steps as we have learned before:

Training Loop Steps
Step 1:
Zero gradients
Clear old values
Step 2:
Forward pass
Get predictions
Step 3:
Compute loss
How wrong?
Step 4:
Backward pass
Get gradients
Step 5:
Update weights
Improve model
def train_epoch(model, loader, optimizer, criterion, device):
    """
    One epoch of training
    
    Key steps:
    1. Zero gradients (PyTorch accumulates by default)
    2. Forward pass through GNN
    3. Compute loss
    4. Backpropagate gradients
    5. Update weights
    """
    model.train()
    total_loss = 0
    
    for batch in loader:
        batch = batch.to(device)
        
        # Reset gradients
        optimizer.zero_grad()
        
        # Forward pass
        out = model(batch.x, batch.edge_index, batch.batch)
        
        # Compute loss
        loss = criterion(out.squeeze(), batch.y)
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        
        total_loss += loss.item() * batch.num_graphs
    
    return total_loss / len(loader.dataset)

5gradient

Now that we’ve seen how optimizer.zero_grad(), loss.backward(), and optimizer.step() appear in every training loop, it’s important to understand what each of them actually does under the hood. These functions may look simple, but they each wrap a number of essential operations that make training work correctly.

1. optimizer.zero_grad()

  • PyTorch accumulates gradients by default
  • Without this, gradients would add up across batches
  • Like clearing a calculator before a new calculation

2. loss.backward()

  • Computes gradient of loss w.r.t. each parameter
  • Uses automatic differentiation (chain rule)
  • Fills the .grad attribute of each parameter

3. optimizer.step()

  • Updates parameters using computed gradients
  • Applies the Adam update rule
  • Parameters move in the direction that reduces loss

3func

Step 6.4: Evaluation Function

Evaluation without gradient computation saves memory:

Training vs Evaluation Mode
Training Mode:
• Compute gradients
• Update parameters
• Dropout active
Evaluation Mode:
• No gradients (save memory)
• No parameter updates
• Dropout inactive
def evaluate(model, loader, criterion, device):
    """Evaluation without gradient computation"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():  # Save memory
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(out.squeeze(), batch.y)
            total_loss += loss.item() * batch.num_graphs
    
    return total_loss / len(loader.dataset)

eval

Step 6.5: Training Execution

Train for 50 epochs with periodic logging:

Training Process Overview
Epochs:
50 total
Full dataset passes
Monitoring:
Train & test loss
Track progress
Logging:
Every 10 epochs
Avoid clutter
num_epochs = 50
train_losses = []
test_losses = []

print("Starting training...")
print("-" * 60)

for epoch in range(num_epochs):
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    train_losses.append(train_loss)
    
    # Evaluate
    test_loss = evaluate(model, test_loader, criterion, device)
    test_losses.append(test_loss)
    
    # Log progress
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")

print("-" * 60)
print("Training completed!")

traineval

Result:

Starting training...
------------------------------------------------------------
Epoch  10 | Train Loss: 3.8593 | Test Loss: 4.1164
Epoch  20 | Train Loss: 3.7131 | Test Loss: 4.0329
Epoch  30 | Train Loss: 3.6504 | Test Loss: 4.0042
Epoch  40 | Train Loss: 3.5853 | Test Loss: 3.8130
Epoch  50 | Train Loss: 3.4851 | Test Loss: 3.7270
------------------------------------------------------------
Training completed!

Training Analysis:

  • Initial loss ≈ 10 (not shown) → Final loss ≈ 3.5
  • Test loss closely follows training loss (good generalization)
  • Loss of 3.7 corresponds to RMSE = $\sqrt{3.7}$ ≈ 1.92 log S
  • Small train–test gap (3.49 vs 3.73) indicates appropriate model capacity

Step 6.6: Visualizing Training Progress

Training Curves Analysis
What to look for:
• Both curves decreasing
• Test follows train
• No divergence
Warning signs:
• Test loss increases
• Large gap develops
• Erratic behavior
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss', linewidth=2)
plt.plot(test_losses, label='Test Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('MSE Loss', fontsize=12)
plt.title('Training Progress', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Training Progress

Curve Interpretation:

  • Epochs 1-10: Rapid loss decrease (learning basic patterns)
  • Epochs 10-30: Gradual improvement (fine-tuning)
  • Epochs 30-50: Plateau (approaching model capacity)
  • No overfitting: Test loss doesn’t increase

Reading Training Curves:

  • Good sign: Test loss follows training loss
  • Bad sign: Test loss increases while training decreases (overfitting)
  • Our case: Slight gap but both decrease = healthy learning!

Step 7: Model Evaluation

Step 7.1 Understanding the Evaluation Metrics

Metric Formula Interpretation What It Tells Us
RMSE $\sqrt{\frac{1}{n}\sum_{i=1}^{n}(y_{pred} - y_{true})^2}$ Average error magnitude Typical prediction error in log S units
MAE $\frac{1}{n}\sum_{i=1}^{n}|y_{pred} - y_{true}|$ Typical prediction error Less sensitive to outliers than RMSE
$1 - \frac{SS_{res}}{SS_{tot}}$ Variance explained (0-1) % of data variation our model captures

3metrics

Understanding R² in Detail

The R² formula:

\[R^2 = 1 - \frac{SS_{\text{res}}}{SS_{\text{tot}}}\]
  • $SS_{\text{tot}}$ = Total variance = $\sum (y_i - \bar{y})^2$
  • $SS_{\text{res}}$ = Residual variance = $\sum (y_i - \hat{y}_i)^2$
  • $\bar{y}$ = Mean of true values
  • $\hat{y}_i$ = Our predictions

Interpretation:

  • R² = 1.0: Perfect predictions
  • R² = 0.5: Model explains 50% of variance
  • R² = 0.0: No better than predicting the mean
  • R² < 0.0: Worse than predicting the mean!

Step 7.2: Prediction Extraction Function

Prediction Collection Process
Step 1:
Set eval mode
No dropout
Step 2:
Loop batches
Collect predictions
Step 3:
Return arrays
For sklearn metrics
def get_predictions(model, loader, device):
    """Extract all predictions for evaluation"""
    model.eval()
    predictions = []
    true_values = []
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            predictions.extend(out.squeeze().cpu().numpy())
            true_values.extend(batch.y.cpu().numpy())
    
    return np.array(predictions), np.array(true_values)

evalmetric

Step 7.3: Metric Calculation

Evaluation Metrics Calculation
RMSE:
1.931 log S
Typical error size
MAE:
1.602 log S
Average absolute error
R²:
0.219
22% variance explained
# Get test set predictions
test_preds, test_true = get_predictions(model, test_loader, device)

# Calculate metrics
rmse = np.sqrt(mean_squared_error(test_true, test_preds))
mae = np.mean(np.abs(test_true - test_preds))
r2 = r2_score(test_true, test_preds)

print("Model Performance on Test Set:")
print(f"  RMSE: {rmse:.3f} log S")
print(f"  MAE:  {mae:.3f} log S")
print(f"  R²:   {r2:.3f}")
print(f"\nInterpretation:")
print(f"  - On average, predictions are off by {mae:.2f} log units")
print(f"  - The model explains {r2*100:.1f}% of the variance in solubility")

Result:

Model Performance on Test Set:
  RMSE: 1.931 log S
  MAE:  1.602 log S
  R²:   0.219

Interpretation:
  - On average, predictions are off by 1.60 log units
  - The model explains 21.9% of the variance in solubility

Performance Reality Check:

  • MAE = 1.6 log units → $10^{1.6} \approx 40 \times$ error in concentration
  • $R^2 = 0.22$ means the model explains only 22% of variance
  • Why seemingly poor performance?

    • Only 5 simple atomic features
    • No bond features or 3D information
    • Solubility spans 13 orders of magnitude!
  • State-of-the-art models achieve $R^2 \approx 0.9$ with richer features

Is $R^2 = 0.22$ Actually Bad?

  • For production: Yes, too low for drug development
  • For learning: No! Shows our simple model captures real patterns
  • Context: Random guessing would give $R^2 \approx 0$
  • Improvement potential: Adding more features could reach $R^2 > 0.8$

Step 7.4: Prediction Visualization

Create scatter plot to visualize prediction quality:

Scatter Plot Components
Points:
Test predictions
200 molecules
Red line:
Perfect predictions
y = x
Gray band:
±1 log S error
10× concentration
plt.figure(figsize=(8, 8))

# Scatter plot
plt.scatter(test_true, test_preds, alpha=0.6, edgecolors='black', linewidth=0.5)

# Perfect prediction line
min_val = min(test_true.min(), test_preds.min())
max_val = max(test_true.max(), test_preds.max())
plt.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect predictions')

# Error bands
plt.fill_between([min_val, max_val], [min_val-1, max_val-1], [min_val+1, max_val+1], 
                 alpha=0.2, color='gray', label='±1 log S error')

plt.xlabel('True Solubility (log S)', fontsize=12)
plt.ylabel('Predicted Solubility (log S)', fontsize=12)
plt.title(f'GNN Predictions vs True Values (R² = {r2:.3f})', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

GNN Predictions vs True Values

Scatter Plot Interpretation:

  • General trend captured: Points follow diagonal direction
  • High variance: Significant scatter around ideal line
  • Regression to mean: Extreme values pulled toward center
  • ±1 log band: Most predictions within acceptable error range

How to Read This Plot:

  • Perfect model: All points on red dashed line
  • Good model: Points clustered near the line
  • Our model: General trend but wide spread
  • Gray band: ±1 log S is approximately a 10× error in real concentration

Step 7.5: Error Analysis

Error Distribution Visualization:

Error Analysis Plots
Left plot:
Error distribution
• Should be centered at 0
• Narrow = better
Right plot:
Error vs true value
• Look for patterns
• Random = good
# Calculate errors
errors = test_preds - test_true

plt.figure(figsize=(12, 5))

# Error distribution
plt.subplot(1, 2, 1)
plt.hist(errors, bins=30, edgecolor='black', alpha=0.7)
plt.axvline(x=0, color='red', linestyle='--', linewidth=2)
plt.xlabel('Prediction Error (log S)', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.title('Distribution of Prediction Errors', fontsize=14)
plt.grid(True, alpha=0.3)

Error vs True Value:

# Error vs true value
plt.subplot(1, 2, 2)
plt.scatter(test_true, errors, alpha=0.6)
plt.axhline(y=0, color='red', linestyle='--', linewidth=2)
plt.xlabel('True Solubility (log S)', fontsize=12)
plt.ylabel('Prediction Error (log S)', fontsize=12)
plt.title('Prediction Error vs True Value', fontsize=14)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Error Distribution and Analysis

Error Statistics:

Error Statistics Summary
Central tendency:
• Mean: -0.368 (slight bias)
• Median: ~1.4 (typical error)
Spread:
• Std dev: 1.895
• 95% within ±3.5 log S
print(f"Error Statistics:")
print(f"  Mean error: {np.mean(errors):.3f} log S")
print(f"  Std deviation: {np.std(errors):.3f} log S")
print(f"  Median absolute error: {np.median(np.abs(errors)):.3f} log S")
print(f"  95% of errors within: ±{np.percentile(np.abs(errors), 95):.3f} log S")

Result:

Error Statistics:
  Mean error: -0.368 log S
  Std deviation: 1.895 log S
  Median absolute error: 1.448 log S
  95% of errors within: ±3.480 log S

Error Pattern Analysis:

  • Negative bias (–0.368): The model slightly underpredicts solubility
  • Normal distribution: No systematic failures detected
  • Heteroscedasticity: Larger errors occur at extreme solubilities
  • 95% confidence: Most errors are within ±3.5 log units

For drug discovery screening, this level of accuracy is often sufficient to filter candidates.

Step 8: Making Predictions on New Molecules

Step 8.1: Prediction Pipeline Function

Create user-friendly prediction function:

Prediction Pipeline Steps
Step 1:
SMILES → Graph
Convert molecule
Step 2:
Prepare batch
Single molecule
Step 3:
Model inference
Get prediction
Step 4:
Return value
Extract number
def predict_solubility(smiles, model, device):
    """
    Complete prediction pipeline for new molecules
    
    Steps:
    1. SMILES → Graph conversion
    2. Feature extraction
    3. Model inference
    4. Return prediction
    """
    # Convert to graph
    graph = molecule_to_graph(smiles)
    if graph is None:
        return None, "Invalid SMILES"
    
    # Prepare for model
    graph = graph.to(device)
    batch = torch.zeros(graph.x.size(0), dtype=torch.long).to(device)
    
    # Predict
    model.eval()
    with torch.no_grad():
        prediction = model(graph.x, graph.edge_index, batch)
    
    return prediction.item(), "Success"

inference

Step 8.2: Testing on Known Molecules

Test predictions on common molecules:

Test Molecule Selection
Categories:
• Water & alcohols
• Solvents
• Hydrocarbons
• Aromatics
Purpose:
• Test diverse structures
• Check chemical intuition
• Validate model behavior
test_molecules = [
    ("O", "Water"),
    ("CCO", "Ethanol"),
    ("CC(=O)C", "Acetone"),
    ("c1ccccc1", "Benzene"),
    ("CC(=O)O", "Acetic acid"),
    ("CCCCCl", "1-Chlorobutane"),
    ("CC(C)C", "Isobutane"),
    ("C1CCCCC1", "Cyclohexane"),
    ("c1ccc(O)cc1", "Phenol"),
    ("CC(=O)Oc1ccccc1C(=O)O", "Aspirin")
]

testingpipe

Step 8.2.1: Prediction Loop:

Prediction Results Format
Output format:
• Molecule name
• SMILES string
• Predicted log S
• Status
Collection:
• Store in list
• For visualization
• Analyze patterns
print("Predictions for common molecules:")
print("-" * 60)
print(f"{'Molecule':<20} {'SMILES':<25} {'Predicted log S':<15} {'Status'}")
print("-" * 60)

predictions_list = []
for smiles, name in test_molecules:
    pred, status = predict_solubility(smiles, model, device)
    if status == "Success":
        predictions_list.append((name, pred))
        print(f"{name:<20} {smiles:<25} {pred:>10.3f}      {status}")

Result:

Predictions for common molecules:
------------------------------------------------------------
Molecule             SMILES                    Predicted log S Status
------------------------------------------------------------
Water                O                              -1.126      Success
Ethanol              CCO                            -2.384      Success
Acetone              CC(=O)C                        -2.531      Success
Benzene              c1ccccc1                       -4.061      Success
Acetic acid          CC(=O)O                        -2.425      Success
1-Chlorobutane       CCCCCl                         -2.853      Success
Isobutane            CC(C)C                         -2.711      Success
Cyclohexane          C1CCCCC1                       -3.037      Success
Phenol               c1ccc(O)cc1                    -3.903      Success
Aspirin              CC(=O)Oc1ccccc1C(=O)O          -3.580      Success

testinglogic

Chemical Interpretation:

  • Water (-1.126): Most soluble prediction (actual water has infinite solubility)
  • Small polar molecules (-2.4 to -2.5): Ethanol, acetone, acetic acid show similar moderate solubility
  • Hydrocarbons (-2.7 to -3.0): Isobutane and cyclohexane less soluble
  • Aromatics (< -3.5): Benzene and phenol least soluble, reflecting hydrophobic π-systems

Step 8.2.2: Visualizing Predictions

Create bar chart with solubility-based coloring:

Bar Chart Design
Color scheme:
• Green: > 0
• Yellow: -3 to 0
• Red: < -3
Features:
• Values on bars
• Rotated labels
• Grid for clarity
Purpose:
• Visual comparison
• Group by solubility
• Quick insights
names, preds = zip(*predictions_list)

plt.figure(figsize=(12, 6))
bars = plt.bar(range(len(names)), preds, color='skyblue', edgecolor='navy', linewidth=1.5)

# Color by solubility level
for i, (bar, pred) in enumerate(zip(bars, preds)):
    if pred > 0:
        bar.set_color('lightgreen')
    elif pred < -3:
        bar.set_color('lightcoral')
    else:
        bar.set_color('lightyellow')

Chart Completion:

plt.xlabel('Molecule', fontsize=12)
plt.ylabel('Predicted Solubility (log S)', fontsize=12)
plt.title('GNN Solubility Predictions for Common Molecules', fontsize=14)
plt.xticks(range(len(names)), names, rotation=45, ha='right')
plt.grid(True, alpha=0.3, axis='y')
plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)

# Add values
for i, (bar, pred) in enumerate(zip(bars, preds)):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
            f'{pred:.2f}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

GNN Solubility Predictions for Common Molecules

Bar Chart Interpretation:

  • Green bars: Would indicate high solubility (> 0 log S) - none in our test set
  • Yellow bars: Moderate solubility (-3 to 0 log S) - most small organics
  • Red bars: Low solubility (< -3 log S) - aromatics like benzene and phenol
  • Chemical sense: Rankings match chemical intuition despite simple model

Step 9: Analyzing Learned Patterns

Step 9.1: Structure-Activity Relationships Testing

Systematic testing of molecular features:

Structure-Activity Test Design
Test 1: Functional Groups
• Base: Hexane
• Add: -OH, -COOH, -NH₂
• Check: Effect on solubility
Test 2: Chain Length
• C₂ → C₄ → C₆ → C₈
• Track: Size effect
• Expect: Decreasing solubility
Test 3: Aromaticity
• Benzene vs Cyclohexane
• Compare: π-system effect
• Expect: Large difference
functional_group_tests = [
    # Base molecule
    ("CCCCCC", "Hexane (hydrophobic)"),
    # Add polar groups
    ("CCCCCCO", "1-Hexanol (add -OH)"),
    ("CCCCCC(=O)O", "Hexanoic acid (add -COOH)"),
    ("CCCCCCN", "1-Hexylamine (add -NH2)"),
    # Size series
    ("CC", "Ethane"),
    ("CCCC", "Butane"),
    ("CCCCCCCC", "Octane"),
    # Aromaticity comparison
    ("c1ccccc1", "Benzene (aromatic)"),
    ("C1CCCCC1", "Cyclohexane (aliphatic)")
]

testingcases

Step 9.2: Testing Loop

Analysis Purpose
What we're testing:
• Model's chemical understanding
• Feature importance
• Learned patterns
Expected insights:
• Which features work well
• Model limitations
• Improvement areas
print("Analyzing functional group effects on solubility:")
print("-" * 70)
print(f"{'Description':<35} {'SMILES':<20} {'Predicted log S':<15}")
print("-" * 70)

for smiles, desc in functional_group_tests:
    pred, _ = predict_solubility(smiles, model, device)
    print(f"{desc:<35} {smiles:<20} {pred:>10.3f}")

Result:

Analyzing functional group effects on solubility:
----------------------------------------------------------------------
Description                         SMILES               Predicted log S
----------------------------------------------------------------------
Hexane (hydrophobic)                CCCCCC                     -2.808
1-Hexanol (add -OH)                 CCCCCCO                    -2.759
Hexanoic acid (add -COOH)           CCCCCC(=O)O                -2.795
1-Hexylamine (add -NH2)             CCCCCCN                    -2.775
Ethane                              CC                         -2.458
Butane                              CCCC                       -2.710
Octane                              CCCCCCCC                   -2.861
Benzene (aromatic)                  c1ccccc1                   -4.061
Cyclohexane (aliphatic)             C1CCCCC1                   -3.037

Key Findings:

  1. Weak Functional Group Effects:
    • Hexane → Hexanol: Only 0.049 log unit improvement
    • Expected: –OH should increase solubility by ~1–2 log units
    • Limitation: Our 5 features don’t capture hydrogen bonding strength
  2. Clear Size Trend:
    • C₂ (–2.458) → C₄ (–2.710) → C₆ (–2.808) → C₈ (–2.861)
    • $\Delta \log S \approx -0.05$ per CH₂ group
    • Success: Model learned hydrophobic effect of alkyl chains
  3. Strong Aromaticity Effect:
    • Benzene vs. Cyclohexane: 1.024 log unit difference
    • Success: Model recognizes π-system hydrophobicity
    • Aromatic feature in our encoding is highly informative

Summary and Conclusions

What We Built

Complete GNN Pipeline:

  1. Data Processing: SMILES → graph conversion with RDKit
  2. Feature Engineering: 5 atomic features + bidirectional edges
  3. Model Architecture: 3-layer GCN with 8,769 parameters
  4. Training: 800 molecules, 50 epochs, Adam optimizer
  5. Deployment: Prediction function for new molecules

review

Performance Metrics:

  • RMSE: 1.93 log S
  • MAE: 1.60 log S
  • $R^2$: 0.22
  • 95% predictions within ± 3.5 log units

Limitations and Future Improvements

Current Limitation Proposed Solution Expected Impact
Weak functional group discrimination Add H-bond donors/acceptors counts +0.1-0.2 R²
No bond information Include bond type, conjugation +0.15 R²
Simple mean pooling Attention-based pooling +0.1 R²
Basic GCN layers GAT or MPNN architectures +0.2-0.3 R²
No 3D information Add 3D coordinates +0.1-0.15 R²

Key Takeaways

Scientific Insights:

  • GNNs naturally handle variable-sized molecules without fixed fingerprints
  • Message passing captures local chemical environments effectively
  • Even simple features yield chemically sensible predictions
  • Model learned size effects and aromaticity impact correctly

Practical Considerations:

  • $R^2 = 0.22$ is insufficient for production use but excellent for teaching
  • With proposed improvements, could reach $R^2 > 0.8$
  • Same architecture applies to any molecular property
  • Computational efficiency: ~1 ms per molecule prediction

insight

Applications Beyond Solubility:

  • Drug–target binding affinity (change target to pIC50)
  • Toxicity prediction (classification instead of regression)
  • Material properties (glass transition, melting point)
  • Reaction outcome prediction (with reaction graphs)

What We Learned About Deep Learning

Key Concepts Demystified:

  • Backpropagation: Automatic calculation of gradients using chain rule
  • Optimization: Iteratively adjusting parameters to minimize loss
  • Message Passing: Atoms sharing information through bonds
  • Batching: Processing multiple graphs simultaneously for efficiency

What We Didn’t Cover (But Should Know):

  • Parameter Initialization: How weights start (Xavier, He initialization)
  • Learning Rate Scheduling: Adjusting learning rate during training
  • Regularization: Preventing overfitting (dropout, weight decay)
  • Advanced GNN Issues: Over-smoothing, gradient vanishing (covered in next section)

The beauty of GNNs for chemistry lies in their natural alignment with molecular structure. While our simple model achieves modest performance, it demonstrates the complete pipeline from molecules to predictions. This foundation, enhanced with richer features and advanced architectures, powers modern drug discovery and materials design platforms.


3.3.4 Challenges and Interpretability in GNNs

Completed and Compiled Code: Click Here

What We're Exploring: Fundamental Challenges in Graph Neural Networks

Why Study GNN Challenges?

  • Over-smoothing: Why deeper isn't always better - node features become indistinguishable
  • Interpretability: Understanding what the model learns - crucial for drug discovery
  • Real Impact: These challenges affect whether GNNs can be trusted in production

What you'll learn: The fundamental limitations of GNNs and current solutions to overcome them

Challenge What Happens Why It Matters Solutions
Over-smoothing Node features converge
All atoms look the same
Limits network depth
Can't capture long-range interactions
Residual connections
Skip connections, normalization
Interpretability Black box predictions
Don't know why it predicts
No trust in predictions
Can't guide drug design
Attention visualization
Substructure explanations

While GNNs have shown remarkable success in molecular property prediction, they face several fundamental challenges that limit their practical deployment. In this section, we’ll explore two critical issues: the over-smoothing phenomenon that limits network depth, and the interpretability challenge that makes it difficult to understand model predictions.

The Power of Depth vs. The Curse of Over-smoothing

In Graph Neural Networks (GNNs), adding more message-passing layers allows nodes (atoms) to gather information from increasingly distant parts of a graph (molecule). At first glance, it seems deeper networks should always perform better—after all, more layers mean more context. But in practice, there’s a major trade-off known as over-smoothing.

Understanding Over-smoothing

Concept Simple Explanation Molecular Context
Message Passing Atoms share info with neighbors Like atoms "talking" through bonds
Receptive Field How far information travels k layers = k-hop neighborhood
Over-smoothing All nodes become similar Can't distinguish different atoms
Critical Depth ~3-5 layers typically Beyond this, performance drops

What to Demonstrate

Before we jump into the code, here’s what it’s trying to show:

We want to measure how similar node embeddings become as we increase the number of GCN layers. If all node vectors become nearly identical after several layers, that means the model is losing resolution—different atoms can’t be distinguished anymore. This is called over-smoothing.

Key Functions and Concepts
GCNConv
Graph convolution layer
Aggregates neighbor features
F.relu()
Non-linear activation
Adds expressiveness
F.normalize()
L2 normalization
For cosine similarity
torch.mm()
Matrix multiplication
Computes similarity matrix

Functions and Concepts Used

  • GCNConv (from torch_geometric.nn): This is a standard Graph Convolutional Network (GCN) layer. It performs message passing by aggregating neighbor features and updating node embeddings. It normalizes messages by node degrees to prevent high-degree nodes from dominating.

  • F.relu(): Applies a non-linear ReLU activation function after each GCN layer. This introduces non-linearity to the model, allowing it to learn more complex patterns.

  • F.normalize(..., p=2, dim=1): This normalizes node embeddings to unit length (L2 norm), which is required for cosine similarity calculation.

  • torch.mm(): Matrix multiplication is used here to compute the full cosine similarity matrix between normalized node embeddings.

  • Cosine similarity: Measures how aligned two vectors are (value close to 1 means very similar). By averaging all pairwise cosine similarities, we can track whether the node representations are collapsing into the same vector.

Graph Construction

We use a 6-node ring structure as a simple molecular graph. Each node starts with a unique identity (using identity matrix torch.eye(6) as input features), and all nodes are connected in a cycle:

Graph Construction Process
Step 1:
Create node features
Identity matrix (6×6)
Step 2:
Define ring topology
Each node → 2 neighbors
Step 3:
Make bidirectional
12 directed edges total
Result:
PyG Data object
Ready for GNN
import torch
from torch_geometric.data import Data

# Each node has a unique 6D feature vector (identity matrix)
x = torch.eye(6)

# Define edges for a 6-node cycle (each edge is bidirectional)
edge_index = torch.tensor([
    [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5, 0, 5, 0, 1, 2, 3, 4]
], dtype=torch.long)

# Create PyTorch Geometric graph object
data = Data(x=x, edge_index=edge_index)

Over-smoothing Analysis

Now we apply the same GCN layer multiple times to simulate a deeper GNN. After each layer, we re-compute the node embeddings and compare them using cosine similarity:

Over-smoothing Measurement Process
Apply GCN layers:
Stack 1-10 layers
Same layer repeated
Compute similarity:
Cosine between nodes
Average all pairs
Track convergence:
Plot vs depth
Watch similarity → 1
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv

def measure_smoothing(num_layers, data):
    """
    Apply num_layers GCNConv layers and measure
    how similar node embeddings become.
    """
    x = data.x
    for _ in range(num_layers):
        conv = GCNConv(x.size(1), x.size(1))
        x = F.relu(conv(x, data.edge_index))

    # Normalize embeddings for cosine similarity
    x_norm = F.normalize(x, p=2, dim=1)
    
    # Cosine similarity matrix
    similarity_matrix = torch.mm(x_norm, x_norm.t())
    
    # Exclude diagonal (self-similarity) when averaging
    n = x.size(0)
    mask = ~torch.eye(n, dtype=torch.bool)
    avg_similarity = similarity_matrix[mask].mean().item()
    
    return avg_similarity

# Run for different GNN depths
depths = [1, 3, 5, 10]
sims = []
for depth in depths:
    sim = measure_smoothing(depth, data)
    sims.append(sim)
    print(f"Depth {depth}: Average similarity = {sim:.3f}")

# Plot the smoothing effect
plt.plot(depths, sims, marker='o')
plt.xlabel("Number of GCN Layers")
plt.ylabel("Average Cosine Similarity")
plt.title("Over-smoothing Effect in GNNs")
plt.grid(True)
plt.show()

Output

Depth 1: Average similarity = 0.406
Depth 3: Average similarity = 0.995
Depth 5: Average similarity = 0.993
Depth 10: Average similarity = 1.000

Over-smoothing in GNNs

Interpretation of Results

Depth Similarity What It Means Practical Impact
1 layer 0.406 Nodes still distinct Can identify different atoms
3 layers 0.995 Nearly identical Losing atomic identity
5 layers 0.993 Effectively same No useful information
10 layers 1.000 Complete collapse Model is useless

As shown above, as the number of message-passing layers increases, node representations converge. Initially distinct feature vectors (left) become nearly indistinguishable after several layers (right), resulting in the loss of structural information. This phenomenon is known as over-smoothing and is a critical limitation of deep GNNs.

Interpretation

As we can see, even at just 3 layers, the node embeddings become nearly identical. By 10 layers, the model has effectively lost all ability to distinguish individual atoms. This is the core issue of over-smoothing—deep GNNs can blur out meaningful structural differences.

Solutions to Over-smoothing

Technique How It Works Implementation Effectiveness
Residual Connections Skip connections preserve original features x = x + GCN(x) Very effective
Feature Concatenation Combine features from multiple layers concat(x₁, x₂, ...) Good for shallow nets
Batch Normalization Normalize features per layer BatchNorm after GCN Moderate help
Jumping Knowledge Aggregate all layer outputs JK networks State-of-the-art

To mitigate this problem, modern GNNs use techniques like:

  • Residual connections (skip connections that reintroduce raw input)
  • Feature concatenation from earlier layers
  • Batch normalization or graph normalization
  • Jumping knowledge networks to combine representations from multiple layers

When working with molecular graphs, you should choose the depth of your GNN carefully. It should be deep enough to capture important substructures, but not so deep that you lose atomic-level details.

Interpretability in Molecular GNNs

Beyond the technical challenge of over-smoothing, GNNs face a critical issue of interpretability. When a model predicts that a molecule might be toxic or have specific properties, chemists need to understand which structural features drive that prediction. This “black box” nature of neural networks is particularly problematic in chemistry, where understanding structure-activity relationships is fundamental to rational drug design.

Why Interpretability Matters in Chemistry

Stakeholder Need Example Impact
Medicinal Chemists Understand SAR
Structure-Activity Relationships
Which groups increase potency? Guide drug optimization
Regulatory Bodies Safety justification
Why is it safe?
Explain toxicity predictions FDA approval
Researchers Scientific insight
New mechanisms
Discover new pharmacophores Advance knowledge
Industry Risk assessment
Confidence in predictions
Why invest in this molecule? Resource allocation

Recent advances in GNN interpretability for molecular applications have taken several promising directions:

Attention-Based Methods:

Attention-Based Interpretability
Method:
Graph Attention Networks
GATs
How it works:
Learn importance weights
α_ij for each edge
Visualization:
Highlight important bonds
Thicker = more important
Reference:
Veličković et al., 2017
ICLR

Graph Attention Networks (GATs) provide built-in interpretability through their attention mechanisms, allowing researchers to visualize which atoms or bonds the model considers most important for a given prediction [1,2]. This approach naturally aligns with chemical intuition about reactive sites and functional groups.

Substructure-Based Explanations:

Substructure Mask Explanation (SME)
Innovation:
Fragment-based
Not just atoms
Alignment:
Chemical intuition
Functional groups
Application:
Toxicophore detection
Find toxic substructures
Reference:
Nature Comms, 2023
14, 2585

The Substructure Mask Explanation (SME) method represents a significant advance by providing interpretations based on chemically meaningful molecular fragments rather than individual atoms or edges [3]. This approach uses established molecular segmentation methods to ensure explanations align with chemists’ understanding, making it particularly valuable for identifying pharmacophores and toxicophores.

Integration of Chemical Knowledge:

Pharmacophore-Integrated GNNs
Concept:
Hierarchical modeling
Multi-level structure
Benefit 1:
Better performance
Domain knowledge helps
Benefit 2:
Natural interpretability
Pharmacophore-level
Reference:
J Cheminformatics, 2022
14, 49

Recent work has shown that incorporating pharmacophore information hierarchically into GNN architectures not only improves prediction performance but also enhances interpretability by explicitly modeling chemically meaningful substructures [4]. This bridges the gap between data-driven learning and domain expertise.

Gradient-Based Attribution:

SHAP for Molecular GNNs
Method:
SHapley values
Game theory based
Advantage:
Rigorous foundation
Additive features
Output:
Feature importance
Per atom/bond
Reference:
Lundberg & Lee, 2017
NeurIPS

Methods like SHAP (SHapley Additive exPlanations) have been successfully applied to molecular property prediction, providing feature importance scores that help identify which molecular characteristics most influence predictions [5,6]. These approaches are particularly useful for understanding global model behavior across different molecular classes.

Comparative Studies:

GNNs vs Traditional Methods
Aspect GNNs Descriptor-based Recommendation
Performance Often superior
Complex patterns
Good baseline
Well-understood
Task-dependent
Interpretability Challenging
Requires extra work
Built-in
Known features
Hybrid approach
Reference Jiang et al., 2021, J Cheminformatics

Recent comparative studies have shown that while GNNs excel at learning complex patterns, traditional descriptor-based models often provide better interpretability through established chemical features, suggesting a potential hybrid approach combining both paradigms [6].

The Future: Interpretable-by-Design

The field is moving toward interpretable-by-design architectures rather than post-hoc explanation methods. As noted by researchers, some medicinal chemists value interpretability over raw accuracy if a small sacrifice in performance can significantly enhance understanding of the model's reasoning [3]. This reflects a broader trend in molecular AI toward building systems that augment rather than replace human chemical intuition.

Design Principle Implementation Example
Chemical hierarchy Multi-scale architectures Atom → Group → Molecule
Explicit substructures Pharmacophore encoding H-bond donors as nodes
Modular predictions Separate property modules Solubility + Toxicity branches

References:

[1] Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., & Bengio, Y. (2017). Graph Attention Networks. International Conference on Learning Representations.

[2] Yuan, H., Yu, H., Gui, S., & Ji, S. (2022). Explainability in graph neural networks: A taxonomic survey. IEEE Transactions on Pattern Analysis and Machine Intelligence.

[3] Chemistry-intuitive explanation of graph neural networks for molecular property prediction with substructure masking. (2023). Nature Communications, 14, 2585.

[4] Integrating concept of pharmacophore with graph neural networks for chemical property prediction and interpretation. (2022). Journal of Cheminformatics, 14, 52.

[5] Lundberg, S. M., & Lee, S. I. (2017). A unified approach to interpreting model predictions. Advances in Neural Information Processing Systems, 30, 4765-4774.

[6] Jiang, D., Wu, Z., Hsieh, C. Y., Chen, G., Liao, B., Wang, Z., … & Hou, T. (2021). Could graph neural networks learn better molecular representation for drug discovery? A comparison study of descriptor-based and graph-based models. Journal of Cheminformatics, 13(1), 1-23.

Summary

Key Takeaways: Challenges and Solutions

Challenge Impact Current Solutions Future Directions
Over-smoothing Limits depth to 3-5 layers
Can't capture long-range
• Residual connections
• Jumping knowledge
• Normalization
Novel architectures
Beyond message passing
Interpretability Low trust & adoption
Can't guide design
• Attention visualization
• SHAP values
• Substructure masking
Interpretable-by-design
Chemical hierarchy

The Path Forward:

  • Balance accuracy with interpretability - Sometimes 90% accuracy with clear explanations beats 95% black box
  • Incorporate domain knowledge - Chemical principles should guide architecture design
  • Develop hybrid approaches - Combine GNN power with traditional descriptor interpretability
  • Focus on augmenting chemists - Tools should enhance, not replace, human expertise

The challenges facing molecular GNNs—over-smoothing and interpretability—are significant but surmountable. Over-smoothing limits the depth of networks we can effectively use, constraining the model’s ability to capture long-range molecular interactions. Meanwhile, the interpretability challenge affects trust and adoption in real-world applications where understanding model decisions is crucial.

Current solutions include architectural innovations like residual connections to combat over-smoothing, and various interpretability methods ranging from attention visualization to substructure-based explanations. The key insight is that effective molecular AI systems must balance predictive power with chemical interpretability, ensuring that models not only make accurate predictions but also provide insights that align with and enhance human understanding of chemistry.

As the field progresses, the focus is shifting from purely accuracy-driven models to systems that provide transparent, chemically meaningful explanations for their predictions. This evolution is essential for GNNs to fulfill their promise as tools for accelerating molecular discovery and understanding.


Section 3.3 – Quiz Questions

1) Factual Questions

Question 1

What is the primary advantage of using Graph Neural Networks (GNNs) over traditional neural networks for molecular property prediction?

A. GNNs require less computational resources
B. GNNs can directly process the graph structure of molecules
C. GNNs always achieve higher accuracy than other methods
D. GNNs work only with small molecules

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: GNNs can directly process molecules as graphs where atoms are nodes and bonds are edges, preserving the structural information that is crucial for determining molecular properties.

Question 2

In the message passing mechanism of GNNs, what happens during the aggregation step?

A. Node features are updated using a neural network
B. Messages from neighboring nodes are combined
C. Edge features are initialized
D. The final molecular prediction is made

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: During aggregation, all incoming messages from neighboring nodes are combined (typically by summing or averaging) to form a single aggregated message for each node.

Question 3

Which of the following molecular representations is most suitable as input for a Graph Neural Network?

A. SMILES string directly as text
B. 2D image of the molecular structure
C. Graph with nodes as atoms and edges as bonds
D. List of molecular descriptors only

▶ Click to show answer Correct Answer: C
▶ Click to show explanation Explanation: GNNs are designed to work with graph-structured data where nodes represent atoms and edges represent chemical bonds, allowing the model to learn from the molecular connectivity.

Question 4

What is the “over-smoothing” problem in Graph Neural Networks?

A. The model becomes too complex to train
B. Node representations become increasingly similar in deeper networks
C. The model cannot handle large molecules
D. Training takes too much time

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: Over-smoothing occurs when deep GNNs make node representations increasingly similar across layers, losing the ability to distinguish between different atoms and their local environments.

2) Conceptual Questions

Question 5

You want to build a GNN to predict molecular solubility (a continuous value). Which combination of pooling and output layers would be most appropriate?

A.

# Mean pooling + regression output
x = global_mean_pool(x, batch)
output = nn.Linear(hidden_dim, 1)(x)

B.

# Max pooling + classification output  
x = global_max_pool(x, batch)
output = nn.Sequential(nn.Linear(hidden_dim, 2), nn.Softmax())(x)

C.

# No pooling + multiple outputs
output = nn.Linear(hidden_dim, num_atoms)(x)

D.

# Sum pooling + sigmoid output
x = global_add_pool(x, batch) 
output = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())(x)
▶ Click to show answer Correct Answer: A
▶ Click to show explanation Explanation: For continuous property prediction (regression), we need to pool node features to get a molecular-level representation, then use a linear layer to output a single continuous value. Mean pooling is commonly used and effective for this purpose.
▶ Click to see code: Complete GNN architecture for solubility prediction

# Complete GNN for solubility prediction
class SolubilityGNN(nn.Module):
    def __init__(self, node_features, hidden_dim=64):
        super(SolubilityGNN, self).__init__()
        self.conv1 = GCNConv(node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.predictor = nn.Linear(hidden_dim, 1)
    
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)  # Pool to molecular level
        return self.predictor(x)        # Single continuous output

Question 6

A chemist notices that their GNN model performs well on training molecules but poorly on a new set of structurally different compounds. What is the most likely cause and solution?

A. The model is too simple; add more layers
B. The model suffers from distribution shift; collect more diverse training data
C. The learning rate is too high; reduce it
D. The model has too many parameters; reduce model size

▶ Click to show answer Correct Answer: B
▶ Click to show explanation Explanation: This scenario describes distribution shift, where the model was trained on one chemical space but tested on a different one. The solution is to include more diverse molecular structures in the training data to improve generalization.
▶ Click to see code: Data augmentation for chemical space diversity

# Data augmentation to improve generalization
def augment_chemical_space(original_smiles_list):
    """Expand training data with structural diversity"""
    augmented_data = []
    
    for smiles in original_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        
        # Add original
        augmented_data.append(smiles)
        
        # Add different SMILES representations
        for _ in range(3):
            random_smiles = Chem.MolToSmiles(mol, doRandom=True)
            augmented_data.append(random_smiles)
    
    return augmented_data

# Use diverse training data from multiple chemical databases
diverse_training_data = combine_datasets([
    'drug_molecules.csv',
    'natural_products.csv', 
    'synthetic_compounds.csv'
])

results matching ""

    No results matching ""