Adds `experiment_config_json` for setting custom configurations with a json.

This commit is contained in:
Jennifer 2024-03-13 01:21:50 -04:00
parent f4df72173d
commit d1fe024b20
2 changed files with 19 additions and 1 deletions

View File

@ -20,6 +20,7 @@ import os
import pickle
import random
import time
import json
logging.basicConfig()
logger = logging.getLogger(__file__)
@ -179,6 +180,11 @@ def main(args):
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.trace_model:
if not config.data.predict.fixed_size:
raise ValueError(
@ -452,6 +458,9 @@ if __name__ == "__main__":
"--cif_output", action="store_true", default=False,
help="Output predicted models in ModelCIF format instead of PDB format (default)"
)
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
add_data_args(parser)
args = parser.parse_args()

View File

@ -2,6 +2,7 @@ import argparse
import logging
import os
import sys
import json
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
@ -39,7 +40,6 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from scripts.zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
from openfold.utils.logger import PerformanceLoggingCallback
@ -59,6 +59,7 @@ class OpenFoldWrapper(pl.LightningModule):
self.cached_weights = None
self.last_lr_step = -1
self.save_hyperparameters
def forward(self, batch):
return self.model(batch)
@ -280,6 +281,11 @@ def main(args):
train=True,
low_prec=(str(args.precision) == "16")
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
model_module = OpenFoldWrapper(config)
if args.resume_from_ckpt:
@ -611,6 +617,9 @@ if __name__ == "__main__":
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass