Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddl: refine some context usage #56243

Merged
merged 12 commits into from
Sep 27, 2024
2 changes: 1 addition & 1 deletion br/pkg/backup/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ func WriteBackupDDLJobs(metaWriter *metautil.MetaWriter, g glue.Glue, store kv.S
newestMeta := meta.NewSnapshotMeta(store.GetSnapshot(kv.NewVersion(version.Ver)))
var allJobs []*model.Job
err = g.UseOneShotSession(store, !needDomain, func(se glue.Session) error {
allJobs, err = ddl.GetAllDDLJobs(se.GetSessionCtx())
allJobs, err = ddl.GetAllDDLJobs(context.Background(), se.GetSessionCtx())
if err != nil {
return errors.Trace(err)
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/ddl/backfilling.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,9 @@ func (dc *ddlCtx) runAddIndexInLocalIngestMode(
return errors.Trace(err)
}
job := reorgInfo.Job
opCtx := NewLocalOperatorCtx(ctx, job.ID)
opCtx, cancel := NewLocalOperatorCtx(ctx, job.ID)
defer cancel()

idxCnt := len(reorgInfo.elements)
indexIDs := make([]int64, 0, idxCnt)
indexInfos := make([]*model.IndexInfo, 0, idxCnt)
Expand Down Expand Up @@ -705,11 +707,6 @@ func (dc *ddlCtx) runAddIndexInLocalIngestMode(
return errors.Trace(err)
}
defer ingest.LitBackCtxMgr.Unregister(job.ID)
sctx, err := sessPool.Get()
if err != nil {
return errors.Trace(err)
}
defer sessPool.Put(sctx)

cpMgr, err := ingest.NewCheckpointManager(
ctx,
Expand Down Expand Up @@ -737,6 +734,11 @@ func (dc *ddlCtx) runAddIndexInLocalIngestMode(
metrics.GenerateReorgLabel("add_idx_rate", job.SchemaName, job.TableName)),
}

sctx, err := sessPool.Get()
if err != nil {
return errors.Trace(err)
}
defer sessPool.Put(sctx)
avgRowSize := estimateTableRowSize(ctx, dc.store, sctx.GetRestrictedSQLExecutor(), t)

engines, err := bcCtx.Register(indexIDs, uniques, t)
Expand Down
24 changes: 12 additions & 12 deletions pkg/ddl/backfilling_operators.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,33 @@ type OperatorCtx struct {
}

// NewDistTaskOperatorCtx is used for adding index with dist framework.
func NewDistTaskOperatorCtx(ctx context.Context, taskID, subtaskID int64) *OperatorCtx {
func NewDistTaskOperatorCtx(
ctx context.Context,
taskID, subtaskID int64,
) (*OperatorCtx, context.CancelFunc) {
opCtx, cancel := context.WithCancel(ctx)
opCtx = logutil.WithFields(opCtx, zap.Int64("task-id", taskID), zap.Int64("subtask-id", subtaskID))
opCtx = logutil.WithFields(opCtx,
zap.Int64("task-id", taskID),
zap.Int64("subtask-id", subtaskID))
return &OperatorCtx{
Context: opCtx,
cancel: cancel,
}
}, cancel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just use OperatorCtx.Cancel?

Copy link
Contributor Author

@lance6716 lance6716 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the unused variable (cancel) can remind the caller. This PR fixes some functions that forget to cancel, like line 678 in pkg/ddl/backfilling.go in the old code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt the remind part, coders can simple ignore it as OperatorCtx already has it

We should unify the place to call cancel, not scattering it all around

Copy link
Contributor Author

@lance6716 lance6716 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The developers have forgotten to call OperatorCtx.Cancel at line 678 in pkg/ddl/backfilling.go (there's also a missing in unit test files), so I prefer to use golang's built-in detect of unused variables to prevent future bug.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static checks doesn't help those logic bugs, but ok anyway

}

// NewLocalOperatorCtx is used for adding index with local ingest mode.
func NewLocalOperatorCtx(ctx context.Context, jobID int64) *OperatorCtx {
func NewLocalOperatorCtx(ctx context.Context, jobID int64) (*OperatorCtx, context.CancelFunc) {
tangenta marked this conversation as resolved.
Show resolved Hide resolved
opCtx, cancel := context.WithCancel(ctx)
opCtx = logutil.WithFields(opCtx, zap.Int64("jobID", jobID))
return &OperatorCtx{
Context: opCtx,
cancel: cancel,
}
}, cancel
}

func (ctx *OperatorCtx) onError(err error) {
tracedErr := errors.Trace(err)
ctx.cancel()
ctx.err.CompareAndSwap(nil, &tracedErr)
}

// Cancel cancels the pipeline.
func (ctx *OperatorCtx) Cancel() {
ctx.cancel()
}

Expand Down Expand Up @@ -769,7 +769,7 @@ func (w *indexIngestLocalWorker) HandleTask(ck IndexRecordChunk, send func(Index
return
}
w.rowCntListener.Written(rs.Added)
flushed, imported, err := w.backendCtx.Flush(ingest.FlushModeAuto)
flushed, imported, err := w.backendCtx.Flush(w.ctx, ingest.FlushModeAuto)
if err != nil {
w.ctx.onError(err)
return
Expand Down Expand Up @@ -949,7 +949,7 @@ func (s *indexWriteResultSink) flush() error {
failpoint.Inject("mockFlushError", func(_ failpoint.Value) {
failpoint.Return(errors.New("mock flush error"))
})
flushed, imported, err := s.backendCtx.Flush(ingest.FlushModeForceFlushAndImport)
flushed, imported, err := s.backendCtx.Flush(s.ctx, ingest.FlushModeForceFlushAndImport)
if s.cpMgr != nil {
// Try to advance watermark even if there is an error.
s.cpMgr.AdvanceWatermark(flushed, imported)
Expand Down
4 changes: 2 additions & 2 deletions pkg/ddl/backfilling_read_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ func (r *readIndexExecutor) RunSubtask(ctx context.Context, subtask *proto.Subta
return err
}

opCtx := NewDistTaskOperatorCtx(ctx, subtask.TaskID, subtask.ID)
defer opCtx.Cancel()
opCtx, cancel := NewDistTaskOperatorCtx(ctx, subtask.TaskID, subtask.ID)
defer cancel()
r.curRowCount.Store(0)

if len(r.cloudStorageURI) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func checkAndSetFlashbackClusterInfo(ctx context.Context, se sessionctx.Context,
}
}

jobs, err := GetAllDDLJobs(se)
jobs, err := GetAllDDLJobs(ctx, se)
if err != nil {
return errors.Trace(err)
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/ddl/db_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1297,11 +1297,12 @@ func prepareTestControlParallelExecSQL(t *testing.T, store kv.Storage) (*testkit
return
}
var qLen int
ctx := context.Background()
for {
sess := testkit.NewTestKit(t, store).Session()
err := sessiontxn.NewTxn(context.Background(), sess)
err := sessiontxn.NewTxn(ctx, sess)
require.NoError(t, err)
jobs, err := ddl.GetAllDDLJobs(sess)
jobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
qLen = len(jobs)
if qLen == 2 {
Expand All @@ -1321,11 +1322,12 @@ func prepareTestControlParallelExecSQL(t *testing.T, store kv.Storage) (*testkit
// Make sure the sql1 is put into the DDLJobQueue.
go func() {
var qLen int
ctx := context.Background()
for {
sess := testkit.NewTestKit(t, store).Session()
err := sessiontxn.NewTxn(context.Background(), sess)
err := sessiontxn.NewTxn(ctx, sess)
require.NoError(t, err)
jobs, err := ddl.GetAllDDLJobs(sess)
jobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
qLen = len(jobs)
if qLen == 1 {
Expand Down
61 changes: 35 additions & 26 deletions pkg/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1229,15 +1229,16 @@ func GetDDLInfo(s sessionctx.Context) (*Info, error) {

func get2JobsFromTable(sess *sess.Session) (*model.Job, *model.Job, error) {
var generalJob, reorgJob *model.Job
jobs, err := getJobsBySQL(sess, JobTable, "not reorg order by job_id limit 1")
ctx := context.Background()
jobs, err := getJobsBySQL(ctx, sess, JobTable, "not reorg order by job_id limit 1")
if err != nil {
return nil, nil, errors.Trace(err)
}

if len(jobs) != 0 {
generalJob = jobs[0]
}
jobs, err = getJobsBySQL(sess, JobTable, "reorg order by job_id limit 1")
jobs, err = getJobsBySQL(ctx, sess, JobTable, "reorg order by job_id limit 1")
if err != nil {
return nil, nil, errors.Trace(err)
}
Expand Down Expand Up @@ -1309,6 +1310,7 @@ func resumePausedJob(_ *sess.Session, job *model.Job,

// processJobs command on the Job according to the process
func processJobs(
ctx context.Context,
process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error),
sessCtx sessionctx.Context,
ids []int64,
Expand Down Expand Up @@ -1336,11 +1338,11 @@ func processJobs(
idsStr = append(idsStr, strconv.FormatInt(id, 10))
}

err = ns.Begin(context.Background())
err = ns.Begin(ctx)
if err != nil {
return nil, err
}
jobs, err := getJobsBySQL(ns, JobTable, fmt.Sprintf("job_id in (%s) order by job_id", strings.Join(idsStr, ", ")))
jobs, err := getJobsBySQL(ctx, ns, JobTable, fmt.Sprintf("job_id in (%s) order by job_id", strings.Join(idsStr, ", ")))
if err != nil {
ns.Rollback()
return nil, err
Expand All @@ -1362,7 +1364,7 @@ func processJobs(
continue
}

err = updateDDLJob2Table(ns, job, false)
err = updateDDLJob2Table(ctx, ns, job, false)
if err != nil {
jobErrs[i] = err
continue
Expand All @@ -1376,7 +1378,7 @@ func processJobs(
})

// There may be some conflict during the update, try it again
if err = ns.Commit(context.Background()); err != nil {
if err = ns.Commit(ctx); err != nil {
continue
}

Expand All @@ -1391,43 +1393,50 @@ func processJobs(
}

// CancelJobs cancels the DDL jobs according to user command.
func CancelJobs(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(cancelRunningJob, se, ids, model.AdminCommandByEndUser)
func CancelJobs(ctx context.Context, se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(ctx, cancelRunningJob, se, ids, model.AdminCommandByEndUser)
}

// PauseJobs pause all the DDL jobs according to user command.
func PauseJobs(se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(pauseRunningJob, se, ids, model.AdminCommandByEndUser)
func PauseJobs(ctx context.Context, se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(ctx, pauseRunningJob, se, ids, model.AdminCommandByEndUser)
}

// ResumeJobs resume all the DDL jobs according to user command.
func ResumeJobs(se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(resumePausedJob, se, ids, model.AdminCommandByEndUser)
func ResumeJobs(ctx context.Context, se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(ctx, resumePausedJob, se, ids, model.AdminCommandByEndUser)
}

// CancelJobsBySystem cancels Jobs because of internal reasons.
func CancelJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(cancelRunningJob, se, ids, model.AdminCommandBySystem)
ctx := context.Background()
return processJobs(ctx, cancelRunningJob, se, ids, model.AdminCommandBySystem)
}

// PauseJobsBySystem pauses Jobs because of internal reasons.
func PauseJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(pauseRunningJob, se, ids, model.AdminCommandBySystem)
ctx := context.Background()
return processJobs(ctx, pauseRunningJob, se, ids, model.AdminCommandBySystem)
}

// ResumeJobsBySystem resumes Jobs that are paused by TiDB itself.
func ResumeJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(resumePausedJob, se, ids, model.AdminCommandBySystem)
ctx := context.Background()
return processJobs(ctx, resumePausedJob, se, ids, model.AdminCommandBySystem)
}

// pprocessAllJobs processes all the jobs in the job table, 100 jobs at a time in case of high memory usage.
func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error),
se sessionctx.Context, byWho model.AdminCommandOperator) (map[int64]error, error) {
func processAllJobs(
ctx context.Context,
process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error),
se sessionctx.Context,
byWho model.AdminCommandOperator,
) (map[int64]error, error) {
var err error
var jobErrs = make(map[int64]error)

ns := sess.NewSession(se)
err = ns.Begin(context.Background())
err = ns.Begin(ctx)
if err != nil {
return nil, err
}
Expand All @@ -1437,7 +1446,7 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp
var limit = 100
for {
var jobs []*model.Job
jobs, err = getJobsBySQL(ns, JobTable,
jobs, err = getJobsBySQL(ctx, ns, JobTable,
fmt.Sprintf("job_id >= %s order by job_id asc limit %s",
strconv.FormatInt(jobID, 10),
strconv.FormatInt(int64(limit), 10)))
Expand All @@ -1453,7 +1462,7 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp
continue
}

err = updateDDLJob2Table(ns, job, false)
err = updateDDLJob2Table(ctx, ns, job, false)
if err != nil {
jobErrs[job.ID] = err
continue
Expand All @@ -1473,7 +1482,7 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp
jobID = jobIDMax + 1
}

err = ns.Commit(context.Background())
err = ns.Commit(ctx)
if err != nil {
return nil, err
}
Expand All @@ -1482,23 +1491,23 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp

// PauseAllJobsBySystem pauses all running Jobs because of internal reasons.
func PauseAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) {
return processAllJobs(pauseRunningJob, se, model.AdminCommandBySystem)
return processAllJobs(context.Background(), pauseRunningJob, se, model.AdminCommandBySystem)
}

// ResumeAllJobsBySystem resumes all paused Jobs because of internal reasons.
func ResumeAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) {
return processAllJobs(resumePausedJob, se, model.AdminCommandBySystem)
return processAllJobs(context.Background(), resumePausedJob, se, model.AdminCommandBySystem)
}

// GetAllDDLJobs get all DDL jobs and sorts jobs by job.ID.
func GetAllDDLJobs(se sessionctx.Context) ([]*model.Job, error) {
return getJobsBySQL(sess.NewSession(se), JobTable, "1 order by job_id")
func GetAllDDLJobs(ctx context.Context, se sessionctx.Context) ([]*model.Job, error) {
return getJobsBySQL(ctx, sess.NewSession(se), JobTable, "1 order by job_id")
}

// IterAllDDLJobs will iterates running DDL jobs first, return directly if `finishFn` return true or error,
// then iterates history DDL jobs until the `finishFn` return true or error.
func IterAllDDLJobs(ctx sessionctx.Context, txn kv.Transaction, finishFn func([]*model.Job) (bool, error)) error {
jobs, err := GetAllDDLJobs(ctx)
jobs, err := GetAllDDLJobs(context.Background(), ctx)
if err != nil {
return err
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/ddl/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func TestGetDDLJobs(t *testing.T) {

cnt := 10
jobs := make([]*model.Job, cnt)
ctx := context.Background()
var currJobs2 []*model.Job
for i := 0; i < cnt; i++ {
jobs[i] = &model.Job{
Expand All @@ -59,7 +60,7 @@ func TestGetDDLJobs(t *testing.T) {
err := addDDLJobs(sess, txn, jobs[i])
require.NoError(t, err)

currJobs, err := ddl.GetAllDDLJobs(sess)
currJobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
require.Len(t, currJobs, i+1)

Expand All @@ -77,7 +78,7 @@ func TestGetDDLJobs(t *testing.T) {
require.Len(t, currJobs2, i+1)
}

currJobs, err := ddl.GetAllDDLJobs(sess)
currJobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)

for i, job := range jobs {
Expand All @@ -93,6 +94,7 @@ func TestGetDDLJobs(t *testing.T) {

func TestGetDDLJobsIsSort(t *testing.T) {
store := testkit.CreateMockStore(t)
ctx := context.Background()

sess := testkit.NewTestKit(t, store).Session()
_, err := sess.Execute(context.Background(), "begin")
Expand All @@ -110,7 +112,7 @@ func TestGetDDLJobsIsSort(t *testing.T) {
// insert add index jobs to AddIndexJobListKey queue
enQueueDDLJobs(t, sess, txn, model.ActionAddIndex, 5, 10)

currJobs, err := ddl.GetAllDDLJobs(sess)
currJobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
require.Len(t, currJobs, 15)

Expand Down
3 changes: 2 additions & 1 deletion pkg/ddl/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func FetchChunk4Test(copCtx copr.CopContext, tbl table.PhysicalTable, startKey,
for i := 0; i < 10; i++ {
srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, batchSize)
}
opCtx := ddl.NewLocalOperatorCtx(context.Background(), 1)
opCtx, cancel := ddl.NewLocalOperatorCtx(context.Background(), 1)
defer cancel()
src := testutil.NewOperatorTestSource(ddl.TableScanTask{ID: 1, Start: startKey, End: endKey})
scanOp := ddl.NewTableScanOperator(opCtx, sessPool, copCtx, srcChkPool, 1, nil, 0)
sink := testutil.NewOperatorTestSink[ddl.IndexRecordChunk]()
Expand Down
Loading