Skip to content

Commit

Permalink
Merge pull request #11458 from KratosMultiphysics/core/mpi_data_commu…
Browse files Browse the repository at this point in the history
…nicator_pybind_hotfix

[Core] Fix pycaster casting data to wrong types in data_communicator
  • Loading branch information
sunethwarna authored Aug 4, 2023
2 parents 8defbb5 + 8860a7f commit 2a7e8f2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 76 deletions.
64 changes: 63 additions & 1 deletion kratos/mpi/tests/test_mpi_data_communicator_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class TestMPIDataCommunicatorPython(KratosUnittest.TestCase):

def setUp(self):
self.world = Kratos.Testing.GetDefaultDataCommunicator()
self.world: Kratos.DataCommunicator = Kratos.Testing.GetDefaultDataCommunicator()
self.rank = self.world.Rank()
self.size = self.world.Size()

Expand Down Expand Up @@ -261,6 +261,68 @@ def testAllGathervOperations(self):
self.assertEqual(gathered_ints, [ [i, i+1, i+2] for i in range(self.size)])
self.assertEqual(gathered_doubles, [ [1.0 + i, 2.0 + i , 3.0 + i] for i in range(self.size)])

def test_CastingTypes(self):
n = self.world.Size()

def check_value_in_rank(ref_value, value, ranks_to_check: 'list[int]'):
self.assertTrue(isinstance(value, type(ref_value)))
if self.rank in ranks_to_check:
if isinstance(ref_value, int) or isinstance(ref_value, float):
self.assertEqual(ref_value, value)
elif isinstance(ref_value, Kratos.Array3) or isinstance(ref_value, Kratos.Array4) or isinstance(ref_value, Kratos.Array6) or isinstance(ref_value, Kratos.Array9) or isinstance(ref_value, Kratos.Vector):
self.assertVectorAlmostEqual(ref_value, value)
elif isinstance(ref_value, Kratos.Matrix):
self.assertMatrixAlmostEqual(ref_value, value)
elif isinstance(ref_value, list):
check_value_in_rank(ref_value[0], value[0], ranks_to_check)

def simple_reduce_check(method, rank: int, ref_value: float, ranks_to_check: 'list[int]'):
check_value_in_rank(method(self.rank+1, rank), int(ref_value), ranks_to_check)
check_value_in_rank(method(float(self.rank+1), rank), float(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array3(self.rank+1), rank), Kratos.Array3(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array4(self.rank+1), rank), Kratos.Array4(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array6(self.rank+1), rank), Kratos.Array6(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array9(self.rank+1), rank), Kratos.Array9(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Vector(2, self.rank+1), rank), Kratos.Vector(2, ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Matrix(2, 2, self.rank+1), rank), Kratos.Matrix(2, 2, ref_value), ranks_to_check)

check_value_in_rank(getattr(self.world, method.__name__ + "Ints")([self.rank+1], rank), [int(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Doubles")([float(self.rank+1)], rank), [float(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array3s")([Kratos.Array3(self.rank+1)], rank), [Kratos.Array3(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array4s")([Kratos.Array4(self.rank+1)], rank), [Kratos.Array4(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array6s")([Kratos.Array6(self.rank+1)], rank), [Kratos.Array6(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array9s")([Kratos.Array9(self.rank+1)], rank), [Kratos.Array9(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Vectors")([Kratos.Vector(2, self.rank+1)], rank), [Kratos.Vector(2, ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Matrices")([Kratos.Matrix(2, 2, self.rank+1)], rank), [Kratos.Matrix(2, 2, ref_value)], ranks_to_check)

def simple_all_reduce_check(method, ref_value: float):
ranks_to_check = [i for i in range(n)]
check_value_in_rank(method(self.rank+1), int(ref_value), ranks_to_check)
check_value_in_rank(method(float(self.rank+1)), float(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array3(self.rank+1)), Kratos.Array3(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array3(self.rank+1)), Kratos.Array3(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array3(self.rank+1)), Kratos.Array3(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Array3(self.rank+1)), Kratos.Array3(ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Vector(2, self.rank+1)), Kratos.Vector(2, ref_value), ranks_to_check)
check_value_in_rank(method(Kratos.Matrix(2, 2, self.rank+1)), Kratos.Matrix(2, 2, ref_value), ranks_to_check)

check_value_in_rank(getattr(self.world, method.__name__ + "Ints")([self.rank+1]), [int(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Doubles")([float(self.rank+1)]), [float(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array3s")([Kratos.Array3(self.rank+1)]), [Kratos.Array3(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array4s")([Kratos.Array4(self.rank+1)]), [Kratos.Array4(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array6s")([Kratos.Array6(self.rank+1)]), [Kratos.Array6(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Array9s")([Kratos.Array9(self.rank+1)]), [Kratos.Array9(ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Vectors")([Kratos.Vector(2, self.rank+1)]), [Kratos.Vector(2, ref_value)], ranks_to_check)
check_value_in_rank(getattr(self.world, method.__name__ + "Matrices")([Kratos.Matrix(2, 2, self.rank+1)]), [Kratos.Matrix(2, 2, ref_value)], ranks_to_check)

simple_reduce_check(self.world.Sum, 0, n*(n+1)/2, [0])
simple_reduce_check(self.world.Min, 0, 1, [0])
simple_reduce_check(self.world.Max, 0, n, [0])
simple_reduce_check(self.world.Broadcast, 0, 1, [i for i in range(n)])
simple_all_reduce_check(self.world.SumAll, n*(n+1)/2)
simple_all_reduce_check(self.world.MinAll, 1)
simple_all_reduce_check(self.world.MaxAll, n)

if __name__ == "__main__":
Kratos.Logger.GetDefaultOutput().SetSeverity(Kratos.Logger.Severity.WARNING)
KratosUnittest.main()
Loading

0 comments on commit 2a7e8f2

Please sign in to comment.