Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

br: have better crypter key error msg #56589

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion br/pkg/task/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ go_test(
],
embed = [":task"],
flaky = True,
shard_count = 38,
shard_count = 39,
deps = [
"//br/pkg/backup",
"//br/pkg/config",
Expand Down
59 changes: 38 additions & 21 deletions br/pkg/task/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/pingcap/tidb/br/pkg/metautil"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
filter "github.com/pingcap/tidb/pkg/util/table-filter"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -115,9 +114,7 @@ const (
)

const (
// Once TableInfoVersion updated. BR need to check compatibility with
// new TableInfoVersion. both snapshot restore and pitr need to be checked.
CURRENT_BACKUP_SUPPORT_TABLE_INFO_VERSION = model.TableInfoVersion5
cipherKeyNonHexErrorMsg = "cipher key must be a valid hexadecimal string"
)

// FullBackupType type when doing full backup or restore
Expand Down Expand Up @@ -464,34 +461,52 @@ func GetCipherKeyContent(cipherKey, cipherKeyFile string) ([]byte, error) {
return nil, errors.Trace(err)
}

// if cipher-key is valid, convert the hexadecimal string to bytes
var hexString string

// Check if cipher-key is provided directly
if len(cipherKey) > 0 {
return hex.DecodeString(cipherKey)
hexString = cipherKey
} else {
// Read content from cipher-file
content, err := os.ReadFile(cipherKeyFile)
if err != nil {
return nil, errors.Annotate(err, "failed to read cipher file")
}
hexString = string(bytes.TrimSuffix(content, []byte("\n")))
}

// convert the content(as hexadecimal string) from cipher-file to bytes
content, err := os.ReadFile(cipherKeyFile)
// Attempt to decode the hex string
decodedKey, err := hex.DecodeString(hexString)
if err != nil {
return nil, errors.Annotate(err, "failed to read cipher file")
return nil, errors.Annotate(berrors.ErrInvalidArgument, cipherKeyNonHexErrorMsg)
}

content = bytes.TrimSuffix(content, []byte("\n"))
return hex.DecodeString(string(content))
return decodedKey, nil
}

func checkCipherKeyMatch(cipher *backuppb.CipherInfo) bool {
func checkCipherKeyMatch(cipher *backuppb.CipherInfo) error {
switch cipher.CipherType {
case encryptionpb.EncryptionMethod_PLAINTEXT:
return true
return nil
case encryptionpb.EncryptionMethod_AES128_CTR:
return len(cipher.CipherKey) == crypterAES128KeyLen
if len(cipher.CipherKey) != crypterAES128KeyLen {
return errors.Annotatef(berrors.ErrInvalidArgument, "AES-128 key length mismatch: expected %d, got %d",
crypterAES128KeyLen, len(cipher.CipherKey))
}
case encryptionpb.EncryptionMethod_AES192_CTR:
return len(cipher.CipherKey) == crypterAES192KeyLen
if len(cipher.CipherKey) != crypterAES192KeyLen {
return errors.Annotatef(berrors.ErrInvalidArgument, "AES-192 key length mismatch: expected %d, got %d",
crypterAES192KeyLen, len(cipher.CipherKey))
}
case encryptionpb.EncryptionMethod_AES256_CTR:
return len(cipher.CipherKey) == crypterAES256KeyLen
if len(cipher.CipherKey) != crypterAES256KeyLen {
return errors.Annotatef(berrors.ErrInvalidArgument, "AES-256 key length mismatch: expected %d, got %d",
crypterAES256KeyLen, len(cipher.CipherKey))
}
default:
return false
return errors.Errorf("Unknown encryption method: %v", cipher.CipherType)
}
return nil
}

func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
Expand Down Expand Up @@ -524,8 +539,9 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
return errors.Trace(err)
}

