Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ __pycache__/
.DS_Store
._.DS_Store
._*
_*
flexcraft_params/
110 changes: 94 additions & 16 deletions flexcraft/rosetta/interface_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,91 @@
from copy import deepcopy

import pyrosetta as pr
from pyrosetta.rosetta.core.select.residue_selector import ChainSelector
from pyrosetta.rosetta.core.select.residue_selector import ChainSelector, NotResidueSelector, TrueResidueSelector

from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
from pyrosetta.rosetta.protocols.rosetta_scripts import XmlObjects

from pyrosetta.rosetta.core.pack.guidance_scoreterms.sap import calculate_sap
from pyrosetta.rosetta.protocols.simple_filters import ContactMolecularSurfaceFilter

from flexcraft.files.pdb import PDBFile
from flexcraft.data.data import DesignData
import flexcraft.sequence.aa_codes as aas

def score_interface(pdb_file: PDBFile | str, is_target):
# load pose
if isinstance(pdb_file, str):
pdb_file = PDBFile(path=str)

def score_interface(pdb_file: PDBFile | str, is_target, set_int=None):
# Rosetta crashes if more than 3 chains are present in the pose
# Relabel all target chains to chain A and renumber residues, relabel binder to chain B
# If multiple binders are generated, these will all be relabelled to chain B and renumbered
# If there are multiple binders, these will all be relabelled to chain B and renumbered
# set_int can specify target and binder chains - others will be excluded from rosetta analysis
# create a temporary pdb file (calc_pdb) with the renumbered/reassigned chains to use for calculations

if isinstance(pdb_file, str):
pdb_file = PDBFile(path=pdb_file)


if set_int is None:
intmode = "default"
# simply relabel target chain(s) to chain A and binder chain(s) to chain B
rosetta_data = deepcopy(pdb_file.data)
rosetta_data["chain_index"][is_target] = 0 # all target chains are now chain A
rosetta_data["residue_index"][is_target] = np.arange(1, np.sum(is_target) + 1) # renumber target residues
rosetta_data["chain_index"][~is_target] = 1 # binder is now chain B
rosetta_data["residue_index"][~is_target] = np.arange(1, np.sum(~is_target) + 1) # renumber binder residues

rosetta_data = deepcopy(pdb_file.data)
rosetta_data["chain_index"][is_target] = 0 # all target chains are now chain A
rosetta_data["residue_index"][is_target] = np.arange(1, np.sum(is_target) + 1) #renumber target residues
rosetta_data["chain_index"][~is_target] = 1 # binder is now chain B
rosetta_data["residue_index"][~is_target] = np.arange(1, np.sum(~is_target) + 1) #renumber binder residues
calc_pdb = PDBFile(data=rosetta_data, prefix="interface_calc_all")
else:
try:
# set_int can e.g. be "A,B_C,D" to calculate the interface between target chains A,B and binder chains C,D
intmode = "specified"
# Map provided chain letters to integers (A->0, B->1, ...)
target_str, binder_str = set_int.split("_")
CHAIN_NAMES = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
try:
target_chain_ids = [CHAIN_NAMES.index(c) for c in target_str.split(",")]
binder_chain_ids = [CHAIN_NAMES.index(c) for c in binder_str.split(",")]
except ValueError:
raise ValueError(f"Invalid chain ID in set_int='{set_int}'. Chains must be in '{CHAIN_NAMES}'.")

calc_pdb = PDBFile(data=rosetta_data, prefix="interface_calc_")
orig_data = deepcopy(pdb_file.data)
orig_chains = orig_data["chain_index"]
# create masks: True if the atom belongs to the requested chains
t_mask = np.isin(orig_chains, target_chain_ids)
b_mask = np.isin(orig_chains, binder_chain_ids)
keep_mask = t_mask | b_mask # union of both masks
if np.sum(keep_mask) == 0:
raise ValueError(f"No atoms found matching set_int: {set_int}")

# filter the original data for chains requested for the interface calculations
rosetta_data = {}
for k, v in orig_data.items():
# filter arrays that match the atom count (positions, aa types, etc.)
if hasattr(v, "__len__") and len(v) == len(orig_chains):
rosetta_data[k] = v[keep_mask]
else:
# keep metadata as is
rosetta_data[k] = v

# update is_target since we may have deleted target chains
tmp_is_target = t_mask[keep_mask]
# set all selected target chains to 0 (Chain A)
rosetta_data["chain_index"][tmp_is_target] = 0
# set all selected binder chains to 1 (Chain B)
rosetta_data["chain_index"][~tmp_is_target] = 1
# renumber residues
rosetta_data["residue_index"][tmp_is_target] = np.arange(1, np.sum(tmp_is_target) + 1)
rosetta_data["residue_index"][~tmp_is_target] = np.arange(1, np.sum(~tmp_is_target) + 1)

calc_pdb = PDBFile(data=DesignData(data=rosetta_data), prefix="interface_calc_specified")
except Exception as e:
print(e)
raise ValueError("Please check provided set_int. Format is 'A,B_C,D' to calculate interface between target chains A,B and binder chains C,D.")

pose = pr.pose_from_pdb(calc_pdb.path)

# analyze interface statistics
iam = InterfaceAnalyzerMover()
iam.set_interface("A_B")
iam.set_interface("A_B") # (requested) target chains are now chain A, and (requested) binder chains are chain B
scorefxn = pr.get_fa_scorefxn()
iam.set_scorefunction(scorefxn)
iam.set_compute_packstat(True)
Expand All @@ -75,8 +130,12 @@ def score_interface(pdb_file: PDBFile | str, is_target):
# We also use the pdb file with re-labelled residues here
data: DesignData = calc_pdb.data

target_data = data[is_target]
binder_data = data[(~is_target)]
if intmode == "default":
target_data = data[is_target]
binder_data = data[(~is_target)]
else:
target_data = data[tmp_is_target]
binder_data = data[(~tmp_is_target)]
hotspot = jnp.linalg.norm(target_data["atom_positions"][:, None, 1] - binder_data["atom_positions"][None, :, 1], axis=-1)
hotspot = (hotspot < 8.0).any(axis=1)
aa = aas.decode(target_data["aa"][hotspot], aas.AF2_CODE)
Expand Down Expand Up @@ -156,10 +215,29 @@ def score_interface(pdb_file: PDBFile | str, is_target):

surface_hydrophobicity = exp_apol_count/total_count

# Added: ContactMolecularSurface (combines dSASA and SC)
cms_filter = ContactMolecularSurfaceFilter()
cms_filter.selector1(ChainSelector("A"))
cms_filter.selector2(ChainSelector("B"))
contact_molecular_surface = cms_filter.report_sm(pose)

# Added: binder SAP
true_sel = TrueResidueSelector() # binder pose is isolated, so we select everything
binder_sap_score = calculate_sap(binder_pose, true_sel, true_sel, true_sel)
binder_len = binder_pose.total_residue()

if binder_len > 0:
binder_sap_per_res = binder_sap_score / binder_len
else:
binder_sap_per_res = 0.0

# output interface score array and amino acid counts at the interface
interface_scores = {
'binder_score': binder_score,
'surface_hydrophobicity': surface_hydrophobicity,
'binder_sap_score': binder_sap_score, # Added metric
'binder_sap_per_res': binder_sap_per_res, # Added metric
'contact_molecular_surface': contact_molecular_surface, # Added metric
'interface_sc': interface_sc,
'interface_packstat': interface_packstat,
'interface_dG': interface_dG,
Expand Down