diff --git a/db.go b/db.go index 521b6e9..eedae3d 100644 --- a/db.go +++ b/db.go @@ -27,6 +27,8 @@ import ( "time" "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" + otlog "github.com/opentracing/opentracing-go/log" "golang.org/x/net/context" ) @@ -710,9 +712,9 @@ func (db *DB) Exec(query interface{}, args ...interface{}) (sql.Result, error) { } // ExecContext is the context version of Exec. -func (db *DB) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) { - ctx, span, finishTrace := db.addTracerToContext(ctx, opExec) - defer finishTrace() +func (db *DB) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (_ sql.Result, err error) { + ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opExec) + defer finishTrace(&err) serializer, err := db.getSerializer(query) if err != nil { @@ -824,9 +826,9 @@ func rewriteQuery(ctx context.Context, db *DB, start time.Time, q string) (strin } // QueryContext is the context version of Query. -func (db *DB) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (*Rows, error) { - ctx, span, finishTrace := db.addTracerToContext(ctx, opQuery) - defer finishTrace() +func (db *DB) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (_ *Rows, err error) { + ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opQuery) + defer finishTrace(&err) start := time.Now() serializer, err := db.getSerializer(q) @@ -865,9 +867,13 @@ func (db *DB) QueryRow(q interface{}, args ...interface{}) *Row { } // QueryRowContext is the context version of QueryRow. -func (db *DB) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) *Row { - ctx, span, finishTrace := db.addTracerToContext(ctx, opQueryRow) - defer finishTrace() +func (db *DB) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) (row *Row) { + ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opQueryRow) + defer func() { + if row.err != nil { + finishTrace(&row.err) + } + }() start := time.Now() serializer, err := db.getSerializer(q) @@ -972,9 +978,9 @@ func (db *DB) UpsertContext(ctx context.Context, list ...interface{}) error { // Begin begins a transaction and returns a *squalor.Tx instead of a // *sql.Tx. -func (db *DB) Begin() (*Tx, error) { - _, _, finishTrace := db.addTracerToContext(db.Context(), "begin") - defer finishTrace() +func (db *DB) Begin() (_ *Tx, err error) { + _, _, finishTrace := db.addTracerToContext(db.Context(), "squalor.begin") + defer finishTrace(&err) tx, err := begin(db) if err != nil { @@ -1031,33 +1037,46 @@ func (tx *Tx) AddPostCommitHook(post PostCommit) { // Commit is a wrapper around sql.Tx.Commit() which also provides pre- and post- // commit hooks. -func (tx *Tx) Commit() error { - _, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "commit") - defer finishTrace() +func (tx *Tx) Commit() (err error) { + _, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "squalor.commit") + defer finishTrace(&err) for _, pre := range tx.preHooks { if err := pre(tx); err != nil { return err } } - err := tx.Tx.Commit() + err = tx.Tx.Commit() for _, post := range tx.postHooks { post(err) } return err } -func (tx *Tx) Rollback() error { - _, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "rollback") - defer finishTrace() +// CommitContext is a wrapper around sql.Tx.Commit() which also provides pre- and post- +// commit hooks. The given Context is passed to any commit hooks. +func (tx *Tx) CommitContext(ctx context.Context) error { + if ctx == nil { + panic("Nil Context passed to Executor.CommitContext") + } + return tx.withContextHelper(ctx).Commit() +} + +func (tx *Tx) Rollback() (err error) { + _, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "squalor.rollback") + defer finishTrace(&err) return tx.Tx.Rollback() } -func (tx *Tx) WithContext(ctx context.Context) ExecutorContext { +func (tx *Tx) RollbackContext(ctx context.Context) error { if ctx == nil { - panic("Nil Context passed to Executor.WithContext") + panic("Nil Context passed to Executor.RollbackContext") } + return tx.withContextHelper(ctx).Rollback() +} + +func (tx *Tx) withContextHelper(ctx context.Context) *Tx { newTx := *tx newDB := *newTx.DB newDB.context = ctx @@ -1065,6 +1084,13 @@ func (tx *Tx) WithContext(ctx context.Context) ExecutorContext { return &newTx } +func (tx *Tx) WithContext(ctx context.Context) ExecutorContext { + if ctx == nil { + panic("Nil Context passed to Executor.WithContext") + } + return tx.withContextHelper(ctx) +} + func (tx *Tx) Context() context.Context { return tx.DB.Context() } @@ -1078,9 +1104,9 @@ func (tx *Tx) Exec(query interface{}, args ...interface{}) (sql.Result, error) { const MaxTracerStatementLength = 400 // ExecContext executes a query using the provided context. -func (tx *Tx) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) { - ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, opExec) - defer finishTrace() +func (tx *Tx) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (_ sql.Result, err error) { + ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, "squalor."+opExec) + defer finishTrace(&err) serializer, err := tx.DB.getSerializer(query) if err != nil { @@ -1181,9 +1207,9 @@ func (tx *Tx) Query(q interface{}, args ...interface{}) (*Rows, error) { } // QueryContext is the context version of Query. -func (tx *Tx) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (*Rows, error) { - ctx, _, finishTrace := tx.DB.addTracerToContext(ctx, opQuery) - defer finishTrace() +func (tx *Tx) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (_ *Rows, err error) { + ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, "squalor."+opQuery) + defer finishTrace(&err) start := time.Now() serializer, err := tx.DB.getSerializer(q) @@ -1194,6 +1220,10 @@ func (tx *Tx) QueryContext(ctx context.Context, q interface{}, args ...interface if err != nil { return nil, err } + if span != nil { + span.SetTag("db.statement", truncate(queryStr, MaxTracerStatementLength)) + } + queryStr, err = rewriteQuery(ctx, tx.DB, start, queryStr) if err != nil { return nil, err @@ -1218,9 +1248,13 @@ func (tx *Tx) QueryRow(q interface{}, args ...interface{}) *Row { } // QueryRowContext is the context version of QueryRow. -func (tx *Tx) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) *Row { - ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, opQueryRow) - defer finishTrace() +func (tx *Tx) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) (row *Row) { + ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, "squalor."+opQueryRow) + defer func() { + if row.err != nil { + finishTrace(&row.err) + } + }() start := time.Now() serializer, err := tx.DB.getSerializer(q) @@ -1716,9 +1750,13 @@ func deleteModel(ctx context.Context, model *Model, exec Executor, list []interf return count, nil } -func deleteObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (int64, error) { - ctx, _, finishTrace := db.addTracerToContext(ctx, opDelete) - defer finishTrace() +func deleteObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (_ int64, err error) { + ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opDelete) + defer finishTrace(&err) + if span != nil { + span.SetTag("db.objects_to_delete", len(list)) + } + objs, err := groupObjects(db, list) if err != nil { return -1, err @@ -1736,9 +1774,9 @@ func deleteObjects(ctx context.Context, db *DB, exec Executor, list []interface{ return count, nil } -func getObject(ctx context.Context, db *DB, exec Executor, obj interface{}, keys []interface{}) error { - ctx, _, finishTrace := db.addTracerToContext(ctx, opGet) - defer finishTrace() +func getObject(ctx context.Context, db *DB, exec Executor, obj interface{}, keys []interface{}) (err error) { + ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+opGet) + defer finishTrace(&err) objT := reflect.TypeOf(obj) if objT.Kind() != reflect.Ptr { return fmt.Errorf("obj must be a pointer: %T", obj) @@ -1896,9 +1934,9 @@ func insertModel(ctx context.Context, model *Model, exec Executor, getPlan func( return nil } -func insertObjects(ctx context.Context, db *DB, exec Executor, getPlan func(m *Model) insertPlan, list []interface{}, name operationName) error { - ctx, _, finishTrace := db.addTracerToContext(ctx, name) - defer finishTrace() +func insertObjects(ctx context.Context, db *DB, exec Executor, getPlan func(m *Model) insertPlan, list []interface{}, name operationName) (err error) { + ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+name) + defer finishTrace(&err) objs, err := groupObjects(db, list) if err != nil { return err @@ -1912,9 +1950,9 @@ func insertObjects(ctx context.Context, db *DB, exec Executor, getPlan func(m *M return nil } -func selectObjects(ctx context.Context, db *DB, exec Executor, dest interface{}, query interface{}, args []interface{}) error { - ctx, _, finishTrace := db.addTracerToContext(ctx, opSelect) - defer finishTrace() +func selectObjects(ctx context.Context, db *DB, exec Executor, dest interface{}, query interface{}, args []interface{}) (err error) { + ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+opSelect) + defer finishTrace(&err) sliceValue := reflect.ValueOf(dest) if sliceValue.Kind() != reflect.Ptr { return fmt.Errorf("dest must be a pointer to a slice: %T", dest) @@ -2021,9 +2059,9 @@ func updateModel(ctx context.Context, model *Model, exec Executor, list []interf return count, nil } -func updateObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (int64, error) { - ctx, _, finishTrace := db.addTracerToContext(ctx, opUpdate) - defer finishTrace() +func updateObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (_ int64, err error) { + ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+opUpdate) + defer finishTrace(&err) objs, err := groupObjects(db, list) if err != nil { return -1, err @@ -2040,21 +2078,31 @@ func updateObjects(ctx context.Context, db *DB, exec Executor, list []interface{ return count, nil } -// addTracerToContext returns a new ctx with tracer attached to it and a method to finish the trace +// addTracerToContext returns a new ctx with tracer attached to it and a method to finish the trace. +// A pointer to the error returned from the enclosing function (or nil) should be passed to the finish method +// within a defer statement. // -// finishing the trace would finish the traced span, so code should call finish as soon as the +// Finishing the trace would finish the traced span, so code should call finish as soon as the // operations complete: // -// func execWithTracer(ctx context.Context) { +// func execWithTracer(ctx context.Context) (err error) { // ctx, span, finishTracer := addTracerToContext(ctx, "exec") -// defer finishTracer() +// defer finishTracer(&err) // return exec(ctx) // } -func (db *DB) addTracerToContext(ctx context.Context, name operationName) (tracedCtx context.Context, span opentracing.Span, finishTrace func()) { +// +// The function has no effect if db.OpentracingEnabled is false. +func (db *DB) addTracerToContext(ctx context.Context, name operationName) (tracedCtx context.Context, span opentracing.Span, finishTrace func(errPtr *error)) { if db.OpentracingEnabled { span, tracedCtx = opentracing.StartSpanFromContext(ctx, string(name)) - return tracedCtx, span, span.Finish + return tracedCtx, span, func(errPtr *error) { + if errPtr != nil { + ext.Error.Set(span, true) + span.LogFields(otlog.Error(*errPtr)) + } + span.Finish() + } } - return ctx, nil, func() {} + return ctx, nil, func(*error) {} } diff --git a/table_test.go b/table_test.go index 4c0329d..c3a2723 100644 --- a/table_test.go +++ b/table_test.go @@ -50,10 +50,10 @@ func TestLoadTableNameInjection(t *testing.T) { t.Fatalf("Expected error %q from injection attempt, got nil", expectedError) } - // Use strings.ToLower in order to prevent error message discrepancies + // Use strings.EqualFold in order to prevent error message discrepancies // between running tests locally through integration_test.sh and the // build running through CI. - if strings.EqualFold(expectedError) { + if !strings.EqualFold(expectedError, err.Error()) { t.Fatalf("Expected error %q from injection attempt, got %q", expectedError, err.Error()) } diff --git a/util.go b/util.go index 0b38f1e..ead87e1 100644 --- a/util.go +++ b/util.go @@ -33,7 +33,7 @@ func recoveryToError(r interface{}) error { // Rune-based string truncation with ellipsis with reasonable support for multibyte characters. func truncate(s string, n int) string { - if n == 0 { + if n <= 0 { return "" } if len(s) < n {