Skip to content

Commit

Permalink
fix bug in numRoots stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Wilson committed Jul 25, 2024
1 parent 0ff9f9f commit a7de325
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 4 deletions.
30 changes: 26 additions & 4 deletions graphmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,29 @@ func (g *graphMap) Range(f func(key PipelineID, value *registeredPipeline) bool)

// Store calls sync.Map.Store
func (g *graphMap) Store(id PipelineID, root *registeredPipeline) {
// Store the root node and increment how many we have.
// Store the root node and increment how many we have (if this is a new pipeline).
// NOTE: These two actions might not be atomic, so potentially something could
// start to range over the map before we've made the change to the total number
// of roots.
if !g.Exists(id) {
g.numRoots++
}
g.m.Store(id, root)
g.numRoots++
}

// Delete calls sync.Map.Delete
func (g *graphMap) Delete(id PipelineID) {
// Delete the nodes for the pipeline, and decrement how many root nodes we have.
if !g.Exists(id) {
return
}

// Delete the root node for the pipeline if it was already stored, and decrement
// how many we have.
// NOTE: These two actions might not be atomic, so potentially something could
// start to range over the map before we've made the change to the total number
// of roots.
g.m.Delete(id)
g.numRoots--
g.m.Delete(id)
}

// Nodes returns all the nodes referenced by the specified Pipeline
Expand All @@ -78,3 +85,18 @@ func (g *graphMap) Nodes(id PipelineID) ([]NodeID, error) {

return result, nil
}

// Exists determines whether a PipelineID is already stored within the graphMap.
func (g *graphMap) Exists(id PipelineID) bool {
var found bool

g.Range(func(key PipelineID, v *registeredPipeline) bool {
if key == id {
found = true
return false
}
return true
})

return found
}
88 changes: 88 additions & 0 deletions graphmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,91 @@ func TestNodes_ListNodes_RegisteredPipeline(t *testing.T) {
require.Contains(t, nodeIDs, NodeID("b"))
require.Contains(t, nodeIDs, NodeID("c"))
}

func TestGraphMap_Store(t *testing.T) {
t.Parallel()

g := &graphMap{}
findPipeline := pipelineFinder(g)

// Set up pipeline for storing.
id := PipelineID("foo")
p := &registeredPipeline{
registrationPolicy: "bar",
}

// Sanity check we have nothing to start with.
require.Equal(t, 0, g.numRoots)
v := findPipeline(id)
require.Nil(t, v)

// Store the pipeline then check we stored it and incremented the counter.
g.Store(id, p)
require.Equal(t, 1, g.numRoots)
v = findPipeline(id)
require.NotNil(t, v)
require.Equal(t, RegistrationPolicy("bar"), v.registrationPolicy)

// Store it again, and check it's still there but the counter hasn't changed
// since it's per distinct ID.
p.registrationPolicy = "baz"
g.Store(id, p)
require.Equal(t, 1, g.numRoots)
v = findPipeline(id)
require.NotNil(t, v)
require.Equal(t, RegistrationPolicy("baz"), v.registrationPolicy)
}

func TestGraphMap_Delete(t *testing.T) {
t.Parallel()

g := &graphMap{}
findPipeline := pipelineFinder(g)

// Set up pipeline for storing.
id := PipelineID("foo")
p := &registeredPipeline{
registrationPolicy: "bar",
}

// Sanity check we have nothing to start with.
require.Equal(t, 0, g.numRoots)
v := findPipeline(id)
require.Nil(t, v)

// Store the pipeline then check we stored it and incremented the counter.
g.Store(id, p)
require.Equal(t, 1, g.numRoots)
v = findPipeline(id)
require.NotNil(t, v)
require.Equal(t, RegistrationPolicy("bar"), v.registrationPolicy)

// Now delete the pipeline and check it's gone and the counter went down.
g.Delete(id)
require.Equal(t, 0, g.numRoots)
v = findPipeline(id)
require.Nil(t, v)

// Delete again and make sure nothing funky happens.
g.Delete(id)
require.Equal(t, 0, g.numRoots)
v = findPipeline(id)
require.Nil(t, v)
}

// pipelineFinder returns a func that can be used to find a pipeline in a graphMap.
func pipelineFinder(g *graphMap) func(id PipelineID) *registeredPipeline {
return func(id PipelineID) *registeredPipeline {
var res *registeredPipeline

g.Range(func(key PipelineID, rp *registeredPipeline) bool {
if key == id {
res = rp
return false
}
return true
})

return res
}
}

0 comments on commit a7de325

Please sign in to comment.