import os
import torch
from torch import nn
import torch.nn.functional as F
from typing import List
import pytorch_lightning as pl
[docs]
class Encoder(nn.Module):
"""A class that encapsulates the encoder."""
[docs]
def __init__(
self,
n_genes: int,
latent_dim: int = 128,
hidden_dim: List[int] = [1024, 1024],
dropout: float = 0.5,
input_dropout: float = 0.4,
residual: bool = False,
):
"""Constructor.
Args:
n_genes:
The number of genes in the gene space, representing the input dimensions.
latent_dim:
The latent space dimensions
hidden_dim:
A list of hidden layer dimensions, describing the number of layers and their dimensions.
Hidden layers are constructed in the order of the list for the encoder and in reverse
for the decoder.
dropout:
The dropout rate for hidden layers
input_dropout:
The dropout rate for the input layer
residual:
Use residual connections.
"""
super().__init__()
self.latent_dim = latent_dim
self.network = nn.ModuleList()
self.residual = residual
if self.residual:
assert len(set(hidden_dim)) == 1
for i in range(len(hidden_dim)):
if i == 0: # input layer
self.network.append(
nn.Sequential(
nn.Dropout(p=input_dropout),
nn.Linear(n_genes, hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
else: # hidden layers
self.network.append(
nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
# output layer
self.network.append(nn.Linear(hidden_dim[-1], latent_dim))
[docs]
def forward(self, x) -> torch.Tensor:
"""Forward.
Args:
x: torch.Tensor
Input tensor corresponding to input layer.
Returns:
Output tensor corresponding to output layer.
"""
for i, layer in enumerate(self.network):
if self.residual and (0 < i < len(self.network) - 1):
x = layer(x) + x
else:
x = layer(x)
return F.normalize(x, p=2, dim=1)
[docs]
def save_state(self, filename: str):
"""Save model state.
Args:
filename:
Filename to save the model state.
"""
torch.save({"state_dict": self.state_dict()}, filename)
[docs]
def load_state(self, filename: str, use_gpu: bool = False):
"""Load model state.
Args:
filename:
Filename containing the model state.
use_gpu:
Boolean indicating whether or not to use GPUs.
"""
if not use_gpu:
ckpt = torch.load(filename, map_location=torch.device("cpu"))
else:
ckpt = torch.load(filename)
self.load_state_dict(ckpt["state_dict"])
[docs]
class Decoder(nn.Module):
"""A class that encapsulates the decoder."""
[docs]
def __init__(
self,
n_genes: int,
latent_dim: int = 128,
hidden_dim: List[int] = [1024, 1024],
dropout: float = 0.5,
residual: bool = False,
):
"""Constructor.
Args:
n_genes:
The number of genes in the gene space, representing the input dimensions.
latent_dim:
The latent space dimensions
hidden_dim:
A list of hidden layer dimensions, describing the number of layers and their dimensions.
Hidden layers are constructed in the order of the list for the encoder and in reverse
for the decoder.
dropout:
The dropout rate for hidden layers
residual:
Use residual connections.
"""
super().__init__()
self.latent_dim = latent_dim
self.network = nn.ModuleList()
self.residual = residual
if self.residual:
assert len(set(hidden_dim)) == 1
for i in range(len(hidden_dim)):
if i == 0: # first hidden layer
self.network.append(
nn.Sequential(
nn.Linear(latent_dim, hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
else: # other hidden layers
self.network.append(
nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
# reconstruction layer
self.network.append(nn.Linear(hidden_dim[-1], n_genes))
[docs]
def forward(self, x) -> torch.Tensor:
"""Forward.
Args:
x:
Input tensor corresponding to input layer.
Returns:
Output tensor corresponding to output layer.
"""
for i, layer in enumerate(self.network):
if self.residual and (0 < i < len(self.network) - 1):
x = layer(x) + x
else:
x = layer(x)
return x
[docs]
def save_state(self, filename: str):
"""Save model state.
Args:
filename:
Filename to save the model state.
"""
torch.save({"state_dict": self.state_dict()}, filename)
[docs]
def load_state(self, filename: str, use_gpu: bool = False):
"""Load model state.
Args:
filename:
Filename containing the model state.
use_gpu:
Boolean indicating whether or not to use GPUs.
"""
if not use_gpu:
ckpt = torch.load(filename, map_location=torch.device("cpu"))
else:
ckpt = torch.load(filename)
self.load_state_dict(ckpt["state_dict"])
[docs]
class AutoEncoder(pl.LightningModule):
"""A class encapsulating training."""
[docs]
def __init__(
self,
n_genes: int,
latent_dim: int = 128,
hidden_dim: List[int] = [1024, 1024],
dropout: float = 0.5,
input_dropout: float = 0.4,
lr: float = 5e-3,
residual: bool = False,
):
"""Constructor.
Args:
n_genes:
The number of genes in the gene space, representing the input dimensions.
latent_dim:
The latent space dimensions. Defaults to 128.
hidden_dim:
A list of hidden layer dimensions, describing the number of layers and their dimensions.
Hidden layers are constructed in the order of the list for the encoder and in reverse
for the decoder.
dropout:
The dropout rate for hidden layers
input_dropout:
The dropout rate for the input layer
lr:
The initial learning rate
residual:
Use residual connections.
"""
super().__init__()
# network architecture
self.n_genes = n_genes
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.input_dropout = input_dropout
self.lr = lr
self.residual = residual
# networks
self.encoder = Encoder(
self.n_genes,
latent_dim=self.latent_dim,
hidden_dim=self.hidden_dim,
dropout=self.dropout,
input_dropout=self.input_dropout,
residual=self.residual,
)
self.decoder = Decoder(
self.n_genes,
latent_dim=self.latent_dim,
hidden_dim=list(reversed(self.hidden_dim)),
dropout=self.dropout,
residual=self.residual,
)
self.mse_loss_fn = nn.MSELoss()
self.scheduler = None
self.val_step_outputs = []
[docs]
def forward(self, x):
"""Forward.
Args:
x:
Input tensor corresponding to input layer.
Returns:
Output tensor corresponding to the last encoder layer.
Output tensor corresponding to the last decoder layer.
"""
z = self.encoder(x)
x_hat = self.decoder(z)
return z, x_hat
[docs]
def get_loss(self, batch):
"""Calculate the loss.
Args:
batch:
A batch as defined by a pytorch DataLoader.
Returns:
The training loss
"""
cells, labels, studies, samples = batch
embedding, reconstruction = self(cells)
return self.mse_loss_fn(cells, reconstruction)
[docs]
def training_step(self, batch, batch_idx):
"""Pytorch-lightning training step."""
loss = self.get_loss(batch)
self.log("train loss", loss, prog_bar=True, logger=True)
return loss
[docs]
def on_validation_epoch_start(self):
"""Pytorch-lightning validation epoch start."""
super().on_validation_epoch_start()
self.val_step_outputs = []
[docs]
def validation_step(self, batch, batch_idx):
"""Pytorch-lightning validation step."""
if self.trainer.datamodule.val_dataset is None:
return {}
return self._eval_step(batch, prefix="val")
[docs]
def on_validation_epoch_end(self):
"""Pytorch-lightning validation epoch end evaluation."""
if self.trainer.datamodule.val_dataset is None:
return {}
return self._eval_epoch(prefix="val")
def _eval_step(self, batch, prefix: str):
"""Evaluation of validation or test step.
Args:
batch:
A batch as defined by a pytorch DataLoader.
prefix:
A string prefix to label logs.
Returns:
A dictionary containing step evaluation metrics.
"""
loss = self.get_loss(batch)
losses = {
f"{prefix}_loss": loss,
}
if prefix == "val":
self.val_step_outputs.append(losses)
return losses
def _eval_epoch(self, prefix: str):
"""Evaluation of validation or test epoch.
Args:
prefix:
A string prefix to label logs.
Returns:
A dictionary containing epoch evaluation metrics.
"""
if prefix == "val":
step_outputs = self.val_step_outputs
loss = torch.Tensor([step[f"{prefix}_loss"] for step in step_outputs]).mean()
self.log(f"{prefix} loss", loss, logger=True)
losses = {
f"{prefix}_loss": loss,
}
return losses
[docs]
def save_all(
self,
model_path: str,
):
if not os.path.isdir(model_path):
os.makedirs(model_path)
# save model
self.encoder.save_state(os.path.join(model_path, "encoder.ckpt"))
self.decoder.save_state(os.path.join(model_path, "decoder.ckpt"))
[docs]
def load_state(
self,
encoder_filename: str,
decoder_filename: str,
use_gpu: bool = False,
):
"""Load model state.
Args:
encoder_filename:
Filename containing the encoder model state.
decoder_filename:
Filename containing the decoder model state.
use_gpu:
Boolean indicating whether or not to use GPUs.
"""
self.encoder.load_state(encoder_filename, use_gpu)
self.decoder.load_state(decoder_filename, use_gpu)