summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/lib/config.h.in2
-rw-r--r--src/lib/crypt/openssl.c243
-rw-r--r--src/lib/tests/CMakeLists.txt1
-rw-r--r--src/lib/tests/crypt_test.c178
4 files changed, 385 insertions, 39 deletions
diff --git a/src/lib/config.h.in b/src/lib/config.h.in
index b34e6a7b..465068cb 100644
--- a/src/lib/config.h.in
+++ b/src/lib/config.h.in
@@ -97,3 +97,5 @@
#define ACKQ_SLOTS (@ACK_WHEEL_SLOTS@)
#define ACKQ_RES (@ACK_WHEEL_RESOLUTION@) /* 2^N ns */
+
+#define KEY_ROTATION_BIT (@KEY_ROTATION_BIT@) /* Bit for key rotation */
diff --git a/src/lib/crypt/openssl.c b/src/lib/crypt/openssl.c
index 5eee2a13..13ed1c64 100644
--- a/src/lib/crypt/openssl.c
+++ b/src/lib/crypt/openssl.c
@@ -50,16 +50,29 @@
#define IS_EC_GROUP(str) (strcmp(str, "EC") == 0)
#define IS_DH_GROUP(str) (strcmp(str, "DH") == 0)
-#define HKDF_INFO_DHE "o7s-ossl-dhe"
-#define HKDF_INFO_ENCAP "o7s-ossl-encap"
-#define HKDF_SALT_LEN 32 /* SHA-256 output size */
+#define HKDF_INFO_DHE "o7s-ossl-dhe"
+#define HKDF_INFO_ENCAP "o7s-ossl-encap"
+#define HKDF_INFO_ROTATION "o7s-key-rotation"
+#define HKDF_SALT_LEN 32 /* SHA-256 output size */
struct ossl_crypt_ctx {
EVP_CIPHER_CTX * evp_ctx;
const EVP_CIPHER * cipher;
- uint8_t * key;
int ivsz;
int tagsz;
+
+ struct {
+ uint8_t * cur; /* current key */
+ uint8_t * prv; /* rotated key */
+ } keys;
+
+ struct {
+ uint32_t cntr; /* counter */
+ uint32_t mask; /* phase mask */
+ uint32_t age; /* counter within epoch */
+ uint8_t phase; /* current key phase */
+ uint8_t salt[HKDF_SALT_LEN];
+ } rot; /* rotation logic */
};
struct kdf_info {
@@ -70,6 +83,17 @@ struct kdf_info {
buffer_t key;
};
+/* Key rotation macros */
+#define HAS_PHASE_BIT_TOGGLED(ctx) \
+ (((ctx)->rot.cntr & (ctx)->rot.mask) != \
+ (((ctx)->rot.cntr - 1) & (ctx)->rot.mask))
+
+#define HAS_GRACE_EXPIRED(ctx) \
+ ((ctx)->rot.age >= ((ctx)->rot.mask >> 1))
+
+#define ROTATION_TOO_RECENT(ctx) \
+ ((ctx)->rot.age < ((ctx)->rot.mask - ((ctx)->rot.mask >> 2)))
+
/* Convert hash NID to OpenSSL digest name string for HKDF */
static const char * hash_nid_to_digest_name(int nid)
{
@@ -234,6 +258,119 @@ static int derive_key_hkdf(struct kdf_info * ki)
return -ECRYPT;
}
+/* Key rotation helper functions implementation */
+static int should_rotate_key_rx(struct ossl_crypt_ctx * ctx,
+ uint8_t rx_phase)
+{
+ assert(ctx != NULL);
+
+ /* Phase must have changed */
+ if (rx_phase == ctx->rot.phase)
+ return 0;
+
+ if (ROTATION_TOO_RECENT(ctx))
+ return 0;
+
+ return 1;
+}
+
+static int rotate_key(struct ossl_crypt_ctx * ctx)
+{
+ struct kdf_info ki;
+ uint8_t * tmp;
+
+ assert(ctx != NULL);
+
+ /* Swap keys - move current to prev */
+ tmp = ctx->keys.prv;
+ ctx->keys.prv = ctx->keys.cur;
+
+ if (tmp != NULL) {
+ /* Reuse old prev_key memory for new key */
+ ctx->keys.cur = tmp;
+ } else {
+ /* First rotation - allocate new memory */
+ ctx->keys.cur = OPENSSL_secure_malloc(SYMMKEYSZ);
+ if (ctx->keys.cur == NULL)
+ return -ECRYPT;
+ }
+
+ /* Derive new key from previous key using HKDF */
+ ki.secret.data = ctx->keys.prv;
+ ki.secret.len = SYMMKEYSZ;
+ ki.nid = NID_sha256;
+ ki.salt.data = ctx->rot.salt;
+ ki.salt.len = HKDF_SALT_LEN;
+ ki.info.data = (uint8_t *) HKDF_INFO_ROTATION;
+ ki.info.len = strlen(HKDF_INFO_ROTATION);
+ ki.key.data = ctx->keys.cur;
+ ki.key.len = SYMMKEYSZ;
+
+ if (derive_key_hkdf(&ki) != 0)
+ return -ECRYPT;
+
+ ctx->rot.age = 0;
+ ctx->rot.phase = !ctx->rot.phase;
+
+ return 0;
+}
+
+static void cleanup_old_key(struct ossl_crypt_ctx * ctx)
+{
+ assert(ctx != NULL);
+
+ if (ctx->keys.prv == NULL)
+ return;
+
+ if (!HAS_GRACE_EXPIRED(ctx))
+ return;
+
+ OPENSSL_secure_clear_free(ctx->keys.prv, SYMMKEYSZ);
+ ctx->keys.prv = NULL;
+}
+
+static int try_decrypt(struct ossl_crypt_ctx * ctx,
+ uint8_t * key,
+ uint8_t * iv,
+ uint8_t * input,
+ int in_sz,
+ uint8_t * out,
+ int * out_sz)
+{
+ uint8_t * tag;
+ int tmp_sz;
+ int ret;
+
+ tag = input + in_sz;
+
+ EVP_CIPHER_CTX_reset(ctx->evp_ctx);
+
+ ret = EVP_DecryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, key, iv);
+ if (ret != 1)
+ return -1;
+
+ if (ctx->tagsz > 0) {
+ ret = EVP_CIPHER_CTX_ctrl(ctx->evp_ctx, EVP_CTRL_AEAD_SET_TAG,
+ ctx->tagsz, tag);
+ if (ret != 1)
+ return -1;
+ }
+
+ ret = EVP_DecryptUpdate(ctx->evp_ctx, out, &tmp_sz, input, in_sz);
+ if (ret != 1)
+ return -1;
+
+ *out_sz = tmp_sz;
+
+ ret = EVP_DecryptFinal_ex(ctx->evp_ctx, out + tmp_sz, &tmp_sz);
+ if (ret != 1)
+ return -1;
+
+ *out_sz += tmp_sz;
+
+ return 0;
+}
+
/*
* Derive the common secret from
* - your public key pair (pkp)
@@ -837,9 +974,16 @@ int openssl_encrypt(struct ossl_crypt_ctx * ctx,
if (random_buffer(iv, ctx->ivsz) < 0)
goto fail_encrypt;
+ /* Set IV bit 7 to current key phase (bit KEY_ROTATION_BIT of counter) */
+ if (ctx->rot.cntr & ctx->rot.mask)
+ iv[0] |= 0x80;
+ else
+ iv[0] &= 0x7F;
+
EVP_CIPHER_CTX_reset(ctx->evp_ctx);
- ret = EVP_EncryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, ctx->key, iv);
+ ret = EVP_EncryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL,
+ ctx->keys.cur, iv);
if (ret != 1)
goto fail_encrypt;
@@ -867,6 +1011,17 @@ int openssl_encrypt(struct ossl_crypt_ctx * ctx,
out->len = (size_t) out_sz + ctx->ivsz;
+ /* Increment packet counter and check for key rotation */
+ ctx->rot.cntr++;
+ ctx->rot.age++;
+
+ if (HAS_PHASE_BIT_TOGGLED(ctx)) {
+ if (rotate_key(ctx) != 0)
+ goto fail_encrypt;
+ }
+
+ cleanup_old_key(ctx);
+
return 0;
fail_encrypt:
free(out->data);
@@ -879,13 +1034,11 @@ int openssl_decrypt(struct ossl_crypt_ctx * ctx,
buffer_t in,
buffer_t * out)
{
- uint8_t * ptr;
- uint8_t * iv;
- uint8_t * input;
- int ret;
- int out_sz;
- int in_sz;
- int tmp_sz;
+ uint8_t * iv;
+ uint8_t * input;
+ uint8_t rx_phase;
+ int out_sz;
+ int in_sz;
assert(ctx != NULL);
@@ -900,34 +1053,27 @@ int openssl_decrypt(struct ossl_crypt_ctx * ctx,
goto fail_malloc;
iv = in.data;
- ptr = out->data;
input = in.data + ctx->ivsz;
- EVP_CIPHER_CTX_reset(ctx->evp_ctx);
-
- ret = EVP_DecryptInit_ex(ctx->evp_ctx, ctx->cipher, NULL, ctx->key, iv);
- if (ret != 1)
- goto fail_decrypt;
+ /* Extract phase from IV bit 7 and check for key rotation */
+ rx_phase = (iv[0] & 0x80) ? 1 : 0;
- /* For AEAD ciphers, set the expected authentication tag */
- if (ctx->tagsz > 0) {
- uint8_t * tag = input + in_sz;
- ret = EVP_CIPHER_CTX_ctrl(ctx->evp_ctx, EVP_CTRL_AEAD_SET_TAG,
- ctx->tagsz, tag);
- if (ret != 1)
+ if (should_rotate_key_rx(ctx, rx_phase)) {
+ if (rotate_key(ctx) != 0)
goto fail_decrypt;
}
- ret = EVP_DecryptUpdate(ctx->evp_ctx, ptr, &tmp_sz, input, in_sz);
- if (ret != 1)
- goto fail_decrypt;
-
- out_sz = tmp_sz;
- ret = EVP_DecryptFinal_ex(ctx->evp_ctx, ptr + tmp_sz, &tmp_sz);
- if (ret != 1)
- goto fail_decrypt;
+ ctx->rot.cntr++;
+ ctx->rot.age++;
- out_sz += tmp_sz;
+ if (try_decrypt(ctx, ctx->keys.cur, iv, input, in_sz, out->data,
+ &out_sz) != 0) {
+ if (ctx->keys.prv == NULL)
+ goto fail_decrypt;
+ if (try_decrypt(ctx, ctx->keys.prv, iv, input, in_sz,
+ out->data, &out_sz) != 0)
+ goto fail_decrypt;
+ }
assert(out_sz <= in_sz);
@@ -954,11 +1100,18 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk)
memset(ctx, 0, sizeof(*ctx));
- ctx->key = OPENSSL_secure_malloc(SYMMKEYSZ);
- if (ctx->key == NULL)
+ ctx->keys.cur = OPENSSL_secure_malloc(SYMMKEYSZ);
+ if (ctx->keys.cur == NULL)
goto fail_key;
- memcpy(ctx->key, sk->key, SYMMKEYSZ);
+ memcpy(ctx->keys.cur, sk->key, SYMMKEYSZ);
+
+ ctx->keys.prv = NULL;
+
+ /* Derive rotation salt from initial shared secret */
+ if (EVP_Digest(sk->key, SYMMKEYSZ, ctx->rot.salt, NULL,
+ EVP_sha256(), NULL) != 1)
+ goto fail_cipher;
ctx->cipher = EVP_get_cipherbynid(sk->nid);
if (ctx->cipher == NULL)
@@ -970,6 +1123,15 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk)
if (EVP_CIPHER_flags(ctx->cipher) & EVP_CIPH_FLAG_AEAD_CIPHER)
ctx->tagsz = 16; /* Standard AEAD tag length (128 bits) */
+ ctx->rot.cntr = 0;
+#ifdef TEST_KEY_ROTATION_BIT
+ ctx->rot.mask = (1U << TEST_KEY_ROTATION_BIT);
+#else
+ ctx->rot.mask = (1U << KEY_ROTATION_BIT);
+#endif
+ ctx->rot.age = 0;
+ ctx->rot.phase = 0;
+
ctx->evp_ctx = EVP_CIPHER_CTX_new();
if (ctx->evp_ctx == NULL)
goto fail_cipher;
@@ -977,7 +1139,7 @@ struct ossl_crypt_ctx * openssl_crypt_create_ctx(struct crypt_sk * sk)
return ctx;
fail_cipher:
- OPENSSL_secure_clear_free(ctx->key, SYMMKEYSZ);
+ OPENSSL_secure_clear_free(ctx->keys.cur, SYMMKEYSZ);
fail_key:
free(ctx);
fail_malloc:
@@ -989,8 +1151,11 @@ void openssl_crypt_destroy_ctx(struct ossl_crypt_ctx * ctx)
if (ctx == NULL)
return;
- if (ctx->key != NULL)
- OPENSSL_secure_clear_free(ctx->key, SYMMKEYSZ);
+ if (ctx->keys.cur != NULL)
+ OPENSSL_secure_clear_free(ctx->keys.cur, SYMMKEYSZ);
+
+ if (ctx->keys.prv != NULL)
+ OPENSSL_secure_clear_free(ctx->keys.prv, SYMMKEYSZ);
EVP_CIPHER_CTX_free(ctx->evp_ctx);
free(ctx);
diff --git a/src/lib/tests/CMakeLists.txt b/src/lib/tests/CMakeLists.txt
index 6ab69bd1..fe4c1342 100644
--- a/src/lib/tests/CMakeLists.txt
+++ b/src/lib/tests/CMakeLists.txt
@@ -26,6 +26,7 @@ add_executable(${PARENT_DIR}_test ${${PARENT_DIR}_tests})
disable_test_logging_for_target(${PARENT_DIR}_test)
target_link_libraries(${PARENT_DIR}_test ouroboros-common)
+target_compile_definitions(${PARENT_DIR}_test PRIVATE TEST_KEY_ROTATION_BIT=10)
add_dependencies(build_tests ${PARENT_DIR}_test)
diff --git a/src/lib/tests/crypt_test.c b/src/lib/tests/crypt_test.c
index 906059be..a24cde66 100644
--- a/src/lib/tests/crypt_test.c
+++ b/src/lib/tests/crypt_test.c
@@ -254,6 +254,182 @@ static int test_md_nid_values(void)
}
#endif
+static int test_key_rotation(void)
+{
+ uint8_t pkt[TEST_PACKET_SIZE];
+ struct crypt_ctx * tx_ctx;
+ struct crypt_ctx * rx_ctx;
+ uint8_t key[SYMMKEYSZ];
+ struct crypt_sk sk = {
+ .nid = NID_aes_256_gcm,
+ .key = key
+ };
+ buffer_t in;
+ buffer_t enc;
+ buffer_t dec;
+ uint32_t i;
+ uint32_t threshold;
+
+ TEST_START();
+
+ if (random_buffer(key, sizeof(key)) < 0) {
+ printf("Failed to generate random key.\n");
+ goto fail;
+ }
+
+ if (random_buffer(pkt, sizeof(pkt)) < 0) {
+ printf("Failed to generate random data.\n");
+ goto fail;
+ }
+
+ tx_ctx = crypt_create_ctx(&sk);
+ if (tx_ctx == NULL) {
+ printf("Failed to create TX context.\n");
+ goto fail;
+ }
+
+ rx_ctx = crypt_create_ctx(&sk);
+ if (rx_ctx == NULL) {
+ printf("Failed to create RX context.\n");
+ goto fail_tx;
+ }
+
+ in.len = sizeof(pkt);
+ in.data = pkt;
+
+ threshold = (1U << TEST_KEY_ROTATION_BIT);
+
+ /* Encrypt and decrypt across multiple rotations */
+ for (i = 0; i < threshold * 3; i++) {
+ if (crypt_encrypt(tx_ctx, in, &enc) < 0) {
+ printf("Encryption failed at packet %u.\n", i);
+ goto fail_rx;
+ }
+
+ if (crypt_decrypt(rx_ctx, enc, &dec) < 0) {
+ printf("Decryption failed at packet %u.\n", i);
+ freebuf(enc);
+ goto fail_rx;
+ }
+
+ if (dec.len != in.len ||
+ memcmp(in.data, dec.data, in.len) != 0) {
+ printf("Data mismatch at packet %u.\n", i);
+ freebuf(dec);
+ freebuf(enc);
+ goto fail_rx;
+ }
+
+ freebuf(dec);
+ freebuf(enc);
+ }
+
+ crypt_destroy_ctx(rx_ctx);
+ crypt_destroy_ctx(tx_ctx);
+
+ TEST_SUCCESS();
+
+ return TEST_RC_SUCCESS;
+ fail_rx:
+ crypt_destroy_ctx(rx_ctx);
+ fail_tx:
+ crypt_destroy_ctx(tx_ctx);
+ fail:
+ TEST_FAIL();
+ return TEST_RC_FAIL;
+}
+
+static int test_key_phase_bit(void)
+{
+ uint8_t pkt[TEST_PACKET_SIZE];
+ struct crypt_ctx * ctx;
+ uint8_t key[SYMMKEYSZ];
+ struct crypt_sk sk = {
+ .nid = NID_aes_256_gcm,
+ .key = key
+ };
+ buffer_t in;
+ buffer_t out;
+ uint32_t count;
+ uint32_t threshold;
+ uint8_t phase_before;
+ uint8_t phase_after;
+ int ivsz;
+
+ TEST_START();
+
+ if (random_buffer(key, sizeof(key)) < 0) {
+ printf("Failed to generate random key.\n");
+ goto fail;
+ }
+
+ if (random_buffer(pkt, sizeof(pkt)) < 0) {
+ printf("Failed to generate random data.\n");
+ goto fail;
+ }
+
+ ctx = crypt_create_ctx(&sk);
+ if (ctx == NULL) {
+ printf("Failed to initialize cryptography.\n");
+ goto fail;
+ }
+
+ ivsz = crypt_get_ivsz(ctx);
+ if (ivsz <= 0) {
+ printf("Invalid IV size.\n");
+ goto fail_ctx;
+ }
+
+ in.len = sizeof(pkt);
+ in.data = pkt;
+
+ /* Encrypt packets up to just before rotation threshold */
+ threshold = (1U << KEY_ROTATION_BIT);
+
+ /* Encrypt threshold - 1 packets (indices 0 to threshold-2) */
+ for (count = 0; count < threshold - 1; count++) {
+ if (crypt_encrypt(ctx, in, &out) < 0) {
+ printf("Encryption failed at count %u.\n", count);
+ goto fail_ctx;
+ }
+ freebuf(out);
+ }
+
+ /* Packet at index threshold-1: phase should still be initial */
+ if (crypt_encrypt(ctx, in, &out) < 0) {
+ printf("Encryption failed before rotation.\n");
+ goto fail_ctx;
+ }
+ phase_before = (out.data[0] & 0x80) ? 1 : 0;
+ freebuf(out);
+
+ /* Packet at index threshold: phase should have toggled */
+ if (crypt_encrypt(ctx, in, &out) < 0) {
+ printf("Encryption failed at rotation threshold.\n");
+ goto fail_ctx;
+ }
+ phase_after = (out.data[0] & 0x80) ? 1 : 0;
+ freebuf(out);
+
+ /* Phase bit should have toggled */
+ if (phase_before == phase_after) {
+ printf("Phase bit did not toggle: before=%u, after=%u.\n",
+ phase_before, phase_after);
+ goto fail_ctx;
+ }
+
+ crypt_destroy_ctx(ctx);
+
+ TEST_SUCCESS();
+
+ return TEST_RC_SUCCESS;
+ fail_ctx:
+ crypt_destroy_ctx(ctx);
+ fail:
+ TEST_FAIL();
+ return TEST_RC_FAIL;
+}
+
int crypt_test(int argc,
char ** argv)
{
@@ -264,6 +440,8 @@ int crypt_test(int argc,
ret |= test_crypt_create_destroy();
ret |= test_encrypt_decrypt_all();
+ ret |= test_key_rotation();
+ ret |= test_key_phase_bit();
#ifdef HAVE_OPENSSL
ret |= test_cipher_nid_values();
ret |= test_md_nid_values();