Adding output_cif as CLI argument

This commit is contained in:
Jose Duarte 2023-02-20 14:08:05 -08:00
parent 4f662f832c
commit 9db8dc369b
1 changed files with 30 additions and 20 deletions

View File

@ -1,6 +1,6 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
@ -70,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
logger.info(f"Generating alignments for {tag}...")
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
@ -141,13 +141,13 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model):
if(not config.data.predict.fixed_size):
raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config"
)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
@ -165,10 +165,10 @@ def main(args):
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
@ -183,7 +183,7 @@ def main(args):
# Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read()
tags, seqs = parse_fasta(data)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
@ -206,10 +206,10 @@ def main(args):
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
# Does nothing if the alignments have already been computed
precompute_alignments(tags, seqs, alignment_dir, args)
feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None):
feature_dict = generate_feature_dict(
@ -234,7 +234,7 @@ def main(args):
)
processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device)
k:torch.as_tensor(v, device=args.model_device)
for k,v in processed_feature_dict.items()
}
@ -255,30 +255,36 @@ def main(args):
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
lambda x: np.array(x[..., -1].cpu()),
lambda x: np.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
feature_processor,
out,
processed_feature_dict,
feature_dict,
feature_processor,
args.config_preset,
args.multimer_ri_gap,
args.subtract_plddt
)
unrelaxed_file_suffix = "_unrelaxed.pdb"
if args.cif_output:
unrelaxed_file_suffix = "_unrelaxed.cif"
unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb'
output_directory, f'{output_name}{unrelaxed_file_suffix}'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
if args.cif_output:
fp.write(protein.to_modelcif(unrelaxed_protein))
else:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation:
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
@ -373,12 +379,16 @@ if __name__ == "__main__":
"--long_sequence_inference", action="store_true", default=False,
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
)
parser.add_argument(
"--cif_output", action="store_true", default=False,
help="Output predicted models in ModelCIF format instead of PDB format (default)"
)
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)