Adds `experiment_config_json` for setting custom configurations with a json.
This commit is contained in:
parent
f4df72173d
commit
d1fe024b20
|
@ -20,6 +20,7 @@ import os
|
|||
import pickle
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__file__)
|
||||
|
@ -179,6 +180,11 @@ def main(args):
|
|||
|
||||
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
|
||||
|
||||
if args.experiment_config_json:
|
||||
with open(args.experiment_config_json, 'r') as f:
|
||||
custom_config_dict = json.load(f)
|
||||
config.update_from_flattened_dict(custom_config_dict)
|
||||
|
||||
if args.trace_model:
|
||||
if not config.data.predict.fixed_size:
|
||||
raise ValueError(
|
||||
|
@ -452,6 +458,9 @@ if __name__ == "__main__":
|
|||
"--cif_output", action="store_true", default=False,
|
||||
help="Output predicted models in ModelCIF format instead of PDB format (default)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
|
||||
)
|
||||
add_data_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
|
@ -39,7 +40,6 @@ from scripts.zero_to_fp32 import (
|
|||
get_fp32_state_dict_from_zero_checkpoint,
|
||||
get_global_step_from_zero_checkpoint
|
||||
)
|
||||
from scripts.zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
|
||||
|
||||
from openfold.utils.logger import PerformanceLoggingCallback
|
||||
|
||||
|
@ -59,6 +59,7 @@ class OpenFoldWrapper(pl.LightningModule):
|
|||
|
||||
self.cached_weights = None
|
||||
self.last_lr_step = -1
|
||||
self.save_hyperparameters
|
||||
|
||||
def forward(self, batch):
|
||||
return self.model(batch)
|
||||
|
@ -280,6 +281,11 @@ def main(args):
|
|||
train=True,
|
||||
low_prec=(str(args.precision) == "16")
|
||||
)
|
||||
if args.experiment_config_json:
|
||||
with open(args.experiment_config_json, 'r') as f:
|
||||
custom_config_dict = json.load(f)
|
||||
config.update_from_flattened_dict(custom_config_dict)
|
||||
|
||||
model_module = OpenFoldWrapper(config)
|
||||
|
||||
if args.resume_from_ckpt:
|
||||
|
@ -611,6 +617,9 @@ if __name__ == "__main__":
|
|||
"--distillation_alignment_index_path", type=str, default=None,
|
||||
help="Distillation alignment index. See the README for instructions."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
|
||||
)
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
|
||||
# Disable the initial validation pass
|
||||
|
|
Loading…
Reference in New Issue