From d1fe024b20c673193a8570357e46b92cc1caacd1 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Wed, 13 Mar 2024 01:21:50 -0400 Subject: [PATCH] Adds `experiment_config_json` for setting custom configurations with a json. --- run_pretrained_openfold.py | 9 +++++++++ train_openfold.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index 10f991d..699da57 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -20,6 +20,7 @@ import os import pickle import random import time +import json logging.basicConfig() logger = logging.getLogger(__file__) @@ -179,6 +180,11 @@ def main(args): config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) + if args.experiment_config_json: + with open(args.experiment_config_json, 'r') as f: + custom_config_dict = json.load(f) + config.update_from_flattened_dict(custom_config_dict) + if args.trace_model: if not config.data.predict.fixed_size: raise ValueError( @@ -452,6 +458,9 @@ if __name__ == "__main__": "--cif_output", action="store_true", default=False, help="Output predicted models in ModelCIF format instead of PDB format (default)" ) + parser.add_argument( + "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", + ) add_data_args(parser) args = parser.parse_args() diff --git a/train_openfold.py b/train_openfold.py index b7aab43..64bb635 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -2,6 +2,7 @@ import argparse import logging import os import sys +import json import pytorch_lightning as pl from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor @@ -39,7 +40,6 @@ from scripts.zero_to_fp32 import ( get_fp32_state_dict_from_zero_checkpoint, get_global_step_from_zero_checkpoint ) -from scripts.zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file from openfold.utils.logger import PerformanceLoggingCallback @@ -59,6 +59,7 @@ class OpenFoldWrapper(pl.LightningModule): self.cached_weights = None self.last_lr_step = -1 + self.save_hyperparameters def forward(self, batch): return self.model(batch) @@ -280,6 +281,11 @@ def main(args): train=True, low_prec=(str(args.precision) == "16") ) + if args.experiment_config_json: + with open(args.experiment_config_json, 'r') as f: + custom_config_dict = json.load(f) + config.update_from_flattened_dict(custom_config_dict) + model_module = OpenFoldWrapper(config) if args.resume_from_ckpt: @@ -611,6 +617,9 @@ if __name__ == "__main__": "--distillation_alignment_index_path", type=str, default=None, help="Distillation alignment index. See the README for instructions." ) + parser.add_argument( + "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", + ) parser = pl.Trainer.add_argparse_args(parser) # Disable the initial validation pass