diff --git a/environment.yml b/environment.yml index 0dfb9db..b8a2082 100644 --- a/environment.yml +++ b/environment.yml @@ -27,4 +27,5 @@ dependencies: - typing-extensions==3.10.0.2 - pytorch_lightning==1.5.10 - wandb==0.12.21 + - modelcif==0.7 - git+https://github.com/NVIDIA/dllogger.git diff --git a/notebooks/OpenFold.ipynb b/notebooks/OpenFold.ipynb index b816e91..8f00178 100755 --- a/notebooks/OpenFold.ipynb +++ b/notebooks/OpenFold.ipynb @@ -121,10 +121,11 @@ " %env PATH=/opt/conda/bin:{PATH}\n", "\n", " # Install the required versions of all dependencies.\n", + " %shell conda install -y -q conda==4.13.0\n", " %shell conda install -y -q -c conda-forge -c bioconda \\\n", " kalign2=2.04 \\\n", " hhsuite=3.3.0 \\\n", - " python=3.7 \\\n", + " python=3.8 \\\n", " 2>&1 1>/dev/null\n", " %shell pip install -q \\\n", " ml-collections==0.1.0 \\\n", @@ -180,15 +181,12 @@ " %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n", " %shell /usr/bin/python3 -m pip install -q ./openfold\n", "\n", - " if(relax_prediction):\n", - " %shell conda install -y -q -c conda-forge \\\n", - " openmm=7.5.1 \\\n", - " pdbfixer=1.7\n", - " \n", - " # Apply OpenMM patch.\n", - " %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n", - " patch -p0 < /content/openfold/lib/openmm.patch && \\\n", - " popd\n", + " %shell conda install -y -q -c conda-forge openmm=7.5.1\n", + " # Apply OpenMM patch.\n", + " %shell pushd /opt/conda/lib/python3.8/site-packages/ && \\\n", + " patch -p0 < /content/openfold/lib/openmm.patch && \\\n", + " popd\n", + " %shell conda install -y -q -c conda-forge pdbfixer=1.7\n", "\n", " if(weight_set == 'AlphaFold'):\n", " %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n", @@ -222,8 +220,8 @@ "import unittest.mock\n", "import sys\n", "\n", - "sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n", - "sys.path.append('/opt/conda/lib/python3.7/site-packages')\n", + "sys.path.insert(0, '/usr/local/lib/python3.8/site-packages/')\n", + "sys.path.append('/opt/conda/lib/python3.8/site-packages')\n", "\n", "# Allows us to skip installing these packages\n", "unnecessary_modules = [\n", @@ -247,6 +245,14 @@ "import numpy as np\n", "import py3Dmol\n", "import torch\n", + "import shutil\n", + "\n", + "# Prevent shell magic being broken by openmm, prevent this cryptic error:\n", + "# \"NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968\"\n", + "import locale\n", + "def getpreferredencoding(do_setlocale = True):\n", + " return \"UTF-8\"\n", + "locale.getpreferredencoding = getpreferredencoding\n", "\n", "# A filthy hack to avoid slow Linear layer initialization\n", "import openfold.model.primitives\n", @@ -267,9 +273,8 @@ "from openfold.data.tools import jackhmmer\n", "from openfold.model import model\n", "from openfold.np import protein\n", - "if(relax_prediction):\n", - " from openfold.np.relax import relax\n", - " from openfold.np.relax import utils\n", + "from openfold.np.relax import relax\n", + "from openfold.np.relax.utils import overwrite_b_factors\n", "from openfold.utils.import_weights import import_jax_weights_\n", "from openfold.utils.tensor_utils import tensor_tree_map\n", "\n", @@ -571,14 +576,13 @@ " relaxed_pdb, _, _ = amber_relaxer.process(\n", " prot=unrelaxed_proteins[best_model_name]\n", " )\n", - "\n", - " # Write out the prediction\n", - " pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n", - " with open(pred_output_path, 'w') as f:\n", - " f.write(relaxed_pdb)\n", - "\n", " best_pdb = relaxed_pdb\n", "\n", + " # Write out the prediction\n", + " pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n", + " with open(pred_output_path, 'w') as f:\n", + " f.write(best_pdb)\n", + "\n", " pbar.update(n=1) # Finished AMBER relax.\n", "\n", "# Construct multiclass b-factors to indicate confidence bands\n", @@ -590,7 +594,7 @@ " banded_b_factors.append(idx)\n", " break\n", "banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n", - "to_visualize_pdb = utils.overwrite_b_factors(best_pdb, banded_b_factors)\n", + "to_visualize_pdb = overwrite_b_factors(best_pdb, banded_b_factors)\n", "\n", "# --- Visualise the prediction & confidence ---\n", "show_sidechains = True\n", @@ -688,7 +692,7 @@ "\n", "\n", "# --- Download the predictions ---\n", - "!zip -q -r {output_dir}.zip {output_dir}\n", + "shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n", "files.download(f'{output_dir}.zip')" ], "execution_count": null, diff --git a/openfold/model/triangular_multiplicative_update.py b/openfold/model/triangular_multiplicative_update.py index 3885e4c..ff10cea 100644 --- a/openfold/model/triangular_multiplicative_update.py +++ b/openfold/model/triangular_multiplicative_update.py @@ -392,8 +392,13 @@ class TriangleMultiplicativeUpdate(nn.Module): b = mask b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.linear_b_p(z) - - if(is_fp16_enabled()): + + # Prevents overflow of torch.matmul in combine projections in + # reduced-precision modes + a = a / a.std() + b = b / b.std() + + if(is_fp16_enabled()): with torch.cuda.amp.autocast(enabled=False): x = self._combine_projections(a.float(), b.float()) else: diff --git a/openfold/np/protein.py b/openfold/np/protein.py index 6f6ae36..352eb75 100644 --- a/openfold/np/protein.py +++ b/openfold/np/protein.py @@ -23,6 +23,13 @@ import string from openfold.np import residue_constants from Bio.PDB import PDBParser import numpy as np +import modelcif +import modelcif.model +import modelcif.dumper +import modelcif.reference +import modelcif.protocol +import modelcif.alignment +import modelcif.qa_metric FeatureDict = Mapping[str, np.ndarray] @@ -87,8 +94,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: Args: pdb_str: The contents of the pdb file - chain_id: If chain_id is specified (e.g. A), then only that chain is - parsed. Else, all chains are parsed. + chain_id: If None, then the whole pdb file is parsed. If chain_id is specified (e.g. A), then only that chain + is parsed. Returns: A new `Protein` parsed from the pdb contents. @@ -184,7 +191,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0 ] groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]]) - + atoms = ['N', 'CA', 'C'] aatype = None atom_positions = None @@ -267,7 +274,7 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str: """ out_pdb_lines = [] lines = pdb_str.split('\n') - + remark = prot.remark if(remark is not None): out_pdb_lines.append(f"REMARK {remark}") @@ -387,7 +394,7 @@ def to_pdb(prot: Protein) -> str: 0 ] # Protein supports only C, N, O, S, this works. charge = "" - + chain_tag = "A" if(chain_index is not None): chain_tag = chain_tags[chain_index[i]] @@ -436,6 +443,134 @@ def to_pdb(prot: Protein) -> str: return '\n'.join(pdb_lines) + '\n' # Add terminating newline. +def to_modelcif(prot: Protein) -> str: + """ + Converts a `Protein` instance to a ModelCIF string. Chains with identical modelled coordinates + will be treated as the same polymer entity. But note that if chains differ in modelled regions, + no attempt is made at identifying them as a single polymer entity. + + Args: + prot: The protein to convert to PDB. + + Returns: + ModelCIF string. + """ + + restypes = residue_constants.restypes + ["X"] + atom_types = residue_constants.atom_types + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + chain_index = prot.chain_index + + n = aatype.shape[0] + if chain_index is None: + chain_index = [0 for i in range(n)] + + system = modelcif.System(title='OpenFold prediction') + + # Finding chains and creating entities + seqs = {} + seq = [] + last_chain_idx = None + for i in range(n): + if last_chain_idx is not None and last_chain_idx != chain_index[i]: + seqs[last_chain_idx] = seq + seq = [] + seq.append(restypes[aatype[i]]) + last_chain_idx = chain_index[i] + # finally add the last chain + seqs[last_chain_idx] = seq + + # now reduce sequences to unique ones (note this won't work if different asyms have different unmodelled regions) + unique_seqs = {} + for chain_idx, seq_list in seqs.items(): + seq = "".join(seq_list) + if seq in unique_seqs: + unique_seqs[seq].append(chain_idx) + else: + unique_seqs[seq] = [chain_idx] + + # adding 1 entity per unique sequence + entities_map = {} + for key, value in unique_seqs.items(): + model_e = modelcif.Entity(key, description='Model subunit') + for chain_idx in value: + entities_map[chain_idx] = model_e + + chain_tags = string.ascii_uppercase + asym_unit_map = {} + for chain_idx in set(chain_index): + # Define the model assembly + chain_id = chain_tags[chain_idx] + asym = modelcif.AsymUnit(entities_map[chain_idx], details='Model subunit %s' % chain_id, id=chain_id) + asym_unit_map[chain_idx] = asym + modeled_assembly = modelcif.Assembly(asym_unit_map.values(), name='Modeled assembly') + + class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT): + name = "pLDDT" + software = None + description = "Predicted lddt" + + class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT): + name = "pLDDT" + software = None + description = "Global pLDDT, mean of per-residue pLDDTs" + + class _MyModel(modelcif.model.AbInitioModel): + def get_atoms(self): + # Add all atom sites. + for i in range(n): + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i] + ): + if mask < 0.5: + continue + element = atom_name[0] # Protein supports only C, N, O, S, this works. + yield modelcif.model.Atom( + asym_unit=asym_unit_map[chain_index[i]], type_symbol=element, + seq_id=residue_index[i], atom_id=atom_name, + x=pos[0], y=pos[1], z=pos[2], + het=False, biso=b_factor, occupancy=1.00) + + def add_scores(self): + # local scores + plddt_per_residue = {} + for i in range(n): + for mask, b_factor in zip(atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + # add 1 per residue, not 1 per atom + if chain_index[i] not in plddt_per_residue: + # first time a chain index is seen: add the key and start the residue dict + plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor} + if residue_index[i] not in plddt_per_residue[chain_index[i]]: + plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor + plddts = [] + for chain_idx in plddt_per_residue: + for residue_idx in plddt_per_residue[chain_idx]: + plddt = plddt_per_residue[chain_idx][residue_idx] + plddts.append(plddt) + self.qa_metrics.append( + _LocalPLDDT(asym_unit_map[chain_idx].residue(residue_idx), plddt)) + # global score + self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts)))) + + # Add the model and modeling protocol to the file and write them out: + model = _MyModel(assembly=modeled_assembly, name='Best scoring model') + model.add_scores() + + model_group = modelcif.model.ModelGroup([model], name='All models') + system.model_groups.append(model_group) + + fh = io.StringIO() + modelcif.dumper.write(fh, [system]) + return fh.getvalue() + + def ideal_atom_mask(prot: Protein) -> np.ndarray: """Computes an ideal atom mask. diff --git a/openfold/np/relax/amber_minimize.py b/openfold/np/relax/amber_minimize.py index 2487fe6..c32a44b 100644 --- a/openfold/np/relax/amber_minimize.py +++ b/openfold/np/relax/amber_minimize.py @@ -524,9 +524,6 @@ def run_pipeline( _check_residues_are_well_defined(prot) pdb_string = clean_protein(prot, checks=checks) - # We keep the input around to restore metadata deleted by the relaxer - input_prot = prot - exclude_residues = exclude_residues or [] exclude_residues = set(exclude_residues) violations = np.inf diff --git a/openfold/np/relax/relax.py b/openfold/np/relax/relax.py index 155e379..2711c76 100644 --- a/openfold/np/relax/relax.py +++ b/openfold/np/relax/relax.py @@ -57,7 +57,7 @@ class AmberRelaxation(object): self._use_gpu = use_gpu def process( - self, *, prot: protein.Protein + self, *, prot: protein.Protein, cif_output: bool ) -> Tuple[str, Dict[str, Any], np.ndarray]: """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" out = amber_minimize.run_pipeline( @@ -89,5 +89,11 @@ class AmberRelaxation(object): ] min_pdb = protein.add_pdb_headers(prot, min_pdb) + output_str = min_pdb + if cif_output: + # TODO the model cif will be missing some metadata like headers (PARENTs and + # REMARK with some details of the run, like num of recycles) + final_prot = protein.from_pdb_string(min_pdb) + output_str = protein.to_modelcif(final_prot) - return min_pdb, debug_data, violations + return output_str, debug_data, violations diff --git a/openfold/utils/script_utils.py b/openfold/utils/script_utils.py index c5dfc8a..626dd0a 100644 --- a/openfold/utils/script_utils.py +++ b/openfold/utils/script_utils.py @@ -228,7 +228,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult return unrelaxed_protein -def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name): +def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name, cif_output): amber_relaxer = relax.AmberRelaxation( use_gpu=(model_device != "cpu"), **config.relax, @@ -239,7 +239,8 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out if "cuda" in model_device: device_no = model_device.split(":")[-1] os.environ["CUDA_VISIBLE_DEVICES"] = device_no - relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) + # the struct_str will contain either a PDB-format or a ModelCIF format string + struct_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein, cif_output=cif_output) os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices relaxation_time = time.perf_counter() - t @@ -247,10 +248,13 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json")) # Save the relaxed PDB. + suffix = "_relaxed.pdb" + if cif_output: + suffix = "_relaxed.cif" relaxed_output_path = os.path.join( - output_directory, f'{output_name}_relaxed.pdb' + output_directory, f'{output_name}{suffix}' ) with open(relaxed_output_path, 'w') as fp: - fp.write(relaxed_pdb_str) + fp.write(struct_str) logger.info(f"Relaxed output written to {relaxed_output_path}...") \ No newline at end of file diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index 51e7a67..5937524 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -1,6 +1,6 @@ # Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".") torch_major_version = int(torch_versions[0]) torch_minor_version = int(torch_versions[1]) if( - torch_major_version > 1 or + torch_major_version > 1 or (torch_major_version == 1 and torch_minor_version >= 12) ): # Gives a large speedup on Ampere-class GPUs @@ -79,7 +79,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer): ) if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)): logger.info(f"Generating alignments for {tag}...") - + os.makedirs(local_alignment_dir) alignment_runner = data_pipeline.AlignmentRunner( @@ -157,8 +157,8 @@ def main(args): config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) - if (args.trace_model): - if (not config.data.predict.fixed_size): + if(args.trace_model): + if(not config.data.predict.fixed_size): raise ValueError( "Tracing requires that fixed_size mode be enabled in the config" ) @@ -230,10 +230,10 @@ def main(args): random_seed = args.data_random_seed if random_seed is None: random_seed = random.randrange(2**32) - + np.random.seed(random_seed) torch.manual_seed(random_seed + 1) - + feature_processor = feature_pipeline.FeaturePipeline(config.data) if not os.path.exists(output_dir_base): os.makedirs(output_dir_base) @@ -249,7 +249,7 @@ def main(args): fasta_path = os.path.join(args.fasta_dir, fasta_file) with open(fasta_path, "r") as fp: data = fp.read() - + tags, seqs = parse_fasta(data) if ((not is_multimer) and len(tags) != 1): @@ -280,10 +280,10 @@ def main(args): output_name = f'{tag}_{args.config_preset}' if args.output_postfix is not None: output_name = f'{output_name}_{args.output_postfix}' - + # Does nothing if the alignments have already been computed precompute_alignments(tags, seqs, alignment_dir, args, is_multimer) - + feature_dict = feature_dicts.get(tag, None) if(feature_dict is None): feature_dict = generate_feature_dict( @@ -308,64 +308,70 @@ def main(args): ) processed_feature_dict = { - k:torch.as_tensor(v, device=args.model_device) + k:torch.as_tensor(v, device=args.model_device) for k,v in processed_feature_dict.items() } + if (args.trace_model): + if (rounded_seqlen > cur_tracing_interval): + logger.info( + f"Tracing model at {rounded_seqlen} residues..." + ) + t = time.perf_counter() + trace_model_(model, processed_feature_dict) + tracing_time = time.perf_counter() - t + logger.info( + f"Tracing time: {tracing_time}" + ) + cur_tracing_interval = rounded_seqlen - if (args.trace_model): - if (rounded_seqlen > cur_tracing_interval): - logger.info( - f"Tracing model at {rounded_seqlen} residues..." - ) - t = time.perf_counter() - trace_model_(model, processed_feature_dict) - tracing_time = time.perf_counter() - t - logger.info( - f"Tracing time: {tracing_time}" - ) - cur_tracing_interval = rounded_seqlen + out = run_model(model, processed_feature_dict, tag, args.output_dir) - out = run_model(model, processed_feature_dict, tag, args.output_dir) - - # Toss out the recycling dimensions --- we don't need them anymore - processed_feature_dict = tensor_tree_map( - lambda x: np.array(x[..., -1].cpu()), - processed_feature_dict - ) - out = tensor_tree_map(lambda x: np.array(x.cpu()), out) - unrelaxed_protein = prep_output( - out, - processed_feature_dict, - feature_dict, - feature_processor, - args.config_preset, - args.multimer_ri_gap, - args.subtract_plddt - ) - - unrelaxed_output_path = os.path.join( - output_directory, f'{output_name}_unrelaxed.pdb' - ) - - with open(unrelaxed_output_path, 'w') as fp: - fp.write(protein.to_pdb(unrelaxed_protein)) - - logger.info(f"Output written to {unrelaxed_output_path}...") - - if not args.skip_relaxation: - # Relax the prediction. - logger.info(f"Running relaxation on {unrelaxed_output_path}...") - relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name) - - if args.save_outputs: - output_dict_path = os.path.join( - output_directory, f'{output_name}_output_dict.pkl' + # Toss out the recycling dimensions --- we don't need them anymore + processed_feature_dict = tensor_tree_map( + lambda x: np.array(x[..., -1].cpu()), + processed_feature_dict ) - with open(output_dict_path, "wb") as fp: - pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) - logger.info(f"Model output written to {output_dict_path}...") + unrelaxed_protein = prep_output( + out, + processed_feature_dict, + feature_dict, + feature_processor, + args.config_preset, + args.multimer_ri_gap, + args.subtract_plddt + ) + + unrelaxed_file_suffix = "_unrelaxed.pdb" + if args.cif_output: + unrelaxed_file_suffix = "_unrelaxed.cif" + unrelaxed_output_path = os.path.join( + output_directory, f'{output_name}{unrelaxed_file_suffix}' + ) + + with open(unrelaxed_output_path, 'w') as fp: + if args.cif_output: + fp.write(protein.to_modelcif(unrelaxed_protein)) + else: + fp.write(protein.to_pdb(unrelaxed_protein)) + + logger.info(f"Output written to {unrelaxed_output_path}...") + + if not args.skip_relaxation: + # Relax the prediction. + logger.info(f"Running relaxation on {unrelaxed_output_path}...") + relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output) + + if args.save_outputs: + output_dict_path = os.path.join( + output_directory, f'{output_name}_output_dict.pkl' + ) + with open(output_dict_path, "wb") as fp: + pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) + + logger.info(f"Model output written to {output_dict_path}...") if __name__ == "__main__": @@ -447,12 +453,16 @@ if __name__ == "__main__": "--long_sequence_inference", action="store_true", default=False, help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details""" ) + parser.add_argument( + "--cif_output", action="store_true", default=False, + help="Output predicted models in ModelCIF format instead of PDB format (default)" + ) add_data_args(parser) args = parser.parse_args() if(args.jax_param_path is None and args.openfold_checkpoint_path is None): args.jax_param_path = os.path.join( - "openfold", "resources", "params", + "openfold", "resources", "params", "params_" + args.config_preset + ".npz" ) diff --git a/scripts/download_cameo.py b/scripts/download_cameo.py index 815a34b..11e5cec 100644 --- a/scripts/download_cameo.py +++ b/scripts/download_cameo.py @@ -57,9 +57,8 @@ def main(args): seq = mmcif_object.chain_to_seqres[chain_id] - if(args.max_seqlen > 0): - if(len(seq) > len(seq)): - continue + if(args.max_seqlen > 0 and len(seq) > args.max_seqlen): + continue fasta_file = '\n'.join([ f">{pdb_id}_{chain_id}", diff --git a/setup.py b/setup.py index af4ac63..a2f3b19 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ import os from setuptools import setup, Extension, find_packages import subprocess +import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from scripts.utils import get_nvidia_cc @@ -37,7 +38,7 @@ extra_cuda_flags = [ ] def get_cuda_bare_metal_version(cuda_dir): - if cuda_dir==None: + if cuda_dir==None or torch.version.cuda==None: print("CUDA is not found, cpu version is installed") return None, -1, 0 else: diff --git a/thread_sequence.py b/thread_sequence.py index e67fcdd..48b409f 100644 --- a/thread_sequence.py +++ b/thread_sequence.py @@ -106,7 +106,7 @@ def main(args): logger.info(f"Output written to {unrelaxed_output_path}...") logger.info(f"Running relaxation on {unrelaxed_output_path}...") - relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name) + relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, False) if __name__ == "__main__": parser = argparse.ArgumentParser()