Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature](aes_encrypt) support GCM mode for aes_encrypt and aes_decrypt #40004

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 144 additions & 36 deletions be/src/util/encryption_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <algorithm>
#include <cstring>
#include <string>
#include <unordered_map>

namespace doris {

Expand Down Expand Up @@ -80,6 +81,12 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) {
return EVP_aes_256_ctr();
case EncryptionMode::AES_256_OFB:
return EVP_aes_256_ofb();
case EncryptionMode::AES_128_GCM:
return EVP_aes_128_gcm();
case EncryptionMode::AES_192_GCM:
return EVP_aes_192_gcm();
case EncryptionMode::AES_256_GCM:
return EVP_aes_256_gcm();
case EncryptionMode::SM4_128_CBC:
return EVP_sm4_cbc();
case EncryptionMode::SM4_128_ECB:
Expand All @@ -95,41 +102,29 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) {
}
}

static uint mode_key_sizes[] = {
128 /* AES_128_ECB */,
192 /* AES_192_ECB */,
256 /* AES_256_ECB */,
128 /* AES_128_CBC */,
192 /* AES_192_CBC */,
256 /* AES_256_CBC */,
128 /* AES_128_CFB */,
192 /* AES_192_CFB */,
256 /* AES_256_CFB */,
128 /* AES_128_CFB1 */,
192 /* AES_192_CFB1 */,
256 /* AES_256_CFB1 */,
128 /* AES_128_CFB8 */,
192 /* AES_192_CFB8 */,
256 /* AES_256_CFB8 */,
128 /* AES_128_CFB128 */,
192 /* AES_192_CFB128 */,
256 /* AES_256_CFB128 */,
128 /* AES_128_CTR */,
192 /* AES_192_CTR */,
256 /* AES_256_CTR */,
128 /* AES_128_OFB */,
192 /* AES_192_OFB */,
256 /* AES_256_OFB */,
128 /* SM4_128_ECB */,
128 /* SM4_128_CBC */,
128 /* SM4_128_CFB128 */,
128 /* SM4_128_OFB */,
128 /* SM4_128_CTR */
};
static std::unordered_map<EncryptionMode, uint> mode_key_sizes = {
{EncryptionMode::AES_128_ECB, 128}, {EncryptionMode::AES_192_ECB, 192},
{EncryptionMode::AES_256_ECB, 256}, {EncryptionMode::AES_128_CBC, 128},
{EncryptionMode::AES_192_CBC, 192}, {EncryptionMode::AES_256_CBC, 256},
{EncryptionMode::AES_128_CFB, 128}, {EncryptionMode::AES_192_CFB, 192},
{EncryptionMode::AES_256_CFB, 256}, {EncryptionMode::AES_128_CFB1, 128},
{EncryptionMode::AES_192_CFB1, 192}, {EncryptionMode::AES_256_CFB1, 256},
{EncryptionMode::AES_128_CFB8, 128}, {EncryptionMode::AES_192_CFB8, 192},
{EncryptionMode::AES_256_CFB8, 256}, {EncryptionMode::AES_128_CFB128, 128},
{EncryptionMode::AES_192_CFB128, 192}, {EncryptionMode::AES_256_CFB128, 256},
{EncryptionMode::AES_128_CTR, 128}, {EncryptionMode::AES_192_CTR, 192},
{EncryptionMode::AES_256_CTR, 256}, {EncryptionMode::AES_128_OFB, 128},
{EncryptionMode::AES_192_OFB, 192}, {EncryptionMode::AES_256_OFB, 256},
{EncryptionMode::AES_128_GCM, 128}, {EncryptionMode::AES_192_GCM, 192},
{EncryptionMode::AES_256_GCM, 256},

{EncryptionMode::SM4_128_ECB, 128}, {EncryptionMode::SM4_128_CBC, 128},
{EncryptionMode::SM4_128_CFB128, 128}, {EncryptionMode::SM4_128_OFB, 128},
{EncryptionMode::SM4_128_CTR, 128}};

static void create_key(const unsigned char* origin_key, uint32_t key_length, uint8_t* encrypt_key,
EncryptionMode mode) {
const uint key_size = mode_key_sizes[int(mode)] / 8;
const uint key_size = mode_key_sizes[mode] / 8;
uint8_t* origin_key_end = ((uint8_t*)origin_key) + key_length; /* origin key boundary*/

uint8_t* encrypt_key_end; /* encrypt key boundary */
Expand Down Expand Up @@ -172,10 +167,58 @@ static int do_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
return ret;
}

static int do_gcm_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
const unsigned char* source, uint32_t source_length,
const unsigned char* encrypt_key, const unsigned char* iv, int iv_length,
unsigned char* encrypt, int* length_ptr, const unsigned char* aad,
uint32_t aad_length) {
int ret = EVP_EncryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, nullptr);
if (ret != 1) {
return ret;
}
ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, nullptr);
if (ret != 1) {
return ret;
}
ret = EVP_EncryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, iv);
if (ret != 1) {
return ret;
}
if (aad) {
int tmp_len = 0;
ret = EVP_EncryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, aad_length);
if (ret != 1) {
return ret;
}
}

std::memcpy(encrypt, iv, iv_length);
encrypt += iv_length;

int u_len = 0;
ret = EVP_EncryptUpdate(cipher_ctx, encrypt, &u_len, source, source_length);
if (ret != 1) {
return ret;
}
encrypt += u_len;

