Skip to content

Commit

Permalink
Add ability to save and restore state
Browse files Browse the repository at this point in the history
  • Loading branch information
dhadka committed Sep 22, 2024
1 parent 36541ec commit 94209d5
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 112 deletions.
18 changes: 18 additions & 0 deletions examples/resumable_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from platypus import NSGAII, DTLZ2, save_state, load_state

try:
import jsonpickle
except ImportError:
print("Please install jsonpickle to run this example!")

problem = DTLZ2()

algorithm = NSGAII(problem)
algorithm.run(5000)

# Save the algorithm state to a file.
save_state("state.json", algorithm, json=True, indent=4)

# Load the state and continue running.
algorithm = load_state("state.json", json=True)
algorithm.run(5000)
3 changes: 2 additions & 1 deletion platypus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@

from .weights import chebyshev, pbi, random_weights, normal_boundary_weights

from .io import save_objectives, load_objectives, save_json, load_json, dump, load
from .io import save_objectives, load_objectives, save_json, load_json, \
dump, load, save_state, load_state

from .deprecated import default_variator, default_mutator, nondominated_cmp

Expand Down
78 changes: 78 additions & 0 deletions platypus/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import os
import json
import pickle
import random
from .core import Algorithm, Archive, FixedLengthArray, Problem, Solution

