diff --git a/Wrappers/Python/cil/optimisation/utilities/sampler.py b/Wrappers/Python/cil/optimisation/utilities/sampler.py index a19a946de4..5631d511de 100644 --- a/Wrappers/Python/cil/optimisation/utilities/sampler.py +++ b/Wrappers/Python/cil/optimisation/utilities/sampler.py @@ -20,6 +20,255 @@ import math import time +class SamplerFromFunction(): + def __init__(self, num_indices,function, sampling_type='from_function', prob_weights=None): + """ + TODO: How should a user call this? + A class to select from a list of indices {0, 1, …, S-1} + The function next() outputs a single next index from the list {0,1,…,S-1} . Different orders are possible including with and without replacement. To be run again and again, depending on how many iterations. + + + Parameters + ---------- + num_indices: int + The sampler will select from a list of indices {0, 1, …, S-1} with S=num_indices. + + sampling_type:str + The sampling type used. Choose from "sequential", "custom_order", "herman_meyer", "staggered", "random_with_replacement" and "random_without_replacement". + + + function: TODO: + + + prob_weights: list of floats of length num_indices that sum to 1. + Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. + + + """ + self.sampling_type=sampling_type + self.num_indices=num_indices + self.function=function + self.prob_weights=prob_weights + if self.prob_weights is None: + self.prob_weights=[1/num_indices]*num_indices + self.iteration_number=-1 + + + + def next(self): + """ + + A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. + + """ + + self.iteration_number+=1 + return (self.function(self.iteration_number)) + + def __next__(self): + """ + A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. + + Allows the user to call next(sampler), to get the same result as sampler.next()""" + return (self.next()) + + def get_samples(self, num_samples=20): + """ + Function that takes an index, num_samples, and returns the first num_samples as a numpy array. + + num_samples: int, default=20 + The number of samples to return. + + Example + ------- + + >>> sampler=Sampler.randomWithReplacement(5) + >>> print(sampler.get_samples()) + [2 4 2 4 1 3 2 2 1 2 4 4 2 3 2 1 0 4 2 3] + """ + save_last_index = self.iteration_number + self.iteration_number = -1 + output = [self.next() for _ in range(num_samples)] + self.iteration_number = save_last_index + return (np.array(output)) + + +class SamplerFromOrder(): + + def __init__(self, num_indices, order, sampling_type, prob_weights=None): + + """ + TODO: How should a user call this? + A class to select from a list of indices {0, 1, …, S-1} + The function next() outputs a single next index from the list {0,1,…,S-1} . Different orders are possible including with and without replacement. To be run again and again, depending on how many iterations. + + + Parameters + ---------- + num_indices: int + The sampler will select from a list of indices {0, 1, …, S-1} with S=num_indices. + + sampling_type:str + The sampling type used. Choose from "sequential", "custom_order", "herman_meyer", "staggered", "random_with_replacement" and "random_without_replacement". + + order: list of indices + The list of indices the method selects from using next. + + + prob_weights: list of floats of length num_indices that sum to 1. + Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. + + + + """ + self.prob_weights=prob_weights + self.type = sampling_type + self.num_indices = num_indices + self.order = order + self.initial_order = self.order + + + self.last_index = len(order)-1 + + + + def next(self): + """ + + A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. + + """ + + self.last_index = (self.last_index+1) % len(self.order) + return (self.order[self.last_index]) + + def __next__(self): + """ + A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. + + Allows the user to call next(sampler), to get the same result as sampler.next()""" + return (self.next()) + + def get_samples(self, num_samples=20): + """ + Function that takes an index, num_samples, and returns the first num_samples as a numpy array. + + num_samples: int, default=20 + The number of samples to return. + + Example + ------- + + >>> sampler=Sampler.randomWithReplacement(5) + >>> print(sampler.get_samples()) + [2 4 2 4 1 3 2 2 1 2 4 4 2 3 2 1 0 4 2 3] + + """ + save_last_index = self.last_index + self.last_index = len(self.order)-1 + output = [self.next() for _ in range(num_samples)] + self.last_index = save_last_index + return (np.array(output)) + + +class SamplerRandom(): + + r""" + A class to select from a list of indices {0, 1, …, S-1} using numpy.random.choice with and without replacement. + The function next() outputs a single next index from the list {0,1,…,S-1} . To be run again and again, depending on how many iterations. + + + Parameters + ---------- + num_indices: int + The sampler will select from a list of indices {0, 1, …, S-1} with S=num_indices. + + sampling_type:str + The sampling type used. + + + replace= bool + If True, sample with replace, otherwise sample without replacement + + + prob: list of floats of length num_indices that sum to 1. + For random sampling with replacement, this is the probability for each index to be called by next. + + seed:int, default=None + Random seed for the methods that use a numpy random number generator. If set to None, the seed will be set using the current time. + + prob_weights: list of floats of length num_indices that sum to 1. + Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. + + + """ + def __init__(self, num_indices, replace, sampling_type, prob=None, seed=None): + """ + This method is the internal init for the sampler method. Most users should call the static methods e.g. Sampler.sequential or Sampler.staggered. + + """ + + self.replace=replace + self.prob=prob + if prob is None: + self.prob=[1/num_indices]*num_indices + if replace: + self.prob_weights=prob + else: + self.prob_weights=[1/num_indices]*num_indices + self.type = sampling_type + self.num_indices = num_indices + if seed is not None: + self.seed = seed + else: + self.seed = int(time.time()) + self.generator = np.random.RandomState(self.seed) + + + + + def next(self): + """ + + A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. + + This function us used by samplers that select from a list of indices{0, 1, …, S-1}, with S=num_indices, randomly with and without replacement. + + """ + if self.replace: + return int(self.generator.choice(self.num_indices, 1, p=self.prob, replace=self.replace)) + else: + return int(self.generator.choice(self.num_indices, 1, p=self.prob, replace=self.replace)) + + + + def __next__(self): + """ + A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. + + Allows the user to call next(sampler), to get the same result as sampler.next()""" + return (self.next()) + + def get_samples(self, num_samples=20): + """ + Function that takes an index, num_samples, and returns the first num_samples as a numpy array. + + num_samples: int, default=20 + The number of samples to return. + + Example + ------- + + >>> sampler=Sampler.randomWithReplacement(5) + >>> print(sampler.get_samples()) + [2 4 2 4 1 3 2 2 1 2 4 4 2 3 2 1 0 4 2 3] + + """ + save_generator = self.generator + self.generator = np.random.RandomState(self.seed) + output = [self.next() for _ in range(num_samples)] + self.generator = save_generator + return (np.array(output)) class Sampler(): @@ -39,9 +288,6 @@ class Sampler(): order: list of indices The list of indices the method selects from using next. - shuffle= bool, default=False - If True, the drawing order changes every each `num_indices`, otherwise the same random order each time the data is sampled is used. - prob: list of floats of length num_indices that sum to 1. For random sampling with replacement, this is the probability for each index to be called by next. @@ -136,21 +382,23 @@ def sequential(num_indices): 0 """ order = list(range(num_indices)) - sampler = Sampler(num_indices, sampling_type='sequential', order=order, prob_weights=[1/num_indices]*num_indices) + sampler = SamplerFromOrder(num_indices, sampling_type='sequential', order=order, prob_weights=[1/num_indices]*num_indices) return sampler @staticmethod - def customOrder(customlist): + def customOrder(num_indices, customlist, prob_weights=None): #TODO: swap to underscores """ Function that outputs a sampler that outputs from a list, one entry at a time before cycling back to the beginning. customlist: list of indices The list that will be sampled from in order. + #TODO: + Example -------- - >>> sampler=Sampler.customOrder([1,4,6,7,8,9,11]) + >>> sampler=Sampler.customOrder(12,[1,4,6,7,8,9,11]) >>> print(sampler.get_samples(11)) >>> for _ in range(9): >>> print(sampler.next()) @@ -169,9 +417,15 @@ def customOrder(customlist): [1 4 6 7 8] """ - num_indices = len(customlist)#TODO: is this an issue - sampler = Sampler( - num_indices, sampling_type='custom_order', order=customlist, prob_weights=None)#TODO: + if prob_weights is None: + temp_list=[] + for i in range(num_indices): + temp_list.append(customlist.count(i)) + total=sum(temp_list) + prob_weights=[x/total for x in temp_list] + + sampler = SamplerFromOrder( + num_indices, sampling_type='custom_order', order=customlist, prob_weights=prob_weights) return sampler @staticmethod @@ -232,7 +486,7 @@ def _herman_meyer_order(n): return order order = _herman_meyer_order(num_indices) - sampler = Sampler( + sampler = SamplerFromOrder( num_indices, sampling_type='herman_meyer', order=order, prob_weights=[1/num_indices]*num_indices) return sampler @@ -279,7 +533,7 @@ def staggered(num_indices, offset): indices = list(range(num_indices)) order = [] [order.extend(indices[i::offset]) for i in range(offset)] - sampler = Sampler(num_indices, sampling_type='staggered', order=order, prob_weights=[1/num_indices]*num_indices) + sampler = SamplerFromOrder(num_indices, sampling_type='staggered', order=order, prob_weights=[1/num_indices]*num_indices) return sampler @staticmethod @@ -317,12 +571,12 @@ def randomWithReplacement(num_indices, prob=None, seed=None): if prob == None: prob = [1/num_indices] * num_indices - sampler = Sampler( - num_indices, sampling_type='random_with_replacement', prob=prob, seed=seed, prob_weights=prob) + sampler = SamplerRandom( + num_indices, sampling_type='random_with_replacement', replace=True, prob=prob, seed=seed) return sampler @staticmethod - def randomWithoutReplacement(num_indices, seed=None, shuffle=True): + def randomWithoutReplacement(num_indices, seed=None, prob=None): """ Function that takes a number of indices and returns a sampler which outputs from a list of indices {0, 1, …, S-1} with S=num_indices uniformly randomly without replacement. @@ -333,8 +587,6 @@ def randomWithoutReplacement(num_indices, seed=None, shuffle=True): seed:int, default=None Random seed for the random number generator. If set to None, the seed will be set using the current time. - shuffle:boolean, default=True - If True, the drawing order changes every each `num_indices`, otherwise the same random order each time the data is sampled is used. Example ------- @@ -342,105 +594,34 @@ def randomWithoutReplacement(num_indices, seed=None, shuffle=True): >>> print(sampler.get_samples(16)) [6 2 1 0 4 3 5 1 0 4 2 5 6 3 3 2] - Example - ------- - >>> sampler=Sampler.randomWithoutReplacement(7, seed=1, shuffle=False) - >>> print(sampler.get_samples(16)) - [6 2 1 0 4 3 5 6 2 1 0 4 3 5 6 2] - """ - order = list(range(num_indices)) - sampler = Sampler(num_indices, sampling_type='random_without_replacement', - order=order, shuffle=shuffle, seed=seed, prob_weights=[1/num_indices]*num_indices) - return sampler - - def __init__(self, num_indices, sampling_type, shuffle=False, order=None, prob=None, seed=None, prob_weights=None): """ - This method is the internal init for the sampler method. Most users should call the static methods e.g. Sampler.sequential or Sampler.staggered. - """ - self.prob_weights=prob_weights - self.type = sampling_type - self.num_indices = num_indices - if seed is not None: - self.seed = seed - else: - self.seed = int(time.time()) - self.generator = np.random.RandomState(self.seed) - self.order = order - if order is not None: - self.iterator = self._next_order - self.shuffle = shuffle - if self.type == 'random_without_replacement' and self.shuffle == False: - self.order = self.generator.permutation(self.order) - self.initial_order = self.order - self.prob = prob - if prob is not None: - self.iterator = self._next_prob - self.last_index = self.num_indices-1 - - def _next_order(self): - """ - The user should call sampler.next() or next(sampler) rather than use this function. - - A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. - - This function is used by samplers that sample without replacement. + sampler = SamplerRandom(num_indices, sampling_type='random_without_replacement', replace=False, seed=seed, prob=prob ) + return sampler + @staticmethod + def from_function(num_indices, function): """ - # print(self.last_index) - if self.shuffle == True and self.last_index == self.num_indices-1: - self.order = self.generator.permutation(self.order) - # print(self.order) - self.last_index = (self.last_index+1) % self.num_indices - return (self.order[self.last_index]) - - def _next_prob(self): - """ - The user should call sampler.next() or next(sampler) rather than use this function. + Function that takes a number of indices and returns a sampler which outputs from a list of indices {0, 1, …, S-1} with S=num_indices TODO: - A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. - - This function us used by samplers that select from a list of indices{0, 1, …, S-1}, with S=num_indices, randomly with replacement. - - """ - return int(self.generator.choice(self.num_indices, 1, p=self.prob)) - def next(self): - """ A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. """ + num_indices: int + The sampler will select from a list of indices {0, 1, …, S-1} with S=num_indices. - return (self.iterator()) + function: TODO: - def __next__(self): - """ - A function of the sampler that selects from a list of indices {0, 1, …, S-1}, with S=num_indices, the next sample according to the type of sampling. + + Example + ------- + TODO: - Allows the user to call next(sampler), to get the same result as sampler.next()""" - return (self.next()) - def get_samples(self, num_samples=20): """ - Function that takes an index, num_samples, and returns the first num_samples as a numpy array. - num_samples: int, default=20 - The number of samples to return. + sampler = SamplerFromFunction(num_indices, sampling_type='random_without_replacement', function=function ) + return sampler - Example - ------- - >>> sampler=Sampler.randomWithReplacement(5) - >>> print(sampler.get_samples()) - [2 4 2 4 1 3 2 2 1 2 4 4 2 3 2 1 0 4 2 3] - """ - save_generator = self.generator - save_last_index = self.last_index - self.last_index = self.num_indices-1 - save_order = self.order - self.order = self.initial_order - self.generator = np.random.RandomState(self.seed) - output = [self.next() for _ in range(num_samples)] - self.generator = save_generator - self.order = save_order - self.last_index = save_last_index - return (np.array(output)) + \ No newline at end of file diff --git a/Wrappers/Python/test/test_sampler.py b/Wrappers/Python/test/test_sampler.py index d751034d45..3de70afd05 100644 --- a/Wrappers/Python/test/test_sampler.py +++ b/Wrappers/Python/test/test_sampler.py @@ -37,55 +37,43 @@ def test_init(self): self.assertEqual(sampler.type, 'sequential') self.assertListEqual(sampler.order, list(range(10))) self.assertListEqual(sampler.initial_order, list(range(10))) - self.assertEqual(sampler.shuffle, False) - self.assertEqual(sampler.prob, None) self.assertEqual(sampler.last_index, 9) + self.assertListEqual(sampler.prob_weights, [1/10]*10) - sampler = Sampler.randomWithoutReplacement(7, shuffle=True) + sampler = Sampler.randomWithoutReplacement(7) self.assertEqual(sampler.num_indices, 7) self.assertEqual(sampler.type, 'random_without_replacement') - self.assertListEqual(sampler.order, list(range(7))) - self.assertListEqual(sampler.initial_order, list(range(7))) - self.assertEqual(sampler.shuffle, True) - self.assertEqual(sampler.prob, None) - self.assertEqual(sampler.last_index, 6) + self.assertEqual(sampler.prob, [1/7]*7) + self.assertListEqual(sampler.prob_weights, sampler.prob) - sampler = Sampler.randomWithoutReplacement(8, shuffle=False, seed=1) + sampler = Sampler.randomWithoutReplacement(8, seed=1) self.assertEqual(sampler.num_indices, 8) self.assertEqual(sampler.type, 'random_without_replacement') - self.assertEqual(sampler.shuffle, False) - self.assertEqual(sampler.prob, None) - self.assertEqual(sampler.last_index, 7) + self.assertEqual(sampler.prob, [1/8]*8) self.assertEqual(sampler.seed, 1) + self.assertListEqual(sampler.prob_weights, sampler.prob) sampler = Sampler.hermanMeyer(12) self.assertEqual(sampler.num_indices, 12) self.assertEqual(sampler.type, 'herman_meyer') - self.assertEqual(sampler.shuffle, False) - self.assertEqual(sampler.prob, None) self.assertEqual(sampler.last_index, 11) self.assertListEqual( sampler.order, [0, 6, 3, 9, 1, 7, 4, 10, 2, 8, 5, 11]) self.assertListEqual(sampler.initial_order, [ 0, 6, 3, 9, 1, 7, 4, 10, 2, 8, 5, 11]) + self.assertListEqual(sampler.prob_weights, [1/12] * 12) sampler = Sampler.randomWithReplacement(5) self.assertEqual(sampler.num_indices, 5) self.assertEqual(sampler.type, 'random_with_replacement') - self.assertEqual(sampler.order, None) - self.assertEqual(sampler.initial_order, None) - self.assertEqual(sampler.shuffle, False) self.assertListEqual(sampler.prob, [1/5] * 5) - self.assertEqual(sampler.last_index, 4) + self.assertListEqual(sampler.prob_weights, [1/5] * 5) sampler = Sampler.randomWithReplacement(4, [0.7, 0.1, 0.1, 0.1]) self.assertEqual(sampler.num_indices, 4) self.assertEqual(sampler.type, 'random_with_replacement') - self.assertEqual(sampler.order, None) - self.assertEqual(sampler.initial_order, None) - self.assertEqual(sampler.shuffle, False) self.assertListEqual(sampler.prob, [0.7, 0.1, 0.1, 0.1]) - self.assertEqual(sampler.last_index, 3) + self.assertListEqual(sampler.prob_weights, [0.7, 0.1, 0.1, 0.1]) sampler = Sampler.staggered(21, 4) self.assertEqual(sampler.num_indices, 21) @@ -94,76 +82,73 @@ def test_init(self): 0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 2, 6, 10, 14, 18, 3, 7, 11, 15, 19]) self.assertListEqual(sampler.initial_order, [ 0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 2, 6, 10, 14, 18, 3, 7, 11, 15, 19]) - self.assertEqual(sampler.shuffle, False) - self.assertEqual(sampler.prob, None) self.assertEqual(sampler.last_index, 20) + self.assertListEqual(sampler.prob_weights, [1/21] * 21) try: Sampler.staggered(22, 25) except ValueError: self.assertTrue(True) - sampler = Sampler.customOrder([1, 4, 6, 7, 8, 9, 11]) - self.assertEqual(sampler.num_indices, 7) + sampler = Sampler.customOrder(12, [1, 4, 6, 7, 8, 9, 11]) + self.assertEqual(sampler.num_indices, 12) self.assertEqual(sampler.type, 'custom_order') self.assertListEqual(sampler.order, [1, 4, 6, 7, 8, 9, 11]) self.assertListEqual(sampler.initial_order, [1, 4, 6, 7, 8, 9, 11]) - self.assertEqual(sampler.shuffle, False) - self.assertEqual(sampler.prob, None) self.assertEqual(sampler.last_index, 6) + self.assertListEqual(sampler.prob_weights, [ + 0, 1/7, 0, 0, 1/7, 0, 1/7, 1/7, 1/7, 1/7, 0, 1/7]) - - def test_sequential_iterator_and_get_samples(self): - - #Test the squential sampler + + # Test the squential sampler sampler = Sampler.sequential(10) for i in range(25): self.assertEqual(next(sampler), i % 10) - if i%5==0: # Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) - + if i % 5 == 0: # Check both that get samples works and doesn't interrupt the sampler + self.assertNumpyArrayEqual(sampler.get_samples(), np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + sampler = Sampler.sequential(10) for i in range(25): - self.assertEqual(sampler.next(), i % 10) # Repeat the test for .next() - if i%5==0: - self.assertNumpyArrayEqual(sampler.get_samples(), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) - - def test_random_without_replacement_iterator_and_get_samples(self): - #Test the random without replacement sampler - sampler = Sampler.randomWithoutReplacement(7, shuffle=True, seed=1) - order = [6, 2, 1, 0, 4, 3, 5, 1, 0, 4, 2, 5, - 6, 3, 3, 2, 1, 4, 0, 5, 6, 2, 6, 3, 4] + # Repeat the test for .next() + self.assertEqual(sampler.next(), i % 10) + if i % 5 == 0: + self.assertNumpyArrayEqual(sampler.get_samples(), np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + + def test_random_without_replacement_iterator_and_get_samples(self): + # Test the random without replacement sampler + sampler = Sampler.randomWithoutReplacement(7, seed=1) + order = [2, 5, 0, 2, 1, 0, 1, 2, 2, 3, 2, 4, + 1, 6, 0, 4, 2, 3, 0, 1, 5, 6, 2, 4, 6] for i in range(25): self.assertEqual(next(sampler), order[i]) - if i%4==0:# Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(6), np.array(order[:6])) - - #Repeat the test for shuffle=False - sampler = Sampler.randomWithoutReplacement(8, shuffle=False, seed=1) - order = [7, 2, 1, 6, 0, 4, 3, 5] - for i in range(25): - self.assertEqual(sampler.next(), order[i % 8]) - if i%5==0:# Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(5), np.array(order[:5])) + if i % 4 == 0: # Check both that get samples works and doesn't interrupt the sampler + self.assertNumpyArrayEqual( + sampler.get_samples(6), np.array(order[:6])) - def test_herman_meyer_iterator_and_get_samples(self): - #Test the Herman Meyer sampler + def test_herman_meyer_iterator_and_get_samples(self): + # Test the Herman Meyer sampler sampler = Sampler.hermanMeyer(12) - order = [0, 6, 3, 9, 1, 7, 4, 10, 2, 8, 5, 11, 0, 6, 3, 9, 1, 7, 4, 10, 2, 8, 5, 11] + order = [0, 6, 3, 9, 1, 7, 4, 10, 2, 8, 5, + 11, 0, 6, 3, 9, 1, 7, 4, 10, 2, 8, 5, 11] for i in range(25): self.assertEqual(sampler.next(), order[i % 12]) - if i%5==0:# Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(14), np.array(order[:14])) + if i % 5 == 0: # Check both that get samples works and doesn't interrupt the sampler + self.assertNumpyArrayEqual( + sampler.get_samples(14), np.array(order[:14])) - def test_random_with_replacement_iterator_and_get_samples(self): - #Test the Random with replacement sampler + def test_random_with_replacement_iterator_and_get_samples(self): + # Test the Random with replacement sampler sampler = Sampler.randomWithReplacement(5, seed=5) - order=[1, 4, 1, 4, 2, 3, 3, 2, 1, 0, 0, 3, 2, 0, 4, 1, 2, 1, 3, 2, 2, 1, 1, 1, 1] + order = [1, 4, 1, 4, 2, 3, 3, 2, 1, 0, 0, 3, + 2, 0, 4, 1, 2, 1, 3, 2, 2, 1, 1, 1, 1] for i in range(25): self.assertEqual(next(sampler), order[i]) - if i%5==0:# Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(14), np.array(order[:14])) + if i % 5 == 0: # Check both that get samples works and doesn't interrupt the sampler + self.assertNumpyArrayEqual( + sampler.get_samples(14), np.array(order[:14])) sampler = Sampler.randomWithReplacement( 4, [0.7, 0.1, 0.1, 0.1], seed=5) @@ -171,24 +156,28 @@ def test_random_with_replacement_iterator_and_get_samples(self): 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] for i in range(25): self.assertEqual(sampler.next(), order[i]) - if i%5==0:# Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(14), np.array(order[:14])) + if i % 5 == 0: # Check both that get samples works and doesn't interrupt the sampler + self.assertNumpyArrayEqual( + sampler.get_samples(14), np.array(order[:14])) - def test_staggered_iterator_and_get_samples(self): - #Test the staggered sampler + def test_staggered_iterator_and_get_samples(self): + # Test the staggered sampler sampler = Sampler.staggered(21, 4) order = [0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 2, 6, 10, 14, 18, 3, 7, 11, 15, 19] for i in range(25): self.assertEqual(next(sampler), order[i % 21]) - if i%5==0:# Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(10), np.array(order[:10])) - - def test_custom_order_iterator_and_get_samples(self): - #Test the custom order sampler - sampler = Sampler.customOrder([1, 4, 6, 7, 8, 9, 11]) - order = [1, 4, 6, 7, 8, 9, 11,1, 4, 6, 7, 8, 9, 11,1, 4, 6, 7, 8, 9, 11,1, 4, 6, 7, 8, 9, 11] + if i % 5 == 0: # Check both that get samples works and doesn't interrupt the sampler + self.assertNumpyArrayEqual( + sampler.get_samples(10), np.array(order[:10])) + + def test_custom_order_iterator_and_get_samples(self): + # Test the custom order sampler + sampler = Sampler.customOrder(12, [1, 4, 6, 7, 8, 9, 11]) + order = [1, 4, 6, 7, 8, 9, 11, 1, 4, 6, 7, 8, 9, + 11, 1, 4, 6, 7, 8, 9, 11, 1, 4, 6, 7, 8, 9, 11] for i in range(25): self.assertEqual(sampler.next(), order[i % 7]) - if i%5==0:# Check both that get samples works and doesn't interrupt the sampler - self.assertNumpyArrayEqual(sampler.get_samples(10), np.array(order[:10])) \ No newline at end of file + if i % 5 == 0: # Check both that get samples works and doesn't interrupt the sampler + self.assertNumpyArrayEqual( + sampler.get_samples(10), np.array(order[:10]))