Skip to content

Commit

Permalink
Add Get Locations
Browse files Browse the repository at this point in the history
  • Loading branch information
tung.tq committed Oct 17, 2023
1 parent 3a3f12c commit 11e2064
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 14 deletions.
91 changes: 78 additions & 13 deletions svloc.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewUniverse() *Universe {
}

// CleanUp removes the data for wiring services
// after that, Get calls will panic
// after that, all calls will panic, excepts for Shutdown
func (u *Universe) CleanUp() {
u.data.mut.Lock()

Expand All @@ -74,8 +74,12 @@ type registeredService struct {
getCallLocation string
createUnv *Universe

newFunc func(unv *Universe) any
wrappers []func(unv *Universe, svc any) any
overrideCallLocation string

newFunc func(unv *Universe) any

wrappers []func(unv *Universe, svc any) any
wrapperLocs []string

onShutdown func()
}
Expand Down Expand Up @@ -216,17 +220,18 @@ func (u *universeData) appendShutdownFunc(fn func()) {
type Locator[T any] struct {
key *T
newFn func(unv *Universe) any

registerLoc string
}

// Get ...
// Get can panic if Shutdown already called
func (s *Locator[T]) Get(unv *Universe) T {
reg, err := unv.data.getService(s.key, s.newFn, "Get")
if err != nil {
panic(err.Error())
}

_, file, line, _ := runtime.Caller(1)
loc := fmt.Sprintf("%s:%d", file, line)
loc := getCallerLocation()

svc := reg.newService(unv, loc)
result, ok := svc.(T)
Expand All @@ -239,9 +244,9 @@ func (s *Locator[T]) Get(unv *Universe) T {

// Override prevents running the function inside Register
func (s *Locator[T]) Override(unv *Universe, svc T) error {
return s.OverrideFunc(unv, func(unv *Universe) T {
return s.overrideFuncWithLoc(unv, func(unv *Universe) T {
return svc
})
}, getCallerLocation())
}

func (s *Locator[T]) doBeforeGet(
Expand Down Expand Up @@ -269,10 +274,18 @@ func (s *Locator[T]) doBeforeGet(

// OverrideFunc ...
func (s *Locator[T]) OverrideFunc(unv *Universe, newFn func(unv *Universe) T) error {
return s.overrideFuncWithLoc(unv, newFn, getCallerLocation())
}

func (s *Locator[T]) overrideFuncWithLoc(
unv *Universe, newFn func(unv *Universe) T,
callLoc string,
) error {
if unv.prev != nil {
return errOverrideInsideNewFunctions
}
return s.doBeforeGet(unv, "Override", func(reg *registeredService) {
reg.overrideCallLocation = callLoc
reg.newFunc = func(unv *Universe) any {
return newFn(unv)
}
Expand All @@ -288,32 +301,42 @@ func (s *Locator[T]) panicOverrideError(err error) {

// MustOverride will fail if Override returns false
func (s *Locator[T]) MustOverride(unv *Universe, svc T) {
err := s.Override(unv, svc)
err := s.overrideFuncWithLoc(unv, func(unv *Universe) T {
return svc
}, getCallerLocation())
if err != nil {
s.panicOverrideError(err)
}
}

// MustOverrideFunc similar to OverrideFunc but panics if error returned
func (s *Locator[T]) MustOverrideFunc(unv *Universe, newFn func(unv *Universe) T) {
err := s.OverrideFunc(unv, newFn)
err := s.overrideFuncWithLoc(unv, newFn, getCallerLocation())
if err != nil {
s.panicOverrideError(err)
}
}

// Wrap the original implementation with the object created by wrapper
func (s *Locator[T]) Wrap(unv *Universe, wrapper func(unv *Universe, svc T) T) (err error) {
return s.wrapWithLoc(unv, wrapper, getCallerLocation())
}

func (s *Locator[T]) wrapWithLoc(
unv *Universe, wrapper func(unv *Universe, svc T) T,
callLoc string,
) (err error) {
return s.doBeforeGet(unv, "Wrap", func(reg *registeredService) {
reg.wrappers = append(reg.wrappers, func(unv *Universe, svc any) any {
return wrapper(unv, svc.(T))
})
reg.wrapperLocs = append(reg.wrapperLocs, callLoc)
})
}

// MustWrap similar to Wrap, but it will panic if not succeeded
func (s *Locator[T]) MustWrap(unv *Universe, wrapper func(unv *Universe, svc T) T) {
err := s.Wrap(unv, wrapper)
err := s.wrapWithLoc(unv, wrapper, getCallerLocation())
if err != nil {
var val *T

Expand All @@ -326,6 +349,39 @@ func (s *Locator[T]) MustWrap(unv *Universe, wrapper func(unv *Universe, svc T)
}
}

// GetLastOverrideLocation returns the last location that Override* is called
// if no Override* is called, returns the Register location
func (s *Locator[T]) GetLastOverrideLocation(unv *Universe) (string, error) {
reg, err := unv.data.getService(s.key, s.newFn, "GetLastOverrideLocation")
if err != nil {
return "", err
}

reg.mut.Lock()
defer reg.mut.Unlock()

if reg.overrideCallLocation != "" {
return reg.overrideCallLocation, nil
}
return s.registerLoc, nil
}

// GetWrapLocations returns Wrap* call's locations
func (s *Locator[T]) GetWrapLocations(unv *Universe) ([]string, error) {
reg, err := unv.data.getService(s.key, s.newFn, "GetWrapLocations")
if err != nil {
return nil, err
}

reg.mut.Lock()
defer reg.mut.Unlock()

locs := make([]string, len(reg.wrapperLocs))
copy(locs, reg.wrapperLocs)

return locs, nil
}

// OnShutdown must only be called inside 'new' functions
// It will panic if called outside
func (u *Universe) OnShutdown(fn func()) {
Expand All @@ -351,7 +407,7 @@ func (u *universeData) cloneShutdownFuncList() []func() {
return funcList
}

// Shutdown call each callback that registered by OnShutdown
// Shutdown call each callback that registered by OnShutdown
// This function must only be called outside the 'new' functions
// It will panic if called inside
func (u *Universe) Shutdown() {
Expand All @@ -371,6 +427,11 @@ func checkAllowRegistering() {
}
}

func getCallerLocation() string {
_, file, line, _ := runtime.Caller(2)
return fmt.Sprintf("%s:%d", file, line)
}

// Register creates a new Locator allow to call Get to create a new object
func Register[T any](newFn func(unv *Universe) T) *Locator[T] {
checkAllowRegistering()
Expand All @@ -381,15 +442,18 @@ func Register[T any](newFn func(unv *Universe) T) *Locator[T] {
newFn: func(unv *Universe) any {
return newFn(unv)
},
registerLoc: getCallerLocation(),
}
}

// RegisterSimple creates a new Locator with very simple newFn that returns the zero value
func RegisterSimple[T any]() *Locator[T] {
return Register[T](func(unv *Universe) T {
s := Register[T](func(unv *Universe) T {
var empty T
return empty
})
s.registerLoc = getCallerLocation()
return s
}

// RegisterEmpty does not init anything when calling Get, and must be Override
Expand All @@ -410,6 +474,7 @@ func RegisterEmpty[T any]() *Locator[T] {
),
)
},
registerLoc: getCallerLocation(),
}
}

Expand Down
128 changes: 127 additions & 1 deletion svloc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ func TestLocator_Do_Shutdown_Complex(t *testing.T) {
}

func TestSizeOfRegisteredService(t *testing.T) {
assert.Equal(t, 120, int(unsafe.Sizeof(registeredService{})))
assert.Equal(t, 160, int(unsafe.Sizeof(registeredService{})))
}

func TestUniverse_CleanUp(t *testing.T) {
Expand Down Expand Up @@ -834,3 +834,129 @@ func TestUniverse_CleanUp(t *testing.T) {
assert.Equal(t, []string{"repo"}, shutdowns)
})
}

func TestLocator_GetLastOverrideLocation(t *testing.T) {
t.Run("normal", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

unv := NewUniverse()

loc, err := repoLoc.GetLastOverrideLocation(unv)
assert.Equal(t, nil, err)

expect := "svloc_test.go:840"
assert.Equal(t, expect, loc[len(loc)-len(expect):])
})

t.Run("after override", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

unv := NewUniverse()

repoLoc.MustOverride(unv, &RepoMock{})

loc, err := repoLoc.GetLastOverrideLocation(unv)
assert.Equal(t, nil, err)

expect := "svloc_test.go:860"
assert.Equal(t, expect, loc[len(loc)-len(expect):])
})

t.Run("after override func", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

unv := NewUniverse()

err := repoLoc.OverrideFunc(unv, func(unv *Universe) Repo {
return &RepoMock{}
})
assert.Equal(t, nil, err)

loc, err := repoLoc.GetLastOverrideLocation(unv)
assert.Equal(t, nil, err)

expect := "svloc_test.go:876"
assert.Equal(t, expect, loc[len(loc)-len(expect):])
})

t.Run("after clean up", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

unv := NewUniverse()

unv.CleanUp()

loc, err := repoLoc.GetLastOverrideLocation(unv)
assert.Equal(t, errors.New("svloc: can NOT call 'GetLastOverrideLocation' after 'CleanUp'"), err)
assert.Equal(t, "", loc)
})
}

func TestLocator_GetWrapLocations(t *testing.T) {
t.Run("empty", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

unv := NewUniverse()

locs, err := repoLoc.GetWrapLocations(unv)
assert.Equal(t, nil, err)
assert.Equal(t, 0, len(locs))
})

t.Run("multiple", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

unv := NewUniverse()

repoLoc.MustWrap(unv, func(unv *Universe, repo Repo) Repo {
return &WrapperRepo{
repo: repo,
prefix: "prefix01",
}
})

repoLoc.MustWrap(unv, func(unv *Universe, repo Repo) Repo {
return &WrapperRepo{
repo: repo,
prefix: "prefix02",
}
})

locs, err := repoLoc.GetWrapLocations(unv)
assert.Equal(t, nil, err)

assert.Equal(t, 2, len(locs))

expect := "svloc_test.go:923"
assert.Equal(t, expect, locs[0][len(locs[0])-len(expect):])

expect = "svloc_test.go:930"
assert.Equal(t, expect, locs[1][len(locs[1])-len(expect):])
})

t.Run("fail after clean up", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

unv := NewUniverse()

unv.CleanUp()

locs, err := repoLoc.GetWrapLocations(unv)
assert.Equal(t, errors.New("svloc: can NOT call 'GetWrapLocations' after 'CleanUp'"), err)
assert.Equal(t, 0, len(locs))
})
}

0 comments on commit 11e2064

Please sign in to comment.