Refactor run_pretrained_openfold.py a little
This commit is contained in:
parent
b4b849af15
commit
1fa6ffab77
|
@ -162,36 +162,28 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
|
|||
return unrelaxed_protein
|
||||
|
||||
|
||||
def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature_processor, prediction_dir):
|
||||
with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
|
||||
data = fp.read()
|
||||
|
||||
def parse_fasta(data):
|
||||
lines = [
|
||||
l.replace('\n', '')
|
||||
for prot in data.split('>') for l in prot.strip().split('\n', 1)
|
||||
][1:]
|
||||
l.replace('\n', '')
|
||||
for prot in data.split('>') for l in prot.strip().split('\n', 1)
|
||||
][1:]
|
||||
tags, seqs = lines[::2], lines[1::2]
|
||||
|
||||
tags = [t.split()[0] for t in tags]
|
||||
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
|
||||
tag = '-'.join(tags)
|
||||
|
||||
output_name = f'{tag}_{args.config_preset}'
|
||||
if args.output_postfix is not None:
|
||||
output_name = f'{output_name}_{args.output_postfix}'
|
||||
return tags, seqs
|
||||
|
||||
# Save the unrelaxed PDB.
|
||||
unrelaxed_output_path = os.path.join(
|
||||
prediction_dir, f'{output_name}_unrelaxed.pdb'
|
||||
)
|
||||
|
||||
if os.path.exists(unrelaxed_output_path):
|
||||
return
|
||||
|
||||
precompute_alignments(tags, seqs, alignment_dir, args)
|
||||
|
||||
def generate_feature_dict(
|
||||
tags,
|
||||
seqs,
|
||||
alignment_dir,
|
||||
data_processor,
|
||||
args,
|
||||
):
|
||||
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
|
||||
if len(seqs) == 1:
|
||||
tag = tags[0]
|
||||
seq = seqs[0]
|
||||
with open(tmp_fasta_path, "w") as fp:
|
||||
fp.write(f">{tag}\n{seq}")
|
||||
|
@ -212,10 +204,7 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
|
|||
# Remove temporary FASTA file
|
||||
os.remove(tmp_fasta_path)
|
||||
|
||||
processed_feature_dict = feature_processor.process_features(
|
||||
feature_dict, mode='predict',
|
||||
)
|
||||
return processed_feature_dict, tag, feature_dict
|
||||
return feature_dict
|
||||
|
||||
|
||||
def load_models_from_command_line(args, config):
|
||||
|
@ -226,13 +215,18 @@ def load_models_from_command_line(args, config):
|
|||
model = AlphaFold(config)
|
||||
model = model.eval()
|
||||
import_jax_weights_(
|
||||
model, path, version=args.model_name
|
||||
model, path, version=args.config_preset
|
||||
)
|
||||
model = model.to(args.model_device)
|
||||
logger.info(
|
||||
f"Successfully loaded JAX parameters at {args.jax_param_path}..."
|
||||
)
|
||||
yield model, None
|
||||
model_version = os.path.basename(
|
||||
os.path.normpath(args.jax_param_path),
|
||||
)
|
||||
model_version = os.path.splitext(model_version)[0]
|
||||
yield model, model_version
|
||||
|
||||
if args.openfold_checkpoint_path:
|
||||
for path in args.openfold_checkpoint_path.split(","):
|
||||
model = AlphaFold(config)
|
||||
|
@ -264,11 +258,14 @@ def load_models_from_command_line(args, config):
|
|||
# The public weights have had this done to them already
|
||||
d = d["ema"]["params"]
|
||||
model.load_state_dict(d)
|
||||
|
||||
model = model.to(args.model_device)
|
||||
logger.info(
|
||||
f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..."
|
||||
)
|
||||
|
||||
yield model, checkpoint_basename
|
||||
|
||||
if not args.jax_param_path and not args.openfold_checkpoint_path:
|
||||
raise ValueError(
|
||||
"At least one of jax_param_path or openfold_checkpoint_path must "
|
||||
|
@ -311,23 +308,40 @@ def main(args):
|
|||
os.makedirs(prediction_dir, exist_ok=True)
|
||||
|
||||
for fasta_file in os.listdir(args.fasta_dir):
|
||||
|
||||
batch_data = generate_batch(
|
||||
fasta_file,
|
||||
args.fasta_dir,
|
||||
alignment_dir,
|
||||
data_processor,
|
||||
feature_processor,
|
||||
prediction_dir)
|
||||
|
||||
if batch_data is None:
|
||||
# this file has already been processed
|
||||
with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
|
||||
data = fp.read()
|
||||
|
||||
tags, seqs = parse_fasta(data)
|
||||
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
|
||||
tag = '-'.join(tags)
|
||||
|
||||
output_name = f'{tag}_{args.config_preset}'
|
||||
if args.output_postfix is not None:
|
||||
output_name = f'{output_name}_{args.output_postfix}'
|
||||
|
||||
unrelaxed_output_path = os.path.join(
|
||||
prediction_dir, f'{output_name}_unrelaxed.pdb'
|
||||
)
|
||||
|
||||
# Output already exists
|
||||
if os.path.exists(unrelaxed_output_path):
|
||||
continue
|
||||
|
||||
batch, tag, feature_dict = batch_data
|
||||
precompute_alignments(tags, seqs, alignment_dir, args)
|
||||
|
||||
feature_dict = generate_feature_dict(
|
||||
tags,
|
||||
seqs,
|
||||
alignment_dir,
|
||||
data_processor,
|
||||
args,
|
||||
)
|
||||
|
||||
processed_feature_dict = feature_processor.process_features(
|
||||
feature_dict, mode='predict',
|
||||
)
|
||||
|
||||
for model, model_version in load_models_from_command_line(args, config):
|
||||
|
||||
working_batch = deepcopy(batch)
|
||||
out = run_model(model, working_batch, tag, args)
|
||||
|
||||
|
@ -339,21 +353,11 @@ def main(args):
|
|||
out, working_batch, feature_dict, feature_processor, args
|
||||
)
|
||||
|
||||
output_name = f'{tag}_{args.config_preset}'
|
||||
|
||||
if model_version is not None:
|
||||
output_name = f'{output_name}_{model_version}'
|
||||
if args.output_postfix is not None:
|
||||
output_name = f'{output_name}_{args.output_postfix}'
|
||||
|
||||
# Save the unrelaxed PDB.
|
||||
unrelaxed_output_path = os.path.join(
|
||||
prediction_dir, f'{output_name}_unrelaxed.pdb'
|
||||
)
|
||||
with open(unrelaxed_output_path, 'w') as fp:
|
||||
fp.write(protein.to_pdb(unrelaxed_protein))
|
||||
|
||||
logger.info(f"Output written to {unrelaxed_output_path}...")
|
||||
|
||||
if not args.skip_relaxation:
|
||||
amber_relaxer = relax.AmberRelaxation(
|
||||
use_gpu=(args.model_device != "cpu"),
|
||||
|
@ -377,6 +381,7 @@ def main(args):
|
|||
)
|
||||
with open(relaxed_output_path, 'w') as fp:
|
||||
fp.write(relaxed_pdb_str)
|
||||
|
||||
logger.info(f"Relaxed output written to {relaxed_output_path}...")
|
||||
|
||||
if args.save_outputs:
|
||||
|
@ -388,6 +393,7 @@ def main(args):
|
|||
|
||||
logger.info(f"Model output written to {output_dict_path}...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
@ -413,8 +419,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument(
|
||||
"--config_preset", type=str, default="model_1",
|
||||
help="""Name of a model config. Choose one of model_{1-5} or
|
||||
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
|
||||
help="""Name of a model config preset defined in openfold/config.py"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jax_param_path", type=str, default=None,
|
||||
|
|
Loading…
Reference in New Issue