Fixed learning rate scheduler issue, returned to original msa file parsing
This commit is contained in:
parent
bc07500422
commit
6275091c96
|
@ -21,14 +21,11 @@ import dataclasses
|
|||
from multiprocessing import cpu_count
|
||||
import tempfile
|
||||
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import torch
|
||||
import pickle
|
||||
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
|
||||
from openfold.data.templates import get_custom_template_features, empty_template_feats
|
||||
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
|
||||
from openfold.data.tools.utils import to_date
|
||||
from openfold.np import residue_constants, protein
|
||||
|
||||
FeatureDict = MutableMapping[str, np.ndarray]
|
||||
|
@ -704,10 +701,10 @@ class DataPipeline:
|
|||
def _parse_msa_data(
|
||||
self,
|
||||
alignment_dir: str,
|
||||
alignment_index: Optional[Any] = None,
|
||||
alignment_index: Optional[Any] = None
|
||||
) -> Mapping[str, Any]:
|
||||
msa_data = {}
|
||||
if(alignment_index is not None):
|
||||
if alignment_index is not None:
|
||||
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
|
||||
|
||||
def read_msa(start, size):
|
||||
|
@ -718,14 +715,14 @@ class DataPipeline:
|
|||
for (name, start, size) in alignment_index["files"]:
|
||||
filename, ext = os.path.splitext(name)
|
||||
|
||||
if(ext == ".a3m"):
|
||||
if ext == ".a3m":
|
||||
msa = parsers.parse_a3m(
|
||||
read_msa(start, size)
|
||||
)
|
||||
# The "hmm_output" exception is a crude way to exclude
|
||||
# multimer template hits.
|
||||
# Multimer "uniprot_hits" processed separately.
|
||||
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
|
||||
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
|
||||
msa = parsers.parse_stockholm(read_msa(start, size))
|
||||
else:
|
||||
continue
|
||||
|
@ -734,13 +731,22 @@ class DataPipeline:
|
|||
|
||||
fp.close()
|
||||
else:
|
||||
# Now will split the following steps into multiple processes
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
cmd = f"{current_directory}/tools/parse_msa_files.py"
|
||||
msa_data_path = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
|
||||
msa_data_path = msa_data_path.stdout.lstrip().rstrip()
|
||||
msa_data = pickle.load((open(msa_data_path,'rb')))
|
||||
os.remove(msa_data_path)
|
||||
for f in os.listdir(alignment_dir):
|
||||
path = os.path.join(alignment_dir, f)
|
||||
filename, ext = os.path.splitext(f)
|
||||
|
||||
if ext == ".a3m":
|
||||
with open(path, "r") as fp:
|
||||
msa = parsers.parse_a3m(fp.read())
|
||||
elif ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]:
|
||||
with open(path, "r") as fp:
|
||||
msa = parsers.parse_stockholm(
|
||||
fp.read()
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
msa_data[f] = msa
|
||||
|
||||
return msa_data
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
|
|||
with open(tmp_fasta_path, "w") as fp:
|
||||
fp.write(f">{tag}\n{seq}")
|
||||
|
||||
local_alignment_dir = os.path.join(alignment_dir, tag),
|
||||
local_alignment_dir = os.path.join(alignment_dir, tag)
|
||||
|
||||
if args.use_precomputed_alignments is None:
|
||||
logger.info(f"Generating alignments for {tag}...")
|
||||
|
|
|
@ -234,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
|
|||
|
||||
lr_scheduler = AlphaFoldLRScheduler(
|
||||
optimizer,
|
||||
last_epoch=self.last_lr_step
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
Loading…
Reference in New Issue