if !checkCipherKeyMatch(&cfg.CipherInfo) {
return errors.Annotate(berrors.ErrInvalidArgument, "crypter method and key length not match")
err = checkCipherKeyMatch(&cfg.CipherInfo)
if err != nil {
return errors.Trace(err)
}

return nil
Expand Down Expand Up @@ -561,8 +577,9 @@ func (cfg *Config) parseLogBackupCipherInfo(flags *pflag.FlagSet) (bool, error)
return false, errors.Trace(err)
}

if !checkCipherKeyMatch(&cfg.CipherInfo) {
return false, errors.Annotate(berrors.ErrInvalidArgument, "log backup encryption method and key length not match")
err = checkCipherKeyMatch(&cfg.CipherInfo)
if err != nil {
return false, errors.Trace(err)
}

return true, nil
Expand Down
102 changes: 70 additions & 32 deletions br/pkg/task/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package task

import (
"encoding/hex"
"fmt"
"testing"

Expand Down Expand Up @@ -70,57 +69,89 @@ func TestStripingPDURL(t *testing.T) {

func TestCheckCipherKeyMatch(t *testing.T) {
cases := []struct {
CipherType encryptionpb.EncryptionMethod
CipherKey string
ok bool
name string
cipherInfo *backup.CipherInfo
expectErr bool
errMsg string
}{
{
CipherType: encryptionpb.EncryptionMethod_PLAINTEXT,
ok: true,
name: "PLAINTEXT",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_PLAINTEXT,
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_UNKNOWN,
ok: false,
name: "UNKNOWN",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_UNKNOWN,
},
expectErr: true,
errMsg: "Unknown encryption method: UNKNOWN",
},
{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: "0123456789abcdef0123456789abcdef",
ok: true,
name: "AES128_CTR valid",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: make([]byte, crypterAES128KeyLen),
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: "0123456789abcdef0123456789abcd",
ok: false,
name: "AES128_CTR invalid length",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: make([]byte, crypterAES128KeyLen-1),
},
expectErr: true,
errMsg: fmt.Sprintf("AES-128 key length mismatch: expected %d, got %d", crypterAES128KeyLen, crypterAES128KeyLen-1),
},
{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdef",
ok: true,
name: "AES192_CTR valid",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: make([]byte, crypterAES192KeyLen),
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdefff",
ok: false,
name: "AES192_CTR invalid length",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: make([]byte, crypterAES192KeyLen+1),
},
expectErr: true,
errMsg: fmt.Sprintf("AES-192 key length mismatch: expected %d, got %d", crypterAES192KeyLen, crypterAES192KeyLen+1),
},
{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
ok: true,
name: "AES256_CTR valid",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: make([]byte, crypterAES256KeyLen),
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: "",
ok: false,
name: "AES256_CTR invalid length",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: make([]byte, 0),
},
expectErr: true,
errMsg: fmt.Sprintf("AES-256 key length mismatch: expected %d, got %d", crypterAES256KeyLen, 0),
},
}

for _, c := range cases {
cipherKey, err := hex.DecodeString(c.CipherKey)
require.NoError(t, err)
require.Equal(t, c.ok, checkCipherKeyMatch(&backup.CipherInfo{
CipherType: c.CipherType,
CipherKey: cipherKey,
}))
t.Run(c.name, func(t *testing.T) {
err := checkCipherKeyMatch(c.cipherInfo)
if c.expectErr {
require.Error(t, err)
require.Contains(t, err.Error(), c.errMsg)
} else {
require.NoError(t, err)
}
})
}
}

Expand Down Expand Up @@ -162,6 +193,13 @@ func TestCheckCipherKey(t *testing.T) {
}
}

func TestGetCipherKey(t *testing.T) {
nonHexKey := "this is not a hex string"
_, err := GetCipherKeyContent(nonHexKey, "")
require.Error(t, err)
require.Contains(t, err.Error(), cipherKeyNonHexErrorMsg)
}

func must[T any](t T, err error) T {
if err != nil {
panic(err)
Expand Down
Loading