diff --git a/spdm_emu/spdm_emu_common/command.c b/spdm_emu/spdm_emu_common/command.c index d1613c6..017faa4 100644 --- a/spdm_emu/spdm_emu_common/command.c +++ b/spdm_emu/spdm_emu_common/command.c @@ -5,6 +5,8 @@ **/ #include "spdm_emu.h" +#include +#include /* hack to add MCTP header for PCAP*/ #include "industry_standard/mctp.h" @@ -17,6 +19,8 @@ bool m_send_receive_buffer_acquired = false; uint8_t m_send_receive_buffer[LIBSPDM_MAX_SENDER_RECEIVER_BUFFER_SIZE]; size_t m_send_receive_buffer_size; +uint8_t m_use_eid = 0; +uint8_t m_send_single_spdm_cmd = 0; /** * Read number of bytes data in blocking mode. * @@ -31,8 +35,17 @@ bool read_bytes(const SOCKET socket, uint8_t *buffer, number_received = 0; while (number_received < number_of_bytes) { - result = recv(socket, (char *)(buffer + number_received), - number_of_bytes - number_received, 0); + if (m_use_transport_layer == SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) { + struct sockaddr_mctp addr = { 0 }; + socklen_t addrlen = sizeof(addr); + result = recvfrom(socket, (char *)(buffer + number_received), + number_of_bytes - number_received, MSG_TRUNC, + (struct sockaddr *)&addr, &addrlen); + } + else { + result = recv(socket, (char *)(buffer + number_received), + number_of_bytes - number_received, 0); + } if (result == -1) { printf("Receive error - 0x%x\n", #ifdef _MSC_VER @@ -79,31 +92,51 @@ bool read_multiple_bytes(const SOCKET socket, uint8_t *buffer, uint32_t length; bool result; - result = read_data32(socket, &length); - if (!result) { - return result; - } - printf("Platform port Receive size: "); - length = ntohl(length); - dump_data((uint8_t *)&length, sizeof(uint32_t)); - printf("\n"); - length = ntohl(length); - - *bytes_received = length; - if (*bytes_received > max_buffer_length) { - printf("buffer too small (0x%x). Expected - 0x%x\n", - max_buffer_length, *bytes_received); - return false; - } - if (length == 0) { - return true; - } - result = read_bytes(socket, buffer, length); - if (!result) { - return result; + if (m_use_transport_layer != SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) { + result = read_data32(socket, &length); + if (!result) { + return result; + } + printf("Platform port Receive size: "); + length = ntohl(length); + dump_data((uint8_t *)&length, sizeof(uint32_t)); + printf("\n"); + length = ntohl(length); + + *bytes_received = length; + if (*bytes_received > max_buffer_length) { + printf("buffer too small (0x%x). Expected - 0x%x\n", + max_buffer_length, *bytes_received); + return false; + } + if (length == 0) { + return true; + } + result = read_bytes(socket, buffer, length); + if (!result) { + return result; + } + } else { + length = recv(socket, NULL, 0, MSG_PEEK | MSG_TRUNC); + if (length == -1) { + printf("Error: %s\n", strerror(errno)); + return false; + } + if (length > max_buffer_length - 1) { + printf("buffer too small (0x%x). Expected - 0x%x\n", + max_buffer_length, length); + return false; + } + result = read_bytes(socket, buffer+1, length); + if (!result) + return result; + // mctp kernel receive payload only. + // So add msg_type byte to receive_message + buffer[0] = MCTP_MESSAGE_TYPE_SPDM; + *bytes_received = length + 1; } printf("Platform port Receive buffer:\n "); - dump_data(buffer, length); + dump_data(buffer, length + 1); printf("\n"); return true; @@ -118,30 +151,31 @@ bool receive_platform_data(const SOCKET socket, uint32_t *command, uint32_t transport_type; uint32_t bytes_received; - result = read_data32(socket, &response); - if (!result) { - return result; - } - *command = response; - printf("Platform port Receive command: "); - response = ntohl(response); - dump_data((uint8_t *)&response, sizeof(uint32_t)); - printf("\n"); - - result = read_data32(socket, &transport_type); - if (!result) { - return result; - } - printf("Platform port Receive transport_type: "); - transport_type = ntohl(transport_type); - dump_data((uint8_t *)&transport_type, sizeof(uint32_t)); - printf("\n"); - transport_type = ntohl(transport_type); - if (transport_type != m_use_transport_layer) { - printf("transport_type mismatch\n"); - return false; + if (m_use_transport_layer != SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) { + result = read_data32(socket, &response); + if (!result) { + return result; + } + *command = response; + printf("Platform port Receive command: "); + response = ntohl(response); + dump_data((uint8_t *)&response, sizeof(uint32_t)); + printf("\n"); + + result = read_data32(socket, &transport_type); + if (!result) { + return result; + } + printf("Platform port Receive transport_type: "); + transport_type = ntohl(transport_type); + dump_data((uint8_t *)&transport_type, sizeof(uint32_t)); + printf("\n"); + transport_type = ntohl(transport_type); + if (transport_type != m_use_transport_layer) { + printf("transport_type mismatch\n"); + return false; + } } - bytes_received = 0; result = read_multiple_bytes(socket, receive_buffer, &bytes_received, (uint32_t)*bytes_to_receive); @@ -193,8 +227,25 @@ bool write_bytes(const SOCKET socket, const uint8_t *buffer, number_sent = 0; while (number_sent < number_of_bytes) { - result = send(socket, (char *)(buffer + number_sent), - number_of_bytes - number_sent, 0); + if (m_use_transport_layer == SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) { + /* MCTP kernel approach does not support send() syscall + * sendto() is recommanded to send messages currently + * https://discord.com/channels/775381525260664832/775381525260664836/1161513903319158904 */ + struct sockaddr_mctp addr = { 0 }; + addr.smctp_family = AF_MCTP; + if (m_use_eid != 0) + addr.smctp_addr.s_addr = m_use_eid; + addr.smctp_type = MCTP_MESSAGE_TYPE_SPDM; + addr.smctp_tag = MCTP_TAG_OWNER; + /* we own the tag, and so the kernel will allocate one for us */ + result = sendto(socket, (char *)(buffer + number_sent), + number_of_bytes - number_sent, 0, + (struct sockaddr *)&addr, sizeof(addr)); + } + else { + result = send(socket, (char *)(buffer + number_sent), + number_of_bytes - number_sent, 0); + } if (result == -1) { #ifdef _MSC_VER if (WSAGetLastError() == 0x2745) { @@ -208,6 +259,7 @@ bool write_bytes(const SOCKET socket, const uint8_t *buffer, errno #endif ); + printf("Something went wrong, cannot send()! errno = %s\n", strerror(errno)); #ifdef _MSC_VER } #endif @@ -235,9 +287,12 @@ bool write_multiple_bytes(const SOCKET socket, const uint8_t *buffer, { bool result; - result = write_data32(socket, bytes_to_send); - if (!result) { - return result; + if (m_use_transport_layer != SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) + { + result = write_data32(socket, bytes_to_send); + if (!result) { + return result; + } } printf("Platform port Transmit size: "); bytes_to_send = htonl(bytes_to_send); @@ -245,7 +300,12 @@ bool write_multiple_bytes(const SOCKET socket, const uint8_t *buffer, printf("\n"); bytes_to_send = htonl(bytes_to_send); - result = write_bytes(socket, buffer, bytes_to_send); + if (m_use_transport_layer != SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) + result = write_bytes(socket, buffer, bytes_to_send); + else { + //mctp kernel do not need to send message type in the payload + result = write_bytes(socket, buffer+1, bytes_to_send-1); + } if (!result) { return result; } @@ -262,19 +322,25 @@ bool send_platform_data(const SOCKET socket, uint32_t command, uint32_t request; uint32_t transport_type; - request = command; - result = write_data32(socket, request); - if (!result) { - return result; - } - printf("Platform port Transmit command: "); - request = htonl(request); - dump_data((uint8_t *)&request, sizeof(uint32_t)); - printf("\n"); - - result = write_data32(socket, m_use_transport_layer); - if (!result) { - return result; + if(m_use_transport_layer != SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) + { + request = command; + result = write_data32(socket, request); + if (!result) { + return result; + } + printf("Platform port Transmit command: "); + request = htonl(request); + dump_data((uint8_t *)&request, sizeof(uint32_t)); + printf("\n"); + result = write_data32(socket, m_use_transport_layer); + if (!result) { + return result; + } + printf("Platform port Transmit transport_type: "); + transport_type = ntohl(m_use_transport_layer); + dump_data((uint8_t *)&transport_type, sizeof(uint32_t)); + printf("\n"); } printf("Platform port Transmit transport_type: "); transport_type = ntohl(m_use_transport_layer); diff --git a/spdm_emu/spdm_emu_common/command.h b/spdm_emu/spdm_emu_common/command.h index e90379d..171f09d 100644 --- a/spdm_emu/spdm_emu_common/command.h +++ b/spdm_emu/spdm_emu_common/command.h @@ -22,6 +22,8 @@ #define SOCKET_TRANSPORT_TYPE_MCTP 0x01 #define SOCKET_TRANSPORT_TYPE_PCI_DOE 0x02 #define SOCKET_TRANSPORT_TYPE_TCP 0x03 +/* Support mctp kernel */ +#define SOCKET_TRANSPORT_TYPE_MCTP_KERNEL 0x05 #define SOCKET_TCP_NO_HANDSHAKE 0x00 #define SOCKET_TCP_HANDSHAKE 0x01 diff --git a/spdm_emu/spdm_emu_common/spdm_emu.c b/spdm_emu/spdm_emu_common/spdm_emu.c index ddb32f6..e747de3 100644 --- a/spdm_emu/spdm_emu_common/spdm_emu.c +++ b/spdm_emu/spdm_emu_common/spdm_emu.c @@ -5,7 +5,11 @@ **/ #include "spdm_emu.h" +#include +#include +#include +#define MCTP_MESSAGE_TYPE_SPDM 0x05 /* * EXE_MODE_SHUTDOWN * EXE_MODE_CONTINUE @@ -35,7 +39,7 @@ struct in_addr m_ip_address = { 0x0100007F }; void print_usage(const char *name) { - printf("\n%s [--trans MCTP|PCI_DOE|TCP|NONE]\n", name); + printf("\n%s [--trans MCTP|PCI_DOE|MCTP_KERNEL|NONE]\n", name); printf(" [--tcp_sub HS|NO_HS]\n"); printf(" [--ver 1.0|1.1|1.2|1.3]\n"); printf(" [--sec_ver 1.0|1.1]\n"); @@ -51,6 +55,7 @@ void print_usage(const char *name) " [--req_asym RSASSA_2048|RSASSA_3072|RSASSA_4096|RSAPSS_2048|RSAPSS_3072|RSAPSS_4096|ECDSA_P256|ECDSA_P384|ECDSA_P521|SM2_P256|EDDSA_25519|EDDSA_448]\n"); printf( " [--dhe FFDHE_2048|FFDHE_3072|FFDHE_4096|SECP_256_R1|SECP_384_R1|SECP_521_R1|SM2_P256]\n"); + printf(" [--cmd GET_VERSION]\n"); printf(" [--aead AES_128_GCM|AES_256_GCM|CHACHA20_POLY1305|SM4_128_GCM]\n"); printf(" [--key_schedule HMAC_HASH]\n"); printf(" [--other_param OPAQUE_FMT_1|MULTI_KEY_CONN]\n"); @@ -188,7 +193,8 @@ value_string_entry_t m_transport_value_string_table[] = { { SOCKET_TRANSPORT_TYPE_NONE, "NONE"}, { SOCKET_TRANSPORT_TYPE_MCTP, "MCTP" }, { SOCKET_TRANSPORT_TYPE_PCI_DOE, "PCI_DOE" }, - { SOCKET_TRANSPORT_TYPE_TCP, "TCP"} + { SOCKET_TRANSPORT_TYPE_TCP, "TCP"}, + { SOCKET_TRANSPORT_TYPE_MCTP_KERNEL, "MCTP_KERNEL" }, }; value_string_entry_t m_tcp_subtype_string_table[] = { @@ -420,6 +426,10 @@ value_string_entry_t m_exe_session_string_table[] = { { EXE_SESSION_APP, "APP" }, }; +value_string_entry_t m_cmd_string_table[] = { + { 0x1, "GET_VERSION" }, +}; + bool get_value_from_name(const value_string_entry_t *table, size_t entry_count, const char *name, uint32_t *value) @@ -1241,6 +1251,42 @@ void process_args(char *program_name, int argc, char *argv[]) } } + if (strcmp(argv[0], "--cmd") == 0) { + if (argc >= 2) { + if (!get_value_from_name( + m_cmd_string_table, + LIBSPDM_ARRAY_SIZE(m_cmd_string_table), + argv[1], &data32)) { + printf("invalid --slot_id %s\n", + argv[1]); + print_usage(program_name); + exit(0); + } + m_send_single_spdm_cmd = (uint8_t)data32; + printf("spdm_cmd - 0x%02x\n", m_send_single_spdm_cmd); + argc -= 2; + argv += 2; + continue; + } else { + printf("invalid --spdm_cmd\n"); + print_usage(program_name); + exit(0); + } + } + + if (strcmp(argv[0], "--eid") == 0) { + m_use_eid = (uint8_t)atoi(argv[1]); + if (argc >= 2 && m_use_eid < 256) { + argc -= 2; + argv += 2; + continue; + } else { + printf("invalid --eid\n"); + print_usage(program_name); + exit(0); + } + } + printf("invalid %s\n", argv[0]); print_usage(program_name); exit(0); @@ -1262,49 +1308,74 @@ void process_args(char *program_name, int argc, char *argv[]) bool init_client(SOCKET *sock, uint16_t port) { SOCKET client_socket; - struct sockaddr_in server_addr; - int32_t ret_val; + if (m_use_transport_layer == SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) + { + struct sockaddr_mctp addr = { 0 }; + int rc = -1; + client_socket = socket(AF_MCTP, SOCK_DGRAM, 0); + if (-1 == client_socket) + { + printf("Failed to create the socket : RC = %d\n", client_socket); + return false; + } + + addr.smctp_family = AF_MCTP; + addr.smctp_network = MCTP_NET_ANY; + addr.smctp_addr.s_addr = MCTP_ADDR_ANY; + addr.smctp_type = MCTP_MESSAGE_TYPE_SPDM; + addr.smctp_tag = MCTP_TAG_OWNER; + + rc = bind(client_socket, (struct sockaddr *)&addr, sizeof(addr)); + if (rc) + { + printf("Failed to bind socket: RC=%d\n", rc); + return false; + } + } + else { + struct sockaddr_in server_addr; + int32_t ret_val; #ifdef _MSC_VER - WSADATA ws; - if (WSAStartup(MAKEWORD(2, 2), &ws) != 0) { - printf("Init Windows socket Failed - %x\n", WSAGetLastError()); - return false; - } + WSADATA ws; + if (WSAStartup(MAKEWORD(2, 2), &ws) != 0) { + printf("Init Windows socket Failed - %x\n", WSAGetLastError()); + return false; + } #endif - client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (client_socket == INVALID_SOCKET) { - printf("Create socket Failed - %x\n", + client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (client_socket == INVALID_SOCKET) { + printf("Create socket Failed - %x\n", #ifdef _MSC_VER - WSAGetLastError() + WSAGetLastError() #else - errno + errno #endif - ); - return false; - } + ); + return false; + } - server_addr.sin_family = AF_INET; - libspdm_copy_mem(&server_addr.sin_addr.s_addr, sizeof(struct in_addr), &m_ip_address, - sizeof(struct in_addr)); - server_addr.sin_port = htons(port); - libspdm_zero_mem(server_addr.sin_zero, sizeof(server_addr.sin_zero)); + server_addr.sin_family = AF_INET; + libspdm_copy_mem(&server_addr.sin_addr.s_addr, sizeof(struct in_addr), &m_ip_address, + sizeof(struct in_addr)); + server_addr.sin_port = htons(port); + libspdm_zero_mem(server_addr.sin_zero, sizeof(server_addr.sin_zero)); - ret_val = connect(client_socket, (struct sockaddr *)&server_addr, - sizeof(server_addr)); - if (ret_val == SOCKET_ERROR) { - printf("Connect Error - %x\n", + ret_val = connect(client_socket, (struct sockaddr *)&server_addr, + sizeof(server_addr)); + if (ret_val == SOCKET_ERROR) { + printf("Connect Error - %x\n", #ifdef _MSC_VER - WSAGetLastError() + WSAGetLastError() #else - errno + errno #endif - ); - closesocket(client_socket); - return false; + ); + closesocket(client_socket); + return false; + } } - printf("connect success!\n"); *sock = client_socket; diff --git a/spdm_emu/spdm_emu_common/spdm_emu.h b/spdm_emu/spdm_emu_common/spdm_emu.h index 687101d..fc21b18 100644 --- a/spdm_emu/spdm_emu_common/spdm_emu.h +++ b/spdm_emu/spdm_emu_common/spdm_emu.h @@ -21,6 +21,8 @@ #include "command.h" #include "nv_storage.h" +extern uint8_t m_send_single_spdm_cmd; +extern uint8_t m_use_eid; extern uint32_t m_use_transport_layer; extern uint32_t m_use_tcp_handshake; extern uint8_t m_use_version; diff --git a/spdm_emu/spdm_requester_emu/spdm_requester_emu.c b/spdm_emu/spdm_requester_emu/spdm_requester_emu.c index 7b7349d..b39b2ee 100644 --- a/spdm_emu/spdm_requester_emu/spdm_requester_emu.c +++ b/spdm_emu/spdm_requester_emu/spdm_requester_emu.c @@ -5,6 +5,9 @@ **/ #include "spdm_requester_emu.h" +#include +#include +#include uint8_t m_receive_buffer[LIBSPDM_MAX_SENDER_RECEIVER_BUFFER_SIZE]; @@ -18,6 +21,10 @@ extern void *m_scratch_buffer; uint8_t m_other_slot_id = 0; +enum { + GET_VERSION = 1, +}; + void *spdm_client_init(void); libspdm_return_t pci_doe_init_requester(void); @@ -41,6 +48,21 @@ libspdm_return_t do_authentication_via_spdm(void); libspdm_return_t do_session_via_spdm(bool use_psk); libspdm_return_t do_certificate_provising_via_spdm(uint32_t* session_id); +void do_send_single_spdm_cmd(void *spdm_context) +{ + if (spdm_context == NULL) + return; + + switch(m_send_single_spdm_cmd) { + case GET_VERSION: + libspdm_get_version(spdm_context, NULL, NULL); + break; + default: + break; + } + +} + bool platform_client_routine(uint16_t port_number) { SOCKET platform_socket; @@ -74,7 +96,8 @@ bool platform_client_routine(uint16_t port_number) m_socket = platform_socket; } - if (m_use_transport_layer != SOCKET_TRANSPORT_TYPE_NONE) { + if (m_use_transport_layer != SOCKET_TRANSPORT_TYPE_NONE && + m_use_transport_layer != SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) { response_size = sizeof(m_receive_buffer); result = communicate_platform_data( m_socket, @@ -97,6 +120,12 @@ bool platform_client_routine(uint16_t port_number) } } + if (m_send_single_spdm_cmd == GET_VERSION) { + m_spdm_context = spdm_client_init(); + do_send_single_spdm_cmd(m_spdm_context); + goto done_single_cmd; + } + m_spdm_context = spdm_client_init(); if (m_spdm_context == NULL) { goto done; @@ -172,12 +201,15 @@ bool platform_client_routine(uint16_t port_number) result = true; done: response_size = 0; - if (!communicate_platform_data( - m_socket, SOCKET_SPDM_COMMAND_SHUTDOWN - m_exe_mode, - NULL, 0, &response, &response_size, NULL)) { - return false; - } + if (m_use_transport_layer != SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) { + if (!communicate_platform_data( + m_socket, SOCKET_SPDM_COMMAND_SHUTDOWN - m_exe_mode, + NULL, 0, &response, &response_size, NULL)) { + return false; + } + } +done_single_cmd: if (m_spdm_context != NULL) { #if LIBSPDM_FIPS_MODE if (!libspdm_export_fips_selftest_context_from_spdm_context( diff --git a/spdm_emu/spdm_requester_emu/spdm_requester_spdm.c b/spdm_emu/spdm_requester_emu/spdm_requester_spdm.c index f1fc1b7..1deeccf 100644 --- a/spdm_emu/spdm_requester_emu/spdm_requester_spdm.c +++ b/spdm_emu/spdm_requester_emu/spdm_requester_spdm.c @@ -13,6 +13,8 @@ void *m_fips_selftest_context; void *m_scratch_buffer; SOCKET m_socket; +#define GET_VERSION 0x01 + bool communicate_platform_data(SOCKET socket, uint32_t command, const uint8_t *send_buffer, size_t bytes_to_send, uint32_t *response, @@ -179,7 +181,8 @@ void *spdm_client_init(void) libspdm_register_device_io_func(spdm_context, spdm_device_send_message, spdm_device_receive_message); - if (m_use_transport_layer == SOCKET_TRANSPORT_TYPE_MCTP) { + if (m_use_transport_layer == SOCKET_TRANSPORT_TYPE_MCTP || + m_use_transport_layer == SOCKET_TRANSPORT_TYPE_MCTP_KERNEL) { libspdm_register_transport_layer_func( spdm_context, LIBSPDM_MAX_SPDM_MSG_SIZE, @@ -312,6 +315,10 @@ void *spdm_client_init(void) libspdm_set_data(spdm_context, LIBSPDM_DATA_MEL_SPEC, ¶meter, &data8, sizeof(data8)); + if(m_send_single_spdm_cmd == GET_VERSION) { + return m_spdm_context; + } + if (m_load_state_file_name == NULL) { /* Skip if state is loaded*/ status = libspdm_init_connection(