Refactor run_pretrained_openfold.py a little

This commit is contained in:
Gustaf Ahdritz 2022-06-22 23:30:43 -04:00
parent b4b849af15
commit 1fa6ffab77
1 changed files with 58 additions and 53 deletions

View File

@ -162,36 +162,28 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return unrelaxed_protein
def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature_processor, prediction_dir):
with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
data = fp.read()
def parse_fasta(data):
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
return tags, seqs
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if os.path.exists(unrelaxed_output_path):
return
precompute_alignments(tags, seqs, alignment_dir, args)
def generate_feature_dict(
tags,
seqs,
alignment_dir,
data_processor,
args,
):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
if len(seqs) == 1:
tag = tags[0]
seq = seqs[0]
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
@ -212,10 +204,7 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
# Remove temporary FASTA file
os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
return processed_feature_dict, tag, feature_dict
return feature_dict
def load_models_from_command_line(args, config):
@ -226,13 +215,18 @@ def load_models_from_command_line(args, config):
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=args.model_name
model, path, version=args.config_preset
)
model = model.to(args.model_device)
logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..."
)
yield model, None
model_version = os.path.basename(
os.path.normpath(args.jax_param_path),
)
model_version = os.path.splitext(model_version)[0]
yield model, model_version
if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config)
@ -264,11 +258,14 @@ def load_models_from_command_line(args, config):
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(args.model_device)
logger.info(
f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..."
)
yield model, checkpoint_basename
if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
@ -311,23 +308,40 @@ def main(args):
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir):
batch_data = generate_batch(
fasta_file,
args.fasta_dir,
alignment_dir,
data_processor,
feature_processor,
prediction_dir)
if batch_data is None:
# this file has already been processed
with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
data = fp.read()
tags, seqs = parse_fasta(data)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
# Output already exists
if os.path.exists(unrelaxed_output_path):
continue
batch, tag, feature_dict = batch_data
precompute_alignments(tags, seqs, alignment_dir, args)
feature_dict = generate_feature_dict(
tags,
seqs,
alignment_dir,
data_processor,
args,
)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
for model, model_version in load_models_from_command_line(args, config):
working_batch = deepcopy(batch)
out = run_model(model, working_batch, tag, args)
@ -339,21 +353,11 @@ def main(args):
out, working_batch, feature_dict, feature_processor, args
)
output_name = f'{tag}_{args.config_preset}'
if model_version is not None:
output_name = f'{output_name}_{model_version}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation:
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
@ -377,6 +381,7 @@ def main(args):
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
if args.save_outputs:
@ -388,6 +393,7 @@ def main(args):
logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@ -413,8 +419,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--config_preset", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
help="""Name of a model config preset defined in openfold/config.py"""
)
parser.add_argument(
"--jax_param_path", type=str, default=None,