From b8a422f6cef0e15026745b44a0906e586aaa11b6 Mon Sep 17 00:00:00 2001 From: Rafal Piotrowski Date: Mon, 19 Apr 2021 16:33:32 +0200 Subject: [PATCH 1/2] Make thread local field transient in `ThreadLocalRandomGenerator` Make it also lazy to not create ThreadLocal until it is used. Fixes #806 --- .../ThreadLocalRandomGenerator.scala | 2 +- .../ThreadLocalRandomGeneratorTest.scala | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala diff --git a/math/src/main/scala/breeze/stats/distributions/ThreadLocalRandomGenerator.scala b/math/src/main/scala/breeze/stats/distributions/ThreadLocalRandomGenerator.scala index bc76fa344..fc293e7c7 100644 --- a/math/src/main/scala/breeze/stats/distributions/ThreadLocalRandomGenerator.scala +++ b/math/src/main/scala/breeze/stats/distributions/ThreadLocalRandomGenerator.scala @@ -9,7 +9,7 @@ import org.apache.commons.math3.random.RandomGenerator **/ @SerialVersionUID(1L) class ThreadLocalRandomGenerator(genThunk: => RandomGenerator) extends RandomGenerator with Serializable { - private val genTL = new ThreadLocal[RandomGenerator] { + @transient private lazy val genTL = new ThreadLocal[RandomGenerator] { override def initialValue(): RandomGenerator = genThunk } def nextBytes(bytes: Array[Byte]) = genTL.get().nextBytes(bytes) diff --git a/math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala b/math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala new file mode 100644 index 000000000..0a18623e2 --- /dev/null +++ b/math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala @@ -0,0 +1,54 @@ +package breeze.stats.distributions + +import com.sun.xml.internal.ws.encoding.soap.SerializationException +import org.apache.commons.math3.random.MersenneTwister +import org.scalatest.{FunSuite, Matchers} + +import java.io._ + +class ThreadLocalRandomGeneratorTest extends FunSuite with Matchers { + test("ThreadLocalRandomGeneratorTest should be serializable") { + val generator = new ThreadLocalRandomGenerator(new MersenneTwister()) + serialize(generator) + } + + test("ThreadLocalRandomGeneratorTest should be serializable after usage") { + val generator = new ThreadLocalRandomGenerator(new MersenneTwister()) + generator.nextInt() + serialize(generator) + } + + test("ThreadLocalRandomGeneratorTest should be deserializable") { + val generator = new ThreadLocalRandomGenerator(new MersenneTwister()) + val i1 = generator.nextInt() + val bytes = serialize(generator) + val deserialized = deserialize(bytes) + val i2 = deserialized.nextInt() + + i1 should not be i2 + } + + private def serialize(generator: ThreadLocalRandomGenerator): Array[Byte] = { + val outputStream = new ByteArrayOutputStream(512) + val out = new ObjectOutputStream(outputStream) + try { + out.writeObject(generator) + outputStream.toByteArray + } catch { + case ex: IOException => throw new SerializationException(ex) + } finally { + if (out != null) out.close() + } + } + + private def deserialize(bytes: Array[Byte]): ThreadLocalRandomGenerator = { + val in = new ObjectInputStream(new ByteArrayInputStream(bytes)) + try { + in.readObject().asInstanceOf[ThreadLocalRandomGenerator] + } catch { + case ex: IOException => throw new SerializationException(ex) + } finally { + if (in != null) in.close() + } + } +} From bf0ef50ab9566817b539c32ee2b03e65ccbebdd9 Mon Sep 17 00:00:00 2001 From: Rafal Piotrowski Date: Mon, 19 Apr 2021 21:05:33 +0200 Subject: [PATCH 2/2] Fix imports --- .../stats/distributions/ThreadLocalRandomGeneratorTest.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala b/math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala index 0a18623e2..f8595e4c3 100644 --- a/math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala +++ b/math/src/test/scala/breeze/stats/distributions/ThreadLocalRandomGeneratorTest.scala @@ -1,6 +1,5 @@ package breeze.stats.distributions -import com.sun.xml.internal.ws.encoding.soap.SerializationException import org.apache.commons.math3.random.MersenneTwister import org.scalatest.{FunSuite, Matchers} @@ -35,7 +34,7 @@ class ThreadLocalRandomGeneratorTest extends FunSuite with Matchers { out.writeObject(generator) outputStream.toByteArray } catch { - case ex: IOException => throw new SerializationException(ex) + case _: IOException => fail("cannot serialize") } finally { if (out != null) out.close() } @@ -46,7 +45,7 @@ class ThreadLocalRandomGeneratorTest extends FunSuite with Matchers { try { in.readObject().asInstanceOf[ThreadLocalRandomGenerator] } catch { - case ex: IOException => throw new SerializationException(ex) + case _: IOException => fail("cannot deserialize") } finally { if (in != null) in.close() }