diff --git a/pkg/ddl/BUILD.bazel b/pkg/ddl/BUILD.bazel index 6a9118fe09e7e..f952cd7fc4cff 100644 --- a/pkg/ddl/BUILD.bazel +++ b/pkg/ddl/BUILD.bazel @@ -315,7 +315,6 @@ go_test( "//pkg/parser/mysql", "//pkg/parser/terror", "//pkg/parser/types", - "//pkg/privilege", "//pkg/server", "//pkg/session", "//pkg/session/types", diff --git a/pkg/ddl/backfilling_test.go b/pkg/ddl/backfilling_test.go index ac0567e10103a..cb769947e4d72 100644 --- a/pkg/ddl/backfilling_test.go +++ b/pkg/ddl/backfilling_test.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/privilege" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" @@ -182,26 +181,6 @@ func assertStaticExprContextEqual(t *testing.T, sctx sessionctx.Context, exprCtx require.NoError(t, err) }, }, - { - field: "requestVerificationFn", - check: func(ctx *exprstatic.EvalContext) { - // RequestVerification should allow all privileges - // that is the same with input session context (GetPrivilegeManager returns nil). - require.Nil(t, privilege.GetPrivilegeManager(sctx)) - require.True(t, sctx.GetExprCtx().GetEvalCtx().RequestVerification("any", "any", "any", mysql.CreatePriv)) - require.True(t, ctx.RequestVerification("any", "any", "any", mysql.CreatePriv)) - }, - }, - { - field: "requestDynamicVerificationFn", - check: func(ctx *exprstatic.EvalContext) { - // RequestDynamicVerification should allow all privileges - // that is the same with input session context (GetPrivilegeManager returns nil). - require.Nil(t, privilege.GetPrivilegeManager(sctx)) - require.True(t, sctx.GetExprCtx().GetEvalCtx().RequestDynamicVerification("RESTRICTED_USER_ADMIN", true)) - require.True(t, ctx.RequestDynamicVerification("RESTRICTED_USER_ADMIN", true)) - }, - }, } // check ExprContext except EvalContext diff --git a/pkg/expression/builtin.go b/pkg/expression/builtin.go index 4ddba2d8d9654..bbe1f2ac0fc04 100644 --- a/pkg/expression/builtin.go +++ b/pkg/expression/builtin.go @@ -979,7 +979,7 @@ var funcs = map[string]functionClass{ ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}}, ast.TiDBDecodePlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodePlan, 1, 1}}, ast.TiDBDecodeBinaryPlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodeBinaryPlan, 1, 1}}, - ast.TiDBDecodeSQLDigests: &tidbDecodeSQLDigestsFunctionClass{baseFunctionClass{ast.TiDBDecodeSQLDigests, 1, 2}}, + ast.TiDBDecodeSQLDigests: &tidbDecodeSQLDigestsFunctionClass{baseFunctionClass: baseFunctionClass{ast.TiDBDecodeSQLDigests, 1, 2}}, ast.TiDBEncodeSQLDigest: &tidbEncodeSQLDigestFunctionClass{baseFunctionClass{ast.TiDBEncodeSQLDigest, 1, 1}}, // TiDB Sequence function. diff --git a/pkg/expression/builtin_info.go b/pkg/expression/builtin_info.go index 2b7a7ef51d9a7..6fe7ad10511df 100644 --- a/pkg/expression/builtin_info.go +++ b/pkg/expression/builtin_info.go @@ -919,11 +919,13 @@ func (c *tidbMVCCInfoFunctionClass) getFunction(ctx BuildContext, args []Express type builtinTiDBMVCCInfoSig struct { baseBuiltinFunc expropt.KVStorePropReader + expropt.PrivilegeCheckerPropReader } // RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. func (b *builtinTiDBMVCCInfoSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { - return b.KVStorePropReader.RequiredOptionalEvalProps() + return b.KVStorePropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *builtinTiDBMVCCInfoSig) Clone() builtinFunc { @@ -934,7 +936,11 @@ func (b *builtinTiDBMVCCInfoSig) Clone() builtinFunc { // evalString evals a builtinTiDBMVCCInfoSig. func (b *builtinTiDBMVCCInfoSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - if !ctx.RequestVerification("", "", "", mysql.SuperPriv) { + privChecker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return "", false, err + } + if !privChecker.RequestVerification("", "", "", mysql.SuperPriv) { return "", false, plannererrors.ErrSpecificAccessDenied.FastGenByArgs("SUPER") } s, isNull, err := b.args[0].EvalString(ctx, row) @@ -1006,12 +1012,14 @@ type builtinTiDBEncodeRecordKeySig struct { baseBuiltinFunc expropt.InfoSchemaPropReader expropt.SessionVarsPropReader + expropt.PrivilegeCheckerPropReader } // RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. func (b *builtinTiDBEncodeRecordKeySig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { return b.InfoSchemaPropReader.RequiredOptionalEvalProps() | - b.SessionVarsPropReader.RequiredOptionalEvalProps() + b.SessionVarsPropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *builtinTiDBEncodeRecordKeySig) Clone() builtinFunc { @@ -1029,7 +1037,11 @@ func (b *builtinTiDBEncodeRecordKeySig) evalString(ctx EvalContext, row chunk.Ro if EncodeRecordKeyFromRow == nil { return "", false, errors.New("EncodeRecordKeyFromRow is not initialized") } - recordKey, isNull, err := EncodeRecordKeyFromRow(ctx, is, b.args, row) + privChecker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return "", false, err + } + recordKey, isNull, err := EncodeRecordKeyFromRow(ctx, privChecker, is, b.args, row) if isNull || err != nil { if errors.ErrorEqual(err, plannererrors.ErrSpecificAccessDenied) { sv, err2 := b.GetSessionVars(ctx) @@ -1072,12 +1084,14 @@ type builtinTiDBEncodeIndexKeySig struct { baseBuiltinFunc expropt.InfoSchemaPropReader expropt.SessionVarsPropReader + expropt.PrivilegeCheckerPropReader } // RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. func (b *builtinTiDBEncodeIndexKeySig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { return b.InfoSchemaPropReader.RequiredOptionalEvalProps() | - b.SessionVarsPropReader.RequiredOptionalEvalProps() + b.SessionVarsPropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *builtinTiDBEncodeIndexKeySig) Clone() builtinFunc { @@ -1095,7 +1109,11 @@ func (b *builtinTiDBEncodeIndexKeySig) evalString(ctx EvalContext, row chunk.Row if EncodeIndexKeyFromRow == nil { return "", false, errors.New("EncodeIndexKeyFromRow is not initialized") } - idxKey, isNull, err := EncodeIndexKeyFromRow(ctx, is, b.args, row) + privChecker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return "", false, err + } + idxKey, isNull, err := EncodeIndexKeyFromRow(ctx, privChecker, is, b.args, row) if isNull || err != nil { if errors.ErrorEqual(err, plannererrors.ErrSpecificAccessDenied) { sv, err2 := b.GetSessionVars(ctx) @@ -1133,10 +1151,10 @@ func (c *tidbDecodeKeyFunctionClass) getFunction(ctx BuildContext, args []Expres var DecodeKeyFromString func(types.Context, infoschema.MetaOnlyInfoSchema, string) string // EncodeRecordKeyFromRow is used to encode record key by expressions. -var EncodeRecordKeyFromRow func(ctx EvalContext, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error) +var EncodeRecordKeyFromRow func(ctx EvalContext, checker expropt.PrivilegeChecker, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error) // EncodeIndexKeyFromRow is used to encode index key by expressions. -var EncodeIndexKeyFromRow func(ctx EvalContext, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error) +var EncodeIndexKeyFromRow func(ctx EvalContext, checker expropt.PrivilegeChecker, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error) type builtinTiDBDecodeKeySig struct { baseBuiltinFunc @@ -1173,6 +1191,7 @@ func (b *builtinTiDBDecodeKeySig) evalString(ctx EvalContext, row chunk.Row) (st type tidbDecodeSQLDigestsFunctionClass struct { baseFunctionClass + expropt.PrivilegeCheckerPropReader } func (c *tidbDecodeSQLDigestsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -1180,7 +1199,12 @@ func (c *tidbDecodeSQLDigestsFunctionClass) getFunction(ctx BuildContext, args [ return nil, err } - if !ctx.GetEvalCtx().RequestVerification("", "", "", mysql.ProcessPriv) { + privChecker, err := c.GetPrivilegeChecker(ctx.GetEvalCtx()) + if err != nil { + return nil, err + } + + if !privChecker.RequestVerification("", "", "", mysql.ProcessPriv) { return nil, errSpecificAccessDenied.GenWithStackByArgs("PROCESS") } @@ -1202,12 +1226,14 @@ type builtinTiDBDecodeSQLDigestsSig struct { baseBuiltinFunc expropt.SessionVarsPropReader expropt.SQLExecutorPropReader + expropt.PrivilegeCheckerPropReader } // RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. func (b *builtinTiDBDecodeSQLDigestsSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { return b.SessionVarsPropReader.RequiredOptionalEvalProps() | - b.SQLExecutorPropReader.RequiredOptionalEvalProps() + b.SQLExecutorPropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *builtinTiDBDecodeSQLDigestsSig) Clone() builtinFunc { @@ -1217,6 +1243,15 @@ func (b *builtinTiDBDecodeSQLDigestsSig) Clone() builtinFunc { } func (b *builtinTiDBDecodeSQLDigestsSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + privChecker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return "", true, err + } + + if !privChecker.RequestVerification("", "", "", mysql.ProcessPriv) { + return "", true, errSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + args := b.getArgs() digestsStr, isNull, err := args[0].EvalString(ctx, row) if err != nil { @@ -1443,11 +1478,13 @@ type builtinNextValSig struct { baseBuiltinFunc expropt.SequenceOperatorPropReader expropt.SessionVarsPropReader + expropt.PrivilegeCheckerPropReader } func (b *builtinNextValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() | - b.SessionVarsPropReader.RequiredOptionalEvalProps() + b.SessionVarsPropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *builtinNextValSig) Clone() builtinFunc { @@ -1477,7 +1514,11 @@ func (b *builtinNextValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool } // Do the privilege check. user := vars.User - if !ctx.RequestVerification(db, seq, "", mysql.InsertPriv) { + privChecker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return 0, false, err + } + if !privChecker.RequestVerification(db, seq, "", mysql.InsertPriv) { return 0, false, errSequenceAccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, seq) } nextVal, err := sequence.GetSequenceNextVal() @@ -1510,11 +1551,13 @@ type builtinLastValSig struct { baseBuiltinFunc expropt.SequenceOperatorPropReader expropt.SessionVarsPropReader + expropt.PrivilegeCheckerPropReader } func (b *builtinLastValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() | - b.SessionVarsPropReader.RequiredOptionalEvalProps() + b.SessionVarsPropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *builtinLastValSig) Clone() builtinFunc { @@ -1544,7 +1587,11 @@ func (b *builtinLastValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool } // Do the privilege check. user := vars.User - if !ctx.RequestVerification(db, seq, "", mysql.SelectPriv) { + privChecker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return 0, false, err + } + if !privChecker.RequestVerification(db, seq, "", mysql.SelectPriv) { return 0, false, errSequenceAccessDenied.GenWithStackByArgs("SELECT", user.AuthUsername, user.AuthHostname, seq) } return vars.SequenceState.GetLastValue(sequence.GetSequenceID()) @@ -1571,11 +1618,13 @@ type builtinSetValSig struct { baseBuiltinFunc expropt.SequenceOperatorPropReader expropt.SessionVarsPropReader + expropt.PrivilegeCheckerPropReader } func (b *builtinSetValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() | - b.SessionVarsPropReader.RequiredOptionalEvalProps() + b.SessionVarsPropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *builtinSetValSig) Clone() builtinFunc { @@ -1605,7 +1654,11 @@ func (b *builtinSetValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, } // Do the privilege check. user := vars.User - if !ctx.RequestVerification(db, seq, "", mysql.InsertPriv) { + privChecker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return 0, false, err + } + if !privChecker.RequestVerification(db, seq, "", mysql.InsertPriv) { return 0, false, errSequenceAccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, seq) } setValue, isNull, err := b.args[1].EvalInt(ctx, row) diff --git a/pkg/expression/exprctx/context.go b/pkg/expression/exprctx/context.go index 1f3bf38c9d3bb..88a188d82c00c 100644 --- a/pkg/expression/exprctx/context.go +++ b/pkg/expression/exprctx/context.go @@ -88,10 +88,6 @@ type EvalContext interface { GetDivPrecisionIncrement() int // GetUserVarsReader returns the `UserVarsReader` to read user vars. GetUserVarsReader() variable.UserVarsReader - // RequestVerification verifies user privilege - RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool - // RequestDynamicVerification verifies user privilege for a DYNAMIC privilege. - RequestDynamicVerification(privName string, grantable bool) bool // GetOptionalPropSet returns the optional properties provided by this context. GetOptionalPropSet() OptionalEvalPropKeySet // GetOptionalPropProvider gets the optional property provider by key @@ -261,6 +257,4 @@ type StaticConvertibleEvalContext interface { AllParamValues() []types.Datum GetWarnHandler() contextutil.WarnHandler - GetRequestVerificationFn() func(db, table, column string, priv mysql.PrivilegeType) bool - GetDynamicPrivCheckFn() func(privName string, grantable bool) bool } diff --git a/pkg/expression/exprctx/optional.go b/pkg/expression/exprctx/optional.go index c7e443ae2d90b..e7d1cc5281c2f 100644 --- a/pkg/expression/exprctx/optional.go +++ b/pkg/expression/exprctx/optional.go @@ -87,6 +87,8 @@ const ( OptPropAdvisoryLock // OptPropDDLOwnerInfo indicates to provide DDL owner information. OptPropDDLOwnerInfo + // OptPropPrivilegeChecker indicates to provide the privilege checker. + OptPropPrivilegeChecker // optPropsCnt is the count of optional properties. DO NOT use it as a property key. optPropsCnt ) @@ -147,6 +149,10 @@ var optionalPropertyDescList = []OptionalEvalPropDesc{ key: OptPropDDLOwnerInfo, str: "OptPropDDLOwnerInfo", }, + { + key: OptPropPrivilegeChecker, + str: "OptPropPrivilegeChecker", + }, } // OptionalEvalPropKeySet is a bit map for optional evaluation properties in EvalContext diff --git a/pkg/expression/exprctx/optional_test.go b/pkg/expression/exprctx/optional_test.go index 8454a44997865..7d8b124020ff1 100644 --- a/pkg/expression/exprctx/optional_test.go +++ b/pkg/expression/exprctx/optional_test.go @@ -58,7 +58,8 @@ func TestOptionalPropKeySet(t *testing.T) { Add(OptPropKVStore). Add(OptPropSQLExecutor). Add(OptPropSequenceOperator). - Add(OptPropAdvisoryLock) + Add(OptPropAdvisoryLock). + Add(OptPropPrivilegeChecker) require.True(t, keySet4.IsFull()) require.False(t, keySet4.IsEmpty()) } diff --git a/pkg/expression/expropt/BUILD.bazel b/pkg/expression/expropt/BUILD.bazel index c301cf1ea989b..0be10503c5585 100644 --- a/pkg/expression/expropt/BUILD.bazel +++ b/pkg/expression/expropt/BUILD.bazel @@ -9,6 +9,7 @@ go_library( "infoschema.go", "kvstore.go", "optional.go", + "priv.go", "sequence.go", "sessionvars.go", "sqlexec.go", @@ -20,6 +21,7 @@ go_library( "//pkg/infoschema/context", "//pkg/kv", "//pkg/parser/auth", + "//pkg/parser/mysql", "//pkg/planner/core/resolve", "//pkg/sessionctx/variable", "//pkg/util/chunk", diff --git a/pkg/expression/expropt/optional.go b/pkg/expression/expropt/optional.go index 1a9c8b793fc9d..44cae044f377e 100644 --- a/pkg/expression/expropt/optional.go +++ b/pkg/expression/expropt/optional.go @@ -78,6 +78,9 @@ func (o *OptionalEvalPropProviders) Add(val exprctx.OptionalEvalPropProvider) { case exprctx.OptPropSequenceOperator: _, ok := val.(SequenceOperatorProvider) intest.Assert(ok) + case exprctx.OptPropPrivilegeChecker: + _, ok := val.(PrivilegeCheckerProvider) + intest.Assert(ok) default: intest.Assert(false) } diff --git a/pkg/expression/expropt/optional_test.go b/pkg/expression/expropt/optional_test.go index 6c59a33def353..127411cffd10a 100644 --- a/pkg/expression/expropt/optional_test.go +++ b/pkg/expression/expropt/optional_test.go @@ -267,6 +267,21 @@ func TestOptionalEvalPropProviders(t *testing.T) { isOwner = false require.False(t, assertReaderFuncValue(t, ctx, r.IsDDLOwner)) } + case exprctx.OptPropPrivilegeChecker: + type mockPrivCheckerTp struct { + PrivilegeChecker + } + mockPrivChecker := &mockPrivCheckerTp{} + p = PrivilegeCheckerProvider(func() PrivilegeChecker { return mockPrivChecker }) + r := PrivilegeCheckerPropReader{} + reader = r + verifyNoProvider = func(ctx exprctx.EvalContext) { + assertReaderFuncReturnErr(t, ctx, r.GetPrivilegeChecker) + } + verifyProvider = func(ctx exprctx.EvalContext, val exprctx.OptionalEvalPropProvider) { + require.Same(t, mockPrivChecker, assertReaderFuncValue(t, ctx, r.GetPrivilegeChecker)) + require.Same(t, mockPrivChecker, val.(PrivilegeCheckerProvider)()) + } default: require.Fail(t, "unexpected optional property key") } diff --git a/pkg/expression/expropt/priv.go b/pkg/expression/expropt/priv.go new file mode 100644 index 0000000000000..2afed7089ca15 --- /dev/null +++ b/pkg/expression/expropt/priv.go @@ -0,0 +1,55 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expropt + +import ( + "github.com/pingcap/tidb/pkg/expression/exprctx" + "github.com/pingcap/tidb/pkg/parser/mysql" +) + +// PrivilegeChecker provides privilege check for expressions. +type PrivilegeChecker interface { + // RequestVerification verifies user privilege + RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool + // RequestDynamicVerification verifies user privilege for a DYNAMIC privilege. + RequestDynamicVerification(privName string, grantable bool) bool +} + +var _ exprctx.OptionalEvalPropProvider = PrivilegeCheckerProvider(nil) + +// PrivilegeCheckerProvider is used to provide PrivilegeChecker. +type PrivilegeCheckerProvider func() PrivilegeChecker + +// Desc returns the description for the property key. +func (PrivilegeCheckerProvider) Desc() *exprctx.OptionalEvalPropDesc { + return exprctx.OptPropPrivilegeChecker.Desc() +} + +// PrivilegeCheckerPropReader is used by expression to get PrivilegeChecker. +type PrivilegeCheckerPropReader struct{} + +// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. +func (PrivilegeCheckerPropReader) RequiredOptionalEvalProps() exprctx.OptionalEvalPropKeySet { + return exprctx.OptPropPrivilegeChecker.AsPropKeySet() +} + +// GetPrivilegeChecker returns a PrivilegeChecker. +func (PrivilegeCheckerPropReader) GetPrivilegeChecker(ctx exprctx.EvalContext) (PrivilegeChecker, error) { + p, err := getPropProvider[PrivilegeCheckerProvider](ctx, exprctx.OptPropPrivilegeChecker) + if err != nil { + return nil, err + } + return p(), nil +} diff --git a/pkg/expression/exprstatic/evalctx.go b/pkg/expression/exprstatic/evalctx.go index 5086143497a1c..a8c3202cd357a 100644 --- a/pkg/expression/exprstatic/evalctx.go +++ b/pkg/expression/exprstatic/evalctx.go @@ -70,21 +70,19 @@ func (t *timeOnce) getTime(loc *time.Location) (tm time.Time, err error) { // evalCtxState is the internal state for `EvalContext`. // We make it as a standalone private struct here to make sure `EvalCtxOption` can only be called in constructor. type evalCtxState struct { - warnHandler contextutil.WarnHandler - sqlMode mysql.SQLMode - typeCtx types.Context - errCtx errctx.Context - currentDB string - currentTime *timeOnce - maxAllowedPacket uint64 - enableRedactLog string - defaultWeekFormatMode string - divPrecisionIncrement int - requestVerificationFn func(db, table, column string, priv mysql.PrivilegeType) bool - requestDynamicVerificationFn func(privName string, grantable bool) bool - paramList []types.Datum - userVars variable.UserVarsReader - props expropt.OptionalEvalPropProviders + warnHandler contextutil.WarnHandler + sqlMode mysql.SQLMode + typeCtx types.Context + errCtx errctx.Context + currentDB string + currentTime *timeOnce + maxAllowedPacket uint64 + enableRedactLog string + defaultWeekFormatMode string + divPrecisionIncrement int + paramList []types.Datum + userVars variable.UserVarsReader + props expropt.OptionalEvalPropProviders } // EvalCtxOption is the option to set `EvalContext`. @@ -172,20 +170,6 @@ func WithDivPrecisionIncrement(inc int) EvalCtxOption { } } -// WithPrivCheck sets the requestVerificationFn -func WithPrivCheck(fn func(db, table, column string, priv mysql.PrivilegeType) bool) EvalCtxOption { - return func(s *evalCtxState) { - s.requestVerificationFn = fn - } -} - -// WithDynamicPrivCheck sets the requestDynamicVerificationFn -func WithDynamicPrivCheck(fn func(privName string, grantable bool) bool) EvalCtxOption { - return func(s *evalCtxState) { - s.requestDynamicVerificationFn = fn - } -} - // WithOptionalProperty sets the optional property providers func WithOptionalProperty(providers ...exprctx.OptionalEvalPropProvider) EvalCtxOption { return func(s *evalCtxState) { @@ -366,22 +350,6 @@ func (ctx *EvalContext) GetUserVarsReader() variable.UserVarsReader { return ctx.userVars } -// RequestVerification verifies user privilege -func (ctx *EvalContext) RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool { - if fn := ctx.requestVerificationFn; fn != nil { - return fn(db, table, column, priv) - } - return true -} - -// RequestDynamicVerification verifies user privilege for a DYNAMIC privilege. -func (ctx *EvalContext) RequestDynamicVerification(privName string, grantable bool) bool { - if fn := ctx.requestDynamicVerificationFn; fn != nil { - return fn(privName, grantable) - } - return true -} - // GetOptionalPropSet gets the optional property set from context func (ctx *EvalContext) GetOptionalPropSet() exprctx.OptionalEvalPropKeySet { return ctx.props.PropKeySet() @@ -429,16 +397,6 @@ func (ctx *EvalContext) AllParamValues() []types.Datum { return ctx.paramList } -// GetDynamicPrivCheckFn implements context.StaticConvertibleEvalContext. -func (ctx *EvalContext) GetDynamicPrivCheckFn() func(privName string, grantable bool) bool { - return ctx.requestDynamicVerificationFn -} - -// GetRequestVerificationFn implements context.StaticConvertibleEvalContext. -func (ctx *EvalContext) GetRequestVerificationFn() func(db string, table string, column string, priv mysql.PrivilegeType) bool { - return ctx.requestVerificationFn -} - // GetWarnHandler implements context.StaticConvertibleEvalContext. func (ctx *EvalContext) GetWarnHandler() contextutil.WarnHandler { return ctx.warnHandler @@ -561,8 +519,6 @@ func MakeEvalContextStatic(ctx exprctx.StaticConvertibleEvalContext) *EvalContex WithMaxAllowedPacket(ctx.GetMaxAllowedPacket()), WithDefaultWeekFormatMode(ctx.GetDefaultWeekFormatMode()), WithDivPrecisionIncrement(ctx.GetDivPrecisionIncrement()), - WithPrivCheck(ctx.GetRequestVerificationFn()), - WithDynamicPrivCheck(ctx.GetDynamicPrivCheckFn()), WithParamList(params), WithUserVarsReader(ctx.GetUserVarsReader().Clone()), WithOptionalProperty(props...), diff --git a/pkg/expression/exprstatic/evalctx_test.go b/pkg/expression/exprstatic/evalctx_test.go index 10b5fc02a69c7..4ac7cb0822e07 100644 --- a/pkg/expression/exprstatic/evalctx_test.go +++ b/pkg/expression/exprstatic/evalctx_test.go @@ -62,10 +62,6 @@ func checkDefaultStaticEvalCtx(t *testing.T, ctx *EvalContext) { require.Equal(t, variable.DefDivPrecisionIncrement, ctx.GetDivPrecisionIncrement()) require.Empty(t, ctx.AllParamValues()) require.Equal(t, variable.NewUserVars(), ctx.GetUserVarsReader()) - require.Nil(t, ctx.requestVerificationFn) - require.Nil(t, ctx.requestDynamicVerificationFn) - require.True(t, ctx.RequestVerification("test", "t1", "", mysql.CreatePriv)) - require.True(t, ctx.RequestDynamicVerification("RESTRICTED_USER_ADMIN", true)) require.True(t, ctx.GetOptionalPropSet().IsEmpty()) p, ok := ctx.GetOptionalPropProvider(exprctx.OptPropAdvisoryLock) require.Nil(t, p) @@ -82,13 +78,11 @@ func checkDefaultStaticEvalCtx(t *testing.T, ctx *EvalContext) { } type evalCtxOptionsTestState struct { - now time.Time - loc *time.Location - warnHandler *contextutil.StaticWarnHandler - userVars *variable.UserVars - ddlOwner bool - privCheckArgs []any - privRet bool + now time.Time + loc *time.Location + warnHandler *contextutil.StaticWarnHandler + userVars *variable.UserVars + ddlOwner bool } func getEvalCtxOptionsForTest(t *testing.T) ([]EvalCtxOption, *evalCtxOptionsTestState) { @@ -127,16 +121,6 @@ func getEvalCtxOptionsForTest(t *testing.T) ([]EvalCtxOption, *evalCtxOptionsTes WithDefaultWeekFormatMode("3"), WithDivPrecisionIncrement(5), WithUserVarsReader(s.userVars), - WithPrivCheck(func(db, table, column string, priv mysql.PrivilegeType) bool { - require.Nil(t, s.privCheckArgs) - s.privCheckArgs = []any{db, table, column, priv} - return s.privRet - }), - WithDynamicPrivCheck(func(privName string, grantable bool) bool { - require.Nil(t, s.privCheckArgs) - s.privCheckArgs = []any{privName, grantable} - return s.privRet - }), WithOptionalProperty(provider1, provider2), }, s } @@ -163,19 +147,6 @@ func checkOptionsStaticEvalCtx(t *testing.T, ctx *EvalContext, s *evalCtxOptions require.Equal(t, 5, ctx.GetDivPrecisionIncrement()) require.Same(t, s.userVars, ctx.GetUserVarsReader()) - s.privCheckArgs, s.privRet = nil, false - require.False(t, ctx.RequestVerification("db", "table", "column", mysql.CreatePriv)) - require.Equal(t, []any{"db", "table", "column", mysql.CreatePriv}, s.privCheckArgs) - s.privCheckArgs, s.privRet = nil, true - require.True(t, ctx.RequestVerification("db2", "table2", "column2", mysql.UpdatePriv)) - require.Equal(t, []any{"db2", "table2", "column2", mysql.UpdatePriv}, s.privCheckArgs) - s.privCheckArgs, s.privRet = nil, false - require.False(t, ctx.RequestDynamicVerification("RESTRICTED_USER_ADMIN", true)) - require.Equal(t, []any{"RESTRICTED_USER_ADMIN", true}, s.privCheckArgs) - s.privCheckArgs, s.privRet = nil, true - require.True(t, ctx.RequestDynamicVerification("RESTRICTED_TABLES_ADMIN", false)) - require.Equal(t, []any{"RESTRICTED_TABLES_ADMIN", false}, s.privCheckArgs) - var optSet exprctx.OptionalEvalPropKeySet optSet = optSet.Add(exprctx.OptPropCurrentUser).Add(exprctx.OptPropDDLOwnerInfo) require.Equal(t, optSet, ctx.GetOptionalPropSet()) @@ -493,12 +464,6 @@ func TestMakeEvalContextStatic(t *testing.T) { WithMaxAllowedPacket(12345), WithDefaultWeekFormatMode("3"), WithDivPrecisionIncrement(5), - WithPrivCheck(func(db, table, column string, priv mysql.PrivilegeType) bool { - return true - }), - WithDynamicPrivCheck(func(privName string, grantable bool) bool { - return true - }), WithParamList(paramList), WithUserVarsReader(userVars), WithOptionalProperty(provider), @@ -517,10 +482,7 @@ func TestMakeEvalContextStatic(t *testing.T) { } deeptest.AssertRecursivelyNotEqual(t, obj, NewEvalContext(), deeptest.WithIgnorePath(ignorePath), - deeptest.WithPointerComparePath([]string{ - "$.evalCtxState.requestVerificationFn", - "$.evalCtxState.requestDynamicVerificationFn", - })) + ) staticObj := MakeEvalContextStatic(obj) @@ -529,8 +491,6 @@ func TestMakeEvalContextStatic(t *testing.T) { deeptest.WithPointerComparePath([]string{ "$.evalCtxState.warnHandler", "$.evalCtxState.paramList*.b", - "$.evalCtxState.requestVerificationFn", - "$.evalCtxState.requestDynamicVerificationFn", }), ) @@ -634,8 +594,6 @@ func TestEvalCtxLoadSystemVars(t *testing.T) { "$.typeCtx.warnHandler", "$.errCtx", "$.currentDB", - "$.requestVerificationFn", - "$.requestDynamicVerificationFn", "$.paramList", "$.userVars", "$.props", diff --git a/pkg/expression/extension.go b/pkg/expression/extension.go index 4fc0f5ea0f074..ff32ab2219ed0 100644 --- a/pkg/expression/extension.go +++ b/pkg/expression/extension.go @@ -65,6 +65,7 @@ func removeExtensionFunc(name string) { type extensionFuncClass struct { baseFunctionClass + expropt.PrivilegeCheckerPropReader funcDef extension.FunctionDef flen int } @@ -96,7 +97,12 @@ func newExtensionFuncClass(def *extension.FunctionDef) (*extensionFuncClass, err } func (c *extensionFuncClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := checkPrivileges(ctx.GetEvalCtx(), &c.funcDef); err != nil { + checker, err := c.GetPrivilegeChecker(ctx.GetEvalCtx()) + if err != nil { + return nil, err + } + + if err := checkPrivileges(checker, &c.funcDef); err != nil { return nil, err } @@ -117,7 +123,7 @@ func (c *extensionFuncClass) getFunction(ctx BuildContext, args []Expression) (b return sig, nil } -func checkPrivileges(ctx EvalContext, fnDef *extension.FunctionDef) error { +func checkPrivileges(privChecker expropt.PrivilegeChecker, fnDef *extension.FunctionDef) error { fn := fnDef.RequireDynamicPrivileges if fn == nil { return nil @@ -130,7 +136,7 @@ func checkPrivileges(ctx EvalContext, fnDef *extension.FunctionDef) error { } for _, priv := range privs { - if !ctx.RequestDynamicVerification(priv, false) { + if !privChecker.RequestDynamicVerification(priv, false) { msg := priv if !semEnabled { msg = "SUPER or " + msg @@ -147,6 +153,7 @@ var _ extension.FunctionContext = extensionFnContext{} type extensionFuncSig struct { baseBuiltinFunc expropt.SessionVarsPropReader + expropt.PrivilegeCheckerPropReader extension.FunctionDef } @@ -159,11 +166,17 @@ func (b *extensionFuncSig) Clone() builtinFunc { } func (b *extensionFuncSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { - return b.SessionVarsPropReader.RequiredOptionalEvalProps() + return b.SessionVarsPropReader.RequiredOptionalEvalProps() | + b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps() } func (b *extensionFuncSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - if err := checkPrivileges(ctx, &b.FunctionDef); err != nil { + checker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return "", true, err + } + + if err := checkPrivileges(checker, &b.FunctionDef); err != nil { return "", true, err } @@ -180,7 +193,12 @@ func (b *extensionFuncSig) evalString(ctx EvalContext, row chunk.Row) (string, b } func (b *extensionFuncSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - if err := checkPrivileges(ctx, &b.FunctionDef); err != nil { + checker, err := b.GetPrivilegeChecker(ctx) + if err != nil { + return 0, true, err + } + + if err := checkPrivileges(checker, &b.FunctionDef); err != nil { return 0, true, err } diff --git a/pkg/expression/sessionexpr/sessionctx.go b/pkg/expression/sessionexpr/sessionctx.go index af25f9b08e594..1dae8c5a6a7a6 100644 --- a/pkg/expression/sessionexpr/sessionctx.go +++ b/pkg/expression/sessionexpr/sessionctx.go @@ -160,6 +160,7 @@ func NewEvalContext(sctx sessionctx.Context) *EvalContext { ctx.setOptionalProp(sequenceOperatorProp(sctx)) ctx.setOptionalProp(expropt.NewAdvisoryLockPropProvider(sctx)) ctx.setOptionalProp(expropt.DDLOwnerInfoProvider(sctx.IsDDLOwner)) + ctx.setOptionalProp(expropt.PrivilegeCheckerProvider(func() expropt.PrivilegeChecker { return ctx })) // When EvalContext is created from a session, it should contain all the optional properties. intest.Assert(ctx.props.PropKeySet().IsFull()) return ctx @@ -418,36 +419,6 @@ func (ctx *EvalContext) AllParamValues() []types.Datum { return ctx.sctx.GetSessionVars().PlanCacheParams.AllParamValues() } -// GetDynamicPrivCheckFn implements context.StaticConvertibleEvalContext. -func (ctx *EvalContext) GetDynamicPrivCheckFn() func(privName string, grantable bool) bool { - checker := privilege.GetPrivilegeManager(ctx.sctx) - activeRoles := make([]*auth.RoleIdentity, len(ctx.sctx.GetSessionVars().ActiveRoles)) - copy(activeRoles, ctx.sctx.GetSessionVars().ActiveRoles) - - return func(privName string, grantable bool) bool { - if checker == nil { - return true - } - - return checker.RequestDynamicVerification(activeRoles, privName, grantable) - } -} - -// GetRequestVerificationFn implements context.StaticConvertibleEvalContext. -func (ctx *EvalContext) GetRequestVerificationFn() func(db string, table string, column string, priv mysql.PrivilegeType) bool { - checker := privilege.GetPrivilegeManager(ctx.sctx) - activeRoles := make([]*auth.RoleIdentity, len(ctx.sctx.GetSessionVars().ActiveRoles)) - copy(activeRoles, ctx.sctx.GetSessionVars().ActiveRoles) - - return func(db string, table string, column string, priv mysql.PrivilegeType) bool { - if checker == nil { - return true - } - - return checker.RequestVerification(activeRoles, db, table, column, priv) - } -} - // GetWarnHandler implements context.StaticConvertibleEvalContext. func (ctx *EvalContext) GetWarnHandler() contextutil.WarnHandler { return ctx.sctx.GetSessionVars().StmtCtx.WarnHandler diff --git a/pkg/expression/sessionexpr/sessionctx_test.go b/pkg/expression/sessionexpr/sessionctx_test.go index a316115f78e69..254bd5296633b 100644 --- a/pkg/expression/sessionexpr/sessionctx_test.go +++ b/pkg/expression/sessionexpr/sessionctx_test.go @@ -262,6 +262,12 @@ func TestSessionEvalContextOptProps(t *testing.T) { require.False(t, ddlInfoProvider()) ctx.SetIsDDLOwner(true) require.True(t, ddlInfoProvider()) + + // test for OptPropPrivilegeChecker + privCheckerProvider := getProvider[expropt.PrivilegeCheckerProvider](t, impl, exprctx.OptPropPrivilegeChecker) + privChecker := privCheckerProvider() + require.NotNil(t, privChecker) + require.Same(t, impl, privChecker) } func TestSessionBuildContext(t *testing.T) { diff --git a/pkg/planner/core/expression_codec_fn.go b/pkg/planner/core/expression_codec_fn.go index 461aca426e321..02d7e0fa91495 100644 --- a/pkg/planner/core/expression_codec_fn.go +++ b/pkg/planner/core/expression_codec_fn.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/expropt" "github.com/pingcap/tidb/pkg/infoschema" infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/kv" @@ -50,6 +51,7 @@ type tidbCodecFuncHelper struct{} func (h tidbCodecFuncHelper) encodeHandleFromRow( ctx expression.EvalContext, + checker expropt.PrivilegeChecker, isVer infoschemactx.MetaOnlyInfoSchema, args []expression.Expression, row chunk.Row, @@ -63,7 +65,7 @@ func (h tidbCodecFuncHelper) encodeHandleFromRow( return nil, isNull, err } is := isVer.(infoschema.InfoSchema) - tbl, _, err := h.findCommonOrPartitionedTable(ctx, is, dbName, tblName) + tbl, _, err := h.findCommonOrPartitionedTable(checker, is, dbName, tblName) if err != nil { return nil, false, err } @@ -76,7 +78,7 @@ func (h tidbCodecFuncHelper) encodeHandleFromRow( } func (h tidbCodecFuncHelper) findCommonOrPartitionedTable( - ctx expression.EvalContext, + checker expropt.PrivilegeChecker, is infoschema.InfoSchema, dbName string, tblName string, @@ -86,7 +88,7 @@ func (h tidbCodecFuncHelper) findCommonOrPartitionedTable( if err != nil { return nil, 0, err } - if !ctx.RequestVerification(dbName, tblName, "", mysql.AllPrivMask) { + if !checker.RequestVerification(dbName, tblName, "", mysql.AllPrivMask) { // The arguments will be filled by caller. return nil, 0, plannererrors.ErrSpecificAccessDenied } @@ -165,6 +167,7 @@ func (tidbCodecFuncHelper) buildHandle( func (h tidbCodecFuncHelper) encodeIndexKeyFromRow( ctx expression.EvalContext, + checker expropt.PrivilegeChecker, isVer infoschemactx.MetaOnlyInfoSchema, args []expression.Expression, row chunk.Row, @@ -182,7 +185,7 @@ func (h tidbCodecFuncHelper) encodeIndexKeyFromRow( return nil, isNull, err } is := isVer.(infoschema.InfoSchema) - tbl, physicalID, err := h.findCommonOrPartitionedTable(ctx, is, dbName, tblName) + tbl, physicalID, err := h.findCommonOrPartitionedTable(checker, is, dbName, tblName) if err != nil { return nil, false, err }