7. Retrosynthesis

Overview

This chapter explores the problem of retrosynthetic analysis and how it is approached using modern machine learning techniques. We begin by introducing the chemical basis of retrosynthesis, including the distinction between single-step and multi-step retrosynthetic planning. We then shift to computational perspectives, introducing both template-based methods and modern template-free approaches that leverage machine learning.

A core focus of the chapter is on template-free methods, which bypass manually encoded reaction rules and instead learn directly from data. We provide a comparative overview of several model architectures commonly used for single-step retrosynthesis prediction, including:

  • Seq2Seq LSTM models, which treat retrosynthesis as a translation task from product to reactants;
  • Transformer-based models, which incorporate attention mechanisms to model complex token dependencies in SMILES strings;
  • Graph Neural Networks (GNNs), which operate directly on molecular graph structures for richer chemical understanding.

The chapter also introduces challenges unique to SMILES-based modeling, including chemical validity, structural ambiguity, and descriptor limitations. These issues motivate ongoing research into hybrid representations and post-processing methods.

Throughout, we emphasize practical aspects such as data preprocessing, model training, and evaluation strategies. The chapter includes complete code examples, with a focus on the USPTO-50K dataset, to support hands-on implementation.

Why do we need this?

Currently, several commercial retrosynthesis solutions exist, such as CAS Retrosynthetic Analysis and Synthia, alongside open-source tools developed by researchers, including AiZynthFinder, Chemformer, and LocalRetro. However, these tools often come with limitations. For instance, Synthia is restricted to reactions available in the Sigma-Aldrich database, which can limit flexibility for novel or domain-specific synthesis goals. Additionally, while commercial software may be effective for general applications, a custom-trained model can be tailored to specialized datasets and use cases, including customization based on available reactants, reagents, or laboratory constraints.

Furthermore, many existing tools are rule-based, template-driven, or rely on decision trees (e.g., the CAS retrosynthetic analysis software), which may struggle to generalize beyond well-studied reaction templates or to handle novel compound classes. This is where machine learning–based models have a distinct advantage: they can learn directly from data, identify patterns beyond explicit rules, and adapt more flexibly to new tasks.

Summary of Key Terms in the Chapter

Term Definition
seq2seq a neural network architecture designed to map one sequence to another, often with different lengths.
LSTM a type of recurrent neural network (RNN) that can learn long-range dependencies in sequences.
canonicalized SMILES a unique, standardized version of a molecule’s SMILES representation. Because the same molecule can be described by multiple valid SMILES strings, canonicalization ensures consistency.
hyperparameters variables that control the learning process of a model. They are not learned from the data during training, rather they are provided before training and influence the training itself
tokens the fundamental units of text that a model processes. They can be characters, words, or subwords.
tokenization the process of converting raw text into a sequence of tokens
tensor a multi-dimensional array used to store data in deep learning frameworks like PyTorch or TensorFlow.
token-level accuracy a measure of how many predicted tokens match the ground truth tokens, averaged over all positions in all sequences. It helps identify whether the model is getting most of the structure right, even if the entire prediction isn’t perfect.
exact match a stricter metric that checks whether the entire predicted sequence (e.g., the reactants) is identical to the target sequence.
top-k accuracy a measure of how often the correct output appears in the model’s top k predictions. For example, top-5 accuracy checks whether the ground truth reactants are found among the five most likely outputs generated by the model.
top-1 accuracy a subset of top-k accuracy with k = 1, that measures how often the model’s most likely (i.e., highest-ranked) prediction matches the ground truth exactly.
teacher forcing a training strategy where the ground truth token at each time step is fed into the decoder instead of the model’s own previous predicted token. It helps the model learn faster and more reliably during training.
attention a mechanism to determine the importance of components in a sequence relative to other components in that sequence, allowing the model to pay attention to the most relevant parts of the sequence.
normalize the process of adjusting values measured on different scales to a common scale, often between 0 and 1. For example, in our sequence comparison, it involves dividing a raw score by the length of the reference or target sequence to convert the score into a relative metric, making it comparable across examples of different lengths.
Levenshtein distance a metric for measuring the minimum number of single-character edits (insertions, deletions, or substitutions) required to change one string into another. It is commonly used to evaluate the similarity between sequences.
greedy decoding a decoding strategy where, at each time step, the model selects the token with the highest probability as the next word
beam search decoding a decoding algorithm that keeps track of the top k most probable partial sequences (beams) at each time step. It balances between exploration and exploitation, allowing for better sequences than greedy decoding by considering multiple paths.
transformer a neural network that transforms an input sequence into an output sequence, using self-attention mechanisms to learn contextual relationships between tokens across the entire sequence.

7.1 Retrosynthesis

Organic synthesis, the design of synthetic routes to organic compounds, is a crucial discipline that holds relevance not only in the field of chemistry but also across many other scientific domains and industries, including medicine and pharmaceuticals. It enables access to new molecules for a wide range of applications, including drug development and materials science.

