Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions basicrta/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
basicrta
A package to extract binding kinetics from molecular dynamics simulations
"""

# Add imports here
from importlib.metadata import version
from basicrta import *
import argparse
import subprocess

__version__ = version("basicrta")

commands = ['contacts', 'cluster', 'combine', 'kinetics', 'gibbs']

def main():
parser = argparse.ArgumentParser(prog='basicrta')
#parser.add_argument('command', help='Step in workflow to execute', nargs='+')
subparsers = parser.add_subparsers(dest='command')


parserA = subparsers.add_parser('contacts', help='Ahelp')
parserA.add_argument('--top', type=str, help='Topology')
parserA.add_argument('--traj', type=str, help='Trajectory')
parserA.add_argument('--sel1', type=str, help='First selection (group for'
'which tau is to be calculated)')
parserA.add_argument('--sel2', type=str, help='Second selection (group of'
'interest in interactions with first selection)')
parserA.add_argument('--cutoff', type=float, help='Value to use (in A) for'
'the maximum separation distance that constitutes a'
'contact.')
parserA.add_argument('--nproc', type=int, default=1, help='Number of'
'processes to use')
parserA.add_argument('--nslices', type=int, default=100, help='Number of'
'trajectory segments to use (if encountering a'
'memoryerror, try using a greater value)')

parserB = subparsers.add_parser('combine', add_help=True, help='Bhelp')
parserB.add_argument('--contacts', nargs='+', required=True,
help="List of contact pickle files to combine (e.g.,"
"contacts_7.0.pkl from different runs)")
parserB.add_argument( '--output', type=str, default='combined_contacts.pkl',
help="Output filename for combined contacts (default:"
"combined_contacts.pkl)")
parserB.add_argument( '--no-validate', action='store_true', help="Skip"
"compatibility validation (use with caution)")

parserC = subparsers.add_parser('cluster', help='Chelp')
parserC.add_argument('--nproc', type=int, default=1)
parserC.add_argument('--cutoff', type=float)
parserC.add_argument('--niter', type=int, default=110000)
parserC.add_argument('--prot', type=str, default=None, nargs='?')
parserC.add_argument('--label-cutoff', type=float, default=3,
dest='label_cutoff',
help='Only label residues with tau > '
'LABEL-CUTOFF * <tau>. ')
parserC.add_argument('--structure', type=str, nargs='?')
# use for default values
parserC.add_argument('--gskip', type=int, default=1000,
help='Gibbs skip parameter for decorrelated samples;'
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
parserC.add_argument('--burnin', type=int, default=10000,
help='Burn-in parameter, drop first N samples as equilibration;'
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')

parserD = subparsers.add_parser('gibbs', help='Dhelp')
parserD.add_argument('--contacts')
parserD.add_argument('--resid', type=int, default=None)
parserD.add_argument('--nproc', type=int, default=1)
parserD.add_argument('--niter', type=int, default=110000)
parserD.add_argument('--ncomp', type=int, default=15)

parserE = subparsers.add_parser('kinetics', help='Ehelp')
parserE.add_argument("--gibbs", type=str)
parserE.add_argument("--contacts", type=str)
parserE.add_argument("--top_n", type=int, nargs='?', default=None)
parserE.add_argument("--step", type=int, nargs='?', default=1)
parserE.add_argument("--wdensity", action='store_true')

args = parser.parse_args()
keys, values = vars(args).keys(), vars(args).values()
inarr = [[f"--{key}", f"{value}"] for key, value in zip(keys, values) if key!='command']
inlist = [aset for alist in inarr for aset in alist]
subprocess.run(['python',
f'/home/r2/opt/basicrta/basicrta/{args.command}.py'] +
inlist)

if __name__ == "__main__":
main()
132 changes: 127 additions & 5 deletions basicrta/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,135 @@

import os
import argparse
from basicrta.contacts import CombineContacts

class CombineContacts(object):
"""Class to combine contact timeseries from multiple repeat runs.

This class enables pooling data from multiple trajectory repeats and
calculating posteriors from all data together, rather than analyzing
each run separately.

