Skip to content

Commit

Permalink
Make parallel tempering parallel again!
Browse files Browse the repository at this point in the history
  • Loading branch information
incud committed Jan 30, 2024
1 parent ac645d4 commit 1fdcf8f
Showing 1 changed file with 84 additions and 27 deletions.
111 changes: 84 additions & 27 deletions quepistasis/parallel_tempering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
from abc import ABC, abstractmethod
from utils import Trace
from joblib import Parallel, delayed
import random


class ParallelTempering(ABC):
Expand Down Expand Up @@ -83,6 +85,28 @@ def acceptance_probability(self, old_prob, new_prob, beta):
"""
return min(1, np.exp(beta * (old_prob - new_prob)))

def initialize_chain(self, i):
self.chains[i, 0] = self.initial_probability()
self.log_likes[i, 0] = self.log_like(self.chains[i, 0]) + self.log_prior(self.chains[i, 0])

def step_chain(self, step, i):
current_chain = self.chains[i, step - 1]
current_log_prob = self.log_like(current_chain) + self.log_prior(current_chain)

# Perform Metropolis-Hastings proposal
proposed_chain = self.metropolis_hastings_proposal(current_chain, self.betas[i])

# Calculate log likelihood and log prior for proposed chain
proposed_log_prob = self.log_like(proposed_chain) + self.log_prior(proposed_chain)

# Accept or reject the proposed chain
if np.log(np.random.uniform()) < self.acceptance_probability(current_log_prob, proposed_log_prob, self.betas[i]):
self.chains[i, step] = proposed_chain
self.log_likes[i, step] = proposed_log_prob
else:
self.chains[i, step] = current_chain
self.log_likes[i, step] = current_log_prob

def run(self):
"""
Run the Parallel Tempering MCMC algorithm.
Expand All @@ -95,27 +119,44 @@ def run(self):
self.log_likes = np.zeros((self.num_chains, self.num_steps))

for i in range(self.num_chains):
self.chains[i, 0] = self.initial_probability()
self.log_likes[i, 0] = self.log_like(self.chains[i, 0]) + self.log_prior(self.chains[i, 0])
self.initialize_chain(i)

for step in range(1, self.num_steps):

for i in range(self.num_chains):
current_chain = self.chains[i, step - 1]
current_log_prob = self.log_like(current_chain) + self.log_prior(current_chain)
self.step_chain(step, i)

# Exchange states between adjacent chains
for i in range(self.num_chains - 1):
diff_prob = self.betas[i + 1] * (self.log_likes[i, step] - self.log_likes[i + 1, step])
if np.log(np.random.uniform()) < diff_prob:
# Swap states between chains i and i+1
self.chains[i, step], self.chains[i + 1, step] = self.chains[i + 1, step], self.chains[i, step]
self.log_likes[i, step], self.log_likes[i + 1, step] = self.log_likes[i + 1, step], self.log_likes[i, step]

return self.chains, self.log_likes

# Perform Metropolis-Hastings proposal
proposed_chain = self.metropolis_hastings_proposal(current_chain, self.betas[i])
def run_parallel(self):
"""
Run the Parallel Tempering MCMC algorithm in parallel.
Returns:
- tuple: Chains and log likelihoods.
"""

ndim = len(self.initial_probability())
self.chains = np.zeros((self.num_chains, self.num_steps, ndim))
self.log_likes = np.zeros((self.num_chains, self.num_steps))

# Initialize chains and log likelihoods for each chain
Parallel(n_jobs=self.num_chains, backend="threading")(delayed(self.initialize_chain)(i) for i in range(self.num_chains))
print(self.chains)

# Calculate log likelihood and log prior for proposed chain
proposed_log_prob = self.log_like(proposed_chain) + self.log_prior(proposed_chain)
for step in range(1, self.num_steps):

# Accept or reject the proposed chain
if np.log(np.random.uniform()) < self.acceptance_probability(current_log_prob, proposed_log_prob, self.betas[i]):
self.chains[i, step] = proposed_chain
self.log_likes[i, step] = proposed_log_prob
else:
self.chains[i, step] = current_chain
self.log_likes[i, step] = current_log_prob
# Run Metropolis-Hastings proposals and updates in parallel
Parallel(n_jobs=self.num_chains, backend="threading")(delayed(self.step_chain)(step, i) for i in range(self.num_chains))
print(self.chains)

# Exchange states between adjacent chains
for i in range(self.num_chains - 1):
Expand All @@ -127,8 +168,6 @@ def run(self):

return self.chains, self.log_likes



class IsingParallelTempering(ParallelTempering):

def __init__(self, h, J, num_chains, num_steps, betas):
Expand Down Expand Up @@ -206,6 +245,16 @@ def get_solution(self):
return best_configuration, best_log_likelihood, best_energy


def sparsify_solution(spins, MAX_UP_SPIN=10):
# delete snps if they are too many
indexes_to_one = [i for i, value in enumerate(spins) if value == 1]
if len(indexes_to_one) > MAX_UP_SPIN:
random.shuffle(indexes_to_one)
for idx in indexes_to_one[MAX_UP_SPIN:]:
spins[idx] = -1
return spins


def run_parallel_tempering(h, J, num_chains, num_steps, save_path):

# trace the input except for the qubo formulation itself which is too large
Expand All @@ -215,25 +264,33 @@ def run_parallel_tempering(h, J, num_chains, num_steps, save_path):
trace.add('input', None, 'save_path', save_path)
betas = np.geomspace(1, 1e-2, num_chains)
ising_pt = IsingParallelTempering(h, J, num_chains, num_steps, betas)
ising_pt.run()
ising_pt.run_parallel()
spins, loglike, energy = ising_pt.get_solution()
trace.add('solution', None, 'spins', spins)
trace.add('solution', None, 'loglike', loglike)
trace.add('solution', None, 'energy', energy)
return spins.tolist()

def test_parallel_tempering():
spins = [int(i) for i in spins.tolist()]
trace.add('solution', 'original', 'spins', spins)
trace.add('solution', 'original', 'loglike', loglike)
trace.add('solution', 'original', 'energy', energy)
spins = sparsify_solution(spins, 10)
trace.add('solution', 'sparsified', 'spins', spins)
trace.add('solution', 'sparsified', 'loglike', None)
trace.add('solution', 'sparsified', 'energy', IsingParallelTempering.ising_energy(spins, h, J))
return spins

def test_parallel_tempering(parallel=False):
"""
Run a basic Ising model to check if the code works correctly.
"""
h = np.array([5, 10, -20])
J = {(0, 1): 1, (1, 2): 2}

num_chains = 4
num_steps = 1000
num_chains = 2
num_steps = 3
betas = np.geomspace(1, 1e-2, num_chains)

ising_pt = IsingParallelTempering(h, J, num_chains, num_steps, betas)
ising_pt.run()
if parallel:
ising_pt.run_parallel()
else:
ising_pt.run()
spins, loglike, energy = ising_pt.get_solution()
print(f"{spins=} {loglike=} {energy=}")

0 comments on commit 1fdcf8f

Please sign in to comment.