diff --git a/src/Connection.php b/src/Connection.php index baf620697f..4410aa1b07 100644 --- a/src/Connection.php +++ b/src/Connection.php @@ -9,6 +9,7 @@ use Doctrine\DBAL\Cache\QueryCacheProfile; use Doctrine\DBAL\Driver\API\ExceptionConverter; use Doctrine\DBAL\Driver\Connection as DriverConnection; +use Doctrine\DBAL\Driver\Exception as TheDriverException; use Doctrine\DBAL\Driver\ServerInfoAwareConnection; use Doctrine\DBAL\Driver\Statement as DriverStatement; use Doctrine\DBAL\Event\TransactionBeginEventArgs; @@ -1279,12 +1280,29 @@ public function transactional(Closure $func) { $this->beginTransaction(); + $successful = false; + try { $res = $func($this); + $successful = true; + } finally { + if (! $successful) { + $this->rollBack(); + } + } + + $shouldRollback = true; + try { $this->commit(); + + $shouldRollback = false; + } catch (TheDriverException $t) { + $shouldRollback = false; + + throw $t; } finally { - if ($this->isTransactionActive()) { + if ($shouldRollback) { $this->rollBack(); } } diff --git a/tests/Functional/TransactionTest.php b/tests/Functional/TransactionTest.php index 1421bd10a8..9e4fbfcd8e 100644 --- a/tests/Functional/TransactionTest.php +++ b/tests/Functional/TransactionTest.php @@ -2,6 +2,7 @@ namespace Doctrine\DBAL\Tests\Functional; +use Doctrine\DBAL\Connection; use Doctrine\DBAL\Driver\Exception as DriverException; use Doctrine\DBAL\Platforms\AbstractMySQLPlatform; use Doctrine\DBAL\Tests\FunctionalTestCase; @@ -11,17 +12,12 @@ class TransactionTest extends FunctionalTestCase { - protected function setUp(): void + public function testCommitFalse(): void { - if ($this->connection->getDatabasePlatform() instanceof AbstractMySQLPlatform) { - return; + if (! $this->connection->getDatabasePlatform() instanceof AbstractMySQLPlatform) { + $this->markTestSkipped('Restricted to MySQL.'); } - $this->markTestSkipped('Restricted to MySQL.'); - } - - public function testCommitFalse(): void - { $this->connection->executeStatement('SET SESSION wait_timeout=1'); self::assertTrue($this->connection->beginTransaction()); @@ -40,4 +36,15 @@ public function testCommitFalse(): void $this->connection->close(); } } + + public function testNestedTransactionWalkthrough(): void + { + $result = $this->connection->transactional( + static fn (Connection $connection) => $connection->transactional( + static fn (Connection $connection) => $connection->fetchOne('SELECT 1'), + ), + ); + + self::assertSame('1', (string) $result); + } }