Added custom template folder
This commit is contained in:
parent
bb3f51e5a2
commit
17f24bd7d2
|
@ -174,7 +174,10 @@ where `data` is the same directory as in the previous step. If `jackhmmer`,
|
|||
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
|
||||
If you've already computed alignments for the query, you have the option to
|
||||
skip the expensive alignment computation here with
|
||||
`--use_precomputed_alignments`.
|
||||
`--use_precomputed_alignments`. If you wish to use a specific template as input,
|
||||
you can use the argument `--use_custom_template`, which then will read all .cif
|
||||
files in `template_mmcif_dir`. Make sure the chains of interest have the identifier _A_
|
||||
and have the same length as the input sequence.
|
||||
|
||||
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
|
||||
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
|
||||
|
|
|
@ -23,8 +23,19 @@ import tempfile
|
|||
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
|
||||
from openfold.data.templates import get_custom_template_features, empty_template_feats
|
||||
from openfold.data import (
|
||||
templates,
|
||||
parsers,
|
||||
mmcif_parsing,
|
||||
msa_identifiers,
|
||||
msa_pairing,
|
||||
feature_processing_multimer,
|
||||
)
|
||||
from openfold.data.templates import (
|
||||
get_custom_template_features,
|
||||
empty_template_feats,
|
||||
CustomHitFeaturizer,
|
||||
)
|
||||
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
|
||||
from openfold.np import residue_constants, protein
|
||||
|
||||
|
@ -38,7 +49,9 @@ def make_template_features(
|
|||
template_featurizer: Any,
|
||||
) -> FeatureDict:
|
||||
hits_cat = sum(hits.values(), [])
|
||||
if(len(hits_cat) == 0 or template_featurizer is None):
|
||||
if template_featurizer is None or (
|
||||
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
|
||||
):
|
||||
template_features = empty_template_feats(len(input_sequence))
|
||||
else:
|
||||
templates_result = template_featurizer.get_templates(
|
||||
|
|
|
@ -22,6 +22,7 @@ import glob
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
|
@ -947,49 +948,58 @@ def _process_single_hit(
|
|||
|
||||
|
||||
def get_custom_template_features(
|
||||
mmcif_path: str,
|
||||
query_sequence: str,
|
||||
pdb_id: str,
|
||||
chain_id: str,
|
||||
kalign_binary_path: str):
|
||||
mmcif_path: str,
|
||||
query_sequence: str,
|
||||
pdb_id: str,
|
||||
chain_id: Optional[str] = "A",
|
||||
kalign_binary_path: Optional[str] = None,
|
||||
):
|
||||
if os.path.isfile(mmcif_path):
|
||||
template_paths = [Path(mmcif_path)]
|
||||
|
||||
with open(mmcif_path, "r") as mmcif_path:
|
||||
cif_string = mmcif_path.read()
|
||||
elif os.path.isdir(mmcif_path):
|
||||
template_paths = list(Path(mmcif_path).glob("*.cif"))
|
||||
else:
|
||||
logging.error("Custom template path %s does not exist", mmcif_path)
|
||||
raise ValueError(f"Custom template path {mmcif_path} does not exist")
|
||||
warnings = []
|
||||
template_features = dict()
|
||||
for template_path in template_paths:
|
||||
logging.info("Featurizing template: %s", template_path)
|
||||
# pdb_id only for error reporting, take file name
|
||||
pdb_id = Path(template_path).stem
|
||||
with open(template_path, "r") as mmcif_path:
|
||||
cif_string = mmcif_path.read()
|
||||
mmcif_parse_result = mmcif_parsing.parse(
|
||||
file_id=pdb_id, mmcif_string=cif_string
|
||||
)
|
||||
# chain_id defaults to A, should be changed?
|
||||
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
|
||||
mapping = {x: x for x, _ in enumerate(query_sequence)}
|
||||
|
||||
mmcif_parse_result = mmcif_parsing.parse(
|
||||
file_id=pdb_id, mmcif_string=cif_string
|
||||
)
|
||||
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
|
||||
|
||||
|
||||
mapping = {x:x for x, _ in enumerate(query_sequence)}
|
||||
|
||||
|
||||
features, warnings = _extract_template_features(
|
||||
mmcif_object=mmcif_parse_result.mmcif_object,
|
||||
pdb_id=pdb_id,
|
||||
mapping=mapping,
|
||||
template_sequence=template_sequence,
|
||||
query_sequence=query_sequence,
|
||||
template_chain_id=chain_id,
|
||||
kalign_binary_path=kalign_binary_path,
|
||||
_zero_center_positions=True
|
||||
)
|
||||
features["template_sum_probs"] = [1.0]
|
||||
|
||||
# TODO: clean up this logic
|
||||
template_features = {}
|
||||
for template_feature_name in TEMPLATE_FEATURES:
|
||||
template_features[template_feature_name] = []
|
||||
|
||||
for k in template_features:
|
||||
template_features[k].append(features[k])
|
||||
|
||||
for name in template_features:
|
||||
template_features[name] = np.stack(
|
||||
template_features[name], axis=0
|
||||
).astype(TEMPLATE_FEATURES[name])
|
||||
curr_features, curr_warnings = _extract_template_features(
|
||||
mmcif_object=mmcif_parse_result.mmcif_object,
|
||||
pdb_id=pdb_id,
|
||||
mapping=mapping,
|
||||
template_sequence=template_sequence,
|
||||
query_sequence=query_sequence,
|
||||
template_chain_id=chain_id,
|
||||
kalign_binary_path=kalign_binary_path,
|
||||
_zero_center_positions=True,
|
||||
)
|
||||
curr_features["template_sum_probs"] = [1.0]
|
||||
template_features = {
|
||||
curr_name: template_features.get(curr_name, []) + [curr_item]
|
||||
for curr_name, curr_item in curr_features.items()
|
||||
}
|
||||
warnings = warnings.append(curr_warnings)
|
||||
|
||||
template_features = {
|
||||
template_feature_name: np.stack(
|
||||
template_features[template_feature_name], axis=0
|
||||
).astype(template_feature_type)
|
||||
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
|
||||
}
|
||||
return TemplateSearchResult(
|
||||
features=template_features, errors=None, warnings=warnings
|
||||
)
|
||||
|
@ -1188,6 +1198,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
|
|||
)
|
||||
|
||||
|
||||
class CustomHitFeaturizer(TemplateHitFeaturizer):
|
||||
"""Featurizer for templates given in folder.
|
||||
Chain of interest has to be chain A and of same residue size as input sequence."""
|
||||
def get_templates(
|
||||
self,
|
||||
query_sequence: str,
|
||||
hits: Sequence[parsers.TemplateHit],
|
||||
) -> TemplateSearchResult:
|
||||
"""Computes the templates for given query sequence (more details above)."""
|
||||
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
|
||||
return get_custom_template_features(
|
||||
self._mmcif_dir,
|
||||
query_sequence=query_sequence,
|
||||
pdb_id="test",
|
||||
chain_id="A",
|
||||
kalign_binary_path=self._kalign_binary_path,
|
||||
)
|
||||
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
|
||||
def get_templates(
|
||||
self,
|
||||
|
|
|
@ -186,8 +186,15 @@ def main(args):
|
|||
)
|
||||
|
||||
is_multimer = "multimer" in args.config_preset
|
||||
|
||||
if is_multimer:
|
||||
is_custom_template = "use_custom_template" in args
|
||||
if is_custom_template:
|
||||
template_featurizer = templates.CustomHitFeaturizer(
|
||||
mmcif_dir=args.template_mmcif_dir,
|
||||
max_template_date="9999-12-31", # just dummy, not used
|
||||
max_hits=-1, # just dummy, not used
|
||||
kalign_binary_path=args.kalign_binary_path
|
||||
)
|
||||
elif is_multimer:
|
||||
template_featurizer = templates.HmmsearchHitFeaturizer(
|
||||
mmcif_dir=args.template_mmcif_dir,
|
||||
max_template_date=args.max_template_date,
|
||||
|
@ -205,11 +212,9 @@ def main(args):
|
|||
release_dates_path=args.release_dates_path,
|
||||
obsolete_pdbs_path=args.obsolete_pdbs_path
|
||||
)
|
||||
|
||||
data_processor = data_pipeline.DataPipeline(
|
||||
template_featurizer=template_featurizer,
|
||||
)
|
||||
|
||||
if is_multimer:
|
||||
data_processor = data_pipeline.DataPipelineMultimer(
|
||||
monomer_data_pipeline=data_processor,
|
||||
|
@ -222,7 +227,6 @@ def main(args):
|
|||
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed + 1)
|
||||
|
||||
feature_processor = feature_pipeline.FeaturePipeline(config.data)
|
||||
if not os.path.exists(output_dir_base):
|
||||
os.makedirs(output_dir_base)
|
||||
|
@ -292,7 +296,6 @@ def main(args):
|
|||
)
|
||||
|
||||
feature_dicts[tag] = feature_dict
|
||||
|
||||
processed_feature_dict = feature_processor.process_features(
|
||||
feature_dict, mode='predict', is_multimer=is_multimer
|
||||
)
|
||||
|
@ -379,6 +382,10 @@ if __name__ == "__main__":
|
|||
help="""Path to alignment directory. If provided, alignment computation
|
||||
is skipped and database path arguments are ignored."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_custom_template", action="store_true", default=False,
|
||||
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_single_seq_mode", action="store_true", default=False,
|
||||
help="""Use single sequence embeddings instead of MSAs."""
|
||||
|
@ -466,5 +473,4 @@ if __name__ == "__main__":
|
|||
"""The model is being run on CPU. Consider specifying
|
||||
--model_device for better performance"""
|
||||
)
|
||||
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue