Merge main again
This commit is contained in:
commit
736f27fdc8
|
@ -27,4 +27,5 @@ dependencies:
|
|||
- typing-extensions==3.10.0.2
|
||||
- pytorch_lightning==1.5.10
|
||||
- wandb==0.12.21
|
||||
- modelcif==0.7
|
||||
- git+https://github.com/NVIDIA/dllogger.git
|
||||
|
|
|
@ -121,10 +121,11 @@
|
|||
" %env PATH=/opt/conda/bin:{PATH}\n",
|
||||
"\n",
|
||||
" # Install the required versions of all dependencies.\n",
|
||||
" %shell conda install -y -q conda==4.13.0\n",
|
||||
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
|
||||
" kalign2=2.04 \\\n",
|
||||
" hhsuite=3.3.0 \\\n",
|
||||
" python=3.7 \\\n",
|
||||
" python=3.8 \\\n",
|
||||
" 2>&1 1>/dev/null\n",
|
||||
" %shell pip install -q \\\n",
|
||||
" ml-collections==0.1.0 \\\n",
|
||||
|
@ -180,15 +181,12 @@
|
|||
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
|
||||
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
|
||||
"\n",
|
||||
" if(relax_prediction):\n",
|
||||
" %shell conda install -y -q -c conda-forge \\\n",
|
||||
" openmm=7.5.1 \\\n",
|
||||
" pdbfixer=1.7\n",
|
||||
" \n",
|
||||
" # Apply OpenMM patch.\n",
|
||||
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
|
||||
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
|
||||
" popd\n",
|
||||
" %shell conda install -y -q -c conda-forge openmm=7.5.1\n",
|
||||
" # Apply OpenMM patch.\n",
|
||||
" %shell pushd /opt/conda/lib/python3.8/site-packages/ && \\\n",
|
||||
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
|
||||
" popd\n",
|
||||
" %shell conda install -y -q -c conda-forge pdbfixer=1.7\n",
|
||||
"\n",
|
||||
" if(weight_set == 'AlphaFold'):\n",
|
||||
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
|
||||
|
@ -222,8 +220,8 @@
|
|||
"import unittest.mock\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
|
||||
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
|
||||
"sys.path.insert(0, '/usr/local/lib/python3.8/site-packages/')\n",
|
||||
"sys.path.append('/opt/conda/lib/python3.8/site-packages')\n",
|
||||
"\n",
|
||||
"# Allows us to skip installing these packages\n",
|
||||
"unnecessary_modules = [\n",
|
||||
|
@ -247,6 +245,14 @@
|
|||
"import numpy as np\n",
|
||||
"import py3Dmol\n",
|
||||
"import torch\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"# Prevent shell magic being broken by openmm, prevent this cryptic error:\n",
|
||||
"# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n",
|
||||
"import locale\n",
|
||||
"def getpreferredencoding(do_setlocale = True):\n",
|
||||
" return \"UTF-8\"\n",
|
||||
"locale.getpreferredencoding = getpreferredencoding\n",
|
||||
"\n",
|
||||
"# A filthy hack to avoid slow Linear layer initialization\n",
|
||||
"import openfold.model.primitives\n",
|
||||
|
@ -267,9 +273,8 @@
|
|||
"from openfold.data.tools import jackhmmer\n",
|
||||
"from openfold.model import model\n",
|
||||
"from openfold.np import protein\n",
|
||||
"if(relax_prediction):\n",
|
||||
" from openfold.np.relax import relax\n",
|
||||
" from openfold.np.relax import utils\n",
|
||||
"from openfold.np.relax import relax\n",
|
||||
"from openfold.np.relax.utils import overwrite_b_factors\n",
|
||||
"from openfold.utils.import_weights import import_jax_weights_\n",
|
||||
"from openfold.utils.tensor_utils import tensor_tree_map\n",
|
||||
"\n",
|
||||
|
@ -571,14 +576,13 @@
|
|||
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
|
||||
" prot=unrelaxed_proteins[best_model_name]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Write out the prediction\n",
|
||||
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
|
||||
" with open(pred_output_path, 'w') as f:\n",
|
||||
" f.write(relaxed_pdb)\n",
|
||||
"\n",
|
||||
" best_pdb = relaxed_pdb\n",
|
||||
"\n",
|
||||
" # Write out the prediction\n",
|
||||
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
|
||||
" with open(pred_output_path, 'w') as f:\n",
|
||||
" f.write(best_pdb)\n",
|
||||
"\n",
|
||||
" pbar.update(n=1) # Finished AMBER relax.\n",
|
||||
"\n",
|
||||
"# Construct multiclass b-factors to indicate confidence bands\n",
|
||||
|
@ -590,7 +594,7 @@
|
|||
" banded_b_factors.append(idx)\n",
|
||||
" break\n",
|
||||
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
|
||||
"to_visualize_pdb = utils.overwrite_b_factors(best_pdb, banded_b_factors)\n",
|
||||
"to_visualize_pdb = overwrite_b_factors(best_pdb, banded_b_factors)\n",
|
||||
"\n",
|
||||
"# --- Visualise the prediction & confidence ---\n",
|
||||
"show_sidechains = True\n",
|
||||
|
@ -688,7 +692,7 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"# --- Download the predictions ---\n",
|
||||
"!zip -q -r {output_dir}.zip {output_dir}\n",
|
||||
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
|
||||
"files.download(f'{output_dir}.zip')"
|
||||
],
|
||||
"execution_count": null,
|
||||
|
|
|
@ -392,8 +392,13 @@ class TriangleMultiplicativeUpdate(nn.Module):
|
|||
b = mask
|
||||
b = b * self.sigmoid(self.linear_b_g(z))
|
||||
b = b * self.linear_b_p(z)
|
||||
|
||||
if(is_fp16_enabled()):
|
||||
|
||||
# Prevents overflow of torch.matmul in combine projections in
|
||||
# reduced-precision modes
|
||||
a = a / a.std()
|
||||
b = b / b.std()
|
||||
|
||||
if(is_fp16_enabled()):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = self._combine_projections(a.float(), b.float())
|
||||
else:
|
||||
|
|
|
@ -23,6 +23,13 @@ import string
|
|||
from openfold.np import residue_constants
|
||||
from Bio.PDB import PDBParser
|
||||
import numpy as np
|
||||
import modelcif
|
||||
import modelcif.model
|
||||
import modelcif.dumper
|
||||
import modelcif.reference
|
||||
import modelcif.protocol
|
||||
import modelcif.alignment
|
||||
import modelcif.qa_metric
|
||||
|
||||
|
||||
FeatureDict = Mapping[str, np.ndarray]
|
||||
|
@ -87,8 +94,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
|
|||
|
||||
Args:
|
||||
pdb_str: The contents of the pdb file
|
||||
chain_id: If chain_id is specified (e.g. A), then only that chain is
|
||||
parsed. Else, all chains are parsed.
|
||||
chain_id: If None, then the whole pdb file is parsed. If chain_id is specified (e.g. A), then only that chain
|
||||
is parsed.
|
||||
|
||||
Returns:
|
||||
A new `Protein` parsed from the pdb contents.
|
||||
|
@ -184,7 +191,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
|||
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
|
||||
]
|
||||
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
|
||||
|
||||
|
||||
atoms = ['N', 'CA', 'C']
|
||||
aatype = None
|
||||
atom_positions = None
|
||||
|
@ -267,7 +274,7 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
|
|||
"""
|
||||
out_pdb_lines = []
|
||||
lines = pdb_str.split('\n')
|
||||
|
||||
|
||||
remark = prot.remark
|
||||
if(remark is not None):
|
||||
out_pdb_lines.append(f"REMARK {remark}")
|
||||
|
@ -387,7 +394,7 @@ def to_pdb(prot: Protein) -> str:
|
|||
0
|
||||
] # Protein supports only C, N, O, S, this works.
|
||||
charge = ""
|
||||
|
||||
|
||||
chain_tag = "A"
|
||||
if(chain_index is not None):
|
||||
chain_tag = chain_tags[chain_index[i]]
|
||||
|
@ -436,6 +443,134 @@ def to_pdb(prot: Protein) -> str:
|
|||
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
|
||||
|
||||
|
||||
def to_modelcif(prot: Protein) -> str:
|
||||
"""
|
||||
Converts a `Protein` instance to a ModelCIF string. Chains with identical modelled coordinates
|
||||
will be treated as the same polymer entity. But note that if chains differ in modelled regions,
|
||||
no attempt is made at identifying them as a single polymer entity.
|
||||
|
||||
Args:
|
||||
prot: The protein to convert to PDB.
|
||||
|
||||
Returns:
|
||||
ModelCIF string.
|
||||
"""
|
||||
|
||||
restypes = residue_constants.restypes + ["X"]
|
||||
atom_types = residue_constants.atom_types
|
||||
|
||||
atom_mask = prot.atom_mask
|
||||
aatype = prot.aatype
|
||||
atom_positions = prot.atom_positions
|
||||
residue_index = prot.residue_index.astype(np.int32)
|
||||
b_factors = prot.b_factors
|
||||
chain_index = prot.chain_index
|
||||
|
||||
n = aatype.shape[0]
|
||||
if chain_index is None:
|
||||
chain_index = [0 for i in range(n)]
|
||||
|
||||
system = modelcif.System(title='OpenFold prediction')
|
||||
|
||||
# Finding chains and creating entities
|
||||
seqs = {}
|
||||
seq = []
|
||||
last_chain_idx = None
|
||||
for i in range(n):
|
||||
if last_chain_idx is not None and last_chain_idx != chain_index[i]:
|
||||
seqs[last_chain_idx] = seq
|
||||
seq = []
|
||||
seq.append(restypes[aatype[i]])
|
||||
last_chain_idx = chain_index[i]
|
||||
# finally add the last chain
|
||||
seqs[last_chain_idx] = seq
|
||||
|
||||
# now reduce sequences to unique ones (note this won't work if different asyms have different unmodelled regions)
|
||||
unique_seqs = {}
|
||||
for chain_idx, seq_list in seqs.items():
|
||||
seq = "".join(seq_list)
|
||||
if seq in unique_seqs:
|
||||
unique_seqs[seq].append(chain_idx)
|
||||
else:
|
||||
unique_seqs[seq] = [chain_idx]
|
||||
|
||||
# adding 1 entity per unique sequence
|
||||
entities_map = {}
|
||||
for key, value in unique_seqs.items():
|
||||
model_e = modelcif.Entity(key, description='Model subunit')
|
||||
for chain_idx in value:
|
||||
entities_map[chain_idx] = model_e
|
||||
|
||||
chain_tags = string.ascii_uppercase
|
||||
asym_unit_map = {}
|
||||
for chain_idx in set(chain_index):
|
||||
# Define the model assembly
|
||||
chain_id = chain_tags[chain_idx]
|
||||
asym = modelcif.AsymUnit(entities_map[chain_idx], details='Model subunit %s' % chain_id, id=chain_id)
|
||||
asym_unit_map[chain_idx] = asym
|
||||
modeled_assembly = modelcif.Assembly(asym_unit_map.values(), name='Modeled assembly')
|
||||
|
||||
class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
|
||||
name = "pLDDT"
|
||||
software = None
|
||||
description = "Predicted lddt"
|
||||
|
||||
class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
|
||||
name = "pLDDT"
|
||||
software = None
|
||||
description = "Global pLDDT, mean of per-residue pLDDTs"
|
||||
|
||||
class _MyModel(modelcif.model.AbInitioModel):
|
||||
def get_atoms(self):
|
||||
# Add all atom sites.
|
||||
for i in range(n):
|
||||
for atom_name, pos, mask, b_factor in zip(
|
||||
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
|
||||
):
|
||||
if mask < 0.5:
|
||||
continue
|
||||
element = atom_name[0] # Protein supports only C, N, O, S, this works.
|
||||
yield modelcif.model.Atom(
|
||||
asym_unit=asym_unit_map[chain_index[i]], type_symbol=element,
|
||||
seq_id=residue_index[i], atom_id=atom_name,
|
||||
x=pos[0], y=pos[1], z=pos[2],
|
||||
het=False, biso=b_factor, occupancy=1.00)
|
||||
|
||||
def add_scores(self):
|
||||
# local scores
|
||||
plddt_per_residue = {}
|
||||
for i in range(n):
|
||||
for mask, b_factor in zip(atom_mask[i], b_factors[i]):
|
||||
if mask < 0.5:
|
||||
continue
|
||||
# add 1 per residue, not 1 per atom
|
||||
if chain_index[i] not in plddt_per_residue:
|
||||
# first time a chain index is seen: add the key and start the residue dict
|
||||
plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor}
|
||||
if residue_index[i] not in plddt_per_residue[chain_index[i]]:
|
||||
plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor
|
||||
plddts = []
|
||||
for chain_idx in plddt_per_residue:
|
||||
for residue_idx in plddt_per_residue[chain_idx]:
|
||||
plddt = plddt_per_residue[chain_idx][residue_idx]
|
||||
plddts.append(plddt)
|
||||
self.qa_metrics.append(
|
||||
_LocalPLDDT(asym_unit_map[chain_idx].residue(residue_idx), plddt))
|
||||
# global score
|
||||
self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts))))
|
||||
|
||||
# Add the model and modeling protocol to the file and write them out:
|
||||
model = _MyModel(assembly=modeled_assembly, name='Best scoring model')
|
||||
model.add_scores()
|
||||
|
||||
model_group = modelcif.model.ModelGroup([model], name='All models')
|
||||
system.model_groups.append(model_group)
|
||||
|
||||
fh = io.StringIO()
|
||||
modelcif.dumper.write(fh, [system])
|
||||
return fh.getvalue()
|
||||
|
||||
|
||||
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
||||
"""Computes an ideal atom mask.
|
||||
|
||||
|
|
|
@ -524,9 +524,6 @@ def run_pipeline(
|
|||
_check_residues_are_well_defined(prot)
|
||||
pdb_string = clean_protein(prot, checks=checks)
|
||||
|
||||
# We keep the input around to restore metadata deleted by the relaxer
|
||||
input_prot = prot
|
||||
|
||||
exclude_residues = exclude_residues or []
|
||||
exclude_residues = set(exclude_residues)
|
||||
violations = np.inf
|
||||
|
|
|
@ -57,7 +57,7 @@ class AmberRelaxation(object):
|
|||
self._use_gpu = use_gpu
|
||||
|
||||
def process(
|
||||
self, *, prot: protein.Protein
|
||||
self, *, prot: protein.Protein, cif_output: bool
|
||||
) -> Tuple[str, Dict[str, Any], np.ndarray]:
|
||||
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
|
||||
out = amber_minimize.run_pipeline(
|
||||
|
@ -89,5 +89,11 @@ class AmberRelaxation(object):
|
|||
]
|
||||
|
||||
min_pdb = protein.add_pdb_headers(prot, min_pdb)
|
||||
output_str = min_pdb
|
||||
if cif_output:
|
||||
# TODO the model cif will be missing some metadata like headers (PARENTs and
|
||||
# REMARK with some details of the run, like num of recycles)
|
||||
final_prot = protein.from_pdb_string(min_pdb)
|
||||
output_str = protein.to_modelcif(final_prot)
|
||||
|
||||
return min_pdb, debug_data, violations
|
||||
return output_str, debug_data, violations
|
||||
|
|
|
@ -228,7 +228,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
|
|||
return unrelaxed_protein
|
||||
|
||||
|
||||
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name):
|
||||
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name, cif_output):
|
||||
amber_relaxer = relax.AmberRelaxation(
|
||||
use_gpu=(model_device != "cpu"),
|
||||
**config.relax,
|
||||
|
@ -239,7 +239,8 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out
|
|||
if "cuda" in model_device:
|
||||
device_no = model_device.split(":")[-1]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
|
||||
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
|
||||
# the struct_str will contain either a PDB-format or a ModelCIF format string
|
||||
struct_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein, cif_output=cif_output)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
|
||||
relaxation_time = time.perf_counter() - t
|
||||
|
||||
|
@ -247,10 +248,13 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out
|
|||
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))
|
||||
|
||||
# Save the relaxed PDB.
|
||||
suffix = "_relaxed.pdb"
|
||||
if cif_output:
|
||||
suffix = "_relaxed.cif"
|
||||
relaxed_output_path = os.path.join(
|
||||
output_directory, f'{output_name}_relaxed.pdb'
|
||||
output_directory, f'{output_name}{suffix}'
|
||||
)
|
||||
with open(relaxed_output_path, 'w') as fp:
|
||||
fp.write(relaxed_pdb_str)
|
||||
fp.write(struct_str)
|
||||
|
||||
logger.info(f"Relaxed output written to {relaxed_output_path}...")
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".")
|
|||
torch_major_version = int(torch_versions[0])
|
||||
torch_minor_version = int(torch_versions[1])
|
||||
if(
|
||||
torch_major_version > 1 or
|
||||
torch_major_version > 1 or
|
||||
(torch_major_version == 1 and torch_minor_version >= 12)
|
||||
):
|
||||
# Gives a large speedup on Ampere-class GPUs
|
||||
|
@ -79,7 +79,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
|
|||
)
|
||||
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
|
||||
logger.info(f"Generating alignments for {tag}...")
|
||||
|
||||
|
||||
os.makedirs(local_alignment_dir)
|
||||
|
||||
alignment_runner = data_pipeline.AlignmentRunner(
|
||||
|
@ -157,8 +157,8 @@ def main(args):
|
|||
|
||||
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
|
||||
|
||||
if (args.trace_model):
|
||||
if (not config.data.predict.fixed_size):
|
||||
if(args.trace_model):
|
||||
if(not config.data.predict.fixed_size):
|
||||
raise ValueError(
|
||||
"Tracing requires that fixed_size mode be enabled in the config"
|
||||
)
|
||||
|
@ -230,10 +230,10 @@ def main(args):
|
|||
random_seed = args.data_random_seed
|
||||
if random_seed is None:
|
||||
random_seed = random.randrange(2**32)
|
||||
|
||||
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed + 1)
|
||||
|
||||
|
||||
feature_processor = feature_pipeline.FeaturePipeline(config.data)
|
||||
if not os.path.exists(output_dir_base):
|
||||
os.makedirs(output_dir_base)
|
||||
|
@ -249,7 +249,7 @@ def main(args):
|
|||
fasta_path = os.path.join(args.fasta_dir, fasta_file)
|
||||
with open(fasta_path, "r") as fp:
|
||||
data = fp.read()
|
||||
|
||||
|
||||
tags, seqs = parse_fasta(data)
|
||||
|
||||
if ((not is_multimer) and len(tags) != 1):
|
||||
|
@ -280,10 +280,10 @@ def main(args):
|
|||
output_name = f'{tag}_{args.config_preset}'
|
||||
if args.output_postfix is not None:
|
||||
output_name = f'{output_name}_{args.output_postfix}'
|
||||
|
||||
|
||||
# Does nothing if the alignments have already been computed
|
||||
precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
|
||||
|
||||
|
||||
feature_dict = feature_dicts.get(tag, None)
|
||||
if(feature_dict is None):
|
||||
feature_dict = generate_feature_dict(
|
||||
|
@ -308,64 +308,70 @@ def main(args):
|
|||
)
|
||||
|
||||
processed_feature_dict = {
|
||||
k:torch.as_tensor(v, device=args.model_device)
|
||||
k:torch.as_tensor(v, device=args.model_device)
|
||||
for k,v in processed_feature_dict.items()
|
||||
}
|
||||
|
||||
if (args.trace_model):
|
||||
if (rounded_seqlen > cur_tracing_interval):
|
||||
logger.info(
|
||||
f"Tracing model at {rounded_seqlen} residues..."
|
||||
)
|
||||
t = time.perf_counter()
|
||||
trace_model_(model, processed_feature_dict)
|
||||
tracing_time = time.perf_counter() - t
|
||||
logger.info(
|
||||
f"Tracing time: {tracing_time}"
|
||||
)
|
||||
cur_tracing_interval = rounded_seqlen
|
||||
|
||||
if (args.trace_model):
|
||||
if (rounded_seqlen > cur_tracing_interval):
|
||||
logger.info(
|
||||
f"Tracing model at {rounded_seqlen} residues..."
|
||||
)
|
||||
t = time.perf_counter()
|
||||
trace_model_(model, processed_feature_dict)
|
||||
tracing_time = time.perf_counter() - t
|
||||
logger.info(
|
||||
f"Tracing time: {tracing_time}"
|
||||
)
|
||||
cur_tracing_interval = rounded_seqlen
|
||||
out = run_model(model, processed_feature_dict, tag, args.output_dir)
|
||||
|
||||
out = run_model(model, processed_feature_dict, tag, args.output_dir)
|
||||
|
||||
# Toss out the recycling dimensions --- we don't need them anymore
|
||||
processed_feature_dict = tensor_tree_map(
|
||||
lambda x: np.array(x[..., -1].cpu()),
|
||||
processed_feature_dict
|
||||
)
|
||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
||||
unrelaxed_protein = prep_output(
|
||||
out,
|
||||
processed_feature_dict,
|
||||
feature_dict,
|
||||
feature_processor,
|
||||
args.config_preset,
|
||||
args.multimer_ri_gap,
|
||||
args.subtract_plddt
|
||||
)
|
||||
|
||||
unrelaxed_output_path = os.path.join(
|
||||
output_directory, f'{output_name}_unrelaxed.pdb'
|
||||
)
|
||||
|
||||
with open(unrelaxed_output_path, 'w') as fp:
|
||||
fp.write(protein.to_pdb(unrelaxed_protein))
|
||||
|
||||
logger.info(f"Output written to {unrelaxed_output_path}...")
|
||||
|
||||
if not args.skip_relaxation:
|
||||
# Relax the prediction.
|
||||
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
||||
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
|
||||
|
||||
if args.save_outputs:
|
||||
output_dict_path = os.path.join(
|
||||
output_directory, f'{output_name}_output_dict.pkl'
|
||||
# Toss out the recycling dimensions --- we don't need them anymore
|
||||
processed_feature_dict = tensor_tree_map(
|
||||
lambda x: np.array(x[..., -1].cpu()),
|
||||
processed_feature_dict
|
||||
)
|
||||
with open(output_dict_path, "wb") as fp:
|
||||
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
||||
|
||||
logger.info(f"Model output written to {output_dict_path}...")
|
||||
unrelaxed_protein = prep_output(
|
||||
out,
|
||||
processed_feature_dict,
|
||||
feature_dict,
|
||||
feature_processor,
|
||||
args.config_preset,
|
||||
args.multimer_ri_gap,
|
||||
args.subtract_plddt
|
||||
)
|
||||
|
||||
unrelaxed_file_suffix = "_unrelaxed.pdb"
|
||||
if args.cif_output:
|
||||
unrelaxed_file_suffix = "_unrelaxed.cif"
|
||||
unrelaxed_output_path = os.path.join(
|
||||
output_directory, f'{output_name}{unrelaxed_file_suffix}'
|
||||
)
|
||||
|
||||
with open(unrelaxed_output_path, 'w') as fp:
|
||||
if args.cif_output:
|
||||
fp.write(protein.to_modelcif(unrelaxed_protein))
|
||||
else:
|
||||
fp.write(protein.to_pdb(unrelaxed_protein))
|
||||
|
||||
logger.info(f"Output written to {unrelaxed_output_path}...")
|
||||
|
||||
if not args.skip_relaxation:
|
||||
# Relax the prediction.
|
||||
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
||||
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
|
||||
|
||||
if args.save_outputs:
|
||||
output_dict_path = os.path.join(
|
||||
output_directory, f'{output_name}_output_dict.pkl'
|
||||
)
|
||||
with open(output_dict_path, "wb") as fp:
|
||||
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
logger.info(f"Model output written to {output_dict_path}...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -447,12 +453,16 @@ if __name__ == "__main__":
|
|||
"--long_sequence_inference", action="store_true", default=False,
|
||||
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cif_output", action="store_true", default=False,
|
||||
help="Output predicted models in ModelCIF format instead of PDB format (default)"
|
||||
)
|
||||
add_data_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
|
||||
args.jax_param_path = os.path.join(
|
||||
"openfold", "resources", "params",
|
||||
"openfold", "resources", "params",
|
||||
"params_" + args.config_preset + ".npz"
|
||||
)
|
||||
|
||||
|
|
|
@ -57,9 +57,8 @@ def main(args):
|
|||
|
||||
seq = mmcif_object.chain_to_seqres[chain_id]
|
||||
|
||||
if(args.max_seqlen > 0):
|
||||
if(len(seq) > len(seq)):
|
||||
continue
|
||||
if(args.max_seqlen > 0 and len(seq) > args.max_seqlen):
|
||||
continue
|
||||
|
||||
fasta_file = '\n'.join([
|
||||
f">{pdb_id}_{chain_id}",
|
||||
|
|
3
setup.py
3
setup.py
|
@ -16,6 +16,7 @@ import os
|
|||
from setuptools import setup, Extension, find_packages
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
from scripts.utils import get_nvidia_cc
|
||||
|
@ -37,7 +38,7 @@ extra_cuda_flags = [
|
|||
]
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
if cuda_dir==None:
|
||||
if cuda_dir==None or torch.version.cuda==None:
|
||||
print("CUDA is not found, cpu version is installed")
|
||||
return None, -1, 0
|
||||
else:
|
||||
|
|
|
@ -106,7 +106,7 @@ def main(args):
|
|||
logger.info(f"Output written to {unrelaxed_output_path}...")
|
||||
|
||||
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
||||
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
|
||||
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
Loading…
Reference in New Issue