Skip to content

Commit

Permalink
expression: provide OptPropPrivilegeChecker for EvalContext (#56302)
Browse files Browse the repository at this point in the history
close #56301
  • Loading branch information
lcwangchao authored Oct 11, 2024
1 parent a0bcee3 commit a56674c
Show file tree
Hide file tree
Showing 17 changed files with 210 additions and 191 deletions.
1 change: 0 additions & 1 deletion pkg/ddl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ go_test(
"//pkg/parser/mysql",
"//pkg/parser/terror",
"//pkg/parser/types",
"//pkg/privilege",
"//pkg/server",
"//pkg/session",
"//pkg/session/types",
Expand Down
21 changes: 0 additions & 21 deletions pkg/ddl/backfilling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/privilege"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/table"
Expand Down Expand Up @@ -182,26 +181,6 @@ func assertStaticExprContextEqual(t *testing.T, sctx sessionctx.Context, exprCtx
require.NoError(t, err)
},
},
{
field: "requestVerificationFn",
check: func(ctx *exprstatic.EvalContext) {
// RequestVerification should allow all privileges
// that is the same with input session context (GetPrivilegeManager returns nil).
require.Nil(t, privilege.GetPrivilegeManager(sctx))
require.True(t, sctx.GetExprCtx().GetEvalCtx().RequestVerification("any", "any", "any", mysql.CreatePriv))
require.True(t, ctx.RequestVerification("any", "any", "any", mysql.CreatePriv))
},
},
{
field: "requestDynamicVerificationFn",
check: func(ctx *exprstatic.EvalContext) {
// RequestDynamicVerification should allow all privileges
// that is the same with input session context (GetPrivilegeManager returns nil).
require.Nil(t, privilege.GetPrivilegeManager(sctx))
require.True(t, sctx.GetExprCtx().GetEvalCtx().RequestDynamicVerification("RESTRICTED_USER_ADMIN", true))
require.True(t, ctx.RequestDynamicVerification("RESTRICTED_USER_ADMIN", true))
},
},
}

// check ExprContext except EvalContext
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ var funcs = map[string]functionClass{
ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}},
ast.TiDBDecodePlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodePlan, 1, 1}},
ast.TiDBDecodeBinaryPlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodeBinaryPlan, 1, 1}},
ast.TiDBDecodeSQLDigests: &tidbDecodeSQLDigestsFunctionClass{baseFunctionClass{ast.TiDBDecodeSQLDigests, 1, 2}},
ast.TiDBDecodeSQLDigests: &tidbDecodeSQLDigestsFunctionClass{baseFunctionClass: baseFunctionClass{ast.TiDBDecodeSQLDigests, 1, 2}},
ast.TiDBEncodeSQLDigest: &tidbEncodeSQLDigestFunctionClass{baseFunctionClass{ast.TiDBEncodeSQLDigest, 1, 1}},

