Added custom template folder

This commit is contained in:
rostro36 2024-02-20 09:57:12 +01:00
parent bb3f51e5a2
commit 17f24bd7d2
4 changed files with 100 additions and 51 deletions

View File

@ -174,7 +174,10 @@ where `data` is the same directory as in the previous step. If `jackhmmer`,
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here with
`--use_precomputed_alignments`.
`--use_precomputed_alignments`. If you wish to use a specific template as input,
you can use the argument `--use_custom_template`, which then will read all .cif
files in `template_mmcif_dir`. Make sure the chains of interest have the identifier _A_
and have the same length as the input sequence.
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,

View File

@ -23,8 +23,19 @@ import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data import (
templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.templates import (
get_custom_template_features,
empty_template_feats,
CustomHitFeaturizer,
)
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.np import residue_constants, protein
@ -38,7 +49,9 @@ def make_template_features(
template_featurizer: Any,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
if template_featurizer is None or (
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
):
template_features = empty_template_feats(len(input_sequence))
else:
templates_result = template_featurizer.get_templates(

View File

@ -22,6 +22,7 @@ import glob
import json
import logging
import os
from pathlib import Path
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
@ -947,49 +948,58 @@ def _process_single_hit(
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: Optional[str] = "A",
kalign_binary_path: Optional[str] = None,
):
if os.path.isfile(mmcif_path):
template_paths = [Path(mmcif_path)]
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
elif os.path.isdir(mmcif_path):
template_paths = list(Path(mmcif_path).glob("*.cif"))
else:
logging.error("Custom template path %s does not exist", mmcif_path)
raise ValueError(f"Custom template path {mmcif_path} does not exist")
warnings = []
template_features = dict()
for template_path in template_paths:
logging.info("Featurizing template: %s", template_path)
# pdb_id only for error reporting, take file name
pdb_id = Path(template_path).stem
with open(template_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
# chain_id defaults to A, should be changed?
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x: x for x, _ in enumerate(query_sequence)}
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
curr_features, curr_warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True,
)
curr_features["template_sum_probs"] = [1.0]
template_features = {
curr_name: template_features.get(curr_name, []) + [curr_item]
for curr_name, curr_item in curr_features.items()
}
warnings = warnings.append(curr_warnings)
template_features = {
template_feature_name: np.stack(
template_features[template_feature_name], axis=0
).astype(template_feature_type)
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
}
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@ -1188,6 +1198,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
)
class CustomHitFeaturizer(TemplateHitFeaturizer):
"""Featurizer for templates given in folder.
Chain of interest has to be chain A and of same residue size as input sequence."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
return get_custom_template_features(
self._mmcif_dir,
query_sequence=query_sequence,
pdb_id="test",
chain_id="A",
kalign_binary_path=self._kalign_binary_path,
)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates(
self,

View File

@ -186,8 +186,15 @@ def main(args):
)
is_multimer = "multimer" in args.config_preset
if is_multimer:
is_custom_template = "use_custom_template" in args
if is_custom_template:
template_featurizer = templates.CustomHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date="9999-12-31", # just dummy, not used
max_hits=-1, # just dummy, not used
kalign_binary_path=args.kalign_binary_path
)
elif is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
@ -205,11 +212,9 @@ def main(args):
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
@ -222,7 +227,6 @@ def main(args):
np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
@ -292,7 +296,6 @@ def main(args):
)
feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer
)
@ -379,6 +382,10 @@ if __name__ == "__main__":
help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored."""
)
parser.add_argument(
"--use_custom_template", action="store_true", default=False,
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
)
parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs."""
@ -466,5 +473,4 @@ if __name__ == "__main__":
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)