7. Retrosynthesis
dataset: USPTO subset (https://pubs.acs.org/doi/full/10.1021/acscentsci.7b00064)
Organic synthesis, the design of synthetic routes to organic compounds, is a crucial discipline that holds relevance not only in the field of chemistry but also across many other scientific domains and industries, including medicine and pharmaceuticals. It enables access to new molecules for a wide range of applications, including drug development and materials science.
The two core approaches involved in the synthesis of new molecules are: forward reaction prediction and retrosynthetic reaction prediction. Forward reaction prediction is the inference of the potential products of a given set of reactants, reagents, and reaction conditions. Retrosynthetic reaction prediction is simply the inverse process of that – starting with the target molecule and reasoning backward to determine how it might be constructed. It involves recursively deconstructing a target compound into simpler precursors that can ultimately be sourced or synthesized. This backward reasoning process that proposes plausible synthetic routes may yield multiple valid pathways for a single molecule. This one-to-many mapping reflects the complexity and flexibility of chemical synthesis.
Retrosynthesis is a critical link between digital molecule design and real-world synthesis. Once the structure of a molecule is determined “in silo”, retrosynthetic analysis helps determine whether and how it can be made using available chemical building blocks. In the context of drug discovery and materials science, this approach accelerates the development pipeline by guiding chemists toward feasible and efficient synthetic routes. The ability to identify valid reaction sequences with minimal experimentation is a valuable asset, especially when time or resources are limited.
7.1 Retrosynthetic Planning & Single-step Retrosynthesis
Complete retrosynthesis planning provides a series of reactions that sequentially breaks up a complex target molecule into smaller and simpler pieces until all of the pieces are commercially available.
This process is often represented as a tree structure: the root is the target molecule, and branches represent reaction steps that simplify the molecule progressively. The goal is to trace a path from the target to a set of readily available ingredients, ensuring that each intermediate step is chemically viable.
Each of these intermediate steps can be viewed as an independent retrosynthesis problem, referred to as single-step retrosynthesis. In other words, the retrosynthetic process can be broken down into discrete stages, each addressing a single chemical transformation. These are known as single-step retrosynthesis tasks. At each step, the target molecule is split into one or more simpler reactants.
However, even inferring a single-step retrosynthesis is not trivial. Both it and retrosynthetic planning have historically relied on the expertise, domain-knowledge and experience of chemists, as well as costly trial and error.
7.2 Computational Approaches to Retrosynthesis
The application of computational techniques to retrosynthesis analysis has emerged as an active and challenging area of research. In recent years, the accumulation of chemical synthesis data, coupled with advances in deep learning, has accelerated the development of computer-assisted synthesis processes, particularly for single-step retrosynthesis.
7.2.1 Types of Computational Approaches
Computational methods for retrosynthesis can be broadly categorized into two main types: template-based and template-free approaches.
Template-Based Methods
Template-based methods rely on predefined collections of reaction rules. These methods typically treat single-step retrosynthesis as a classification task, selecting the most suitable reaction template to generate a given target product.
Despite their structured approach, template-based methods have notable limitations. They cannot predict retrosynthetic outcomes for target molecules that involve novel synthesis patterns not covered by the existing template library. Moreover, updating these libraries to incorporate newly discovered synthesis knowledge can be tedious and labor-intensive.
Template-Free Methods
In contrast, template-free methods predict the reactants of a target molecule without depending on predefined reaction templates. Because they are not constrained by a fixed rule set, these approaches are particularly valued for their ability to generalize to novel reactions and unfamiliar scenarios. Machine learning models leveraged in this approach include:
- Seq2Seq models
- Transformer-based models
- Graph Neural Networks (GNNs)
7.2.2 Challenges in Template-Free Methods
While template-free methods facilitate the discovery of novel synthesis routes, they also introduce new challenges. One key issue is that the reactant SMILES generated by these models may be chemically invalid or commercially unviable or unavailable. Although some studies have identified potential value in these invalid SMILES outputs, a substantial body of research has focused on mitigating this problem through mechanisms such as syntax post-checkers and semi-template-based methods.
Another limitation of SMILES-based methods is their inability to effectively capture molecules’ structural information such as atomic properties, bond features, and adjacency relationships. In addition, current models struggle to fully exploit the potential of multiple molecular descriptors. This challenge often necessitates trade-offs between computational efficiency and predictive accuracy. For example, molecular fingerprinting emphasizes detailed structural features, while SMILES provides more global molecular information. When using only one descriptor, important molecular characteristics may be lost.
Despite recent progress, the complexity of machine-supported retrosynthesis continues to motivate further research into data-driven strategies for synthesis planning.
A significant dimension yet to be fully addressed is empirical chemical knowledge, particularly the influence of catalysts and solvents as the same set of reactants can yield different products depending on the solvent or catalyst used. Considerations of reaction conditions, catalyst availability, and associated costs remain active and important areas for future research.
7.2.3 SMILES (Simplified Molecular Input Line Entry System)
Retrosynthesis prediction is complex and heavily dependent on molecular descriptors such as SMILES and molecular fingerprints. The Simplified Molecular Input Line Entry System (SMILES) is a string-based notation used to represent molecular structures and reactions. More detailed introduction to SMILES can be found in Chapter 2.
Each element in a SMILES string can be interpreted as a token in a machine translation model. By converting equivalent chemical structures into SMILES strings, several models can be employed for reaction prediction. Notable among these are the sequence-to-sequence (seq2seq) model and the Transformer model, based on attention mechanisms. These models will be discussed in detail in upcoming sections.
7.3 Seq2Seq LSTM
The first machine learning-based template-free approach to retrosynthesis we will discuss is the use of a Seq2Seq LSTM. This is a sequence-to-sequence architecture built using two LSTMs, one serving as the encoder and the other as the decoder. (See Chapter 6 for information on RNN and LSTM, and Chapter 3 for information on neural networks.)
The Seq2Seq LSTM architecture is used for tasks where one sequence of data needs to be transformed into another sequence, especially when the input and output sequences can have different lengths. Retrosynthesis can be framed as a sequence-to-sequence task, a “translation” from product to reactant, when molecules are represented using SMILES strings.
7.3.1 Choosing Seq2Seq LSTM for Retrosynthesis Tasks
While Seq2Seq LSTM is generally considered inferior to more modern approaches like Transformers or Graph Neural Networks (also presented in this chapter), it does have certain advantages:
-
Simple to implement and train: Especially if working with tokenized SMILES strings (reactants » product or vice versa).
-
Captures sequence dependencies: LSTMs are good at modeling dependencies in sequences, which helps with SMILES syntax.
-
Works decently for small datasets: On datasets like USPTO-50K, it gives reasonable results without huge compute.
-
Useful in resource-constrained settings: It can be useful for prototyping, building lightweight models, or if constrained by compute and dataset size, or applying post-processing (e.g., syntax correction, beam search) to improve outputs.
Overall, Seq2Seq LSTM is a conceptually clean and easy-to-train baseline for retrosynthesis prediction using SMILES strings.
7.3.2 Seq2Seq LSTM for Retrosynthesis: Application Code
Complete code: Click here
Dataset: The USPTO-50k dataset was used for this model. Link: Click here
Step 1: Download the data files and upload them to Colab
The provided link has downloadable raw files split into raw_train.csv
, raw_val.csv
, and raw_test.csv
. Download the zip, extract files, and upload into the Colab notebook.
Step 2: Install and Import Required Libraries
# Install RDKit, transformers and any other missing libraries in Google Colab
!pip install rdkit
!pip install transformers
!pip install wandb # Optional: Used for hyperparameter tuning or logging models
import pandas as pd
from rdkit import Chem
from transformers import RobertaTokenizerFast
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import wandb # Optional: Used for hyperparameter tuning or logging models
Hyperparameter tuning: In machine learning, hyperparameters are configuration values that govern the training process of machine learning models. They are used during the learning phase of the model, but they are not part of the final resulting model. Examples include the learning rate, number of layers, hidden dimensions, dropout rate, and batch size. As these values can have a significant impact on model performance, choosing the right set of hyperparameters is essential to achieving the optimal model. This process of selection is known as hyperparameter tuning.
Note: To use transformers library, add your HuggingFace token to the Colab notebook. The HuggingFace token can be found in “Settings»Access Tokens” when logged in to HuggingFace (more information here). To add the key to Colab, click on the key icon on the left side panel of the notebook and paste the token in the value field. Name the token “HF_TOKEN” and toggle notebook access for the key.
If not using Colab, the following lines of code can be used to access the token:
# If using Colab, add HF_TOKEN to Colab Secrets
# else, add HF_TOKEN to python env and use the following lines of code
import os
access_token = os.environ.get('HF_TOKEN')
Step 3: Data loading and Processing
# Data loading and Processing
# Paths to data files
train_file = "raw_train.csv"
val_file = "raw_val.csv"
test_file = "raw_test.csv"
# Load the data
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)
# File structure: "ID, Class, reactants>>products"
# Extract only the third column: "reactant>>products"
train_rxns = train_df.iloc[:, 2]
val_rxns = val_df.iloc[:, 2]
test_rxns = test_df.iloc[:, 2]
# Sanity check: Print some of the loaded data
print(train_rxns.head())
print(val_rxns.head())
print(test_rxns.head())
Step 4: Separate reactants and products
# Separate reactants and products
def extract_pairs(reaction_series):
inputs = []
outputs = []
for rxn in reaction_series:
try:
reactants, products = rxn.split(">>")
# Delimiter ">>" used in our data to separate reactants and products
inputs.append(products.strip()) # make products our input
outputs.append(reactants.strip()) # make reactants the output
except ValueError:
continue # skip malformed lines
return inputs, outputs
# Call extract_pairs for train, validation, and test sets, and store reactants and products separately
train_X, train_y = extract_pairs(train_rxns)
val_X, val_y = extract_pairs(val_rxns)
test_X, test_y = extract_pairs(test_rxns)
# Sanity check: Expecetd True
print(train_y[0]+">>"+train_X[0]==train_rxns[0])
Step 5: Canonicalize SMILES using RDKit
With SMILES representation, a single molecule can be represented by more than one valid SMILES string. This means that the same molecule could appear multiple times in the dataset with different SMILES and models might overfit or mislearn due to inconsistent representations. To solve this issue, we use SMILES canonicalization, which converts different valid SMILES strings that represent the same molecule into a unique, standardized form (called canonical SMILES).
# Canonicalise SMILES
def canonicalize(smiles):
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return Chem.MolToSmiles(mol, canonical=True)
except:
return None
def canonicalize_pairs(X, y):
canon_X, canon_y = [], []
for prod, react in zip(X, y):
# Canonicalize product
canon_prod_parts = []
for p in prod.split('.'):
# '.' is the delimiter used to separate multiple reactants/
# multiple products in SMILES representation
c = canonicalize(p)
if c:
canon_prod_parts.append(c)
c_prod = '.'.join(canon_prod_parts)
# Canonicalize reactants
canon_react_parts = []
for r in react.split('.'):
c = canonicalize(r)
if c:
canon_react_parts.append(c)
c_react = '.'.join(canon_react_parts)
if c_prod and c_react:
canon_X.append(c_prod)
canon_y.append(c_react)
return canon_X, canon_y
# Replace train, test, eval product-reactant pairs with canonicalized forms
train_X, train_y = canonicalize_pairs(train_X, train_y)
val_X, val_y = canonicalize_pairs(val_X, val_y)
test_X, test_y = canonicalize_pairs(test_X, test_y)
# Sanity check
print(train_X[0])
Step 6: Tokenize SMILES
# Tokenize SMILES
tokenizer = RobertaTokenizerFast.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
# This is a fast tokenizer implementation of the RoBERTa tokenizer, specifically designed for use with SMILES strings
# You may also use other tokenizer of choice
# Create function for tokenization
def tokenize_smiles_bpe(smiles_list, tokenizer, max_length=600):
encodings = tokenizer(smiles_list,
padding='max_length',
truncation=True,
max_length=max_length,
return_tensors='pt')
return encodings['input_ids']
# Tokenize encoder (product SMILES)
train_enc_input = tokenize_smiles_bpe(train_X, tokenizer)
val_enc_input = tokenize_smiles_bpe(val_X, tokenizer)
test_enc_input = tokenize_smiles_bpe(test_X, tokenizer)
# Tokenize decoder (reactant SMILES)
train_dec_input = tokenize_smiles_bpe(train_y, tokenizer)
val_dec_input = tokenize_smiles_bpe(val_y, tokenizer)
test_dec_input = tokenize_smiles_bpe(test_y, tokenizer)
# Sanity check
print(train_enc_input.shape)
print(train_dec_input.shape)
Step 7: Define Some Helpful Helpers
create_dataloader
is a utility function used to wrap input and target tensors into a DataLoader object which handles batching, shuffling etc. Additionally, this function ensures that both enc_inputs and dec_inputs are PyTorch tensors. If they’re not already tensors, it converts them.
# Dataset wrapper
def create_dataloader(enc_inputs, dec_inputs, batch_size):
inputs = enc_inputs if isinstance(enc_inputs, torch.Tensor) else torch.tensor(enc_inputs, dtype=torch.long)
targets = dec_inputs if isinstance(dec_inputs, torch.Tensor) else torch.tensor(dec_inputs, dtype=torch.long)
dataset = TensorDataset(inputs, targets)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
Next, we define the compute_accuracy
function which calculates token-level accuracy of predictions from the model, while ignoring padded positions. Token-level accuracy is used here during the training and evaluation loops as sequence-level comparisions may be too harsh for the model to learn. During testing, sequence-level checking may be used e.g., exact match or top-k accuracy.
# Helper function
def compute_accuracy(predictions, targets, pad_token_id=0):
preds = predictions.argmax(dim=2) # shape: (batch_size, seq_len)
mask = targets != pad_token_id # ignore padding
correct = (preds == targets) & mask
accuracy = correct.sum().float() / mask.sum().float()
return accuracy.item()
We also define device
to define the appropriate device to run the model, specifically GPU or CPU.
# Device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Finally, we may optionally use wandb
to tune hyperparameters during training by setting several possible ranges for these as seen in the following code. Note that the sweep values shown are simply for demonstration purposes.
# WandB Sweep Options Configuration (if tuning hyperparameters)
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val_loss', 'goal': 'minimize'},
'parameters': {
'learning_rate': {'distribution': 'uniform', 'min':0.0001, 'max': 0.01},
'embed_dim': {'values': [128, 256]},
'hidden_dim': {'values': [ 256, 512]},
'num_layers': {'values': [2,3,4]},
'dropout': {'distribution': 'uniform', 'min':0.2, 'max': 0.7},
'epochs': {'distribution': 'int_uniform', 'min': 2, 'max': 25},
}
}
Step 8: Define Model Components
# LSTM Components
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers,
dropout=dropout, batch_first=True, bidirectional=False)
def forward(self, x):
embedded = self.embedding(x)
outputs, hidden = self.lstm(embedded)
return outputs, hidden
class Decoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers,
dropout=dropout, batch_first=True, bidirectional=False)
self.fc_out = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden):
# x: (batch_size, 1)
embedded = self.embedding(x) # (batch_size, 1, embed_dim)
output, hidden = self.lstm(embedded, hidden) # output: (batch_size, 1, hidden_dim)
prediction = self.fc_out(output) # (batch_size, 1, vocab_size)
return prediction, hidden
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, tokenizer):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.tokenizer = tokenizer
def forward(self, src, trg=None, max_len=600, teacher_forcing=False):
batch_size = src.size(0)
outputs = []
input_token = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=src.device)
_, hidden = self.encoder(src)
for _ in range(max_len):
output, hidden = self.decoder(input_token, hidden)
outputs.append(output)
input_token = output.argmax(2)
return torch.cat(outputs, dim=1)
Step 9: Create the Training, Evaluation, and Test Functions
The train_epoch
function handles one training pass over the training dataset (epoch).
def train_epoch(model, dataloader, criterion, optimizer, vocab_size, max_len):
model.train()
total_loss = 0
total_acc = 0
for enc_input, dec_target in dataloader:
enc_input, dec_target = enc_input.to(device), dec_target.to(device)
optimizer.zero_grad()
batch_size = enc_input.size(0)
dec_input = torch.full((batch_size, 1), fill_value=tokenizer.bos_token_id, dtype=torch.long, device=device)
outputs = []
encoder_outputs, hidden = model.encoder(enc_input)
decoder_hidden = hidden
for t in range(max_len):
output, decoder_hidden = model.decoder(dec_input, decoder_hidden)
outputs.append(output)
# Greedy sampling (no teacher forcing)
top1 = output.argmax(2) # (batch_size, 1)
dec_input = top1.detach()
outputs = torch.cat(outputs, dim=1) # (batch_size, max_len, vocab_size)
loss = criterion(outputs.view(-1, vocab_size), dec_target.view(-1))
accuracy = compute_accuracy(outputs, dec_target)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_acc += accuracy
return total_loss / len(dataloader), total_acc/len(dataloader)
Next, the evaluate
function runs a validation pass over the validation dataset, and returns the average loss and accuracy over the evaluation set.
def evaluate(model, dataloader, criterion, vocab_size, max_len):
model.eval()
total_loss = 0
total_acc = 0
with torch.no_grad():
for enc_input, dec_target in dataloader:
enc_input, dec_target = enc_input.to(device), dec_target.to(device)
batch_size = enc_input.size(0)
dec_input = torch.full((batch_size, 1), fill_value=tokenizer.bos_token_id, dtype=torch.long, device=device)
outputs = []
encoder_outputs, hidden = model.encoder(enc_input)
decoder_hidden = hidden
for t in range(max_len):
output, decoder_hidden = model.decoder(dec_input, decoder_hidden)
outputs.append(output)
top1 = output.argmax(2)
dec_input = top1.detach()
outputs = torch.cat(outputs, dim=1)
loss = criterion(outputs.view(-1, vocab_size), dec_target.view(-1))
accuracy = compute_accuracy(outputs, dec_target)
total_loss += loss.item()
total_acc += accuracy
return total_loss / len(dataloader), total_acc/ len(dataloader)
Finally, we have the test_exactmatch
function which performs exact-match checking of the tested model’s output against ground truth for the test dataset. This is a sequence-level accuracy check, as mentioned in Step 7, and is much stricter than the token-level checks used for training.
code
Step 10: Training Loop
The train
function handles the complete training workflow for the Seq2Seq LSTM, with support for hyperparameter tuning via wandb. The function sets up the training, validation, and test data loaders. It loads hyperparameter values from the wandb.config object, which is defined through a sweep configuration (refer Step 7 for sweep_config
). Alternatively, you may replace a parameter (config.<param>
here) with fixed values to manually define your hyperparameters.
For each epoch, the model is trained using train_epoch()
, and validation is performed using evaluate()
. The etst fucntion used is test_exactmatch
. Key metrics are logged to wandb for tracking. After training completes, the final model is saved to disk and also uploaded to wandb for record-keeping or future use.
# Training loop, with wandb
def train():
wandb.init()
config = wandb.config
vocab_size = tokenizer.vocab_size
batch_size = 64 # Batch size can also be configurable and performance may change
# but, parallelisation was used to utilise multiple GPUs and memory limitations restricted us to batch size of 64
# Data loaders
train_loader = create_dataloader(train_enc_input, train_dec_input, batch_size)
val_loader = create_dataloader(val_enc_input, val_dec_input, batch_size)
test_loader = create_dataloader(val_enc_input, val_dec_input, batch_size=256) # testing doesn't rely
# Model
encoder = Encoder(
vocab_size=vocab_size,
embed_dim=config.embed_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
dropout=config.dropout,
)
decoder = Decoder(
vocab_size=vocab_size,
embed_dim=config.embed_dim,
hidden_dim=config.hidden_dim,
num_layers=config.num_layers,
dropout=config.dropout,
)
model = Seq2Seq(encoder, decoder, tokenizer).to(device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=0)
for epoch in range(config.epochs):
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, vocab_size, max_len=600)
val_loss, val_acc = evaluate(model, val_loader, criterion, vocab_size, max_len=600)
test_acc=test_greedy(model, test_loader, tokenizer, max_len=600, pad_token_id=tokenizer.pad_token_id)
# Log metrics in wandb
# Remove/ comment out the following line if not suing wandb
wandb.log({
"epoch": epoch + 1,
"train_loss": train_loss,
"val_loss": val_loss,
"train_acc": train_acc,
"val_acc": val_acc,
"test_acc": test_acc
})
print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} |Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
# Save final model
torch.save(model.state_dict(), "model_final.pth")
wandb.save("model_final.pth")
return model
(If using wandb) Step 11: Start wandb Sweep
wandb.login()
sweep_id = wandb.sweep(sweep_config, project="retro-lstm")
wandb.agent(sweep_id, function=train)