diff --git a/api/api.go b/api/api.go index 45ab7d0f..cf2b9e8b 100644 --- a/api/api.go +++ b/api/api.go @@ -26,7 +26,6 @@ import ( "fmt" "net" - "github.com/fxamacker/cbor/v2" log "github.com/sirupsen/logrus" ) @@ -203,7 +202,7 @@ func SignerOptsToHash(opts crypto.SignerOpts) (HashFunction, error) { // // Len uint32 -> Length of the payload to be sent // Type uint32 -> Type of the payload -// payload []byte -> CBOR-encoded payload +// payload []byte -> encoded payload func Receive(conn net.Conn) ([]byte, uint32, error) { // If unix domain sockets are used, set the write buffer size @@ -260,16 +259,6 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { log.Tracef("Received payload length %v", payloadLen) - if msgType == TypeError { - resp := new(SocketError) - err = cbor.Unmarshal(payload.Bytes(), resp) - if err != nil { - return nil, 0, fmt.Errorf("failed to unmarshal error response") - } else { - return nil, 0, fmt.Errorf("server responded with error: %v", resp.Msg) - } - } - return payload.Bytes(), msgType, nil } @@ -277,7 +266,7 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { // // Len uint32 -> Length of the payload to be sent // Type uint32 -> Type of the payload -// payload []byte -> CBOR-encoded payload +// payload []byte -> encoded payload func Send(conn net.Conn, payload []byte, t uint32) error { if len(payload) > MaxMsgLen { @@ -320,17 +309,3 @@ func Send(conn net.Conn, payload []byte, t uint32) error { return nil } - -func SendError(conn net.Conn, format string, args ...interface{}) error { - msg := fmt.Sprintf(format, args...) - log.Warn(msg) - resp := &SocketError{ - Msg: msg, - } - payload, err := cbor.Marshal(resp) - if err != nil { - return fmt.Errorf("failed to marshal error response: %v", err) - } - - return Send(conn, payload, TypeError) -} diff --git a/cmcd/socket.go b/cmcd/socket.go index 4440472e..1a639158 100644 --- a/cmcd/socket.go +++ b/cmcd/socket.go @@ -29,15 +29,15 @@ import ( "encoding/hex" "encoding/json" - "github.com/fxamacker/cbor/v2" - // local modules "github.com/Fraunhofer-AISEC/cmc/api" + ar "github.com/Fraunhofer-AISEC/cmc/attestationreport" "github.com/Fraunhofer-AISEC/cmc/cmc" "github.com/Fraunhofer-AISEC/cmc/generate" "github.com/Fraunhofer-AISEC/cmc/internal" m "github.com/Fraunhofer-AISEC/cmc/measure" "github.com/Fraunhofer-AISEC/cmc/verify" + "github.com/fxamacker/cbor/v2" ) // Server is the server structure @@ -81,33 +81,40 @@ func handleIncoming(conn net.Conn, cmc *cmc.Cmc) { payload, reqType, err := api.Receive(conn) if err != nil { - api.SendError(conn, "Failed to receive: %v", err) + s, err := detectSerialization(payload) + sendError(conn, s, "Failed to receive: %v", err) + return + } + + s, err := detectSerialization(payload) + if err != nil { + log.Errorf("Failed to detect serialization of request: %v", err) return } // Handle request switch reqType { case api.TypeAttest: - attest(conn, payload, cmc) + attest(conn, payload, cmc, s) case api.TypeVerify: - validate(conn, payload, cmc) + validate(conn, payload, cmc, s) case api.TypeMeasure: - measure(conn, payload, cmc) + measure(conn, payload, cmc, s) case api.TypeTLSCert: - tlscert(conn, payload, cmc) + tlscert(conn, payload, cmc, s) case api.TypeTLSSign: - tlssign(conn, payload, cmc) + tlssign(conn, payload, cmc, s) default: - api.SendError(conn, "Invalid Type: %v", reqType) + sendError(conn, s, "Invalid Type: %v", reqType) } } -func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Prover: Received socket attestation request") if len(cmc.Drivers) == 0 { - api.SendError(conn, "no valid signers configured") + sendError(conn, s, "no valid signers configured") return } @@ -116,9 +123,9 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { } req := new(api.AttestationRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "failed to unmarshal attestation request: %v", err) + sendError(conn, s, "failed to unmarshal attestation request: %v", err) return } @@ -126,14 +133,14 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { report, err := generate.Generate(req.Nonce, cmc.Metadata, cmc.Drivers, cmc.Serializer) if err != nil { - api.SendError(conn, "failed to generate attestation report: %v", err) + sendError(conn, s, "failed to generate attestation report: %v", err) return } log.Debug("Prover: Signing Attestation Report") r, err := generate.Sign(report, cmc.Drivers[0], cmc.Serializer) if err != nil { - api.SendError(conn, "Failed to sign attestation report: %v", err) + sendError(conn, s, "Failed to sign attestation report: %v", err) return } @@ -141,28 +148,28 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := &api.AttestationResponse{ AttestationReport: r, } - data, err := cbor.Marshal(resp) + data, err := s.Marshal(resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeAttest) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Prover: Finished") } -func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received Connection Request Type 'Verification Request'") req := new(api.VerificationRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "Failed to unmarshal verification request: %v", err) + sendError(conn, s, "Failed to unmarshal verification request: %v", err) return } @@ -173,7 +180,7 @@ func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc) { log.Debug("Verifier: Marshaling Attestation Result") r, err := json.Marshal(result) if err != nil { - api.SendError(conn, "Verifier: failed to marshal Attestation Result: %v", err) + sendError(conn, s, "Verifier: failed to marshal Attestation Result: %v", err) return } @@ -181,28 +188,28 @@ func validate(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := api.VerificationResponse{ VerificationResult: r, } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeVerify) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Verifier: Finished") } -func measure(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func measure(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received Connection Request Type 'Measure Request'") req := new(api.MeasureRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "Failed to unmarshal measure request: %v", err) + sendError(conn, s, "Failed to unmarshal measure request: %v", err) return } @@ -225,48 +232,48 @@ func measure(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := api.MeasureResponse{ Success: success, } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeMeasure) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Measurer: Finished") } -func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received TLS sign request") if len(cmc.Drivers) == 0 { - api.SendError(conn, "no valid signers configured") + sendError(conn, s, "no valid signers configured") return } // Parse the message and return the TLS signing request req := new(api.TLSSignRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "failed to unmarshal payload: %v", err) + sendError(conn, s, "failed to unmarshal payload: %v", err) return } // Get signing options from request opts, err := api.HashToSignerOpts(req.Hashtype, req.PssOpts) if err != nil { - api.SendError(conn, "failed to choose requested hash function: %v", err) + sendError(conn, s, "failed to choose requested hash function: %v", err) return } // Get key handle from (hardware) interface tlsKeyPriv, _, err := cmc.Drivers[0].GetSigningKeys() if err != nil { - api.SendError(conn, "failed to get IK: %v", err) + sendError(conn, s, "failed to get IK: %v", err) return } @@ -274,7 +281,7 @@ func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) { log.Trace("TLSSign using opts: ", opts) signature, err := tlsKeyPriv.(crypto.Signer).Sign(rand.Reader, req.Content, opts) if err != nil { - api.SendError(conn, "failed to sign: %v", err) + sendError(conn, s, "failed to sign: %v", err) return } @@ -282,34 +289,34 @@ func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := &api.TLSSignResponse{ SignedContent: signature, } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeTLSSign) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Performed signing") } -func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) { +func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc, s ar.Serializer) { log.Debug("Received TLS cert request") if len(cmc.Drivers) == 0 { - api.SendError(conn, "no valid signers configured") + sendError(conn, s, "no valid signers configured") return } // Parse the message and return the TLS signing request req := new(api.TLSSignRequest) - err := cbor.Unmarshal(payload, req) + err := s.Unmarshal(payload, req) if err != nil { - api.SendError(conn, "failed to unmarshal payload: %v", err) + sendError(conn, s, "failed to unmarshal payload: %v", err) return } // TODO ID is currently not used @@ -318,7 +325,7 @@ func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) { // Retrieve certificates certChain, err := cmc.Drivers[0].GetCertChain() if err != nil { - api.SendError(conn, "failed to get certchain: %v", err) + sendError(conn, s, "failed to get certchain: %v", err) return } @@ -326,16 +333,44 @@ func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) { resp := &api.TLSCertResponse{ Certificate: internal.WriteCertsPem(certChain), } - data, err := cbor.Marshal(&resp) + data, err := s.Marshal(&resp) if err != nil { - api.SendError(conn, "failed to marshal message: %v", err) + sendError(conn, s, "failed to marshal message: %v", err) return } err = api.Send(conn, data, api.TypeTLSCert) if err != nil { - api.SendError(conn, "failed to send: %v", err) + sendError(conn, s, "failed to send: %v", err) } log.Debug("Obtained TLS cert") } + +func sendError(conn net.Conn, s ar.Serializer, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + log.Warn(msg) + resp := &api.SocketError{ + Msg: msg, + } + payload, err := s.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal error response: %v", err) + } + + return api.Send(conn, payload, api.TypeError) +} + +func detectSerialization(payload []byte) (ar.Serializer, error) { + log.Trace("Detecting serialization of request..") + if json.Valid(payload) { + log.Trace("Detected JSON serialization") + return ar.JsonSerializer{}, nil + } else if err := cbor.Valid(payload); err == nil { + log.Trace("Detected CBOR serialization") + return ar.CborSerializer{}, nil + } else { + log.Trace("Unable to detect AR serialization format") + return nil, fmt.Errorf("failed to detect request serialization") + } +} diff --git a/testtool/config.go b/testtool/config.go index 001f66a9..033196c1 100644 --- a/testtool/config.go +++ b/testtool/config.go @@ -25,6 +25,7 @@ import ( "strings" "time" + ar "github.com/Fraunhofer-AISEC/cmc/attestationreport" "github.com/Fraunhofer-AISEC/cmc/cmc" "github.com/Fraunhofer-AISEC/cmc/internal" "github.com/sirupsen/logrus" @@ -38,6 +39,11 @@ const ( var ( apis = map[string]Api{} + serializers = map[string]ar.Serializer{ + "json": ar.JsonSerializer{}, + "cbor": ar.CborSerializer{}, + } + logLevels = map[string]logrus.Level{ "panic": logrus.PanicLevel, "fatal": logrus.FatalLevel, @@ -83,6 +89,7 @@ type config struct { Header []string `json:"header"` Method string `json:"method"` Data string `json:"data"` + Serializer string `json:"socketApiSerializer"` // Only Lib API ProvAddr string `json:"provServerAddr"` Metadata []string `json:"metadata"` @@ -98,30 +105,32 @@ type config struct { CtrRootfs string `json:"ctrRootfs"` CtrConfig string `json:"ctrConfig"` - ca []byte - policies []byte - api Api - interval time.Duration + ca []byte + policies []byte + api Api + interval time.Duration + serializer ar.Serializer } const ( // Generic flags - configFlag = "config" - modeFlag = "mode" - addrFlag = "addr" - cmcFlag = "cmc" - reportFlag = "report" - resultFlag = "result" - nonceFlag = "nonce" - caFlag = "ca" - policiesFlag = "policies" - apiFlag = "api" - networkFlag = "network" - mtlsFlag = "mtls" - attestFlag = "attest" - logFlag = "log" - publishFlag = "publish" - intervalFlag = "interval" + configFlag = "config" + modeFlag = "mode" + addrFlag = "addr" + cmcFlag = "cmc" + reportFlag = "report" + resultFlag = "result" + nonceFlag = "nonce" + caFlag = "ca" + policiesFlag = "policies" + apiFlag = "api" + networkFlag = "network" + mtlsFlag = "mtls" + attestFlag = "attest" + logFlag = "log" + publishFlag = "publish" + intervalFlag = "interval" + serializerFlag = "serializer" // Only lib API provAddrFlag = "prov" metadataFlag = "metadata" @@ -165,6 +174,7 @@ func getConfig() *config { interval := flag.String(intervalFlag, "", "Interval at which connectors will be attested. If set to <=0, attestation will only be"+ " done once") + serializer := flag.String(serializerFlag, "", "Serializer to be used for socket API (JSON or CBOR)") // Lib API flags provAddr := flag.String(provAddrFlag, "", "Address of the provisioning server (only for libapi)") @@ -205,6 +215,7 @@ func getConfig() *config { IntervalStr: "0s", Attest: "mutual", Method: "GET", + Serializer: "cbor", } // Obtain custom configuration from file if specified @@ -266,6 +277,9 @@ func getConfig() *config { if internal.FlagPassed(intervalFlag) { c.IntervalStr = *interval } + if internal.FlagPassed(serializerFlag) { + c.Serializer = *serializer + } // Lib API flags if internal.FlagPassed(provAddrFlag) { c.ProvAddr = *provAddr @@ -356,11 +370,19 @@ func getConfig() *config { } } + // Get Serializer + log.Tracef("Getting serializer %v", c.Serializer) + c.serializer, ok = serializers[strings.ToLower(c.Serializer)] + if !ok { + flag.Usage() + log.Fatalf("Serializer %v is not implemented", c.Serializer) + } + // Get API c.api, ok = apis[strings.ToLower(c.Api)] if !ok { flag.Usage() - log.Fatalf("API %v is not implemented\n", c.Api) + log.Fatalf("API %v is not implemented", c.Api) } return c @@ -433,7 +455,7 @@ func printConfig(c *config) { log.Debugf("\tPoliciesFile: %v", c.PoliciesFile) } if strings.EqualFold(c.Api, "socket") { - log.Debugf("\tApi (Network): %v (%v)", c.Api, c.Network) + log.Debugf("\tApi (Network): %v (%v, %v)", c.Api, c.Network, c.Serializer) } else { log.Debugf("\tApi : %v", c.Api) } diff --git a/testtool/socket.go b/testtool/socket.go index 509cee5f..aec7ccb4 100644 --- a/testtool/socket.go +++ b/testtool/socket.go @@ -24,9 +24,8 @@ import ( "net" "os" - "github.com/fxamacker/cbor/v2" - // local modules + ar "github.com/Fraunhofer-AISEC/cmc/attestationreport" m "github.com/Fraunhofer-AISEC/cmc/measure" "github.com/Fraunhofer-AISEC/cmc/api" @@ -62,7 +61,7 @@ func (a SocketApi) generate(c *config) { } // Marshal payload - payload, err := cbor.Marshal(req) + payload, err := c.serializer.Marshal(req) if err != nil { log.Fatalf("failed to marshal payload: %v", err) } @@ -74,14 +73,15 @@ func (a SocketApi) generate(c *config) { } // Read reply - payload, _, err = api.Receive(conn) + payload, msgType, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + checkError(msgType, payload, c.serializer) // Unmarshal attestation response var attestationResp api.AttestationResponse - err = cbor.Unmarshal(payload, &attestationResp) + err = c.serializer.Unmarshal(payload, &attestationResp) if err != nil { log.Fatalf("failed to unmarshal response") } @@ -122,7 +122,7 @@ func (a SocketApi) verify(c *config) { Policies: c.policies, } - resp, err := verifySocketRequest(c.Network, c.CmcAddr, req) + resp, err := verifySocketRequest(c, req) if err != nil { log.Fatalf("Failed to verify: %v", err) } @@ -156,7 +156,7 @@ func (a SocketApi) measure(c *config) { RootfsSha256: rootfsHash, } - resp, err := measureSocketRequest(c.Network, c.CmcAddr, req) + resp, err := measureSocketRequest(c, req) if err != nil { log.Fatalf("Failed to measure: %v", err) } @@ -188,19 +188,19 @@ func (a SocketApi) iothub(c *config) { log.Fatalf("IoT hub not implemented for sockets API") } -func verifySocketRequest(network, addr string, req *api.VerificationRequest, +func verifySocketRequest(c *config, req *api.VerificationRequest, ) (*api.VerificationResponse, error) { - log.Tracef("Connecting via %v socket to %v", network, addr) + log.Tracef("Connecting via %v socket to %v", c.Network, c.CmcAddr) // Establish connection - conn, err := net.Dial(network, addr) + conn, err := net.Dial(c.Network, c.CmcAddr) if err != nil { return nil, fmt.Errorf("error dialing: %v", err) } // Marshal payload - payload, err := cbor.Marshal(req) + payload, err := c.serializer.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal payload: %v", err) } @@ -212,14 +212,15 @@ func verifySocketRequest(network, addr string, req *api.VerificationRequest, } // Read reply - payload, _, err = api.Receive(conn) + payload, msgType, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + checkError(msgType, payload, c.serializer) // Unmarshal attestation response verifyResp := new(api.VerificationResponse) - err = cbor.Unmarshal(payload, verifyResp) + err = c.serializer.Unmarshal(payload, verifyResp) if err != nil { return nil, fmt.Errorf("failed to unmarshal response") } @@ -227,19 +228,19 @@ func verifySocketRequest(network, addr string, req *api.VerificationRequest, return verifyResp, nil } -func measureSocketRequest(network, addr string, req *api.MeasureRequest, +func measureSocketRequest(c *config, req *api.MeasureRequest, ) (*api.MeasureResponse, error) { - log.Tracef("Connecting via %v socket to %v", network, addr) + log.Tracef("Connecting via %v socket to %v", c.Network, c.CmcAddr) // Establish connection - conn, err := net.Dial(network, addr) + conn, err := net.Dial(c.Network, c.CmcAddr) if err != nil { return nil, fmt.Errorf("error dialing: %v", err) } // Marshal payload - payload, err := cbor.Marshal(req) + payload, err := c.serializer.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal payload: %v", err) } @@ -251,17 +252,30 @@ func measureSocketRequest(network, addr string, req *api.MeasureRequest, } // Read reply - payload, _, err = api.Receive(conn) + payload, msgType, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + checkError(msgType, payload, c.serializer) // Unmarshal attestation response measureResp := new(api.MeasureResponse) - err = cbor.Unmarshal(payload, measureResp) + err = c.serializer.Unmarshal(payload, measureResp) if err != nil { return nil, fmt.Errorf("failed to unmarshal response") } return measureResp, nil } + +func checkError(t uint32, payload []byte, s ar.Serializer) { + if t == api.TypeError { + resp := new(api.SocketError) + err := s.Unmarshal(payload, resp) + if err != nil { + log.Fatal("failed to unmarshal error response") + } else { + log.Fatalf("server responded with error: %v", resp.Msg) + } + } +}