Replaced ArrayIterator with StatementIterator in Portability\Connection

Fixes https://github.com/doctrine/dbal/issues/3114
parent 11037b43
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
namespace Doctrine\DBAL\Portability; namespace Doctrine\DBAL\Portability;
use Doctrine\DBAL\Driver\StatementIterator;
use Doctrine\DBAL\FetchMode; use Doctrine\DBAL\FetchMode;
use Doctrine\DBAL\ParameterType; use Doctrine\DBAL\ParameterType;
use function array_change_key_case; use function array_change_key_case;
...@@ -139,9 +140,7 @@ class Statement implements \IteratorAggregate, \Doctrine\DBAL\Driver\Statement ...@@ -139,9 +140,7 @@ class Statement implements \IteratorAggregate, \Doctrine\DBAL\Driver\Statement
*/ */
public function getIterator() public function getIterator()
{ {
$data = $this->fetchAll(); return new StatementIterator($this);
return new \ArrayIterator($data);
} }
/** /**
......
...@@ -2,39 +2,101 @@ ...@@ -2,39 +2,101 @@
namespace Doctrine\Tests\DBAL\Driver; namespace Doctrine\Tests\DBAL\Driver;
use Doctrine\DBAL\Driver\IBMDB2\DB2Statement;
use Doctrine\DBAL\Driver\Mysqli\MysqliStatement;
use Doctrine\DBAL\Driver\OCI8\OCI8Statement;
use Doctrine\DBAL\Driver\SQLAnywhere\SQLAnywhereStatement;
use Doctrine\DBAL\Driver\SQLSrv\SQLSrvStatement;
use Doctrine\DBAL\Driver\Statement; use Doctrine\DBAL\Driver\Statement;
use Doctrine\DBAL\Driver\StatementIterator; use Doctrine\DBAL\Driver\StatementIterator;
use Doctrine\DBAL\Portability\Statement as PortabilityStatement;
use IteratorAggregate;
use PHPUnit\Framework\MockObject\MockObject;
use Traversable;
use function extension_loaded;
class StatementIteratorTest extends \Doctrine\Tests\DbalTestCase class StatementIteratorTest extends \Doctrine\Tests\DbalTestCase
{ {
public function testGettingIteratorDoesNotCallFetch() /**
* @dataProvider statementProvider()
*/
public function testGettingIteratorDoesNotCallFetch(string $class) : void
{ {
$stmt = $this->createMock(Statement::class); /** @var IteratorAggregate|MockObject $stmt */
$stmt = $this->createPartialMock($class, ['fetch', 'fetchAll', 'fetchColumn']);
$stmt->expects($this->never())->method('fetch'); $stmt->expects($this->never())->method('fetch');
$stmt->expects($this->never())->method('fetchAll'); $stmt->expects($this->never())->method('fetchAll');
$stmt->expects($this->never())->method('fetchColumn'); $stmt->expects($this->never())->method('fetchColumn');
$stmt->getIterator();
}
public function testIteratorIterationCallsFetchOncePerStep() : void
{
$stmt = $this->createMock(Statement::class);
$calls = 0;
$this->configureStatement($stmt, $calls);
$stmtIterator = new StatementIterator($stmt); $stmtIterator = new StatementIterator($stmt);
$stmtIterator->getIterator();
$this->assertIterationCallsFetchOncePerStep($stmtIterator, $calls);
} }
public function testIterationCallsFetchOncePerStep() /**
* @dataProvider statementProvider()
*/
public function testStatementIterationCallsFetchOncePerStep(string $class) : void
{
$stmt = $this->createPartialMock($class, ['fetch']);
$calls = 0;
$this->configureStatement($stmt, $calls);
$this->assertIterationCallsFetchOncePerStep($stmt, $calls);
}
private function configureStatement(MockObject $stmt, int &$calls) : void
{ {
$values = ['foo', '', 'bar', '0', 'baz', 0, 'qux', null, 'quz', false, 'impossible']; $values = ['foo', '', 'bar', '0', 'baz', 0, 'qux', null, 'quz', false, 'impossible'];
$calls = 0; $calls = 0;
$stmt = $this->createMock(Statement::class);
$stmt->expects($this->exactly(10)) $stmt->expects($this->exactly(10))
->method('fetch') ->method('fetch')
->willReturnCallback(function() use ($values, &$calls) { ->willReturnCallback(function() use ($values, &$calls) {
$value = $values[$calls]; $value = $values[$calls];
$calls++; $calls++;
return $value; return $value;
}); });
}
$stmtIterator = new StatementIterator($stmt); private function assertIterationCallsFetchOncePerStep(Traversable $iterator, int &$calls) : void
foreach ($stmtIterator as $i => $_) { {
foreach ($iterator as $i => $_) {
$this->assertEquals($i + 1, $calls); $this->assertEquals($i + 1, $calls);
} }
} }
/**
* @return string[][]
*/
public static function statementProvider() : iterable
{
if (extension_loaded('ibm_db2')) {
yield [DB2Statement::class];
}
yield [MysqliStatement::class];
if (extension_loaded('oci8')) {
yield [OCI8Statement::class];
}
yield [PortabilityStatement::class];
yield [SQLAnywhereStatement::class];
if (extension_loaded('sqlsrv')) {
yield [SQLSrvStatement::class];
}
}
} }
...@@ -6,6 +6,7 @@ use Doctrine\DBAL\FetchMode; ...@@ -6,6 +6,7 @@ use Doctrine\DBAL\FetchMode;
use Doctrine\DBAL\ParameterType; use Doctrine\DBAL\ParameterType;
use Doctrine\DBAL\Portability\Connection; use Doctrine\DBAL\Portability\Connection;
use Doctrine\DBAL\Portability\Statement; use Doctrine\DBAL\Portability\Statement;
use function iterator_to_array;
class StatementTest extends \Doctrine\Tests\DbalTestCase class StatementTest extends \Doctrine\Tests\DbalTestCase
{ {
...@@ -141,16 +142,11 @@ class StatementTest extends \Doctrine\Tests\DbalTestCase ...@@ -141,16 +142,11 @@ class StatementTest extends \Doctrine\Tests\DbalTestCase
public function testGetIterator() public function testGetIterator()
{ {
$data = array( $this->wrappedStmt->expects($this->exactly(3))
'foo' => 'bar', ->method('fetch')
'bar' => 'foo' ->willReturnOnConsecutiveCalls('foo', 'bar', false);
);
$this->wrappedStmt->expects($this->once())
->method('fetchAll')
->will($this->returnValue($data));
self::assertEquals(new \ArrayIterator($data), $this->stmt->getIterator()); self::assertSame(['foo', 'bar'], iterator_to_array($this->stmt->getIterator()));
} }
public function testRowCount() public function testRowCount()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment