Merge main again
This commit is contained in:
commit
736f27fdc8
|
@ -27,4 +27,5 @@ dependencies:
|
||||||
- typing-extensions==3.10.0.2
|
- typing-extensions==3.10.0.2
|
||||||
- pytorch_lightning==1.5.10
|
- pytorch_lightning==1.5.10
|
||||||
- wandb==0.12.21
|
- wandb==0.12.21
|
||||||
|
- modelcif==0.7
|
||||||
- git+https://github.com/NVIDIA/dllogger.git
|
- git+https://github.com/NVIDIA/dllogger.git
|
||||||
|
|
|
@ -121,10 +121,11 @@
|
||||||
" %env PATH=/opt/conda/bin:{PATH}\n",
|
" %env PATH=/opt/conda/bin:{PATH}\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Install the required versions of all dependencies.\n",
|
" # Install the required versions of all dependencies.\n",
|
||||||
|
" %shell conda install -y -q conda==4.13.0\n",
|
||||||
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
|
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
|
||||||
" kalign2=2.04 \\\n",
|
" kalign2=2.04 \\\n",
|
||||||
" hhsuite=3.3.0 \\\n",
|
" hhsuite=3.3.0 \\\n",
|
||||||
" python=3.7 \\\n",
|
" python=3.8 \\\n",
|
||||||
" 2>&1 1>/dev/null\n",
|
" 2>&1 1>/dev/null\n",
|
||||||
" %shell pip install -q \\\n",
|
" %shell pip install -q \\\n",
|
||||||
" ml-collections==0.1.0 \\\n",
|
" ml-collections==0.1.0 \\\n",
|
||||||
|
@ -180,15 +181,12 @@
|
||||||
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
|
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
|
||||||
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
|
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if(relax_prediction):\n",
|
" %shell conda install -y -q -c conda-forge openmm=7.5.1\n",
|
||||||
" %shell conda install -y -q -c conda-forge \\\n",
|
" # Apply OpenMM patch.\n",
|
||||||
" openmm=7.5.1 \\\n",
|
" %shell pushd /opt/conda/lib/python3.8/site-packages/ && \\\n",
|
||||||
" pdbfixer=1.7\n",
|
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
|
||||||
" \n",
|
" popd\n",
|
||||||
" # Apply OpenMM patch.\n",
|
" %shell conda install -y -q -c conda-forge pdbfixer=1.7\n",
|
||||||
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
|
|
||||||
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
|
|
||||||
" popd\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" if(weight_set == 'AlphaFold'):\n",
|
" if(weight_set == 'AlphaFold'):\n",
|
||||||
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
|
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
|
||||||
|
@ -222,8 +220,8 @@
|
||||||
"import unittest.mock\n",
|
"import unittest.mock\n",
|
||||||
"import sys\n",
|
"import sys\n",
|
||||||
"\n",
|
"\n",
|
||||||
"sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
|
"sys.path.insert(0, '/usr/local/lib/python3.8/site-packages/')\n",
|
||||||
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
|
"sys.path.append('/opt/conda/lib/python3.8/site-packages')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Allows us to skip installing these packages\n",
|
"# Allows us to skip installing these packages\n",
|
||||||
"unnecessary_modules = [\n",
|
"unnecessary_modules = [\n",
|
||||||
|
@ -247,6 +245,14 @@
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import py3Dmol\n",
|
"import py3Dmol\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
|
"import shutil\n",
|
||||||
|
"\n",
|
||||||
|
"# Prevent shell magic being broken by openmm, prevent this cryptic error:\n",
|
||||||
|
"# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n",
|
||||||
|
"import locale\n",
|
||||||
|
"def getpreferredencoding(do_setlocale = True):\n",
|
||||||
|
" return \"UTF-8\"\n",
|
||||||
|
"locale.getpreferredencoding = getpreferredencoding\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# A filthy hack to avoid slow Linear layer initialization\n",
|
"# A filthy hack to avoid slow Linear layer initialization\n",
|
||||||
"import openfold.model.primitives\n",
|
"import openfold.model.primitives\n",
|
||||||
|
@ -267,9 +273,8 @@
|
||||||
"from openfold.data.tools import jackhmmer\n",
|
"from openfold.data.tools import jackhmmer\n",
|
||||||
"from openfold.model import model\n",
|
"from openfold.model import model\n",
|
||||||
"from openfold.np import protein\n",
|
"from openfold.np import protein\n",
|
||||||
"if(relax_prediction):\n",
|
"from openfold.np.relax import relax\n",
|
||||||
" from openfold.np.relax import relax\n",
|
"from openfold.np.relax.utils import overwrite_b_factors\n",
|
||||||
" from openfold.np.relax import utils\n",
|
|
||||||
"from openfold.utils.import_weights import import_jax_weights_\n",
|
"from openfold.utils.import_weights import import_jax_weights_\n",
|
||||||
"from openfold.utils.tensor_utils import tensor_tree_map\n",
|
"from openfold.utils.tensor_utils import tensor_tree_map\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -571,14 +576,13 @@
|
||||||
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
|
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
|
||||||
" prot=unrelaxed_proteins[best_model_name]\n",
|
" prot=unrelaxed_proteins[best_model_name]\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
|
||||||
" # Write out the prediction\n",
|
|
||||||
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
|
|
||||||
" with open(pred_output_path, 'w') as f:\n",
|
|
||||||
" f.write(relaxed_pdb)\n",
|
|
||||||
"\n",
|
|
||||||
" best_pdb = relaxed_pdb\n",
|
" best_pdb = relaxed_pdb\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # Write out the prediction\n",
|
||||||
|
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
|
||||||
|
" with open(pred_output_path, 'w') as f:\n",
|
||||||
|
" f.write(best_pdb)\n",
|
||||||
|
"\n",
|
||||||
" pbar.update(n=1) # Finished AMBER relax.\n",
|
" pbar.update(n=1) # Finished AMBER relax.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Construct multiclass b-factors to indicate confidence bands\n",
|
"# Construct multiclass b-factors to indicate confidence bands\n",
|
||||||
|
@ -590,7 +594,7 @@
|
||||||
" banded_b_factors.append(idx)\n",
|
" banded_b_factors.append(idx)\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
|
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
|
||||||
"to_visualize_pdb = utils.overwrite_b_factors(best_pdb, banded_b_factors)\n",
|
"to_visualize_pdb = overwrite_b_factors(best_pdb, banded_b_factors)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# --- Visualise the prediction & confidence ---\n",
|
"# --- Visualise the prediction & confidence ---\n",
|
||||||
"show_sidechains = True\n",
|
"show_sidechains = True\n",
|
||||||
|
@ -688,7 +692,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# --- Download the predictions ---\n",
|
"# --- Download the predictions ---\n",
|
||||||
"!zip -q -r {output_dir}.zip {output_dir}\n",
|
"shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
|
||||||
"files.download(f'{output_dir}.zip')"
|
"files.download(f'{output_dir}.zip')"
|
||||||
],
|
],
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|
|
@ -392,8 +392,13 @@ class TriangleMultiplicativeUpdate(nn.Module):
|
||||||
b = mask
|
b = mask
|
||||||
b = b * self.sigmoid(self.linear_b_g(z))
|
b = b * self.sigmoid(self.linear_b_g(z))
|
||||||
b = b * self.linear_b_p(z)
|
b = b * self.linear_b_p(z)
|
||||||
|
|
||||||
if(is_fp16_enabled()):
|
# Prevents overflow of torch.matmul in combine projections in
|
||||||
|
# reduced-precision modes
|
||||||
|
a = a / a.std()
|
||||||
|
b = b / b.std()
|
||||||
|
|
||||||
|
if(is_fp16_enabled()):
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x = self._combine_projections(a.float(), b.float())
|
x = self._combine_projections(a.float(), b.float())
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -23,6 +23,13 @@ import string
|
||||||
from openfold.np import residue_constants
|
from openfold.np import residue_constants
|
||||||
from Bio.PDB import PDBParser
|
from Bio.PDB import PDBParser
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import modelcif
|
||||||
|
import modelcif.model
|
||||||
|
import modelcif.dumper
|
||||||
|
import modelcif.reference
|
||||||
|
import modelcif.protocol
|
||||||
|
import modelcif.alignment
|
||||||
|
import modelcif.qa_metric
|
||||||
|
|
||||||
|
|
||||||
FeatureDict = Mapping[str, np.ndarray]
|
FeatureDict = Mapping[str, np.ndarray]
|
||||||
|
@ -87,8 +94,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pdb_str: The contents of the pdb file
|
pdb_str: The contents of the pdb file
|
||||||
chain_id: If chain_id is specified (e.g. A), then only that chain is
|
chain_id: If None, then the whole pdb file is parsed. If chain_id is specified (e.g. A), then only that chain
|
||||||
parsed. Else, all chains are parsed.
|
is parsed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new `Protein` parsed from the pdb contents.
|
A new `Protein` parsed from the pdb contents.
|
||||||
|
@ -184,7 +191,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
||||||
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
|
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
|
||||||
]
|
]
|
||||||
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
|
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
|
||||||
|
|
||||||
atoms = ['N', 'CA', 'C']
|
atoms = ['N', 'CA', 'C']
|
||||||
aatype = None
|
aatype = None
|
||||||
atom_positions = None
|
atom_positions = None
|
||||||
|
@ -267,7 +274,7 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
|
||||||
"""
|
"""
|
||||||
out_pdb_lines = []
|
out_pdb_lines = []
|
||||||
lines = pdb_str.split('\n')
|
lines = pdb_str.split('\n')
|
||||||
|
|
||||||
remark = prot.remark
|
remark = prot.remark
|
||||||
if(remark is not None):
|
if(remark is not None):
|
||||||
out_pdb_lines.append(f"REMARK {remark}")
|
out_pdb_lines.append(f"REMARK {remark}")
|
||||||
|
@ -387,7 +394,7 @@ def to_pdb(prot: Protein) -> str:
|
||||||
0
|
0
|
||||||
] # Protein supports only C, N, O, S, this works.
|
] # Protein supports only C, N, O, S, this works.
|
||||||
charge = ""
|
charge = ""
|
||||||
|
|
||||||
chain_tag = "A"
|
chain_tag = "A"
|
||||||
if(chain_index is not None):
|
if(chain_index is not None):
|
||||||
chain_tag = chain_tags[chain_index[i]]
|
chain_tag = chain_tags[chain_index[i]]
|
||||||
|
@ -436,6 +443,134 @@ def to_pdb(prot: Protein) -> str:
|
||||||
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
|
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
|
||||||
|
|
||||||
|
|
||||||
|
def to_modelcif(prot: Protein) -> str:
|
||||||
|
"""
|
||||||
|
Converts a `Protein` instance to a ModelCIF string. Chains with identical modelled coordinates
|
||||||
|
will be treated as the same polymer entity. But note that if chains differ in modelled regions,
|
||||||
|
no attempt is made at identifying them as a single polymer entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prot: The protein to convert to PDB.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelCIF string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
restypes = residue_constants.restypes + ["X"]
|
||||||
|
atom_types = residue_constants.atom_types
|
||||||
|
|
||||||
|
atom_mask = prot.atom_mask
|
||||||
|
aatype = prot.aatype
|
||||||
|
atom_positions = prot.atom_positions
|
||||||
|
residue_index = prot.residue_index.astype(np.int32)
|
||||||
|
b_factors = prot.b_factors
|
||||||
|
chain_index = prot.chain_index
|
||||||
|
|
||||||
|
n = aatype.shape[0]
|
||||||
|
if chain_index is None:
|
||||||
|
chain_index = [0 for i in range(n)]
|
||||||
|
|
||||||
|
system = modelcif.System(title='OpenFold prediction')
|
||||||
|
|
||||||
|
# Finding chains and creating entities
|
||||||
|
seqs = {}
|
||||||
|
seq = []
|
||||||
|
last_chain_idx = None
|
||||||
|
for i in range(n):
|
||||||
|
if last_chain_idx is not None and last_chain_idx != chain_index[i]:
|
||||||
|
seqs[last_chain_idx] = seq
|
||||||
|
seq = []
|
||||||
|
seq.append(restypes[aatype[i]])
|
||||||
|
last_chain_idx = chain_index[i]
|
||||||
|
# finally add the last chain
|
||||||
|
seqs[last_chain_idx] = seq
|
||||||
|
|
||||||
|
# now reduce sequences to unique ones (note this won't work if different asyms have different unmodelled regions)
|
||||||
|
unique_seqs = {}
|
||||||
|
for chain_idx, seq_list in seqs.items():
|
||||||
|
seq = "".join(seq_list)
|
||||||
|
if seq in unique_seqs:
|
||||||
|
unique_seqs[seq].append(chain_idx)
|
||||||
|
else:
|
||||||
|
unique_seqs[seq] = [chain_idx]
|
||||||
|
|
||||||
|
# adding 1 entity per unique sequence
|
||||||
|
entities_map = {}
|
||||||
|
for key, value in unique_seqs.items():
|
||||||
|
model_e = modelcif.Entity(key, description='Model subunit')
|
||||||
|
for chain_idx in value:
|
||||||
|
entities_map[chain_idx] = model_e
|
||||||
|
|
||||||
|
chain_tags = string.ascii_uppercase
|
||||||
|
asym_unit_map = {}
|
||||||
|
for chain_idx in set(chain_index):
|
||||||
|
# Define the model assembly
|
||||||
|
chain_id = chain_tags[chain_idx]
|
||||||
|
asym = modelcif.AsymUnit(entities_map[chain_idx], details='Model subunit %s' % chain_id, id=chain_id)
|
||||||
|
asym_unit_map[chain_idx] = asym
|
||||||
|
modeled_assembly = modelcif.Assembly(asym_unit_map.values(), name='Modeled assembly')
|
||||||
|
|
||||||
|
class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
|
||||||
|
name = "pLDDT"
|
||||||
|
software = None
|
||||||
|
description = "Predicted lddt"
|
||||||
|
|
||||||
|
class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
|
||||||
|
name = "pLDDT"
|
||||||
|
software = None
|
||||||
|
description = "Global pLDDT, mean of per-residue pLDDTs"
|
||||||
|
|
||||||
|
class _MyModel(modelcif.model.AbInitioModel):
|
||||||
|
def get_atoms(self):
|
||||||
|
# Add all atom sites.
|
||||||
|
for i in range(n):
|
||||||
|
for atom_name, pos, mask, b_factor in zip(
|
||||||
|
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
|
||||||
|
):
|
||||||
|
if mask < 0.5:
|
||||||
|
continue
|
||||||
|
element = atom_name[0] # Protein supports only C, N, O, S, this works.
|
||||||
|
yield modelcif.model.Atom(
|
||||||
|
asym_unit=asym_unit_map[chain_index[i]], type_symbol=element,
|
||||||
|
seq_id=residue_index[i], atom_id=atom_name,
|
||||||
|
x=pos[0], y=pos[1], z=pos[2],
|
||||||
|
het=False, biso=b_factor, occupancy=1.00)
|
||||||
|
|
||||||
|
def add_scores(self):
|
||||||
|
# local scores
|
||||||
|
plddt_per_residue = {}
|
||||||
|
for i in range(n):
|
||||||
|
for mask, b_factor in zip(atom_mask[i], b_factors[i]):
|
||||||
|
if mask < 0.5:
|
||||||
|
continue
|
||||||
|
# add 1 per residue, not 1 per atom
|
||||||
|
if chain_index[i] not in plddt_per_residue:
|
||||||
|
# first time a chain index is seen: add the key and start the residue dict
|
||||||
|
plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor}
|
||||||
|
if residue_index[i] not in plddt_per_residue[chain_index[i]]:
|
||||||
|
plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor
|
||||||
|
plddts = []
|
||||||
|
for chain_idx in plddt_per_residue:
|
||||||
|
for residue_idx in plddt_per_residue[chain_idx]:
|
||||||
|
plddt = plddt_per_residue[chain_idx][residue_idx]
|
||||||
|
plddts.append(plddt)
|
||||||
|
self.qa_metrics.append(
|
||||||
|
_LocalPLDDT(asym_unit_map[chain_idx].residue(residue_idx), plddt))
|
||||||
|
# global score
|
||||||
|
self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts))))
|
||||||
|
|
||||||
|
# Add the model and modeling protocol to the file and write them out:
|
||||||
|
model = _MyModel(assembly=modeled_assembly, name='Best scoring model')
|
||||||
|
model.add_scores()
|
||||||
|
|
||||||
|
model_group = modelcif.model.ModelGroup([model], name='All models')
|
||||||
|
system.model_groups.append(model_group)
|
||||||
|
|
||||||
|
fh = io.StringIO()
|
||||||
|
modelcif.dumper.write(fh, [system])
|
||||||
|
return fh.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
||||||
"""Computes an ideal atom mask.
|
"""Computes an ideal atom mask.
|
||||||
|
|
||||||
|
|
|
@ -524,9 +524,6 @@ def run_pipeline(
|
||||||
_check_residues_are_well_defined(prot)
|
_check_residues_are_well_defined(prot)
|
||||||
pdb_string = clean_protein(prot, checks=checks)
|
pdb_string = clean_protein(prot, checks=checks)
|
||||||
|
|
||||||
# We keep the input around to restore metadata deleted by the relaxer
|
|
||||||
input_prot = prot
|
|
||||||
|
|
||||||
exclude_residues = exclude_residues or []
|
exclude_residues = exclude_residues or []
|
||||||
exclude_residues = set(exclude_residues)
|
exclude_residues = set(exclude_residues)
|
||||||
violations = np.inf
|
violations = np.inf
|
||||||
|
|
|
@ -57,7 +57,7 @@ class AmberRelaxation(object):
|
||||||
self._use_gpu = use_gpu
|
self._use_gpu = use_gpu
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
self, *, prot: protein.Protein
|
self, *, prot: protein.Protein, cif_output: bool
|
||||||
) -> Tuple[str, Dict[str, Any], np.ndarray]:
|
) -> Tuple[str, Dict[str, Any], np.ndarray]:
|
||||||
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
|
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
|
||||||
out = amber_minimize.run_pipeline(
|
out = amber_minimize.run_pipeline(
|
||||||
|
@ -89,5 +89,11 @@ class AmberRelaxation(object):
|
||||||
]
|
]
|
||||||
|
|
||||||
min_pdb = protein.add_pdb_headers(prot, min_pdb)
|
min_pdb = protein.add_pdb_headers(prot, min_pdb)
|
||||||
|
output_str = min_pdb
|
||||||
|
if cif_output:
|
||||||
|
# TODO the model cif will be missing some metadata like headers (PARENTs and
|
||||||
|
# REMARK with some details of the run, like num of recycles)
|
||||||
|
final_prot = protein.from_pdb_string(min_pdb)
|
||||||
|
output_str = protein.to_modelcif(final_prot)
|
||||||
|
|
||||||
return min_pdb, debug_data, violations
|
return output_str, debug_data, violations
|
||||||
|
|
|
@ -228,7 +228,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
|
||||||
return unrelaxed_protein
|
return unrelaxed_protein
|
||||||
|
|
||||||
|
|
||||||
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name):
|
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name, cif_output):
|
||||||
amber_relaxer = relax.AmberRelaxation(
|
amber_relaxer = relax.AmberRelaxation(
|
||||||
use_gpu=(model_device != "cpu"),
|
use_gpu=(model_device != "cpu"),
|
||||||
**config.relax,
|
**config.relax,
|
||||||
|
@ -239,7 +239,8 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out
|
||||||
if "cuda" in model_device:
|
if "cuda" in model_device:
|
||||||
device_no = model_device.split(":")[-1]
|
device_no = model_device.split(":")[-1]
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
|
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
|
||||||
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
|
# the struct_str will contain either a PDB-format or a ModelCIF format string
|
||||||
|
struct_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein, cif_output=cif_output)
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
|
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
|
||||||
relaxation_time = time.perf_counter() - t
|
relaxation_time = time.perf_counter() - t
|
||||||
|
|
||||||
|
@ -247,10 +248,13 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out
|
||||||
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))
|
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))
|
||||||
|
|
||||||
# Save the relaxed PDB.
|
# Save the relaxed PDB.
|
||||||
|
suffix = "_relaxed.pdb"
|
||||||
|
if cif_output:
|
||||||
|
suffix = "_relaxed.cif"
|
||||||
relaxed_output_path = os.path.join(
|
relaxed_output_path = os.path.join(
|
||||||
output_directory, f'{output_name}_relaxed.pdb'
|
output_directory, f'{output_name}{suffix}'
|
||||||
)
|
)
|
||||||
with open(relaxed_output_path, 'w') as fp:
|
with open(relaxed_output_path, 'w') as fp:
|
||||||
fp.write(relaxed_pdb_str)
|
fp.write(struct_str)
|
||||||
|
|
||||||
logger.info(f"Relaxed output written to {relaxed_output_path}...")
|
logger.info(f"Relaxed output written to {relaxed_output_path}...")
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2021 AlQuraishi Laboratory
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
# Copyright 2021 DeepMind Technologies Limited
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
# You may obtain a copy of the License at
|
# You may obtain a copy of the License at
|
||||||
|
@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".")
|
||||||
torch_major_version = int(torch_versions[0])
|
torch_major_version = int(torch_versions[0])
|
||||||
torch_minor_version = int(torch_versions[1])
|
torch_minor_version = int(torch_versions[1])
|
||||||
if(
|
if(
|
||||||
torch_major_version > 1 or
|
torch_major_version > 1 or
|
||||||
(torch_major_version == 1 and torch_minor_version >= 12)
|
(torch_major_version == 1 and torch_minor_version >= 12)
|
||||||
):
|
):
|
||||||
# Gives a large speedup on Ampere-class GPUs
|
# Gives a large speedup on Ampere-class GPUs
|
||||||
|
@ -79,7 +79,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
|
||||||
)
|
)
|
||||||
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
|
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
|
||||||
logger.info(f"Generating alignments for {tag}...")
|
logger.info(f"Generating alignments for {tag}...")
|
||||||
|
|
||||||
os.makedirs(local_alignment_dir)
|
os.makedirs(local_alignment_dir)
|
||||||
|
|
||||||
alignment_runner = data_pipeline.AlignmentRunner(
|
alignment_runner = data_pipeline.AlignmentRunner(
|
||||||
|
@ -157,8 +157,8 @@ def main(args):
|
||||||
|
|
||||||
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
|
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
|
||||||
|
|
||||||
if (args.trace_model):
|
if(args.trace_model):
|
||||||
if (not config.data.predict.fixed_size):
|
if(not config.data.predict.fixed_size):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Tracing requires that fixed_size mode be enabled in the config"
|
"Tracing requires that fixed_size mode be enabled in the config"
|
||||||
)
|
)
|
||||||
|
@ -230,10 +230,10 @@ def main(args):
|
||||||
random_seed = args.data_random_seed
|
random_seed = args.data_random_seed
|
||||||
if random_seed is None:
|
if random_seed is None:
|
||||||
random_seed = random.randrange(2**32)
|
random_seed = random.randrange(2**32)
|
||||||
|
|
||||||
np.random.seed(random_seed)
|
np.random.seed(random_seed)
|
||||||
torch.manual_seed(random_seed + 1)
|
torch.manual_seed(random_seed + 1)
|
||||||
|
|
||||||
feature_processor = feature_pipeline.FeaturePipeline(config.data)
|
feature_processor = feature_pipeline.FeaturePipeline(config.data)
|
||||||
if not os.path.exists(output_dir_base):
|
if not os.path.exists(output_dir_base):
|
||||||
os.makedirs(output_dir_base)
|
os.makedirs(output_dir_base)
|
||||||
|
@ -249,7 +249,7 @@ def main(args):
|
||||||
fasta_path = os.path.join(args.fasta_dir, fasta_file)
|
fasta_path = os.path.join(args.fasta_dir, fasta_file)
|
||||||
with open(fasta_path, "r") as fp:
|
with open(fasta_path, "r") as fp:
|
||||||
data = fp.read()
|
data = fp.read()
|
||||||
|
|
||||||
tags, seqs = parse_fasta(data)
|
tags, seqs = parse_fasta(data)
|
||||||
|
|
||||||
if ((not is_multimer) and len(tags) != 1):
|
if ((not is_multimer) and len(tags) != 1):
|
||||||
|
@ -280,10 +280,10 @@ def main(args):
|
||||||
output_name = f'{tag}_{args.config_preset}'
|
output_name = f'{tag}_{args.config_preset}'
|
||||||
if args.output_postfix is not None:
|
if args.output_postfix is not None:
|
||||||
output_name = f'{output_name}_{args.output_postfix}'
|
output_name = f'{output_name}_{args.output_postfix}'
|
||||||
|
|
||||||
# Does nothing if the alignments have already been computed
|
# Does nothing if the alignments have already been computed
|
||||||
precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
|
precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
|
||||||
|
|
||||||
feature_dict = feature_dicts.get(tag, None)
|
feature_dict = feature_dicts.get(tag, None)
|
||||||
if(feature_dict is None):
|
if(feature_dict is None):
|
||||||
feature_dict = generate_feature_dict(
|
feature_dict = generate_feature_dict(
|
||||||
|
@ -308,64 +308,70 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_feature_dict = {
|
processed_feature_dict = {
|
||||||
k:torch.as_tensor(v, device=args.model_device)
|
k:torch.as_tensor(v, device=args.model_device)
|
||||||
for k,v in processed_feature_dict.items()
|
for k,v in processed_feature_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (args.trace_model):
|
||||||
|
if (rounded_seqlen > cur_tracing_interval):
|
||||||
|
logger.info(
|
||||||
|
f"Tracing model at {rounded_seqlen} residues..."
|
||||||
|
)
|
||||||
|
t = time.perf_counter()
|
||||||
|
trace_model_(model, processed_feature_dict)
|
||||||
|
tracing_time = time.perf_counter() - t
|
||||||
|
logger.info(
|
||||||
|
f"Tracing time: {tracing_time}"
|
||||||
|
)
|
||||||
|
cur_tracing_interval = rounded_seqlen
|
||||||
|
|
||||||
if (args.trace_model):
|
out = run_model(model, processed_feature_dict, tag, args.output_dir)
|
||||||
if (rounded_seqlen > cur_tracing_interval):
|
|
||||||
logger.info(
|
|
||||||
f"Tracing model at {rounded_seqlen} residues..."
|
|
||||||
)
|
|
||||||
t = time.perf_counter()
|
|
||||||
trace_model_(model, processed_feature_dict)
|
|
||||||
tracing_time = time.perf_counter() - t
|
|
||||||
logger.info(
|
|
||||||
f"Tracing time: {tracing_time}"
|
|
||||||
)
|
|
||||||
cur_tracing_interval = rounded_seqlen
|
|
||||||
|
|
||||||
out = run_model(model, processed_feature_dict, tag, args.output_dir)
|
# Toss out the recycling dimensions --- we don't need them anymore
|
||||||
|
processed_feature_dict = tensor_tree_map(
|
||||||
# Toss out the recycling dimensions --- we don't need them anymore
|
lambda x: np.array(x[..., -1].cpu()),
|
||||||
processed_feature_dict = tensor_tree_map(
|
processed_feature_dict
|
||||||
lambda x: np.array(x[..., -1].cpu()),
|
|
||||||
processed_feature_dict
|
|
||||||
)
|
|
||||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
|
||||||
unrelaxed_protein = prep_output(
|
|
||||||
out,
|
|
||||||
processed_feature_dict,
|
|
||||||
feature_dict,
|
|
||||||
feature_processor,
|
|
||||||
args.config_preset,
|
|
||||||
args.multimer_ri_gap,
|
|
||||||
args.subtract_plddt
|
|
||||||
)
|
|
||||||
|
|
||||||
unrelaxed_output_path = os.path.join(
|
|
||||||
output_directory, f'{output_name}_unrelaxed.pdb'
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(unrelaxed_output_path, 'w') as fp:
|
|
||||||
fp.write(protein.to_pdb(unrelaxed_protein))
|
|
||||||
|
|
||||||
logger.info(f"Output written to {unrelaxed_output_path}...")
|
|
||||||
|
|
||||||
if not args.skip_relaxation:
|
|
||||||
# Relax the prediction.
|
|
||||||
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
|
||||||
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
|
|
||||||
|
|
||||||
if args.save_outputs:
|
|
||||||
output_dict_path = os.path.join(
|
|
||||||
output_directory, f'{output_name}_output_dict.pkl'
|
|
||||||
)
|
)
|
||||||
with open(output_dict_path, "wb") as fp:
|
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
||||||
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
|
|
||||||
logger.info(f"Model output written to {output_dict_path}...")
|
unrelaxed_protein = prep_output(
|
||||||
|
out,
|
||||||
|
processed_feature_dict,
|
||||||
|
feature_dict,
|
||||||
|
feature_processor,
|
||||||
|
args.config_preset,
|
||||||
|
args.multimer_ri_gap,
|
||||||
|
args.subtract_plddt
|
||||||
|
)
|
||||||
|
|
||||||
|
unrelaxed_file_suffix = "_unrelaxed.pdb"
|
||||||
|
if args.cif_output:
|
||||||
|
unrelaxed_file_suffix = "_unrelaxed.cif"
|
||||||
|
unrelaxed_output_path = os.path.join(
|
||||||
|
output_directory, f'{output_name}{unrelaxed_file_suffix}'
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(unrelaxed_output_path, 'w') as fp:
|
||||||
|
if args.cif_output:
|
||||||
|
fp.write(protein.to_modelcif(unrelaxed_protein))
|
||||||
|
else:
|
||||||
|
fp.write(protein.to_pdb(unrelaxed_protein))
|
||||||
|
|
||||||
|
logger.info(f"Output written to {unrelaxed_output_path}...")
|
||||||
|
|
||||||
|
if not args.skip_relaxation:
|
||||||
|
# Relax the prediction.
|
||||||
|
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
||||||
|
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
|
||||||
|
|
||||||
|
if args.save_outputs:
|
||||||
|
output_dict_path = os.path.join(
|
||||||
|
output_directory, f'{output_name}_output_dict.pkl'
|
||||||
|
)
|
||||||
|
with open(output_dict_path, "wb") as fp:
|
||||||
|
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
logger.info(f"Model output written to {output_dict_path}...")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -447,12 +453,16 @@ if __name__ == "__main__":
|
||||||
"--long_sequence_inference", action="store_true", default=False,
|
"--long_sequence_inference", action="store_true", default=False,
|
||||||
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
|
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cif_output", action="store_true", default=False,
|
||||||
|
help="Output predicted models in ModelCIF format instead of PDB format (default)"
|
||||||
|
)
|
||||||
add_data_args(parser)
|
add_data_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
|
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
|
||||||
args.jax_param_path = os.path.join(
|
args.jax_param_path = os.path.join(
|
||||||
"openfold", "resources", "params",
|
"openfold", "resources", "params",
|
||||||
"params_" + args.config_preset + ".npz"
|
"params_" + args.config_preset + ".npz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -57,9 +57,8 @@ def main(args):
|
||||||
|
|
||||||
seq = mmcif_object.chain_to_seqres[chain_id]
|
seq = mmcif_object.chain_to_seqres[chain_id]
|
||||||
|
|
||||||
if(args.max_seqlen > 0):
|
if(args.max_seqlen > 0 and len(seq) > args.max_seqlen):
|
||||||
if(len(seq) > len(seq)):
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
fasta_file = '\n'.join([
|
fasta_file = '\n'.join([
|
||||||
f">{pdb_id}_{chain_id}",
|
f">{pdb_id}_{chain_id}",
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -16,6 +16,7 @@ import os
|
||||||
from setuptools import setup, Extension, find_packages
|
from setuptools import setup, Extension, find_packages
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||||
|
|
||||||
from scripts.utils import get_nvidia_cc
|
from scripts.utils import get_nvidia_cc
|
||||||
|
@ -37,7 +38,7 @@ extra_cuda_flags = [
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_cuda_bare_metal_version(cuda_dir):
|
def get_cuda_bare_metal_version(cuda_dir):
|
||||||
if cuda_dir==None:
|
if cuda_dir==None or torch.version.cuda==None:
|
||||||
print("CUDA is not found, cpu version is installed")
|
print("CUDA is not found, cpu version is installed")
|
||||||
return None, -1, 0
|
return None, -1, 0
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -106,7 +106,7 @@ def main(args):
|
||||||
logger.info(f"Output written to {unrelaxed_output_path}...")
|
logger.info(f"Output written to {unrelaxed_output_path}...")
|
||||||
|
|
||||||
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
||||||
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
|
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, False)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
Loading…
Reference in New Issue