From 805db80d9bd3a9c192118c6acd8e7ceca8779a64 Mon Sep 17 00:00:00 2001 From: Pavel Gabriel Date: Thu, 21 Sep 2023 19:58:22 +0200 Subject: [PATCH] make message thread safe --- message.go | 131 ++++++++++++++++++++++++++++++------------------ message_test.go | 8 --- 2 files changed, 82 insertions(+), 57 deletions(-) diff --git a/message.go b/message.go index 980e24d..9823085 100644 --- a/message.go +++ b/message.go @@ -23,8 +23,8 @@ const ( ) type Message struct { - spec *MessageSpec - bitmap *field.Bitmap + spec *MessageSpec + cachedBitmap *field.Bitmap // stores all fields according to the spec fields map[int]field.Field @@ -57,28 +57,35 @@ func (m *Message) SetData(data interface{}) error { } func (m *Message) Bitmap() *field.Bitmap { - if m.bitmap != nil { - return m.bitmap + m.mu.Lock() + defer m.mu.Unlock() + + return m.bitmap() +} + +// bitmap creates and returns the bitmap field, it's not thread safe +// and should be called from a thread safe function +func (m *Message) bitmap() *field.Bitmap { + if m.cachedBitmap != nil { + return m.cachedBitmap } // We validate the presence and type of the bitmap field in // spec.Validate() when we create the message so we can safely assume // it exists and is of the correct type - m.bitmap, _ = m.fields[bitmapIdx].(*field.Bitmap) - m.bitmap.Reset() + m.cachedBitmap, _ = m.fields[bitmapIdx].(*field.Bitmap) + m.cachedBitmap.Reset() - m.mu.Lock() m.fieldsMap[bitmapIdx] = struct{}{} - m.mu.Unlock() - return m.bitmap + return m.cachedBitmap } func (m *Message) MTI(val string) { m.mu.Lock() - m.fieldsMap[mtiIdx] = struct{}{} - m.mu.Unlock() + defer m.mu.Unlock() + m.fieldsMap[mtiIdx] = struct{}{} m.fields[mtiIdx].SetBytes([]byte(val)) } @@ -87,37 +94,37 @@ func (m *Message) GetSpec() *MessageSpec { } func (m *Message) Field(id int, val string) error { + m.mu.Lock() + defer m.mu.Unlock() + if f, ok := m.fields[id]; ok { - m.mu.Lock() m.fieldsMap[id] = struct{}{} - m.mu.Unlock() return f.SetBytes([]byte(val)) } return fmt.Errorf("failed to set field %d. ID does not exist", id) } func (m *Message) BinaryField(id int, val []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + if f, ok := m.fields[id]; ok { - m.mu.Lock() m.fieldsMap[id] = struct{}{} - m.mu.Unlock() - return f.SetBytes(val) } return fmt.Errorf("failed to set binary field %d. ID does not exist", id) } func (m *Message) GetMTI() (string, error) { - // check index + // we validate the presence and type of the mti field in + // spec.Validate() when we create the message so we can safely assume + // it exists return m.fields[mtiIdx].String() } func (m *Message) GetString(id int) (string, error) { if f, ok := m.fields[id]; ok { - m.mu.Lock() - m.fieldsMap[id] = struct{}{} - m.mu.Unlock() - + // m.fieldsMap[id] = struct{}{} return f.String() } return "", fmt.Errorf("failed to get string for field %d. ID does not exist", id) @@ -125,10 +132,7 @@ func (m *Message) GetString(id int) (string, error) { func (m *Message) GetBytes(id int) ([]byte, error) { if f, ok := m.fields[id]; ok { - m.mu.Lock() - m.fieldsMap[id] = struct{}{} - m.mu.Unlock() - + // m.fieldsMap[id] = struct{}{} return f.Bytes() } return nil, fmt.Errorf("failed to get bytes for field %d. ID does not exist", id) @@ -143,6 +147,12 @@ func (m *Message) GetFields() map[int]field.Field { m.mu.Lock() defer m.mu.Unlock() + return m.getFields() +} + +// getFields returns the map of the set fields. It assumes that the mutex +// is already locked by the caller. +func (m *Message) getFields() map[int]field.Field { fields := map[int]field.Field{} for i := range m.fieldsMap { fields[i] = m.GetField(i) @@ -150,9 +160,20 @@ func (m *Message) GetFields() map[int]field.Field { return fields } -// Pack returns the packed message or an error if the message is invalid -// error is of type *PackError +// Pack locks the message, packs its fields, and then unlocks it. +// If any errors are encountered during packing, they will be wrapped +// in a *PackError before being returned. func (m *Message) Pack() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.wrapErrorPack() + +} + +// wrapErrorPack calls the core packing logic and wraps any errors in a +// *PackError. It assumes that the mutex is already locked by the caller. +func (m *Message) wrapErrorPack() ([]byte, error) { data, err := m.pack() if err != nil { return nil, &PackError{Err: err} @@ -161,9 +182,12 @@ func (m *Message) Pack() ([]byte, error) { return data, nil } +// pack contains the core logic for packing the message. This method does not +// handle locking or error wrapping and should typically be used internally +// after ensuring concurrency safety. func (m *Message) pack() ([]byte, error) { packed := []byte{} - m.Bitmap().Reset() + m.bitmap().Reset() ids, err := m.packableFieldIDs() if err != nil { @@ -174,16 +198,16 @@ func (m *Message) pack() ([]byte, error) { // indexes 0 and 1 are for mti and bitmap // regular field number startd from index 2 // do not pack presence bits as well - if id < 2 || m.Bitmap().IsBitmapPresenceBit(id) { + if id < 2 || m.bitmap().IsBitmapPresenceBit(id) { continue } - m.Bitmap().Set(id) + m.bitmap().Set(id) } // pack fields for _, i := range ids { // do not pack presence bits other than the first one as it's the bitmap itself - if i != 1 && m.Bitmap().IsBitmapPresenceBit(i) { + if i != 1 && m.bitmap().IsBitmapPresenceBit(i) { continue } @@ -204,6 +228,16 @@ func (m *Message) pack() ([]byte, error) { // Unpack unpacks the message from the given byte slice or returns an error // which is of type *UnpackError and contains the raw message func (m *Message) Unpack(src []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.wrappErrorUnpack(src) +} + +// wrappErrorUnpack calls the core unpacking logic and wraps any +// errors in a *UnpackError. It assumes that the mutex is already +// locked by the caller. +func (m *Message) wrappErrorUnpack(src []byte) error { if err := m.unpack(src); err != nil { return &UnpackError{ Err: err, @@ -213,22 +247,17 @@ func (m *Message) Unpack(src []byte) error { return nil } +// unpack contains the core logic for unpacking the message. This method does +// not handle locking or error wrapping and should typically be used internally +// after ensuring concurrency safety. func (m *Message) unpack(src []byte) error { var off int - m.mu.Lock() // reset fields that were set m.fieldsMap = map[int]struct{}{} - // we unlock here as m.Bitmap() will lock the mutex again - m.mu.Unlock() - bitmap := m.Bitmap() // This method implicitly also sets m.fieldsMap[bitmapIdx] - bitmap.Reset() - - // lock the mutex again as we're going to set fields - m.mu.Lock() - defer m.mu.Unlock() + m.bitmap().Reset() read, err := m.fields[mtiIdx].Unpack(src) if err != nil { @@ -247,13 +276,13 @@ func (m *Message) unpack(src []byte) error { off += read - for i := 2; i <= bitmap.Len(); i++ { + for i := 2; i <= m.bitmap().Len(); i++ { // skip bitmap presence bits (for default bitmap length of 64 these are bits 1, 65, 129, 193, etc.) - if bitmap.IsBitmapPresenceBit(i) { + if m.bitmap().IsBitmapPresenceBit(i) { continue } - if bitmap.IsSet(i) { + if m.bitmap().IsSet(i) { fl, ok := m.fields[i] if !ok { return fmt.Errorf("failed to unpack field %d: no specification found", i) @@ -273,15 +302,19 @@ func (m *Message) unpack(src []byte) error { return nil } +// TODO: protect against concurrent access func (m *Message) MarshalJSON() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + // by packing the message we will generate bitmap // create HEX representation // and validate message against the spec - if _, err := m.Pack(); err != nil { + if _, err := m.wrapErrorPack(); err != nil { return nil, err } - fieldMap := m.GetFields() + fieldMap := m.getFields() strFieldMap := map[string]field.Field{} for id, field := range fieldMap { strFieldMap[fmt.Sprint(id)] = field @@ -326,9 +359,6 @@ func (m *Message) UnmarshalJSON(b []byte) error { } func (m *Message) packableFieldIDs() ([]int, error) { - m.mu.Lock() - defer m.mu.Unlock() - // Index 1 represent bitmap which is always populated. populatedFieldIDs := []int{1} @@ -349,9 +379,12 @@ func (m *Message) packableFieldIDs() ([]int, error) { // Clone clones the message by creating a new message from the binary // representation of the original message func (m *Message) Clone() (*Message, error) { + m.mu.Lock() + defer m.mu.Unlock() + newMessage := NewMessage(m.spec) - bytes, err := m.Pack() + bytes, err := m.wrapErrorPack() if err != nil { return nil, err } diff --git a/message_test.go b/message_test.go index 2eb4868..7968340 100644 --- a/message_test.go +++ b/message_test.go @@ -3,8 +3,6 @@ package iso8583 import ( "encoding/hex" "encoding/json" - "log" - "net/http" "reflect" "sync" "testing" @@ -17,15 +15,9 @@ import ( "github.com/moov-io/iso8583/sort" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - _ "net/http/pprof" ) func TestMessage(t *testing.T) { - go func() { - log.Println(http.ListenAndServe("localhost:6060", nil)) - }() - spec := &MessageSpec{ Fields: map[int]field.Field{ 0: field.NewString(&field.Spec{