Fixed learning rate scheduler issue, returned to original msa file parsing

This commit is contained in:
Christina Floristean 2024-02-07 15:47:24 -05:00
parent bc07500422
commit 6275091c96
3 changed files with 22 additions and 15 deletions

View File

@ -21,14 +21,11 @@ import dataclasses
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import subprocess
import numpy as np
import torch
import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
FeatureDict = MutableMapping[str, np.ndarray]
@ -704,10 +701,10 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
msa_data = {}
if(alignment_index is not None):
if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
@ -718,14 +715,14 @@ class DataPipeline:
for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name)
if(ext == ".a3m"):
if ext == ".a3m":
msa = parsers.parse_a3m(
read_msa(start, size)
)
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
# Multimer "uniprot_hits" processed separately.
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
msa = parsers.parse_stockholm(read_msa(start, size))
else:
continue
@ -734,13 +731,22 @@ class DataPipeline:
fp.close()
else:
# Now will split the following steps into multiple processes
current_directory = os.path.dirname(os.path.abspath(__file__))
cmd = f"{current_directory}/tools/parse_msa_files.py"
msa_data_path = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data_path = msa_data_path.stdout.lstrip().rstrip()
msa_data = pickle.load((open(msa_data_path,'rb')))
os.remove(msa_data_path)
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)
if ext == ".a3m":
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
return msa_data

View File

@ -63,7 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag),
local_alignment_dir = os.path.join(alignment_dir, tag)
if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...")

View File

@ -234,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
last_epoch=self.last_lr_step
)
return {