Skip to content

Commit

Permalink
Fixes in \copy implementation for Clickhouse
Browse files Browse the repository at this point in the history
  • Loading branch information
murfffi committed Apr 12, 2024
1 parent 99e974d commit 9ad5cce
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 3 deletions.
80 changes: 79 additions & 1 deletion drivers/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
package clickhouse

import (
"context"
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"

Expand Down Expand Up @@ -35,7 +38,82 @@ func init() {
}
return false
},
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
Copy: CopyWithInsert,
NewMetadataReader: NewMetadataReader,
})
}

// CopyWithInsert builds a copy handler based on insert.
func CopyWithInsert(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
columns, err := rows.Columns()
if err != nil {
return 0, fmt.Errorf("failed to fetch source rows columns: %w", err)
}
clen := len(columns)
query := table
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
leftParen := strings.IndexRune(table, '(')
if leftParen == -1 {
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
if err != nil {
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
}
columns, err := colRows.Columns()
_ = colRows.Close()
if err != nil {
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
}
table += "(" + strings.Join(columns, ", ") + ")"
}
// TODO if the db supports multiple rows per insert, create batches of 10 rows
query = "INSERT INTO " + table + " VALUES (" + strings.Repeat("?, ", clen-1) + "?)"
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return 0, fmt.Errorf("failed to begin transaction: %w", err)
}
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return 0, fmt.Errorf("failed to prepare insert query: %w", err)
}
defer stmt.Close()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
}
values := make([]interface{}, clen)
valueRefs := make([]reflect.Value, clen)
actuals := make([]interface{}, clen)
for i := 0; i < len(columnTypes); i++ {
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
values[i] = valueRefs[i].Interface()
}
var n int64
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
return n, fmt.Errorf("failed to scan row: %w", err)
}
//We can't use values... in Exec() below, because, in some cases, clickhouse
//driver doesn't accept pointer to an argument instead of the arg itself.
for i := range values {
actuals[i] = valueRefs[i].Elem().Interface()
}
res, err := stmt.ExecContext(ctx, actuals...)
if err != nil {
return n, fmt.Errorf("failed to exec insert: %w", err)
}
rn, err := res.RowsAffected()
if err != nil {
return n, fmt.Errorf("failed to check rows affected: %w", err)
}
n += rn
}
// TODO if using batches, flush the last batch,
// TODO prepare another statement and count remaining rows
err = tx.Commit()
if err != nil {
return n, fmt.Errorf("failed to commit transaction: %w", err)
}
return n, rows.Err()
}
80 changes: 78 additions & 2 deletions drivers/clickhouse/clickhouse_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
package clickhouse_test

import (
"context"
"database/sql"
"flag"
"fmt"
"github.com/xo/dburl"
"github.com/xo/usql/drivers"
"log"
"os"
"path/filepath"
"testing"
"time"

dt "github.com/ory/dockertest/v3"
"github.com/xo/usql/drivers/clickhouse"
"github.com/xo/usql/drivers/metadata"
"github.com/yookoala/realpath"

_ "github.com/xo/usql/drivers/csvq"
_ "github.com/xo/usql/drivers/moderncsqlite"
)

// db is the database connection.
Expand Down Expand Up @@ -59,7 +66,7 @@ func doMain(m *testing.M, cleanup bool) (int, error) {
if cleanup {
defer func() {
if err := pool.Purge(db.res); err != nil {
fmt.Fprintf(os.Stderr, "error: could not purge resoure: %v\n", err)
fmt.Fprintf(os.Stderr, "error: could not purge resource: %v\n", err)
}
}()
}
Expand All @@ -85,7 +92,7 @@ func TestSchemas(t *testing.T) {
if err != nil {
t.Fatalf("could not read schemas: %v", err)
}
checkNames(t, "schema", res, "default", "system", "tutorial", "tutorial_unexpected", "INFORMATION_SCHEMA", "information_schema")
checkNames(t, "schema", res, "default", "system", "tutorial", "tutorial_unexpected", "INFORMATION_SCHEMA", "information_schema", "copy_test")
}

func TestTables(t *testing.T) {
Expand Down Expand Up @@ -119,6 +126,75 @@ func TestColumns(t *testing.T) {
checkNames(t, "column", res, colNames()...)
}

func TestCopy(t *testing.T) {
// Tests with csvq source DB. That driver doesn't support ScanType()
for _, destTableSpec := range []string{
"copy_test.dest",
"copy_test.dest(StringCol, NumCol)",
"insert into copy_test.dest values(?, ?)",
} {
t.Run("csvq_"+destTableSpec, func(t *testing.T) {
testCopy(t, destTableSpec, "csvq:.")
})
}
// Test with a driver that supports ScanType()
t.Run("sqlite", func(t *testing.T) {
testCopy(t, "copy_test.dest", "moderncsqlite://:memory:")
})
}

func testCopy(t *testing.T, destTableSpec string, sourceDbUrlStr string) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_, err := db.db.ExecContext(ctx, "truncate table copy_test.dest")
if err != nil {
t.Fatalf("could not truncate copy_test table: %v", err)
}
// Prepare copy destination URL
port := db.res.GetPort("9000/tcp")
dbUrlStr := fmt.Sprintf("clickhouse://127.0.0.1:%s", port)
dbUrl, err := dburl.Parse(dbUrlStr)
if err != nil {
t.Fatalf("could not parse clickhouse url %s: %v", dbUrlStr, err)
}
// Prepare source data
sourceDbUrl, err := dburl.Parse(sourceDbUrlStr)
if err != nil {
t.Fatalf("could not parse source DB url %s: %v", sourceDbUrlStr, err)
}
sourceDb, err := drivers.Open(ctx, sourceDbUrl, nil, nil)
if err != nil {
t.Fatalf("could not open sourceDb: %v", err)
}
defer sourceDb.Close()
rows, err := sourceDb.QueryContext(ctx, "select 'string', 1")
if err != nil {
t.Fatalf("could not retrieve source rows: %v", err)
}
// Do Copy, ignoring copied rows count because clickhouse driver doesn't report RowsAffected
_, err = drivers.Copy(ctx, dbUrl, nil, nil, rows, destTableSpec)
if err != nil {
t.Fatalf("copy failed: %v", err)
}
rows, err = db.db.QueryContext(ctx, "select StringCol, NumCol from copy_test.dest")
if err != nil {
t.Fatalf("failed to query: %v", err)
}
defer rows.Close()
var copiedString string
var copiedNum int
if !rows.Next() {
t.Fatalf("nothing copied")
}
err = rows.Scan(&copiedString, &copiedNum)
if err != nil {
t.Fatalf("could not read copied data: %v", err)
}
if copiedString != "string" || copiedNum != 1 {
t.Fatalf("copied data differs: %s != string, %d != 1", copiedString, copiedNum)
}
}

func checkNames(t *testing.T, typ string, res interface{ Next() bool }, exp ...string) {
n := make(map[string]bool)
for _, s := range exp {
Expand Down
8 changes: 8 additions & 0 deletions drivers/clickhouse/testdata/clickhouse.sql
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,11 @@ CREATE TABLE tutorial_unexpected.hits_v1 (
)
ENGINE = MergeTree()
ORDER BY (Unexpected);

CREATE DATABASE copy_test;
CREATE TABLE copy_test.dest (
StringCol String,
NumCol UInt32
)
ENGINE = MergeTree()
ORDER BY (StringCol);

0 comments on commit 9ad5cce

Please sign in to comment.