diff --git a/drivers/clickhouse/clickhouse.go b/drivers/clickhouse/clickhouse.go index b397736f571..24c244b0d81 100644 --- a/drivers/clickhouse/clickhouse.go +++ b/drivers/clickhouse/clickhouse.go @@ -5,7 +5,10 @@ package clickhouse import ( + "context" "database/sql" + "fmt" + "reflect" "strconv" "strings" @@ -35,7 +38,82 @@ func init() { } return false }, - Copy: drivers.CopyWithInsert(func(int) string { return "?" }), + Copy: CopyWithInsert, NewMetadataReader: NewMetadataReader, }) } + +// CopyWithInsert builds a copy handler based on insert. +func CopyWithInsert(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) { + columns, err := rows.Columns() + if err != nil { + return 0, fmt.Errorf("failed to fetch source rows columns: %w", err) + } + clen := len(columns) + query := table + if !strings.HasPrefix(strings.ToLower(query), "insert into") { + leftParen := strings.IndexRune(table, '(') + if leftParen == -1 { + colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0") + if err != nil { + return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err) + } + columns, err := colRows.Columns() + _ = colRows.Close() + if err != nil { + return 0, fmt.Errorf("failed to fetch target table columns: %w", err) + } + table += "(" + strings.Join(columns, ", ") + ")" + } + // TODO if the db supports multiple rows per insert, create batches of 10 rows + query = "INSERT INTO " + table + " VALUES (" + strings.Repeat("?, ", clen-1) + "?)" + } + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("failed to begin transaction: %w", err) + } + stmt, err := tx.PrepareContext(ctx, query) + if err != nil { + return 0, fmt.Errorf("failed to prepare insert query: %w", err) + } + defer stmt.Close() + columnTypes, err := rows.ColumnTypes() + if err != nil { + return 0, fmt.Errorf("failed to fetch source column types: %w", err) + } + values := make([]interface{}, clen) + valueRefs := make([]reflect.Value, clen) + actuals := make([]interface{}, clen) + for i := 0; i < len(columnTypes); i++ { + valueRefs[i] = reflect.New(columnTypes[i].ScanType()) + values[i] = valueRefs[i].Interface() + } + var n int64 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + return n, fmt.Errorf("failed to scan row: %w", err) + } + //We can't use values... in Exec() below, because, in some cases, clickhouse + //driver doesn't accept pointer to an argument instead of the arg itself. + for i := range values { + actuals[i] = valueRefs[i].Elem().Interface() + } + res, err := stmt.ExecContext(ctx, actuals...) + if err != nil { + return n, fmt.Errorf("failed to exec insert: %w", err) + } + rn, err := res.RowsAffected() + if err != nil { + return n, fmt.Errorf("failed to check rows affected: %w", err) + } + n += rn + } + // TODO if using batches, flush the last batch, + // TODO prepare another statement and count remaining rows + err = tx.Commit() + if err != nil { + return n, fmt.Errorf("failed to commit transaction: %w", err) + } + return n, rows.Err() +} diff --git a/drivers/clickhouse/clickhouse_test.go b/drivers/clickhouse/clickhouse_test.go index 87c3fb37b38..d138f186673 100644 --- a/drivers/clickhouse/clickhouse_test.go +++ b/drivers/clickhouse/clickhouse_test.go @@ -1,18 +1,25 @@ package clickhouse_test import ( + "context" "database/sql" "flag" "fmt" + "github.com/xo/dburl" + "github.com/xo/usql/drivers" "log" "os" "path/filepath" "testing" + "time" dt "github.com/ory/dockertest/v3" "github.com/xo/usql/drivers/clickhouse" "github.com/xo/usql/drivers/metadata" "github.com/yookoala/realpath" + + _ "github.com/xo/usql/drivers/csvq" + _ "github.com/xo/usql/drivers/moderncsqlite" ) // db is the database connection. @@ -59,7 +66,7 @@ func doMain(m *testing.M, cleanup bool) (int, error) { if cleanup { defer func() { if err := pool.Purge(db.res); err != nil { - fmt.Fprintf(os.Stderr, "error: could not purge resoure: %v\n", err) + fmt.Fprintf(os.Stderr, "error: could not purge resource: %v\n", err) } }() } @@ -85,7 +92,7 @@ func TestSchemas(t *testing.T) { if err != nil { t.Fatalf("could not read schemas: %v", err) } - checkNames(t, "schema", res, "default", "system", "tutorial", "tutorial_unexpected", "INFORMATION_SCHEMA", "information_schema") + checkNames(t, "schema", res, "default", "system", "tutorial", "tutorial_unexpected", "INFORMATION_SCHEMA", "information_schema", "copy_test") } func TestTables(t *testing.T) { @@ -119,6 +126,75 @@ func TestColumns(t *testing.T) { checkNames(t, "column", res, colNames()...) } +func TestCopy(t *testing.T) { + // Tests with csvq source DB. That driver doesn't support ScanType() + for _, destTableSpec := range []string{ + "copy_test.dest", + "copy_test.dest(StringCol, NumCol)", + "insert into copy_test.dest values(?, ?)", + } { + t.Run("csvq_"+destTableSpec, func(t *testing.T) { + testCopy(t, destTableSpec, "csvq:.") + }) + } + // Test with a driver that supports ScanType() + t.Run("sqlite", func(t *testing.T) { + testCopy(t, "copy_test.dest", "moderncsqlite://:memory:") + }) +} + +func testCopy(t *testing.T, destTableSpec string, sourceDbUrlStr string) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _, err := db.db.ExecContext(ctx, "truncate table copy_test.dest") + if err != nil { + t.Fatalf("could not truncate copy_test table: %v", err) + } + // Prepare copy destination URL + port := db.res.GetPort("9000/tcp") + dbUrlStr := fmt.Sprintf("clickhouse://127.0.0.1:%s", port) + dbUrl, err := dburl.Parse(dbUrlStr) + if err != nil { + t.Fatalf("could not parse clickhouse url %s: %v", dbUrlStr, err) + } + // Prepare source data + sourceDbUrl, err := dburl.Parse(sourceDbUrlStr) + if err != nil { + t.Fatalf("could not parse source DB url %s: %v", sourceDbUrlStr, err) + } + sourceDb, err := drivers.Open(ctx, sourceDbUrl, nil, nil) + if err != nil { + t.Fatalf("could not open sourceDb: %v", err) + } + defer sourceDb.Close() + rows, err := sourceDb.QueryContext(ctx, "select 'string', 1") + if err != nil { + t.Fatalf("could not retrieve source rows: %v", err) + } + // Do Copy, ignoring copied rows count because clickhouse driver doesn't report RowsAffected + _, err = drivers.Copy(ctx, dbUrl, nil, nil, rows, destTableSpec) + if err != nil { + t.Fatalf("copy failed: %v", err) + } + rows, err = db.db.QueryContext(ctx, "select StringCol, NumCol from copy_test.dest") + if err != nil { + t.Fatalf("failed to query: %v", err) + } + defer rows.Close() + var copiedString string + var copiedNum int + if !rows.Next() { + t.Fatalf("nothing copied") + } + err = rows.Scan(&copiedString, &copiedNum) + if err != nil { + t.Fatalf("could not read copied data: %v", err) + } + if copiedString != "string" || copiedNum != 1 { + t.Fatalf("copied data differs: %s != string, %d != 1", copiedString, copiedNum) + } +} + func checkNames(t *testing.T, typ string, res interface{ Next() bool }, exp ...string) { n := make(map[string]bool) for _, s := range exp { diff --git a/drivers/clickhouse/testdata/clickhouse.sql b/drivers/clickhouse/testdata/clickhouse.sql index 78f1870025a..b14a35c6e34 100644 --- a/drivers/clickhouse/testdata/clickhouse.sql +++ b/drivers/clickhouse/testdata/clickhouse.sql @@ -340,3 +340,11 @@ CREATE TABLE tutorial_unexpected.hits_v1 ( ) ENGINE = MergeTree() ORDER BY (Unexpected); + +CREATE DATABASE copy_test; +CREATE TABLE copy_test.dest ( + StringCol String, + NumCol UInt32 +) +ENGINE = MergeTree() +ORDER BY (StringCol); \ No newline at end of file