From 4d42b87339f042f71d04d658195b155a4a713804 Mon Sep 17 00:00:00 2001 From: Simon Ott Date: Thu, 26 Sep 2024 15:42:10 +0000 Subject: [PATCH] treewide: implement serializer interface for socket API For the socket API, the cmcd and testtool used CBOR serialization for the payloads (regardless of the serialization of metadata and attestation reports). This commit introduces the socketApiSerializer configuration for the testtool and adds automatic serialization detection for socket API requests in the cmcd. This makes it possible to interact with the cmcd socket API without requiring CBOR support. Signed-off-by: Simon Ott --- api/api.go | 29 +--------- cmcd/socket.go | 133 ++++++++++++++++++++++++++++----------------- testtool/config.go | 66 ++++++++++++++-------- testtool/socket.go | 52 +++++++++++------- 4 files changed, 163 insertions(+), 117 deletions(-) diff --git a/api/api.go b/api/api.go index 45ab7d0f..cf2b9e8b 100644 --- a/api/api.go +++ b/api/api.go @@ -26,7 +26,6 @@ import ( "fmt" "net" - "github.com/fxamacker/cbor/v2" log "github.com/sirupsen/logrus" ) @@ -203,7 +202,7 @@ func SignerOptsToHash(opts crypto.SignerOpts) (HashFunction, error) { // // Len uint32 -> Length of the payload to be sent // Type uint32 -> Type of the payload -// payload []byte -> CBOR-encoded payload +// payload []byte -> encoded payload func Receive(conn net.Conn) ([]byte, uint32, error) { // If unix domain sockets are used, set the write buffer size @@ -260,16 +259,6 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { log.Tracef("Received payload length %v", payloadLen) - if msgType == TypeError { - resp := new(SocketError) - err = cbor.Unmarshal(payload.Bytes(), resp) - if err != nil { - return nil, 0, fmt.Errorf("failed to unmarshal error response") - } else { - return nil, 0, fmt.Errorf("server responded with error: %v", resp.Msg) - } - } - return payload.Bytes(), msgType, nil } @@ -277,7 +266,7 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { // // Len uint32 -> Length of the payload to be sent // Type uint32 -> Type of the payload -// payload []byte -> CBOR-encoded payload +// payload []byte -> encoded payload func Send(conn net.Conn, payload []byte, t uint32) error { if len(payload) > MaxMsgLen { @@ -320,17 +309,3 @@ func Send(conn net.Conn, payload []byte, t uint32) error { return nil } - -func SendError(conn net.Conn, format string, args ...interface{}) error { - msg := fmt.Sprintf(format, args...) - log.Warn(msg) - resp := &SocketError{ - Msg: msg, - } - payload, err := cbor.Marshal(resp) - if err != nil { - return fmt.Errorf("failed to marshal error response: %v", err) - } - - return Send(conn, payload, TypeError) -} diff --git a/cmcd/socket.go b/cmcd/socket.go index 4440472e..1a639158 100644 --- a/cmcd/socket.go +++ b/cmcd/socket.go @@ -29,15 +29,15 @@ import ( "encoding/hex" "encoding/json" - "github.com/fxamacker/cbor/v2" - // local modules "github.com/Fraunhofer-AISEC/cmc/api" + ar "github.com/Fraunhofer-AISEC/cmc/attestationreport" "github.com/Fraunhofer-AISEC/cmc/cmc" "github.com/Fraunhofer-AISEC/cmc/generate" "github.com/Fraunhofer-AISEC/cmc/internal" m "github.com/Fraunhofer-AISEC/cmc/measure" "github.com/Fraunhofer-AISEC/cmc/verify" + "github.com/fxamacker/cbor/v2" ) // Server is the server structure @@ -81,33 +81,40 @@ func handleIncoming(conn net.Conn, cmc *cmc.Cmc) { payload, reqType, err := api.Receive(conn) if err != nil { - api.SendError(conn, "Failed to receive: %v", err) + s, err := detectSerialization(payload) + sendError(conn, s, "Failed to receive: %v", err) + return + } + + s, err := detectSerialization(payload) + if err != nil { + log.Errorf("Failed to detect serialization of request: %v", err) return } // Handle request switch reqType { case api.TypeAttest: - attest(conn, payload, cmc) + attest(conn, payload, cmc, s) case api.TypeVerify: - validate(conn, payload, cmc) + validate(conn, payload, cmc, s) case api.TypeMeasure: - measure(conn, payload, cmc) + measure(conn, payload, cmc, s) case api.TypeTLSCert: - tlscert(conn, payload, cmc) + tlscert(conn, payload, cmc, s) case api.TypeTLSSign: - tlssign(conn, payload, cmc) + tlssign(conn, payload, cmc, s) default: - api.SendError(conn, "Invalid Type: %v", reqType) + sendError(conn, s, "Invalid Type: %v", reqType) } } -func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Prover: Received socket attestation request") if len(cmc.Drivers) == 0 { - api.SendError(conn, "no valid signers configured") + sendError(conn, s, "no valid signers configured") return } @@ -116,9 +123,9 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { } req := new(api.AttestationRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "failed to unmarshal attestation request: %v", err) + sendError(conn, s, "failed to unmarshal attestation request: %v", err) return } @@ -126,14 +133,14 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { report, err := generate.Generate(req.Nonce, cmc.Metadata, cmc.Drivers, cmc.Serializer) if err != nil { - api.SendError(conn, "failed to generate attestation report: %v", err) + sendError(conn, s, "failed to generate attestation report: %v", err) return } log.Debug("Prover: Signing Attestation Report") r, err := generate.Sign(report, cmc.Drivers[0], cmc.Serializer) if err != nil { - api.SendError(conn, "Failed to sign attestation report: %v", err) + sendError(conn, s, "Failed to sign attestation report: %v", err) return } @@ -141,28 +148,28 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := &api.AttestationResponse{ AttestationReport: r, } - data, err := cbor.Marshal(resp) + data, err := s.Marshal(resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeAttest) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Prover: Finished") } -func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received Connection Request Type 'Verification Request'") req := new(api.VerificationRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "Failed to unmarshal verification request: %v", err) + sendError(conn, s, "Failed to unmarshal verification request: %v", err) return } @@ -173,7 +180,7 @@ func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc) { log.Debug("Verifier: Marshaling Attestation Result") r, err := json.Marshal(result) if err != nil { - api.SendError(conn, "Verifier: failed to marshal Attestation Result: %v", err) + sendError(conn, s, "Verifier: failed to marshal Attestation Result: %v", err) return } @@ -181,28 +188,28 @@ func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := api.VerificationResponse{ VerificationResult: r, } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeVerify) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Verifier: Finished") } -func measure(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func measure(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received Connection Request Type 'Measure Request'") req := new(api.MeasureRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "Failed to unmarshal measure request: %v", err) + sendError(conn, s, "Failed to unmarshal measure request: %v", err) return } @@ -225,48 +232,48 @@ func measure(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := api.MeasureResponse{ Success: success, } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeMeasure) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Measurer: Finished") } -func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received TLS sign request") if len(cmc.Drivers) == 0 { - api.SendError(conn, "no valid signers configured") + sendError(conn, s, "no valid signers configured") return } // Parse the message and return the TLS signing request req := new(api.TLSSignRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "failed to unmarshal payload: %v", err) + sendError(conn, s, "failed to unmarshal payload: %v", err) return } // Get signing options from request opts, err := api.HashToSignerOpts(req.Hashtype, req.PssOpts) if err != nil { - api.SendError(conn, "failed to choose requested hash function: %v", err) + sendError(conn, s, "failed to choose requested hash function: %v", err) return } // Get key handle from (hardware) interface tlsKeyPriv, _, err := cmc.Drivers[0].GetSigningKeys() if err != nil { - api.SendError(conn, "failed to get IK: %v", err) + sendError(conn, s, "failed to get IK: %v", err) return } @@ -274,7 +281,7 @@ func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) { log.Trace("TLSSign using opts: ", opts) signature, err := tlsKeyPriv.(crypto.Signer).Sign(rand.Reader, req.Content, opts) if err != nil { - api.SendError(conn, "failed to sign: %v", err) + sendError(conn, s, "failed to sign: %v", err) return } @@ -282,34 +289,34 @@ func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := &api.TLSSignResponse{ SignedContent: signature, } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeTLSSign) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Performed signing") } -func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received TLS cert request") if len(cmc.Drivers) == 0 { - api.SendError(conn, "no valid signers configured") + sendError(conn, s, "no valid signers configured") return } // Parse the message and return the TLS signing request req := new(api.TLSSignRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "failed to unmarshal payload: %v", err) + sendError(conn, s, "failed to unmarshal payload: %v", err) return } // TODO ID is currently not used @@ -318,7 +325,7 @@ func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) { // Retrieve certificates certChain, err := cmc.Drivers[0].GetCertChain() if err != nil { - api.SendError(conn, "failed to get certchain: %v", err) + sendError(conn, s, "failed to get certchain: %v", err) return } @@ -326,16 +333,44 @@ func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := &api.TLSCertResponse{ Certificate: internal.WriteCertsPem(certChain), } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeTLSCert) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Obtained TLS cert") } + +func sendError(conn net.Conn, s ar.Serializer, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + log.Warn(msg) + resp := &api.SocketError{ + Msg: msg, + } + payload, err := s.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal error response: %v", err) + } + + return api.Send(conn, payload, api.TypeError) +} + +func detectSerialization(payload []byte) (ar.Serializer, error) { + log.Trace("Detecting serialization of request..") + if json.Valid(payload) { + log.Trace("Detected JSON serialization") + return ar.JsonSerializer{}, nil + } else if err := cbor.Valid(payload); err == nil { + log.Trace("Detected CBOR serialization") + return ar.CborSerializer{}, nil + } else { + log.Trace("Unable to detect AR serialization format") + return nil, fmt.Errorf("failed to detect request serialization") + } +} diff --git a/testtool/config.go b/testtool/config.go index 001f66a9..033196c1 100644 --- a/testtool/config.go +++ b/testtool/config.go @@ -25,6 +25,7 @@ import ( "strings" "time" + ar "github.com/Fraunhofer-AISEC/cmc/attestationreport" "github.com/Fraunhofer-AISEC/cmc/cmc" "github.com/Fraunhofer-AISEC/cmc/internal" "github.com/sirupsen/logrus" @@ -38,6 +39,11 @@ const ( var ( apis = map[string]Api{} + serializers = map[string]ar.Serializer{ + "json": ar.JsonSerializer{}, + "cbor": ar.CborSerializer{}, + } + logLevels = map[string]logrus.Level{ "panic": logrus.PanicLevel, "fatal": logrus.FatalLevel, @@ -83,6 +89,7 @@ type config struct { Header []string `json:"header"` Method string `json:"method"` Data string `json:"data"` + Serializer string `json:"socketApiSerializer"` // Only Lib API ProvAddr string `json:"provServerAddr"` Metadata []string `json:"metadata"` @@ -98,30 +105,32 @@ type config struct { CtrRootfs string `json:"ctrRootfs"` CtrConfig string `json:"ctrConfig"` - ca []byte - policies []byte - api Api - interval time.Duration + ca []byte + policies []byte + api Api + interval time.Duration + serializer ar.Serializer } const ( // Generic flags - configFlag = "config" - modeFlag = "mode" - addrFlag = "addr" - cmcFlag = "cmc" - reportFlag = "report" - resultFlag = "result" - nonceFlag = "nonce" - caFlag = "ca" - policiesFlag = "policies" - apiFlag = "api" - networkFlag = "network" - mtlsFlag = "mtls" - attestFlag = "attest" - logFlag = "log" - publishFlag = "publish" - intervalFlag = "interval" + configFlag = "config" + modeFlag = "mode" + addrFlag = "addr" + cmcFlag = "cmc" + reportFlag = "report" + resultFlag = "result" + nonceFlag = "nonce" + caFlag = "ca" + policiesFlag = "policies" + apiFlag = "api" + networkFlag = "network" + mtlsFlag = "mtls" + attestFlag = "attest" + logFlag = "log" + publishFlag = "publish" + intervalFlag = "interval" + serializerFlag = "serializer" // Only lib API provAddrFlag = "prov" metadataFlag = "metadata" @@ -165,6 +174,7 @@ func getConfig() *config { interval := flag.String(intervalFlag, "", "Interval at which connectors will be attested. If set to <=0, attestation will only be"+ " done once") + serializer := flag.String(serializerFlag, "", "Serializer to be used for socket API (JSON or CBOR)") // Lib API flags provAddr := flag.String(provAddrFlag, "", "Address of the provisioning server (only for libapi)") @@ -205,6 +215,7 @@ func getConfig() *config { IntervalStr: "0s", Attest: "mutual", Method: "GET", + Serializer: "cbor", } // Obtain custom configuration from file if specified @@ -266,6 +277,9 @@ func getConfig() *config { if internal.FlagPassed(intervalFlag) { c.IntervalStr = *interval } + if internal.FlagPassed(serializerFlag) { + c.Serializer = *serializer + } // Lib API flags if internal.FlagPassed(provAddrFlag) { c.ProvAddr = *provAddr @@ -356,11 +370,19 @@ func getConfig() *config { } } + // Get Serializer + log.Tracef("Getting serializer %v", c.Serializer) + c.serializer, ok = serializers[strings.ToLower(c.Serializer)] + if !ok { + flag.Usage() + log.Fatalf("Serializer %v is not implemented", c.Serializer) + } + // Get API c.api, ok = apis[strings.ToLower(c.Api)] if !ok { flag.Usage() - log.Fatalf("API %v is not implemented\n", c.Api) + log.Fatalf("API %v is not implemented", c.Api) } return c @@ -433,7 +455,7 @@ func printConfig(c *config) { log.Debugf("\tPoliciesFile: %v", c.PoliciesFile) } if strings.EqualFold(c.Api, "socket") { - log.Debugf("\tApi (Network): %v (%v)", c.Api, c.Network) + log.Debugf("\tApi (Network): %v (%v, %v)", c.Api, c.Network, c.Serializer) } else { log.Debugf("\tApi : %v", c.Api) } diff --git a/testtool/socket.go b/testtool/socket.go index 509cee5f..aec7ccb4 100644 --- a/testtool/socket.go +++ b/testtool/socket.go @@ -24,9 +24,8 @@ import ( "net" "os" - "github.com/fxamacker/cbor/v2" - // local modules + ar "github.com/Fraunhofer-AISEC/cmc/attestationreport" m "github.com/Fraunhofer-AISEC/cmc/measure" "github.com/Fraunhofer-AISEC/cmc/api" @@ -62,7 +61,7 @@ func (a SocketApi) generate(c *config) { } // Marshal payload - payload, err := cbor.Marshal(req) + payload, err := c.serializer.Marshal(req) if err != nil { log.Fatalf("failed to marshal payload: %v", err) } @@ -74,14 +73,15 @@ func (a SocketApi) generate(c *config) { } // Read reply - payload, _, err = api.Receive(conn) + payload, msgType, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + checkError(msgType, payload, c.serializer) // Unmarshal attestation response var attestationResp api.AttestationResponse - err = cbor.Unmarshal(payload, &attestationResp) + err = c.serializer.Unmarshal(payload, &attestationResp) if err != nil { log.Fatalf("failed to unmarshal response") } @@ -122,7 +122,7 @@ func (a SocketApi) verify(c *config) { Policies: c.policies, } - resp, err := verifySocketRequest(c.Network, c.CmcAddr, req) + resp, err := verifySocketRequest(c, req) if err != nil { log.Fatalf("Failed to verify: %v", err) } @@ -156,7 +156,7 @@ func (a SocketApi) measure(c *config) { RootfsSha256: rootfsHash, } - resp, err := measureSocketRequest(c.Network, c.CmcAddr, req) + resp, err := measureSocketRequest(c, req) if err != nil { log.Fatalf("Failed to measure: %v", err) } @@ -188,19 +188,19 @@ func (a SocketApi) iothub(c *config) { log.Fatalf("IoT hub not implemented for sockets API") } -func verifySocketRequest(network, addr string, req *api.VerificationRequest, +func verifySocketRequest(c *config, req *api.VerificationRequest, ) (*api.VerificationResponse, error) { - log.Tracef("Connecting via %v socket to %v", network, addr) + log.Tracef("Connecting via %v socket to %v", c.Network, c.CmcAddr) // Establish connection - conn, err := net.Dial(network, addr) + conn, err := net.Dial(c.Network, c.CmcAddr) if err != nil { return nil, fmt.Errorf("error dialing: %v", err) } // Marshal payload - payload, err := cbor.Marshal(req) + payload, err := c.serializer.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal payload: %v", err) } @@ -212,14 +212,15 @@ func verifySocketRequest(network, addr string, req *api.VerificationRequest, } // Read reply - payload, _, err = api.Receive(conn) + payload, msgType, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + checkError(msgType, payload, c.serializer) // Unmarshal attestation response verifyResp := new(api.VerificationResponse) - err = cbor.Unmarshal(payload, verifyResp) + err = c.serializer.Unmarshal(payload, verifyResp) if err != nil { return nil, fmt.Errorf("failed to unmarshal response") } @@ -227,19 +228,19 @@ func verifySocketRequest(network, addr string, req *api.VerificationRequest, return verifyResp, nil } -func measureSocketRequest(network, addr string, req *api.MeasureRequest, +func measureSocketRequest(c *config, req *api.MeasureRequest, ) (*api.MeasureResponse, error) { - log.Tracef("Connecting via %v socket to %v", network, addr) + log.Tracef("Connecting via %v socket to %v", c.Network, c.CmcAddr) // Establish connection - conn, err := net.Dial(network, addr) + conn, err := net.Dial(c.Network, c.CmcAddr) if err != nil { return nil, fmt.Errorf("error dialing: %v", err) } // Marshal payload - payload, err := cbor.Marshal(req) + payload, err := c.serializer.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal payload: %v", err) } @@ -251,17 +252,30 @@ func measureSocketRequest(network, addr string, req *api.MeasureRequest, } // Read reply - payload, _, err = api.Receive(conn) + payload, msgType, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + checkError(msgType, payload, c.serializer) // Unmarshal attestation response measureResp := new(api.MeasureResponse) - err = cbor.Unmarshal(payload, measureResp) + err = c.serializer.Unmarshal(payload, measureResp) if err != nil { return nil, fmt.Errorf("failed to unmarshal response") } return measureResp, nil } + +func checkError(t uint32, payload []byte, s ar.Serializer) { + if t == api.TypeError { + resp := new(api.SocketError) + err := s.Unmarshal(payload, resp) + if err != nil { + log.Fatal("failed to unmarshal error response") + } else { + log.Fatalf("server responded with error: %v", resp.Msg) + } + } +}