MOOD Protocol
As model selection is often argued to improve generalization, we investigate what molecular splitting strategy mimics the deployment distribution the best. The investigation measures the representativeness of various candidate splitting methods.
- Compute the distance of each molecule in the deployment set(s) to the training set. This step gives the “deployment-to-train” distribution which is the target distance distribution that should be mimicked during model selection to better generalize during deployment. If the final model will be retrained on the full-dataset before deployment, the distances must be computed w.r.t the full dataset instead of just the training partition.
- Characterize each splitting method by splitting the dataset into a train and test sets. Then, compute the distance of each test sample to the training set to get the “test-to-train” distribution. For small datasets, this step should be repeated with multiple seeds to get more reliable estimates of the test-to-train distribution before doing the final split that will be used for training.
- Score the different splitting methods by measuring the distance between their test-to-train distribution and the deployment-to-train distance distribution. Then, select the splitting method that has the lowest distance for model selection. Here, we use the Jenssen-Shannon distance between the distributions.
This protocol is implemented in the MOODSplitter. See an example of how to use the it below:
In [1]:
Copied!
%load_ext autoreload
%autoreload 2
import numpy as np
import datamol as dm
from sklearn.model_selection import ShuffleSplit
import splito
%load_ext autoreload
%autoreload 2
import numpy as np
import datamol as dm
from sklearn.model_selection import ShuffleSplit
import splito
In [3]:
Copied!
# Load the training dataset
dataset = dm.data.solubility()
dataset_feat = [dm.to_fp(mol) for mol in dataset.mol]
# Load the deployment set
# Alternatively, you can also load an array of deployment-to-dataset distance
deployment_feat = [dm.to_fp(mol) for mol in dm.data.chembl_drugs()["smiles"]]
# Load the training dataset
dataset = dm.data.solubility()
dataset_feat = [dm.to_fp(mol) for mol in dataset.mol]
# Load the deployment set
# Alternatively, you can also load an array of deployment-to-dataset distance
deployment_feat = [dm.to_fp(mol) for mol in dm.data.chembl_drugs()["smiles"]]
In [4]:
Copied!
# Define the candidate splitters
# Since we use the scikit-learn interface, this can also be sklearn Splitters
splitters = {
"Random": ShuffleSplit(),
"Scaffold": splito.ScaffoldSplit(dataset.mol.values),
"Perimeter": splito.PerimeterSplit(),
"MaxDissimilarity": splito.MaxDissimilaritySplit(),
}
splitter = splito.MOODSplitter(splitters)
# Define the candidate splitters
# Since we use the scikit-learn interface, this can also be sklearn Splitters
splitters = {
"Random": ShuffleSplit(),
"Scaffold": splito.ScaffoldSplit(dataset.mol.values),
"Perimeter": splito.PerimeterSplit(),
"MaxDissimilarity": splito.MaxDissimilaritySplit(),
}
splitter = splito.MOODSplitter(splitters)
In [5]:
Copied!
# get the rank of the splitting methods with the givent deployment set
splitter.fit(X=np.stack(dataset_feat), X_deployment=np.stack(deployment_feat))
# get the rank of the splitting methods with the givent deployment set
splitter.fit(X=np.stack(dataset_feat), X_deployment=np.stack(deployment_feat))
2023-09-22 08:57:15.795 | INFO | splito._mood_split:fit:308 - Ranked all different splitting methods: split representativeness best rank 0 Random 0.375938 False 4.0 1 Scaffold 0.492793 False 3.0 2 Perimeter 0.526232 False 2.0 3 MaxDissimilarity 0.552740 True 1.0 2023-09-22 08:57:15.795 | INFO | splito._mood_split:fit:309 - Selected MaxDissimilarity as the most representative splitting method
Out[5]:
split | representativeness | best | rank | |
---|---|---|---|---|
0 | Random | 0.375938 | False | 4.0 |
1 | Scaffold | 0.492793 | False | 3.0 |
2 | Perimeter | 0.526232 | False | 2.0 |
3 | MaxDissimilarity | 0.552740 | True | 1.0 |
With the given deployment, the best splitting method to ensure the generalization is the PerimeterSplit
.
- The End :-)