diff options
Diffstat (limited to 'src/lib/crypt/openssl.c')
| -rw-r--r-- | src/lib/crypt/openssl.c | 243 |
1 files changed, 204 insertions, 39 deletions
diff --git a/src/lib/crypt/openssl.c b/src/lib/crypt/openssl.c index 5eee2a13..13ed1c64 100644 --- a/src/lib/crypt/openssl.c +++ b/src/lib/crypt/openssl.c @@ -50,16 +50,29 @@ #define IS_EC_GROUP(str) (strcmp(str, "EC") == 0) #define IS_DH_GROUP(str) (strcmp(str, "DH") == 0) -#define HKDF_INFO_DHE "o7s-ossl-dhe" -#define HKDF_INFO_ENCAP "o7s-ossl-encap" -#define HKDF_SALT_LEN 32 /* SHA-256 output size */ +#define HKDF_INFO_DHE "o7s-ossl-dhe" +#define HKDF_INFO_ENCAP "o7s-ossl-encap" +#define HKDF_INFO_ROTATION "o7s-key-rotation" +#define HKDF_SALT_LEN 32 /* SHA-256 output size */ struct ossl_crypt_ctx { EVP_CIPHER_CTX * evp_ctx; const EVP_CIPHER * cipher; - uint8_t * key; int ivsz; int tagsz; + + struct { + uint8_t * cur; /* current key */ + uint8_t * prv; /* rotated key */ + } keys; + + struct { + uint32_t cntr; /* counter */ + uint32_t mask; /* phase mask */ + uint32_t age; /* counter within epoch */ + uint8_t phase; /* current key phase */ + uint8_t salt[HKDF_SALT_LEN]; + } rot; /* rotation logic */ }; struct kdf_info { @@ -70,6 +83,17 @@ struct kdf_info { buffer_t key; }; +/* Key rotation macros */ +#define HAS_PHASE_BIT_TOGGLED(ctx) \ + (((ctx)->rot.cntr & (ctx)->rot.mask) != \ + (((ctx)->rot.cntr - 1) & (ctx)->rot.mask)) + +#define HAS_GRACE_EXPIRED(ctx) \ + ((ctx)->rot.age >= ((ctx)->rot.mask >> 1)) + +#define ROTATION_TOO_RECENT(ctx) \ + ((ctx)->rot.age < ((ctx)->rot.mask - ((ctx)->rot.mask >> 2))) + /* Convert hash NID to OpenSSL digest name string for HKDF */ static const char * hash_nid_to_digest_name(int nid) { @@ -234,6 +258,119 @@ static int derive_key_hkdf(struct kdf_info * ki) return -ECRYPT; } +/* Key rotation helper functions implementation */ +static int should_rotate_key_rx(struct ossl_crypt_ctx * ctx, + uint8_t rx_phase) +{ + assert(ctx != NULL); + + /* Phase must have changed */ + if (rx_phase == ctx->rot.phase) + return 0; + + if (ROTATION_TOO_RECENT(ctx)) + return 0; + + return 1; +} + +static int rotate_key(struct ossl_crypt_ctx * ctx) +{ + struct kdf_info ki; + uint8_t * tmp; + + assert(ctx != NULL); + + /* Swap keys - move current to prev */ + tmp = ctx->keys.prv; + ctx->keys.prv = ctx->keys.cur; + + if (tmp != NULL) { + /* Reuse old prev_key memory for new key */ + ctx->keys.cur = tmp; + } else { + /* First rotation - allocate new memory */ + ctx->keys.cur = OPENSSL_secure_malloc(SYMMKEYSZ); + if (ctx->keys.cur == NULL) + return -ECRYPT; + } + + /* Derive new key from previous key using HKDF */ + ki.secret.data = ctx->keys.prv; + ki.secret.len = SYMMKEYSZ; + ki.nid = NID_sha256; + ki.salt.data = ctx->rot.salt; + ki.salt.len = HKDF_SALT_LEN; + ki.info.data = (uint8_t *) HKDF_INFO_ROTATION; + ki.info.len = strlen(HKDF_INFO_ROTATION); + ki.key.data = ctx->keys.cur; + ki.key.len = SYMMKEYSZ; + + if (derive_key_hkdf(&ki) != 0) + return -ECRYPT; + + ctx->rot.age = 0; + ctx->rot.phase = !ctx->rot.phase; + + return 0; +} + +static void cleanup_old_key(struct ossl_crypt_ctx * ctx) +{ + assert(ctx != NULL); + + if (ctx->keys.prv == NULL) + return; + + if (!HAS_GRACE_EXPIRED(ctx)) + return; + + OPENSSL_secure_clear_free(ctx->keys.prv, SYMMKEYSZ); + ctx->keys.prv = NULL; +} + +static int try_decrypt(struct ossl_crypt_ctx * ctx, + uint8_t * key, + uint8_t * iv, + uint8_t * input, + int in_sz, + uint8_t * out, + int * out_sz) +{ + uint8_t * tag; + int tmp_sz; + int ret; + + tag = input + in_sz; + + EVP_CIPHER_CTX_reset(ctx->evp_ctx); + + ret = EVP_DecryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, key, iv); + if (ret != 1) + return -1; + + if (ctx->tagsz > 0) { + ret = EVP_CIPHER_CTX_ctrl(ctx->evp_ctx, EVP_CTRL_AEAD_SET_TAG, + ctx->tagsz, tag); + if (ret != 1) + return -1; + } + + ret = EVP_DecryptUpdate(ctx->evp_ctx, out, &tmp_sz, input, in_sz); + if (ret != 1) + return -1; + + *out_sz = tmp_sz; + + ret = EVP_DecryptFinal_ex(ctx->evp_ctx, out + tmp_sz, &tmp_sz); + if (ret != 1) + return -1; + + *out_sz += tmp_sz; + + return 0; +} + /* * Derive the common secret from * - your public key pair (pkp) @@ -837,9 +974,16 @@ int openssl_encrypt(struct ossl_crypt_ctx * ctx, if (random_buffer(iv, ctx->ivsz) < 0) goto fail_encrypt; + /* Set IV bit 7 to current key phase (bit KEY_ROTATION_BIT of counter) */ + if (ctx->rot.cntr & ctx->rot.mask) + iv[0] |= 0x80; + else + iv[0] &= 0x7F; + EVP_CIPHER_CTX_reset(ctx->evp_ctx); - ret = EVP_EncryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, ctx->key, iv); + ret = EVP_EncryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, + ctx->keys.cur, iv); if (ret != 1) goto fail_encrypt; @@ -867,6 +1011,17 @@ int openssl_encrypt(struct ossl_crypt_ctx * ctx, out->len = (size_t) out_sz + ctx->ivsz; + /* Increment packet counter and check for key rotation */ + ctx->rot.cntr++; + ctx->rot.age++; + + if (HAS_PHASE_BIT_TOGGLED(ctx)) { + if (rotate_key(ctx) != 0) + goto fail_encrypt; + } + + cleanup_old_key(ctx); + return 0; fail_encrypt: free(out->data); @@ -879,13 +1034,11 @@ int openssl_decrypt(struct ossl_crypt_ctx * ctx, buffer_t in, buffer_t * out) { - uint8_t * ptr; - uint8_t * iv; - uint8_t * input; - int ret; - int out_sz; - int in_sz; - int tmp_sz; + uint8_t * iv; + uint8_t * input; + uint8_t rx_phase; + int out_sz; + int in_sz; assert(ctx != NULL); @@ -900,34 +1053,27 @@ int openssl_decrypt(struct ossl_crypt_ctx * ctx, goto fail_malloc; iv = in.data; - ptr = out->data; input = in.data + ctx->ivsz; - EVP_CIPHER_CTX_reset(ctx->evp_ctx); - - ret = EVP_DecryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, ctx->key, iv); - if (ret != 1) - goto fail_decrypt; + /* Extract phase from IV bit 7 and check for key rotation */ + rx_phase = (iv[0] & 0x80) ? 1 : 0; - /* For AEAD ciphers, set the expected authentication tag */ - if (ctx->tagsz > 0) { - uint8_t * tag = input + in_sz; - ret = EVP_CIPHER_CTX_ctrl(ctx->evp_ctx, EVP_CTRL_AEAD_SET_TAG, - ctx->tagsz, tag); - if (ret != 1) + if (should_rotate_key_rx(ctx, rx_phase)) { + if (rotate_key(ctx) != 0) goto fail_decrypt; } - ret = EVP_DecryptUpdate(ctx->evp_ctx, ptr, &tmp_sz, input, in_sz); - if (ret != 1) - goto fail_decrypt; - - out_sz = tmp_sz; - ret = EVP_DecryptFinal_ex(ctx->evp_ctx, ptr + tmp_sz, &tmp_sz); - if (ret != 1) - goto fail_decrypt; + ctx->rot.cntr++; + ctx->rot.age++; - out_sz += tmp_sz; + if (try_decrypt(ctx, ctx->keys.cur, iv, input, in_sz, out->data, + &out_sz) != 0) { + if (ctx->keys.prv == NULL) + goto fail_decrypt; + if (try_decrypt(ctx, ctx->keys.prv, iv, input, in_sz, + out->data, &out_sz) != 0) + goto fail_decrypt; + } assert(out_sz <= in_sz); @@ -954,11 +1100,18 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk) memset(ctx, 0, sizeof(*ctx)); - ctx->key = OPENSSL_secure_malloc(SYMMKEYSZ); - if (ctx->key == NULL) + ctx->keys.cur = OPENSSL_secure_malloc(SYMMKEYSZ); + if (ctx->keys.cur == NULL) goto fail_key; - memcpy(ctx->key, sk->key, SYMMKEYSZ); + memcpy(ctx->keys.cur, sk->key, SYMMKEYSZ); + + ctx->keys.prv = NULL; + + /* Derive rotation salt from initial shared secret */ + if (EVP_Digest(sk->key, SYMMKEYSZ, ctx->rot.salt, NULL, + EVP_sha256(), NULL) != 1) + goto fail_cipher; ctx->cipher = EVP_get_cipherbynid(sk->nid); if (ctx->cipher == NULL) @@ -970,6 +1123,15 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk) if (EVP_CIPHER_flags(ctx->cipher) & EVP_CIPH_FLAG_AEAD_CIPHER) ctx->tagsz = 16; /* Standard AEAD tag length (128 bits) */ + ctx->rot.cntr = 0; +#ifdef TEST_KEY_ROTATION_BIT + ctx->rot.mask = (1U << TEST_KEY_ROTATION_BIT); +#else + ctx->rot.mask = (1U << KEY_ROTATION_BIT); +#endif + ctx->rot.age = 0; + ctx->rot.phase = 0; + ctx->evp_ctx = EVP_CIPHER_CTX_new(); if (ctx->evp_ctx == NULL) goto fail_cipher; @@ -977,7 +1139,7 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk) return ctx; fail_cipher: - OPENSSL_secure_clear_free(ctx->key, SYMMKEYSZ); + OPENSSL_secure_clear_free(ctx->keys.cur, SYMMKEYSZ); fail_key: free(ctx); fail_malloc: @@ -989,8 +1151,11 @@ void openssl_crypt_destroy_ctx(struct ossl_crypt_ctx * ctx) if (ctx == NULL) return; - if (ctx->key != NULL) - OPENSSL_secure_clear_free(ctx->key, SYMMKEYSZ); + if (ctx->keys.cur != NULL) + OPENSSL_secure_clear_free(ctx->keys.cur, SYMMKEYSZ); + + if (ctx->keys.prv != NULL) + OPENSSL_secure_clear_free(ctx->keys.prv, SYMMKEYSZ); EVP_CIPHER_CTX_free(ctx->evp_ctx); free(ctx); |
