# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import logging import math import numpy as np import os import pickle import random import time import json logging.basicConfig() logger = logging.getLogger(__file__) logger.setLevel(level=logging.INFO) import torch torch_versions = torch.__version__.split(".") torch_major_version = int(torch_versions[0]) torch_minor_version = int(torch_versions[1]) if ( torch_major_version > 1 or (torch_major_version == 1 and torch_minor_version >= 12) ): # Gives a large speedup on Ampere-class GPUs torch.set_float32_matmul_precision("high") torch.set_grad_enabled(False) from openfold.config import model_config from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data.tools import hhsearch, hmmsearch from openfold.np import protein from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model, prep_output, relax_protein) from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.trace_utils import ( pad_feature_dict_seq, trace_model_, ) from scripts.precompute_embeddings import EmbeddingGenerator from scripts.utils import add_data_args TRACING_INTERVAL = 50 def precompute_alignments(tags, seqs, alignment_dir, args): for tag, seq in zip(tags, seqs): tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") with open(tmp_fasta_path, "w") as fp: fp.write(f">{tag}\n{seq}") local_alignment_dir = os.path.join(alignment_dir, tag) if args.use_precomputed_alignments is None: logger.info(f"Generating alignments for {tag}...") os.makedirs(local_alignment_dir, exist_ok=True) if "multimer" in args.config_preset: template_searcher = hmmsearch.Hmmsearch( binary_path=args.hmmsearch_binary_path, hmmbuild_binary_path=args.hmmbuild_binary_path, database_path=args.pdb_seqres_database_path, ) else: template_searcher = hhsearch.HHSearch( binary_path=args.hhsearch_binary_path, databases=[args.pdb70_database_path], ) # In seqemb mode, use AlignmentRunner only to generate templates if args.use_single_seq_mode: alignment_runner = data_pipeline.AlignmentRunner( jackhmmer_binary_path=args.jackhmmer_binary_path, uniref90_database_path=args.uniref90_database_path, template_searcher=template_searcher, no_cpus=args.cpus, ) embedding_generator = EmbeddingGenerator() embedding_generator.run(tmp_fasta_path, alignment_dir) else: alignment_runner = data_pipeline.AlignmentRunner( jackhmmer_binary_path=args.jackhmmer_binary_path, hhblits_binary_path=args.hhblits_binary_path, uniref90_database_path=args.uniref90_database_path, mgnify_database_path=args.mgnify_database_path, bfd_database_path=args.bfd_database_path, uniref30_database_path=args.uniref30_database_path, uniclust30_database_path=args.uniclust30_database_path, uniprot_database_path=args.uniprot_database_path, template_searcher=template_searcher, use_small_bfd=args.bfd_database_path is None, no_cpus=args.cpus ) alignment_runner.run( tmp_fasta_path, local_alignment_dir ) else: logger.info( f"Using precomputed alignments for {tag} at {alignment_dir}..." ) # Remove temporary FASTA file os.remove(tmp_fasta_path) def round_up_seqlen(seqlen): return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL def generate_feature_dict( tags, seqs, alignment_dir, data_processor, args, ): tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") if "multimer" in args.config_preset: with open(tmp_fasta_path, "w") as fp: fp.write( '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)]) ) feature_dict = data_processor.process_fasta( fasta_path=tmp_fasta_path, alignment_dir=alignment_dir, ) elif len(seqs) == 1: tag = tags[0] seq = seqs[0] with open(tmp_fasta_path, "w") as fp: fp.write(f">{tag}\n{seq}") local_alignment_dir = os.path.join(alignment_dir, tag) feature_dict = data_processor.process_fasta( fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir, seqemb_mode=args.use_single_seq_mode, ) else: with open(tmp_fasta_path, "w") as fp: fp.write( '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)]) ) feature_dict = data_processor.process_multiseq_fasta( fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir, ) # Remove temporary FASTA file os.remove(tmp_fasta_path) return feature_dict def list_files_with_extensions(dir, extensions): return [f for f in os.listdir(dir) if f.endswith(extensions)] def main(args): # Create the output directory os.makedirs(args.output_dir, exist_ok=True) if args.config_preset.startswith("seq"): args.use_single_seq_mode = True config = model_config( args.config_preset, long_sequence_inference=args.long_sequence_inference, use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention, ) if args.experiment_config_json: with open(args.experiment_config_json, 'r') as f: custom_config_dict = json.load(f) config.update_from_flattened_dict(custom_config_dict) if args.trace_model: if not config.data.predict.fixed_size: raise ValueError( "Tracing requires that fixed_size mode be enabled in the config" ) is_multimer = "multimer" in args.config_preset is_custom_template = "use_custom_template" in args and args.use_custom_template if is_custom_template: template_featurizer = templates.CustomHitFeaturizer( mmcif_dir=args.template_mmcif_dir, max_template_date="9999-12-31", # just dummy, not used max_hits=-1, # just dummy, not used kalign_binary_path=args.kalign_binary_path ) elif is_multimer: template_featurizer = templates.HmmsearchHitFeaturizer( mmcif_dir=args.template_mmcif_dir, max_template_date=args.max_template_date, max_hits=config.data.predict.max_templates, kalign_binary_path=args.kalign_binary_path, release_dates_path=args.release_dates_path, obsolete_pdbs_path=args.obsolete_pdbs_path ) else: template_featurizer = templates.HhsearchHitFeaturizer( mmcif_dir=args.template_mmcif_dir, max_template_date=args.max_template_date, max_hits=config.data.predict.max_templates, kalign_binary_path=args.kalign_binary_path, release_dates_path=args.release_dates_path, obsolete_pdbs_path=args.obsolete_pdbs_path ) data_processor = data_pipeline.DataPipeline( template_featurizer=template_featurizer, ) if is_multimer: data_processor = data_pipeline.DataPipelineMultimer( monomer_data_pipeline=data_processor, ) output_dir_base = args.output_dir random_seed = args.data_random_seed if random_seed is None: random_seed = random.randrange(2 ** 32) np.random.seed(random_seed) torch.manual_seed(random_seed + 1) feature_processor = feature_pipeline.FeaturePipeline(config.data) if not os.path.exists(output_dir_base): os.makedirs(output_dir_base) if args.use_precomputed_alignments is None: alignment_dir = os.path.join(output_dir_base, "alignments") else: alignment_dir = args.use_precomputed_alignments tag_list = [] seq_list = [] for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")): # Gather input sequences fasta_path = os.path.join(args.fasta_dir, fasta_file) with open(fasta_path, "r") as fp: data = fp.read() tags, seqs = parse_fasta(data) if not is_multimer and len(tags) != 1: print( f"{fasta_path} contains more than one sequence but " f"multimer mode is not enabled. Skipping..." ) continue # assert len(tags) == len(set(tags)), "All FASTA tags must be unique" tag = '-'.join(tags) tag_list.append((tag, tags)) seq_list.append(seqs) seq_sort_fn = lambda target: sum([len(s) for s in target[1]]) sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn) feature_dicts = {} if is_multimer and args.openfold_checkpoint_path: raise ValueError( '`openfold_checkpoint_path` was specified, but no OpenFold checkpoints are available for multimer mode') model_generator = load_models_from_command_line( config, args.model_device, args.openfold_checkpoint_path, args.jax_param_path, args.output_dir) for model, output_directory in model_generator: cur_tracing_interval = 0 for (tag, tags), seqs in sorted_targets: output_name = f'{tag}_{args.config_preset}' if args.output_postfix is not None: output_name = f'{output_name}_{args.output_postfix}' # Does nothing if the alignments have already been computed precompute_alignments(tags, seqs, alignment_dir, args) feature_dict = feature_dicts.get(tag, None) if feature_dict is None: feature_dict = generate_feature_dict( tags, seqs, alignment_dir, data_processor, args, ) if args.trace_model: n = feature_dict["aatype"].shape[-2] rounded_seqlen = round_up_seqlen(n) feature_dict = pad_feature_dict_seq( feature_dict, rounded_seqlen, ) feature_dicts[tag] = feature_dict processed_feature_dict = feature_processor.process_features( feature_dict, mode='predict', is_multimer=is_multimer ) processed_feature_dict = { k: torch.as_tensor(v, device=args.model_device) for k, v in processed_feature_dict.items() } if args.trace_model: if rounded_seqlen > cur_tracing_interval: logger.info( f"Tracing model at {rounded_seqlen} residues..." ) t = time.perf_counter() trace_model_(model, processed_feature_dict) tracing_time = time.perf_counter() - t logger.info( f"Tracing time: {tracing_time}" ) cur_tracing_interval = rounded_seqlen out = run_model(model, processed_feature_dict, tag, args.output_dir) # Toss out the recycling dimensions --- we don't need them anymore processed_feature_dict = tensor_tree_map( lambda x: np.array(x[..., -1].cpu()), processed_feature_dict ) out = tensor_tree_map(lambda x: np.array(x.cpu()), out) unrelaxed_protein = prep_output( out, processed_feature_dict, feature_dict, feature_processor, args.config_preset, args.multimer_ri_gap, args.subtract_plddt ) unrelaxed_file_suffix = "_unrelaxed.pdb" if args.cif_output: unrelaxed_file_suffix = "_unrelaxed.cif" unrelaxed_output_path = os.path.join( output_directory, f'{output_name}{unrelaxed_file_suffix}' ) with open(unrelaxed_output_path, 'w') as fp: if args.cif_output: fp.write(protein.to_modelcif(unrelaxed_protein)) else: fp.write(protein.to_pdb(unrelaxed_protein)) logger.info(f"Output written to {unrelaxed_output_path}...") if not args.skip_relaxation: # Relax the prediction. logger.info(f"Running relaxation on {unrelaxed_output_path}...") relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output) if args.save_outputs: output_dict_path = os.path.join( output_directory, f'{output_name}_output_dict.pkl' ) with open(output_dict_path, "wb") as fp: pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"Model output written to {output_dict_path}...") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "fasta_dir", type=str, help="Path to directory containing FASTA files, one sequence per file" ) parser.add_argument( "template_mmcif_dir", type=str, ) parser.add_argument( "--use_precomputed_alignments", type=str, default=None, help="""Path to alignment directory. If provided, alignment computation is skipped and database path arguments are ignored.""" ) parser.add_argument( "--use_custom_template", action="store_true", default=False, help="""Use mmcif given with "template_mmcif_dir" argument as template input.""" ) parser.add_argument( "--use_single_seq_mode", action="store_true", default=False, help="""Use single sequence embeddings instead of MSAs.""" ) parser.add_argument( "--output_dir", type=str, default=os.getcwd(), help="""Name of the directory in which to output the prediction""", ) parser.add_argument( "--model_device", type=str, default="cpu", help="""Name of the device on which to run the model. Any valid torch device name is accepted (e.g. "cpu", "cuda:0")""" ) parser.add_argument( "--config_preset", type=str, default="model_1", help="""Name of a model config preset defined in openfold/config.py""" ) parser.add_argument( "--jax_param_path", type=str, default=None, help="""Path to JAX model parameters. If None, and openfold_checkpoint_path is also None, parameters are selected automatically according to the model name from openfold/resources/params""" ) parser.add_argument( "--openfold_checkpoint_path", type=str, default=None, help="""Path to OpenFold checkpoint. Can be either a DeepSpeed checkpoint directory or a .pt file""" ) parser.add_argument( "--save_outputs", action="store_true", default=False, help="Whether to save all model outputs, including embeddings, etc." ) parser.add_argument( "--cpus", type=int, default=4, help="""Number of CPUs with which to run alignment tools""" ) parser.add_argument( "--preset", type=str, default='full_dbs', choices=('reduced_dbs', 'full_dbs') ) parser.add_argument( "--output_postfix", type=str, default=None, help="""Postfix for output prediction filenames""" ) parser.add_argument( "--data_random_seed", type=int, default=None ) parser.add_argument( "--skip_relaxation", action="store_true", default=False, ) parser.add_argument( "--multimer_ri_gap", type=int, default=200, help="""Residue index offset between multiple sequences, if provided""" ) parser.add_argument( "--trace_model", action="store_true", default=False, help="""Whether to convert parts of each model to TorchScript. Significantly improves runtime at the cost of lengthy 'compilation.' Useful for large batch jobs.""" ) parser.add_argument( "--subtract_plddt", action="store_true", default=False, help=""""Whether to output (100 - pLDDT) in the B-factor column instead of the pLDDT itself""" ) parser.add_argument( "--long_sequence_inference", action="store_true", default=False, help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details""" ) parser.add_argument( "--cif_output", action="store_true", default=False, help="Output predicted models in ModelCIF format instead of PDB format (default)" ) parser.add_argument( "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", ) parser.add_argument( "--use_deepspeed_evoformer_attention", action="store_true", default=False, help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.", ) add_data_args(parser) args = parser.parse_args() if args.jax_param_path is None and args.openfold_checkpoint_path is None: args.jax_param_path = os.path.join( "openfold", "resources", "params", "params_" + args.config_preset + ".npz" ) if args.model_device == "cpu" and torch.cuda.is_available(): logging.warning( """The model is being run on CPU. Consider specifying --model_device for better performance""" ) main(args)