diff --git a/src/Ydb.Sdk/src/Ado/YdbDataReader.cs b/src/Ydb.Sdk/src/Ado/YdbDataReader.cs index 45c872a8..3fe6b2ea 100644 --- a/src/Ydb.Sdk/src/Ado/YdbDataReader.cs +++ b/src/Ydb.Sdk/src/Ado/YdbDataReader.cs @@ -1,4 +1,3 @@ -using System.Collections; using System.Data.Common; using Google.Protobuf.Collections; using Ydb.Issue; @@ -7,7 +6,7 @@ namespace Ydb.Sdk.Ado; -public sealed class YdbDataReader : DbDataReader +public sealed class YdbDataReader : DbDataReader, IAsyncEnumerable { private readonly IAsyncEnumerator _stream; private readonly YdbTransaction? _ydbTransaction; @@ -60,19 +59,72 @@ public sbyte GetSByte(int ordinal) return GetFieldYdbValue(ordinal).GetInt8(); } + // ReSharper disable once MemberCanBePrivate.Global + public byte[] GetBytes(int ordinal) + { + return GetFieldYdbValue(ordinal).GetString(); + } + public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) { - throw new NotImplementedException(); + var bytes = GetBytes(ordinal); + + CheckOffsets(dataOffset, buffer, bufferOffset, length); + + if (buffer == null) + { + return 0; + } + + var copyCount = Math.Min(bytes.Length - dataOffset, length); + Array.Copy(bytes, (int)dataOffset, buffer, bufferOffset, copyCount); + + return copyCount; } public override char GetChar(int ordinal) { - throw new NotImplementedException(); + return GetString(ordinal)[0]; } public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) { - throw new NotImplementedException(); + var chars = GetString(ordinal).ToCharArray(); + + CheckOffsets(dataOffset, buffer, bufferOffset, length); + + if (buffer == null) + { + return 0; + } + + var copyCount = Math.Min(chars.Length - dataOffset, length); + Array.Copy(chars, (int)dataOffset, buffer, bufferOffset, copyCount); + + return copyCount; + } + + private static void CheckOffsets(long dataOffset, T[]? buffer, int bufferOffset, int length) + { + if (dataOffset is < 0 or > int.MaxValue) + { + throw new IndexOutOfRangeException($"dataOffset must be between 0 and {int.MaxValue}"); + } + + if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length)) + { + throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length}"); + } + + if (buffer != null && length < 0) + { + throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length}"); + } + + if (buffer != null && length > buffer.Length - bufferOffset) + { + throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length - length}"); + } } public override string GetDataTypeName(int ordinal) @@ -124,7 +176,7 @@ public override float GetFloat(int ordinal) public override Guid GetGuid(int ordinal) { - throw new NotImplementedException(); + throw new YdbException("Ydb does not supported Guid"); } public override short GetInt16(int ordinal) @@ -267,9 +319,12 @@ public override async Task ReadAsync(CancellationToken cancellationToken) public override int Depth => 0; - public override IEnumerator GetEnumerator() + public override IEnumerator GetEnumerator() { - throw new NotImplementedException(); + while (Read()) + { + yield return new YdbDataRecord(this); + } } public override async Task CloseAsync() @@ -360,9 +415,8 @@ private async Task NextExecPart() if (part.Status != StatusIds.Types.StatusCode.Success) { - CompleteTransaction(); + OnFailReadStream(); - ReaderState = State.Closed; while (await _stream.MoveNextAsync()) { _issueMessagesInStream.AddRange(_stream.Current.Issues); @@ -389,14 +443,15 @@ private async Task NextExecPart() } catch (Driver.TransportException e) { - CompleteTransaction(); + OnFailReadStream(); throw new YdbException(e.Status); } } - private void CompleteTransaction() + private void OnFailReadStream() { + ReaderState = State.Closed; if (_ydbTransaction != null) { _ydbTransaction.Completed = true; @@ -407,4 +462,12 @@ public override async ValueTask DisposeAsync() { await CloseAsync(); } + + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = new()) + { + while (await ReadAsync(cancellationToken)) + { + yield return new YdbDataRecord(this); + } + } } diff --git a/src/Ydb.Sdk/src/Ado/YdbDataRecord.cs b/src/Ydb.Sdk/src/Ado/YdbDataRecord.cs new file mode 100644 index 00000000..acc9c5ed --- /dev/null +++ b/src/Ydb.Sdk/src/Ado/YdbDataRecord.cs @@ -0,0 +1,164 @@ +using System.Data.Common; + +namespace Ydb.Sdk.Ado; + +public class YdbDataRecord : DbDataRecord +{ + private readonly YdbDataReader _ydbDataReader; + + internal YdbDataRecord(YdbDataReader ydbDataReader) + { + _ydbDataReader = ydbDataReader; + } + + public override bool GetBoolean(int i) + { + return _ydbDataReader.GetBoolean(i); + } + + public override byte GetByte(int i) + { + return _ydbDataReader.GetByte(i); + } + + public override long GetBytes(int i, long dataIndex, byte[]? buffer, int bufferIndex, int length) + { + return _ydbDataReader.GetBytes(i, dataIndex, buffer, bufferIndex, length); + } + + public override char GetChar(int i) + { + return _ydbDataReader.GetChar(i); + } + + public override long GetChars(int i, long dataIndex, char[]? buffer, int bufferIndex, int length) + { + return _ydbDataReader.GetChars(i, dataIndex, buffer, bufferIndex, length); + } + + public override string GetDataTypeName(int i) + { + return _ydbDataReader.GetDataTypeName(i); + } + + public override DateTime GetDateTime(int i) + { + return _ydbDataReader.GetDateTime(i); + } + + public override decimal GetDecimal(int i) + { + return _ydbDataReader.GetDecimal(i); + } + + public override double GetDouble(int i) + { + return _ydbDataReader.GetDouble(i); + } + + public override System.Type GetFieldType(int i) + { + return _ydbDataReader.GetFieldType(i); + } + + public override float GetFloat(int i) + { + return _ydbDataReader.GetFloat(i); + } + + public override Guid GetGuid(int i) + { + return _ydbDataReader.GetGuid(i); + } + + public override short GetInt16(int i) + { + return _ydbDataReader.GetInt16(i); + } + + public override int GetInt32(int i) + { + return _ydbDataReader.GetInt32(i); + } + + public override long GetInt64(int i) + { + return _ydbDataReader.GetInt64(i); + } + + public override string GetName(int i) + { + return _ydbDataReader.GetName(i); + } + + public override int GetOrdinal(string name) + { + return _ydbDataReader.GetOrdinal(name); + } + + public override string GetString(int i) + { + return _ydbDataReader.GetString(i); + } + + public override object GetValue(int i) + { + return _ydbDataReader.GetValue(i); + } + + public override int GetValues(object[] values) + { + return _ydbDataReader.GetValues(values); + } + + public override bool IsDBNull(int i) + { + return _ydbDataReader.IsDBNull(i); + } + + public override int FieldCount => _ydbDataReader.FieldCount; + + public override object this[int i] => _ydbDataReader[i]; + + public override object this[string name] => _ydbDataReader[name]; + + public byte[] GetBytes(int i) + { + return _ydbDataReader.GetBytes(i); + } + + public sbyte GetSByte(int i) + { + return _ydbDataReader.GetSByte(i); + } + + public ulong GetUint16(int i) + { + return _ydbDataReader.GetUint16(i); + } + + public ulong GetUint32(int i) + { + return _ydbDataReader.GetUint32(i); + } + + public ulong GetUint64(int i) + { + return _ydbDataReader.GetUint64(i); + } + + public string GetJson(int i) + { + return _ydbDataReader.GetJson(i); + } + + public string GetJsonDocument(int i) + { + return _ydbDataReader.GetJsonDocument(i); + } + + public TimeSpan GetInterval(int i) + { + return _ydbDataReader.GetInterval(i); + } +} diff --git a/src/Ydb.Sdk/tests/Ado/YdbCommandTests.cs b/src/Ydb.Sdk/tests/Ado/YdbCommandTests.cs index 8f29ae17..eca60a88 100644 --- a/src/Ydb.Sdk/tests/Ado/YdbCommandTests.cs +++ b/src/Ydb.Sdk/tests/Ado/YdbCommandTests.cs @@ -1,4 +1,5 @@ using System.Data; +using System.Text; using Xunit; using Ydb.Sdk.Ado; @@ -168,4 +169,142 @@ public void ExecuteDbDataReader_WhenPreviousIsNotClosed_ThrowException() ydbDataReader.Close(); Assert.True(ydbDataReader.IsClosed); } + + [Fact] + public void GetChars_WhenSelectText_MoveCharsToBuffer() + { + using var connection = new YdbConnection(); + connection.Open(); + var ydbDataReader = + new YdbCommand(connection) { CommandText = "SELECT CAST('abacaba' AS Text)" }.ExecuteReader(); + Assert.True(ydbDataReader.Read()); + var bufferChars = new char[10]; + var checkBuffer = new char[10]; + + Assert.Equal(0, ydbDataReader.GetChars(0, 4, null, 0, 6)); + Assert.Equal($"dataOffset must be between 0 and {int.MaxValue}", + Assert.Throws(() => ydbDataReader.GetChars(0, -1, null, 0, 6)).Message); + Assert.Equal($"dataOffset must be between 0 and {int.MaxValue}", + Assert.Throws( + () => ydbDataReader.GetChars(0, long.MaxValue, null, 0, 6)).Message); + + Assert.Equal("bufferOffset must be between 0 and 10", Assert.Throws( + () => ydbDataReader.GetChars(0, 0, bufferChars, -1, 6)).Message); + Assert.Equal("bufferOffset must be between 0 and 10", Assert.Throws( + () => ydbDataReader.GetChars(0, 0, bufferChars, -1, 6)).Message); + + Assert.Equal("length must be between 0 and 10", Assert.Throws( + () => ydbDataReader.GetChars(0, 0, bufferChars, 3, -1)).Message); + Assert.Equal("bufferOffset must be between 0 and 5", Assert.Throws( + () => ydbDataReader.GetChars(0, 0, bufferChars, 8, 5)).Message); + + Assert.Equal(6, ydbDataReader.GetChars(0, 0, bufferChars, 4, 6)); + checkBuffer[4] = 'a'; + checkBuffer[5] = 'b'; + checkBuffer[6] = 'a'; + checkBuffer[7] = 'c'; + checkBuffer[8] = 'a'; + checkBuffer[9] = 'b'; + Assert.Equal(checkBuffer, bufferChars); + bufferChars = new char[10]; + checkBuffer = new char[10]; + + Assert.Equal(4, ydbDataReader.GetChars(0, 3, bufferChars, 4, 6)); + checkBuffer[4] = 'c'; + checkBuffer[5] = 'a'; + checkBuffer[6] = 'b'; + checkBuffer[7] = 'a'; + Assert.Equal(checkBuffer, bufferChars); + + Assert.Equal('a', ydbDataReader.GetChar(0)); + } + + [Fact] + public void GetBytes_WhenSelectBytes_MoveBytesToBuffer() + { + using var connection = new YdbConnection(); + connection.Open(); + var ydbDataReader = new YdbCommand(connection) { CommandText = "SELECT 'abacaba'" }.ExecuteReader(); + Assert.True(ydbDataReader.Read()); + var bufferChars = new byte[10]; + var checkBuffer = new byte[10]; + + Assert.Equal(0, ydbDataReader.GetBytes(0, 4, null, 0, 6)); + Assert.Equal($"dataOffset must be between 0 and {int.MaxValue}", + Assert.Throws(() => ydbDataReader.GetBytes(0, -1, null, 0, 6)).Message); + Assert.Equal($"dataOffset must be between 0 and {int.MaxValue}", + Assert.Throws( + () => ydbDataReader.GetBytes(0, long.MaxValue, null, 0, 6)).Message); + + Assert.Equal("bufferOffset must be between 0 and 10", Assert.Throws( + () => ydbDataReader.GetBytes(0, 0, bufferChars, -1, 6)).Message); + Assert.Equal("bufferOffset must be between 0 and 10", Assert.Throws( + () => ydbDataReader.GetBytes(0, 0, bufferChars, -1, 6)).Message); + + Assert.Equal("length must be between 0 and 10", Assert.Throws( + () => ydbDataReader.GetBytes(0, 0, bufferChars, 3, -1)).Message); + Assert.Equal("bufferOffset must be between 0 and 5", Assert.Throws( + () => ydbDataReader.GetBytes(0, 0, bufferChars, 8, 5)).Message); + + Assert.Equal(6, ydbDataReader.GetBytes(0, 0, bufferChars, 4, 6)); + checkBuffer[4] = (byte)'a'; + checkBuffer[5] = (byte)'b'; + checkBuffer[6] = (byte)'a'; + checkBuffer[7] = (byte)'c'; + checkBuffer[8] = (byte)'a'; + checkBuffer[9] = (byte)'b'; + Assert.Equal(checkBuffer, bufferChars); + bufferChars = new byte[10]; + checkBuffer = new byte[10]; + + Assert.Equal(4, ydbDataReader.GetBytes(0, 3, bufferChars, 4, 5)); + checkBuffer[4] = (byte)'c'; + checkBuffer[5] = (byte)'a'; + checkBuffer[6] = (byte)'b'; + checkBuffer[7] = (byte)'a'; + Assert.Equal(checkBuffer, bufferChars); + } + + [Fact] + public async Task GetEnumerator_WhenReadMultiSelect_ReadFirstResultSet() + { + await using var ydbConnection = new YdbConnection(); + ydbConnection.Open(); + var ydbCommand = new YdbCommand(ydbConnection) + { + CommandText = @" +$new_data = AsList( + AsStruct(1 AS Key, 'text' AS Value), + AsStruct(1 AS Key, 'text' AS Value) +); + +SELECT Key, Cast(Value AS Text) FROM AS_TABLE($new_data); SELECT 1, 'text';" + }; + var ydbDataReader = ydbCommand.ExecuteReader(); + + foreach (var row in ydbDataReader) + { + Assert.Equal(1, row.GetInt32(0)); + Assert.Equal("text", row.GetString(1)); + } + + Assert.True(ydbDataReader.NextResult()); + Assert.True(ydbDataReader.Read()); + Assert.Equal(1, ydbDataReader.GetInt32(0)); + Assert.Equal(Encoding.ASCII.GetBytes("text"), ydbDataReader.GetBytes(1)); + Assert.False(ydbDataReader.Read()); + + ydbDataReader = ydbCommand.ExecuteReader(); + await foreach (var row in ydbDataReader) + { + Assert.Equal(1, row.GetInt32(0)); + Assert.Equal("text", row.GetString(1)); + } + + Assert.True(ydbDataReader.NextResult()); + Assert.True(ydbDataReader.Read()); + Assert.Equal(1, ydbDataReader.GetInt32(0)); + Assert.Equal(Encoding.ASCII.GetBytes("text"), ydbDataReader.GetBytes(1)); + Assert.False(ydbDataReader.Read()); + } }