Adding output_cif as CLI argument
This commit is contained in:
parent
4f662f832c
commit
9db8dc369b
|
@ -1,6 +1,6 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".")
|
|||
torch_major_version = int(torch_versions[0])
|
||||
torch_minor_version = int(torch_versions[1])
|
||||
if(
|
||||
torch_major_version > 1 or
|
||||
torch_major_version > 1 or
|
||||
(torch_major_version == 1 and torch_minor_version >= 12)
|
||||
):
|
||||
# Gives a large speedup on Ampere-class GPUs
|
||||
|
@ -70,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
|
|||
local_alignment_dir = os.path.join(alignment_dir, tag)
|
||||
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
|
||||
logger.info(f"Generating alignments for {tag}...")
|
||||
|
||||
|
||||
os.makedirs(local_alignment_dir)
|
||||
|
||||
alignment_runner = data_pipeline.AlignmentRunner(
|
||||
|
@ -141,13 +141,13 @@ def main(args):
|
|||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
|
||||
|
||||
|
||||
if(args.trace_model):
|
||||
if(not config.data.predict.fixed_size):
|
||||
raise ValueError(
|
||||
"Tracing requires that fixed_size mode be enabled in the config"
|
||||
)
|
||||
|
||||
|
||||
template_featurizer = templates.TemplateHitFeaturizer(
|
||||
mmcif_dir=args.template_mmcif_dir,
|
||||
max_template_date=args.max_template_date,
|
||||
|
@ -165,10 +165,10 @@ def main(args):
|
|||
random_seed = args.data_random_seed
|
||||
if random_seed is None:
|
||||
random_seed = random.randrange(2**32)
|
||||
|
||||
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed + 1)
|
||||
|
||||
|
||||
feature_processor = feature_pipeline.FeaturePipeline(config.data)
|
||||
if not os.path.exists(output_dir_base):
|
||||
os.makedirs(output_dir_base)
|
||||
|
@ -183,7 +183,7 @@ def main(args):
|
|||
# Gather input sequences
|
||||
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
|
||||
data = fp.read()
|
||||
|
||||
|
||||
tags, seqs = parse_fasta(data)
|
||||
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
|
||||
tag = '-'.join(tags)
|
||||
|
@ -206,10 +206,10 @@ def main(args):
|
|||
output_name = f'{tag}_{args.config_preset}'
|
||||
if args.output_postfix is not None:
|
||||
output_name = f'{output_name}_{args.output_postfix}'
|
||||
|
||||
|
||||
# Does nothing if the alignments have already been computed
|
||||
precompute_alignments(tags, seqs, alignment_dir, args)
|
||||
|
||||
|
||||
feature_dict = feature_dicts.get(tag, None)
|
||||
if(feature_dict is None):
|
||||
feature_dict = generate_feature_dict(
|
||||
|
@ -234,7 +234,7 @@ def main(args):
|
|||
)
|
||||
|
||||
processed_feature_dict = {
|
||||
k:torch.as_tensor(v, device=args.model_device)
|
||||
k:torch.as_tensor(v, device=args.model_device)
|
||||
for k,v in processed_feature_dict.items()
|
||||
}
|
||||
|
||||
|
@ -255,30 +255,36 @@ def main(args):
|
|||
|
||||
# Toss out the recycling dimensions --- we don't need them anymore
|
||||
processed_feature_dict = tensor_tree_map(
|
||||
lambda x: np.array(x[..., -1].cpu()),
|
||||
lambda x: np.array(x[..., -1].cpu()),
|
||||
processed_feature_dict
|
||||
)
|
||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
||||
|
||||
unrelaxed_protein = prep_output(
|
||||
out,
|
||||
processed_feature_dict,
|
||||
feature_dict,
|
||||
feature_processor,
|
||||
out,
|
||||
processed_feature_dict,
|
||||
feature_dict,
|
||||
feature_processor,
|
||||
args.config_preset,
|
||||
args.multimer_ri_gap,
|
||||
args.subtract_plddt
|
||||
)
|
||||
|
||||
unrelaxed_file_suffix = "_unrelaxed.pdb"
|
||||
if args.cif_output:
|
||||
unrelaxed_file_suffix = "_unrelaxed.cif"
|
||||
unrelaxed_output_path = os.path.join(
|
||||
output_directory, f'{output_name}_unrelaxed.pdb'
|
||||
output_directory, f'{output_name}{unrelaxed_file_suffix}'
|
||||
)
|
||||
|
||||
with open(unrelaxed_output_path, 'w') as fp:
|
||||
fp.write(protein.to_pdb(unrelaxed_protein))
|
||||
if args.cif_output:
|
||||
fp.write(protein.to_modelcif(unrelaxed_protein))
|
||||
else:
|
||||
fp.write(protein.to_pdb(unrelaxed_protein))
|
||||
|
||||
logger.info(f"Output written to {unrelaxed_output_path}...")
|
||||
|
||||
|
||||
if not args.skip_relaxation:
|
||||
# Relax the prediction.
|
||||
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
||||
|
@ -373,12 +379,16 @@ if __name__ == "__main__":
|
|||
"--long_sequence_inference", action="store_true", default=False,
|
||||
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cif_output", action="store_true", default=False,
|
||||
help="Output predicted models in ModelCIF format instead of PDB format (default)"
|
||||
)
|
||||
add_data_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
|
||||
args.jax_param_path = os.path.join(
|
||||
"openfold", "resources", "params",
|
||||
"openfold", "resources", "params",
|
||||
"params_" + args.config_preset + ".npz"
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue