diff options
| -rw-r--r-- | cmake/lib/lib.cmake | 2 | ||||
| -rw-r--r-- | src/lib/config.h.in | 2 | ||||
| -rw-r--r-- | src/lib/crypt/openssl.c | 243 | ||||
| -rw-r--r-- | src/lib/tests/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | src/lib/tests/crypt_test.c | 178 |
5 files changed, 387 insertions, 39 deletions
diff --git a/cmake/lib/lib.cmake b/cmake/lib/lib.cmake index 674d8503..bb24d1b9 100644 --- a/cmake/lib/lib.cmake +++ b/cmake/lib/lib.cmake @@ -85,6 +85,8 @@ set(TPM_DEBUG_REPORT_INTERVAL 0 CACHE STRING "Interval at wich the TPM will report long running threads (s), 0 disables") set(TPM_DEBUG_ABORT_TIMEOUT 0 CACHE STRING "TPM abort process after a thread reaches this timeout (s), 0 disables") +set(KEY_ROTATION_BIT 20 CACHE STRING + "Bit position in packet counter that triggers key rotation (default 20 = every 2^20 packets)") if (HAVE_FUSE) set(PROC_FLOW_STATS TRUE CACHE BOOL diff --git a/src/lib/config.h.in b/src/lib/config.h.in index b34e6a7b..465068cb 100644 --- a/src/lib/config.h.in +++ b/src/lib/config.h.in @@ -97,3 +97,5 @@ #define ACKQ_SLOTS (@ACK_WHEEL_SLOTS@) #define ACKQ_RES (@ACK_WHEEL_RESOLUTION@) /* 2^N ns */ + +#define KEY_ROTATION_BIT (@KEY_ROTATION_BIT@) /* Bit for key rotation */ diff --git a/src/lib/crypt/openssl.c b/src/lib/crypt/openssl.c index 5eee2a13..13ed1c64 100644 --- a/src/lib/crypt/openssl.c +++ b/src/lib/crypt/openssl.c @@ -50,16 +50,29 @@ #define IS_EC_GROUP(str) (strcmp(str, "EC") == 0) #define IS_DH_GROUP(str) (strcmp(str, "DH") == 0) -#define HKDF_INFO_DHE "o7s-ossl-dhe" -#define HKDF_INFO_ENCAP "o7s-ossl-encap" -#define HKDF_SALT_LEN 32 /* SHA-256 output size */ +#define HKDF_INFO_DHE "o7s-ossl-dhe" +#define HKDF_INFO_ENCAP "o7s-ossl-encap" +#define HKDF_INFO_ROTATION "o7s-key-rotation" +#define HKDF_SALT_LEN 32 /* SHA-256 output size */ struct ossl_crypt_ctx { EVP_CIPHER_CTX * evp_ctx; const EVP_CIPHER * cipher; - uint8_t * key; int ivsz; int tagsz; + + struct { + uint8_t * cur; /* current key */ + uint8_t * prv; /* rotated key */ + } keys; + + struct { + uint32_t cntr; /* counter */ + uint32_t mask; /* phase mask */ + uint32_t age; /* counter within epoch */ + uint8_t phase; /* current key phase */ + uint8_t salt[HKDF_SALT_LEN]; + } rot; /* rotation logic */ }; struct kdf_info { @@ -70,6 +83,17 @@ struct kdf_info { buffer_t key; }; +/* Key rotation macros */ +#define HAS_PHASE_BIT_TOGGLED(ctx) \ + (((ctx)->rot.cntr & (ctx)->rot.mask) != \ + (((ctx)->rot.cntr - 1) & (ctx)->rot.mask)) + +#define HAS_GRACE_EXPIRED(ctx) \ + ((ctx)->rot.age >= ((ctx)->rot.mask >> 1)) + +#define ROTATION_TOO_RECENT(ctx) \ + ((ctx)->rot.age < ((ctx)->rot.mask - ((ctx)->rot.mask >> 2))) + /* Convert hash NID to OpenSSL digest name string for HKDF */ static const char * hash_nid_to_digest_name(int nid) { @@ -234,6 +258,119 @@ static int derive_key_hkdf(struct kdf_info * ki) return -ECRYPT; } +/* Key rotation helper functions implementation */ +static int should_rotate_key_rx(struct ossl_crypt_ctx * ctx, + uint8_t rx_phase) +{ + assert(ctx != NULL); + + /* Phase must have changed */ + if (rx_phase == ctx->rot.phase) + return 0; + + if (ROTATION_TOO_RECENT(ctx)) + return 0; + + return 1; +} + +static int rotate_key(struct ossl_crypt_ctx * ctx) +{ + struct kdf_info ki; + uint8_t * tmp; + + assert(ctx != NULL); + + /* Swap keys - move current to prev */ + tmp = ctx->keys.prv; + ctx->keys.prv = ctx->keys.cur; + + if (tmp != NULL) { + /* Reuse old prev_key memory for new key */ + ctx->keys.cur = tmp; + } else { + /* First rotation - allocate new memory */ + ctx->keys.cur = OPENSSL_secure_malloc(SYMMKEYSZ); + if (ctx->keys.cur == NULL) + return -ECRYPT; + } + + /* Derive new key from previous key using HKDF */ + ki.secret.data = ctx->keys.prv; + ki.secret.len = SYMMKEYSZ; + ki.nid = NID_sha256; + ki.salt.data = ctx->rot.salt; + ki.salt.len = HKDF_SALT_LEN; + ki.info.data = (uint8_t *) HKDF_INFO_ROTATION; + ki.info.len = strlen(HKDF_INFO_ROTATION); + ki.key.data = ctx->keys.cur; + ki.key.len = SYMMKEYSZ; + + if (derive_key_hkdf(&ki) != 0) + return -ECRYPT; + + ctx->rot.age = 0; + ctx->rot.phase = !ctx->rot.phase; + + return 0; +} + +static void cleanup_old_key(struct ossl_crypt_ctx * ctx) +{ + assert(ctx != NULL); + + if (ctx->keys.prv == NULL) + return; + + if (!HAS_GRACE_EXPIRED(ctx)) + return; + + OPENSSL_secure_clear_free(ctx->keys.prv, SYMMKEYSZ); + ctx->keys.prv = NULL; +} + +static int try_decrypt(struct ossl_crypt_ctx * ctx, + uint8_t * key, + uint8_t * iv, + uint8_t * input, + int in_sz, + uint8_t * out, + int * out_sz) +{ + uint8_t * tag; + int tmp_sz; + int ret; + + tag = input + in_sz; + + EVP_CIPHER_CTX_reset(ctx->evp_ctx); + + ret = EVP_DecryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, key, iv); + if (ret != 1) + return -1; + + if (ctx->tagsz > 0) { + ret = EVP_CIPHER_CTX_ctrl(ctx->evp_ctx, EVP_CTRL_AEAD_SET_TAG, + ctx->tagsz, tag); + if (ret != 1) + return -1; + } + + ret = EVP_DecryptUpdate(ctx->evp_ctx, out, &tmp_sz, input, in_sz); + if (ret != 1) + return -1; + + *out_sz = tmp_sz; + + ret = EVP_DecryptFinal_ex(ctx->evp_ctx, out + tmp_sz, &tmp_sz); + if (ret != 1) + return -1; + + *out_sz += tmp_sz; + + return 0; +} + /* * Derive the common secret from * - your public key pair (pkp) @@ -837,9 +974,16 @@ int openssl_encrypt(struct ossl_crypt_ctx * ctx, if (random_buffer(iv, ctx->ivsz) < 0) goto fail_encrypt; + /* Set IV bit 7 to current key phase (bit KEY_ROTATION_BIT of counter) */ + if (ctx->rot.cntr & ctx->rot.mask) + iv[0] |= 0x80; + else + iv[0] &= 0x7F; + EVP_CIPHER_CTX_reset(ctx->evp_ctx); - ret = EVP_EncryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, ctx->key, iv); + ret = EVP_EncryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, + ctx->keys.cur, iv); if (ret != 1) goto fail_encrypt; @@ -867,6 +1011,17 @@ int openssl_encrypt(struct ossl_crypt_ctx * ctx, out->len = (size_t) out_sz + ctx->ivsz; + /* Increment packet counter and check for key rotation */ + ctx->rot.cntr++; + ctx->rot.age++; + + if (HAS_PHASE_BIT_TOGGLED(ctx)) { + if (rotate_key(ctx) != 0) + goto fail_encrypt; + } + + cleanup_old_key(ctx); + return 0; fail_encrypt: free(out->data); @@ -879,13 +1034,11 @@ int openssl_decrypt(struct ossl_crypt_ctx * ctx, buffer_t in, buffer_t * out) { - uint8_t * ptr; - uint8_t * iv; - uint8_t * input; - int ret; - int out_sz; - int in_sz; - int tmp_sz; + uint8_t * iv; + uint8_t * input; + uint8_t rx_phase; + int out_sz; + int in_sz; assert(ctx != NULL); @@ -900,34 +1053,27 @@ int openssl_decrypt(struct ossl_crypt_ctx * ctx, goto fail_malloc; iv = in.data; - ptr = out->data; input = in.data + ctx->ivsz; - EVP_CIPHER_CTX_reset(ctx->evp_ctx); - - ret = EVP_DecryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, ctx->key, iv); - if (ret != 1) - goto fail_decrypt; + /* Extract phase from IV bit 7 and check for key rotation */ + rx_phase = (iv[0] & 0x80) ? 1 : 0; - /* For AEAD ciphers, set the expected authentication tag */ - if (ctx->tagsz > 0) { - uint8_t * tag = input + in_sz; - ret = EVP_CIPHER_CTX_ctrl(ctx->evp_ctx, EVP_CTRL_AEAD_SET_TAG, - ctx->tagsz, tag); - if (ret != 1) + if (should_rotate_key_rx(ctx, rx_phase)) { + if (rotate_key(ctx) != 0) goto fail_decrypt; } - ret = EVP_DecryptUpdate(ctx->evp_ctx, ptr, &tmp_sz, input, in_sz); - if (ret != 1) - goto fail_decrypt; - - out_sz = tmp_sz; - ret = EVP_DecryptFinal_ex(ctx->evp_ctx, ptr + tmp_sz, &tmp_sz); - if (ret != 1) - goto fail_decrypt; + ctx->rot.cntr++; + ctx->rot.age++; - out_sz += tmp_sz; + if (try_decrypt(ctx, ctx->keys.cur, iv, input, in_sz, out->data, + &out_sz) != 0) { + if (ctx->keys.prv == NULL) + goto fail_decrypt; + if (try_decrypt(ctx, ctx->keys.prv, iv, input, in_sz, + out->data, &out_sz) != 0) + goto fail_decrypt; + } assert(out_sz <= in_sz); @@ -954,11 +1100,18 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk) memset(ctx, 0, sizeof(*ctx)); - ctx->key = OPENSSL_secure_malloc(SYMMKEYSZ); - if (ctx->key == NULL) + ctx->keys.cur = OPENSSL_secure_malloc(SYMMKEYSZ); + if (ctx->keys.cur == NULL) goto fail_key; - memcpy(ctx->key, sk->key, SYMMKEYSZ); + memcpy(ctx->keys.cur, sk->key, SYMMKEYSZ); + + ctx->keys.prv = NULL; + + /* Derive rotation salt from initial shared secret */ + if (EVP_Digest(sk->key, SYMMKEYSZ, ctx->rot.salt, NULL, + EVP_sha256(), NULL) != 1) + goto fail_cipher; ctx->cipher = EVP_get_cipherbynid(sk->nid); if (ctx->cipher == NULL) @@ -970,6 +1123,15 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk) if (EVP_CIPHER_flags(ctx->cipher) & EVP_CIPH_FLAG_AEAD_CIPHER) ctx->tagsz = 16; /* Standard AEAD tag length (128 bits) */ + ctx->rot.cntr = 0; +#ifdef TEST_KEY_ROTATION_BIT + ctx->rot.mask = (1U << TEST_KEY_ROTATION_BIT); +#else + ctx->rot.mask = (1U << KEY_ROTATION_BIT); +#endif + ctx->rot.age = 0; + ctx->rot.phase = 0; + ctx->evp_ctx = EVP_CIPHER_CTX_new(); if (ctx->evp_ctx == NULL) goto fail_cipher; @@ -977,7 +1139,7 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk) return ctx; fail_cipher: - OPENSSL_secure_clear_free(ctx->key, SYMMKEYSZ); + OPENSSL_secure_clear_free(ctx->keys.cur, SYMMKEYSZ); fail_key: free(ctx); fail_malloc: @@ -989,8 +1151,11 @@ void openssl_crypt_destroy_ctx(struct ossl_crypt_ctx * ctx) if (ctx == NULL) return; - if (ctx->key != NULL) - OPENSSL_secure_clear_free(ctx->key, SYMMKEYSZ); + if (ctx->keys.cur != NULL) + OPENSSL_secure_clear_free(ctx->keys.cur, SYMMKEYSZ); + + if (ctx->keys.prv != NULL) + OPENSSL_secure_clear_free(ctx->keys.prv, SYMMKEYSZ); EVP_CIPHER_CTX_free(ctx->evp_ctx); free(ctx); diff --git a/src/lib/tests/CMakeLists.txt b/src/lib/tests/CMakeLists.txt index 6ab69bd1..fe4c1342 100644 --- a/src/lib/tests/CMakeLists.txt +++ b/src/lib/tests/CMakeLists.txt @@ -26,6 +26,7 @@ add_executable(${PARENT_DIR}_test ${${PARENT_DIR}_tests}) disable_test_logging_for_target(${PARENT_DIR}_test) target_link_libraries(${PARENT_DIR}_test ouroboros-common) +target_compile_definitions(${PARENT_DIR}_test PRIVATE TEST_KEY_ROTATION_BIT=10) add_dependencies(build_tests ${PARENT_DIR}_test) diff --git a/src/lib/tests/crypt_test.c b/src/lib/tests/crypt_test.c index 906059be..a24cde66 100644 --- a/src/lib/tests/crypt_test.c +++ b/src/lib/tests/crypt_test.c @@ -254,6 +254,182 @@ static int test_md_nid_values(void) } #endif +static int test_key_rotation(void) +{ + uint8_t pkt[TEST_PACKET_SIZE]; + struct crypt_ctx * tx_ctx; + struct crypt_ctx * rx_ctx; + uint8_t key[SYMMKEYSZ]; + struct crypt_sk sk = { + .nid = NID_aes_256_gcm, + .key = key + }; + buffer_t in; + buffer_t enc; + buffer_t dec; + uint32_t i; + uint32_t threshold; + + TEST_START(); + + if (random_buffer(key, sizeof(key)) < 0) { + printf("Failed to generate random key.\n"); + goto fail; + } + + if (random_buffer(pkt, sizeof(pkt)) < 0) { + printf("Failed to generate random data.\n"); + goto fail; + } + + tx_ctx = crypt_create_ctx(&sk); + if (tx_ctx == NULL) { + printf("Failed to create TX context.\n"); + goto fail; + } + + rx_ctx = crypt_create_ctx(&sk); + if (rx_ctx == NULL) { + printf("Failed to create RX context.\n"); + goto fail_tx; + } + + in.len = sizeof(pkt); + in.data = pkt; + + threshold = (1U << TEST_KEY_ROTATION_BIT); + + /* Encrypt and decrypt across multiple rotations */ + for (i = 0; i < threshold * 3; i++) { + if (crypt_encrypt(tx_ctx, in, &enc) < 0) { + printf("Encryption failed at packet %u.\n", i); + goto fail_rx; + } + + if (crypt_decrypt(rx_ctx, enc, &dec) < 0) { + printf("Decryption failed at packet %u.\n", i); + freebuf(enc); + goto fail_rx; + } + + if (dec.len != in.len || + memcmp(in.data, dec.data, in.len) != 0) { + printf("Data mismatch at packet %u.\n", i); + freebuf(dec); + freebuf(enc); + goto fail_rx; + } + + freebuf(dec); + freebuf(enc); + } + + crypt_destroy_ctx(rx_ctx); + crypt_destroy_ctx(tx_ctx); + + TEST_SUCCESS(); + + return TEST_RC_SUCCESS; + fail_rx: + crypt_destroy_ctx(rx_ctx); + fail_tx: + crypt_destroy_ctx(tx_ctx); + fail: + TEST_FAIL(); + return TEST_RC_FAIL; +} + +static int test_key_phase_bit(void) +{ + uint8_t pkt[TEST_PACKET_SIZE]; + struct crypt_ctx * ctx; + uint8_t key[SYMMKEYSZ]; + struct crypt_sk sk = { + .nid = NID_aes_256_gcm, + .key = key + }; + buffer_t in; + buffer_t out; + uint32_t count; + uint32_t threshold; + uint8_t phase_before; + uint8_t phase_after; + int ivsz; + + TEST_START(); + + if (random_buffer(key, sizeof(key)) < 0) { + printf("Failed to generate random key.\n"); + goto fail; + } + + if (random_buffer(pkt, sizeof(pkt)) < 0) { + printf("Failed to generate random data.\n"); + goto fail; + } + + ctx = crypt_create_ctx(&sk); + if (ctx == NULL) { + printf("Failed to initialize cryptography.\n"); + goto fail; + } + + ivsz = crypt_get_ivsz(ctx); + if (ivsz <= 0) { + printf("Invalid IV size.\n"); + goto fail_ctx; + } + + in.len = sizeof(pkt); + in.data = pkt; + + /* Encrypt packets up to just before rotation threshold */ + threshold = (1U << KEY_ROTATION_BIT); + + /* Encrypt threshold - 1 packets (indices 0 to threshold-2) */ + for (count = 0; count < threshold - 1; count++) { + if (crypt_encrypt(ctx, in, &out) < 0) { + printf("Encryption failed at count %u.\n", count); + goto fail_ctx; + } + freebuf(out); + } + + /* Packet at index threshold-1: phase should still be initial */ + if (crypt_encrypt(ctx, in, &out) < 0) { + printf("Encryption failed before rotation.\n"); + goto fail_ctx; + } + phase_before = (out.data[0] & 0x80) ? 1 : 0; + freebuf(out); + + /* Packet at index threshold: phase should have toggled */ + if (crypt_encrypt(ctx, in, &out) < 0) { + printf("Encryption failed at rotation threshold.\n"); + goto fail_ctx; + } + phase_after = (out.data[0] & 0x80) ? 1 : 0; + freebuf(out); + + /* Phase bit should have toggled */ + if (phase_before == phase_after) { + printf("Phase bit did not toggle: before=%u, after=%u.\n", + phase_before, phase_after); + goto fail_ctx; + } + + crypt_destroy_ctx(ctx); + + TEST_SUCCESS(); + + return TEST_RC_SUCCESS; + fail_ctx: + crypt_destroy_ctx(ctx); + fail: + TEST_FAIL(); + return TEST_RC_FAIL; +} + int crypt_test(int argc, char ** argv) { @@ -264,6 +440,8 @@ int crypt_test(int argc, ret |= test_crypt_create_destroy(); ret |= test_encrypt_decrypt_all(); + ret |= test_key_rotation(); + ret |= test_key_phase_bit(); #ifdef HAVE_OPENSSL ret |= test_cipher_nid_values(); ret |= test_md_nid_values(); |
