Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement initial guess feature #592

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 180 additions & 20 deletions colabfold/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def validate_and_fix_mmcif(cif_file: Path):
"CSD" : "CYS", "SEC" : "CYS"
}

order_to_restype = {v: k for k, v in residue_constants.restype_order_with_x.items()}

class ReplaceOrRemoveHetatmSelect(Select):
def accept_residue(self, residue):
hetfield, _, _ = residue.get_id()
Expand Down Expand Up @@ -304,6 +306,74 @@ def pad_input(
) # template_mask (4, 4) second value
return input_fix

MODRES = {'MSE':'MET','MLY':'LYS','FME':'MET','HYP':'PRO',
'TPO':'THR','CSO':'CYS','SEP':'SER','M3L':'LYS',
'HSK':'HIS','SAC':'SER','PCA':'GLU','DAL':'ALA',
'CME':'CYS','CSD':'CYS','OCS':'CYS','DPR':'PRO',
'B3K':'LYS','ALY':'LYS','YCM':'CYS','MLZ':'LYS',
'4BF':'TYR','KCX':'LYS','B3E':'GLU','B3D':'ASP',
'HZP':'PRO','CSX':'CYS','BAL':'ALA','HIC':'HIS',
'DBZ':'ALA','DCY':'CYS','DVA':'VAL','NLE':'LEU',
'SMC':'CYS','AGM':'ARG','B3A':'ALA','DAS':'ASP',
'DLY':'LYS','DSN':'SER','DTH':'THR','GL3':'GLY',
'HY3':'PRO','LLP':'LYS','MGN':'GLN','MHS':'HIS',
'TRQ':'TRP','B3Y':'TYR','PHI':'PHE','PTR':'TYR',
'TYS':'TYR','IAS':'ASP','GPL':'LYS','KYN':'TRP',
'CSD':'CYS','SEC':'CYS'}

def pdb_to_string(
pdb_file: str,
chains: Optional[str] = None,
models: Optional[list] = None
) -> str:
'''read pdb file and return as string'''

if chains is not None:
if "," in chains: chains = chains.split(",")
if not isinstance(chains,list): chains = [chains]
if models is not None:
if not isinstance(models,list): models = [models]

modres = {**MODRES}
lines = []
seen = []
model = 1

if "\n" in pdb_file:
old_lines = pdb_file.split("\n")
else:
with open(pdb_file,"rb") as f:
old_lines = [line.decode("utf-8","ignore").rstrip() for line in f]
for line in old_lines:
if line[:5] == "MODEL":
model = int(line[5:])
if models is None or model in models:
if line[:6] == "MODRES":
k = line[12:15]
v = line[24:27]
if k not in modres and v in residue_constants.restype_3to1:
modres[k] = v
if line[:6] == "HETATM":
k = line[17:20]
if k in modres:
line = "ATOM "+line[6:17]+modres[k]+line[20:]
if line[:4] == "ATOM":
chain = line[21:22]
if chains is None or chain in chains:
atom = line[12:12+4].strip()
resi = line[17:17+3]
resn = line[22:22+5].strip()
if resn[-1].isalpha(): # alternative atom
resn = resn[:-1]
line = line[:26]+" "+line[27:]
key = f"{model}_{chain}_{resn}_{resi}_{atom}"
if key not in seen: # skip alternative placements
lines.append(line)
seen.append(key)
if line[:5] == "MODEL" or line[:3] == "TER" or line[:6] == "ENDMDL":
lines.append(line)
return "\n".join(lines)

