574 lines
19 KiB
Python
574 lines
19 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
|
|
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
|
import torch
|
|
|
|
from openfold.config import model_config
|
|
from openfold.data.data_modules import (
|
|
OpenFoldDataModule,
|
|
DummyDataLoader,
|
|
)
|
|
from openfold.model.model import AlphaFold
|
|
from openfold.model.torchscript import script_preset_
|
|
from openfold.np import residue_constants
|
|
from openfold.utils.argparse import remove_arguments
|
|
from openfold.utils.callbacks import (
|
|
EarlyStoppingVerbose,
|
|
)
|
|
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
|
|
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
|
|
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
|
|
from openfold.utils.seed import seed_everything
|
|
from openfold.utils.superimposition import superimpose
|
|
from openfold.utils.tensor_utils import tensor_tree_map
|
|
from openfold.utils.validation_metrics import (
|
|
drmsd,
|
|
gdt_ts,
|
|
gdt_ha,
|
|
)
|
|
from scripts.zero_to_fp32 import (
|
|
get_fp32_state_dict_from_zero_checkpoint,
|
|
get_global_step_from_zero_checkpoint
|
|
)
|
|
|
|
from openfold.utils.logger import PerformanceLoggingCallback
|
|
|
|
|
|
class OpenFoldWrapper(pl.LightningModule):
|
|
def __init__(self, config):
|
|
super(OpenFoldWrapper, self).__init__()
|
|
self.config = config
|
|
self.model = AlphaFold(config)
|
|
self.loss = AlphaFoldLoss(config.loss)
|
|
self.ema = ExponentialMovingAverage(
|
|
model=self.model, decay=config.ema.decay
|
|
)
|
|
|
|
self.cached_weights = None
|
|
self.last_lr_step = -1
|
|
|
|
def forward(self, batch):
|
|
return self.model(batch)
|
|
|
|
def _log(self, loss_breakdown, batch, outputs, train=True):
|
|
phase = "train" if train else "val"
|
|
for loss_name, indiv_loss in loss_breakdown.items():
|
|
self.log(
|
|
f"{phase}/{loss_name}",
|
|
indiv_loss,
|
|
on_step=train, on_epoch=(not train), logger=True,
|
|
)
|
|
|
|
if(train):
|
|
self.log(
|
|
f"{phase}/{loss_name}_epoch",
|
|
indiv_loss,
|
|
on_step=False, on_epoch=True, logger=True,
|
|
)
|
|
|
|
with torch.no_grad():
|
|
other_metrics = self._compute_validation_metrics(
|
|
batch,
|
|
outputs,
|
|
superimposition_metrics=(not train)
|
|
)
|
|
|
|
for k,v in other_metrics.items():
|
|
self.log(
|
|
f"{phase}/{k}",
|
|
v,
|
|
on_step=False, on_epoch=True, logger=True
|
|
)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
if(self.ema.device != batch["aatype"].device):
|
|
self.ema.to(batch["aatype"].device)
|
|
|
|
# Run the model
|
|
outputs = self(batch)
|
|
|
|
# Remove the recycling dimension
|
|
batch = tensor_tree_map(lambda t: t[..., -1], batch)
|
|
|
|
# Compute loss
|
|
loss, loss_breakdown = self.loss(
|
|
outputs, batch, _return_breakdown=True
|
|
)
|
|
|
|
# Log it
|
|
self._log(loss_breakdown, batch, outputs)
|
|
|
|
return loss
|
|
|
|
def on_before_zero_grad(self, *args, **kwargs):
|
|
self.ema.update(self.model)
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
# At the start of validation, load the EMA weights
|
|
if(self.cached_weights is None):
|
|
# model.state_dict() contains references to model weights rather
|
|
# than copies. Therefore, we need to clone them before calling
|
|
# load_state_dict().
|
|
clone_param = lambda t: t.detach().clone()
|
|
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
|
|
self.model.load_state_dict(self.ema.state_dict()["params"])
|
|
|
|
# Run the model
|
|
outputs = self(batch)
|
|
batch = tensor_tree_map(lambda t: t[..., -1], batch)
|
|
|
|
# Compute loss and other metrics
|
|
batch["use_clamped_fape"] = 0.
|
|
_, loss_breakdown = self.loss(
|
|
outputs, batch, _return_breakdown=True
|
|
)
|
|
|
|
self._log(loss_breakdown, batch, outputs, train=False)
|
|
|
|
def validation_epoch_end(self, _):
|
|
# Restore the model weights to normal
|
|
self.model.load_state_dict(self.cached_weights)
|
|
self.cached_weights = None
|
|
|
|
def _compute_validation_metrics(self,
|
|
batch,
|
|
outputs,
|
|
superimposition_metrics=False
|
|
):
|
|
metrics = {}
|
|
|
|
gt_coords = batch["all_atom_positions"]
|
|
pred_coords = outputs["final_atom_positions"]
|
|
all_atom_mask = batch["all_atom_mask"]
|
|
|
|
# This is super janky for superimposition. Fix later
|
|
gt_coords_masked = gt_coords * all_atom_mask[..., None]
|
|
pred_coords_masked = pred_coords * all_atom_mask[..., None]
|
|
ca_pos = residue_constants.atom_order["CA"]
|
|
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
|
|
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
|
|
all_atom_mask_ca = all_atom_mask[..., ca_pos]
|
|
|
|
lddt_ca_score = lddt_ca(
|
|
pred_coords,
|
|
gt_coords,
|
|
all_atom_mask,
|
|
eps=self.config.globals.eps,
|
|
per_residue=False,
|
|
)
|
|
|
|
metrics["lddt_ca"] = lddt_ca_score
|
|
|
|
drmsd_ca_score = drmsd(
|
|
pred_coords_masked_ca,
|
|
gt_coords_masked_ca,
|
|
mask=all_atom_mask_ca, # still required here to compute n
|
|
)
|
|
|
|
metrics["drmsd_ca"] = drmsd_ca_score
|
|
|
|
if(superimposition_metrics):
|
|
superimposed_pred, alignment_rmsd = superimpose(
|
|
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
|
|
)
|
|
gdt_ts_score = gdt_ts(
|
|
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
|
|
)
|
|
gdt_ha_score = gdt_ha(
|
|
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
|
|
)
|
|
|
|
metrics["alignment_rmsd"] = alignment_rmsd
|
|
metrics["gdt_ts"] = gdt_ts_score
|
|
metrics["gdt_ha"] = gdt_ha_score
|
|
|
|
return metrics
|
|
|
|
def configure_optimizers(self,
|
|
learning_rate: float = 1e-3,
|
|
eps: float = 1e-5,
|
|
) -> torch.optim.Adam:
|
|
# return torch.optim.Adam(
|
|
# self.model.parameters(),
|
|
# lr=learning_rate,
|
|
# eps=eps
|
|
# )
|
|
# Ignored as long as a DeepSpeed optimizer is configured
|
|
optimizer = torch.optim.Adam(
|
|
self.model.parameters(),
|
|
lr=learning_rate,
|
|
eps=eps
|
|
)
|
|
|
|
if self.last_lr_step != -1:
|
|
for group in optimizer.param_groups:
|
|
if 'initial_lr' not in group:
|
|
group['initial_lr'] = learning_rate
|
|
|
|
lr_scheduler = AlphaFoldLRScheduler(
|
|
optimizer,
|
|
)
|
|
|
|
return {
|
|
"optimizer": optimizer,
|
|
"lr_scheduler": {
|
|
"scheduler": lr_scheduler,
|
|
"interval": "step",
|
|
"name": "AlphaFoldLRScheduler",
|
|
}
|
|
}
|
|
|
|
def on_load_checkpoint(self, checkpoint):
|
|
ema = checkpoint["ema"]
|
|
if(not self.model.template_config.enabled):
|
|
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
|
|
self.ema.load_state_dict(ema)
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
checkpoint["ema"] = self.ema.state_dict()
|
|
|
|
def resume_last_lr_step(self, lr_step):
|
|
self.last_lr_step = lr_step
|
|
|
|
|
|
def main(args):
|
|
if(args.seed is not None):
|
|
seed_everything(args.seed)
|
|
|
|
config = model_config(
|
|
args.config_preset,
|
|
train=True,
|
|
low_prec=(str(args.precision) == "16")
|
|
)
|
|
|
|
model_module = OpenFoldWrapper(config)
|
|
if(args.resume_from_ckpt):
|
|
if(os.path.isdir(args.resume_from_ckpt)):
|
|
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
|
|
else:
|
|
sd = torch.load(args.resume_from_ckpt)
|
|
last_global_step = int(sd['global_step'])
|
|
model_module.resume_last_lr_step(last_global_step)
|
|
logging.info("Successfully loaded last lr step...")
|
|
if(args.resume_from_ckpt and args.resume_model_weights_only):
|
|
if(os.path.isdir(args.resume_from_ckpt)):
|
|
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
|
|
else:
|
|
sd = torch.load(args.resume_from_ckpt)
|
|
sd = {k[len("module."):]:v for k,v in sd.items()}
|
|
model_module.load_state_dict(sd)
|
|
logging.info("Successfully loaded model weights...")
|
|
|
|
# TorchScript components of the model
|
|
if(args.script_modules):
|
|
script_preset_(model_module)
|
|
|
|
#data_module = DummyDataLoader("new_batch.pickle")
|
|
data_module = OpenFoldDataModule(
|
|
config=config.data,
|
|
batch_seed=args.seed,
|
|
**vars(args)
|
|
)
|
|
|
|
data_module.prepare_data()
|
|
data_module.setup()
|
|
|
|
callbacks = []
|
|
if(args.checkpoint_every_epoch):
|
|
mc = ModelCheckpoint(
|
|
every_n_epochs=1,
|
|
auto_insert_metric_name=False,
|
|
save_top_k=-1,
|
|
)
|
|
callbacks.append(mc)
|
|
|
|
if(args.early_stopping):
|
|
es = EarlyStoppingVerbose(
|
|
monitor="val/lddt_ca",
|
|
min_delta=args.min_delta,
|
|
patience=args.patience,
|
|
verbose=False,
|
|
mode="max",
|
|
check_finite=True,
|
|
strict=True,
|
|
)
|
|
callbacks.append(es)
|
|
|
|
if(args.log_performance):
|
|
global_batch_size = args.num_nodes * args.gpus
|
|
perf = PerformanceLoggingCallback(
|
|
log_file=os.path.join(args.output_dir, "performance_log.json"),
|
|
global_batch_size=global_batch_size,
|
|
)
|
|
callbacks.append(perf)
|
|
|
|
if(args.log_lr):
|
|
lr_monitor = LearningRateMonitor(logging_interval="step")
|
|
callbacks.append(lr_monitor)
|
|
|
|
loggers = []
|
|
if(args.wandb):
|
|
wdb_logger = WandbLogger(
|
|
name=args.experiment_name,
|
|
save_dir=args.output_dir,
|
|
id=args.wandb_id,
|
|
project=args.wandb_project,
|
|
**{"entity": args.wandb_entity}
|
|
)
|
|
loggers.append(wdb_logger)
|
|
|
|
if(args.deepspeed_config_path is not None):
|
|
strategy = DeepSpeedPlugin(
|
|
config=args.deepspeed_config_path,
|
|
)
|
|
if(args.wandb):
|
|
wdb_logger.experiment.save(args.deepspeed_config_path)
|
|
wdb_logger.experiment.save("openfold/config.py")
|
|
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
|
|
strategy = DDPPlugin(find_unused_parameters=False)
|
|
else:
|
|
strategy = None
|
|
|
|
if(args.wandb):
|
|
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
|
|
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
|
|
wdb_logger.experiment.save(f"{freeze_path}")
|
|
|
|
trainer = pl.Trainer.from_argparse_args(
|
|
args,
|
|
default_root_dir=args.output_dir,
|
|
strategy=strategy,
|
|
callbacks=callbacks,
|
|
logger=loggers,
|
|
)
|
|
|
|
if(args.resume_model_weights_only):
|
|
ckpt_path = None
|
|
else:
|
|
ckpt_path = args.resume_from_ckpt
|
|
|
|
trainer.fit(
|
|
model_module,
|
|
datamodule=data_module,
|
|
ckpt_path=ckpt_path,
|
|
)
|
|
|
|
|
|
def bool_type(bool_str: str):
|
|
bool_str_lower = bool_str.lower()
|
|
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
|
|
return False
|
|
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
|
|
return True
|
|
else:
|
|
raise ValueError(f'Cannot interpret {bool_str} as bool')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"train_data_dir", type=str,
|
|
help="Directory containing training mmCIF files"
|
|
)
|
|
parser.add_argument(
|
|
"train_alignment_dir", type=str,
|
|
help="Directory containing precomputed training alignments"
|
|
)
|
|
parser.add_argument(
|
|
"template_mmcif_dir", type=str,
|
|
help="Directory containing mmCIF files to search for templates"
|
|
)
|
|
parser.add_argument(
|
|
"output_dir", type=str,
|
|
help='''Directory in which to output checkpoints, logs, etc. Ignored
|
|
if not on rank 0'''
|
|
)
|
|
parser.add_argument(
|
|
"max_template_date", type=str,
|
|
help='''Cutoff for all templates. In training mode, templates are also
|
|
filtered by the release date of the target'''
|
|
)
|
|
parser.add_argument(
|
|
"--distillation_data_dir", type=str, default=None,
|
|
help="Directory containing training PDB files"
|
|
)
|
|
parser.add_argument(
|
|
"--distillation_alignment_dir", type=str, default=None,
|
|
help="Directory containing precomputed distillation alignments"
|
|
)
|
|
parser.add_argument(
|
|
"--val_data_dir", type=str, default=None,
|
|
help="Directory containing validation mmCIF files"
|
|
)
|
|
parser.add_argument(
|
|
"--val_alignment_dir", type=str, default=None,
|
|
help="Directory containing precomputed validation alignments"
|
|
)
|
|
parser.add_argument(
|
|
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
|
|
help="Path to the kalign binary"
|
|
)
|
|
parser.add_argument(
|
|
"--train_filter_path", type=str, default=None,
|
|
help='''Optional path to a text file containing names of training
|
|
examples to include, one per line. Used to filter the training
|
|
set'''
|
|
)
|
|
parser.add_argument(
|
|
"--distillation_filter_path", type=str, default=None,
|
|
help="""See --train_filter_path"""
|
|
)
|
|
parser.add_argument(
|
|
"--obsolete_pdbs_file_path", type=str, default=None,
|
|
help="""Path to obsolete.dat file containing list of obsolete PDBs and
|
|
their replacements."""
|
|
)
|
|
parser.add_argument(
|
|
"--template_release_dates_cache_path", type=str, default=None,
|
|
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
|
|
files."""
|
|
)
|
|
parser.add_argument(
|
|
"--use_small_bfd", type=bool_type, default=False,
|
|
help="Whether to use a reduced version of the BFD database"
|
|
)
|
|
parser.add_argument(
|
|
"--seed", type=int, default=None,
|
|
help="Random seed"
|
|
)
|
|
parser.add_argument(
|
|
"--deepspeed_config_path", type=str, default=None,
|
|
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint_every_epoch", action="store_true", default=False,
|
|
help="""Whether to checkpoint at the end of every training epoch"""
|
|
)
|
|
parser.add_argument(
|
|
"--early_stopping", type=bool_type, default=False,
|
|
help="Whether to stop training when validation loss fails to decrease"
|
|
)
|
|
parser.add_argument(
|
|
"--min_delta", type=float, default=0,
|
|
help="""The smallest decrease in validation loss that counts as an
|
|
improvement for the purposes of early stopping"""
|
|
)
|
|
parser.add_argument(
|
|
"--patience", type=int, default=3,
|
|
help="Early stopping patience"
|
|
)
|
|
parser.add_argument(
|
|
"--resume_from_ckpt", type=str, default=None,
|
|
help="Path to a model checkpoint from which to restore training state"
|
|
)
|
|
parser.add_argument(
|
|
"--resume_model_weights_only", type=bool_type, default=False,
|
|
help="Whether to load just model weights as opposed to training state"
|
|
)
|
|
parser.add_argument(
|
|
"--log_performance", type=bool_type, default=False,
|
|
help="Measure performance"
|
|
)
|
|
parser.add_argument(
|
|
"--wandb", action="store_true", default=False,
|
|
help="Whether to log metrics to Weights & Biases"
|
|
)
|
|
parser.add_argument(
|
|
"--experiment_name", type=str, default=None,
|
|
help="Name of the current experiment. Used for wandb logging"
|
|
)
|
|
parser.add_argument(
|
|
"--wandb_id", type=str, default=None,
|
|
help="ID of a previous run to be resumed"
|
|
)
|
|
parser.add_argument(
|
|
"--wandb_project", type=str, default=None,
|
|
help="Name of the wandb project to which this run will belong"
|
|
)
|
|
parser.add_argument(
|
|
"--wandb_entity", type=str, default=None,
|
|
help="wandb username or team name to which runs are attributed"
|
|
)
|
|
parser.add_argument(
|
|
"--script_modules", type=bool_type, default=False,
|
|
help="Whether to TorchScript eligible components of them model"
|
|
)
|
|
parser.add_argument(
|
|
"--train_chain_data_cache_path", type=str, default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--distillation_chain_data_cache_path", type=str, default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--train_epoch_len", type=int, default=10000,
|
|
help=(
|
|
"The virtual length of each training epoch. Stochastic filtering "
|
|
"of training data means that training datasets have no "
|
|
"well-defined length. This virtual length affects frequency of "
|
|
"validation & checkpointing (by default, one of each per epoch)."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
"--log_lr", action="store_true", default=False,
|
|
help="Whether to log the actual learning rate"
|
|
)
|
|
parser.add_argument(
|
|
"--config_preset", type=str, default="initial_training",
|
|
help=(
|
|
'Config setting. Choose e.g. "initial_training", "finetuning", '
|
|
'"model_1", etc. By default, the actual values in the config are '
|
|
'used.'
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
"--_distillation_structure_index_path", type=str, default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--alignment_index_path", type=str, default=None,
|
|
help="Training alignment index. See the README for instructions."
|
|
)
|
|
parser.add_argument(
|
|
"--distillation_alignment_index_path", type=str, default=None,
|
|
help="Distillation alignment index. See the README for instructions."
|
|
)
|
|
parser = pl.Trainer.add_argparse_args(parser)
|
|
|
|
# Disable the initial validation pass
|
|
parser.set_defaults(
|
|
num_sanity_val_steps=0,
|
|
)
|
|
|
|
# Remove some buggy/redundant arguments introduced by the Trainer
|
|
remove_arguments(
|
|
parser,
|
|
[
|
|
"--accelerator",
|
|
"--resume_from_checkpoint",
|
|
"--reload_dataloaders_every_epoch",
|
|
"--reload_dataloaders_every_n_epochs",
|
|
]
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if(args.seed is None and
|
|
((args.gpus is not None and args.gpus > 1) or
|
|
(args.num_nodes is not None and args.num_nodes > 1))):
|
|
raise ValueError("For distributed training, --seed must be specified")
|
|
|
|
# This re-applies the training-time filters at the beginning of every epoch
|
|
args.reload_dataloaders_every_n_epochs = 1
|
|
|
|
main(args)
|