Skip to content

Commit

Permalink
Add more opentracing information
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Rice committed Oct 31, 2023
1 parent 1a9a328 commit 4ea67f6
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 50 deletions.
140 changes: 92 additions & 48 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"time"

"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
otlog "github.com/opentracing/opentracing-go/log"
"golang.org/x/net/context"
)

Expand Down Expand Up @@ -710,9 +712,9 @@ func (db *DB) Exec(query interface{}, args ...interface{}) (sql.Result, error) {
}

// ExecContext is the context version of Exec.
func (db *DB) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) {
ctx, span, finishTrace := db.addTracerToContext(ctx, opExec)
defer finishTrace()
func (db *DB) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (_ sql.Result, err error) {
ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opExec)
defer finishTrace(&err)

serializer, err := db.getSerializer(query)
if err != nil {
Expand Down Expand Up @@ -824,9 +826,9 @@ func rewriteQuery(ctx context.Context, db *DB, start time.Time, q string) (strin
}

// QueryContext is the context version of Query.
func (db *DB) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (*Rows, error) {
ctx, span, finishTrace := db.addTracerToContext(ctx, opQuery)
defer finishTrace()
func (db *DB) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (_ *Rows, err error) {
ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opQuery)
defer finishTrace(&err)
start := time.Now()

serializer, err := db.getSerializer(q)
Expand Down Expand Up @@ -865,9 +867,13 @@ func (db *DB) QueryRow(q interface{}, args ...interface{}) *Row {
}

// QueryRowContext is the context version of QueryRow.
func (db *DB) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) *Row {
ctx, span, finishTrace := db.addTracerToContext(ctx, opQueryRow)
defer finishTrace()
func (db *DB) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) (row *Row) {
ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opQueryRow)
defer func() {
if row.err != nil {
finishTrace(&row.err)
}
}()
start := time.Now()

serializer, err := db.getSerializer(q)
Expand Down Expand Up @@ -972,9 +978,9 @@ func (db *DB) UpsertContext(ctx context.Context, list ...interface{}) error {

// Begin begins a transaction and returns a *squalor.Tx instead of a
// *sql.Tx.
func (db *DB) Begin() (*Tx, error) {
_, _, finishTrace := db.addTracerToContext(db.Context(), "begin")
defer finishTrace()
func (db *DB) Begin() (_ *Tx, err error) {
_, _, finishTrace := db.addTracerToContext(db.Context(), "squalor.begin")
defer finishTrace(&err)

tx, err := begin(db)
if err != nil {
Expand Down Expand Up @@ -1031,40 +1037,60 @@ func (tx *Tx) AddPostCommitHook(post PostCommit) {

// Commit is a wrapper around sql.Tx.Commit() which also provides pre- and post-
// commit hooks.
func (tx *Tx) Commit() error {
_, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "commit")
defer finishTrace()
func (tx *Tx) Commit() (err error) {
_, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "squalor.commit")
defer finishTrace(&err)

for _, pre := range tx.preHooks {
if err := pre(tx); err != nil {
return err
}
}
err := tx.Tx.Commit()
err = tx.Tx.Commit()
for _, post := range tx.postHooks {
post(err)
}
return err
}

func (tx *Tx) Rollback() error {
_, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "rollback")
defer finishTrace()
// CommitContext is a wrapper around sql.Tx.Commit() which also provides pre- and post-
// commit hooks. The given Context is passed to any commit hooks.
func (tx *Tx) CommitContext(ctx context.Context) error {
if ctx == nil {
panic("Nil Context passed to Executor.CommitContext")
}
return tx.withContextHelper(ctx).Commit()
}

func (tx *Tx) Rollback() (err error) {
_, _, finishTrace := tx.DB.addTracerToContext(tx.Context(), "squalor.rollback")
defer finishTrace(&err)

return tx.Tx.Rollback()
}

func (tx *Tx) WithContext(ctx context.Context) ExecutorContext {
func (tx *Tx) RollbackContext(ctx context.Context) error {
if ctx == nil {
panic("Nil Context passed to Executor.WithContext")
panic("Nil Context passed to Executor.RollbackContext")
}
return tx.withContextHelper(ctx).Rollback()
}

func (tx *Tx) withContextHelper(ctx context.Context) *Tx {
newTx := *tx
newDB := *newTx.DB
newDB.context = ctx
newTx.DB = &newDB
return &newTx
}

func (tx *Tx) WithContext(ctx context.Context) ExecutorContext {
if ctx == nil {
panic("Nil Context passed to Executor.WithContext")
}
return tx.withContextHelper(ctx)
}

func (tx *Tx) Context() context.Context {
return tx.DB.Context()
}
Expand All @@ -1078,9 +1104,9 @@ func (tx *Tx) Exec(query interface{}, args ...interface{}) (sql.Result, error) {
const MaxTracerStatementLength = 400

// ExecContext executes a query using the provided context.
func (tx *Tx) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) {
ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, opExec)
defer finishTrace()
func (tx *Tx) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (_ sql.Result, err error) {
ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, "squalor."+opExec)
defer finishTrace(&err)

serializer, err := tx.DB.getSerializer(query)
if err != nil {
Expand Down Expand Up @@ -1181,9 +1207,9 @@ func (tx *Tx) Query(q interface{}, args ...interface{}) (*Rows, error) {
}

// QueryContext is the context version of Query.
func (tx *Tx) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (*Rows, error) {
ctx, _, finishTrace := tx.DB.addTracerToContext(ctx, opQuery)
defer finishTrace()
func (tx *Tx) QueryContext(ctx context.Context, q interface{}, args ...interface{}) (_ *Rows, err error) {
ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, "squalor."+opQuery)
defer finishTrace(&err)
start := time.Now()

serializer, err := tx.DB.getSerializer(q)
Expand All @@ -1194,6 +1220,10 @@ func (tx *Tx) QueryContext(ctx context.Context, q interface{}, args ...interface
if err != nil {
return nil, err
}
if span != nil {
span.SetTag("db.statement", truncate(queryStr, MaxTracerStatementLength))
}

queryStr, err = rewriteQuery(ctx, tx.DB, start, queryStr)
if err != nil {
return nil, err
Expand All @@ -1218,9 +1248,13 @@ func (tx *Tx) QueryRow(q interface{}, args ...interface{}) *Row {
}

// QueryRowContext is the context version of QueryRow.
func (tx *Tx) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) *Row {
ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, opQueryRow)
defer finishTrace()
func (tx *Tx) QueryRowContext(ctx context.Context, q interface{}, args ...interface{}) (row *Row) {
ctx, span, finishTrace := tx.DB.addTracerToContext(ctx, "squalor."+opQueryRow)
defer func() {
if row.err != nil {
finishTrace(&row.err)
}
}()
start := time.Now()

serializer, err := tx.DB.getSerializer(q)
Expand Down Expand Up @@ -1716,9 +1750,13 @@ func deleteModel(ctx context.Context, model *Model, exec Executor, list []interf
return count, nil
}

func deleteObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (int64, error) {
ctx, _, finishTrace := db.addTracerToContext(ctx, opDelete)
defer finishTrace()
func deleteObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (_ int64, err error) {
ctx, span, finishTrace := db.addTracerToContext(ctx, "squalor."+opDelete)
defer finishTrace(&err)
if span != nil {
span.SetTag("db.objects_to_delete", len(list))
}

objs, err := groupObjects(db, list)
if err != nil {
return -1, err
Expand All @@ -1736,9 +1774,9 @@ func deleteObjects(ctx context.Context, db *DB, exec Executor, list []interface{
return count, nil
}

func getObject(ctx context.Context, db *DB, exec Executor, obj interface{}, keys []interface{}) error {
ctx, _, finishTrace := db.addTracerToContext(ctx, opGet)
defer finishTrace()
func getObject(ctx context.Context, db *DB, exec Executor, obj interface{}, keys []interface{}) (err error) {
ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+opGet)
defer finishTrace(&err)
objT := reflect.TypeOf(obj)
if objT.Kind() != reflect.Ptr {
return fmt.Errorf("obj must be a pointer: %T", obj)
Expand Down Expand Up @@ -1896,9 +1934,9 @@ func insertModel(ctx context.Context, model *Model, exec Executor, getPlan func(
return nil
}

func insertObjects(ctx context.Context, db *DB, exec Executor, getPlan func(m *Model) insertPlan, list []interface{}, name operationName) error {
ctx, _, finishTrace := db.addTracerToContext(ctx, name)
defer finishTrace()
func insertObjects(ctx context.Context, db *DB, exec Executor, getPlan func(m *Model) insertPlan, list []interface{}, name operationName) (err error) {
ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+name)
defer finishTrace(&err)
objs, err := groupObjects(db, list)
if err != nil {
return err
Expand All @@ -1912,9 +1950,9 @@ func insertObjects(ctx context.Context, db *DB, exec Executor, getPlan func(m *M
return nil
}

func selectObjects(ctx context.Context, db *DB, exec Executor, dest interface{}, query interface{}, args []interface{}) error {
ctx, _, finishTrace := db.addTracerToContext(ctx, opSelect)
defer finishTrace()
func selectObjects(ctx context.Context, db *DB, exec Executor, dest interface{}, query interface{}, args []interface{}) (err error) {
ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+opSelect)
defer finishTrace(&err)
sliceValue := reflect.ValueOf(dest)
if sliceValue.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer to a slice: %T", dest)
Expand Down Expand Up @@ -2021,9 +2059,9 @@ func updateModel(ctx context.Context, model *Model, exec Executor, list []interf
return count, nil
}

func updateObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (int64, error) {
ctx, _, finishTrace := db.addTracerToContext(ctx, opUpdate)
defer finishTrace()
func updateObjects(ctx context.Context, db *DB, exec Executor, list []interface{}) (_ int64, err error) {
ctx, _, finishTrace := db.addTracerToContext(ctx, "squalor."+opUpdate)
defer finishTrace(&err)
objs, err := groupObjects(db, list)
if err != nil {
return -1, err
Expand All @@ -2050,11 +2088,17 @@ func updateObjects(ctx context.Context, db *DB, exec Executor, list []interface{
// defer finishTracer()
// return exec(ctx)
// }
func (db *DB) addTracerToContext(ctx context.Context, name operationName) (tracedCtx context.Context, span opentracing.Span, finishTrace func()) {
func (db *DB) addTracerToContext(ctx context.Context, name operationName) (tracedCtx context.Context, span opentracing.Span, finishTrace func(errPtr *error)) {
if db.OpentracingEnabled {
span, tracedCtx = opentracing.StartSpanFromContext(ctx, string(name))
return tracedCtx, span, span.Finish
return tracedCtx, span, func(errPtr *error) {
if errPtr != nil {
ext.Error.Set(span, true)
span.LogFields(otlog.Error(*errPtr))
}
span.Finish()
}
}

return ctx, nil, func() {}
return ctx, nil, func(*error) {}
}
4 changes: 2 additions & 2 deletions table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ func TestLoadTableNameInjection(t *testing.T) {
t.Fatalf("Expected error %q from injection attempt, got nil", expectedError)
}

// Use strings.ToLower in order to prevent error message discrepancies
// Use strings.EqualFold in order to prevent error message discrepancies
// between running tests locally through integration_test.sh and the
// build running through CI.
if strings.EqualFold(expectedError) {
if !strings.EqualFold(expectedError, err.Error()) {
t.Fatalf("Expected error %q from injection attempt, got %q", expectedError, err.Error())
}

Expand Down

0 comments on commit 4ea67f6

Please sign in to comment.