diff --git a/examples/resumable_run.py b/examples/resumable_run.py
new file mode 100644
index 00000000..1c07f3d9
--- /dev/null
+++ b/examples/resumable_run.py
@@ -0,0 +1,18 @@
+from platypus import NSGAII, DTLZ2, save_state, load_state
+
+try:
+ import jsonpickle
+except ImportError:
+ print("Please install jsonpickle to run this example!")
+
+problem = DTLZ2()
+
+algorithm = NSGAII(problem)
+algorithm.run(5000)
+
+# Save the algorithm state to a file.
+save_state("state.json", algorithm, json=True, indent=4)
+
+# Load the state and continue running.
+algorithm = load_state("state.json", json=True)
+algorithm.run(5000)
diff --git a/platypus/__init__.py b/platypus/__init__.py
index 1d026472..0a234c8b 100644
--- a/platypus/__init__.py
+++ b/platypus/__init__.py
@@ -65,7 +65,8 @@
from .weights import chebyshev, pbi, random_weights, normal_boundary_weights
-from .io import save_objectives, load_objectives, save_json, load_json, dump, load
+from .io import save_objectives, load_objectives, save_json, load_json, \
+ dump, load, save_state, load_state
from .deprecated import default_variator, default_mutator, nondominated_cmp
diff --git a/platypus/io.py b/platypus/io.py
index 9e4520f2..7cdff24b 100644
--- a/platypus/io.py
+++ b/platypus/io.py
@@ -19,6 +19,8 @@
import os
import json
+import pickle
+import random
from .core import Algorithm, Archive, FixedLengthArray, Problem, Solution
def load_objectives(file, problem=None):
@@ -185,3 +187,79 @@ def load_json(file, problem=None):
"""
with open(os.fspath(file), "r") as f:
return load(f, problem=problem)
+
+def save_state(file, algorithm, json=False, **kwargs):
+ """Capture and save the algorithm state to a file.
+
+ Allows saving the algorithm and RNG state to a file, which can be later
+ restored using :meth:`load_state`. This is useful to either:
+ 1. Inspect or record the configuration of an algorithm; or
+ 2. Allow resuming runs from the state file.
+
+ This feature is experimental. Please take note of the following:
+ 1. Platypus uses Python's :code:`random` library, which uses a global RNG
+ state. Reproducibility is not guaranteed when running multithreaded
+ or async programs.
+ 2. State files are not guaranteed to be compatible across versions,
+ including minor or patch versions.
+ 3. Internally, :mod:`pickle` and :mod:`jsonpickle` (for JSON output) are
+ used. Refer to each for warnings related to potential security
+ concerns when dealing with untrusted inputs.
+
+ Setting :code:`json=True` will produce human-readable output in a JSON
+ format. This requires the optional :code:`jsonpickle` dependency.
+
+ Parameters
+ ----------
+ file: str, bytes, or os.PathLike
+ The file.
+ algorithm: Algorithm
+ The algorithm to capture.
+ json: bool
+ If :code:`False`, produces a binary-encoded state file.
+ If :code:`True`, produces a JSON file.
+ kwargs
+ Additional arguments passed to the pickle library.
+ """
+ state = {"random": random.getstate(),
+ "algorithm": algorithm}
+
+ if json:
+ import jsonpickle
+ with open(os.fspath(file), "w") as f:
+ f.write(jsonpickle.dumps(state, **kwargs))
+ else:
+ with open(os.fspath(file), "wb") as f:
+ f.write(pickle.dumps(state, **kwargs))
+
+def load_state(file, json=False, update_rng=True, **kwargs):
+ """Restores the algorithm from a state file.
+
+ Refer to :meth:`save_state` for details and warnings when working with
+ state files.
+
+ Parameters
+ ----------
+ file: str, bytes, or os.PathLike
+ The file.
+ json: bool
+ If :code:`False`, reads a binary-encoded state file.
+ If :code:`True`, reads a JSON file.
+ update_rng: bool
+ If :code:`True`, updates the RNG state. Must be set for reproducible
+ results.
+ kwargs
+ Additional arguments passed to the pickle library.
+ """
+ if json:
+ import jsonpickle
+ with open(os.fspath(file), "r") as f:
+ state = jsonpickle.loads(f.read())
+ else:
+ with open(os.fspath(file), "rb") as f:
+ state = pickle.loads(f.read())
+
+ if update_rng:
+ random.setstate(state["random"])
+
+ return state["algorithm"]
diff --git a/platypus/operators.py b/platypus/operators.py
index 918fe5b4..76d07e48 100644
--- a/platypus/operators.py
+++ b/platypus/operators.py
@@ -36,7 +36,7 @@ def __init__(self):
def generate(self, problem):
solution = Solution(problem)
- solution.variables = [x.rand() for x in problem.types]
+ solution.variables[:] = [x.rand() for x in problem.types]
return solution
class InjectedPopulation(Generator):
@@ -65,7 +65,7 @@ def generate(self, problem):
else:
# Otherwise generate a random solution
solution = Solution(problem)
- solution.variables = [x.rand() for x in problem.types]
+ solution.variables[:] = [x.rand() for x in problem.types]
return solution
class TournamentSelector(Selector):
diff --git a/platypus/tests/_utils.py b/platypus/tests/_utils.py
new file mode 100644
index 00000000..46114192
--- /dev/null
+++ b/platypus/tests/_utils.py
@@ -0,0 +1,21 @@
+import math
+from ..core import Problem, Solution, FixedLengthArray
+
+class SolutionMixin:
+
+ def createSolution(self, *args):
+ problem = Problem(0, len(args))
+ solution = Solution(problem)
+ solution.objectives[:] = [float(x) for x in args]
+ return solution
+
+ def assertSimilar(self, a, b, epsilon=0.0000001):
+ if isinstance(a, Solution) and isinstance(b, Solution):
+ self.assertSimilar(a.variables, b.variables)
+ self.assertSimilar(a.objectives, b.objectives)
+ self.assertSimilar(a.constraints, b.constraints)
+ elif isinstance(a, (list, FixedLengthArray)) and isinstance(b, (list, FixedLengthArray)):
+ for (x, y) in zip(a, b):
+ self.assertSimilar(x, y, epsilon)
+ else:
+ self.assertLessEqual(math.fabs(b - a), epsilon)
diff --git a/platypus/tests/test_core.py b/platypus/tests/test_core.py
index fdc16ffb..dfc302de 100644
--- a/platypus/tests/test_core.py
+++ b/platypus/tests/test_core.py
@@ -19,22 +19,16 @@
import copy
import random
import unittest
-from ..core import Constraint, Problem, Solution, ParetoDominance, Archive, \
+from ._utils import SolutionMixin
+from ..core import Constraint, ParetoDominance, Archive, EpsilonBoxArchive, \
nondominated_sort, nondominated_truncate, nondominated_prune, \
- POSITIVE_INFINITY, nondominated_split, truncate_fitness, normalize, \
- EpsilonBoxArchive
+ POSITIVE_INFINITY, nondominated_split, truncate_fitness, normalize
from ..errors import PlatypusError
-def createSolution(*args):
- problem = Problem(0, len(args))
- solution = Solution(problem)
- solution.objectives[:] = [float(x) for x in args]
- return solution
-
-class TestSolution(unittest.TestCase):
+class TestSolution(SolutionMixin, unittest.TestCase):
def test_deepcopy(self):
- orig = createSolution(4, 5)
+ orig = self.createSolution(4, 5)
orig.constraint_violation = 2
orig.evaluated = True
@@ -106,13 +100,13 @@ def test_invalid_empty_string(self):
with self.assertRaises(PlatypusError):
Constraint("")
-class TestParetoDominance(unittest.TestCase):
+class TestParetoDominance(SolutionMixin, unittest.TestCase):
def test_dominance(self):
dominance = ParetoDominance()
- s1 = createSolution(0.0, 0.0)
- s2 = createSolution(1.0, 1.0)
- s3 = createSolution(0.0, 1.0)
+ s1 = self.createSolution(0.0, 0.0)
+ s2 = self.createSolution(1.0, 1.0)
+ s3 = self.createSolution(0.0, 1.0)
self.assertEqual(-1, dominance.compare(s1, s2))
self.assertEqual(1, dominance.compare(s2, s1))
@@ -121,9 +115,9 @@ def test_dominance(self):
def test_nondominance(self):
dominance = ParetoDominance()
- s1 = createSolution(0.0, 1.0)
- s2 = createSolution(0.5, 0.5)
- s3 = createSolution(1.0, 0.0)
+ s1 = self.createSolution(0.0, 1.0)
+ s2 = self.createSolution(0.5, 0.5)
+ s3 = self.createSolution(1.0, 0.0)
self.assertEqual(0, dominance.compare(s1, s2))
self.assertEqual(0, dominance.compare(s2, s1))
@@ -132,12 +126,12 @@ def test_nondominance(self):
self.assertEqual(0, dominance.compare(s1, s3))
self.assertEqual(0, dominance.compare(s3, s1))
-class TestArchive(unittest.TestCase):
+class TestArchive(SolutionMixin, unittest.TestCase):
def test_dominance(self):
- s1 = createSolution(1.0, 1.0)
- s2 = createSolution(0.0, 0.0)
- s3 = createSolution(0.0, 1.0)
+ s1 = self.createSolution(1.0, 1.0)
+ s2 = self.createSolution(0.0, 0.0)
+ s3 = self.createSolution(0.0, 1.0)
archive = Archive(ParetoDominance())
archive += [s1, s2, s3]
@@ -146,23 +140,23 @@ def test_dominance(self):
self.assertEqual(s2, archive[0])
def test_nondominance(self):
- s1 = createSolution(0.0, 1.0)
- s2 = createSolution(0.5, 0.5)
- s3 = createSolution(1.0, 0.0)
+ s1 = self.createSolution(0.0, 1.0)
+ s2 = self.createSolution(0.5, 0.5)
+ s3 = self.createSolution(1.0, 0.0)
archive = Archive(ParetoDominance())
archive += [s1, s2, s3]
self.assertEqual(3, len(archive))
-class TestNondominatedSort(unittest.TestCase):
+class TestNondominatedSort(SolutionMixin, unittest.TestCase):
def setUp(self):
- self.s1 = createSolution(0.0, 1.0)
- self.s2 = createSolution(0.5, 0.5)
- self.s3 = createSolution(1.0, 0.0)
- self.s4 = createSolution(0.75, 0.75)
- self.s5 = createSolution(1.0, 1.0)
+ self.s1 = self.createSolution(0.0, 1.0)
+ self.s2 = self.createSolution(0.5, 0.5)
+ self.s3 = self.createSolution(1.0, 0.0)
+ self.s4 = self.createSolution(0.75, 0.75)
+ self.s5 = self.createSolution(1.0, 1.0)
self.population = [self.s1, self.s2, self.s3, self.s4, self.s5]
random.shuffle(self.population)
@@ -280,12 +274,12 @@ def test_truncate_fitness_min(self):
self.assertIn(self.s5, result)
self.assertIn(self.s3, result)
-class TestNormalize(unittest.TestCase):
+class TestNormalize(SolutionMixin, unittest.TestCase):
def test_normalize(self):
- s1 = createSolution(0, 2)
- s2 = createSolution(2, 3)
- s3 = createSolution(1, 1)
+ s1 = self.createSolution(0, 2)
+ s2 = self.createSolution(2, 3)
+ s3 = self.createSolution(1, 1)
solutions = [s1, s2, s3]
normalize(solutions)
@@ -294,19 +288,19 @@ def test_normalize(self):
self.assertEqual([1.0, 1.0], s2.normalized_objectives)
self.assertEqual([0.5, 0.0], s3.normalized_objectives)
-class TestEpsilonBoxArchive(unittest.TestCase):
+class TestEpsilonBoxArchive(SolutionMixin, unittest.TestCase):
def test_improvements(self):
- s1 = createSolution(0.25, 0.25) # Improvement 1 - First solution always counted as improvement
- s2 = createSolution(0.10, 0.10) # Improvement 2 - Dominates prior solution and in new epsilon-box
- s3 = createSolution(0.24, 0.24)
- s4 = createSolution(0.09, 0.50) # Improvement 3 - Non-dominated to all existing solutions
- s5 = createSolution(0.50, 0.50)
- s6 = createSolution(0.05, 0.05) # Improvement 4 - Dominates prior solution and in new epsilon-box
- s7 = createSolution(0.04, 0.04)
- s8 = createSolution(0.02, 0.02)
- s9 = createSolution(0.00, 0.00)
- s10 = createSolution(-0.01, -0.01) # Improvement 5 - Dominates prior solution and in new epsilon-box
+ s1 = self.createSolution(0.25, 0.25) # Improvement 1 - First solution always counted as improvement
+ s2 = self.createSolution(0.10, 0.10) # Improvement 2 - Dominates prior solution and in new epsilon-box
+ s3 = self.createSolution(0.24, 0.24)
+ s4 = self.createSolution(0.09, 0.50) # Improvement 3 - Non-dominated to all existing solutions
+ s5 = self.createSolution(0.50, 0.50)
+ s6 = self.createSolution(0.05, 0.05) # Improvement 4 - Dominates prior solution and in new epsilon-box
+ s7 = self.createSolution(0.04, 0.04)
+ s8 = self.createSolution(0.02, 0.02)
+ s9 = self.createSolution(0.00, 0.00)
+ s10 = self.createSolution(-0.01, -0.01) # Improvement 5 - Dominates prior solution and in new epsilon-box
solutions = [s1, s2, s3, s4, s5, s6, s7, s8, s9, s10]
expectedImprovements = [1, 2, 2, 3, 3, 4, 4, 4, 4, 5]
diff --git a/platypus/tests/test_distance.py b/platypus/tests/test_distance.py
index ee4a924a..33d94863 100644
--- a/platypus/tests/test_distance.py
+++ b/platypus/tests/test_distance.py
@@ -17,8 +17,8 @@
# You should have received a copy of the GNU General Public License
# along with Platypus. If not, see .
import unittest
+from ._utils import SolutionMixin
from ..distance import euclidean_dist, manhattan_dist, DistanceMatrix
-from .test_core import createSolution
class TestDistances(unittest.TestCase):
@@ -33,10 +33,10 @@ def test_manhattan(self):
self.assertAlmostEqual(2.0, manhattan_dist([0, 0], [1, 1]), delta=0.001)
self.assertAlmostEqual(2.0, manhattan_dist([1, 1], [0, 0]), delta=0.001)
-class TestDistanceMatrix(unittest.TestCase):
+class TestDistanceMatrix(SolutionMixin, unittest.TestCase):
def test(self):
- solutions = [createSolution(0, 1), createSolution(0.5, 0.5), createSolution(0.75, 0.25), createSolution(1, 0)]
+ solutions = [self.createSolution(0, 1), self.createSolution(0.5, 0.5), self.createSolution(0.75, 0.25), self.createSolution(1, 0)]
matrix = DistanceMatrix(solutions)
self.assertAlmostEqual(0.353, matrix[1, 2], delta=0.001)
diff --git a/platypus/tests/test_filters.py b/platypus/tests/test_filters.py
index 9972f6e0..3cdc8431 100644
--- a/platypus/tests/test_filters.py
+++ b/platypus/tests/test_filters.py
@@ -18,7 +18,7 @@
# along with Platypus. If not, see .
import unittest
from abc import ABCMeta, abstractmethod
-from .test_core import createSolution
+from ._utils import SolutionMixin
from ..filters import unique, group, truncate, matches, objectives_key, \
objective_value_at_index
@@ -31,22 +31,24 @@ def generator(*args):
def view(*args):
return {x: x for x in args}.keys()
-class TestKeys(unittest.TestCase):
+class TestKeys(SolutionMixin, unittest.TestCase):
def test_objectives(self):
- s = createSolution(0.0, 1.0)
+ s = self.createSolution(0.0, 1.0)
self.assertEqual((0.0, 1.0), objectives_key(s))
def test_objective_value_at_index(self):
- s = createSolution(0.0, 1.0)
+ s = self.createSolution(0.0, 1.0)
self.assertEqual(0.0, objective_value_at_index(0)(s))
self.assertEqual(1.0, objective_value_at_index(1)(s))
-class FilterTestCase(unittest.TestCase, metaclass=ABCMeta):
+class FilterTestCase(SolutionMixin, unittest.TestCase, metaclass=ABCMeta):
- s1 = createSolution(0.0, 1.0)
- s2 = createSolution(1.0, 0.0)
- s3 = createSolution(0.0, 1.0)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.s1 = self.createSolution(0.0, 1.0)
+ self.s2 = self.createSolution(1.0, 0.0)
+ self.s3 = self.createSolution(0.0, 1.0)
@abstractmethod
def filter(self, solutions):
diff --git a/platypus/tests/test_indicators.py b/platypus/tests/test_indicators.py
index b3b77db8..bbc525a1 100644
--- a/platypus/tests/test_indicators.py
+++ b/platypus/tests/test_indicators.py
@@ -18,69 +18,69 @@
# along with Platypus. If not, see .
import math
import unittest
-from .test_core import createSolution
+from ._utils import SolutionMixin
from ..indicators import GenerationalDistance, InvertedGenerationalDistance, \
EpsilonIndicator, Spacing, Hypervolume
from ..core import Solution, Problem, POSITIVE_INFINITY
-class TestGenerationalDistance(unittest.TestCase):
+class TestGenerationalDistance(SolutionMixin, unittest.TestCase):
def test(self):
- reference_set = [createSolution(0, 1), createSolution(1, 0)]
+ reference_set = [self.createSolution(0, 1), self.createSolution(1, 0)]
gd = GenerationalDistance(reference_set)
set = []
self.assertEqual(POSITIVE_INFINITY, gd(set))
- set = [createSolution(0.0, 1.0)]
+ set = [self.createSolution(0.0, 1.0)]
self.assertEqual(0.0, gd(set))
- set = [createSolution(0.0, 1.0), createSolution(1.0, 0.0)]
+ set = [self.createSolution(0.0, 1.0), self.createSolution(1.0, 0.0)]
self.assertEqual(0.0, gd(set))
- set = [createSolution(2.0, 2.0)]
+ set = [self.createSolution(2.0, 2.0)]
self.assertEqual(math.sqrt(5.0), gd(set))
- set = [createSolution(0.5, 0.0), createSolution(0.0, 0.5)]
+ set = [self.createSolution(0.5, 0.0), self.createSolution(0.0, 0.5)]
self.assertEqual(math.sqrt(0.5)/2.0, gd(set))
-class TestInvertedGenerationalDistance(unittest.TestCase):
+class TestInvertedGenerationalDistance(SolutionMixin, unittest.TestCase):
def test(self):
- reference_set = [createSolution(0, 1), createSolution(1, 0)]
+ reference_set = [self.createSolution(0, 1), self.createSolution(1, 0)]
igd = InvertedGenerationalDistance(reference_set)
set = []
self.assertEqual(POSITIVE_INFINITY, igd(set))
- set = [createSolution(0.0, 1.0)]
+ set = [self.createSolution(0.0, 1.0)]
self.assertEqual(math.sqrt(2.0)/2.0, igd(set))
- set = [createSolution(0.0, 1.0), createSolution(1.0, 0.0)]
+ set = [self.createSolution(0.0, 1.0), self.createSolution(1.0, 0.0)]
self.assertEqual(0.0, igd(set))
- set = [createSolution(2.0, 2.0)]
+ set = [self.createSolution(2.0, 2.0)]
self.assertEqual(2.0*math.sqrt(5.0)/2.0, igd(set))
-class TestEpsilonIndicator(unittest.TestCase):
+class TestEpsilonIndicator(SolutionMixin, unittest.TestCase):
def test(self):
- reference_set = [createSolution(0, 1), createSolution(1, 0)]
+ reference_set = [self.createSolution(0, 1), self.createSolution(1, 0)]
ei = EpsilonIndicator(reference_set)
set = []
self.assertEqual(POSITIVE_INFINITY, ei(set))
- set = [createSolution(0.0, 1.0)]
+ set = [self.createSolution(0.0, 1.0)]
self.assertEqual(1.0, ei(set))
- set = [createSolution(0.0, 1.0), createSolution(1.0, 0.0)]
+ set = [self.createSolution(0.0, 1.0), self.createSolution(1.0, 0.0)]
self.assertEqual(0.0, ei(set))
- set = [createSolution(2.0, 2.0)]
+ set = [self.createSolution(2.0, 2.0)]
self.assertEqual(2.0, ei(set))
-class TestSpacing(unittest.TestCase):
+class TestSpacing(SolutionMixin, unittest.TestCase):
def test(self):
sp = Spacing()
@@ -88,41 +88,41 @@ def test(self):
set = []
self.assertEqual(0.0, sp(set))
- set = [createSolution(0.5, 0.5)]
+ set = [self.createSolution(0.5, 0.5)]
self.assertEqual(0.0, sp(set))
- set = [createSolution(0.0, 1.0), createSolution(1.0, 0.0)]
+ set = [self.createSolution(0.0, 1.0), self.createSolution(1.0, 0.0)]
self.assertEqual(0.0, sp(set))
- set = [createSolution(0.0, 1.0), createSolution(0.5, 0.5), createSolution(1.0, 0.0)]
+ set = [self.createSolution(0.0, 1.0), self.createSolution(0.5, 0.5), self.createSolution(1.0, 0.0)]
self.assertEqual(0.0, sp(set))
- set = [createSolution(0.0, 1.0), createSolution(0.25, 0.75), createSolution(1.0, 0.0)]
+ set = [self.createSolution(0.0, 1.0), self.createSolution(0.25, 0.75), self.createSolution(1.0, 0.0)]
self.assertGreater(sp(set), 0.0)
-class TestHypervolume(unittest.TestCase):
+class TestHypervolume(SolutionMixin, unittest.TestCase):
def test(self):
- reference_set = [createSolution(0.0, 1.0), createSolution(1.0, 0.0)]
+ reference_set = [self.createSolution(0.0, 1.0), self.createSolution(1.0, 0.0)]
hyp = Hypervolume(reference_set)
set = []
self.assertEqual(0.0, hyp(set))
- set = [createSolution(0.5, 0.5)]
+ set = [self.createSolution(0.5, 0.5)]
self.assertEqual(0.25, hyp(set))
- set = [createSolution(0.0, 0.0)]
+ set = [self.createSolution(0.0, 0.0)]
self.assertEqual(1.0, hyp(set))
- set = [createSolution(1.0, 1.0)]
+ set = [self.createSolution(1.0, 1.0)]
self.assertEqual(0.0, hyp(set))
- set = [createSolution(0.5, 0.0), createSolution(0.0, 0.5)]
+ set = [self.createSolution(0.5, 0.0), self.createSolution(0.0, 0.5)]
self.assertEqual(0.75, hyp(set))
def test_maximize(self):
- reference_set = [createSolution(0.0, 1.0), createSolution(1.0, 0.0)]
+ reference_set = [self.createSolution(0.0, 1.0), self.createSolution(1.0, 0.0)]
hyp = Hypervolume(reference_set)
problem = Problem(0, 2)
diff --git a/platypus/tests/test_io.py b/platypus/tests/test_io.py
index 6fb07f36..73f3c11d 100644
--- a/platypus/tests/test_io.py
+++ b/platypus/tests/test_io.py
@@ -17,20 +17,19 @@
# You should have received a copy of the GNU General Public License
# along with Platypus. If not, see .
-import math
import tempfile
import unittest
-from .test_core import createSolution
+from ._utils import SolutionMixin
from ..algorithms import NSGAII
-from ..core import FixedLengthArray
from ..problems import DTLZ2
-from ..io import save_objectives, load_objectives, save_json, load_json
+from ..io import save_objectives, load_objectives, save_json, load_json, \
+ save_state, load_state
-class TestObjectives(unittest.TestCase):
+class TestObjectives(SolutionMixin, unittest.TestCase):
def test(self):
- s1 = createSolution(0.0, 1.0)
- s2 = createSolution(1.0, 0.0)
+ s1 = self.createSolution(0.0, 1.0)
+ s2 = self.createSolution(1.0, 0.0)
expected = [s1, s2]
with tempfile.NamedTemporaryFile() as f:
@@ -42,18 +41,11 @@ def test(self):
for i in range(len(expected)):
self.assertEqual(expected[i].objectives, actual[i].objectives)
-class TestJSON(unittest.TestCase):
+class TestJSON(SolutionMixin, unittest.TestCase):
- def assertSimilar(self, a, b, epsilon=0.0000001):
- if isinstance(a, (list, FixedLengthArray)) and isinstance(b, (list, FixedLengthArray)):
- for (x, y) in zip(a, b):
- self.assertSimilar(x, y, epsilon)
- else:
- self.assertLessEqual(math.fabs(b - a), epsilon)
-
- def testSolutions(self):
- s1 = createSolution(0.0, 1.0)
- s2 = createSolution(1.0, 0.0)
+ def test_solutions(self):
+ s1 = self.createSolution(0.0, 1.0)
+ s2 = self.createSolution(1.0, 0.0)
expected = [s1, s2]
with tempfile.NamedTemporaryFile() as f:
@@ -64,11 +56,9 @@ def testSolutions(self):
for i in range(len(expected)):
self.assertIsNotNone(actual[i].problem)
- self.assertSimilar(expected[i].variables, actual[i].variables)
- self.assertSimilar(expected[i].objectives, actual[i].objectives)
- self.assertSimilar(expected[i].constraints, actual[i].constraints)
+ self.assertSimilar(expected[i], actual[i])
- def testAlgorithm(self):
+ def test_algorithm(self):
problem = DTLZ2()
algorithm = NSGAII(problem)
algorithm.run(10000)
@@ -86,3 +76,33 @@ def testAlgorithm(self):
self.assertSimilar(expected[i].variables, actual[i].variables)
self.assertSimilar(expected[i].objectives, actual[i].objectives)
self.assertSimilar(expected[i].constraints, actual[i].constraints)
+
+class TestState(SolutionMixin, unittest.TestCase):
+
+ def run_test(self, json):
+ problem = DTLZ2()
+ original = NSGAII(problem)
+
+ with tempfile.NamedTemporaryFile() as f:
+ save_state(f.name, original, json=json)
+
+ original.run(10000)
+
+ copy = load_state(f.name, json=json)
+ copy.run(10000)
+
+ self.assertEqual(original.nfe, copy.nfe)
+
+ expected = original.result
+ actual = copy.result
+
+ self.assertEqual(len(expected), len(actual))
+
+ for i in range(len(expected)):
+ self.assertSimilar(expected[i], actual[i])
+
+ def test_binary(self):
+ self.run_test(False)
+
+ def test_json(self):
+ self.run_test(True)
diff --git a/pyproject.toml b/pyproject.toml
index fe39f573..52b81e32 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,7 @@ dynamic = ["version"] # Version is read from platypus/__init__.py
"Bug Tracker" = "https://github.com/Project-Platypus/Platypus/issues"
[project.optional-dependencies]
-test = ["pytest", "mock", "flake8", "flake8-pyproject", "numpy", "matplotlib"]
+test = ["pytest", "mock", "flake8", "flake8-pyproject", "numpy", "matplotlib", "jsonpickle"]
docs = ["sphinx", "sphinx-rtd-theme"]
full = ["mpi4py", "Platypus-Opt[test]", "Platypus-Opt[docs]"]