The two core approaches involved in the synthesis of new molecules are: forward reaction prediction and retrosynthetic reaction prediction. Forward reaction prediction is the inference of the potential products of a given set of reactants, reagents, and reaction conditions. Retrosynthetic reaction prediction is simply the inverse process of that – starting with the target molecule and reasoning backward to determine how it might be constructed. It involves recursively deconstructing a target compound into simpler precursors that can ultimately be sourced or synthesized. This backward reasoning process that proposes plausible synthetic routes may yield multiple valid pathways for a single molecule. This one-to-many mapping reflects the complexity and flexibility of chemical synthesis.

Retrosynthesis is a critical link between digital molecule design and real-world synthesis. Once the structure of a molecule is determined “in silo”, retrosynthetic analysis helps determine whether and how it can be made using available chemical building blocks. In the context of drug discovery and materials science, this approach accelerates the development pipeline by guiding chemists toward feasible and efficient synthetic routes. The ability to identify valid reaction sequences with minimal experimentation is a valuable asset, especially when time or resources are limited.

7.1.1 Retrosynthetic Planning & Single-step Retrosynthesis

Complete retrosynthesis planning provides a series of reactions that sequentially breaks up a complex target molecule into smaller and simpler pieces until all of the pieces are commercially available.

This process is often represented as a tree structure: the root is the target molecule, and branches represent reaction steps that simplify the molecule progressively. The goal is to trace a path from the target to a set of readily available ingredients, ensuring that each intermediate step is chemically viable.

Each of these intermediate steps can be viewed as an independent retrosynthesis problem, referred to as single-step retrosynthesis. In other words, the retrosynthetic process can be broken down into discrete stages, each addressing a single chemical transformation. These are known as single-step retrosynthesis tasks. At each step, the target molecule is split into one or more simpler reactants.

image here

However, even inferring a single-step retrosynthesis is not trivial. Both it and retrosynthetic planning have historically relied on the expertise, domain-knowledge and experience of chemists, as well as costly trial and error.

7.2 Computational Approaches to Retrosynthesis

The application of computational techniques to retrosynthesis analysis has emerged as an active and challenging area of research. In recent years, the accumulation of chemical synthesis data, coupled with advances in deep learning, has accelerated the development of computer-assisted synthesis processes, particularly for single-step retrosynthesis.

7.2.1 Types of Computational Approaches

Computational methods for retrosynthesis can be broadly categorized into two main types: template-based and template-free approaches. In this section, we provide a brief overview of both, but the detailed methods for single-step retrosynthesis presented in this chapter fall in the category of template-free methods.

Template-Based Methods

Template-based methods rely on predefined collections of reaction rules. These methods typically treat single-step retrosynthesis as a classification task, selecting the most suitable reaction template to generate a given target product.

Despite their structured approach, template-based methods have notable limitations. They cannot predict retrosynthetic outcomes for target molecules that involve novel synthesis patterns not covered by the existing template library. Moreover, updating these libraries to incorporate newly discovered synthesis knowledge can be tedious and labor-intensive.

Template-Free Methods

In contrast, template-free methods predict the reactants of a target molecule without depending on predefined reaction templates. Because they are not constrained by a fixed rule set, these approaches are particularly valued for their ability to generalize to novel reactions and unfamiliar scenarios. Machine learning models leveraged in this approach include:

  • Seq2Seq models
  • Transformer-based models
  • Graph Neural Networks (GNNs)

7.2.2 Challenges in Template-Free Methods

While template-free methods facilitate the discovery of novel synthesis routes, they also introduce new challenges. One key issue is that the reactant SMILES generated by these models may be chemically invalid or commercially unviable or unavailable. Although some studies have identified potential value in these invalid SMILES outputs, a substantial body of research has focused on mitigating this problem through mechanisms such as syntax post-checkers and semi-template-based methods.

Another limitation of SMILES-based methods is their inability to effectively capture molecules’ structural information such as atomic properties, bond features, and adjacency relationships. In addition, current models struggle to fully exploit the potential of multiple molecular descriptors. This challenge often necessitates trade-offs between computational efficiency and predictive accuracy. For example, molecular fingerprinting emphasizes detailed structural features, while SMILES provides more global molecular information. When using only one descriptor, important molecular characteristics may be lost.

Despite recent progress, the complexity of machine-supported retrosynthesis continues to motivate further research into data-driven strategies for synthesis planning.

A significant dimension yet to be fully addressed is empirical chemical knowledge, particularly the influence of catalysts and solvents as the same set of reactants can yield different products depending on the solvent or catalyst used. Considerations of reaction conditions, catalyst availability, and associated costs remain active and important areas for future research.

7.2.3 SMILES (Simplified Molecular Input Line Entry System)

Retrosynthesis prediction is complex and heavily dependent on molecular descriptors such as SMILES and molecular fingerprints. The Simplified Molecular Input Line Entry System (SMILES) is a string-based notation used to represent molecular structures and reactions. More detailed introduction to SMILES can be found in Chapter 2.

Each element in a SMILES string can be interpreted as a token in a machine translation model. By converting equivalent chemical structures into SMILES strings, several models can be employed for reaction prediction. Notable among these are the sequence-to-sequence (seq2seq) model and the Transformer model which is based on attention mechanisms. These models will be discussed in detail in upcoming sections.

7.3 Seq2Seq LSTM

