From 1fdcf8f3583e62eef36c29239ffd4cd5d8cd58ab Mon Sep 17 00:00:00 2001 From: Massimiliano Incudini Date: Tue, 30 Jan 2024 11:43:24 +0100 Subject: [PATCH] Make parallel tempering parallel again! --- quepistasis/parallel_tempering.py | 111 ++++++++++++++++++++++-------- 1 file changed, 84 insertions(+), 27 deletions(-) diff --git a/quepistasis/parallel_tempering.py b/quepistasis/parallel_tempering.py index a82cbdeed..5efa86935 100644 --- a/quepistasis/parallel_tempering.py +++ b/quepistasis/parallel_tempering.py @@ -1,6 +1,8 @@ import numpy as np from abc import ABC, abstractmethod from utils import Trace +from joblib import Parallel, delayed +import random class ParallelTempering(ABC): @@ -83,6 +85,28 @@ def acceptance_probability(self, old_prob, new_prob, beta): """ return min(1, np.exp(beta * (old_prob - new_prob))) + def initialize_chain(self, i): + self.chains[i, 0] = self.initial_probability() + self.log_likes[i, 0] = self.log_like(self.chains[i, 0]) + self.log_prior(self.chains[i, 0]) + + def step_chain(self, step, i): + current_chain = self.chains[i, step - 1] + current_log_prob = self.log_like(current_chain) + self.log_prior(current_chain) + + # Perform Metropolis-Hastings proposal + proposed_chain = self.metropolis_hastings_proposal(current_chain, self.betas[i]) + + # Calculate log likelihood and log prior for proposed chain + proposed_log_prob = self.log_like(proposed_chain) + self.log_prior(proposed_chain) + + # Accept or reject the proposed chain + if np.log(np.random.uniform()) < self.acceptance_probability(current_log_prob, proposed_log_prob, self.betas[i]): + self.chains[i, step] = proposed_chain + self.log_likes[i, step] = proposed_log_prob + else: + self.chains[i, step] = current_chain + self.log_likes[i, step] = current_log_prob + def run(self): """ Run the Parallel Tempering MCMC algorithm. @@ -95,27 +119,44 @@ def run(self): self.log_likes = np.zeros((self.num_chains, self.num_steps)) for i in range(self.num_chains): - self.chains[i, 0] = self.initial_probability() - self.log_likes[i, 0] = self.log_like(self.chains[i, 0]) + self.log_prior(self.chains[i, 0]) + self.initialize_chain(i) for step in range(1, self.num_steps): + for i in range(self.num_chains): - current_chain = self.chains[i, step - 1] - current_log_prob = self.log_like(current_chain) + self.log_prior(current_chain) + self.step_chain(step, i) + + # Exchange states between adjacent chains + for i in range(self.num_chains - 1): + diff_prob = self.betas[i + 1] * (self.log_likes[i, step] - self.log_likes[i + 1, step]) + if np.log(np.random.uniform()) < diff_prob: + # Swap states between chains i and i+1 + self.chains[i, step], self.chains[i + 1, step] = self.chains[i + 1, step], self.chains[i, step] + self.log_likes[i, step], self.log_likes[i + 1, step] = self.log_likes[i + 1, step], self.log_likes[i, step] + + return self.chains, self.log_likes - # Perform Metropolis-Hastings proposal - proposed_chain = self.metropolis_hastings_proposal(current_chain, self.betas[i]) + def run_parallel(self): + """ + Run the Parallel Tempering MCMC algorithm in parallel. + + Returns: + - tuple: Chains and log likelihoods. + """ + + ndim = len(self.initial_probability()) + self.chains = np.zeros((self.num_chains, self.num_steps, ndim)) + self.log_likes = np.zeros((self.num_chains, self.num_steps)) + + # Initialize chains and log likelihoods for each chain + Parallel(n_jobs=self.num_chains, backend="threading")(delayed(self.initialize_chain)(i) for i in range(self.num_chains)) + print(self.chains) - # Calculate log likelihood and log prior for proposed chain - proposed_log_prob = self.log_like(proposed_chain) + self.log_prior(proposed_chain) + for step in range(1, self.num_steps): - # Accept or reject the proposed chain - if np.log(np.random.uniform()) < self.acceptance_probability(current_log_prob, proposed_log_prob, self.betas[i]): - self.chains[i, step] = proposed_chain - self.log_likes[i, step] = proposed_log_prob - else: - self.chains[i, step] = current_chain - self.log_likes[i, step] = current_log_prob + # Run Metropolis-Hastings proposals and updates in parallel + Parallel(n_jobs=self.num_chains, backend="threading")(delayed(self.step_chain)(step, i) for i in range(self.num_chains)) + print(self.chains) # Exchange states between adjacent chains for i in range(self.num_chains - 1): @@ -127,8 +168,6 @@ def run(self): return self.chains, self.log_likes - - class IsingParallelTempering(ParallelTempering): def __init__(self, h, J, num_chains, num_steps, betas): @@ -206,6 +245,16 @@ def get_solution(self): return best_configuration, best_log_likelihood, best_energy +def sparsify_solution(spins, MAX_UP_SPIN=10): + # delete snps if they are too many + indexes_to_one = [i for i, value in enumerate(spins) if value == 1] + if len(indexes_to_one) > MAX_UP_SPIN: + random.shuffle(indexes_to_one) + for idx in indexes_to_one[MAX_UP_SPIN:]: + spins[idx] = -1 + return spins + + def run_parallel_tempering(h, J, num_chains, num_steps, save_path): # trace the input except for the qubo formulation itself which is too large @@ -215,25 +264,33 @@ def run_parallel_tempering(h, J, num_chains, num_steps, save_path): trace.add('input', None, 'save_path', save_path) betas = np.geomspace(1, 1e-2, num_chains) ising_pt = IsingParallelTempering(h, J, num_chains, num_steps, betas) - ising_pt.run() + ising_pt.run_parallel() spins, loglike, energy = ising_pt.get_solution() - trace.add('solution', None, 'spins', spins) - trace.add('solution', None, 'loglike', loglike) - trace.add('solution', None, 'energy', energy) - return spins.tolist() - -def test_parallel_tempering(): + spins = [int(i) for i in spins.tolist()] + trace.add('solution', 'original', 'spins', spins) + trace.add('solution', 'original', 'loglike', loglike) + trace.add('solution', 'original', 'energy', energy) + spins = sparsify_solution(spins, 10) + trace.add('solution', 'sparsified', 'spins', spins) + trace.add('solution', 'sparsified', 'loglike', None) + trace.add('solution', 'sparsified', 'energy', IsingParallelTempering.ising_energy(spins, h, J)) + return spins + +def test_parallel_tempering(parallel=False): """ Run a basic Ising model to check if the code works correctly. """ h = np.array([5, 10, -20]) J = {(0, 1): 1, (1, 2): 2} - num_chains = 4 - num_steps = 1000 + num_chains = 2 + num_steps = 3 betas = np.geomspace(1, 1e-2, num_chains) ising_pt = IsingParallelTempering(h, J, num_chains, num_steps, betas) - ising_pt.run() + if parallel: + ising_pt.run_parallel() + else: + ising_pt.run() spins, loglike, energy = ising_pt.get_solution() print(f"{spins=} {loglike=} {energy=}")