diff --git a/field/composite.go b/field/composite.go index 2edeea5..6612c89 100644 --- a/field/composite.go +++ b/field/composite.go @@ -8,6 +8,7 @@ import ( "reflect" "regexp" "strconv" + "sync" "github.com/moov-io/iso8583/encoding" "github.com/moov-io/iso8583/prefix" @@ -63,11 +64,15 @@ var _ json.Unmarshaler = (*Composite)(nil) // For the sake of determinism, packing of subfields is executed in order of Tag // (using Spec.Tag.Sort) regardless of the value of Spec.Tag.Length. type Composite struct { - spec *Spec - bitmap *Bitmap + spec *Spec + cachedBitmap *Bitmap orderedSpecFieldTags []string + // mu is used to synchronize access to the subfields and + // setSubfields maps when the composite is used concurrently + mu sync.Mutex + // stores all fields according to the spec subfields map[string]Field @@ -92,7 +97,13 @@ type CompositeWithSubfields interface { ConstructSubfields() } +// ConstructSubfields creates subfields according to the spec +// this method is used when composite field is created without +// calling NewComposite (when we create message spec and composite spec) func (f *Composite) ConstructSubfields() { + f.mu.Lock() + defer f.mu.Unlock() + if f.subfields == nil { f.subfields = CreateSubfields(f.spec) } @@ -106,6 +117,15 @@ func (f *Composite) Spec() *Spec { // GetSubfields returns the map of set sub fields func (f *Composite) GetSubfields() map[string]Field { + f.mu.Lock() + defer f.mu.Unlock() + + return f.getSubfields() +} + +// getSubfields returns the map of set sub fields, it should be called +// only when the mutex is locked +func (f *Composite) getSubfields() map[string]Field { fields := map[string]Field{} for i := range f.setSubfields { fields[i] = f.subfields[i] @@ -120,7 +140,7 @@ func (f *Composite) GetSubfields() map[string]Field { // will result in a panic. func (f *Composite) SetSpec(spec *Spec) { if err := spec.Validate(); err != nil { - panic(err) //nolint // as specs moslty static, we panic on spec validation errors + panic(err) //nolint:forbidigo,nolintlint // as specs moslty static, we panic on spec validation errors } f.spec = spec @@ -137,6 +157,9 @@ func (f *Composite) SetSpec(spec *Spec) { } func (f *Composite) Unmarshal(v interface{}) error { + f.mu.Lock() + defer f.mu.Unlock() + rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return errors.New("data is not a pointer or nil") @@ -203,6 +226,9 @@ func (f *Composite) SetData(v interface{}) error { // F4 *SubfieldCompositeData // } func (f *Composite) Marshal(v interface{}) error { + f.mu.Lock() + defer f.mu.Unlock() + rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return errors.New("data is not a pointer or nil") @@ -251,6 +277,9 @@ func (f *Composite) Marshal(v interface{}) error { // Pack deserialises data held by the receiver (via SetData) // into bytes and returns an error on failure. func (f *Composite) Pack() ([]byte, error) { + f.mu.Lock() + defer f.mu.Unlock() + packed, err := f.pack() if err != nil { return nil, err @@ -268,6 +297,9 @@ func (f *Composite) Pack() ([]byte, error) { // subfields. An offset (unit depends on encoding and prefix values) is // returned on success. A non-nil error is returned on failure. func (f *Composite) Unpack(data []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + dataLen, offset, err := f.spec.Pref.DecodeLength(f.spec.Length, data) if err != nil { return 0, fmt.Errorf("failed to decode length: %w", err) @@ -300,6 +332,9 @@ func (f *Composite) Unpack(data []byte) (int, error) { // pack all subfields in full. However, unlike Unpack(), it requires the // aggregate length of the subfields not to be encoded in the prefix. func (f *Composite) SetBytes(data []byte) error { + f.mu.Lock() + defer f.mu.Unlock() + _, err := f.unpack(data, false) return err } @@ -308,14 +343,24 @@ func (f *Composite) SetBytes(data []byte) error { // does not incorporate the encoded aggregate length of the subfields in the // prefix. func (f *Composite) Bytes() ([]byte, error) { + f.mu.Lock() + defer f.mu.Unlock() + return f.pack() } // Bitmap returns the parsed bitmap instantiated on the key "0" of the spec. // In case the bitmap is not instantiated on the spec, returns nil. func (f *Composite) Bitmap() *Bitmap { - if f.bitmap != nil { - return f.bitmap + f.mu.Lock() + defer f.mu.Unlock() + + return f.bitmap() +} + +func (f *Composite) bitmap() *Bitmap { + if f.cachedBitmap != nil { + return f.cachedBitmap } if f.spec.Bitmap == nil { @@ -327,9 +372,9 @@ func (f *Composite) Bitmap() *Bitmap { return nil } - f.bitmap = bitmap + f.cachedBitmap = bitmap - return f.bitmap + return f.cachedBitmap } // String iterates over the receiver's subfields, packs them and converts the @@ -345,7 +390,10 @@ func (f *Composite) String() (string, error) { // MarshalJSON implements the encoding/json.Marshaler interface. func (f *Composite) MarshalJSON() ([]byte, error) { - jsonData := OrderedMap(f.GetSubfields()) + f.mu.Lock() + defer f.mu.Unlock() + + jsonData := OrderedMap(f.getSubfields()) bytes, err := json.Marshal(jsonData) if err != nil { return nil, utils.NewSafeError(err, "failed to JSON marshal map to bytes") @@ -357,6 +405,9 @@ func (f *Composite) MarshalJSON() ([]byte, error) { // An error is thrown if the JSON consists of a subfield that has not // been defined in the spec. func (f *Composite) UnmarshalJSON(b []byte) error { + f.mu.Lock() + defer f.mu.Unlock() + var data map[string]json.RawMessage err := json.Unmarshal(b, &data) if err != nil { @@ -384,7 +435,7 @@ func (f *Composite) UnmarshalJSON(b []byte) error { } func (f *Composite) pack() ([]byte, error) { - if f.Bitmap() != nil { + if f.bitmap() != nil { return f.packByBitmap() } @@ -392,7 +443,7 @@ func (f *Composite) pack() ([]byte, error) { } func (f *Composite) packByBitmap() ([]byte, error) { - f.Bitmap().Reset() + f.bitmap().Reset() var packedFields []byte @@ -409,7 +460,7 @@ func (f *Composite) packByBitmap() ([]byte, error) { } // set bitmap bit for this field - f.Bitmap().Set(idInt) + f.bitmap().Set(idInt) field, ok := f.subfields[id] if !ok { @@ -425,7 +476,7 @@ func (f *Composite) packByBitmap() ([]byte, error) { } // pack bitmap. - packedBitmap, err := f.Bitmap().Pack() + packedBitmap, err := f.bitmap().Pack() if err != nil { return nil, fmt.Errorf("packing bitmap: %w", err) } @@ -469,7 +520,7 @@ func (f *Composite) packByTag() ([]byte, error) { } func (f *Composite) unpack(data []byte, isVariableLength bool) (int, error) { - if f.Bitmap() != nil { + if f.bitmap() != nil { return f.unpackSubfieldsByBitmap(data) } if f.spec.Tag.Enc != nil { @@ -509,17 +560,17 @@ func (f *Composite) unpackSubfieldsByBitmap(data []byte) (int, error) { // Reset fields that were set. f.setSubfields = make(map[string]struct{}) - f.Bitmap().Reset() + f.bitmap().Reset() - read, err := f.Bitmap().Unpack(data[off:]) + read, err := f.bitmap().Unpack(data[off:]) if err != nil { return 0, fmt.Errorf("failed to unpack bitmap: %w", err) } off += read - for i := 1; i <= f.Bitmap().Len(); i++ { - if f.Bitmap().IsSet(i) { + for i := 1; i <= f.bitmap().Len(); i++ { + if f.bitmap().IsSet(i) { iStr := strconv.Itoa(i) fl, ok := f.subfields[iStr] diff --git a/field/composite_test.go b/field/composite_test.go index 874273a..78d7fe0 100644 --- a/field/composite_test.go +++ b/field/composite_test.go @@ -1,9 +1,11 @@ package field import ( + "encoding/hex" "fmt" "reflect" "strconv" + "sync" "testing" "github.com/moov-io/iso8583/encoding" @@ -1892,3 +1894,60 @@ func TestComposite_getFieldIndexOrTag(t *testing.T) { require.Empty(t, index) }) } + +func TestComposit_concurrency(t *testing.T) { + t.Run("Pack and Marshal", func(t *testing.T) { + // packing and marshaling + data := &TLVTestData{ + F9A: NewHexValue("210720"), + F9F02: NewHexValue("000000000501"), + } + + composite := NewComposite(tlvTestSpec) + + wg := sync.WaitGroup{} + + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + err := composite.Marshal(data) + require.NoError(t, err) + + _, err = composite.Pack() + require.NoError(t, err) + }() + } + + wg.Wait() + }) + + t.Run("Unpack and Unmarshal", func(t *testing.T) { + packed, err := hex.DecodeString("3031349A032107209F0206000000000501") + require.NoError(t, err) + + composite := NewComposite(tlvTestSpec) + + wg := sync.WaitGroup{} + wg.Add(5) + + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + + data := &TLVTestData{} + _, err := composite.Unpack(packed) + require.NoError(t, err) + + err = composite.Unmarshal(data) + require.NoError(t, err) + + require.Equal(t, "210720", data.F9A.Value()) + require.Equal(t, "000000000501", data.F9F02.Value()) + }() + } + + wg.Wait() + }) +} diff --git a/message.go b/message.go index 51c7a47..0153850 100644 --- a/message.go +++ b/message.go @@ -8,6 +8,7 @@ import ( "regexp" "sort" "strconv" + "sync" "github.com/moov-io/iso8583/field" "github.com/moov-io/iso8583/utils" @@ -22,12 +23,15 @@ const ( ) type Message struct { - spec *MessageSpec - bitmap *field.Bitmap + spec *MessageSpec + cachedBitmap *field.Bitmap // stores all fields according to the spec fields map[int]field.Field + // to guard fieldsMap + mu sync.Mutex + // tracks which fields were set fieldsMap map[int]struct{} } @@ -35,7 +39,7 @@ type Message struct { func NewMessage(spec *MessageSpec) *Message { // Validate the spec if err := spec.Validate(); err != nil { - panic(err) //nolint // as specs moslty static, we panic on spec validation errors + panic(err) //nolint:forbidigo,nolintlint // as specs moslty static, we panic on spec validation errors } fields := spec.CreateMessageFields() @@ -53,21 +57,34 @@ func (m *Message) SetData(data interface{}) error { } func (m *Message) Bitmap() *field.Bitmap { - if m.bitmap != nil { - return m.bitmap + m.mu.Lock() + defer m.mu.Unlock() + + return m.bitmap() +} + +// bitmap creates and returns the bitmap field, it's not thread safe +// and should be called from a thread safe function +func (m *Message) bitmap() *field.Bitmap { + if m.cachedBitmap != nil { + return m.cachedBitmap } // We validate the presence and type of the bitmap field in // spec.Validate() when we create the message so we can safely assume // it exists and is of the correct type - m.bitmap, _ = m.fields[bitmapIdx].(*field.Bitmap) - m.bitmap.Reset() + m.cachedBitmap, _ = m.fields[bitmapIdx].(*field.Bitmap) + m.cachedBitmap.Reset() + m.fieldsMap[bitmapIdx] = struct{}{} - return m.bitmap + return m.cachedBitmap } func (m *Message) MTI(val string) { + m.mu.Lock() + defer m.mu.Unlock() + m.fieldsMap[mtiIdx] = struct{}{} m.fields[mtiIdx].SetBytes([]byte(val)) } @@ -77,6 +94,9 @@ func (m *Message) GetSpec() *MessageSpec { } func (m *Message) Field(id int, val string) error { + m.mu.Lock() + defer m.mu.Unlock() + if f, ok := m.fields[id]; ok { m.fieldsMap[id] = struct{}{} return f.SetBytes([]byte(val)) @@ -85,6 +105,9 @@ func (m *Message) Field(id int, val string) error { } func (m *Message) BinaryField(id int, val []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + if f, ok := m.fields[id]; ok { m.fieldsMap[id] = struct{}{} return f.SetBytes(val) @@ -93,13 +116,15 @@ func (m *Message) BinaryField(id int, val []byte) error { } func (m *Message) GetMTI() (string, error) { - // check index + // we validate the presence and type of the mti field in + // spec.Validate() when we create the message so we can safely assume + // it exists return m.fields[mtiIdx].String() } func (m *Message) GetString(id int) (string, error) { if f, ok := m.fields[id]; ok { - m.fieldsMap[id] = struct{}{} + // m.fieldsMap[id] = struct{}{} return f.String() } return "", fmt.Errorf("failed to get string for field %d. ID does not exist", id) @@ -107,7 +132,7 @@ func (m *Message) GetString(id int) (string, error) { func (m *Message) GetBytes(id int) ([]byte, error) { if f, ok := m.fields[id]; ok { - m.fieldsMap[id] = struct{}{} + // m.fieldsMap[id] = struct{}{} return f.Bytes() } return nil, fmt.Errorf("failed to get bytes for field %d. ID does not exist", id) @@ -119,6 +144,15 @@ func (m *Message) GetField(id int) field.Field { // Fields returns the map of the set fields func (m *Message) GetFields() map[int]field.Field { + m.mu.Lock() + defer m.mu.Unlock() + + return m.getFields() +} + +// getFields returns the map of the set fields. It assumes that the mutex +// is already locked by the caller. +func (m *Message) getFields() map[int]field.Field { fields := map[int]field.Field{} for i := range m.fieldsMap { fields[i] = m.GetField(i) @@ -126,9 +160,20 @@ func (m *Message) GetFields() map[int]field.Field { return fields } -// Pack returns the packed message or an error if the message is invalid -// error is of type *PackError +// Pack locks the message, packs its fields, and then unlocks it. +// If any errors are encountered during packing, they will be wrapped +// in a *PackError before being returned. func (m *Message) Pack() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.wrapErrorPack() + +} + +// wrapErrorPack calls the core packing logic and wraps any errors in a +// *PackError. It assumes that the mutex is already locked by the caller. +func (m *Message) wrapErrorPack() ([]byte, error) { data, err := m.pack() if err != nil { return nil, &PackError{Err: err} @@ -137,9 +182,12 @@ func (m *Message) Pack() ([]byte, error) { return data, nil } +// pack contains the core logic for packing the message. This method does not +// handle locking or error wrapping and should typically be used internally +// after ensuring concurrency safety. func (m *Message) pack() ([]byte, error) { packed := []byte{} - m.Bitmap().Reset() + m.bitmap().Reset() ids, err := m.packableFieldIDs() if err != nil { @@ -150,16 +198,16 @@ func (m *Message) pack() ([]byte, error) { // indexes 0 and 1 are for mti and bitmap // regular field number startd from index 2 // do not pack presence bits as well - if id < 2 || m.Bitmap().IsBitmapPresenceBit(id) { + if id < 2 || m.bitmap().IsBitmapPresenceBit(id) { continue } - m.Bitmap().Set(id) + m.bitmap().Set(id) } // pack fields for _, i := range ids { // do not pack presence bits other than the first one as it's the bitmap itself - if i != 1 && m.Bitmap().IsBitmapPresenceBit(i) { + if i != 1 && m.bitmap().IsBitmapPresenceBit(i) { continue } @@ -180,6 +228,16 @@ func (m *Message) pack() ([]byte, error) { // Unpack unpacks the message from the given byte slice or returns an error // which is of type *UnpackError and contains the raw message func (m *Message) Unpack(src []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.wrappErrorUnpack(src) +} + +// wrappErrorUnpack calls the core unpacking logic and wraps any +// errors in a *UnpackError. It assumes that the mutex is already +// locked by the caller. +func (m *Message) wrappErrorUnpack(src []byte) error { if err := m.unpack(src); err != nil { return &UnpackError{ Err: err, @@ -189,6 +247,9 @@ func (m *Message) Unpack(src []byte) error { return nil } +// unpack contains the core logic for unpacking the message. This method does +// not handle locking or error wrapping and should typically be used internally +// after ensuring concurrency safety. func (m *Message) unpack(src []byte) error { var off int @@ -196,7 +257,7 @@ func (m *Message) unpack(src []byte) error { m.fieldsMap = map[int]struct{}{} // This method implicitly also sets m.fieldsMap[bitmapIdx] - m.Bitmap().Reset() + m.bitmap().Reset() read, err := m.fields[mtiIdx].Unpack(src) if err != nil { @@ -215,13 +276,13 @@ func (m *Message) unpack(src []byte) error { off += read - for i := 2; i <= m.Bitmap().Len(); i++ { + for i := 2; i <= m.bitmap().Len(); i++ { // skip bitmap presence bits (for default bitmap length of 64 these are bits 1, 65, 129, 193, etc.) - if m.Bitmap().IsBitmapPresenceBit(i) { + if m.bitmap().IsBitmapPresenceBit(i) { continue } - if m.Bitmap().IsSet(i) { + if m.bitmap().IsSet(i) { fl, ok := m.fields[i] if !ok { return fmt.Errorf("failed to unpack field %d: no specification found", i) @@ -242,14 +303,17 @@ func (m *Message) unpack(src []byte) error { } func (m *Message) MarshalJSON() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + // by packing the message we will generate bitmap // create HEX representation // and validate message against the spec - if _, err := m.Pack(); err != nil { + if _, err := m.wrapErrorPack(); err != nil { return nil, err } - fieldMap := m.GetFields() + fieldMap := m.getFields() strFieldMap := map[string]field.Field{} for id, field := range fieldMap { strFieldMap[fmt.Sprint(id)] = field @@ -264,6 +328,9 @@ func (m *Message) MarshalJSON() ([]byte, error) { } func (m *Message) UnmarshalJSON(b []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + var data map[string]json.RawMessage if err := json.Unmarshal(b, &data); err != nil { return err @@ -311,9 +378,12 @@ func (m *Message) packableFieldIDs() ([]int, error) { // Clone clones the message by creating a new message from the binary // representation of the original message func (m *Message) Clone() (*Message, error) { + m.mu.Lock() + defer m.mu.Unlock() + newMessage := NewMessage(m.spec) - bytes, err := m.Pack() + bytes, err := m.wrapErrorPack() if err != nil { return nil, err } @@ -338,6 +408,9 @@ func (m *Message) Clone() (*Message, error) { // through the message fields and calls Unmarshal(...) on them setting the v If // v is not a struct or not a pointer to struct then it returns error. func (m *Message) Marshal(v interface{}) error { + m.mu.Lock() + defer m.mu.Unlock() + if v == nil { return nil } @@ -391,6 +464,9 @@ func (m *Message) Marshal(v interface{}) error { // through the message fields and calls Unmarshal(...) on them setting the v If // v is nil or not a pointer it returns error. func (m *Message) Unmarshal(v interface{}) error { + m.mu.Lock() + defer m.mu.Unlock() + rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return errors.New("data is not a pointer or nil") diff --git a/message_test.go b/message_test.go index 2989d4c..7968340 100644 --- a/message_test.go +++ b/message_test.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "encoding/json" "reflect" + "sync" "testing" "time" @@ -90,6 +91,27 @@ func TestMessage(t *testing.T) { }, } + // this test most probably will fail in regular mode, + // and should fail when is run with -race flag + t.Run("No data race when accessing fields concurrently", func(t *testing.T) { + message := NewMessage(spec) + + var wg sync.WaitGroup + + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + // calling GetString writes into the map of the + // set fields + message.GetString(0) + }() + } + + wg.Wait() + }) + t.Run("Test packing and unpacking untyped fields", func(t *testing.T) { message := NewMessage(spec) message.MTI("0100") diff --git a/test/fuzz-reader/reader.go b/test/fuzz-reader/reader.go index c66c682..d939bbc 100644 --- a/test/fuzz-reader/reader.go +++ b/test/fuzz-reader/reader.go @@ -40,7 +40,7 @@ func Fuzz(data []byte) int { _, err = message.Pack() if err != nil { - panic(fmt.Errorf("failed to pack unpacked message: %w", err)) //nolint + panic(fmt.Errorf("failed to pack unpacked message: %w", err)) //nolint:forbidigo,nolintlint } return 1