// TiDB Sequence function.
Expand Down
85 changes: 69 additions & 16 deletions pkg/expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -919,11 +919,13 @@ func (c *tidbMVCCInfoFunctionClass) getFunction(ctx BuildContext, args []Express
type builtinTiDBMVCCInfoSig struct {
baseBuiltinFunc
expropt.KVStorePropReader
expropt.PrivilegeCheckerPropReader
}

// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface.
func (b *builtinTiDBMVCCInfoSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.KVStorePropReader.RequiredOptionalEvalProps()
return b.KVStorePropReader.RequiredOptionalEvalProps() |
b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps()
}

func (b *builtinTiDBMVCCInfoSig) Clone() builtinFunc {
Expand All @@ -934,7 +936,11 @@ func (b *builtinTiDBMVCCInfoSig) Clone() builtinFunc {

// evalString evals a builtinTiDBMVCCInfoSig.
func (b *builtinTiDBMVCCInfoSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) {
if !ctx.RequestVerification("", "", "", mysql.SuperPriv) {
privChecker, err := b.GetPrivilegeChecker(ctx)
if err != nil {
return "", false, err
}
if !privChecker.RequestVerification("", "", "", mysql.SuperPriv) {
return "", false, plannererrors.ErrSpecificAccessDenied.FastGenByArgs("SUPER")
}
s, isNull, err := b.args[0].EvalString(ctx, row)
Expand Down Expand Up @@ -1006,12 +1012,14 @@ type builtinTiDBEncodeRecordKeySig struct {
baseBuiltinFunc
expropt.InfoSchemaPropReader
expropt.SessionVarsPropReader
expropt.PrivilegeCheckerPropReader
}

// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface.
func (b *builtinTiDBEncodeRecordKeySig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.InfoSchemaPropReader.RequiredOptionalEvalProps() |
b.SessionVarsPropReader.RequiredOptionalEvalProps()
b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps()
}

func (b *builtinTiDBEncodeRecordKeySig) Clone() builtinFunc {
Expand All @@ -1029,7 +1037,11 @@ func (b *builtinTiDBEncodeRecordKeySig) evalString(ctx EvalContext, row chunk.Ro
if EncodeRecordKeyFromRow == nil {
return "", false, errors.New("EncodeRecordKeyFromRow is not initialized")
}
recordKey, isNull, err := EncodeRecordKeyFromRow(ctx, is, b.args, row)
privChecker, err := b.GetPrivilegeChecker(ctx)
if err != nil {
return "", false, err
}
recordKey, isNull, err := EncodeRecordKeyFromRow(ctx, privChecker, is, b.args, row)
if isNull || err != nil {
if errors.ErrorEqual(err, plannererrors.ErrSpecificAccessDenied) {
sv, err2 := b.GetSessionVars(ctx)
Expand Down Expand Up @@ -1072,12 +1084,14 @@ type builtinTiDBEncodeIndexKeySig struct {
baseBuiltinFunc
expropt.InfoSchemaPropReader
expropt.SessionVarsPropReader
expropt.PrivilegeCheckerPropReader
}

// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface.
func (b *builtinTiDBEncodeIndexKeySig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.InfoSchemaPropReader.RequiredOptionalEvalProps() |
b.SessionVarsPropReader.RequiredOptionalEvalProps()
b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps()
}

func (b *builtinTiDBEncodeIndexKeySig) Clone() builtinFunc {
Expand All @@ -1095,7 +1109,11 @@ func (b *builtinTiDBEncodeIndexKeySig) evalString(ctx EvalContext, row chunk.Row
if EncodeIndexKeyFromRow == nil {
return "", false, errors.New("EncodeIndexKeyFromRow is not initialized")
}
idxKey, isNull, err := EncodeIndexKeyFromRow(ctx, is, b.args, row)
privChecker, err := b.GetPrivilegeChecker(ctx)
if err != nil {
return "", false, err
}
idxKey, isNull, err := EncodeIndexKeyFromRow(ctx, privChecker, is, b.args, row)
if isNull || err != nil {
if errors.ErrorEqual(err, plannererrors.ErrSpecificAccessDenied) {
sv, err2 := b.GetSessionVars(ctx)
Expand Down Expand Up @@ -1133,10 +1151,10 @@ func (c *tidbDecodeKeyFunctionClass) getFunction(ctx BuildContext, args []Expres
var DecodeKeyFromString func(types.Context, infoschema.MetaOnlyInfoSchema, string) string

// EncodeRecordKeyFromRow is used to encode record key by expressions.
var EncodeRecordKeyFromRow func(ctx EvalContext, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error)
var EncodeRecordKeyFromRow func(ctx EvalContext, checker expropt.PrivilegeChecker, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error)

// EncodeIndexKeyFromRow is used to encode index key by expressions.
var EncodeIndexKeyFromRow func(ctx EvalContext, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error)
var EncodeIndexKeyFromRow func(ctx EvalContext, checker expropt.PrivilegeChecker, is infoschema.MetaOnlyInfoSchema, args []Expression, row chunk.Row) ([]byte, bool, error)

type builtinTiDBDecodeKeySig struct {
baseBuiltinFunc
Expand Down Expand Up @@ -1173,14 +1191,20 @@ func (b *builtinTiDBDecodeKeySig) evalString(ctx EvalContext, row chunk.Row) (st

type tidbDecodeSQLDigestsFunctionClass struct {
baseFunctionClass
expropt.PrivilegeCheckerPropReader
}

func (c *tidbDecodeSQLDigestsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}

if !ctx.GetEvalCtx().RequestVerification("", "", "", mysql.ProcessPriv) {
privChecker, err := c.GetPrivilegeChecker(ctx.GetEvalCtx())
if err != nil {
return nil, err
}

if !privChecker.RequestVerification("", "", "", mysql.ProcessPriv) {
return nil, errSpecificAccessDenied.GenWithStackByArgs("PROCESS")
}

Expand All @@ -1202,12 +1226,14 @@ type builtinTiDBDecodeSQLDigestsSig struct {
baseBuiltinFunc
expropt.SessionVarsPropReader
expropt.SQLExecutorPropReader
expropt.PrivilegeCheckerPropReader
}

// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface.
func (b *builtinTiDBDecodeSQLDigestsSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.SQLExecutorPropReader.RequiredOptionalEvalProps()
b.SQLExecutorPropReader.RequiredOptionalEvalProps() |
b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps()
}

func (b *builtinTiDBDecodeSQLDigestsSig) Clone() builtinFunc {
Expand All @@ -1217,6 +1243,15 @@ func (b *builtinTiDBDecodeSQLDigestsSig) Clone() builtinFunc {
}

func (b *builtinTiDBDecodeSQLDigestsSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) {
privChecker, err := b.GetPrivilegeChecker(ctx)
if err != nil {
return "", true, err
}

if !privChecker.RequestVerification("", "", "", mysql.ProcessPriv) {
return "", true, errSpecificAccessDenied.GenWithStackByArgs("PROCESS")
}

args := b.getArgs()
digestsStr, isNull, err := args[0].EvalString(ctx, row)
if err != nil {
Expand Down Expand Up @@ -1443,11 +1478,13 @@ type builtinNextValSig struct {
baseBuiltinFunc
expropt.SequenceOperatorPropReader
expropt.SessionVarsPropReader
expropt.PrivilegeCheckerPropReader
}

func (b *builtinNextValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() |
b.SessionVarsPropReader.RequiredOptionalEvalProps()
b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps()
}

func (b *builtinNextValSig) Clone() builtinFunc {
Expand Down Expand Up @@ -1477,7 +1514,11 @@ func (b *builtinNextValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool
}
// Do the privilege check.
user := vars.User
if !ctx.RequestVerification(db, seq, "", mysql.InsertPriv) {
privChecker, err := b.GetPrivilegeChecker(ctx)
if err != nil {
return 0, false, err
}
if !privChecker.RequestVerification(db, seq, "", mysql.InsertPriv) {
return 0, false, errSequenceAccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, seq)
}
nextVal, err := sequence.GetSequenceNextVal()
Expand Down Expand Up @@ -1510,11 +1551,13 @@ type builtinLastValSig struct {
baseBuiltinFunc
expropt.SequenceOperatorPropReader
expropt.SessionVarsPropReader
expropt.PrivilegeCheckerPropReader
}

func (b *builtinLastValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() |
b.SessionVarsPropReader.RequiredOptionalEvalProps()
b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps()
}

func (b *builtinLastValSig) Clone() builtinFunc {
Expand Down Expand Up @@ -1544,7 +1587,11 @@ func (b *builtinLastValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool
}
// Do the privilege check.
user := vars.User
if !ctx.RequestVerification(db, seq, "", mysql.SelectPriv) {
privChecker, err := b.GetPrivilegeChecker(ctx)
if err != nil {
return 0, false, err
}
if !privChecker.RequestVerification(db, seq, "", mysql.SelectPriv) {
return 0, false, errSequenceAccessDenied.GenWithStackByArgs("SELECT", user.AuthUsername, user.AuthHostname, seq)
}
return vars.SequenceState.GetLastValue(sequence.GetSequenceID())
Expand All @@ -1571,11 +1618,13 @@ type builtinSetValSig struct {
baseBuiltinFunc
expropt.SequenceOperatorPropReader
expropt.SessionVarsPropReader
expropt.PrivilegeCheckerPropReader
}

func (b *builtinSetValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() |
b.SessionVarsPropReader.RequiredOptionalEvalProps()
b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.PrivilegeCheckerPropReader.RequiredOptionalEvalProps()
}

func (b *builtinSetValSig) Clone() builtinFunc {
Expand Down Expand Up @@ -1605,7 +1654,11 @@ func (b *builtinSetValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool,
}
// Do the privilege check.
user := vars.User
if !ctx.RequestVerification(db, seq, "", mysql.InsertPriv) {
privChecker, err := b.GetPrivilegeChecker(ctx)
if err != nil {
return 0, false, err
}
if !privChecker.RequestVerification(db, seq, "", mysql.InsertPriv) {
return 0, false, errSequenceAccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, seq)
}
setValue, isNull, err := b.args[1].EvalInt(ctx, row)
Expand Down
6 changes: 0 additions & 6 deletions pkg/expression/exprctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ type EvalContext interface {
GetDivPrecisionIncrement() int
// GetUserVarsReader returns the `UserVarsReader` to read user vars.
GetUserVarsReader() variable.UserVarsReader
// RequestVerification verifies user privilege
RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool
// RequestDynamicVerification verifies user privilege for a DYNAMIC privilege.
RequestDynamicVerification(privName string, grantable bool) bool
// GetOptionalPropSet returns the optional properties provided by this context.
GetOptionalPropSet() OptionalEvalPropKeySet
// GetOptionalPropProvider gets the optional property provider by key
Expand Down Expand Up @@ -261,6 +257,4 @@ type StaticConvertibleEvalContext interface {

AllParamValues() []types.Datum
GetWarnHandler() contextutil.WarnHandler
GetRequestVerificationFn() func(db, table, column string, priv mysql.PrivilegeType) bool
GetDynamicPrivCheckFn() func(privName string, grantable bool) bool
}
6 changes: 6 additions & 0 deletions pkg/expression/exprctx/optional.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ const (
OptPropAdvisoryLock
// OptPropDDLOwnerInfo indicates to provide DDL owner information.
OptPropDDLOwnerInfo
// OptPropPrivilegeChecker indicates to provide the privilege checker.
OptPropPrivilegeChecker
// optPropsCnt is the count of optional properties. DO NOT use it as a property key.
optPropsCnt
)
Expand Down Expand Up @@ -147,6 +149,10 @@ var optionalPropertyDescList = []OptionalEvalPropDesc{
key: OptPropDDLOwnerInfo,
str: "OptPropDDLOwnerInfo",
},
{
key: OptPropPrivilegeChecker,
str: "OptPropPrivilegeChecker",
},
}

// OptionalEvalPropKeySet is a bit map for optional evaluation properties in EvalContext
Expand Down
3 changes: 2 additions & 1 deletion pkg/expression/exprctx/optional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ func TestOptionalPropKeySet(t *testing.T) {
Add(OptPropKVStore).
Add(OptPropSQLExecutor).
Add(OptPropSequenceOperator).
Add(OptPropAdvisoryLock)
Add(OptPropAdvisoryLock).
Add(OptPropPrivilegeChecker)
require.True(t, keySet4.IsFull())
require.False(t, keySet4.IsEmpty())
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/expression/expropt/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ go_library(
"infoschema.go",
"kvstore.go",
"optional.go",
"priv.go",
"sequence.go",
"sessionvars.go",
"sqlexec.go",
Expand All @@ -20,6 +21,7 @@ go_library(
"//pkg/infoschema/context",
"//pkg/kv",
"//pkg/parser/auth",
"//pkg/parser/mysql",
"//pkg/planner/core/resolve",
"//pkg/sessionctx/variable",
"//pkg/util/chunk",
Expand Down
3 changes: 3 additions & 0 deletions pkg/expression/expropt/optional.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ func (o *OptionalEvalPropProviders) Add(val exprctx.OptionalEvalPropProvider) {
case exprctx.OptPropSequenceOperator:
_, ok := val.(SequenceOperatorProvider)
intest.Assert(ok)
case exprctx.OptPropPrivilegeChecker:
_, ok := val.(PrivilegeCheckerProvider)
intest.Assert(ok)
default:
intest.Assert(false)
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/expression/expropt/optional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,21 @@ func TestOptionalEvalPropProviders(t *testing.T) {
isOwner = false
require.False(t, assertReaderFuncValue(t, ctx, r.IsDDLOwner))
}
case exprctx.OptPropPrivilegeChecker:
type mockPrivCheckerTp struct {
PrivilegeChecker
}
mockPrivChecker := &mockPrivCheckerTp{}
p = PrivilegeCheckerProvider(func() PrivilegeChecker { return mockPrivChecker })
r := PrivilegeCheckerPropReader{}
reader = r
verifyNoProvider = func(ctx exprctx.EvalContext) {
assertReaderFuncReturnErr(t, ctx, r.GetPrivilegeChecker)
}
verifyProvider = func(ctx exprctx.EvalContext, val exprctx.OptionalEvalPropProvider) {
require.Same(t, mockPrivChecker, assertReaderFuncValue(t, ctx, r.GetPrivilegeChecker))
require.Same(t, mockPrivChecker, val.(PrivilegeCheckerProvider)())
}
default:
require.Fail(t, "unexpected optional property key")
}
Expand Down
Loading

0 comments on commit a56674c

Please sign in to comment.