The first machine learning-based template-free approach to retrosynthesis we will discuss is the use of a Seq2Seq LSTM. This is a sequence-to-sequence architecture built using two LSTMs, one serving as the encoder and the other as the decoder. (See Chapter 6 for information on RNN and LSTM, and Chapter 3 for information on neural networks.)

The Seq2Seq LSTM architecture is used for tasks where one sequence of data needs to be transformed into another sequence, especially when the input and output sequences can have different lengths. Retrosynthesis can be framed as a sequence-to-sequence task, a “translation” from product to reactant, when molecules are represented using SMILES strings.

7.3.1 Choosing Seq2Seq LSTM for Retrosynthesis Tasks

While Seq2Seq LSTM is generally considered inferior to more modern approaches like Transformers or Graph Neural Networks (also presented in this chapter), it does have certain advantages:

  • Simple to implement and train: Especially if working with tokenized SMILES strings (reactants » product or vice versa).

  • Captures sequence dependencies: LSTMs are good at modeling dependencies in sequences, which helps with SMILES syntax.

  • Works decently for small datasets: On datasets like USPTO-50K, it gives reasonable results without huge compute.

  • Useful in resource-constrained settings: It can be useful for prototyping, building lightweight models, or if constrained by compute and dataset size, or applying post-processing (e.g., syntax correction, beam search) to improve outputs.

Overall, Seq2Seq LSTM is a conceptually clean and easy-to-train baseline for retrosynthesis prediction using SMILES strings.

7.3.2 Seq2Seq LSTM for Retrosynthesis: Application Code

Complete code: Click here

Dataset: The USPTO-50k dataset was used for this model. A version of this dataset with reaction classes included and train, validation, and test files split as used in this tutorial can be found here: Click here!

In this section, we provide a step-by-step process for creating an LSTM for single-step retrosynthesis. We have chosen the small USPTO-50k dataset which contains 50,000 reactions for this demonstration for ease of explanation and limited computational resources.

In this handbook, we have implemented a seq2seq LSTM that uses attention. It is trained using teacher forcing, with the teacher forcing ratio being decayed at every epoch. Additionally, we included the reaction classes in training and inference. All three of these are optional, though they may influence model performance, and the Colab notebook above provides alternative code should you wish to not include any one of them.

We tested this model on Reaction Class 10 rather than all reaction classes due to limitations on time and computational resources. Our model achieved Top-1 accuracy = 69.56%, with an average normalized Levenshtein distance of 0.0163, which is excellent on this subset.

Step 1: Download the data files and upload them to Colab

The provided link has downloadable raw files split into rxn_train.csv, rxn_val.csv, rxn_test.csv, files and rxnclass10_test.csv. Download the zip, extract files, and upload into the Colab notebook.

Step 2: Install and Import Required Libraries

# Install RDKit, transformers and any other missing libraries in Google Colab
!pip install rdkit
!pip install transformers
!pip install wandb # Optional: Used for hyperparameter tuning or logging models

import pandas as pd
from rdkit import Chem
from transformers import RobertaTokenizerFast # See Note
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import wandb # Optional: Used for hyperparameter tuning or logging models (See below)

Hyperparameter tuning: In machine learning, hyperparameters are configuration values that govern the training process of machine learning models. They are used during the learning phase of the model, but they are not part of the final resulting model. Examples include the learning rate, number of layers, hidden dimensions, dropout rate, and batch size. As these values can have a significant impact on model performance, choosing the right set of hyperparameters is essential to achieving the optimal model. This process of selection is known as hyperparameter tuning.

Note: To use transformers library, add your HuggingFace token to the Colab notebook. The HuggingFace token can be found in “Settings»Access Tokens” when logged in to HuggingFace (more information here). To add the key to Colab, click on the key icon on the left side panel of the notebook and paste the token in the value field. Name the token “HF_TOKEN” and toggle notebook access for the key.

If not using Colab, the following lines of code can be used to access the token:

# If using Colab, add HF_TOKEN to Colab Secrets
# else, add HF_TOKEN to python env and use the following lines of code
import os
access_token = os.environ.get('HF_TOKEN')

Step 3: Data loading and Processing

# Data loading and Processing
# Paths to data files
train_file = "rxn_train.csv"
val_file = "rxn_val.csv"
test_file = "rxnclass10_test.csv" # Or, use rxn_test.csv to test model on entire test set

# Load the data
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

# File structure: "ID, Class, reactants>>products"
# Extract only the third column: "reactant>>products"
train_rxns = train_df.iloc[:, 2]
val_rxns = val_df.iloc[:, 2]
test_rxns = test_df.iloc[:, 2]

# Extract the reaction classes (Second Column) (optional)
trxntype = train_df.iloc[:, 1]
vrxntype = val_df.iloc[:, 1]
tstrxntype = test_df.iloc[:, 1]

# Check: Print some of the loaded data
print(train_rxns.head())
print(val_rxns.head())
print(test_rxns.head())

Step 4: Separate reactants and products

Currently, the format of the loaded data (train_rxns, val_rxns, test_rxns) is: reactant>>products. However, we need reactants and products separately to act as the output and input for our model. Given our data formatting, an easy way to separate them is to make use of the delimiter >> as seen in the function below.