int f_len = 0;
ret = EVP_EncryptFinal_ex(cipher_ctx, encrypt, &f_len);
if (ret != 1) {
return ret;
}
encrypt += f_len;

ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_GET_TAG, EncryptionUtil::GCM_TAG_SIZE,
encrypt);
*length_ptr = iv_length + u_len + f_len + EncryptionUtil::GCM_TAG_SIZE;
return ret;
}

int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source,
uint32_t source_length, const unsigned char* key, uint32_t key_length,
const char* iv_str, int iv_input_length, bool padding,
unsigned char* encrypt) {
unsigned char* encrypt, const unsigned char* aad, uint32_t aad_length) {
const EVP_CIPHER* cipher = get_evp_type(mode);
/* The encrypt key to be used for encryption */
unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8];
Expand All @@ -196,8 +239,16 @@ int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source,
EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
EVP_CIPHER_CTX_reset(cipher_ctx);
int length = 0;
int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
int ret = 0;
if (is_gcm_mode(mode)) {
ret = do_gcm_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
reinterpret_cast<unsigned char*>(init_vec), iv_length, encrypt,
&length, aad, aad_length);
} else {
ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
reinterpret_cast<unsigned char*>(init_vec), padding, encrypt, &length);
}

EVP_CIPHER_CTX_free(cipher_ctx);
if (ret == 0) {
ERR_clear_error();
Expand Down Expand Up @@ -230,10 +281,61 @@ static int do_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
return ret;
}

static int do_gcm_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
const unsigned char* encrypt, uint32_t encrypt_length,
const unsigned char* encrypt_key, int iv_length,
unsigned char* decrypt_content, int* length_ptr, const unsigned char* aad,
uint32_t aad_length) {
if (encrypt_length < iv_length + EncryptionUtil::GCM_TAG_SIZE) {
return -1;
}
int ret = EVP_DecryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, nullptr);
if (ret != 1) {
return ret;
}
ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, nullptr);
if (ret != 1) {
return ret;
}
ret = EVP_DecryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, encrypt);
if (ret != 1) {
return ret;
}
encrypt += iv_length;
if (aad) {
int tmp_len = 0;
ret = EVP_DecryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, aad_length);
if (ret != 1) {
return ret;
}
}

uint32_t real_encrypt_length = encrypt_length - iv_length - EncryptionUtil::GCM_TAG_SIZE;
int u_len = 0;
ret = EVP_DecryptUpdate(cipher_ctx, decrypt_content, &u_len, encrypt, real_encrypt_length);
if (ret != 1) {
return ret;
}
encrypt += real_encrypt_length;
decrypt_content += u_len;

void* tag = const_cast<void*>(reinterpret_cast<const void*>(encrypt));
ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_TAG, EncryptionUtil::GCM_TAG_SIZE, tag);
if (ret != 1) {
return ret;
}

int f_len = 0;
ret = EVP_DecryptFinal_ex(cipher_ctx, decrypt_content, &f_len);
*length_ptr = u_len + f_len;
return ret;
}

int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt,
uint32_t encrypt_length, const unsigned char* key, uint32_t key_length,
const char* iv_str, int iv_input_length, bool padding,
unsigned char* decrypt_content) {
unsigned char* decrypt_content, const unsigned char* aad,
uint32_t aad_length) {
const EVP_CIPHER* cipher = get_evp_type(mode);

/* The encrypt key to be used for decryption */
Expand All @@ -255,9 +357,15 @@ int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt,
EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
EVP_CIPHER_CTX_reset(cipher_ctx);
int length = 0;
int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key,
int ret = 0;
if (is_gcm_mode(mode)) {
ret = do_gcm_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, iv_length,
decrypt_content, &length, aad, aad_length);
} else {
ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key,
reinterpret_cast<unsigned char*>(init_vec), padding, decrypt_content,
&length);
}
EVP_CIPHER_CTX_free(cipher_ctx);
if (ret > 0) {
return length;
Expand Down
17 changes: 15 additions & 2 deletions be/src/util/encryption_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ enum class EncryptionMode {
AES_128_OFB,
AES_192_OFB,
AES_256_OFB,
AES_128_GCM,
AES_192_GCM,
AES_256_GCM,
SM4_128_ECB,
SM4_128_CBC,
SM4_128_CFB128,
Expand All @@ -57,13 +60,23 @@ enum EncryptionState { AES_SUCCESS = 0, AES_BAD_DATA = -1 };

class EncryptionUtil {
public:
static bool is_gcm_mode(EncryptionMode mode) {
return mode == EncryptionMode::AES_128_GCM || mode == EncryptionMode::AES_192_GCM ||
mode == EncryptionMode::AES_256_GCM;
}

// https://tools.ietf.org/html/rfc5116#section-5.1
static const int GCM_TAG_SIZE = 16;

static int encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length,
const unsigned char* key, uint32_t key_length, const char* iv_str,
int iv_input_length, bool padding, unsigned char* encrypt);
int iv_input_length, bool padding, unsigned char* encrypt,
const unsigned char* aad = nullptr, uint32_t aad_length = 0);

static int decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length,
const unsigned char* key, uint32_t key_length, const char* iv_str,
int iv_input_length, bool padding, unsigned char* decrypt_content);
int iv_input_length, bool padding, unsigned char* decrypt_content,
const unsigned char* aad = nullptr, uint32_t aad_length = 0);
};

} // namespace doris
Loading
Loading