Skip to content

Commit

Permalink
make message thread safe
Browse files Browse the repository at this point in the history
  • Loading branch information
alovak committed Sep 21, 2023
1 parent 4b3d7eb commit 805db80
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 57 deletions.
131 changes: 82 additions & 49 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ const (
)

type Message struct {
spec *MessageSpec
bitmap *field.Bitmap
spec *MessageSpec
cachedBitmap *field.Bitmap

// stores all fields according to the spec
fields map[int]field.Field
Expand Down Expand Up @@ -57,28 +57,35 @@ func (m *Message) SetData(data interface{}) error {
}

func (m *Message) Bitmap() *field.Bitmap {
if m.bitmap != nil {
return m.bitmap
m.mu.Lock()
defer m.mu.Unlock()

return m.bitmap()
}

// bitmap creates and returns the bitmap field, it's not thread safe
// and should be called from a thread safe function
func (m *Message) bitmap() *field.Bitmap {
if m.cachedBitmap != nil {
return m.cachedBitmap
}

// We validate the presence and type of the bitmap field in
// spec.Validate() when we create the message so we can safely assume
// it exists and is of the correct type
m.bitmap, _ = m.fields[bitmapIdx].(*field.Bitmap)
m.bitmap.Reset()
m.cachedBitmap, _ = m.fields[bitmapIdx].(*field.Bitmap)
m.cachedBitmap.Reset()

m.mu.Lock()
m.fieldsMap[bitmapIdx] = struct{}{}
m.mu.Unlock()

return m.bitmap
return m.cachedBitmap
}

func (m *Message) MTI(val string) {
m.mu.Lock()
m.fieldsMap[mtiIdx] = struct{}{}
m.mu.Unlock()
defer m.mu.Unlock()

m.fieldsMap[mtiIdx] = struct{}{}
m.fields[mtiIdx].SetBytes([]byte(val))
}

Expand All @@ -87,48 +94,45 @@ func (m *Message) GetSpec() *MessageSpec {
}

func (m *Message) Field(id int, val string) error {
m.mu.Lock()
defer m.mu.Unlock()

if f, ok := m.fields[id]; ok {
m.mu.Lock()
m.fieldsMap[id] = struct{}{}
m.mu.Unlock()
return f.SetBytes([]byte(val))
}
return fmt.Errorf("failed to set field %d. ID does not exist", id)
}

func (m *Message) BinaryField(id int, val []byte) error {
m.mu.Lock()
defer m.mu.Unlock()

if f, ok := m.fields[id]; ok {
m.mu.Lock()
m.fieldsMap[id] = struct{}{}
m.mu.Unlock()

return f.SetBytes(val)
}
return fmt.Errorf("failed to set binary field %d. ID does not exist", id)
}

func (m *Message) GetMTI() (string, error) {
// check index
// we validate the presence and type of the mti field in
// spec.Validate() when we create the message so we can safely assume
// it exists
return m.fields[mtiIdx].String()
}

func (m *Message) GetString(id int) (string, error) {
if f, ok := m.fields[id]; ok {
m.mu.Lock()
m.fieldsMap[id] = struct{}{}
m.mu.Unlock()

// m.fieldsMap[id] = struct{}{}
return f.String()
}
return "", fmt.Errorf("failed to get string for field %d. ID does not exist", id)
}

func (m *Message) GetBytes(id int) ([]byte, error) {
if f, ok := m.fields[id]; ok {
m.mu.Lock()
m.fieldsMap[id] = struct{}{}
m.mu.Unlock()

// m.fieldsMap[id] = struct{}{}
return f.Bytes()
}
return nil, fmt.Errorf("failed to get bytes for field %d. ID does not exist", id)
Expand All @@ -143,16 +147,33 @@ func (m *Message) GetFields() map[int]field.Field {
m.mu.Lock()
defer m.mu.Unlock()

return m.getFields()
}

// getFields returns the map of the set fields. It assumes that the mutex
// is already locked by the caller.
func (m *Message) getFields() map[int]field.Field {
fields := map[int]field.Field{}
for i := range m.fieldsMap {
fields[i] = m.GetField(i)
}
return fields
}

// Pack returns the packed message or an error if the message is invalid
// error is of type *PackError
// Pack locks the message, packs its fields, and then unlocks it.
// If any errors are encountered during packing, they will be wrapped
// in a *PackError before being returned.
func (m *Message) Pack() ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()

return m.wrapErrorPack()

}

// wrapErrorPack calls the core packing logic and wraps any errors in a
// *PackError. It assumes that the mutex is already locked by the caller.
func (m *Message) wrapErrorPack() ([]byte, error) {
data, err := m.pack()
if err != nil {
return nil, &PackError{Err: err}
Expand All @@ -161,9 +182,12 @@ func (m *Message) Pack() ([]byte, error) {
return data, nil
}

// pack contains the core logic for packing the message. This method does not
// handle locking or error wrapping and should typically be used internally
// after ensuring concurrency safety.
func (m *Message) pack() ([]byte, error) {
packed := []byte{}
m.Bitmap().Reset()
m.bitmap().Reset()

ids, err := m.packableFieldIDs()
if err != nil {
Expand All @@ -174,16 +198,16 @@ func (m *Message) pack() ([]byte, error) {
// indexes 0 and 1 are for mti and bitmap
// regular field number startd from index 2
// do not pack presence bits as well
if id < 2 || m.Bitmap().IsBitmapPresenceBit(id) {
if id < 2 || m.bitmap().IsBitmapPresenceBit(id) {
continue
}
m.Bitmap().Set(id)
m.bitmap().Set(id)
}

// pack fields
for _, i := range ids {
// do not pack presence bits other than the first one as it's the bitmap itself
if i != 1 && m.Bitmap().IsBitmapPresenceBit(i) {
if i != 1 && m.bitmap().IsBitmapPresenceBit(i) {
continue
}

Expand All @@ -204,6 +228,16 @@ func (m *Message) pack() ([]byte, error) {
// Unpack unpacks the message from the given byte slice or returns an error
// which is of type *UnpackError and contains the raw message
func (m *Message) Unpack(src []byte) error {
m.mu.Lock()
defer m.mu.Unlock()

return m.wrappErrorUnpack(src)
}

// wrappErrorUnpack calls the core unpacking logic and wraps any
// errors in a *UnpackError. It assumes that the mutex is already
// locked by the caller.
func (m *Message) wrappErrorUnpack(src []byte) error {
if err := m.unpack(src); err != nil {
return &UnpackError{
Err: err,
Expand All @@ -213,22 +247,17 @@ func (m *Message) Unpack(src []byte) error {
return nil
}

// unpack contains the core logic for unpacking the message. This method does
// not handle locking or error wrapping and should typically be used internally
// after ensuring concurrency safety.
func (m *Message) unpack(src []byte) error {
var off int

m.mu.Lock()
// reset fields that were set
m.fieldsMap = map[int]struct{}{}
// we unlock here as m.Bitmap() will lock the mutex again
m.mu.Unlock()

bitmap := m.Bitmap()
// This method implicitly also sets m.fieldsMap[bitmapIdx]
bitmap.Reset()

// lock the mutex again as we're going to set fields
m.mu.Lock()
defer m.mu.Unlock()
m.bitmap().Reset()

read, err := m.fields[mtiIdx].Unpack(src)
if err != nil {
Expand All @@ -247,13 +276,13 @@ func (m *Message) unpack(src []byte) error {

off += read

for i := 2; i <= bitmap.Len(); i++ {
for i := 2; i <= m.bitmap().Len(); i++ {
// skip bitmap presence bits (for default bitmap length of 64 these are bits 1, 65, 129, 193, etc.)
if bitmap.IsBitmapPresenceBit(i) {
if m.bitmap().IsBitmapPresenceBit(i) {
continue
}

if bitmap.IsSet(i) {
if m.bitmap().IsSet(i) {
fl, ok := m.fields[i]
if !ok {
return fmt.Errorf("failed to unpack field %d: no specification found", i)
Expand All @@ -273,15 +302,19 @@ func (m *Message) unpack(src []byte) error {
return nil
}

// TODO: protect against concurrent access
func (m *Message) MarshalJSON() ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()

// by packing the message we will generate bitmap
// create HEX representation
// and validate message against the spec
if _, err := m.Pack(); err != nil {
if _, err := m.wrapErrorPack(); err != nil {
return nil, err
}

fieldMap := m.GetFields()
fieldMap := m.getFields()
strFieldMap := map[string]field.Field{}
for id, field := range fieldMap {
strFieldMap[fmt.Sprint(id)] = field
Expand Down Expand Up @@ -326,9 +359,6 @@ func (m *Message) UnmarshalJSON(b []byte) error {
}

func (m *Message) packableFieldIDs() ([]int, error) {
m.mu.Lock()
defer m.mu.Unlock()

// Index 1 represent bitmap which is always populated.
populatedFieldIDs := []int{1}

Expand All @@ -349,9 +379,12 @@ func (m *Message) packableFieldIDs() ([]int, error) {
// Clone clones the message by creating a new message from the binary
// representation of the original message
func (m *Message) Clone() (*Message, error) {
m.mu.Lock()
defer m.mu.Unlock()

newMessage := NewMessage(m.spec)

bytes, err := m.Pack()
bytes, err := m.wrapErrorPack()
if err != nil {
return nil, err
}
Expand Down
8 changes: 0 additions & 8 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package iso8583
import (
"encoding/hex"
"encoding/json"
"log"
"net/http"
"reflect"
"sync"
"testing"
Expand All @@ -17,15 +15,9 @@ import (
"github.com/moov-io/iso8583/sort"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

_ "net/http/pprof"
)

func TestMessage(t *testing.T) {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()

spec := &MessageSpec{
Fields: map[int]field.Field{
0: field.NewString(&field.Spec{
Expand Down

0 comments on commit 805db80

Please sign in to comment.