:param contact_files: List of contact pickle files to combine
:type contact_files: list of str
:param output_name: Name for the combined output file (default: 'combined_contacts.pkl')
:type output_name: str, optional
:param validate_compatibility: Whether to validate that files are compatible (default: True)
:type validate_compatibility: bool, optional
"""

def __init__(self, contact_files, output_name='combined_contacts.pkl',
validate_compatibility=True):
self.contact_files = contact_files
self.output_name = output_name
self.validate_compatibility = validate_compatibility

if len(contact_files) < 2:
raise ValueError("At least 2 contact files are required for combining")

def _load_contact_file(self, filename):
"""Load a contact pickle file and return data and metadata."""
if not os.path.exists(filename):
raise FileNotFoundError(f"Contact file not found: {filename}")

with open(filename, 'rb') as f:
contacts = pickle.load(f)

metadata = contacts.dtype.metadata
return contacts, metadata

def _validate_compatibility(self, metadatas):
"""Validate that contact files are compatible for combining."""
reference = metadatas[0]

# Check that all files have the same atom groups
for i, meta in enumerate(metadatas[1:], 1):
# Compare cutoff
if meta['cutoff'] != reference['cutoff']:
raise ValueError(f"Incompatible cutoffs: file 0 has {reference['cutoff']}, "
f"file {i} has {meta['cutoff']}")

# Compare atom group selections by checking if resids match
ref_ag1_resids = set(reference['ag1'].residues.resids)
ref_ag2_resids = set(reference['ag2'].residues.resids)
meta_ag1_resids = set(meta['ag1'].residues.resids)
meta_ag2_resids = set(meta['ag2'].residues.resids)

if ref_ag1_resids != meta_ag1_resids:
raise ValueError(f"Incompatible ag1 residues between file 0 and file {i}")
if ref_ag2_resids != meta_ag2_resids:
raise ValueError(f"Incompatible ag2 residues between file 0 and file {i}")

# Check timesteps and warn if different
timesteps = [meta['ts'] for meta in metadatas]
if not all(abs(ts - timesteps[0]) < 1e-6 for ts in timesteps):
print("WARNING: Different timesteps detected across runs:")
for i, (filename, ts) in enumerate(zip(self.contact_files, timesteps)):
print(f" File {i} ({filename}): dt = {ts} ns")
print("This may affect residence time estimates, especially for fast events.")

def run(self):
"""Combine contact files and save the result."""
print(f"Combining {len(self.contact_files)} contact files...")

all_contacts = []
all_metadatas = []

# Load all contact files
for i, filename in enumerate(self.contact_files):
print(f"Loading file {i+1}/{len(self.contact_files)}: {filename}")
contacts, metadata = self._load_contact_file(filename)
all_contacts.append(contacts)
all_metadatas.append(metadata)

# Validate compatibility if requested
if self.validate_compatibility:
print("Validating file compatibility...")
self._validate_compatibility(all_metadatas)

# Combine contact data
print("Combining contact data...")

# Calculate total size and create combined array
total_size = sum(len(contacts) for contacts in all_contacts)
reference_metadata = all_metadatas[0].copy()

# Extend metadata to include trajectory source information
reference_metadata['source_files'] = self.contact_files
reference_metadata['n_trajectories'] = len(self.contact_files)

# Determine number of columns (5 for raw contacts, 4 for processed)
n_cols = all_contacts[0].shape[1]

# Create dtype with extended metadata
combined_dtype = np.dtype(np.float64, metadata=reference_metadata)

# Add trajectory source column (will be last column)
combined_contacts = np.zeros((total_size, n_cols + 1), dtype=np.float64)

# Combine data and add trajectory source information
offset = 0
for traj_idx, contacts in enumerate(all_contacts):
n_contacts = len(contacts)
# Copy original contact data
combined_contacts[offset:offset+n_contacts, :n_cols] = contacts[:]
# Add trajectory source index
combined_contacts[offset:offset+n_contacts, n_cols] = traj_idx
offset += n_contacts

# Create final memmap with proper dtype
final_contacts = combined_contacts.view(combined_dtype)

# Save combined contacts
print(f"Saving combined contacts to {self.output_name}...")
final_contacts.dump(self.output_name, protocol=5)

print(f"Successfully combined {len(self.contact_files)} files into {self.output_name}")
print(f"Total contacts: {total_size}")
print(f"Added trajectory source column (index {n_cols}) for kinetic clustering support")

return self.output_name

def main():
if __name__ == "__main__":
"""Main function for combining contact files."""
parser = argparse.ArgumentParser(
description="Combine contact timeseries from multiple repeat runs. "
Expand Down Expand Up @@ -82,6 +207,3 @@ def main():
print(f"ERROR: {e}")
return 1


if __name__ == '__main__':
exit(main())
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ doc = [
source = "https://github.com/becksteinlab/basicrta"
documentation = "https://basicrta.readthedocs.io"

[project.scripts]
basicrta = "basicrta.cli:main"

[tool.setuptools]
py-modules = []

Expand Down
Loading