Skip to content

Commit

Permalink
Make the package safe for concurrent access (#284)
Browse files Browse the repository at this point in the history
* draft data race solution for message

* make message thread safe

* remove //nolint as it does not work

moov-io/infra#280

* add test for concurrent access

* protect Composite field against concurrent access

* satisfy linkers

* remove unnecessary comments and address feedback
  • Loading branch information
alovak authored Sep 26, 2023
1 parent b652df4 commit b05481c
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 42 deletions.
85 changes: 68 additions & 17 deletions field/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"reflect"
"regexp"
"strconv"
"sync"

"github.com/moov-io/iso8583/encoding"
"github.com/moov-io/iso8583/prefix"
Expand Down Expand Up @@ -63,11 +64,15 @@ var _ json.Unmarshaler = (*Composite)(nil)
// For the sake of determinism, packing of subfields is executed in order of Tag
// (using Spec.Tag.Sort) regardless of the value of Spec.Tag.Length.
type Composite struct {
spec *Spec
bitmap *Bitmap
spec *Spec
cachedBitmap *Bitmap

orderedSpecFieldTags []string

// mu is used to synchronize access to the subfields and
// setSubfields maps when the composite is used concurrently
mu sync.Mutex

// stores all fields according to the spec
subfields map[string]Field

Expand All @@ -92,7 +97,13 @@ type CompositeWithSubfields interface {
ConstructSubfields()
}

// ConstructSubfields creates subfields according to the spec
// this method is used when composite field is created without
// calling NewComposite (when we create message spec and composite spec)
func (f *Composite) ConstructSubfields() {
f.mu.Lock()
defer f.mu.Unlock()

if f.subfields == nil {
f.subfields = CreateSubfields(f.spec)
}
Expand All @@ -106,6 +117,15 @@ func (f *Composite) Spec() *Spec {

// GetSubfields returns the map of set sub fields
func (f *Composite) GetSubfields() map[string]Field {
f.mu.Lock()
defer f.mu.Unlock()

return f.getSubfields()
}

// getSubfields returns the map of set sub fields, it should be called
// only when the mutex is locked
func (f *Composite) getSubfields() map[string]Field {
fields := map[string]Field{}
for i := range f.setSubfields {
fields[i] = f.subfields[i]
Expand All @@ -120,7 +140,7 @@ func (f *Composite) GetSubfields() map[string]Field {
// will result in a panic.
func (f *Composite) SetSpec(spec *Spec) {
if err := spec.Validate(); err != nil {
panic(err) //nolint // as specs moslty static, we panic on spec validation errors
panic(err) //nolint:forbidigo,nolintlint // as specs moslty static, we panic on spec validation errors
}
f.spec = spec

Expand All @@ -137,6 +157,9 @@ func (f *Composite) SetSpec(spec *Spec) {
}

func (f *Composite) Unmarshal(v interface{}) error {
f.mu.Lock()
defer f.mu.Unlock()

rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("data is not a pointer or nil")
Expand Down Expand Up @@ -203,6 +226,9 @@ func (f *Composite) SetData(v interface{}) error {
// F4 *SubfieldCompositeData
// }
func (f *Composite) Marshal(v interface{}) error {
f.mu.Lock()
defer f.mu.Unlock()

rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("data is not a pointer or nil")
Expand Down Expand Up @@ -251,6 +277,9 @@ func (f *Composite) Marshal(v interface{}) error {
// Pack deserialises data held by the receiver (via SetData)
// into bytes and returns an error on failure.
func (f *Composite) Pack() ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()

packed, err := f.pack()
if err != nil {
return nil, err
Expand All @@ -268,6 +297,9 @@ func (f *Composite) Pack() ([]byte, error) {
// subfields. An offset (unit depends on encoding and prefix values) is
// returned on success. A non-nil error is returned on failure.
func (f *Composite) Unpack(data []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

dataLen, offset, err := f.spec.Pref.DecodeLength(f.spec.Length, data)
if err != nil {
return 0, fmt.Errorf("failed to decode length: %w", err)
Expand Down Expand Up @@ -300,6 +332,9 @@ func (f *Composite) Unpack(data []byte) (int, error) {
// pack all subfields in full. However, unlike Unpack(), it requires the
// aggregate length of the subfields not to be encoded in the prefix.
func (f *Composite) SetBytes(data []byte) error {
f.mu.Lock()
defer f.mu.Unlock()

_, err := f.unpack(data, false)
return err
}
Expand All @@ -308,14 +343,24 @@ func (f *Composite) SetBytes(data []byte) error {
// does not incorporate the encoded aggregate length of the subfields in the
// prefix.
func (f *Composite) Bytes() ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()

return f.pack()
}

// Bitmap returns the parsed bitmap instantiated on the key "0" of the spec.
// In case the bitmap is not instantiated on the spec, returns nil.
func (f *Composite) Bitmap() *Bitmap {
if f.bitmap != nil {
return f.bitmap
f.mu.Lock()
defer f.mu.Unlock()

return f.bitmap()
}

func (f *Composite) bitmap() *Bitmap {
if f.cachedBitmap != nil {
return f.cachedBitmap
}

if f.spec.Bitmap == nil {
Expand All @@ -327,9 +372,9 @@ func (f *Composite) Bitmap() *Bitmap {
return nil
}

f.bitmap = bitmap
f.cachedBitmap = bitmap

return f.bitmap
return f.cachedBitmap
}

// String iterates over the receiver's subfields, packs them and converts the
Expand All @@ -345,7 +390,10 @@ func (f *Composite) String() (string, error) {

// MarshalJSON implements the encoding/json.Marshaler interface.
func (f *Composite) MarshalJSON() ([]byte, error) {
jsonData := OrderedMap(f.GetSubfields())
f.mu.Lock()
defer f.mu.Unlock()

jsonData := OrderedMap(f.getSubfields())
bytes, err := json.Marshal(jsonData)
if err != nil {
return nil, utils.NewSafeError(err, "failed to JSON marshal map to bytes")
Expand All @@ -357,6 +405,9 @@ func (f *Composite) MarshalJSON() ([]byte, error) {
// An error is thrown if the JSON consists of a subfield that has not
// been defined in the spec.
func (f *Composite) UnmarshalJSON(b []byte) error {
f.mu.Lock()
defer f.mu.Unlock()

var data map[string]json.RawMessage
err := json.Unmarshal(b, &data)
if err != nil {
Expand Down Expand Up @@ -384,15 +435,15 @@ func (f *Composite) UnmarshalJSON(b []byte) error {
}

func (f *Composite) pack() ([]byte, error) {
if f.Bitmap() != nil {
if f.bitmap() != nil {
return f.packByBitmap()
}

return f.packByTag()
}

func (f *Composite) packByBitmap() ([]byte, error) {
f.Bitmap().Reset()
f.bitmap().Reset()

var packedFields []byte

Expand All @@ -409,7 +460,7 @@ func (f *Composite) packByBitmap() ([]byte, error) {
}

// set bitmap bit for this field
f.Bitmap().Set(idInt)
f.bitmap().Set(idInt)

field, ok := f.subfields[id]
if !ok {
Expand All @@ -425,7 +476,7 @@ func (f *Composite) packByBitmap() ([]byte, error) {
}

// pack bitmap.
packedBitmap, err := f.Bitmap().Pack()
packedBitmap, err := f.bitmap().Pack()
if err != nil {
return nil, fmt.Errorf("packing bitmap: %w", err)
}
Expand Down Expand Up @@ -469,7 +520,7 @@ func (f *Composite) packByTag() ([]byte, error) {
}

func (f *Composite) unpack(data []byte, isVariableLength bool) (int, error) {
if f.Bitmap() != nil {
if f.bitmap() != nil {
return f.unpackSubfieldsByBitmap(data)
}
if f.spec.Tag.Enc != nil {
Expand Down Expand Up @@ -509,17 +560,17 @@ func (f *Composite) unpackSubfieldsByBitmap(data []byte) (int, error) {
// Reset fields that were set.
f.setSubfields = make(map[string]struct{})

f.Bitmap().Reset()
f.bitmap().Reset()

read, err := f.Bitmap().Unpack(data[off:])
read, err := f.bitmap().Unpack(data[off:])
if err != nil {
return 0, fmt.Errorf("failed to unpack bitmap: %w", err)
}

off += read

for i := 1; i <= f.Bitmap().Len(); i++ {
if f.Bitmap().IsSet(i) {
for i := 1; i <= f.bitmap().Len(); i++ {
if f.bitmap().IsSet(i) {
iStr := strconv.Itoa(i)

fl, ok := f.subfields[iStr]
Expand Down
59 changes: 59 additions & 0 deletions field/composite_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package field

import (
"encoding/hex"
"fmt"
"reflect"
"strconv"
"sync"
"testing"

"github.com/moov-io/iso8583/encoding"
Expand Down Expand Up @@ -1892,3 +1894,60 @@ func TestComposite_getFieldIndexOrTag(t *testing.T) {
require.Empty(t, index)
})
}

func TestComposit_concurrency(t *testing.T) {
t.Run("Pack and Marshal", func(t *testing.T) {
// packing and marshaling
data := &TLVTestData{
F9A: NewHexValue("210720"),
F9F02: NewHexValue("000000000501"),
}

composite := NewComposite(tlvTestSpec)

wg := sync.WaitGroup{}

for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()

err := composite.Marshal(data)
require.NoError(t, err)

_, err = composite.Pack()
require.NoError(t, err)
}()
}

wg.Wait()
})

t.Run("Unpack and Unmarshal", func(t *testing.T) {
packed, err := hex.DecodeString("3031349A032107209F0206000000000501")
require.NoError(t, err)

composite := NewComposite(tlvTestSpec)

wg := sync.WaitGroup{}
wg.Add(5)

for i := 0; i < 5; i++ {
go func() {
defer wg.Done()

data := &TLVTestData{}
_, err := composite.Unpack(packed)
require.NoError(t, err)

err = composite.Unmarshal(data)
require.NoError(t, err)

require.Equal(t, "210720", data.F9A.Value())
require.Equal(t, "000000000501", data.F9F02.Value())
}()
}

wg.Wait()
})
}
Loading

0 comments on commit b05481c

Please sign in to comment.