diff --git a/benchmark/bench_modules/wh_bench_mod_mlkem.c b/benchmark/bench_modules/wh_bench_mod_mlkem.c index def29ccc..5d75b6a8 100644 --- a/benchmark/bench_modules/wh_bench_mod_mlkem.c +++ b/benchmark/bench_modules/wh_bench_mod_mlkem.c @@ -24,7 +24,7 @@ #include "wolfhsm/wh_client_crypto.h" #if !defined(WOLFHSM_CFG_NO_CRYPTO) && defined(WOLFHSM_CFG_BENCH_ENABLE) -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #if defined(WOLFSSL_HAVE_MLKEM) @@ -53,7 +53,7 @@ static int _benchMlKemKeyGen(whClientContext* client, whBenchOpContext* ctx, else #endif /* WOLFHSM_CFG_DMA */ { - ret = wh_Client_MlKemMakeExportKey(client, securityLevel, 0, key); + ret = wh_Client_MlKemMakeExportKey(client, securityLevel, key); } benchStopRet = wh_Bench_StopOp(ctx, id); @@ -97,7 +97,7 @@ static int _benchMlKemEncaps(whClientContext* client, whBenchOpContext* ctx, else #endif /* WOLFHSM_CFG_DMA */ { - ret = wh_Client_MlKemMakeExportKey(client, securityLevel, 0, key); + ret = wh_Client_MlKemMakeExportKey(client, securityLevel, key); } if (ret != WH_ERROR_OK) { WH_BENCH_PRINTF("Failed ML-KEM key setup %d\n", ret); @@ -169,7 +169,7 @@ static int _benchMlKemDecaps(whClientContext* client, whBenchOpContext* ctx, else #endif /* WOLFHSM_CFG_DMA */ { - ret = wh_Client_MlKemMakeExportKey(client, securityLevel, 0, key); + ret = wh_Client_MlKemMakeExportKey(client, securityLevel, key); } if (ret != WH_ERROR_OK) { WH_BENCH_PRINTF("Failed ML-KEM key setup %d\n", ret); diff --git a/src/wh_client_crypto.c b/src/wh_client_crypto.c index e28fc390..60e91005 100644 --- a/src/wh_client_crypto.c +++ b/src/wh_client_crypto.c @@ -55,7 +55,6 @@ #include "wolfssl/wolfcrypt/curve25519.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" -#include "wolfssl/wolfcrypt/mlkem.h" #include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/sha512.h" @@ -127,7 +126,7 @@ static int _MlDsaMakeKeyDma(whClientContext* ctx, int level, #endif /* HAVE_DILITHIUM */ #ifdef WOLFSSL_HAVE_MLKEM -static int _MlKemMakeKey(whClientContext* ctx, int size, int level, +static int _MlKemMakeKey(whClientContext* ctx, int level, whKeyId* inout_key_id, whNvmFlags flags, uint16_t label_len, uint8_t* label, MlKemKey* key); #ifdef WOLFHSM_CFG_DMA @@ -6484,7 +6483,7 @@ int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + byte* buffer = NULL; uint16_t buffer_len = 0; if ((ctx == NULL) || (key == NULL) || @@ -6492,11 +6491,20 @@ int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, return WH_ERROR_BADARGS; } + buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + if (inout_keyId != NULL) { key_id = *inout_keyId; } - ret = wh_Crypto_MlKemSerializeKey(key, sizeof(buffer), buffer, &buffer_len); + ret = wh_Crypto_MlKemSerializeKey(key, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, + buffer, &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKey: serialize ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); if (ret == WH_ERROR_OK) { ret = wh_Client_KeyCache(ctx, flags, label, label_len, buffer, buffer_len, &key_id); @@ -6504,31 +6512,45 @@ int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, *inout_keyId = key_id; } } + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKey: ret:%d keyId:%u\n", ret, key_id); + ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return ret; } int wh_Client_MlKemExportKey(whClientContext* ctx, whKeyId keyId, MlKemKey* key, uint16_t label_len, uint8_t* label) { - int ret = WH_ERROR_OK; - byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; - uint16_t buffer_len = sizeof(buffer); + int ret = WH_ERROR_OK; + byte* buffer = NULL; + uint16_t buffer_len = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { return WH_ERROR_BADARGS; } + buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + ret = wh_Client_KeyExport(ctx, keyId, label, label_len, buffer, &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKey: export ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); if (ret == WH_ERROR_OK) { ret = wh_Crypto_MlKemDeserializeKey(buffer, buffer_len, key); } + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKey: keyId:%x ret:%d\n", keyId, ret); + ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return ret; } -static int _MlKemMakeKey(whClientContext* ctx, int size, int level, +static int _MlKemMakeKey(whClientContext* ctx, int level, whKeyId* inout_key_id, whNvmFlags flags, uint16_t label_len, uint8_t* label, MlKemKey* key) { @@ -6555,63 +6577,71 @@ static int _MlKemMakeKey(whClientContext* ctx, int size, int level, key_id = *inout_key_id; } - if (ret == WH_ERROR_OK) { + { uint16_t group = WH_MESSAGE_GROUP_CRYPTO; uint16_t action = WC_ALGO_TYPE_PK; uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + uint16_t res_len; - if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) { - memset(req, 0, sizeof(*req)); - req->level = level; - req->sz = size; - req->flags = flags; - req->keyId = key_id; - if ((label != NULL) && (label_len > 0)) { - if (label_len > WH_NVM_LABEL_LEN) { - label_len = WH_NVM_LABEL_LEN; - } - memcpy(req->label, label, label_len); + /* Defense in depth: ensure request fits in comm buffer */ + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->level = level; + req->flags = flags; + req->keyId = key_id; + if ((label != NULL) && (label_len > 0)) { + if (label_len > WH_NVM_LABEL_LEN) { + label_len = WH_NVM_LABEL_LEN; } + memcpy(req->label, label, label_len); + } - ret = wh_Client_SendRequest(ctx, group, action, req_len, - (uint8_t*)dataPtr); - if (ret == WH_ERROR_OK) { - uint16_t res_len; - do { - ret = wh_Client_RecvResponse(ctx, &group, &action, &res_len, - (uint8_t*)dataPtr); - } while (ret == WH_ERROR_NOTREADY); + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemMakeKey: Req sent:level:%d, " + "ret:%d\n", level, ret); + if (ret != WH_ERROR_OK) { + return ret; + } - if (ret == WH_ERROR_OK) { - ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_KEYGEN, - (uint8_t**)&res); - if (ret >= 0) { - key_id = (whKeyId)res->keyId; - if (inout_key_id != NULL) { - *inout_key_id = key_id; - } - if (key != NULL) { - wh_Client_MlKemSetKeyId(key, key_id); - if ((flags & WH_NVM_FLAGS_EPHEMERAL) != 0) { - uint8_t* key_raw = (uint8_t*)(res + 1); - ret = wh_Crypto_MlKemDeserializeKey( - key_raw, (uint16_t)res->len, key); - } - } - } + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &res_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + if (ret != WH_ERROR_OK) { + return ret; + } + + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_KEYGEN, + (uint8_t**)&res); + if (ret >= 0) { + key_id = (whKeyId)res->keyId; + WH_DEBUG_CLIENT_VERBOSE("MlKemMakeKey: Res recv:" + "keyId:%u, len:%u, ret:%d\n", + (unsigned int)res->keyId, + (unsigned int)res->len, ret); + if (inout_key_id != NULL) { + *inout_key_id = key_id; + } + if (key != NULL) { + wh_Client_MlKemSetKeyId(key, key_id); + if ((flags & WH_NVM_FLAGS_EPHEMERAL) != 0) { + uint8_t* key_raw = (uint8_t*)(res + 1); + ret = wh_Crypto_MlKemDeserializeKey( + key_raw, (uint16_t)res->len, key); } } } - else { - ret = WH_ERROR_BADARGS; - } } return ret; } -int wh_Client_MlKemMakeCacheKey(whClientContext* ctx, int size, int level, +int wh_Client_MlKemMakeCacheKey(whClientContext* ctx, int level, whKeyId* inout_key_id, whNvmFlags flags, uint16_t label_len, uint8_t* label) { @@ -6619,18 +6649,18 @@ int wh_Client_MlKemMakeCacheKey(whClientContext* ctx, int size, int level, return WH_ERROR_BADARGS; } - return _MlKemMakeKey(ctx, size, level, inout_key_id, flags, label_len, + return _MlKemMakeKey(ctx, level, inout_key_id, flags, label_len, label, NULL); } -int wh_Client_MlKemMakeExportKey(whClientContext* ctx, int level, int size, +int wh_Client_MlKemMakeExportKey(whClientContext* ctx, int level, MlKemKey* key) { if (key == NULL) { return WH_ERROR_BADARGS; } - return _MlKemMakeKey(ctx, size, level, NULL, WH_NVM_FLAGS_EPHEMERAL, 0, + return _MlKemMakeKey(ctx, level, NULL, WH_NVM_FLAGS_EPHEMERAL, 0, NULL, key); } @@ -6645,17 +6675,13 @@ int wh_Client_MlKemEncapsulate(whClientContext* ctx, MlKemKey* key, whKeyId key_id; int evict = 0; - word32 ct_len; - word32 ss_len; if ((ctx == NULL) || (key == NULL) || (ct == NULL) || (ss == NULL) || (inout_ct_len == NULL) || (inout_ss_len == NULL)) { return WH_ERROR_BADARGS; } - ct_len = *inout_ct_len; - ss_len = *inout_ss_len; - if ((ct_len == 0) || (ss_len == 0)) { + if ((*inout_ct_len == 0) || (*inout_ss_len == 0)) { return WH_ERROR_BADARGS; } @@ -6679,6 +6705,9 @@ int wh_Client_MlKemEncapsulate(whClientContext* ctx, MlKemKey* key, dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } return WH_ERROR_BADARGS; } @@ -6696,11 +6725,13 @@ int wh_Client_MlKemEncapsulate(whClientContext* ctx, MlKemKey* key, req->options = options; req->level = key->type; req->keyId = key_id; - req->ctSz = ct_len; - req->ssSz = ss_len; ret = wh_Client_SendRequest(ctx, group, action, req_len, (uint8_t*)dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemEncapsulate: Req sent:keyId:%u, " + "level:%u, ret:%d\n", + (unsigned int)key_id, + (unsigned int)key->type, ret); if (ret == WH_ERROR_OK) { evict = 0; do { @@ -6716,16 +6747,23 @@ int wh_Client_MlKemEncapsulate(whClientContext* ctx, MlKemKey* key, uint8_t* resp_data = (uint8_t*)(res + 1); word32 out_ct_len = res->ctSz; word32 out_ss_len = res->ssSz; - if (out_ct_len > *inout_ct_len) { - out_ct_len = *inout_ct_len; + word32 max_resp = WOLFHSM_CFG_COMM_DATA_LEN - + (uint16_t)((uint8_t*)resp_data - dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemEncapsulate: Res recv:" + "ctSz:%u, ssSz:%u, ret:%d\n", + (unsigned int)out_ct_len, + (unsigned int)out_ss_len, ret); + if (out_ct_len + out_ss_len > max_resp || + *inout_ct_len < out_ct_len || + *inout_ss_len < out_ss_len) { + ret = WH_ERROR_BADARGS; } - if (out_ss_len > *inout_ss_len) { - out_ss_len = *inout_ss_len; + else { + memcpy(ct, resp_data, out_ct_len); + memcpy(ss, resp_data + out_ct_len, out_ss_len); + *inout_ct_len = out_ct_len; + *inout_ss_len = out_ss_len; } - memcpy(ct, resp_data, out_ct_len); - memcpy(ss, resp_data + res->ctSz, out_ss_len); - *inout_ct_len = out_ct_len; - *inout_ss_len = out_ss_len; } } } @@ -6778,6 +6816,9 @@ int wh_Client_MlKemDecapsulate(whClientContext* ctx, MlKemKey* key, dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } return WH_ERROR_BADARGS; } @@ -6799,13 +6840,16 @@ int wh_Client_MlKemDecapsulate(whClientContext* ctx, MlKemKey* key, req->level = key->type; req->keyId = key_id; req->ctSz = ct_len; - req->ssSz = *inout_ss_len; if ((ct != NULL) && (ct_len > 0)) { memcpy(req_ct, ct, ct_len); } ret = wh_Client_SendRequest(ctx, group, action, req_len, (uint8_t*)dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemDecapsulate: Req sent:keyId:%u, " + "ctSz:%u, ret:%d\n", + (unsigned int)key_id, + (unsigned int)ct_len, ret); if (ret == WH_ERROR_OK) { evict = 0; do { @@ -6820,11 +6864,19 @@ int wh_Client_MlKemDecapsulate(whClientContext* ctx, MlKemKey* key, if (ret >= 0) { uint8_t* resp_ss = (uint8_t*)(res + 1); word32 out_ss_len = res->ssSz; - if (out_ss_len > *inout_ss_len) { - out_ss_len = *inout_ss_len; + word32 max_resp = WOLFHSM_CFG_COMM_DATA_LEN - + (uint16_t)((uint8_t*)resp_ss - dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemDecapsulate: Res recv:" + "ssSz:%u, ret:%d\n", + (unsigned int)out_ss_len, ret); + if (out_ss_len > max_resp || + *inout_ss_len < out_ss_len) { + ret = WH_ERROR_BADARGS; + } + else { + memcpy(ss, resp_ss, out_ss_len); + *inout_ss_len = out_ss_len; } - memcpy(ss, resp_ss, out_ss_len); - *inout_ss_len = out_ss_len; } } } @@ -6847,7 +6899,7 @@ int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + byte* buffer = NULL; uint16_t buffer_len = 0; if ((ctx == NULL) || (key == NULL) || @@ -6855,11 +6907,20 @@ int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, return WH_ERROR_BADARGS; } + buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + if (inout_keyId != NULL) { key_id = *inout_keyId; } - ret = wh_Crypto_MlKemSerializeKey(key, sizeof(buffer), buffer, &buffer_len); + ret = wh_Crypto_MlKemSerializeKey(key, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, + buffer, &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKeyDma: serialize ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); if (ret == WH_ERROR_OK) { ret = wh_Client_KeyCacheDma(ctx, flags, label, label_len, buffer, buffer_len, &key_id); @@ -6867,7 +6928,11 @@ int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, *inout_keyId = key_id; } } + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKeyDma: ret:%d keyId:%u\n", + ret, key_id); + ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return ret; } @@ -6876,19 +6941,32 @@ int wh_Client_MlKemExportKeyDma(whClientContext* ctx, whKeyId keyId, uint8_t* label) { int ret = WH_ERROR_OK; - byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE] = {0}; - uint16_t buffer_len = sizeof(buffer); + byte* buffer = NULL; + uint16_t buffer_len = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { return WH_ERROR_BADARGS; } + buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + memset(buffer, 0, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + ret = wh_Client_KeyExportDma(ctx, keyId, buffer, buffer_len, label, label_len, &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKeyDma: export ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); if (ret == WH_ERROR_OK) { ret = wh_Crypto_MlKemDeserializeKey(buffer, buffer_len, key); } + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKeyDma: keyId:%x ret:%d\n", + keyId, ret); + ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return ret; } @@ -6898,19 +6976,27 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + byte* buffer = NULL; uint8_t* dataPtr = NULL; whMessageCrypto_MlKemKeyGenDmaRequest* req = NULL; whMessageCrypto_MlKemKeyGenDmaResponse* res = NULL; uintptr_t keyAddr = 0; - uint64_t keyAddrSz = sizeof(buffer); + uint64_t keyAddrSz = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; if (ctx == NULL) { return WH_ERROR_BADARGS; } + buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); if (dataPtr == NULL) { + ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return WH_ERROR_BADARGS; } @@ -6988,6 +7074,8 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, } } + ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return ret; } @@ -7040,6 +7128,9 @@ int wh_Client_MlKemEncapsulateDma(whClientContext* ctx, MlKemKey* key, dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } return WH_ERROR_BADARGS; } @@ -7154,6 +7245,9 @@ int wh_Client_MlKemDecapsulateDma(whClientContext* ctx, MlKemKey* key, dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } return WH_ERROR_BADARGS; } diff --git a/src/wh_client_cryptocb.c b/src/wh_client_cryptocb.c index 41e2cd76..e283dd93 100644 --- a/src/wh_client_cryptocb.c +++ b/src/wh_client_cryptocb.c @@ -47,7 +47,7 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/sha512.h" -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfhsm/wh_crypto.h" #include "wolfhsm/wh_client_crypto.h" @@ -636,6 +636,8 @@ static int _handlePqcKemKeyGen(whClientContext* ctx, wc_CryptoInfo* info, } #endif + (void)size; + switch (type) { case WC_PQC_KEM_TYPE_KYBER: { int level = ((MlKemKey*)key)->type; @@ -646,7 +648,7 @@ static int _handlePqcKemKeyGen(whClientContext* ctx, wc_CryptoInfo* info, else #endif /* WOLFHSM_CFG_DMA */ { - ret = wh_Client_MlKemMakeExportKey(ctx, level, size, key); + ret = wh_Client_MlKemMakeExportKey(ctx, level, key); } } break; diff --git a/src/wh_crypto.c b/src/wh_crypto.c index 9364b37a..fac54900 100644 --- a/src/wh_crypto.c +++ b/src/wh_crypto.c @@ -44,7 +44,7 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfhsm/wh_error.h" #include "wolfhsm/wh_utils.h" diff --git a/src/wh_message_crypto.c b/src/wh_message_crypto.c index d6884993..5c4692a5 100644 --- a/src/wh_message_crypto.c +++ b/src/wh_message_crypto.c @@ -843,7 +843,6 @@ int wh_MessageCrypto_TranslateMlKemKeyGenRequest( if ((src == NULL) || (dest == NULL)) { return WH_ERROR_BADARGS; } - WH_T32(magic, dest, src, sz); WH_T32(magic, dest, src, level); WH_T32(magic, dest, src, keyId); WH_T32(magic, dest, src, flags); @@ -878,8 +877,6 @@ int wh_MessageCrypto_TranslateMlKemEncapsRequest( WH_T32(magic, dest, src, options); WH_T32(magic, dest, src, level); WH_T32(magic, dest, src, keyId); - WH_T32(magic, dest, src, ctSz); - WH_T32(magic, dest, src, ssSz); return 0; } @@ -908,7 +905,6 @@ int wh_MessageCrypto_TranslateMlKemDecapsRequest( WH_T32(magic, dest, src, level); WH_T32(magic, dest, src, keyId); WH_T32(magic, dest, src, ctSz); - WH_T32(magic, dest, src, ssSz); return 0; } diff --git a/src/wh_server_crypto.c b/src/wh_server_crypto.c index e2509961..b8c59c8e 100644 --- a/src/wh_server_crypto.c +++ b/src/wh_server_crypto.c @@ -44,7 +44,7 @@ #include "wolfssl/wolfcrypt/sha512.h" #include "wolfssl/wolfcrypt/cmac.h" #include "wolfssl/wolfcrypt/dilithium.h" -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/hmac.h" #include "wolfssl/wolfcrypt/kdf.h" @@ -4602,12 +4602,8 @@ static int _HandleMlKemEncaps(whServerContext* ctx, uint16_t magic, int devId, } if (ret == WH_ERROR_OK) { - ct_len = req.ctSz; - ss_len = req.ssSz; - if (ct_len == 0) { - ret = wc_MlKemKey_CipherTextSize(key, &ct_len); - } - if ((ret == WH_ERROR_OK) && (ss_len == 0)) { + ret = wc_MlKemKey_CipherTextSize(key, &ct_len); + if (ret == WH_ERROR_OK) { ret = wc_MlKemKey_SharedSecretSize(key, &ss_len); } } @@ -4707,8 +4703,7 @@ static int _HandleMlKemDecaps(whServerContext* ctx, uint16_t magic, int devId, ret = wh_Server_MlKemKeyCacheExport(ctx, key_id, key); } - ss_len = req.ssSz; - if ((ret == WH_ERROR_OK) && (ss_len == 0)) { + if (ret == WH_ERROR_OK) { ret = wc_MlKemKey_SharedSecretSize(key, &ss_len); } @@ -6132,6 +6127,8 @@ static int _HandleMlKemKeyGenDma(whServerContext* ctx, uint16_t magic, whMessageCrypto_MlKemKeyGenDmaRequest req; whMessageCrypto_MlKemKeyGenDmaResponse res; + memset(&res, 0, sizeof(res)); + if (inSize < sizeof(whMessageCrypto_MlKemKeyGenDmaRequest)) { return WH_ERROR_BADARGS; } @@ -6221,13 +6218,15 @@ static int _HandleMlKemEncapsDma(whServerContext* ctx, uint16_t magic, MlKemKey key[1]; void* ctAddr = NULL; void* ssAddr = NULL; - word32 ctLen; - word32 ssLen; + word32 ctLen = 0; + word32 ssLen = 0; whKeyId key_id; - int evict; + int evict = 0; whMessageCrypto_MlKemEncapsDmaRequest req; whMessageCrypto_MlKemEncapsDmaResponse res; + memset(&res, 0, sizeof(res)); + if (inSize < sizeof(whMessageCrypto_MlKemEncapsDmaRequest)) { return WH_ERROR_BADARGS; } @@ -6339,12 +6338,14 @@ static int _HandleMlKemDecapsDma(whServerContext* ctx, uint16_t magic, MlKemKey key[1]; void* ctAddr = NULL; void* ssAddr = NULL; - word32 ssLen; + word32 ssLen = 0; whKeyId key_id; - int evict; + int evict = 0; whMessageCrypto_MlKemDecapsDmaRequest req; whMessageCrypto_MlKemDecapsDmaResponse res; + memset(&res, 0, sizeof(res)); + if (inSize < sizeof(whMessageCrypto_MlKemDecapsDmaRequest)) { return WH_ERROR_BADARGS; } diff --git a/test/wh_test_crypto.c b/test/wh_test_crypto.c index 4f6681ab..7e8ffff5 100644 --- a/test/wh_test_crypto.c +++ b/test/wh_test_crypto.c @@ -32,7 +32,7 @@ #include "wolfssl/wolfcrypt/types.h" #include "wolfssl/wolfcrypt/kdf.h" #include "wolfssl/wolfcrypt/ed25519.h" -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfhsm/wh_error.h" @@ -5261,7 +5261,7 @@ static int whTestCrypto_MlKemClient(whClientContext* ctx, int devId, WC_RNG* rng } if (ret == 0) { - ret = wh_Client_MlKemMakeExportKey(ctx, levels[i], 0, key); + ret = wh_Client_MlKemMakeExportKey(ctx, levels[i], key); if (ret != 0) { WH_ERROR_PRINT( "Failed ML-KEM make export key level=%d ret=%d\n", @@ -5269,7 +5269,7 @@ static int whTestCrypto_MlKemClient(whClientContext* ctx, int devId, WC_RNG* rng } } if (ret == 0) { - ret = wh_Client_MlKemMakeExportKey(ctx, levels[i], 0, wrongKey); + ret = wh_Client_MlKemMakeExportKey(ctx, levels[i], wrongKey); if (ret != 0) { WH_ERROR_PRINT( "Failed ML-KEM make wrong export key level=%d ret=%d\n", @@ -5318,7 +5318,7 @@ static int whTestCrypto_MlKemClient(whClientContext* ctx, int devId, WC_RNG* rng if (ret == 0) { ret = wh_Client_MlKemMakeCacheKey( - ctx, 0, levels[i], &usageKeyId, WH_NVM_FLAGS_NONE, + ctx, levels[i], &usageKeyId, WH_NVM_FLAGS_NONE, (uint16_t)strlen((const char*)usageLabel), (uint8_t*)usageLabel); if (ret != 0) { WH_ERROR_PRINT( diff --git a/wolfhsm/wh_client_crypto.h b/wolfhsm/wh_client_crypto.h index 8001e587..22e68bc9 100644 --- a/wolfhsm/wh_client_crypto.h +++ b/wolfhsm/wh_client_crypto.h @@ -50,7 +50,7 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/hmac.h" /** @@ -1294,38 +1294,197 @@ int wh_Client_MlDsaCheckPrivKeyDma(whClientContext* ctx, MlDsaKey* key, #endif /* HAVE_DILITHIUM */ #ifdef WOLFSSL_HAVE_MLKEM + +/** + * @brief Associate a ML-KEM key with a specific key ID. + * + * Sets the device context of a ML-KEM key to the specified key ID. On the + * server side, this key ID is used to reference the key stored in the HSM. + * + * @param[in] key Pointer to the ML-KEM key structure. + * @param[in] keyId Key ID to be associated with the ML-KEM key. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemSetKeyId(MlKemKey* key, whKeyId keyId); + +/** + * @brief Retrieve the key ID associated with a ML-KEM key. + * + * @param[in] key Pointer to the ML-KEM key structure. + * @param[out] outId Pointer to store the retrieved key ID. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemGetKeyId(MlKemKey* key, whKeyId* outId); +/** + * @brief Import a ML-KEM key to the server key cache. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key to import. + * @param[in,out] inout_keyId Pointer to key ID to use/receive. + * @param[in] flags Flags to control key persistence. + * @param[in] label_len Length of optional label in bytes. + * @param[in] label Optional label to associate with the key. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, whKeyId* inout_keyId, whNvmFlags flags, uint16_t label_len, uint8_t* label); + +/** + * @brief Export a ML-KEM key from the server key cache. + * + * @param[in] ctx Pointer to the client context. + * @param[in] keyId Key ID of the key to export. + * @param[out] key Pointer to the ML-KEM key structure to populate. + * @param[in] label_len Length of optional label in bytes. + * @param[out] label Optional label buffer to receive the key label. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemExportKey(whClientContext* ctx, whKeyId keyId, MlKemKey* key, uint16_t label_len, uint8_t* label); -int wh_Client_MlKemMakeExportKey(whClientContext* ctx, int level, int size, + +/** + * @brief Generate a ML-KEM key pair and return it as an ephemeral key. + * + * The key pair is generated on the server, serialized, and returned to the + * client without being cached. + * + * @param[in] ctx Pointer to the client context. + * @param[in] level ML-KEM security level (WC_ML_KEM_512/768/1024). + * @param[out] key Pointer to the ML-KEM key to populate with the generated key. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemMakeExportKey(whClientContext* ctx, int level, MlKemKey* key); -int wh_Client_MlKemMakeCacheKey(whClientContext* ctx, int size, int level, + +/** + * @brief Generate a ML-KEM key pair and cache it on the server. + * + * @param[in] ctx Pointer to the client context. + * @param[in] level ML-KEM security level (WC_ML_KEM_512/768/1024). + * @param[in,out] inout_key_id Pointer to key ID to use/receive. + * @param[in] flags Flags to control key persistence and usage. + * @param[in] label_len Length of optional label in bytes. + * @param[in] label Optional label to associate with the key. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemMakeCacheKey(whClientContext* ctx, int level, whKeyId* inout_key_id, whNvmFlags flags, uint16_t label_len, uint8_t* label); + +/** + * @brief Perform ML-KEM encapsulation using a server-cached public key. + * + * Generates a shared secret and ciphertext using the public key identified by + * the key ID stored in the provided MlKemKey. If the key is not yet cached, + * it will be auto-imported and evicted after use. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[out] ct Buffer to receive the ciphertext. + * @param[in,out] inout_ct_len On input, size of ct buffer; on output, actual + * ciphertext length. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemEncapsulate(whClientContext* ctx, MlKemKey* key, byte* ct, word32* inout_ct_len, byte* ss, word32* inout_ss_len); + +/** + * @brief Perform ML-KEM decapsulation using a server-cached private key. + * + * Recovers the shared secret from the ciphertext using the private key + * identified by the key ID stored in the provided MlKemKey. If the key is not + * yet cached, it will be auto-imported and evicted after use. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[in] ct Pointer to the ciphertext. + * @param[in] ct_len Length of the ciphertext in bytes. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemDecapsulate(whClientContext* ctx, MlKemKey* key, const byte* ct, word32 ct_len, byte* ss, word32* inout_ss_len); #ifdef WOLFHSM_CFG_DMA + +/** + * @brief Import a ML-KEM key using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key to import. + * @param[in,out] inout_keyId Pointer to store/provide the key ID. + * @param[in] flags NVM flags for key storage. + * @param[in] label_len Length of the key label in bytes. + * @param[in] label Pointer to the key label. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, whKeyId* inout_keyId, whNvmFlags flags, uint16_t label_len, uint8_t* label); + +/** + * @brief Export a ML-KEM key from the server using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] keyId Key ID of the key to export. + * @param[out] key Pointer to the ML-KEM key structure to populate. + * @param[in] label_len Length of the key label in bytes. + * @param[out] label Pointer to the key label buffer. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemExportKeyDma(whClientContext* ctx, whKeyId keyId, MlKemKey* key, uint16_t label_len, uint8_t* label); + +/** + * @brief Generate an ephemeral ML-KEM key pair using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] level ML-KEM security level (WC_ML_KEM_512/768/1024). + * @param[out] key Pointer to the ML-KEM key to populate. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemMakeExportKeyDma(whClientContext* ctx, int level, MlKemKey* key); + +/** + * @brief Perform ML-KEM encapsulation using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[out] ct Buffer to receive the ciphertext. + * @param[in,out] inout_ct_len On input, size of ct buffer; on output, actual + * ciphertext length. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemEncapsulateDma(whClientContext* ctx, MlKemKey* key, byte* ct, word32* inout_ct_len, byte* ss, word32* inout_ss_len); + +/** + * @brief Perform ML-KEM decapsulation using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[in] ct Pointer to the ciphertext. + * @param[in] ct_len Length of the ciphertext in bytes. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ int wh_Client_MlKemDecapsulateDma(whClientContext* ctx, MlKemKey* key, const byte* ct, word32 ct_len, byte* ss, word32* inout_ss_len); diff --git a/wolfhsm/wh_crypto.h b/wolfhsm/wh_crypto.h index 1279a17e..3774fa63 100644 --- a/wolfhsm/wh_crypto.h +++ b/wolfhsm/wh_crypto.h @@ -43,7 +43,7 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfhsm/wh_message_crypto.h" diff --git a/wolfhsm/wh_message_crypto.h b/wolfhsm/wh_message_crypto.h index 7a39e68b..bc2eb05c 100644 --- a/wolfhsm/wh_message_crypto.h +++ b/wolfhsm/wh_message_crypto.h @@ -953,7 +953,6 @@ int wh_MessageCrypto_TranslateMlDsaVerifyResponse( /* ML-KEM Key Generation Request */ typedef struct { - uint32_t sz; uint32_t level; uint32_t keyId; uint32_t flags; @@ -984,8 +983,6 @@ typedef struct { #define WH_MESSAGE_CRYPTO_MLKEM_ENCAPS_OPTIONS_EVICT (1 << 0) uint32_t level; uint32_t keyId; - uint32_t ctSz; - uint32_t ssSz; uint8_t WH_PAD[4]; } whMessageCrypto_MlKemEncapsRequest; @@ -1014,8 +1011,6 @@ typedef struct { uint32_t level; uint32_t keyId; uint32_t ctSz; - uint32_t ssSz; - uint8_t WH_PAD[4]; /* Data follows: * uint8_t ct[ctSz]; */ diff --git a/wolfhsm/wh_server_crypto.h b/wolfhsm/wh_server_crypto.h index 18ee72f2..af945ce9 100644 --- a/wolfhsm/wh_server_crypto.h +++ b/wolfhsm/wh_server_crypto.h @@ -37,7 +37,7 @@ #include "wolfssl/wolfcrypt/curve25519.h" #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" -#include "wolfssl/wolfcrypt/mlkem.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/aes.h" #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/cmac.h"