# Separate reactants and products
def extract_pairs(reaction_series):
    inputs = []
    outputs = []
    for rxn in reaction_series:
        try:
            reactants, products = rxn.split(">>")
            # Delimiter ">>" used in our data to separate reactants and products
            inputs.append(products.strip()) # make products our input
            outputs.append(reactants.strip()) # make reactants the output
        except ValueError:
            continue  # skip malformed lines
    return inputs, outputs

# Call extract_pairs for train, validation, and test sets, and store reactants and products separately
train_X, train_y = extract_pairs(train_rxns)
val_X, val_y = extract_pairs(val_rxns)
test_X, test_y = extract_pairs(test_rxns)

# Check: Expected True (reactant from separated pair + “>>” +  produced from separated pair = original reaction read from file)
print(train_y[0]+">>"+train_X[0]==train_rxns[0])

Step 5: Canonicalize SMILES using RDKit

With SMILES representation, a single molecule can be represented by more than one valid SMILES string. This means that the same molecule could appear multiple times in the dataset with different SMILES and models might overfit or mislearn due to inconsistent representations. To solve this issue, we use SMILES canonicalization, which converts different valid SMILES strings that represent the same molecule into a unique, standardized form (called canonical SMILES).

# Canonicalise SMILES
def canonicalize(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        return Chem.MolToSmiles(mol, canonical=True)
    except:
        return None

# define Canonicalise SMILES function
def canonicalize(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        return Chem.MolToSmiles(mol, canonical=True)
    except:
        return None

# helper to canonicalise both the reactants and products
def canonicalize_pairs(X, y):
    canon_X, canon_y = [], []

    for prod, react in zip(X, y):

        # Canonicalize product
        canon_prod_parts = []
        for p in prod.split('.'):
            # '.' is the delimiter used to separate multiple reactants/
            # multiple products in SMILES representation
            c = canonicalize(p)
            if c:
                canon_prod_parts.append(c) 
        c_prod = '.'.join(canon_prod_parts) # join canonicalised products back to prod1.prod2 format

        # Canonicalize reactants, same process as for products
        canon_react_parts = []
        for r in react.split('.'):
            c = canonicalize(r)
            if c:
                canon_react_parts.append(c)
        c_react = '.'.join(canon_react_parts)
        if c_prod and c_react:
            canon_X.append(c_prod)
            canon_y.append(c_react)
    return canon_X, canon_y

# Replace train, test, eval product-reactant pairs with canonicalized forms
train_X, train_y = canonicalize_pairs(train_X, train_y)
val_X, val_y = canonicalize_pairs(val_X, val_y)
test_X, test_y = canonicalize_pairs(test_X, test_y)

# Sanity check
print(train_X[0])

Optional Step 6: Append Reaction Classes

This step involves adding known reaction classes to our input as it can potentially help the model learn better. It is entirely optional as in the real-world, we may or may not be aware of the reaction type, but ML-based computational methods work with both approaches though there may be some differences in performance.

# Add Reaction Class to Input 
train_X = [f"[RX_TYPE_{rx_type}] {smiles}" for rx_type, smiles in zip(trxntype, train_X)]
val_X = [f"[RX_TYPE_{rx_type}] {smiles}" for rx_type, smiles in zip(vrxntype, val_X)]
test_X = [f"[RX_TYPE_{rx_type}] {smiles}" for rx_type, smiles in zip(tstrxntype, test_X)]

Step 7: Prepare to Tokenize

Add Beginning-of-Sequence (BOS) and End-of-Sequence (EOS) tokens. These signal the start and end of a sequence and help the model learn where to start predicted sequences and when to stop.

# Add BOS and EOS tokens manually if they are not already present
# only for train and val 
train_y = [f"<s> {seq}" if not seq.startswith("<s>") else seq for seq in train_y]
train_y = [f"{seq} </s>" if not seq.endswith("</s>") else seq for seq in train_y]

val_y = [f"<s> {seq}" if not seq.startswith("<s>") else seq for seq in val_y]
val_y = [f"{seq} </s>" if not seq.endswith("</s>") else seq for seq in val_y]

# print some to verify
for x in range(5):
    print("Some train e.g.: ", train_X[x], "from: ", train_y[x])
for x in range(5):
    print("Some val e.g.: ", val_X[x], "from: ", val_y[x])
for x in range(5):
    print("Some train e.g.: ", test_X[x], "from: ", test_y[x])

Next, define the tokenizer and some special tokens. These special tokens are not part of the vocabulary i.e., not among the tokens predicted, but the tokenizer must be able to recognize them.

from transformers import RobertaTokenizerFast

# define tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
# This is a fast tokenizer implementation of the RoBERTa tokenizer, specifically designed for use with SMILES strings
# You may also use other tokenizer of choice

# define special tokens for tokenizer
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.pad_token="<pad>"
tokenizer.add_special_tokens({"additional_special_tokens": ["[RXN_TYPE_1]", "[RXN_TYPE_2]","[RXN_TYPE_3]","[RXN_TYPE_4]","[RXN_TYPE_5]","[RXN_TYPE_6]","[RXN_TY>

# print bos, eos, pad token ids for reference
print("BOS token:", tokenizer.bos_token, "-> ID:", tokenizer.bos_token_id)
print("EOS token:", tokenizer.eos_token, "-> ID:", tokenizer.eos_token_id)
print("Pad token:", tokenizer.pad_token, "-> ID", tokenizer.pad_token_id)
print()

Step 8: Tokenize SMILES

In this step, we take the human-readable SMILES strings and convert them into machine-friendly encodings.

# Create function for tokenization
def tokenize_smiles_bpe(smiles_list, tokenizer, max_length=600):
    encodings = tokenizer(smiles_list,
                          padding='max_length',
                          truncation=True,
                          max_length=max_length,
                          return_tensors='pt')
    return encodings['input_ids'] 

# Tokenize encoder (product SMILES)
train_enc_input = tokenize_smiles_bpe(train_X, tokenizer)
val_enc_input   = tokenize_smiles_bpe(val_X, tokenizer)
test_enc_input  = tokenize_smiles_bpe(test_X, tokenizer)

# Tokenize decoder (reactant SMILES)
train_dec_input = tokenize_smiles_bpe(train_y, tokenizer)
val_dec_input   = tokenize_smiles_bpe(val_y, tokenizer)
test_dec_input  = tokenize_smiles_bpe(test_y, tokenizer)

# Some checks
# Print shape 
print(train_enc_input.shape)
print(train_dec_input.shape

# Print first input_ids row (token IDs) as a list
print(train_enc_input[0].tolist())

# Decode back to SMILES tokens if you want to verify the round trip
print(tokenizer.decode(train_enc_input[0], skip_special_tokens=True))
print(train_dec_input.shape)

Step 9: Define Some Helpful Helpers

create_dataloader is a utility function used to wrap input and target tensors into a DataLoader object which handles batching, shuffling etc. Additionally, this function ensures that both enc_inputs and dec_inputs are PyTorch tensors. If they’re not already tensors, it converts them.

# Dataset wrapper
def create_dataloader(enc_inputs, dec_inputs, batch_size):
    inputs = enc_inputs if isinstance(enc_inputs, torch.Tensor) else torch.tensor(enc_inputs, dtype=torch.long)
    targets = dec_inputs if isinstance(dec_inputs, torch.Tensor) else torch.tensor(dec_inputs, dtype=torch.long)
    dataset = TensorDataset(inputs, targets)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

Next, we define the compute_accuracy function which calculates token-level accuracy of predictions from the model, while ignoring padded positions. Token-level accuracy is used here during the training and evaluation loops as sequence-level comparisons may be too harsh for the model to learn. During testing, sequence-level checking may be used e.g., exact match or top-k accuracy.

# Helper function
def compute_accuracy(predictions, targets, pad_token_id=0):
    preds = predictions.argmax(dim=2)  # shape: (batch_size, seq_len)
    mask = targets != pad_token_id     # ignore padding
    correct = (preds == targets) & mask
    accuracy = correct.sum().float() / mask.sum().float()
    return accuracy.item()

We also define device to define the appropriate device to run the model, specifically GPU or CPU.

# Device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Finally, we may optionally use wandb to tune hyperparameters during training by setting several possible ranges for these as seen in the following code. Note that the sweep values shown are simply for demonstration purposes.

# WandB Sweep Options Configuration (if tuning hyperparameters)
sweep_config = {
    'method': 'bayes', 
    'metric': {'name': 'val_loss', 'goal': 'minimize'},
    'parameters': {
        'learning_rate': {'distribution': 'uniform', 'min':0.0001, 'max': 0.01},
        'embed_dim': {'values': [128, 256]},
        'hidden_dim': {'values': [ 256, 512]},
        'num_layers': {'values': [2,3,4]},
        'dropout': {'distribution': 'uniform', 'min':0.2, 'max': 0.7},
        'epochs': {'distribution': 'int_uniform', 'min': 2, 'max': 25},
    }
}

Step 10: Attention

Attention is the key to a model that generalizes well and is able to predict reactant SMILES quite accurately. We noted a drastic improvement in model performance with the addition of addition, with sequence level accuracy jumping from ~0% to ~69.6% on our test set!

# Add Attention Support
class myAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(myAttention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))

    def forward(self, hidden, encoder_outputs):
        # hidden: (batch, 1, hidden_dim)
        # encoder_outputs: (batch, src_len, hidden_dim)

        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # Repeat decoder hidden state for every source token
        hidden = hidden.repeat(1, src_len, 1)  # (batch, src_len, hidden_dim)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # (batch, src_len, hidden_dim)

        v = self.v.repeat(batch_size, 1).unsqueeze(1)  # (batch, 1, hidden_dim)
        attn_weights = torch.bmm(v, energy.transpose(1, 2))  # (batch, 1, src_len)

        attn_weights = torch.softmax(attn_weights, dim=-1)  # (batch, 1, src_len)
        context = torch.bmm(attn_weights, encoder_outputs)  # (batch, 1, hidden_dim)

        return context, attn_weights  # context used to guide decoder

Step 11: Define Model Components

# LSTM Components
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers,
                            dropout=dropout, batch_first=True, bidirectional=False)

    def forward(self, x):
        embedded = self.embedding(x)
        outputs, hidden = self.lstm(embedded)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.attn = myAttention(hidden_dim)
        self.lstm = nn.LSTM(embed_dim + hidden_dim, hidden_dim, num_layers,
                            dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, input_token, hidden, encoder_outputs):
        # input_token: (batch_size, 1)
        # hidden: (h_n, c_n) from previous timestep
        # encoder_outputs: (batch_size, src_len, hidden_dim)

        embedded = self.embedding(input_token)  # (batch_size, 1, embed_dim)

        # Compute attention context vector from previous hidden state and encoder outputs
        decoder_hidden = hidden[0][-1].unsqueeze(1)  # (batch_size, 1, hidden_dim)
        context, attn_weights = self.attn(decoder_hidden, encoder_outputs)  # context: (batch_size, 1, hidden_dim)

        # Concatenate embedded input and context vector, then feed to LSTM
        inp = torch.cat((embedded, context), dim=2)  # (batch_size, 1, embed_dim + hidden_dim)
        output, hidden = self.lstm(inp, hidden)  # output: (batch_size, 1, hidden_dim)

        # Combine output with context and pass through final layer
        output_combined = torch.cat((output, context), dim=2)  # (batch_size, 1, 2*hidden_dim)
        prediction = self.fc_out(output_combined)  # (batch_size, 1, vocab_size)

        return prediction, hidden, attn_weights

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, tokenizer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tokenizer = tokenizer

    def forward(self, src, trg=None, max_len=600, teacher_forcing=False):
        batch_size = src.size(0)
        outputs = []
        input_token = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=src.device)

        encoder_outputs, hidden = self.encoder(src)

        for _ in range(max_len):
            output, hidden, _ = self.decoder(input_token, hidden, encoder_outputs)
            outputs.append(output)
            input_token = output.argmax(2)  # greedy sampling

        return torch.cat(outputs, dim=1)  # (batch_size, max_len, vocab_size)

Step 12: Create the Training and Evaluation Functions

The train_epoch function handles one training pass over the training dataset (epoch).

Here, we have trained the model using teacher forcing.

Note on teacher forcing:

In teacher forcing, instead of the model using its own previous prediction, the teacher i.e. the training program prompts the student (seq2seq LSTM) with the next correct token from the ground truth. However, we do not do this for every single token. Rather we set a teacher forcing ratio to determine how much of the sequence will be “taught” to the model. In the example code below, we start with a ratio of 0.5 i.e., the model is taught 50% of the sequence. However, this can result in the model relying on correct tokens and any single incorrect token during inference (when there is no teacher-forcing) can make the predictions go haywire. To address this, we “wean” the model away from teacher forcing by gradually reducing the teacher forcing ratio to zero. We do this using the decay rate.

def train_epoch(model, dataloader, criterion, optimizer, vocab_size, epoch, max_len):
    model.train()
    total_loss = 0
    total_acc = 0
    for enc_input, dec_target in dataloader:
        enc_input, dec_target = enc_input.to(device), dec_target.to(device)
        optimizer.zero_grad()

        batch_size = enc_input.size(0)
        dec_input = torch.full((batch_size, 1), fill_value=tokenizer.bos_token_id, dtype=torch.long, device=device)
        outputs = []

        encoder_outputs, hidden = model.encoder(enc_input)
        decoder_hidden = hidden

        tfri = 0.5  # initial teacher forcing ratio, set this as a hyperparameter if desired
        decay_rate = 0.05  # teacher forcing ratio reduces to 0 in ~10 epochs
        teacher_forcing_ratio = max(0.0, tfri - epoch * decay_rate)

        for t in range(max_len):
            output, decoder_hidden, attn_weights = model.decoder(dec_input, decoder_hidden, encoder_outputs) # FOr no attention, use: output, decoder_hidden =  model.decoder(dec_input, decoder_hidden)
            outputs.append(output)

            use_teacher_forcing = random.random() < teacher_forcing_ratio

            if t < dec_target.size(1) and use_teacher_forcing:
                dec_input = dec_target[:, t].unsqueeze(1)  # (batch_size, 1) - ground truth token
            else:
                top1 = output.argmax(2)  # (batch_size, 1)
                dec_input = top1.detach()

        outputs = torch.cat(outputs, dim=1)  # (batch_size, max_len, vocab_size)
        loss = criterion(outputs.view(-1, vocab_size), dec_target.view(-1))
        accuracy = compute_accuracy(outputs, dec_target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += accuracy

    return total_loss / len(dataloader), total_acc/len(dataloader)

Next, we create the evaluate function which runs a validation pass over the validation dataset, and returns the average loss and accuracy over the validation set.

def evaluate(model, dataloader, criterion, vocab_size, max_len):
    model.eval()
    total_loss = 0
    total_acc = 0
    with torch.no_grad():
        for enc_input, dec_target in dataloader:
            enc_input, dec_target = enc_input.to(device), dec_target.to(device)

            batch_size = enc_input.size(0)
            dec_input = torch.full((batch_size, 1), fill_value=tokenizer.bos_token_id, dtype=torch.long, device=device)
            outputs = []

            encoder_outputs, hidden = model.encoder(enc_input)
            decoder_hidden = hidden

            for t in range(max_len):
                output, decoder_hidden = model.decoder(dec_input, decoder_hidden)
                outputs.append(output)

                top1 = output.argmax(2)
                dec_input = top1.detach()

            outputs = torch.cat(outputs, dim=1)
            loss = criterion(outputs.view(-1, vocab_size), dec_target.view(-1))
            accuracy = compute_accuracy(outputs, dec_target)
            total_loss += loss.item()
            total_acc += accuracy

    return total_loss / len(dataloader), total_acc/ len(dataloader)

Step 13: Test Function

Finally, we have the test_beam_search function which performs beam search decoding and exact-match checking of the tested model’s output against ground truth for the test dataset. This is a sequence-level accuracy check, as mentioned in Step 7, and is much stricter than the token-level checks used for training. We also calculate the average normalized Levenshtein distance of the predictions.

Note on Beam Search Decoding:

Beam search is a search algorithm used in sequence generation tasks (like this one) to find the most likely output sequence given a model’s learned probability distribution over sequences. Unlike greedy decoding, which we use in our evaluation loop, and which considers only one path (best at each step), beam search decoding explores multiple paths by keeping track of multiple (top k) options at each step. In brief, greedy decoding chooses the best token at each step, possibly missing globally optimal sequences but beam search explores multiple candidate sequences, improving the chance of finding higher-probability completions, while remaining computationally feasible.

Note on Average Normalized Levenshtein Distance:

The average normalized Levenshtein distance is a metric used to evaluate the similarity between predicted and target sequences. It is based on the Levenshtein distance, which counts the minimum number of single-character edits (insertions, deletions, or substitutions) required to transform one sequence into another. To normalize this value, the raw distance is divided by the length of the longer sequence between the prediction and the reference, resulting in a score between 0 (exact match) and 1 (completely different). This normalization makes comparisons fair across sequences of different lengths. By calculating the normalized Levenshtein distance for each pair of predicted and target sequences in a dataset and averaging the results, we obtain the average normalized Levenshtein distance. This metric is especially valuable because it captures how close predictions are to the correct answer even when they are not exact matches, offering a more nuanced view of model performance than strict exact match accuracy.

# Test implementation with beam search decoding

# Remove eos, bos, pad tokens in case they remain
def strip_special_tokens(seq, bos_id, eos_id):
    if isinstance(seq, torch.Tensor):
        seq = seq.squeeze().tolist()
    # Remove BOS if present
    if bos_id in seq:
        seq = seq[seq.index(bos_id) + 1:]
    # Truncate at EOS if present
    if eos_id in seq:
        seq = seq[:seq.index(eos_id)]
    return seq

# Remove starting signal tokens (none-few characters + whitespace)
def clean(pred_seq):
  pred_seq = pred_seq.split()
  if len(pred_seq)==1:
    return pred_seq[0]
  else:
    return pred_seq[1]

import Levenshtein # used to calculate edit distance

def test_beam_search(model, test_loader, tokenizer, max_len=600, beam_width=5, pad_token_id=tokenizer.pad_token_id, batch_size=1):
    model.eval()
    total_sequences = 0
    exact_instr = 0 # For storing exact string matches between prediction and truth
    total_levenshtein = 0.0  # For normalized Levenshtein
    bos_token_id = tokenizer.bos_token_id
    eos_token_id = tokenizer.eos_token_id
    vocab_size = tokenizer.vocab_size

    with torch.no_grad():
        for batch in test_loader:
            enc_input, dec_target = batch
            enc_input = enc_input.to(device)
            dec_target = dec_target.to(device)

            encoder_output, encoder_hidden = model.encoder(enc_input)
            pred_tokens = []

            for b in range(batch_size):
                hidden = (
                    encoder_hidden[0][:, b:b+1].contiguous(),
                    encoder_hidden[1][:, b:b+1].contiguous()
                )

                beams = [(
                    torch.tensor([[bos_token_id]], dtype=torch.long, device=device),
                    hidden,
                    0.0
                )]
                finished = []

                for _ in range(max_len):
                    all_candidates = []

                    for tokens_so_far, hidden_state, score_so_far in beams:
                        last_token = tokens_so_far[:, -1].unsqueeze(1)
                        output, hidden_out, _ = model.decoder(last_token, hidden_state, encoder_output)
                        logits = output[:, -1, :]
                        log_probs = torch.log_softmax(logits, dim=-1)

                        topk_log_probs, topk_indices = log_probs.topk(beam_width, dim=-1)

                        for k in range(topk_indices.size(1)):
                            next_token = topk_indices[0, k].item()
                            next_log_prob = topk_log_probs[0, k].item()
                            raw_score = score_so_far + next_log_prob

                            seq_len = tokens_so_far.size(1) + 1
                            alpha = 0.6
                            length_penalty = ((5 + seq_len) / 6) ** alpha # Length penalty to ensure model does not overly and incorrectly favor shorter predictions
                            normalized_score = raw_score / length_penalty

                            new_seq = torch.cat([
                                tokens_so_far,
                                torch.tensor([[next_token]], device=device)
                            ], dim=1)

                            new_hidden = (
                                hidden_out[0].contiguous(),
                                hidden_out[1].contiguous()
                            )

                            if next_token == eos_token_id:
                                finished.append((new_seq.squeeze(0).tolist(), normalized_score))
                            else:
                                all_candidates.append((new_seq, new_hidden, normalized_score))

                    beams = sorted(all_candidates, key=lambda x: -x[-1])[:beam_width]
                    if not beams:
                        break

                if finished:
                    best_seq = sorted(finished, key=lambda x: -x[-1])[0][0]
                elif beams:
                    best_seq = beams[0][0].squeeze(0).tolist()
                else:
                    best_seq = [bos_token_id]

                pred_tokens.append(best_seq)

            for i in range(batch_size):
                pred_seq = strip_special_tokens(pred_tokens[i], bos_token_id, eos_token_id)
                true_seq = strip_special_tokens(dec_target[i].cpu().tolist(), bos_token_id, eos_token_id)

                pred_str = clean(tokenizer.decode(pred_seq, skip_special_tokens=True))
                true_str = tokenizer.decode(true_seq, skip_special_tokens=True)

                # Print for debug purposes 
                if i < 5:
                    print("Pred:", pred_str)
                    print("True:", true_str)
                    print(f"Predicted token IDs: {pred_seq}")
                    print(f"Ground truth token IDs: {true_seq}")

                lev_dist = Levenshtein.distance(pred_str, true_str)
                norm_lev_dist = lev_dist / max(1, max(len(pred_str), len(true_str)))
                total_levenshtein += norm_lev_dist

                total_sequences += 1
                exact_instr += int(pred_str.strip() == true_str.strip())

    accuracy = exact_instr / total_sequences if total_sequences > 0 else 0.0
    avg_norm_levenshtein = total_levenshtein / total_sequences if total_sequences > 0 else 0.0

    print(f"[Beam Search Test - Seq Level] Exact Match Accuracy (Top-1): {accuracy:.4f}")
    print(f"[Beam Search Test - Seq Level] Avg Normalized Levenshtein: {avg_norm_levenshtein:.4f}")
    return accuracy, avg_norm_levenshtein

Step 14: Training Loop

The train function handles the complete training workflow for the Seq2Seq LSTM, with support for hyperparameter tuning via wandb. The function sets up the training, validation, and test data loaders. It loads hyperparameter values from the wandb.config object, which is defined through a sweep configuration (refer Step 7 for sweep_config). Alternatively, you may replace a parameter (config.<param> here) with fixed values to manually define your hyperparameters.

For each epoch, the model is trained using train_epoch(), and validation is performed using evaluate(). The test function used is test_beam_search. Key metrics are logged to wandb for tracking. After training completes, the final model is saved locally and also uploaded to wandb for record-keeping or future use.

# Training loop, with wandb
def train():
    wandb.init()
    config = wandb.config

    vocab_size = tokenizer.vocab_size
    batch_size = 64 # Batch size can also be configurable and performance may change
    # but, parallelisation was used to utilise multiple GPUs and memory limitations restricted us to batch size based on current availability

    # Data loaders
    train_loader = create_dataloader(train_enc_input, train_dec_input, batch_size)
    val_loader = create_dataloader(val_enc_input, val_dec_input, batch_size)
    test_loader = create_dataloader(test_enc_input, test_dec_input, batch_size=1) # current test_beam_serach can only handle batch size of 1

    # Model
    encoder = Encoder(
        vocab_size=vocab_size,
        embed_dim=config.embed_dim,
        hidden_dim=config.hidden_dim,
        num_layers=config.num_layers,
        dropout=config.dropout,
    )

    decoder = Decoder(
        vocab_size=vocab_size,
        embed_dim=config.embed_dim,
        hidden_dim=config.hidden_dim,
        num_layers=config.num_layers,
        dropout=config.dropout,
    )

    model = Seq2Seq(encoder, decoder, tokenizer).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    for epoch in range(config.epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, vocab_size, epoch, max_len=600)
        val_loss, val_acc = evaluate(model, val_loader, criterion, vocab_size, max_len=600)
        # Skip beam search testing during early epochs to save time
        if epoch >=3: # Adjust as required/ based on available resources
            print("Epoch number: ", epoch)
            test_acc, avg_norm_levenshtein = test_beam_search(model, test_loader, tokenizer)
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
                "test_acc": test_acc,
                "avg_norm_levenshtein": avg_norm_levenshtein
            })
        else:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
            })

        print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} |Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # Save final model
    torch.save(model.state_dict(), "model_final.pth")
    wandb.save("model_final.pth")
    return model

(If using wandb) Step 15: Start wandb Sweep

wandb.login()
sweep_id = wandb.sweep(sweep_config, project="retro-lstm") # name your project as you please
wandb.agent(sweep_id, function = train) # set the function to the function that handles the entire model training, eval, test process

7.4 Transformer

7.5 Graph Neural Networks

results matching ""

    No results matching ""