From 909fc01c032d6e3697a6b50fbb0121abfe007755 Mon Sep 17 00:00:00 2001 From: Brahm Yachnin Date: Fri, 16 May 2025 10:31:01 -0400 Subject: [PATCH 1/3] Maintains the input chain ids in RFdiffusion output The output was previously renumbering all of the chains, making comparisons to the input structures and handling of multi-chain inputs challenging. This commit maintains the input chain ids in the output. --- rfdiffusion/contigs.py | 14 +++++++++++++- rfdiffusion/inference/model_runners.py | 12 ++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/rfdiffusion/contigs.py b/rfdiffusion/contigs.py index a940d6b6..8eb55f66 100644 --- a/rfdiffusion/contigs.py +++ b/rfdiffusion/contigs.py @@ -80,6 +80,7 @@ def __init__( self.inpaint, self.inpaint_hal, self.inpaint_rf, + self.sampled_mask_length_bound, ) = self.expand_sampled_mask() self.ref = self.inpaint + self.receptor self.hal = self.inpaint_hal + self.receptor_hal @@ -241,6 +242,8 @@ def expand_sampled_mask(self): inpaint_chain_idx = -1 receptor_chain_break = [] inpaint_chain_break = [] + _receptor_mask_length_bound = [] + _inpaint_mask_length_bound = [] for con in self.sampled_mask: if ( all([i[0].isalpha() for i in con.split("/")[:-1]]) @@ -286,6 +289,7 @@ def expand_sampled_mask(self): receptor_chain_break.append( (receptor_idx - 1, 200) ) # 200 aa chain break + _receptor_mask_length_bound.append(len(receptor)) else: inpaint_chain_idx += 1 for subcon in con.split("/"): @@ -320,6 +324,7 @@ def expand_sampled_mask(self): ) inpaint_idx += int(subcon.split("-")[0]) inpaint_chain_break.append((inpaint_idx - 1, 200)) + _inpaint_mask_length_bound.append(len(inpaint)) if self.topo is True or inpaint_hal == []: receptor_hal = [(i[0], i[1]) for i in receptor_hal] @@ -335,7 +340,13 @@ def expand_sampled_mask(self): inpaint_rf[ch_break[0] :] += ch_break[1] for ch_break in receptor_chain_break[:-1]: receptor_rf[ch_break[0] :] += ch_break[1] - + sampled_mask_length_bound = [] + sampled_mask_length_bound.extend(_inpaint_mask_length_bound) + if _inpaint_mask_length_bound: + inpaint_last_bound = _inpaint_mask_length_bound[-1] + else: + inpaint_last_bound = 0 + sampled_mask_length_bound.extend(map(lambda x: x + inpaint_last_bound, _receptor_mask_length_bound)) return ( receptor, receptor_hal, @@ -343,6 +354,7 @@ def expand_sampled_mask(self): inpaint, inpaint_hal, inpaint_rf.tolist(), + sampled_mask_length_bound ) def get_inpaint_seq_str(self, inpaint_s, ss=False): diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index 3e6505f4..d8936397 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -309,8 +309,16 @@ def sample_init(self, return_forward_trajectory=False): contig_map=self.contig_map self.diffusion_mask = self.mask_str - self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(L_mapped)] - + length_bound = self.contig_map.sampled_mask_length_bound.copy() + + first_res = 0 + self.chain_idx = [] + for last_res in length_bound: + chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} - {"_"} + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" + self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) + first_res = last_res + #################################### ### Generate initial coordinates ### #################################### From 63e270f715104896a5684108f582046b1d551b1e Mon Sep 17 00:00:00 2001 From: Brahm Yachnin Date: Tue, 20 May 2025 10:38:19 -0400 Subject: [PATCH 2/3] For fixed chains, retain residue numbering For chains that are completely fixed, retain the residue numbering from the input rather than renumbering. For chains that are partially or fully designed by RFdiffusion, it isn't clear to me what the 'correct' behaviour should be, so these chains will be re-numbered starting at residue 1. --- rfdiffusion/inference/model_runners.py | 31 ++++++++++++++++++++++---- scripts/run_inference.py | 1 + 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index d8936397..8d0f5902 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -14,6 +14,7 @@ from rfdiffusion import util from hydra.core.hydra_config import HydraConfig import os +import string from rfdiffusion.model_input_logger import pickle_function_call import sys @@ -144,13 +145,14 @@ def initialize(self, conf: DictConfig) -> None: self.symmetry = None self.allatom = ComputeAllAtomCoords().to(self.device) - + if self.inf_conf.input_pdb is None: # set default pdb script_dir=os.path.dirname(os.path.realpath(__file__)) self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) self.chain_idx = None + self.idx_pdb = None ############################## ### Handle Partial Noising ### @@ -313,10 +315,31 @@ def sample_init(self, return_forward_trajectory=False): first_res = 0 self.chain_idx = [] + self.idx_pdb = [] + all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref} + available_chains = sorted(list(set(string.ascii_uppercase) - all_chains)) + # Iterate over each chain for last_res in length_bound: - chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} - {"_"} - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" - self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) + chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} + # If we are designing this chain, it will have a '-' in the contig map + # Renumber this chain from 1 + if "_" in chain_ids: + self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)] + chain_ids = chain_ids - {"_"} + # If there are no fixed residues that have a chain id, pick the first available letter + if not chain_ids: + chain_id = available_chains[0] + available_chains.remove(chain_id) + # Otherwise, use the chain of the fixed (motif) residues + else: + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" + chain_id = list(chain_ids)[0] + self.chain_idx += [chain_id] * (last_res - first_res) + # If this is a fixed chain, maintain the chain and residue numbering + else: + self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]] + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" + self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) first_res = last_res #################################### diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 2a3bf362..3fb6466e 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -141,6 +141,7 @@ def main(conf: HydraConfig) -> None: sampler.binderlen, chain_idx=sampler.chain_idx, bfacts=bfacts, + idx_pdb=sampler.idx_pdb ) # run metadata From d32205a17f244757fcbe94f4658329ef9cb0459c Mon Sep 17 00:00:00 2001 From: Brahm Yachnin Date: Tue, 17 Jun 2025 13:24:16 -0400 Subject: [PATCH 3/3] Extend available new chains to include lowercase letters Also print a warning if the user exceeds 52 chains (using up all upper- and lower-case chain ids). --- rfdiffusion/inference/model_runners.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index 8d0f5902..1a881ff7 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -317,7 +317,8 @@ def sample_init(self, return_forward_trajectory=False): self.chain_idx = [] self.idx_pdb = [] all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref} - available_chains = sorted(list(set(string.ascii_uppercase) - all_chains)) + available_chains = sorted(list(set(string.ascii_letters) - all_chains)) + # Iterate over each chain for last_res in length_bound: chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} @@ -328,6 +329,10 @@ def sample_init(self, return_forward_trajectory=False): chain_ids = chain_ids - {"_"} # If there are no fixed residues that have a chain id, pick the first available letter if not chain_ids: + if not available_chains: + raise ValueError(f"No available chains! You are trying to design a new chain, and you have " + f"already used all upper- and lower-case chain ids (up to 52 chains): " + f"{','.join(all_chains)}.") chain_id = available_chains[0] available_chains.remove(chain_id) # Otherwise, use the chain of the fixed (motif) residues