How to Create a Molecular Dataset
A step-by-step guide to featurizing molecules using RDKit and storing them as a PyTorch Geometric dataset.
Introduction
If you’d like to follow along or run this notebook yourself, you can find both the notebook and dataset in my GitHub repository.
Imagine you want to train a neural network to predict the binding affinity of a ligand to a specific protein. You have a dataset of ligands, their 3D structures and binding affinities, but you don’t know how to featurize the ligands for this task. This interactive notebook will walk you through the steps of featurizing ligands using RDKit and organizing them into a torch_geometric dataset that can be used to train, for example, a Graph Neural Network.
To go through this notebook, I created a dataset of 12 small molecules with their 3D structures in the MOL2 format. We will use RDKit to extract the features of the atoms and bonds of these molecules and store them as molecular graphs using PyTorch Geometric, so let’s start by importing the necessary libraries.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pandas as pd
import torch
import os
from torch_geometric.utils import one_hot
from torch_geometric.data import (
InMemoryDataset,
Data,
download_url,
extract_zip,
)
from rdkit import Chem
from rdkit.Chem import Draw
Draw.rdDepictor.SetPreferCoordGen(True)
from rdkit.Chem.Draw import IPythonConsole
import py3Dmol
Now, let’s download the dataset and extract the molecular structures into the data directory. This dataset includes a CSV file containing each molecule’s name, SMILES, activity, and purchasability, and also provides the corresponding 3D structures in MOL2 format.
1
2
3
4
! mkdir -p data
! wget -q https://github.com/vladislach/small-molecules/raw/main/structures.zip
! wget -q https://github.com/vladislach/small-molecules/raw/main/molecules.csv -O data/molecules.csv
! unzip -q -o structures.zip -d data/raw; rm structures.zip
Next, let’s load the dataset and take a look at the CSV file to see the information we have about each molecule.
1
2
3
4
5
df = pd.read_csv('data/molecules.csv')
df.style.format({'activity': '{:.2f}'}) \
.hide(axis='index') \
.set_table_styles([{'selector': 'thead th', 'props': [('text-align', 'center')]}]) \
.set_properties(**{'text-align': 'left'})
| id | name | smiles | activity | purchasable |
|---|---|---|---|---|
| X-0131 | Ibuprofen | CC(C)Cc1ccc([C@@H](C)C(=O)O)cc1 | 3.70 | 1 |
| X-0179 | Lidocaine | CCN(CC)CC(=O)Nc1c(C)cccc1C | 5.80 | 1 |
| X-0258 | Bimatoprost | CCNC(=O)CCC/C=C\C[C@@H]1[C@@H](/C=C/[C@@H](O)CCc2ccccc2)[C@H](O)C[C@@H]1O | 0.70 | 1 |
| X-1053 | Diclofenac | O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl | 12.20 | 1 |
| X-2717 | Metoprolol | COCCc1ccc(OCC(O)CNC(C)C)cc1 | 7.70 | 0 |
| X-1362 | Clopidogrel | COC(=O)[C@H](c1ccccc1Cl)N1CCc2sccc2C1 | 10.10 | 0 |
| X-0948 | Salbutamol | CC(C)(C)NCC(O)c1ccc(O)c(CO)c1 | 9.10 | 0 |
| X-2003 | Lenvatinib | COc1cc2nccc(Oc3ccc(NC(=O)NC4CC4)c(Cl)c3)c2cc1C(N)=O | 8.50 | 0 |
| X-1432 | Tafamidis | O=C(O)c1ccc2nc(-c3cc(Cl)cc(Cl)c3)oc2c1 | 15.50 | 1 |
| X-0024 | Naproxen | COc1ccc2cc([C@H](C)C(=O)O)ccc2c1 | 0.40 | 0 |
| X-0144 | Enzalutamide | CNC(=O)c1ccc(N2C(=S)N(c3ccc(C#N)c(C(F)(F)F)c3)C(=O)C2(C)C)cc1F | 6.10 | 1 |
| X-1155 | Linezolid | CC(=O)NC[C@H]1CN(c2ccc(N3CCOCC3)c(F)c2)C(=O)O1 | 2.40 | 1 |
As you can see, we have 12 molecules in the dataset, each with a unique ID, name, SMILES, activity, and purchasability. The molecules’ activity will be our target variable. The purchasability will not be used in this notebook, but you can use it to create a binary classification task if you want. Let’s start by converting the SMILES strings into RDKit Mol objects and visualizing each molecule.
1
2
3
4
5
6
7
mols = [Chem.MolFromSmiles(smiles) for smiles in df['smiles']]
opts = Draw.MolDrawOptions()
opts.legendFontSize = 14
Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(300, 270),
legends=df['name'].to_list(), drawOptions=opts, useSVG=True)
The id of each molecule can be used to access the corresponding 3D structure in the data/raw directory. Let’s visualize the 3D structure of Diclofenac as an example.
1
2
3
4
5
6
7
8
9
10
# Feel free to change the molecule name to visualize a different molecule.
id = df.loc[df['name'] == 'Diclofenac', 'id'].values[0]
mol = Chem.MolFromMol2File(f'data/raw/{id}.mol2')
mb = Chem.MolToMolBlock(mol)
viewer = py3Dmol.view(width=800, height=400)
viewer.addModel(mb, "mol")
viewer.setStyle({'stick': {}})
viewer.zoomTo()
viewer.show()
The structure above is interactive, so you can zoom in/out and rotate it.
Featurization
RDKit provides many possible features, but we will focus on the following:
- Atomic Number (one-hot encoded)
- Total Number of Neighboring Atoms (one-hot encoded)
- Number of Connected Hydrogens (one-hot encoded)
- Hybridization (one-hot encoded)
- Aromaticity (True/False)
- Presence of the Atom in a Ring of a specific size (True/False for each size)
For a full list of available features, you can check the RDKit documentation here. Now, let’s examine the types of atoms in the dataset across each feature to understand how we can one-hot encode them later.
1
2
3
4
5
6
7
print(f"Atomic Numbers: {set(atom.GetAtomicNum() for mol in mols for atom in mol.GetAtoms())}")
print(f"Degrees: {set(atom.GetTotalDegree() for mol in mols for atom in mol.GetAtoms())}")
print(f"Number Hs: {set(atom.GetTotalNumHs() for mol in mols for atom in mol.GetAtoms())}")
print(f"Hybridization: {set(str(atom.GetHybridization()) for mol in mols for atom in mol.GetAtoms())}")
print(f"Aromaticity: {set(atom.GetIsAromatic() for mol in mols for atom in mol.GetAtoms())}")
print("Presence in Ring of Size: " +
", ".join([f"{n}: {set(atom.IsInRingSize(n) for mol in mols for atom in mol.GetAtoms())}" for n in range(3, 9)]))
Atomic Numbers: {6, 7, 8, 9, 16, 17}
Degrees: {1, 2, 3, 4}
Number Hs: {0, 1, 2, 3}
Hybridization: {'SP2', 'SP3', 'SP'}
Aromaticity: {False, True}
Presence in Ring of Size: 3: {False, True}, 4: {False}, 5: {False, True}, 6: {False, True}, 7: {False}, 8: {False}
As you can see, the dataset includes six different atom types: C, N, O, F, S, and Cl (hydrogens are not included here). We’ll create a mapping for each feature to one-hot encode them later. For example, the hybridization feature has three types: SP, SP2, and SP3. By storing them in a list, we can use each item’s index for one-hot encoding, so SP becomes [1, 0, 0], SP2 becomes [0, 1, 0], and SP3 becomes [0, 0, 1]. We’ll apply this approach to all relevant features and create mappings for both atom and bond features.
In total, we’ll have the following features per atom:
- 6 for atomic number
- 4 for degree
- 4 for number of hydrogens
- 3 for hybridization
- 1 for aromaticity
- 4 for presence in rings of size 3, 4, 5, and 6
Adding these up, we’ll generate 22 features for each atom.
1
2
3
4
5
6
7
8
9
10
x_map = {
'atomic_num': [6, 7, 8, 9, 16, 17],
'degree': [1, 2, 3, 4],
'num_hs': [0, 1, 2, 3],
'hybridization': ['SP', 'SP2', 'SP3']
}
e_map = {
'bond_type': ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC']
}
Now, let’s implement the feature extraction function. This function will take a molecule as input and return the relevant features we need. It will also one-hot encode these features using the mappings we created earlier and the one_hot function from torch_geometric. For more details on this function, you can refer to the documentation here.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def get_node_feats(rdmol: Chem.Mol) -> torch.Tensor:
"""Generates a feature tensor for each atom (node) in an RDKit molecule object (rdmol)."""
feats = []
feats.append(
one_hot(
torch.tensor([x_map['atomic_num'].index(atom.GetAtomicNum()) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['atomic_num']), dtype=torch.float
)
)
feats.append(
one_hot(
torch.tensor([x_map['degree'].index(atom.GetTotalDegree()) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['degree']), dtype=torch.float
)
)
feats.append(
one_hot(
torch.tensor([x_map['num_hs'].index(atom.GetTotalNumHs()) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['num_hs']), dtype=torch.float
)
)
feats.append(
one_hot(
torch.tensor([x_map['hybridization'].index(str(atom.GetHybridization())) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['hybridization']), dtype=torch.float
)
)
feats.append(
torch.tensor([atom.GetIsAromatic() for atom in rdmol.GetAtoms()], dtype=torch.float).view(-1, 1)
)
for i in range(3, 7):
feats.append(
torch.tensor([atom.IsInRingSize(i) for atom in rdmol.GetAtoms()], dtype=torch.float).view(-1, 1)
)
return torch.cat(feats, dim=-1)
Now, let’s test the function on a molecule and print the features for one of its atoms. We’ll also visualize the atom to see if the features make sense. Feel free to change the molecule and atom index to see different results.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
rdmol = mols[df.query("name == 'Clopidogrel'").index[0]] # Try different molecules from the dataset
node_feats = get_node_feats(rdmol)
i = 12 # Atom index, can change it to see different atoms
IPythonConsole.drawOptions.addAtomIndices = True
display(Draw.MolsToGridImage([rdmol], molsPerRow=1, subImgSize=(250, 250), legends=[f"Atom #{i}"],
useSVG=True, highlightAtomLists=[[i]]))
print(f"node_feats shape: {node_feats.shape}\n")
print(f" Full feature vector: {node_feats[i]}\n")
print(f"One-hot encoded atomic number for Atom #{i}:")
print(f" Encoding: {node_feats[i, :6]}")
print(f" Corresponds to atomic number: {x_map['atomic_num'][node_feats[i, :6].argmax().item()]}\n")
print(f"One-hot encoded degree for Atom #{i}:")
print(f" Encoding: {node_feats[i, 6:10]}")
print(f" Corresponds to degree: {x_map['degree'][node_feats[i, 6:10].argmax().item()]}\n")
print(f"One-hot encoded number of Hs for Atom #{i}:")
print(f" Encoding: {node_feats[i, 10:14]}")
print(f" Corresponds to number of Hs: {x_map['num_hs'][node_feats[i, 10:14].argmax().item()]}\n")
print(f"One-hot encoded hybridization for Atom #{i}:")
print(f" Encoding: {node_feats[i, 14:17]}")
print(f" Corresponds to hybridization: {x_map['hybridization'][node_feats[i, 14:17].argmax().item()]}\n")
node_feats shape: torch.Size([21, 22])
Full feature vector: tensor([0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0.,
0., 0., 0., 1.])
One-hot encoded atomic number for Atom #12:
Encoding: tensor([0., 1., 0., 0., 0., 0.])
Corresponds to atomic number: 7
One-hot encoded degree for Atom #12:
Encoding: tensor([0., 0., 1., 0.])
Corresponds to degree: 3
One-hot encoded number of Hs for Atom #12:
Encoding: tensor([1., 0., 0., 0.])
Corresponds to number of Hs: 0
One-hot encoded hybridization for Atom #12:
Encoding: tensor([0., 0., 1.])
Corresponds to hybridization: SP3
Next, we need to obtain bond features, which represent the edges of our molecular graphs. In PyTorch Geometric, these are stored as edge_attr in the Data object. Before extracting the bond features (whether a given bond is single, double, triple, or aromatic), we need to get the edge_index from the molecule.
The edge_index is a tensor of shape [2, num_edges] that stores the indices of atoms connected by an edge and defines the graph’s connectivity. This standard representation of graph connectivity in PyTorch Geometric is very useful for GNNs. It also allows us to get bond features and store them in the correct order in the edge_attr tensor.
If you’re unfamiliar with edge_index or the Data object in PyTorch Geometric, you might find this excellent tutorial by Chaitanya K. Joshi, Charlie Harris and Ramon Viñas Torné helpful. It’s part of the Geometric Deep Learning course, and it explains these concepts in depth. Additionally, you can refer to the official PyTorch Geometric documentation.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_edge_index(mol: Chem.Mol) -> torch.Tensor:
"""Generates an edge index tensor (edge connectivity) for an RDKit molecule object (mol)."""
return torch.tensor(Chem.GetAdjacencyMatrix(mol), dtype=torch.long).to_sparse().indices()
def get_edge_attr(rdmol: Chem.Mol, edge_index: torch.Tensor) -> torch.Tensor:
"""Generates a feature tensor for each bond (edge) in an RDKit molecule object (rdmol)."""
edge_attr = []
for i, j in edge_index.T:
bond = rdmol.GetBondBetweenAtoms(i.item(), j.item())
edge_attr.append(e_map['bond_type'].index(str(bond.GetBondType())))
edge_attr = torch.tensor(edge_attr, dtype=torch.long)
return one_hot(edge_attr, num_classes=len(e_map['bond_type']), dtype=torch.float)
Similar to the atom features, we’ll test the bond feature extraction function on a molecule and print the features for one of its bonds. We’ll also visualize the bond to see if the features make sense. Feel free to change the atoms’ indices to see different results.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
edge_index = get_edge_index(rdmol)
edge_attr = get_edge_attr(rdmol, edge_index)
i, j = 7, 8 # Change i and j to see different bonds
display(Draw.MolsToGridImage([rdmol], molsPerRow=1, subImgSize=(250, 250), legends=[f"Bond between atoms {i} and {j}"],
useSVG=True, highlightAtomLists=[[i, j]], highlightBondLists=[[rdmol.GetBondBetweenAtoms(i, j).GetIdx()]]))
print(f'edge_index shape:\n{edge_index.shape}\n')
print(f'First 15 bonds in edge_index:\n{edge_index[:, :15]}\n')
print(f'edge_attr shape: {edge_attr.shape}\n')
bond_idx = torch.where((edge_index.T == torch.tensor([i, j])).all(dim=1))[0].item()
print(f'One-hot encoded bond type for bond between atoms {i} and {j}:')
print(f' Encoding: {edge_attr[bond_idx]}')
print(f' Corresponds to bond type: {e_map["bond_type"][edge_attr[bond_idx].argmax().item()]}')
edge_index shape:
torch.Size([2, 46])
First 15 bonds in edge_index:
tensor([[ 0, 1, 1, 2, 2, 2, 3, 4, 4, 4, 5, 5, 5, 6, 6],
[ 1, 0, 2, 1, 3, 4, 2, 2, 5, 12, 4, 6, 10, 5, 7]])
edge_attr shape: torch.Size([46, 4])
One-hot encoded bond type for bond between atoms 7 and 8:
Encoding: tensor([0., 0., 0., 1.])
Corresponds to bond type: AROMATIC
Molecular Dataset
Now that we know how to extract features for atoms and bonds, let’s create a torch_geometric dataset to store molecular graphs as Data objects. If you’re not familiar with PyTorch Geometric’s Data object, you can refer to the official documentation or the tutorial mentioned above.
The download method will fetch the dataset from the URL provided in the url attribute. This is similar to what we did earlier, but now we’re using the torch_geometric download_url function. The process method will read the dataset and create a list of Data objects, where each Data object represents a molecular graph. Each Data object will contain:
- The
xattribute, which stores the atom features. - The
edge_indexattribute, which stores the graph’s connectivity. - The
edge_attrattribute, which stores bond features. - The
yattribute, which stores the target value (in this case, the molecule’s activity). - The
posattribute, which stores the 3D coordinates of the atoms extracted with RDKit.
We will skip transform, pre_transform, and pre_filter for now. If you’re interested in learning more about these, check the documentation here.
Below is the implementation for the MolDataset class, which includes methods for processing the data we implemented earlier:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class MolDataset(InMemoryDataset):
"""
A PyTorch Geometric dataset for small molecules. For each molecule, the dataset contains the following:
- x: Node features (atom type, degree, number of Hs, hybridization, aromaticity, presence in rings of size 3-6)
- edge_index: Edge connectivity
- edge_attr: Edge features (bond type)
- pos: Node positions
- y: Target value (activity)
- name: Molecule name
"""
url = 'https://github.com/vladislach/small-molecules/raw/main/structures.zip'
def __init__(self, root, csv, transform=None):
self.df = pd.read_csv(csv)
super().__init__(root, transform)
self.load(self.processed_paths[0])
def raw_file_names(self):
return [f"{id}.mol2" for id in self.df['id']]
def processed_file_names(self):
return ['data.pt']
def download(self) -> None:
path = download_url(self.url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)
def process(self):
data_list = []
for path, row in zip(self.raw_paths, self.df.iterrows()):
row = row[1]
rdmol = Chem.MolFromMol2File(path)
x = self.get_node_feats(rdmol)
edge_index = self.get_edge_index(rdmol)
edge_attr = self.get_edge_attr(rdmol, edge_index)
pos = torch.tensor(rdmol.GetConformer().GetPositions(), dtype=torch.float)
y = row['activity']
name = row['name']
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos, name=name)
data_list.append(data)
self.save(data_list, self.processed_paths[0])
def get_node_feats(self, rdmol: Chem.Mol) -> torch.Tensor:
"""Generates a feature tensor for each atom (node) in an RDKit molecule object (rdmol)."""
feats = []
feats.append(
one_hot(
torch.tensor([x_map['atomic_num'].index(atom.GetAtomicNum()) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['atomic_num']), dtype=torch.float
)
)
feats.append(
one_hot(
torch.tensor([x_map['degree'].index(atom.GetTotalDegree()) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['degree']), dtype=torch.float
)
)
feats.append(
one_hot(
torch.tensor([x_map['num_hs'].index(atom.GetTotalNumHs()) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['num_hs']), dtype=torch.float
)
)
feats.append(
one_hot(
torch.tensor([x_map['hybridization'].index(str(atom.GetHybridization())) for atom in rdmol.GetAtoms()]),
num_classes=len(x_map['hybridization']), dtype=torch.float
)
)
feats.append(
torch.tensor([atom.GetIsAromatic() for atom in rdmol.GetAtoms()], dtype=torch.float).view(-1, 1)
)
for i in range(3, 7):
feats.append(
torch.tensor([atom.IsInRingSize(i) for atom in rdmol.GetAtoms()], dtype=torch.float).view(-1, 1)
)
return torch.cat(feats, dim=-1)
def get_edge_index(self, rdmol: Chem.Mol) -> torch.Tensor:
"""Generates an edge index tensor (edge connectivity) for an RDKit molecule object (mol)."""
return torch.tensor(Chem.GetAdjacencyMatrix(rdmol), dtype=torch.long).to_sparse().indices()
def get_edge_attr(self, rdmol: Chem.Mol, edge_index: torch.Tensor) -> torch.Tensor:
"""Generates a feature tensor for each bond (edge) in an RDKit molecule object (rdmol)."""
edge_attr = []
for i, j in edge_index.T:
bond = rdmol.GetBondBetweenAtoms(i.item(), j.item())
edge_attr.append(e_map['bond_type'].index(str(bond.GetBondType())))
edge_attr = torch.tensor(edge_attr, dtype=torch.long)
return one_hot(edge_attr, num_classes=len(e_map['bond_type']), dtype=torch.float)
Now, let’s create an instance of the MolDataset class and iterate over the dataset to examine the Data objects it contains. PyTorch Geometric’s Data objects display the shape of each attribute, so we can verify that:
- The atom features (
x) have shape[num_atoms, 22] - The connectivity (
edge_index) has shape[2, num_edges] - The bond features (
edge_attr) have shape[num_edges, 4] - The target (
y) has shape[1] - The 3D coordinates (
pos) have shape[num_atoms, 3]
1
2
3
4
dataset = MolDataset(root='data', csv='data/molecules.csv')
for data in dataset:
print(data)
Data(x=[15, 22], edge_index=[2, 30], edge_attr=[30, 4], y=[1], pos=[15, 3], name='Ibuprofen')
Data(x=[17, 22], edge_index=[2, 34], edge_attr=[34, 4], y=[1], pos=[17, 3], name='Lidocaine')
Data(x=[30, 22], edge_index=[2, 62], edge_attr=[62, 4], y=[1], pos=[30, 3], name='Bimatoprost')
Data(x=[19, 22], edge_index=[2, 40], edge_attr=[40, 4], y=[1], pos=[19, 3], name='Diclofenac')
Data(x=[19, 22], edge_index=[2, 38], edge_attr=[38, 4], y=[1], pos=[19, 3], name='Metoprolol')
Data(x=[21, 22], edge_index=[2, 46], edge_attr=[46, 4], y=[1], pos=[21, 3], name='Clopidogrel')
Data(x=[17, 22], edge_index=[2, 34], edge_attr=[34, 4], y=[1], pos=[17, 3], name='Salbutamol')
Data(x=[30, 22], edge_index=[2, 66], edge_attr=[66, 4], y=[1], pos=[30, 3], name='Lenvatinib')
Data(x=[20, 22], edge_index=[2, 44], edge_attr=[44, 4], y=[1], pos=[20, 3], name='Tafamidis')
Data(x=[17, 22], edge_index=[2, 36], edge_attr=[36, 4], y=[1], pos=[17, 3], name='Naproxen')
Data(x=[32, 22], edge_index=[2, 68], edge_attr=[68, 4], y=[1], pos=[32, 3], name='Enzalutamide')
Data(x=[24, 22], edge_index=[2, 52], edge_attr=[52, 4], y=[1], pos=[24, 3], name='Linezolid')
And voilà! We’ve successfully created a torch_geometric dataset that stores molecular graphs as Data objects with the atom and bond features we need. This dataset (although very small) is now ready for use in training a GNN model, for example, to predict the activity of molecules.