def load_objectives(file, problem=None):
Expand Down Expand Up @@ -185,3 +187,79 @@ def load_json(file, problem=None):
"""
with open(os.fspath(file), "r") as f:
return load(f, problem=problem)

def save_state(file, algorithm, json=False, **kwargs):
"""Capture and save the algorithm state to a file.
Allows saving the algorithm and RNG state to a file, which can be later
restored using :meth:`load_state`. This is useful to either:
1. Inspect or record the configuration of an algorithm; or
2. Allow resuming runs from the state file.
This feature is experimental. Please take note of the following:
1. Platypus uses Python's :code:`random` library, which uses a global RNG
state. Reproducibility is not guaranteed when running multithreaded
or async programs.
2. State files are not guaranteed to be compatible across versions,
including minor or patch versions.
3. Internally, :mod:`pickle` and :mod:`jsonpickle` (for JSON output) are
used. Refer to each for warnings related to potential security
concerns when dealing with untrusted inputs.
Setting :code:`json=True` will produce human-readable output in a JSON
format. This requires the optional :code:`jsonpickle` dependency.
Parameters
----------
file: str, bytes, or os.PathLike
The file.
algorithm: Algorithm
The algorithm to capture.
json: bool
If :code:`False`, produces a binary-encoded state file.
If :code:`True`, produces a JSON file.
kwargs
Additional arguments passed to the pickle library.
"""
state = {"random": random.getstate(),
"algorithm": algorithm}

if json:
import jsonpickle
with open(os.fspath(file), "w") as f:
f.write(jsonpickle.dumps(state, **kwargs))
else:
with open(os.fspath(file), "wb") as f:
f.write(pickle.dumps(state, **kwargs))

def load_state(file, json=False, update_rng=True, **kwargs):
"""Restores the algorithm from a state file.
Refer to :meth:`save_state` for details and warnings when working with
state files.
Parameters
----------
file: str, bytes, or os.PathLike
The file.
json: bool
If :code:`False`, reads a binary-encoded state file.
If :code:`True`, reads a JSON file.
update_rng: bool
If :code:`True`, updates the RNG state. Must be set for reproducible
results.
kwargs
Additional arguments passed to the pickle library.
"""
if json:
import jsonpickle
with open(os.fspath(file), "r") as f:
state = jsonpickle.loads(f.read())
else:
with open(os.fspath(file), "rb") as f:
state = pickle.loads(f.read())

if update_rng:
random.setstate(state["random"])

return state["algorithm"]
4 changes: 2 additions & 2 deletions platypus/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self):

def generate(self, problem):
solution = Solution(problem)
solution.variables = [x.rand() for x in problem.types]
solution.variables[:] = [x.rand() for x in problem.types]
return solution

class InjectedPopulation(Generator):
Expand Down Expand Up @@ -65,7 +65,7 @@ def generate(self, problem):
else:
# Otherwise generate a random solution
solution = Solution(problem)
solution.variables = [x.rand() for x in problem.types]
solution.variables[:] = [x.rand() for x in problem.types]
return solution

class TournamentSelector(Selector):
Expand Down
21 changes: 21 additions & 0 deletions platypus/tests/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import math
from ..core import Problem, Solution, FixedLengthArray

class SolutionMixin:

def createSolution(self, *args):
problem = Problem(0, len(args))
solution = Solution(problem)
solution.objectives[:] = [float(x) for x in args]
return solution

def assertSimilar(self, a, b, epsilon=0.0000001):
if isinstance(a, Solution) and isinstance(b, Solution):
self.assertSimilar(a.variables, b.variables)
self.assertSimilar(a.objectives, b.objectives)
self.assertSimilar(a.constraints, b.constraints)
elif isinstance(a, (list, FixedLengthArray)) and isinstance(b, (list, FixedLengthArray)):
for (x, y) in zip(a, b):
self.assertSimilar(x, y, epsilon)
else:
self.assertLessEqual(math.fabs(b - a), epsilon)
86 changes: 40 additions & 46 deletions platypus/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,16 @@
import copy
import random
import unittest
from ..core import Constraint, Problem, Solution, ParetoDominance, Archive, \
from ._utils import SolutionMixin
from ..core import Constraint, ParetoDominance, Archive, EpsilonBoxArchive, \
nondominated_sort, nondominated_truncate, nondominated_prune, \
POSITIVE_INFINITY, nondominated_split, truncate_fitness, normalize, \
EpsilonBoxArchive
POSITIVE_INFINITY, nondominated_split, truncate_fitness, normalize
from ..errors import PlatypusError

def createSolution(*args):
problem = Problem(0, len(args))
solution = Solution(problem)
solution.objectives[:] = [float(x) for x in args]
return solution

class TestSolution(unittest.TestCase):
class TestSolution(SolutionMixin, unittest.TestCase):

def test_deepcopy(self):
orig = createSolution(4, 5)
orig = self.createSolution(4, 5)
orig.constraint_violation = 2
orig.evaluated = True

Expand Down Expand Up @@ -106,13 +100,13 @@ def test_invalid_empty_string(self):
with self.assertRaises(PlatypusError):
Constraint("")

class TestParetoDominance(unittest.TestCase):
class TestParetoDominance(SolutionMixin, unittest.TestCase):

def test_dominance(self):
dominance = ParetoDominance()
s1 = createSolution(0.0, 0.0)
s2 = createSolution(1.0, 1.0)
s3 = createSolution(0.0, 1.0)
s1 = self.createSolution(0.0, 0.0)
s2 = self.createSolution(1.0, 1.0)
s3 = self.createSolution(0.0, 1.0)

self.assertEqual(-1, dominance.compare(s1, s2))
self.assertEqual(1, dominance.compare(s2, s1))
Expand All @@ -121,9 +115,9 @@ def test_dominance(self):

def test_nondominance(self):
dominance = ParetoDominance()
s1 = createSolution(0.0, 1.0)
s2 = createSolution(0.5, 0.5)
s3 = createSolution(1.0, 0.0)
s1 = self.createSolution(0.0, 1.0)
s2 = self.createSolution(0.5, 0.5)
s3 = self.createSolution(1.0, 0.0)

self.assertEqual(0, dominance.compare(s1, s2))
self.assertEqual(0, dominance.compare(s2, s1))
Expand All @@ -132,12 +126,12 @@ def test_nondominance(self):
self.assertEqual(0, dominance.compare(s1, s3))
self.assertEqual(0, dominance.compare(s3, s1))

class TestArchive(unittest.TestCase):
class TestArchive(SolutionMixin, unittest.TestCase):

def test_dominance(self):
s1 = createSolution(1.0, 1.0)
s2 = createSolution(0.0, 0.0)
s3 = createSolution(0.0, 1.0)
s1 = self.createSolution(1.0, 1.0)
s2 = self.createSolution(0.0, 0.0)
s3 = self.createSolution(0.0, 1.0)

archive = Archive(ParetoDominance())
archive += [s1, s2, s3]
Expand All @@ -146,23 +140,23 @@ def test_dominance(self):
self.assertEqual(s2, archive[0])

def test_nondominance(self):
s1 = createSolution(0.0, 1.0)
s2 = createSolution(0.5, 0.5)
s3 = createSolution(1.0, 0.0)
s1 = self.createSolution(0.0, 1.0)
s2 = self.createSolution(0.5, 0.5)
s3 = self.createSolution(1.0, 0.0)

archive = Archive(ParetoDominance())
archive += [s1, s2, s3]

self.assertEqual(3, len(archive))

class TestNondominatedSort(unittest.TestCase):
class TestNondominatedSort(SolutionMixin, unittest.TestCase):

def setUp(self):
self.s1 = createSolution(0.0, 1.0)
self.s2 = createSolution(0.5, 0.5)
self.s3 = createSolution(1.0, 0.0)
self.s4 = createSolution(0.75, 0.75)
self.s5 = createSolution(1.0, 1.0)
self.s1 = self.createSolution(0.0, 1.0)
self.s2 = self.createSolution(0.5, 0.5)
self.s3 = self.createSolution(1.0, 0.0)
self.s4 = self.createSolution(0.75, 0.75)
self.s5 = self.createSolution(1.0, 1.0)

self.population = [self.s1, self.s2, self.s3, self.s4, self.s5]
random.shuffle(self.population)
Expand Down Expand Up @@ -280,12 +274,12 @@ def test_truncate_fitness_min(self):
self.assertIn(self.s5, result)
self.assertIn(self.s3, result)

class TestNormalize(unittest.TestCase):
class TestNormalize(SolutionMixin, unittest.TestCase):

def test_normalize(self):
s1 = createSolution(0, 2)
s2 = createSolution(2, 3)
s3 = createSolution(1, 1)
s1 = self.createSolution(0, 2)
s2 = self.createSolution(2, 3)
s3 = self.createSolution(1, 1)
solutions = [s1, s2, s3]

normalize(solutions)
Expand All @@ -294,19 +288,19 @@ def test_normalize(self):
self.assertEqual([1.0, 1.0], s2.normalized_objectives)
self.assertEqual([0.5, 0.0], s3.normalized_objectives)

class TestEpsilonBoxArchive(unittest.TestCase):
class TestEpsilonBoxArchive(SolutionMixin, unittest.TestCase):

def test_improvements(self):
s1 = createSolution(0.25, 0.25) # Improvement 1 - First solution always counted as improvement
s2 = createSolution(0.10, 0.10) # Improvement 2 - Dominates prior solution and in new epsilon-box
s3 = createSolution(0.24, 0.24)
s4 = createSolution(0.09, 0.50) # Improvement 3 - Non-dominated to all existing solutions
s5 = createSolution(0.50, 0.50)
s6 = createSolution(0.05, 0.05) # Improvement 4 - Dominates prior solution and in new epsilon-box
s7 = createSolution(0.04, 0.04)
s8 = createSolution(0.02, 0.02)
s9 = createSolution(0.00, 0.00)
s10 = createSolution(-0.01, -0.01) # Improvement 5 - Dominates prior solution and in new epsilon-box
s1 = self.createSolution(0.25, 0.25) # Improvement 1 - First solution always counted as improvement
s2 = self.createSolution(0.10, 0.10) # Improvement 2 - Dominates prior solution and in new epsilon-box
s3 = self.createSolution(0.24, 0.24)
s4 = self.createSolution(0.09, 0.50) # Improvement 3 - Non-dominated to all existing solutions
s5 = self.createSolution(0.50, 0.50)
s6 = self.createSolution(0.05, 0.05) # Improvement 4 - Dominates prior solution and in new epsilon-box
s7 = self.createSolution(0.04, 0.04)
s8 = self.createSolution(0.02, 0.02)
s9 = self.createSolution(0.00, 0.00)
s10 = self.createSolution(-0.01, -0.01) # Improvement 5 - Dominates prior solution and in new epsilon-box

solutions = [s1, s2, s3, s4, s5, s6, s7, s8, s9, s10]
expectedImprovements = [1, 2, 2, 3, 3, 4, 4, 4, 4, 5]
Expand Down
6 changes: 3 additions & 3 deletions platypus/tests/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# You should have received a copy of the GNU General Public License
# along with Platypus. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ._utils import SolutionMixin
from ..distance import euclidean_dist, manhattan_dist, DistanceMatrix
from .test_core import createSolution


class TestDistances(unittest.TestCase):
Expand All @@ -33,10 +33,10 @@ def test_manhattan(self):
self.assertAlmostEqual(2.0, manhattan_dist([0, 0], [1, 1]), delta=0.001)
self.assertAlmostEqual(2.0, manhattan_dist([1, 1], [0, 0]), delta=0.001)

class TestDistanceMatrix(unittest.TestCase):
class TestDistanceMatrix(SolutionMixin, unittest.TestCase):

def test(self):
solutions = [createSolution(0, 1), createSolution(0.5, 0.5), createSolution(0.75, 0.25), createSolution(1, 0)]
solutions = [self.createSolution(0, 1), self.createSolution(0.5, 0.5), self.createSolution(0.75, 0.25), self.createSolution(1, 0)]
matrix = DistanceMatrix(solutions)

self.assertAlmostEqual(0.353, matrix[1, 2], delta=0.001)
Expand Down
18 changes: 10 additions & 8 deletions platypus/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# along with Platypus. If not, see <http://www.gnu.org/licenses/>.
import unittest
from abc import ABCMeta, abstractmethod
from .test_core import createSolution
from ._utils import SolutionMixin
from ..filters import unique, group, truncate, matches, objectives_key, \
objective_value_at_index

Expand All @@ -31,22 +31,24 @@ def generator(*args):
def view(*args):
return {x: x for x in args}.keys()

class TestKeys(unittest.TestCase):
class TestKeys(SolutionMixin, unittest.TestCase):

def test_objectives(self):
s = createSolution(0.0, 1.0)
s = self.createSolution(0.0, 1.0)
self.assertEqual((0.0, 1.0), objectives_key(s))

def test_objective_value_at_index(self):
s = createSolution(0.0, 1.0)
s = self.createSolution(0.0, 1.0)
self.assertEqual(0.0, objective_value_at_index(0)(s))
self.assertEqual(1.0, objective_value_at_index(1)(s))

class FilterTestCase(unittest.TestCase, metaclass=ABCMeta):
class FilterTestCase(SolutionMixin, unittest.TestCase, metaclass=ABCMeta):

s1 = createSolution(0.0, 1.0)
s2 = createSolution(1.0, 0.0)
s3 = createSolution(0.0, 1.0)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.s1 = self.createSolution(0.0, 1.0)
self.s2 = self.createSolution(1.0, 0.0)
self.s3 = self.createSolution(0.0, 1.0)

@abstractmethod
def filter(self, solutions):
Expand Down
Loading

0 comments on commit 94209d5

Please sign in to comment.