From 36673bdcbcd0cd3d8306eafc74c095f852b91612 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20R=C3=BChl?= Date: Thu, 27 Jul 2023 23:04:09 +0200 Subject: [PATCH] feat(plc4go/opcua): some progress on secure channel --- plc4go/internal/opcua/EncryptionHandler.go | 23 +- plc4go/internal/opcua/MessageCodec.go | 13 + plc4go/internal/opcua/SecureChannel.go | 619 ++++++++++++++++++++- 3 files changed, 641 insertions(+), 14 deletions(-) diff --git a/plc4go/internal/opcua/EncryptionHandler.go b/plc4go/internal/opcua/EncryptionHandler.go index 738d839d409..4cb43c3de06 100644 --- a/plc4go/internal/opcua/EncryptionHandler.go +++ b/plc4go/internal/opcua/EncryptionHandler.go @@ -19,7 +19,10 @@ package opcua -import readWriteModel "github.com/apache/plc4x/plc4go/protocols/opcua/readwrite/model" +import ( + "crypto/x509" + readWriteModel "github.com/apache/plc4x/plc4go/protocols/opcua/readwrite/model" +) type EncryptionHandler struct { // TODO: implement me @@ -29,6 +32,22 @@ func NewEncryptionHandler(any, []byte, string) *EncryptionHandler { return &EncryptionHandler{} } -func (h *EncryptionHandler) encodeMessage(messageRequest readWriteModel.OpcuaMessageRequest, bytes []byte) []byte { +func (h *EncryptionHandler) encodeMessage(messageRequest readWriteModel.MessagePDU, bytes []byte) []byte { + return nil +} + +func (h *EncryptionHandler) decodeMessage(apu readWriteModel.OpcuaAPU) readWriteModel.OpcuaAPUExactly { + return nil +} + +func (h *EncryptionHandler) getCertificateX509(senderCertificate []byte) x509.Certificate { + return x509.Certificate{} +} + +func (h *EncryptionHandler) setServerCertificate(certificateX509 x509.Certificate) { + return +} + +func (h *EncryptionHandler) encryptPassword(password []byte) []byte { return nil } diff --git a/plc4go/internal/opcua/MessageCodec.go b/plc4go/internal/opcua/MessageCodec.go index ec3d6da52d5..82b45c155ec 100644 --- a/plc4go/internal/opcua/MessageCodec.go +++ b/plc4go/internal/opcua/MessageCodec.go @@ -61,6 +61,19 @@ func (m *MessageCodec) GetCodec() spi.MessageCodec { return m } +func (m *MessageCodec) Connect() error { + return m.ConnectWithContext(context.Background()) +} + +func (m *MessageCodec) ConnectWithContext(ctx context.Context) error { + if err := m.DefaultCodec.ConnectWithContext(ctx); err != nil { + return errors.Wrap(err, "error connecting default codec") + } + m.log.Debug().Msg("Opcua Driver running in ACTIVE mode.") + m.channel.onConnect(ctx, m) + return nil +} + func (m *MessageCodec) Send(message spi.Message) error { m.log.Trace().Msgf("Sending message\n%s", message) // Cast the message to the correct type of struct diff --git a/plc4go/internal/opcua/SecureChannel.go b/plc4go/internal/opcua/SecureChannel.go index a9314b81c73..127c497a064 100644 --- a/plc4go/internal/opcua/SecureChannel.go +++ b/plc4go/internal/opcua/SecureChannel.go @@ -20,12 +20,18 @@ package opcua import ( + "bytes" "context" + "encoding/binary" + "github.com/apache/plc4x/plc4go/spi" + "github.com/pkg/errors" "github.com/rs/zerolog" + "golang.org/x/exp/slices" "math/rand" "net" "net/url" "regexp" + "strconv" "sync/atomic" "time" @@ -83,7 +89,7 @@ var ( type SecureChannel struct { sessionName string clientNonce []byte - requestHandleGenerator atomic.Int32 + requestHandleGenerator atomic.Uint32 policyId readWriteModel.PascalString tokenType readWriteModel.UserTokenType discovery bool @@ -106,14 +112,14 @@ type SecureChannel struct { channelId atomic.Int32 tokenId atomic.Int32 authenticationToken readWriteModel.NodeIdTypeDefinition - context *MessageCodec // TODO: not sure if we need the codec here + codec *MessageCodec channelTransactionManager *SecureChannelTransactionManager lifetime int - keepAlive chan struct{} // TODO: check if this is the right thing + keepAlive func() sendBufferSize int maxMessageSize int endpoints []string - senderSequenceNumber atomic.Int64 + senderSequenceNumber atomic.Int32 log zerolog.Logger } @@ -164,10 +170,6 @@ func NewSecureChannel(log zerolog.Logger, ctx DriverContext, configuration Confi return s } -func (s *SecureChannel) Something() { - -} - func (s *SecureChannel) getAuthenticationToken() readWriteModel.NodeId { return readWriteModel.NewNodeId(s.authenticationToken) } @@ -202,11 +204,604 @@ func (s *SecureChannel) submit(ctx context.Context, codec *MessageCodec, errorDi } requestConsumer := func(transactionId int32) { - // bos - //codec.SendRequest(ctx,) // TODO: feed - _ = apu + var messageBuffer []byte + if err := codec.SendRequest(ctx, apu, + func(message spi.Message) bool { + opcuaAPU, ok := message.(readWriteModel.OpcuaAPUExactly) + if !ok { + s.log.Debug().Type("type", message).Msg("Not relevant") + return false + } + opcuaAPU = s.encryptionHandler.decodeMessage(opcuaAPU) + messagePDU := opcuaAPU.GetMessage() + opcuaResponse, ok := messagePDU.(readWriteModel.OpcuaMessageResponseExactly) + if !ok { + s.log.Debug().Type("type", message).Msg("Not relevant") + return false + } + if requestId := opcuaResponse.GetRequestId(); requestId != transactionId { + s.log.Debug().Int32("requestId", requestId).Int32("transactionId", transactionId).Msg("Not relevant") + return false + } else { + messageBuffer = opcuaResponse.GetMessage() + if !(s.senderSequenceNumber.Add(1) == (opcuaResponse.GetSequenceNumber())) { + s.log.Error().Msgf("Sequence number isn't as expected, we might have missed a packet. - %d != %d", s.senderSequenceNumber.Add(1), opcuaResponse.GetSequenceNumber()) + // TODO: where to dispatch the disconnect too + // codec.fireDisconnected() + } + } + return true + }, + func(message spi.Message) error { + opcuaAPU := message.(readWriteModel.OpcuaAPU) + opcuaAPU = s.encryptionHandler.decodeMessage(opcuaAPU) + messagePDU := opcuaAPU.GetMessage() + opcuaResponse := messagePDU.(readWriteModel.OpcuaMessageResponse) + if opcuaResponse.GetChunk() == (FINAL_CHUNK) { + s.tokenId.Store(opcuaResponse.GetSecureTokenId()) + s.channelId.Store(opcuaResponse.GetSecureChannelId()) + + consumer(messageBuffer) + } + return nil + }, + func(err error) error { + errorDispatcher(err) + return nil + }, + REQUEST_TIMEOUT); err != nil { + errorDispatcher(err) + } } s.log.Debug().Msgf("Submitting Transaction to TransactionManager %v", transactionId) - s.channelTransactionManager.submit(requestConsumer, transactionId) + if err := s.channelTransactionManager.submit(requestConsumer, transactionId); err != nil { + s.log.Debug().Err(err).Msg("error submitting") + } +} + +func (s *SecureChannel) onConnect(ctx context.Context, codec *MessageCodec) { + // Only the TCP transport supports login. + s.log.Debug().Msg("Opcua Driver running in ACTIVE mode.") + s.codec = codec + + hello := readWriteModel.NewOpcuaHelloRequest(FINAL_CHUNK, + VERSION, + DEFAULT_RECEIVE_BUFFER_SIZE, + DEFAULT_SEND_BUFFER_SIZE, + DEFAULT_MAX_MESSAGE_SIZE, + DEFAULT_MAX_CHUNK_COUNT, + s.endpoint) + + requestConsumer := func(transactionId int32) { + if err := codec.SendRequest( + ctx, + hello, + func(message spi.Message) bool { + opcuaAPU, ok := message.(readWriteModel.OpcuaAPUExactly) + if !ok { + s.log.Debug().Type("type", message).Msg("Not relevant") + return false + } + messagePDU := opcuaAPU.GetMessage() + _, ok = messagePDU.(readWriteModel.OpcuaAcknowledgeResponseExactly) + if !ok { + s.log.Debug().Type("type", messagePDU).Msg("Not relevant") + return false + } + return true + }, + func(message spi.Message) error { + opcuaAPU := message.(readWriteModel.OpcuaAPU) + messagePDU := opcuaAPU.GetMessage() + opcuaAcknowledgeResponse := messagePDU.(readWriteModel.OpcuaAcknowledgeResponse) + s.onConnectOpenSecureChannel(ctx, codec, opcuaAcknowledgeResponse) + return nil + }, + func(err error) error { + s.log.Debug().Err(err).Msg("error submitting") + return nil + }, + REQUEST_TIMEOUT); err != nil { + s.log.Debug().Err(err).Msg("error sending") + } + } + if err := s.channelTransactionManager.submit(requestConsumer, s.channelTransactionManager.getTransactionIdentifier()); err != nil { + s.log.Debug().Err(err).Msg("error submitting") + } +} + +func (s *SecureChannel) onConnectOpenSecureChannel(ctx context.Context, codec *MessageCodec, response readWriteModel.OpcuaAcknowledgeResponse) { + transactionId := s.channelTransactionManager.getTransactionIdentifier() + + requestHeader := readWriteModel.NewRequestHeader(readWriteModel.NewNodeId(s.authenticationToken), + s.getCurrentDateTime(), + 0, //RequestHandle + 0, + NULL_STRING, + REQUEST_TIMEOUT_LONG, + NULL_EXTENSION_OBJECT) + + var openSecureChannelRequest readWriteModel.OpenSecureChannelRequest + if s.isEncrypted { + openSecureChannelRequest = readWriteModel.NewOpenSecureChannelRequest( + requestHeader, + VERSION, + readWriteModel.SecurityTokenRequestType_securityTokenRequestTypeIssue, + readWriteModel.MessageSecurityMode_messageSecurityModeSignAndEncrypt, + readWriteModel.NewPascalByteString(int32(len(s.clientNonce)), s.clientNonce), + uint32(s.lifetime)) + } else { + openSecureChannelRequest = readWriteModel.NewOpenSecureChannelRequest( + requestHeader, + VERSION, + readWriteModel.SecurityTokenRequestType_securityTokenRequestTypeIssue, + readWriteModel.MessageSecurityMode_messageSecurityModeNone, + NULL_BYTE_STRING, + uint32(s.lifetime)) + } + + identifier, err := strconv.ParseUint(openSecureChannelRequest.GetIdentifier(), 10, 16) + if err != nil { + s.log.Debug().Err(err).Msg("error parsing identifier") + return + } + expandedNodeId := readWriteModel.NewExpandedNodeId( + false, //Namespace Uri Specified + false, //Server Index Specified + readWriteModel.NewNodeIdFourByte(0, uint16(identifier)), + nil, + nil, + ) + + extObject := readWriteModel.NewExtensionObject( + expandedNodeId, + nil, + openSecureChannelRequest, + false, + ) + + buffer := utils.NewWriteBufferByteBased(utils.WithByteOrderForByteBasedBuffer(binary.LittleEndian)) + if err := extObject.SerializeWithWriteBuffer(ctx, buffer); err != nil { + s.log.Debug().Err(err).Msg("error serializing") + return + } + + openRequest := readWriteModel.NewOpcuaOpenRequest( + FINAL_CHUNK, + 0, + readWriteModel.NewPascalString(s.securityPolicy), + s.publicCertificate, + s.thumbprint, + transactionId, + transactionId, + buffer.GetBytes(), + ) + + var apu readWriteModel.OpcuaAPU + + if s.isEncrypted { + apu, err = readWriteModel.OpcuaAPUParse(ctx, s.encryptionHandler.encodeMessage(openRequest, buffer.GetBytes()), false) + if err != nil { + s.log.Debug().Err(err).Msg("error parsing") + return + } + } else { + apu = readWriteModel.NewOpcuaAPU(openRequest, false) + } + + requestConsumer := func(transactionId int32) { + if err := codec.SendRequest( + ctx, + apu, + func(message spi.Message) bool { + opcuaAPU, ok := message.(readWriteModel.OpcuaAPUExactly) + if !ok { + s.log.Debug().Type("type", message).Msg("Not relevant") + return false + } + messagePDU := opcuaAPU.GetMessage() + openResponse, ok := messagePDU.(readWriteModel.OpcuaOpenResponseExactly) + if !ok { + s.log.Debug().Type("type", messagePDU).Msg("Not relevant") + return false + } + return openResponse.GetRequestId() == transactionId + }, + func(message spi.Message) error { + opcuaAPU := message.(readWriteModel.OpcuaAPU) + messagePDU := opcuaAPU.GetMessage() + opcuaOpenResponse := messagePDU.(readWriteModel.OpcuaOpenResponse) + readBuffer := utils.NewReadBufferByteBased(opcuaOpenResponse.GetMessage(), utils.WithByteOrderForReadBufferByteBased(binary.LittleEndian)) + extensionObject, err := readWriteModel.ExtensionObjectParseWithBuffer(ctx, readBuffer, false) + if err != nil { + return errors.Wrap(err, "error parsing") + } + //Store the initial sequence number from the server. there's no requirement for the server and client to use the same starting number. + s.senderSequenceNumber.Store(opcuaOpenResponse.GetSequenceNumber()) + + if fault, ok := extensionObject.GetBody().(readWriteModel.ServiceFaultExactly); ok { + statusCode := fault.GetResponseHeader().(readWriteModel.ResponseHeader).GetServiceResult().GetStatusCode() + statusCodeByValue, _ := readWriteModel.OpcuaStatusCodeByValue(statusCode) + s.log.Error().Msgf("Failed to connect to opc ua server for the following reason:- %v, %v", + statusCode, + statusCodeByValue) + } else { + s.log.Debug().Msg("Got Secure Response Connection Response") + openSecureChannelResponse := extensionObject.GetBody().(readWriteModel.OpenSecureChannelResponse) + s.tokenId.Store(int32(openSecureChannelResponse.GetSecurityToken().(readWriteModel.ChannelSecurityToken).GetTokenId())) // TODO: strange that int32 and uint32 missmatch + s.channelId.Store(int32(openSecureChannelResponse.GetSecurityToken().(readWriteModel.ChannelSecurityToken).GetChannelId())) + s.onConnectCreateSessionRequest(ctx, codec) + } + return nil + }, + func(err error) error { + s.log.Debug().Err(err).Msg("error submitting") + return nil + }, + REQUEST_TIMEOUT, + ); err != nil { + s.log.Debug().Err(err).Msg("a error") + } + } + s.log.Debug().Msgf("Submitting OpenSecureChannel with id of %d", transactionId) + if err := s.channelTransactionManager.submit(requestConsumer, transactionId); err != nil { + s.log.Debug().Err(err).Msg("error submitting") + } +} + +func (s *SecureChannel) onConnectCreateSessionRequest(ctx context.Context, codec *MessageCodec) { + requestHeader := readWriteModel.NewRequestHeader( + readWriteModel.NewNodeId(s.authenticationToken), + s.getCurrentDateTime(), + 0, + 0, + NULL_STRING, + REQUEST_TIMEOUT_LONG, + NULL_EXTENSION_OBJECT) + + applicationName := readWriteModel.NewLocalizedText( + true, + true, + readWriteModel.NewPascalString("en"), + APPLICATION_TEXT) + + noOfDiscoveryUrls := int32(-1) + var discoveryUrls []readWriteModel.PascalString + + clientDescription := readWriteModel.NewApplicationDescription(APPLICATION_URI, + PRODUCT_URI, + applicationName, + readWriteModel.ApplicationType_applicationTypeClient, + NULL_STRING, + NULL_STRING, + noOfDiscoveryUrls, + discoveryUrls) + + createSessionRequest := readWriteModel.NewCreateSessionRequest( + requestHeader, + clientDescription, + NULL_STRING, + s.endpoint, + readWriteModel.NewPascalString(s.sessionName), + readWriteModel.NewPascalByteString(int32(len(s.clientNonce)), s.clientNonce), + NULL_BYTE_STRING, + 120000, + 0, + ) + + identifier, err := strconv.ParseUint(createSessionRequest.GetIdentifier(), 10, 16) + if err != nil { + s.log.Debug().Err(err).Msg("error parsing identifier") + return + } + expandedNodeId := readWriteModel.NewExpandedNodeId( + false, //Namespace Uri Specified + false, //Server Index Specified + readWriteModel.NewNodeIdFourByte(0, uint16(identifier)), + nil, + nil) + + extObject := readWriteModel.NewExtensionObject( + expandedNodeId, + nil, + createSessionRequest, + false, + ) + + buffer := utils.NewWriteBufferByteBased(utils.WithByteOrderForByteBasedBuffer(binary.LittleEndian)) + if err := extObject.SerializeWithWriteBuffer(ctx, buffer); err != nil { + s.log.Debug().Err(err).Msg("error serializing") + return + } + + consumer := func(opcuaResponse []byte) { + message, err := readWriteModel.ExtensionObjectParseWithBuffer(ctx, utils.NewReadBufferByteBased(opcuaResponse, utils.WithByteOrderForReadBufferByteBased(binary.LittleEndian)), false) + if err != nil { + s.log.Error().Err(err).Msg("error parsing") + return + } + if fault, ok := message.GetBody().(readWriteModel.ServiceFaultExactly); ok { + statusCode := fault.GetResponseHeader().(readWriteModel.ResponseHeader).GetServiceResult().GetStatusCode() + statusCodeByValue, _ := readWriteModel.OpcuaStatusCodeByValue(statusCode) + s.log.Error().Msgf("Failed to connect to opc ua server for the following reason:- %v, %v", + statusCode, + statusCodeByValue) + } else { + s.log.Debug().Msg("Got Create Session Response Connection Response") + + extensionObject, err := readWriteModel.ExtensionObjectParseWithBuffer(ctx, utils.NewReadBufferByteBased(opcuaResponse, utils.WithByteOrderForReadBufferByteBased(binary.LittleEndian)), false) + if err != nil { + s.log.Error().Err(err).Msg("error parsing") + return + } + unknownExtensionObject := extensionObject.GetBody() + if responseMessage, ok := unknownExtensionObject.(readWriteModel.CreateSessionResponseExactly); ok { + s.authenticationToken = responseMessage.GetAuthenticationToken().GetNodeId() + + s.onConnectActivateSessionRequest(ctx, codec, responseMessage, message.GetBody().(readWriteModel.CreateSessionResponse)) + } else { + serviceFault := unknownExtensionObject.(readWriteModel.ServiceFault) + header := serviceFault.GetResponseHeader().(readWriteModel.ResponseHeader) + s.log.Error().Msgf("Subscription ServiceFault returned from server with error code, '%s'", header.GetServiceResult()) + } + } + } + + errorDispatcher := func(err error) { + s.log.Error().Err(err).Msg("Error while waiting for subscription response") + } + + result := make(chan apiModel.PlcReadRequestResult, 1) + s.submit(ctx, codec, errorDispatcher, result, consumer, buffer) +} + +func (s *SecureChannel) onConnectActivateSessionRequest(ctx context.Context, codec *MessageCodec, opcuaMessageResponse readWriteModel.CreateSessionResponse, sessionResponse readWriteModel.CreateSessionResponse) { + s.senderCertificate = sessionResponse.GetServerCertificate().GetStringValue() + s.encryptionHandler.setServerCertificate(s.encryptionHandler.getCertificateX509(s.senderCertificate)) + s.senderNonce = sessionResponse.GetServerNonce().GetStringValue() + endpoints := make([]string, 3) + if address, err := url.Parse(s.configuration.host); err != nil { + if names, err := net.LookupAddr(address.Host); err != nil { + endpoints[0] = "opc.tcp://" + names[rand.Intn(len(names))] + ":" + s.configuration.port + s.configuration.transportEndpoint + } + endpoints[1] = "opc.tcp://" + address.Hostname() + ":" + s.configuration.port + s.configuration.transportEndpoint + //endpoints[2] = "opc.tcp://" + address.getCanonicalHostName() + ":" + s.configuration.getPort() + s.configuration.transportEndpoint// TODO: not sure how to get that in golang + } + + s.selectEndpoint(sessionResponse) + + if s.policyId == nil { + s.log.Error().Msg("Unable to find endpoint - " + endpoints[1]) + return + } + + userIdentityToken := s.getIdentityToken(s.tokenType, s.policyId.GetStringValue()) + + requestHandle := s.getRequestHandle() + + requestHeader := readWriteModel.NewRequestHeader( + readWriteModel.NewNodeId(s.authenticationToken), + s.getCurrentDateTime(), + requestHandle, + 0, + NULL_STRING, + REQUEST_TIMEOUT_LONG, + NULL_EXTENSION_OBJECT) + + clientSignature := readWriteModel.NewSignatureData(NULL_STRING, NULL_BYTE_STRING) + + activateSessionRequest := readWriteModel.NewActivateSessionRequest( + requestHeader, + clientSignature, + 0, + nil, + 0, + nil, + userIdentityToken, + clientSignature) + + identifier, err := strconv.ParseUint(activateSessionRequest.GetIdentifier(), 10, 16) + if err != nil { + s.log.Debug().Err(err).Msg("error parsing identifier") + return + } + + expandedNodeId := readWriteModel.NewExpandedNodeId(false, //Namespace Uri Specified + false, //Server Index Specified + readWriteModel.NewNodeIdFourByte(0, uint16(identifier)), + nil, + nil) + + extObject := readWriteModel.NewExtensionObject( + expandedNodeId, + nil, + activateSessionRequest, + false, + ) + + buffer := utils.NewWriteBufferByteBased(utils.WithByteOrderForByteBasedBuffer(binary.LittleEndian)) + if err := extObject.SerializeWithWriteBuffer(ctx, buffer); err != nil { + s.log.Debug().Err(err).Msg("error serializing") + return + } + + consumer := func(opcuaResponse []byte) { + message, err := readWriteModel.ExtensionObjectParseWithBuffer(ctx, utils.NewReadBufferByteBased(opcuaResponse, utils.WithByteOrderForReadBufferByteBased(binary.LittleEndian)), false) + if err != nil { + s.log.Error().Err(err).Msg("error parsing") + return + } + if fault, ok := message.GetBody().(readWriteModel.ServiceFaultExactly); ok { + statusCode := fault.GetResponseHeader().(readWriteModel.ResponseHeader).GetServiceResult().GetStatusCode() + statusCodeByValue, _ := readWriteModel.OpcuaStatusCodeByValue(statusCode) + s.log.Error().Msgf("Failed to connect to opc ua server for the following reason:- %v, %v", + statusCode, + statusCodeByValue) + } else { + s.log.Debug().Msg("Got Activate Session Response Connection Response") + + extensionObject, err := readWriteModel.ExtensionObjectParseWithBuffer(ctx, utils.NewReadBufferByteBased(opcuaResponse, utils.WithByteOrderForReadBufferByteBased(binary.LittleEndian)), false) + if err != nil { + s.log.Error().Err(err).Msg("error parsing") + return + } + unknownExtensionObject := extensionObject.GetBody() + if responseMessage, ok := unknownExtensionObject.(readWriteModel.ActivateSessionResponseExactly); ok { + returnedRequestHandle := responseMessage.GetResponseHeader().(readWriteModel.ResponseHeader).GetRequestHandle() + if !(requestHandle == returnedRequestHandle) { + s.log.Error().Msgf("Request handle isn't as expected, we might have missed a packet. %d != %d", requestHandle, returnedRequestHandle) + } + + // Send an event that connection setup is complete. + s.keepAlive = s.createKeepAlive() + // codec.fireConnected()// TODO: how to do that + } else { + serviceFault := unknownExtensionObject.(readWriteModel.ServiceFault) + header := serviceFault.GetResponseHeader().(readWriteModel.ResponseHeader) + s.log.Error().Msgf("Subscription ServiceFault returned from server with error code, '%s'", header.GetServiceResult()) + } + } + } + + errorDispatcher := func(err error) { + s.log.Error().Err(err).Msg("Error while waiting for subscription response") + } + + result := make(chan apiModel.PlcReadRequestResult, 1) + s.submit(ctx, codec, errorDispatcher, result, consumer, buffer) +} + +func (s *SecureChannel) getRequestHandle() uint32 { + return s.requestHandleGenerator.Add(1) +} + +func (s *SecureChannel) createKeepAlive() func() { + //TODO big wip: look for keepalive method not sure how to implement that properly + return nil +} + +func (s *SecureChannel) selectEndpoint(sessionResponse readWriteModel.CreateSessionResponse) { + // Get a list of the endpoints which match ours. + var filteredEndpoints []readWriteModel.EndpointDescription + for _, endpoint := range sessionResponse.GetServerEndpoints() { + endpointDescription := endpoint.(readWriteModel.EndpointDescription) + if s.isEndpoint(endpointDescription) { + filteredEndpoints = append(filteredEndpoints, endpointDescription) + } + } + + //Determine if the requested security policy is included in the endpoint + for _, endpoint := range filteredEndpoints { + userIdentityTokens := make([]readWriteModel.UserTokenPolicy, len(endpoint.GetUserIdentityTokens())) + for i, definition := range endpoint.GetUserIdentityTokens() { + userIdentityTokens[i] = definition.(readWriteModel.UserTokenPolicy) + } + s.hasIdentity(userIdentityTokens) + } + + if s.policyId == nil { + s.log.Error().Msgf("Unable to find endpoint - %s", s.endpoints[0]) + return + } + + if s.tokenType == 0xffffffff { // TODO: what did we use as undefined + s.log.Error().Msgf("Unable to find Security Policy for endpoint - %s", s.endpoints[0]) + return + } +} + +func (s *SecureChannel) isEndpoint(endpoint readWriteModel.EndpointDescription) bool { + // Split up the connection string into it's individual segments. + matches := utils.GetSubgroupMatches(URI_PATTERN, endpoint.GetEndpointUrl().GetStringValue()) + if len(matches) == 0 { + s.log.Error().Msgf("Endpoint returned from the server doesn't match the format '{protocol-code}:({transport-code})?//{transport-host}(:{transport-port})(/{transport-endpoint})'") + return false + } + s.log.Trace().Msgf("Using Endpoint %s %s %s", matches["transportHost"], matches["transportPort"], matches["transportEndpoint"]) + + if s.configuration.discovery && !slices.Contains(s.endpoints, matches["transportHost"]) { + return false + } + + if s.configuration.port != matches["transportPort"] { + return false + } + + if s.configuration.transportEndpoint != matches["transportEndpoint"] { + return false + } + + if !s.configuration.discovery { + s.configuration.host = matches["transportHost"] + } + + return true +} + +func (s *SecureChannel) hasIdentity(policies []readWriteModel.UserTokenPolicy) { + for _, identityToken := range policies { + if (identityToken.GetTokenType() == readWriteModel.UserTokenType_userTokenTypeAnonymous) && (s.username == "") { + s.policyId = identityToken.GetPolicyId() + s.tokenType = identityToken.GetTokenType() + } else if (identityToken.GetTokenType() == readWriteModel.UserTokenType_userTokenTypeUserName) && (s.username != "") { + s.policyId = identityToken.GetPolicyId() + s.tokenType = identityToken.GetTokenType() + } + } +} + +func (s *SecureChannel) getIdentityToken(tokenType readWriteModel.UserTokenType, value string) readWriteModel.ExtensionObject { + switch tokenType { + case readWriteModel.UserTokenType_userTokenTypeAnonymous: + //If we aren't using authentication tell the server we would like to log in anonymously + anonymousIdentityToken := readWriteModel.NewAnonymousIdentityToken() + extExpandedNodeId := readWriteModel.NewExpandedNodeId( + false, //Namespace Uri Specified + false, //Server Index Specified + readWriteModel.NewNodeIdFourByte( + 0, uint16(readWriteModel.OpcuaNodeIdServices_AnonymousIdentityToken_Encoding_DefaultBinary)), + nil, + nil, + ) + return readWriteModel.NewExtensionObject( + extExpandedNodeId, + readWriteModel.NewExtensionObjectEncodingMask(false, false, true), + readWriteModel.NewUserIdentityToken(readWriteModel.NewPascalString(s.securityPolicy), anonymousIdentityToken), + false, + ) + case readWriteModel.UserTokenType_userTokenTypeUserName: + //Encrypt the password using the server nonce and server public key + passwordBytes := []byte(s.password) + encodeableBuffer := new(bytes.Buffer) + var err error + err = binary.Write(encodeableBuffer, binary.LittleEndian, len(passwordBytes)+len(s.senderNonce)) + s.log.Debug().Err(err).Msg("write") + err = binary.Write(encodeableBuffer, binary.LittleEndian, passwordBytes) + s.log.Debug().Err(err).Msg("write") + err = binary.Write(encodeableBuffer, binary.LittleEndian, s.senderNonce) + s.log.Debug().Err(err).Msg("write") + encodeablePassword := make([]byte, 4+len(passwordBytes)+len(s.senderNonce)) + n, err := encodeableBuffer.Read(encodeablePassword) + s.log.Debug().Err(err).Int("n", n).Msg("read") + encryptedPassword := s.encryptionHandler.encryptPassword(encodeablePassword) + userNameIdentityToken := readWriteModel.NewUserNameIdentityToken( + readWriteModel.NewPascalString(s.username), + readWriteModel.NewPascalByteString(int32(len(encryptedPassword)), encryptedPassword), + readWriteModel.NewPascalString(PASSWORD_ENCRYPTION_ALGORITHM), + ) + extExpandedNodeId := readWriteModel.NewExpandedNodeId( + false, //Namespace Uri Specified + false, //Server Index Specified + readWriteModel.NewNodeIdFourByte(0, uint16(readWriteModel.OpcuaNodeIdServices_UserNameIdentityToken_Encoding_DefaultBinary)), + nil, + nil) + return readWriteModel.NewExtensionObject( + extExpandedNodeId, + readWriteModel.NewExtensionObjectEncodingMask(false, false, true), + readWriteModel.NewUserIdentityToken(readWriteModel.NewPascalString(s.securityPolicy), userNameIdentityToken), + false, + ) + } + return nil }