class file_manager:
def __init__(self, prefix: str, result_dir: Path):
self.prefix = prefix
Expand Down Expand Up @@ -331,6 +401,7 @@ def predict_structure(
pad_len: int,
model_type: str,
model_runner_and_params: List[Tuple[str, model.RunModel, haiku.Params]],
initial_guess: str = None,
num_relax: int = 0,
relax_max_iterations: int = 0,
relax_tolerance: float = 2.39,
Expand Down Expand Up @@ -388,6 +459,18 @@ def predict_structure(
model_names.append(tag)
files.set_tag(tag)

# initial guess
if initial_guess:
input_guess = Path(initial_guess)
if input_guess.suffix == ".pdb":
pdb_string = pdb_to_string(initial_guess)
input_features["all_atom_positions"] = protein.from_pdb_string(pdb_string).atom_positions
elif input_guess.suffix == ".cif":
input_features["all_atom_positions"] = protein.from_mmcif_string(input_guess.read_text()).atom_positions
else:
raise ValueError(f"Unsupported initial guess file format: {initial_guess}")


########################
# predict
########################
Expand Down Expand Up @@ -571,6 +654,27 @@ def parse_fasta(fasta_string: str) -> Tuple[List[str], List[str]]:

return sequences, descriptions

def decode_structure_sequences(
aatype_array: List[int],
chain_index_array: List[int],
order_dict: Dict[int, str] = order_to_restype
) -> List[str]:
decoded_sequences = []
current_sequence = []

for i in range(len(aatype_array)):
amino_acid = order_dict[aatype_array[i]]
if i == 0 or chain_index_array[i] == chain_index_array[i - 1]:
current_sequence.append(amino_acid)
else:
decoded_sequences.append("".join(current_sequence))
current_sequence = [amino_acid]

# Append the last sequence
decoded_sequences.append("".join(current_sequence))

return decoded_sequences

def get_queries(
input_path: Union[str, Path], sort_queries_by: str = "length"
) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]:
Expand Down Expand Up @@ -612,6 +716,20 @@ def get_queries(
else:
# Complex mode
queries.append((header, sequence.upper().split(":"), None))
elif input_path.suffix in [".pdb", ".cif"]:
if input_path.suffix == ".pdb":
pdb_string = pdb_to_string(input_path.read_text())
prot = protein.from_pdb_string(pdb_string)
elif input_path.suffix == ".cif":
prot = protein.from_mmcif_string(input_path.read_text())
header = input_path.stem
sequences = decode_structure_sequences(prot.aatype, prot.chain_index)

if len(sequences) == 0:
raise ValueError(f"{input_path} is empty")

queries = [(header, sequences, None)]

else:
raise ValueError(f"Unknown file format {input_path.suffix}")
else:
Expand All @@ -620,29 +738,44 @@ def get_queries(
for file in sorted(input_path.iterdir()):
if not file.is_file():
continue
if file.suffix.lower() not in [".a3m", ".fasta", ".faa"]:
logger.warning(f"non-fasta/a3m file in input directory: {file}")
if file.suffix.lower() not in [".a3m", ".fasta", ".faa", ".pdb", ".cif"]:
logger.warning(f"non-fasta/a3m/pdb/cif file in input directory: {file}")
continue
(seqs, header) = parse_fasta(file.read_text())
if len(seqs) == 0:
logger.error(f"{file} is empty")
continue
query_sequence = seqs[0]
if len(seqs) > 1 and file.suffix in [".fasta", ".faa", ".fa"]:
logger.warning(
f"More than one sequence in {file}, ignoring all but the first sequence"
)
if file.suffix.lower() in [".pdb", ".cif"]:
header = file.stem
if file.suffix.lower() == ".pdb":
pdb_string = pdb_to_string(file.read_text())
prot = protein.from_pdb_string(pdb_string)
else: # file.suffix.lower() == ".cif"
prot = protein.from_mmcif_string(file.read_text())
sequences = decode_structure_sequences(prot.aatype, prot.chain_index)

if len(sequences) == 0:
logger.error(f"{file} is empty")
continue

if file.suffix.lower() == ".a3m":
a3m_lines = [file.read_text()]
queries.append((file.stem, query_sequence.upper(), a3m_lines))
else:
if query_sequence.count(":") == 0:
# Single sequence
queries.append((file.stem, query_sequence, None))
queries.append((header, sequences, None))
else: # file.suffix.lower() in [".a3m", ".fasta", ".faa"]
(seqs, header) = parse_fasta(file.read_text())
if len(seqs) == 0:
logger.error(f"{file} is empty")
continue
query_sequence = seqs[0]
if len(seqs) > 1 and file.suffix in [".fasta", ".faa", ".fa"]:
logger.warning(
f"More than one sequence in {file}, ignoring all but the first sequence"
)

if file.suffix.lower() == ".a3m":
a3m_lines = [file.read_text()]
queries.append((file.stem, query_sequence.upper(), a3m_lines))
else:
# Complex mode
queries.append((file.stem, query_sequence.upper().split(":"), None))
if query_sequence.count(":") == 0:
# Single sequence
queries.append((file.stem, query_sequence, None))
else:
# Complex mode
queries.append((file.stem, query_sequence.upper().split(":"), None))

# sort by seq. len
if sort_queries_by == "length":
Expand Down Expand Up @@ -1226,6 +1359,7 @@ def run(
num_recycles: Optional[int] = None,
recycle_early_stop_tolerance: Optional[float] = None,
model_order: List[int] = [1,2,3,4,5],
initial_guess: str = None,
num_ensemble: int = 1,
model_type: str = "auto",
msa_mode: str = "mmseqs2_uniref_env",
Expand Down Expand Up @@ -1356,6 +1490,10 @@ def run(
# sort model order
model_order.sort()

# initial guess
if initial_guess is not None:
logger.info(f'Using initial guess: {initial_guess}')

# Record the parameters of this run
config = {
"num_queries": len(queries),
Expand All @@ -1372,6 +1510,7 @@ def run(
"recycle_early_stop_tolerance": recycle_early_stop_tolerance,
"num_ensemble": num_ensemble,
"model_order": model_order,
"initial_guess": initial_guess,
"keep_existing_results": keep_existing_results,
"rank_by": rank_by,
"max_seq": max_seq,
Expand Down Expand Up @@ -1578,6 +1717,7 @@ def run(
use_templates=use_templates,
sequences_lengths=query_sequence_len_array,
pad_len=pad_len,
initial_guess=initial_guess,
model_type=model_type,
model_runner_and_params=model_runner_and_params,
num_relax=num_relax,
Expand Down Expand Up @@ -1811,6 +1951,14 @@ def main():
],
)
pred_group.add_argument("--model-order", default="1,2,3,4,5", type=str)
pred_group.add_argument(
"--initial-guess",
nargs="?",
const=True,
help="Specify a starting model for the prediction. If the main input file is a PDB format, "
"it will be used as the initial guess. Otherwise, you can provide an input file with this flag, "
"which will override the main input."
)
pred_group.add_argument(
"--use-dropout",
default=False,
Expand Down Expand Up @@ -2011,6 +2159,17 @@ def main():
queries, is_complex = get_queries(args.input, args.sort_queries_by)
model_type = set_model_type(is_complex, args.model_type)

# use pdb or cif input as initial guess
if args.initial_guess is not None:
if isinstance(args.initial_guess, str) and Path(args.initial_guess).suffix in (".pdb", ".cif"):
initial_guess = args.initial_guess
elif Path(args.input).suffix in (".pdb", ".cif"):
initial_guess = args.input
else:
raise ValueError("Provide PDB or CIF file for initial guess.")
else:
initial_guess = None

if args.msa_only:
args.num_models = 0

Expand Down Expand Up @@ -2049,6 +2208,7 @@ def main():
recycle_early_stop_tolerance=args.recycle_early_stop_tolerance,
num_ensemble=args.num_ensemble,
model_order=model_order,
initial_guess=initial_guess,
is_complex=is_complex,
keep_existing_results=not args.overwrite_existing_results,
rank_by=args.rank,
Expand Down