Example usage of the SWoTTeD module

In this notebook we illustrate the functionalities of a SWoTTeD module. It makes use of the synthetic data generator that is available only in the git repository (but not in the PyPI deposit). In case you installed SWoTTeD with pip then you need first to get the gen_data.py module to generate synthetic datasets. A basic example that does not require this file is available in this notebook

[ ]:
# environnement setting
%pwd
import sys
sys.path.append("..")
[ ]:
from torch.utils.data import DataLoader
from swotted import swottedModule, swottedTrainer
from swotted.utils import Subset, success_rate
from swotted.loss_metrics import *
from tests.gen_data import gen_synthetic_data

import matplotlib.pyplot as plt
import numpy as np

from omegaconf import OmegaConf

Generation of synthetic dataset

In this section, we generate a synthetic dataset.

[ ]:
# Synthetic dataset parameters
K = 100     # number of patients
N = 10      # number of medical events
T = 6       # length of time's stay
R = 4       # number of phenotypes
Tw = 3      # length of time's window
[ ]:
# Generating synthetic data
W_, Ph_, X, params = gen_synthetic_data(
    K, N, T, R, Tw, sliding_window=True, noise=0.0, truncate=True
)

#create the dataloader that will be used to feed the model trainer
train_loader = DataLoader(
    Subset(X, np.arange(len(X))),
    batch_size=50,
    shuffle=False,
    collate_fn=lambda x: x
)
[ ]:
#Illustration of the phenotypes
fig, axs = plt.subplots(1, R)
for i in range(R):
    axs[i].imshow(1-Ph_[i], cmap="gray", vmin=0, vmax=1, interpolation="nearest")
    axs[i].set_ylabel("Drugs")
    axs[i].set_xlabel("Time")
    #axs[0,i].set_title("phenotype")
[ ]:
#Illustration of one pathway
plt.imshow(1-X[0], interpolation="none", cmap="gray", vmax=1, vmin=0)

Definition of the model

The parameters of the SWoTTeD module are provided through a config dictionnary. It uses OmegaConf configuration to ease the use of the dictionaries.

[ ]:
params = {}
params['model']={}
params['model']['non_succession']=0.5
params['model']['sparsity']=0.5
params['model']['rank']=R
params['model']['twl']=Tw
params['model']['N']=N
params['model']['metric']="Bernoulli"

#some additional parameters of the trainer
params['training']={}
params['training']['lr']=1e-2

#some additional parameters for the projection (decomposition on new sequences)
params['predict']={}
params['predict']['nepochs']=100
params['predict']['lr']=1e-2

config=OmegaConf.create(params)

The SWoTTeD module is implemented as PyTorch Lightning module, and it has to be used with the corresponding trainer. Indeed, SWoTTeD optimisation problem is a not a classical supervised task and the optimization problem has to be set with the knowledge of the size of the dataset.

[ ]:
# define the model
swotted = swottedModule(config)
# train the model
trainer = swottedTrainer(max_epochs=200)

Then, a classical fit function run the optimization process.

[ ]:
#train the model
trainer.fit(model=swotted, train_dataloaders=train_loader)

Analysis of the results

Once fitted, the SWoTTeD models contains the phenotypes and the intermediary decompositions (pathways). This intermediary decompositions are not part of the model, but are contained in the swotted object at the end of the optimization process.

Phenotype analysis

We start by comparing the extracted phenotypes with the hidden ones.

[ ]:
# visualize the phenotype
reordered_pheno, reordered_pathways = swotted.reorderPhenotypes(Ph_, tw=Tw)
for i in range(R):
    plt.subplot(211)
    plt.imshow(Ph_[i], vmin=0, vmax=1, cmap="binary", interpolation="nearest")
    plt.ylabel("Drugs")
    plt.xlabel("time")
    plt.title("phenotype")
    plt.subplot(212)
    plt.imshow(reordered_pheno[i].detach().numpy(), vmin=0, vmax=1, cmap="binary", interpolation="nearest")
    plt.ylabel("Drugs")
    plt.xlabel("time")
    plt.title("result")
    plt.show()

Decomposition

We now illustrate how to apply the SWoTTeD model on an (assumed) new sequence. Note that the SWoTTeD model is only made of the phenotypes. Applying the model project at best the sequence on the model phenotypes. This output a pathway that can be compared to the expected one.

[ ]:
#make predictions with the train model: it projects the X on the phenotypes of the model
id=10
W=swotted(X[id:id+1])

#reorder the row of the pathway
_,W=swotted.reorderPhenotypes(Ph_,W)

#Visual comparison of the care pathways
plt.subplot(121)
plt.imshow(W_[id], vmin=0, vmax=1, cmap="binary", interpolation="nearest")
plt.ylabel("Phenotypes")
plt.xlabel("time")
plt.title("Original pathway")
plt.subplot(122)
plt.imshow(W[0].detach().numpy(), vmin=0, vmax=1, cmap="binary", interpolation="nearest")
plt.ylabel("Phenotypes")
plt.xlabel("time")
plt.title("Discovered pathway")
plt.show()

Reconstruction of the matrix from the phenotypes and the pathway

[ ]:
id=10
W=swotted(X[id:id+1])
#reorder the row of the pathway
rPh,rW=swotted.reorderPhenotypes(Ph_,W)
X_pred=swotted.model.reconstruct(rW[0], rPh)

[ ]:
#Visual comparison of the data and the reconstructed data
plt.subplot(121)
plt.imshow(X[id], vmin=0, vmax=1, cmap="binary", interpolation="nearest")
plt.ylabel("Drugs")
plt.xlabel("Days")
plt.title("Original pathway")
plt.subplot(122)
plt.imshow(X_pred.detach().numpy(), vmin=0, vmax=1, cmap="binary", interpolation="nearest")
plt.ylabel("Drugs")
plt.xlabel("Days")
plt.title("Discovered pathway")
plt.show()
[ ]:
error= success_rate(X[id], X_pred)
error

Forecast

We now illustrate how to apply the SWoTTeD model to predict next event.

In this example, we illustrate how to predict the next event from the learned decomposition. We take one matrix and remove its last date. Then, we use the model on the amputed matrix to forecast the events at the last position.

The figure compares the real last position and the predicted one. The prediction is probabilitic.

[ ]:
id = 15
pred = swotted.forecast([X[id][:, :-1]])

# Visual comparison of the care pathways
plt.subplot(211)
plt.imshow(X[id][:, -1].unsqueeze(0), vmin=0, vmax=1, cmap="binary", interpolation="nearest")
plt.ylabel("Drugs")
plt.xlabel("time")
plt.subplot(212)
plt.imshow(pred[0].unsqueeze(0), vmin=0, vmax=1, cmap="binary", interpolation="nearest")
plt.ylabel("Drugs")
plt.xlabel("time")
plt.title("Forecast")
plt.show()