diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..9965205d0 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,7 @@ +.github +assets + +LICENSE + +*.yml +*.md diff --git a/.github/workflows/build_release.yml b/.github/workflows/build_release.yml index 2a8948ab0..c1576d25d 100644 --- a/.github/workflows/build_release.yml +++ b/.github/workflows/build_release.yml @@ -76,7 +76,7 @@ jobs: with: context: . platforms: linux/amd64,linux/arm64 - file: docker/${{ matrix.image }}/Dockerfile + file: cmd/${{ matrix.image }}/Dockerfile push: true tags: | zhenghaoz/${{ matrix.image }}:latest diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 68842efc0..bf4a5eafd 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -134,7 +134,6 @@ jobs: - uses: actions/checkout@v1 - name: Build the stack - working-directory: ./docker run: docker-compose up -d - name: Check the deployed service URL diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4eedb1c3d..d149e2cfb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,7 +24,7 @@ These following installations are required: - **Docker Compose**: Multiple databases are required for unit tests. It's convenient to manage databases on Docker Compose. ```bash -cd misc/database_test +cd storage docker-compose up -d ``` @@ -73,7 +73,7 @@ Most logics in Gorse are covered by unit tests. Run unit tests by the following go test -v ./... ``` -The default database URLs are directed to these databases in `misc/database_test/docker-compose.yml`. Test databases could be overrode by setting following environment variables: +The default database URLs are directed to these databases in `storage/docker-compose.yml`. Test databases could be overrode by setting following environment variables: | Environment Value | Default Value | |-------------------|----------------------------------------------| diff --git a/README.md b/README.md index e31671630..718628ed2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ +![](https://img.shields.io/github/go-mod/go-version/zhenghaoz/gorse) [![build](https://github.com/zhenghaoz/gorse/workflows/build/badge.svg)](https://github.com/zhenghaoz/gorse/actions?query=workflow%3Abuild) [![codecov](https://codecov.io/gh/gorse-io/gorse/branch/master/graph/badge.svg)](https://codecov.io/gh/gorse-io/gorse) [![Go Report Card](https://goreportcard.com/badge/github.com/zhenghaoz/gorse)](https://goreportcard.com/report/github.com/zhenghaoz/gorse) diff --git a/base/copier/copier.go b/base/copier/copier.go index 03f2a3f66..621af599e 100644 --- a/base/copier/copier.go +++ b/base/copier/copier.go @@ -86,10 +86,10 @@ func copyValue(dst, src reflect.Value) error { dstPointer := reflect.New(dst.Type()) srcPointer := reflect.New(src.Type()) srcPointer.Elem().Set(src) - srcMarshaller, hasSrcMarshaler := srcPointer.Interface().(encoding.BinaryMarshaler) + srcMarshaller, hasSrcMarshaller := srcPointer.Interface().(encoding.BinaryMarshaler) dstUnmarshaler, hasDstUnmarshaler := dstPointer.Interface().(encoding.BinaryUnmarshaler) - if hasDstUnmarshaler && hasSrcMarshaler { + if hasDstUnmarshaler && hasSrcMarshaller { dstByte, err := srcMarshaller.MarshalBinary() if err != nil { return err @@ -114,6 +114,11 @@ func copyValue(dst, src reflect.Value) error { } } case reflect.Ptr: + if src.IsNil() { + // If source is nil, set dst to nil. + dst.Set(reflect.Zero(dst.Type())) + return nil + } if dst.IsNil() { dst.Set(reflect.New(src.Elem().Type())) } @@ -124,6 +129,11 @@ func copyValue(dst, src reflect.Value) error { return err } case reflect.Interface: + if src.IsNil() { + // If source is nil, set dst to nil. + dst.Set(reflect.Zero(dst.Type())) + return nil + } if !dst.IsNil() { switch dst.Elem().Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, diff --git a/base/copier/copier_test.go b/base/copier/copier_test.go index 06ddf34ba..fca089e67 100644 --- a/base/copier/copier_test.go +++ b/base/copier/copier_test.go @@ -170,3 +170,22 @@ func TestPrivate(t *testing.T) { *a.text = "world" assert.Equal(t, "hello", *b.text) } + +type NilInterface interface{} + +type NilStruct struct { + Interface NilInterface + Pointer *Foo +} + +func TestNil(t *testing.T) { + var d = NilStruct{ + Interface: 100, + Pointer: &Foo{A: 100}, + } + var e NilStruct + err := Copy(&d, e) + assert.NoError(t, err) + assert.Nil(t, d.Interface) + assert.Nil(t, d.Pointer) +} diff --git a/base/parallel/parallel.go b/base/parallel/parallel.go index 77aee96fa..600cae641 100644 --- a/base/parallel/parallel.go +++ b/base/parallel/parallel.go @@ -15,9 +15,17 @@ package parallel import ( + "sync" + "github.com/juju/errors" + "github.com/zhenghaoz/gorse/base/task" + "go.uber.org/atomic" "modernc.org/mathutil" - "sync" +) + +const ( + chanSize = 1024 + allocPeriod = 128 ) /* Parallel Schedulers */ @@ -33,19 +41,13 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error } } } else { - const chanSize = 64 - const chanEOF = -1 c := make(chan int, chanSize) // producer go func() { - // send jobs for i := 0; i < nJobs; i++ { c <- i } - // send EOF - for i := 0; i < nWorkers; i++ { - c <- chanEOF - } + close(c) }() // consumer var wg sync.WaitGroup @@ -57,8 +59,8 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error defer wg.Done() for { // read job - jobId := <-c - if jobId == chanEOF { + jobId, ok := <-c + if !ok { return } // run job @@ -80,6 +82,55 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error return nil } +func DynamicParallel(nJobs int, jobsAlloc *task.JobsAllocator, worker func(workerId, jobId int) error) error { + c := make(chan int, chanSize) + // producer + go func() { + for i := 0; i < nJobs; i++ { + c <- i + } + close(c) + }() + // consumer + for { + exit := atomic.NewBool(true) + numJobs := jobsAlloc.AvailableJobs(nil) + var wg sync.WaitGroup + wg.Add(numJobs) + errs := make([]error, nJobs) + for j := 0; j < numJobs; j++ { + // start workers + go func(workerId int) { + defer wg.Done() + for i := 0; i < allocPeriod; i++ { + // read job + jobId, ok := <-c + if !ok { + return + } + exit.Store(false) + // run job + if err := worker(workerId, jobId); err != nil { + errs[jobId] = err + return + } + } + }(j) + } + wg.Wait() + // check errors + for _, err := range errs { + if err != nil { + return errors.Trace(err) + } + } + // exit if finished + if exit.Load() { + return nil + } + } +} + type batchJob struct { beginId int endId int @@ -90,19 +141,13 @@ func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJo if nWorkers == 1 { return worker(0, 0, nJobs) } - const chanSize = 64 - const chanEOF = -1 c := make(chan batchJob, chanSize) // producer go func() { - // send jobs for i := 0; i < nJobs; i += batchSize { c <- batchJob{beginId: i, endId: mathutil.Min(i+batchSize, nJobs)} } - // send EOF - for i := 0; i < nWorkers; i++ { - c <- batchJob{beginId: chanEOF, endId: chanEOF} - } + close(c) }() // consumer var wg sync.WaitGroup @@ -114,8 +159,8 @@ func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJo defer wg.Done() for { // read job - job := <-c - if job.beginId == chanEOF { + job, ok := <-c + if !ok { return } // run job diff --git a/base/parallel/parallel_test.go b/base/parallel/parallel_test.go index ffd3ab0f8..aa21b4124 100644 --- a/base/parallel/parallel_test.go +++ b/base/parallel/parallel_test.go @@ -18,6 +18,7 @@ import ( "github.com/scylladb/go-set" "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base" + "github.com/zhenghaoz/gorse/base/task" "testing" "time" ) @@ -48,6 +49,22 @@ func TestParallel(t *testing.T) { assert.Equal(t, 1, workersSet.Size()) } +func TestDynamicParallel(t *testing.T) { + a := base.RangeInt(10000) + b := make([]int, len(a)) + workerIds := make([]int, len(a)) + _ = DynamicParallel(len(a), task.NewConstantJobsAllocator(3), func(workerId, jobId int) error { + b[jobId] = a[jobId] + workerIds[jobId] = workerId + time.Sleep(time.Microsecond) + return nil + }) + workersSet := set.NewIntSet(workerIds...) + assert.Equal(t, a, b) + assert.GreaterOrEqual(t, 4, workersSet.Size()) + assert.Less(t, 1, workersSet.Size()) +} + func TestBatchParallel(t *testing.T) { a := base.RangeInt(10000) b := make([]int, len(a)) diff --git a/base/random.go b/base/random.go index eb02aa5a1..10fd00c10 100644 --- a/base/random.go +++ b/base/random.go @@ -15,9 +15,10 @@ package base import ( + "math/rand" + "github.com/scylladb/go-set/i32set" "github.com/scylladb/go-set/iset" - "math/rand" ) // RandomGenerator is the random generator for gorse. @@ -76,15 +77,6 @@ func (rng RandomGenerator) NormalVector64(size int, mean, stdDev float64) []floa return ret } -// NormalMatrix64 makes a matrix filled with normal random floats. -func (rng RandomGenerator) NormalMatrix64(row, col int, mean, stdDev float64) [][]float64 { - ret := make([][]float64, row) - for i := range ret { - ret[i] = rng.NormalVector64(col, mean, stdDev) - } - return ret -} - // Sample n values between low and high, but not in exclude. func (rng RandomGenerator) Sample(low, high, n int, exclude ...*iset.Set) []int { intervalLength := high - low diff --git a/base/random_test.go b/base/random_test.go index 1b110f023..caa2a1e79 100644 --- a/base/random_test.go +++ b/base/random_test.go @@ -14,13 +14,12 @@ package base import ( + "testing" + "github.com/chewxy/math32" "github.com/scylladb/go-set" "github.com/stretchr/testify/assert" "github.com/thoas/go-funk" - "gonum.org/v1/gonum/stat" - "math" - "testing" ) const randomEpsilon = 0.1 @@ -39,13 +38,6 @@ func TestRandomGenerator_MakeUniformMatrix(t *testing.T) { assert.False(t, funk.MaxFloat32(vec) > 2) } -func TestRandomGenerator_MakeNormalMatrix64(t *testing.T) { - rng := NewRandomGenerator(0) - vec := rng.NormalMatrix64(1, 1000, 1, 2)[0] - assert.False(t, math.Abs(stat.Mean(vec, nil)-1) > randomEpsilon) - assert.False(t, math.Abs(stat.StdDev(vec, nil)-2) > randomEpsilon) -} - func TestRandomGenerator_Sample(t *testing.T) { excludeSet := set.NewIntSet(0, 1, 2, 3, 4) rng := NewRandomGenerator(0) @@ -80,8 +72,10 @@ func stdDev(x []float32) float32 { } // meanVariance computes the sample mean and unbiased variance, where the mean and variance are -// \sum_i w_i * x_i / (sum_i w_i) -// \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1) +// +// \sum_i w_i * x_i / (sum_i w_i) +// \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1) +// // respectively. // If weights is nil then all of the weights are 1. If weights is not nil, then // len(x) must equal len(weights). diff --git a/base/search/index_test.go b/base/search/index_test.go index 65b03183d..9a070c290 100644 --- a/base/search/index_test.go +++ b/base/search/index_test.go @@ -36,7 +36,7 @@ func TestHNSW_InnerProduct(t *testing.T) { model.InitMean: 0, model.InitStdDev: 0.001, }) - fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU()) + fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU())) m.Fit(trainSet, testSet, fitConfig) var vectors []Vector for i, itemFactor := range m.ItemFactor { diff --git a/base/search/ivf.go b/base/search/ivf.go index a8435ab10..12bb20355 100644 --- a/base/search/ivf.go +++ b/base/search/ivf.go @@ -15,6 +15,11 @@ package search import ( + "math" + "math/rand" + "sync" + "time" + "github.com/chewxy/math32" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/heap" @@ -23,12 +28,7 @@ import ( "github.com/zhenghaoz/gorse/base/task" "go.uber.org/atomic" "go.uber.org/zap" - "math" - "math/rand" "modernc.org/mathutil" - "runtime" - "sync" - "time" ) const ( @@ -46,7 +46,7 @@ type IVF struct { errorRate float32 maxIter int numProbe int - numJobs int + jobsAlloc *task.JobsAllocator task *task.SubTask } @@ -65,9 +65,9 @@ func SetClusterErrorRate(errorRate float32) IVFConfig { } } -func SetIVFNumJobs(numJobs int) IVFConfig { +func SetIVFJobsAllocator(jobsAlloc *task.JobsAllocator) IVFConfig { return func(ivf *IVF) { - ivf.numJobs = numJobs + ivf.jobsAlloc = jobsAlloc } } @@ -90,7 +90,6 @@ func NewIVF(vectors []Vector, configs ...IVFConfig) *IVF { errorRate: 0.05, maxIter: DefaultMaxIter, numProbe: 1, - numJobs: runtime.NumCPU(), } for _, config := range configs { config(idx) @@ -207,7 +206,7 @@ func (idx *IVF) Build() { // reassign clusters nextClusters := make([]ivfCluster, idx.k) - _ = parallel.Parallel(len(idx.data), idx.numJobs, func(_, i int) error { + _ = parallel.Parallel(len(idx.data), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error { if !idx.data[i].IsHidden() { nextCluster, nextDistance := -1, float32(math32.MaxFloat32) for c := range clusters { @@ -270,7 +269,7 @@ func (b *IVFBuilder) evaluate(idx *IVF, prune0 bool) float32 { samples := b.rng.Sample(0, len(b.data), testSize) var result, count float32 var mu sync.Mutex - _ = parallel.Parallel(len(samples), idx.numJobs, func(_, i int) error { + _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error { sample := samples[i] expected, _ := b.bruteForce.Search(b.data[sample], b.k, prune0) if len(expected) > 0 { @@ -321,7 +320,7 @@ func (b *IVFBuilder) evaluateTermSearch(idx *IVF, prune0 bool, term string) floa samples := b.rng.Sample(0, len(b.data), testSize) var result, count float32 var mu sync.Mutex - _ = parallel.Parallel(len(samples), idx.numJobs, func(_, i int) error { + _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error { sample := samples[i] expected, _ := b.bruteForce.MultiSearch(b.data[sample], []string{term}, b.k, prune0) if len(expected) > 0 { diff --git a/base/search/ivf_test.go b/base/search/ivf_test.go index 8f9ba1ffb..da875d4e8 100644 --- a/base/search/ivf_test.go +++ b/base/search/ivf_test.go @@ -15,8 +15,10 @@ package search import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/base/task" ) func TestIVFConfig(t *testing.T) { @@ -28,8 +30,8 @@ func TestIVFConfig(t *testing.T) { SetClusterErrorRate(0.123)(ivf) assert.Equal(t, float32(0.123), ivf.errorRate) - SetIVFNumJobs(234)(ivf) - assert.Equal(t, 234, ivf.numJobs) + SetIVFJobsAllocator(task.NewConstantJobsAllocator(234))(ivf) + assert.Equal(t, 234, ivf.jobsAlloc.AvailableJobs(nil)) SetMaxIteration(345)(ivf) assert.Equal(t, 345, ivf.maxIter) diff --git a/base/task/schedule.go b/base/task/schedule.go index b5308e09c..bc0860258 100644 --- a/base/task/schedule.go +++ b/base/task/schedule.go @@ -15,70 +15,181 @@ package task import ( - "github.com/scylladb/go-set/strset" + "runtime" + "sort" "sync" + + "github.com/samber/lo" + "github.com/zhenghaoz/gorse/base/log" + "go.uber.org/zap" + "modernc.org/mathutil" ) -// Scheduler schedules that pre-locked tasks are executed first. -type Scheduler struct { - *sync.Cond - Privileged *strset.Set - Running bool +type JobsAllocator struct { + numJobs int // the max number of jobs + taskName string // its task name in scheduler + scheduler *JobsScheduler } -// NewTaskScheduler creates a Scheduler. -func NewTaskScheduler() *Scheduler { - return &Scheduler{ - Cond: sync.NewCond(&sync.Mutex{}), - Privileged: strset.New(), +func NewConstantJobsAllocator(num int) *JobsAllocator { + return &JobsAllocator{ + numJobs: num, } } -// PreLock a task, the task has the privilege to run first than un-pre-clocked tasks. -func (t *Scheduler) PreLock(name string) { - t.L.Lock() - defer t.L.Unlock() - t.Privileged.Add(name) +func (allocator *JobsAllocator) MaxJobs() int { + if allocator == nil || allocator.numJobs < 1 { + // Return 1 for invalid allocator + return 1 + } + return allocator.numJobs } -// Lock gets the permission to run task. -func (t *Scheduler) Lock(name string) { - t.L.Lock() - defer t.L.Unlock() - for t.Running || (!t.Privileged.IsEmpty() && !t.Privileged.Has(name)) { - t.Wait() +func (allocator *JobsAllocator) AvailableJobs(tracker *Task) int { + if allocator == nil || allocator.numJobs < 1 { + // Return 1 for invalid allocator + return 1 + } else if allocator.scheduler != nil { + // Use jobs scheduler + return allocator.scheduler.allocateJobsForTask(allocator.taskName, true, tracker) } - t.Running = true + return allocator.numJobs } -// UnLock returns the permission to run task. -func (t *Scheduler) UnLock(name string) { - t.L.Lock() - defer t.L.Unlock() - t.Running = false - t.Privileged.Remove(name) - t.Broadcast() +// Init jobs allocation. This method is used to request allocation of jobs for the first time. +func (allocator *JobsAllocator) Init() { + if allocator.scheduler != nil { + allocator.scheduler.allocateJobsForTask(allocator.taskName, true, nil) + } } -func (t *Scheduler) NewRunner(name string) *Runner { - return &Runner{ - Scheduler: t, - Name: name, +// taskInfo represents a task in JobsScheduler. +type taskInfo struct { + name string // name of the task + priority int // high priority tasks are allocated first + privileged bool // privileged tasks are allocated first + jobs int // number of jobs allocated to the task + previous int // previous number of jobs allocated to the task +} + +// JobsScheduler allocates jobs to multiple tasks. +type JobsScheduler struct { + *sync.Cond + numJobs int // number of jobs + freeJobs int // number of free jobs + tasks map[string]*taskInfo +} + +// NewJobsScheduler creates a JobsScheduler with num jobs. +func NewJobsScheduler(num int) *JobsScheduler { + if num <= 0 { + // Use all cores if num is less than 1. + num = runtime.NumCPU() + } + return &JobsScheduler{ + Cond: sync.NewCond(&sync.Mutex{}), + numJobs: num, + freeJobs: num, + tasks: make(map[string]*taskInfo), } } -// Runner is a Scheduler bounded with a task. -type Runner struct { - *Scheduler - Name string +// Register a task in the JobsScheduler. Registered tasks will be ignored and return false. +func (s *JobsScheduler) Register(taskName string, priority int, privileged bool) bool { + s.L.Lock() + defer s.L.Unlock() + if _, exits := s.tasks[taskName]; !exits { + s.tasks[taskName] = &taskInfo{name: taskName, priority: priority, privileged: privileged} + return true + } else { + return false + } } -// Lock gets the permission to run task. -func (locker *Runner) Lock() { - locker.Scheduler.Lock(locker.Name) +// Unregister a task from the JobsScheduler. +func (s *JobsScheduler) Unregister(taskName string) { + s.L.Lock() + defer s.L.Unlock() + if task, exits := s.tasks[taskName]; exits { + // Return allocated jobs. + s.freeJobs += task.jobs + delete(s.tasks, taskName) + s.Broadcast() + } } -// UnLock returns the permission to run task. -func (locker *Runner) UnLock() { - locker.Scheduler.UnLock(locker.Name) +func (s *JobsScheduler) GetJobsAllocator(taskName string) *JobsAllocator { + return &JobsAllocator{ + numJobs: s.numJobs, + taskName: taskName, + scheduler: s, + } +} + +func (s *JobsScheduler) allocateJobsForTask(taskName string, block bool, tracker *Task) int { + // Find current task and return the jobs temporarily. + s.L.Lock() + currentTask, exist := s.tasks[taskName] + if !exist { + panic("task not found") + } + s.freeJobs += currentTask.jobs + currentTask.jobs = 0 + s.L.Unlock() + + s.L.Lock() + defer s.L.Unlock() + for { + s.allocateJobsForAll() + if currentTask.jobs == 0 && block { + tracker.Suspend(true) + s.Wait() + } else { + tracker.Suspend(false) + return currentTask.jobs + } + } +} + +func (s *JobsScheduler) allocateJobsForAll() { + // Separate privileged tasks and normal tasks + privileged := make([]*taskInfo, 0) + normal := make([]*taskInfo, 0) + for _, task := range s.tasks { + if task.privileged { + privileged = append(privileged, task) + } else { + normal = append(normal, task) + } + } + + var tasks []*taskInfo + if len(privileged) > 0 { + tasks = privileged + } else { + tasks = normal + } + + // allocate jobs + sort.Slice(tasks, func(i, j int) bool { + return tasks[i].priority > tasks[j].priority + }) + for i, task := range tasks { + if s.freeJobs == 0 { + return + } + targetJobs := s.numJobs/len(tasks) + lo.If(i < s.numJobs%len(tasks), 1).Else(0) + targetJobs = mathutil.Min(targetJobs, s.freeJobs) + if task.jobs < targetJobs { + if task.previous != targetJobs { + log.Logger().Debug("reallocate jobs for task", + zap.String("task", task.name), + zap.Int("previous_jobs", task.previous), + zap.Int("target_jobs", targetJobs)) + } + s.freeJobs -= targetJobs - task.jobs + task.jobs = targetJobs + task.previous = task.jobs + } + } } diff --git a/base/task/schedule_test.go b/base/task/schedule_test.go index 614c5e70f..c3541bcfb 100644 --- a/base/task/schedule_test.go +++ b/base/task/schedule_test.go @@ -1,51 +1,109 @@ +// Copyright 2022 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package task import ( - "fmt" - "github.com/stretchr/testify/assert" "sync" "testing" + + "github.com/stretchr/testify/assert" ) -func TestTaskScheduler(t *testing.T) { - taskScheduler := NewTaskScheduler() +func TestConstantJobsAllocator(t *testing.T) { + allocator := NewConstantJobsAllocator(314) + assert.Equal(t, 314, allocator.MaxJobs()) + assert.Equal(t, 314, allocator.AvailableJobs(nil)) + + allocator = NewConstantJobsAllocator(-1) + assert.Equal(t, 1, allocator.MaxJobs()) + assert.Equal(t, 1, allocator.AvailableJobs(nil)) + + allocator = nil + assert.Equal(t, 1, allocator.MaxJobs()) + assert.Equal(t, 1, allocator.AvailableJobs(nil)) +} + +func TestDynamicJobsAllocator(t *testing.T) { + s := NewJobsScheduler(8) + s.Register("a", 1, true) + s.Register("b", 2, true) + s.Register("c", 3, true) + s.Register("d", 4, false) + s.Register("e", 4, false) + c := s.GetJobsAllocator("c") + assert.Equal(t, 8, c.MaxJobs()) + assert.Equal(t, 3, c.AvailableJobs(nil)) + b := s.GetJobsAllocator("b") + assert.Equal(t, 3, b.AvailableJobs(nil)) + a := s.GetJobsAllocator("a") + assert.Equal(t, 2, a.AvailableJobs(nil)) + + barrier := make(chan struct{}) var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + barrier <- struct{}{} + d := s.GetJobsAllocator("d") + assert.Equal(t, 4, d.AvailableJobs(nil)) + }() + go func() { + defer wg.Done() + barrier <- struct{}{} + e := s.GetJobsAllocator("e") + e.Init() + assert.Equal(t, 4, s.allocateJobsForTask("e", false, nil)) + }() - // pre-lock for privileged tasks - for i := 0; i < 50; i++ { - taskScheduler.PreLock(fmt.Sprintf("privileged_%d", i)) - } - - // start ragtag tasks - result := make([]string, 0, 1000) - for i := 0; i < 50; i++ { - wg.Add(1) - go func(name string) { - taskScheduler.Lock(name) - result = append(result, name) - taskScheduler.UnLock(name) - wg.Done() - }(fmt.Sprintf("ragtag_%d", i)) - } - - // start privileged tasks - for i := 0; i < 50; i++ { - wg.Add(1) - go func(locker *Runner) { - locker.Lock() - result = append(result, locker.Name) - locker.UnLock() - wg.Done() - }(taskScheduler.NewRunner(fmt.Sprintf("privileged_%d", i))) - } - - // check result + <-barrier + <-barrier + s.Unregister("a") + s.Unregister("b") + s.Unregister("c") wg.Wait() - for i := 0; i < 100; i++ { - if i < 50 { - assert.Contains(t, result[i], "privileged_") - } else { - assert.Contains(t, result[i], "ragtag_") - } - } +} + +func TestJobsScheduler(t *testing.T) { + s := NewJobsScheduler(8) + assert.True(t, s.Register("a", 1, true)) + assert.True(t, s.Register("b", 2, true)) + assert.True(t, s.Register("c", 3, true)) + assert.True(t, s.Register("d", 4, false)) + assert.True(t, s.Register("e", 4, false)) + assert.False(t, s.Register("c", 1, true)) + assert.Equal(t, 3, s.allocateJobsForTask("c", false, nil)) + assert.Equal(t, 3, s.allocateJobsForTask("b", false, nil)) + assert.Equal(t, 2, s.allocateJobsForTask("a", false, nil)) + assert.Equal(t, 0, s.allocateJobsForTask("d", false, nil)) + assert.Equal(t, 0, s.allocateJobsForTask("e", false, nil)) + + // several tasks complete + s.Unregister("b") + s.Unregister("c") + assert.Equal(t, 8, s.allocateJobsForTask("a", false, nil)) + + // privileged tasks complete + s.Unregister("a") + assert.Equal(t, 4, s.allocateJobsForTask("d", false, nil)) + assert.Equal(t, 4, s.allocateJobsForTask("e", false, nil)) + + // block privileged tasks if normal tasks are running + s.Register("a", 1, true) + s.Register("b", 2, true) + s.Register("c", 3, true) + assert.Equal(t, 0, s.allocateJobsForTask("c", false, nil)) + assert.Equal(t, 0, s.allocateJobsForTask("b", false, nil)) + assert.Equal(t, 0, s.allocateJobsForTask("a", false, nil)) } diff --git a/config/config.toml.template b/config/config.toml.template index f80e73ff5..f07dea27d 100644 --- a/config/config.toml.template +++ b/config/config.toml.template @@ -2,7 +2,9 @@ # The database for caching, support Redis, MySQL, Postgres and MongoDB: # redis://:@:/ +# rediss://:@:/ # postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full +# postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] # mongodb+srv://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] cache_store = "redis://localhost:6379/0" @@ -10,7 +12,10 @@ cache_store = "redis://localhost:6379/0" # The database for persist data, support MySQL, Postgres, ClickHouse and MongoDB: # mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] # postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full +# postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # clickhouse://user:password@host[:port]/database?param1=value1&...¶mN=valueN +# chhttp://user:password@host[:port]/database?param1=value1&...¶mN=valueN +# chhttps://user:password@host[:port]/database?param1=value1&...¶mN=valueN # mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] # mongodb+srv://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]] data_store = "mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse" diff --git a/docker/docker-compose.yml b/docker-compose.yml similarity index 85% rename from docker/docker-compose.yml rename to docker-compose.yml index 8bfb7b91f..7ccdd4050 100644 --- a/docker/docker-compose.yml +++ b/docker-compose.yml @@ -20,7 +20,9 @@ services: - mysql_data:/var/lib/mysql worker: - image: zhenghaoz/gorse-worker:nightly + build: + context: . + dockerfile: cmd/gorse-worker/Dockerfile restart: unless-stopped ports: - 8089:8089 @@ -36,7 +38,9 @@ services: - master server: - image: zhenghaoz/gorse-server:nightly + build: + context: . + dockerfile: cmd/gorse-server/Dockerfile restart: unless-stopped ports: - 8087:8087 @@ -52,7 +56,9 @@ services: - master master: - image: zhenghaoz/gorse-master:nightly + build: + context: . + dockerfile: cmd/gorse-master/Dockerfile restart: unless-stopped ports: - 8086:8086 @@ -65,7 +71,7 @@ services: --log-path /var/log/gorse/master.log --cache-path /var/lib/gorse/master_cache.data volumes: - - ./config.toml:/etc/gorse/config.toml + - ./config/config.toml.template:/etc/gorse/config.toml - gorse_log:/var/log/gorse - master_data:/var/lib/gorse depends_on: diff --git a/docker/config.toml b/docker/config.toml deleted file mode 100644 index a8631d298..000000000 --- a/docker/config.toml +++ /dev/null @@ -1,192 +0,0 @@ -[master] - -# GRPC port of the master node. The default value is 8086. -port = 8086 - -# gRPC host of the master node. The default values is "0.0.0.0". -host = "0.0.0.0" - -# HTTP port of the master node. The default values is 8088. -http_port = 8088 - -# HTTP host of the master node. The default values is "0.0.0.0". -http_host = "0.0.0.0" - -# Number of working jobs in the master node. The default value is 1. -n_jobs = 1 - -# Meta information timeout. The default value is 10s. -meta_timeout = "10s" - -# Username for the master node dashboard. -dashboard_user_name = "" - -# Password for the master node dashboard. -dashboard_password = "" - -[server] - -# Default number of returned items. The default value is 10. -default_n = 10 - -# Secret key for RESTful APIs (SSL required). -api_key = "" - -# Clock error in the cluster. The default value is 5s. -clock_error = "5s" - -# Insert new users while inserting feedback. The default value is true. -auto_insert_user = true - -# Insert new items while inserting feedback. The default value is true. -auto_insert_item = false - -# Server-side cache expire time. The default value is 10s. -cache_expire = "10s" - -[recommend] - -# The cache size for recommended/popular/latest items. The default value is 10. -cache_size = 100 - -# Recommended cache expire time. The default value is 72h. -cache_expire = "72h" - -[recommend.data_source] - -# The feedback types for positive events. -positive_feedback_types = ["star","like"] - -# The feedback types for read events. -read_feedback_types = ["read"] - -# The time-to-live (days) of positive feedback, 0 means disabled. The default value is 0. -positive_feedback_ttl = 0 - -# The time-to-live (days) of items, 0 means disabled. The default value is 0. -item_ttl = 0 - -[recommend.popular] - -# The time window of popular items. The default values is 4320h. -popular_window = "720h" - -[recommend.user_neighbors] - -# The type of neighbors for users. There are three types: -# similar: Neighbors are found by number of common labels. -# related: Neighbors are found by number of common liked items. -# auto: If a user have labels, neighbors are found by number of common labels. -# If this user have no labels, neighbors are found by number of common liked items. -# The default value is "auto". -neighbor_type = "similar" - -# Enable approximate user neighbor searching using vector index. The default value is true. -enable_index = true - -# Minimal recall for approximate user neighbor searching. The default value is 0.8. -index_recall = 0.8 - -# Maximal number of fit epochs for approximate user neighbor searching vector index. The default value is 3. -index_fit_epoch = 3 - -[recommend.item_neighbors] - -# The type of neighbors for items. There are three types: -# similar: Neighbors are found by number of common labels. -# related: Neighbors are found by number of common users. -# auto: If a item have labels, neighbors are found by number of common labels. -# If this item have no labels, neighbors are found by number of common users. -# The default value is "auto". -neighbor_type = "similar" - -# Enable approximate item neighbor searching using vector index. The default value is true. -enable_index = true - -# Minimal recall for approximate item neighbor searching. The default value is 0.8. -index_recall = 0.8 - -# Maximal number of fit epochs for approximate item neighbor searching vector index. The default value is 3. -index_fit_epoch = 3 - -[recommend.collaborative] - -# Enable approximate collaborative filtering recommend using vector index. The default value is true. -enable_index = true - -# Minimal recall for approximate collaborative filtering recommend. The default value is 0.9. -index_recall = 0.9 - -# Maximal number of fit epochs for approximate collaborative filtering recommend vector index. The default value is 3. -index_fit_epoch = 3 - -# The time period for model fitting. The default value is "60m". -model_fit_period = "60m" - -# The time period for model searching. The default value is "360m". -model_search_period = "360m" - -# The number of epochs for model searching. The default value is 100. -model_search_epoch = 100 - -# The number of trials for model searching. The default value is 10. -model_search_trials = 10 - -# Enable searching models of different sizes, which consume more memory. The default value is false. -enable_model_size_search = false - -[recommend.replacement] - -# Replace historical items back to recommendations. The default value is false. -enable_replacement = false - -# Decay the weights of replaced items from positive feedbacks. The default value is 0.8. -positive_replacement_decay = 0.8 - -# Decay the weights of replaced items from read feedbacks. The default value is 0.6. -read_replacement_decay = 0.6 - -[recommend.offline] - -# The time period to check recommendation for users. The default values is 1m. -check_recommend_period = "1m" - -# The time period to refresh recommendation for inactive users. The default values is 120h. -refresh_recommend_period = "24h" - -# Enable latest recommendation during offline recommendation. The default value is false. -enable_latest_recommend = true - -# Enable popular recommendation during offline recommendation. The default value is false. -enable_popular_recommend = false - -# Enable user-based similarity recommendation during offline recommendation. The default value is false. -enable_user_based_recommend = true - -# Enable item-based similarity recommendation during offline recommendation. The default value is false. -enable_item_based_recommend = false - -# Enable collaborative filtering recommendation during offline recommendation. The default value is true. -enable_collaborative_recommend = true - -# Enable click-though rate prediction during offline recommendation. Otherwise, results from multi-way recommendation -# would be merged randomly. The default value is false. -enable_click_through_prediction = true - -# The explore recommendation method is used to inject popular items or latest items into recommended result: -# popular: Recommend popular items to cold-start users. -# latest: Recommend latest items to cold-start users. -# The default values is { popular = 0.0, latest = 0.0 }. -explore_recommend = { popular = 0.1, latest = 0.2 } - -[recommend.online] - -# The fallback recommendation method is used when cached recommendation drained out: -# item_based: Recommend similar items to cold-start users. -# popular: Recommend popular items to cold-start users. -# latest: Recommend latest items to cold-start users. -# Recommenders are used in order. The default values is ["latest"]. -fallback_recommend = ["item_based", "latest"] - -# The number of feedback used in fallback item-based similar recommendation. The default values is 10. -num_feedback_fallback_item_based = 10 diff --git a/go.mod b/go.mod index a8abb7436..57ddda4a3 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,13 @@ go 1.18 require ( github.com/ReneKroon/ttlcache/v2 v2.11.0 - github.com/alicebob/miniredis/v2 v2.16.1 + github.com/alicebob/miniredis/v2 v2.23.0 github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de github.com/benhoyt/goawk v1.20.0 github.com/bits-and-blooms/bitset v1.2.1 github.com/chewxy/math32 v1.10.1 github.com/emicklei/go-restful-openapi/v2 v2.9.0 - github.com/emicklei/go-restful/v3 v3.8.0 + github.com/emicklei/go-restful/v3 v3.9.0 github.com/go-playground/locales v0.14.0 github.com/go-playground/universal-translator v0.18.0 github.com/go-playground/validator/v10 v10.11.0 @@ -29,24 +29,23 @@ require ( github.com/mailru/go-clickhouse/v2 v2.0.0 github.com/mitchellh/mapstructure v1.5.0 github.com/orcaman/concurrent-map v1.0.0 - github.com/prometheus/client_golang v1.12.2 + github.com/prometheus/client_golang v1.13.0 github.com/rakyll/statik v0.1.7 - github.com/samber/lo v1.25.0 - github.com/schollz/progressbar/v3 v3.8.7 + github.com/samber/lo v1.27.0 + github.com/schollz/progressbar/v3 v3.9.0 github.com/scylladb/go-set v1.0.2 github.com/sijms/go-ora/v2 v2.4.27 github.com/spf13/cobra v1.5.0 github.com/spf13/viper v1.12.0 - github.com/steinfletcher/apitest v1.5.11 - github.com/stretchr/testify v1.7.5 + github.com/steinfletcher/apitest v1.5.12 + github.com/stretchr/testify v1.8.0 github.com/thoas/go-funk v0.9.2 - go.mongodb.org/mongo-driver v1.10.0 - go.uber.org/atomic v1.9.0 - go.uber.org/zap v1.21.0 - golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75 - gonum.org/v1/gonum v0.0.0-20190409070159-6e46824336d2 + go.mongodb.org/mongo-driver v1.10.1 + go.uber.org/atomic v1.10.0 + go.uber.org/zap v1.22.0 + golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e google.golang.org/grpc v1.48.0 - google.golang.org/protobuf v1.28.0 + google.golang.org/protobuf v1.28.1 gopkg.in/yaml.v2 v2.4.0 gorm.io/driver/clickhouse v0.4.2 gorm.io/driver/mysql v1.3.4 @@ -107,28 +106,27 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.37.0 // indirect - github.com/prometheus/procfs v0.7.3 // indirect + github.com/prometheus/procfs v0.8.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect - github.com/rivo/uniseg v0.2.0 // indirect + github.com/rivo/uniseg v0.3.4 // indirect github.com/shopspring/decimal v1.3.1 // indirect github.com/spf13/afero v1.9.2 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.4.0 // indirect github.com/subosito/gotenv v1.4.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.1 // indirect github.com/xdg-go/stringprep v1.0.3 // indirect github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect - github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da // indirect + github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 // indirect go.uber.org/multierr v1.8.0 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/net v0.0.0-20220708220712-1185a9018129 // indirect golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect - golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect - golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 // indirect + golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab // indirect + golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/tools v0.1.11 // indirect google.golang.org/genproto v0.0.0-20220719170305-83ca9fad585f // indirect diff --git a/go.sum b/go.sum index 5638e717c..3c05f7012 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= -github.com/alicebob/miniredis/v2 v2.16.1 h1:ikfCfUHWlfiVCVVaaDO60SBgPWS4UNIi1A7p7QmUVyw= -github.com/alicebob/miniredis/v2 v2.16.1/go.mod h1:gquAfGbzn92jvtrSC69+6zZnwSODVXVpYDRaGhWaL6I= +github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw= +github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de h1:FxWPpzIjnTlhPwqqXc4/vE0f7GvRjuAsbW+HOIe8KnA= github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoUhZrYW9p1lxo/cm8EmUOOzAPSEZNGF2DK1dJgw= @@ -100,8 +100,8 @@ github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25Kn github.com/emicklei/go-restful-openapi/v2 v2.9.0 h1:djsWqjhI0EVYfkLCCX6jZxUkLmYUq2q9tt09ZbixfyE= github.com/emicklei/go-restful-openapi/v2 v2.9.0/go.mod h1:VKNgZyYviM1hnyrjD9RDzP2RuE94xTXxV+u6MGN4v4k= github.com/emicklei/go-restful/v3 v3.7.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= -github.com/emicklei/go-restful/v3 v3.8.0 h1:eCZ8ulSerjdAiaNpF7GxXIE7ZCMo1moN1qX+S609eVw= -github.com/emicklei/go-restful/v3 v3.8.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.9.0 h1:XwGDlfxEnQZzuopoqxwSEllNcCOM9DhhFyhFIIGKwxE= +github.com/emicklei/go-restful/v3 v3.9.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -415,8 +415,8 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= -github.com/prometheus/client_golang v1.12.2 h1:51L9cDoUHVrXx4zWYlcLQIZ+d+VXHgqnYKkIuq4g/34= -github.com/prometheus/client_golang v1.12.2/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= +github.com/prometheus/client_golang v1.13.0 h1:b71QUfeo5M8gq2+evJdTPfZhYMAU0uKPkyPJ7TPsloU= +github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -432,16 +432,18 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= +github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/rakyll/statik v0.1.7 h1:OF3QCZUuyPxuGEP7B4ypUa7sB/iHtqOTDYZXGM8KOdQ= github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Unghqrcc= github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.3.4 h1:3Z3Eu6FGHZWSfNKJTOUiPatWwfc7DzJRU04jFUqJODw= +github.com/rivo/uniseg v0.3.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= @@ -451,11 +453,11 @@ github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/samber/lo v1.25.0 h1:H8F6cB0RotRdgcRCivTByAQePaYhGMdOTJIj2QFS2I0= -github.com/samber/lo v1.25.0/go.mod h1:2I7tgIv8Q1SG2xEIkRq0F2i2zgxVpnyPOP0d3Gj2r+A= +github.com/samber/lo v1.27.0 h1:GOyDWxsblvqYobqsmUuMddPa2/mMzkKyojlXol4+LaQ= +github.com/samber/lo v1.27.0/go.mod h1:it33p9UtPMS7z72fP4gw/EIfQB2eI8ke7GR2wc6+Rhg= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/schollz/progressbar/v3 v3.8.7 h1:rtje4lnXVD1Dy/RtPpGd2ijLCmQ7Su3G2ia8dJcRKIo= -github.com/schollz/progressbar/v3 v3.8.7/go.mod h1:W5IEwbJecncFGBvuEh4A7HT1nZZ6WNIL2i3qbnI0WKY= +github.com/schollz/progressbar/v3 v3.9.0 h1:k9SRNQ8KZyibz1UZOaKxnkUE3iGtmGSDt1YY9KlCYQk= +github.com/schollz/progressbar/v3 v3.9.0/go.mod h1:W5IEwbJecncFGBvuEh4A7HT1nZZ6WNIL2i3qbnI0WKY= github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE= github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs= github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg= @@ -481,12 +483,11 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.12.0 h1:CZ7eSOd3kZoaYDLbXnmzgQI5RlciuXBMA+18HwHRfZQ= github.com/spf13/viper v1.12.0/go.mod h1:b6COn30jlNxbm/V2IqWiNWkJ+vZNiMNksliPCiuKtSI= -github.com/steinfletcher/apitest v1.5.11 h1:bG3hq3sA4+oPHln3O/xQ6LzsQgN0J2WJl+6EpydQZ8Q= -github.com/steinfletcher/apitest v1.5.11/go.mod h1:cf7Bneo52IIAgpqhP8xaLlzWgAiQ9fHtsDMjeDnZ3so= +github.com/steinfletcher/apitest v1.5.12 h1:zv+UiSXxDspQ8R7DILKiQVIlamDH7ufCsyQuzWsn3H4= +github.com/steinfletcher/apitest v1.5.12/go.mod h1:cf7Bneo52IIAgpqhP8xaLlzWgAiQ9fHtsDMjeDnZ3so= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= -github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -496,8 +497,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.7.5 h1:s5PTfem8p8EbKQOctVV53k6jCJt3UX4IEJzwh+C324Q= -github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/subosito/gotenv v1.4.0 h1:yAzM1+SmVcz5R4tXGsNMu1jUl2aOJXoiWUCEwwnGrvs= github.com/subosito/gotenv v1.4.0/go.mod h1:mZd6rFysKEcUhUHXJk0C/08wAgyDBFuwEYL7vWWGaGo= github.com/thoas/go-funk v0.9.2 h1:oKlNYv0AY5nyf9g+/GhMgS/UO2ces0QRdPKwkhY3VCk= @@ -518,11 +519,12 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da h1:NimzV1aGyq29m5ukMK0AMWEhFaL/lrEOaephfuoiARg= -github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= +github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 h1:5mLPGnFdSsevFRFc9q3yYbBkB6tsm4aCwwQV/j1JQAQ= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.mongodb.org/mongo-driver v1.10.0 h1:UtV6N5k14upNp4LTduX0QCufG124fSu25Wz9tu94GLg= -go.mongodb.org/mongo-driver v1.10.0/go.mod h1:wsihk0Kdgv8Kqu1Anit4sfK+22vSFbUrAVEYRhCXrA8= +go.mongodb.org/mongo-driver v1.10.1 h1:NujsPveKwHaWuKUer/ceo9DzEe7HIj1SlJ6uvXZG0S4= +go.mongodb.org/mongo-driver v1.10.1/go.mod h1:z4XpeoU6w+9Vht+jAFyLgVrD+jGSQQe0+CBWFHNiHt8= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -535,8 +537,8 @@ go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= @@ -551,8 +553,9 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= +go.uber.org/zap v1.22.0 h1:Zcye5DUgBloQ9BaT4qc9BnjOFog5TvBSAGkJ3Nf70c0= +go.uber.org/zap v1.22.0/go.mod h1:H4siCOZOrAolnUPJEkfaSjDqyP+BDS0DdDWzwcgt3+U= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= @@ -570,10 +573,10 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= @@ -583,8 +586,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75 h1:x03zeu7B2B11ySp+daztnwM5oBJ/8wGUSqrwcw9L0RA= -golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= +golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= +golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -738,13 +741,13 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab h1:2QkjZIsXupsJbJIdSjjUOgWK3aEtzyuh2mPt3l/CkeU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM= -golang.org/x/term v0.0.0-20220526004731-065cf7ba2467/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -759,7 +762,6 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -822,10 +824,6 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.0.0-20190409070159-6e46824336d2 h1:IZ343hUHFQ/v9IrSpcT1ih//e8VIctin7w2WyVl4kcc= -gonum.org/v1/gonum v0.0.0-20190409070159-6e46824336d2/go.mod h1:2ltnJ7xHfj0zHS40VVPYEAAMTa3ZGguvHGBSJeRWqE0= -gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= -gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= @@ -924,8 +922,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/master/master.go b/master/master.go index b27c1bce7..1d9c274a1 100644 --- a/master/master.go +++ b/master/master.go @@ -17,15 +17,6 @@ package master import ( "context" "fmt" - "github.com/emicklei/go-restful/v3" - "github.com/juju/errors" - "github.com/zhenghaoz/gorse/base/encoding" - "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/base/task" - "github.com/zhenghaoz/gorse/model/click" - "github.com/zhenghaoz/gorse/server" - "go.uber.org/zap" - "google.golang.org/grpc" "math" "math/rand" "net" @@ -33,12 +24,21 @@ import ( "time" "github.com/ReneKroon/ttlcache/v2" + "github.com/emicklei/go-restful/v3" + "github.com/juju/errors" "github.com/zhenghaoz/gorse/base" + "github.com/zhenghaoz/gorse/base/encoding" + "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" + "github.com/zhenghaoz/gorse/model/click" "github.com/zhenghaoz/gorse/model/ranking" "github.com/zhenghaoz/gorse/protocol" + "github.com/zhenghaoz/gorse/server" "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" + "go.uber.org/zap" + "google.golang.org/grpc" ) // Master is the master node. @@ -48,7 +48,7 @@ type Master struct { grpcServer *grpc.Server taskMonitor *task.Monitor - taskScheduler *task.Scheduler + jobsScheduler *task.JobsScheduler cacheFile string // cluster meta cache @@ -81,7 +81,8 @@ type Master struct { // events fitTicker *time.Ticker - importedChan chan bool // feedback inserted events + importedChan chan struct{} // feedback inserted events + loadDataChan chan struct{} // dataset loaded events } // NewMaster creates a master node. @@ -99,20 +100,18 @@ func NewMaster(cfg *config.Config, cacheFile string) *Master { // create task monitor cacheFile: cacheFile, taskMonitor: taskMonitor, - taskScheduler: task.NewTaskScheduler(), + jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs), // default ranking model rankingModelName: "bpr", rankingModelSearcher: ranking.NewModelSearcher( cfg.Recommend.Collaborative.ModelSearchEpoch, cfg.Recommend.Collaborative.ModelSearchTrials, - cfg.Master.NumJobs, cfg.Recommend.Collaborative.EnableModelSizeSearch, ), // default click model clickModelSearcher: click.NewModelSearcher( cfg.Recommend.Collaborative.ModelSearchEpoch, cfg.Recommend.Collaborative.ModelSearchTrials, - cfg.Master.NumJobs, cfg.Recommend.Collaborative.EnableModelSizeSearch, ), RestServer: server.RestServer{ @@ -131,7 +130,8 @@ func NewMaster(cfg *config.Config, cacheFile string) *Master { WebService: new(restful.WebService), }, fitTicker: time.NewTicker(cfg.Recommend.Collaborative.ModelFitPeriod), - importedChan: make(chan bool), + importedChan: make(chan struct{}), + loadDataChan: make(chan struct{}), } } @@ -161,7 +161,7 @@ func (m *Master) Serve() { CollaborativeFilteringPrecision10.Set(float64(m.rankingScore.Precision)) CollaborativeFilteringRecall10.Set(float64(m.rankingScore.Recall)) CollaborativeFilteringNDCG10.Set(float64(m.rankingScore.NDCG)) - MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes())) + MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes())) } if m.localCache.ClickModel != nil { log.Logger().Info("load cached click model", @@ -174,7 +174,7 @@ func (m *Master) Serve() { RankingPrecision.Set(float64(m.clickScore.Precision)) RankingRecall.Set(float64(m.clickScore.Recall)) RankingAUC.Set(float64(m.clickScore.AUC)) - MemoryInuseBytesVec.WithLabelValues("ranking_model").Set(float64(m.ClickModel.Bytes())) + MemoryInUseBytesVec.WithLabelValues("ranking_model").Set(float64(m.ClickModel.Bytes())) } // create cluster meta cache @@ -208,12 +208,6 @@ func (m *Master) Serve() { m.RestServer.HiddenItemsManager = server.NewHiddenItemsManager(&m.RestServer) m.RestServer.PopularItemsCache = server.NewPopularItemsCache(&m.RestServer) - // pre-lock privileged tasks - tasksNames := []string{TaskLoadDataset, TaskFindItemNeighbors, TaskFindUserNeighbors, TaskFitRankingModel, TaskFitClickModel} - for _, taskName := range tasksNames { - m.taskScheduler.PreLock(taskName) - } - go m.RunPrivilegedTasksLoop() log.Logger().Info("start model fit", zap.Duration("period", m.Config.Recommend.Collaborative.ModelFitPeriod)) go m.RunRagtagTasksLoop() @@ -252,19 +246,20 @@ func (m *Master) Shutdown() { func (m *Master) RunPrivilegedTasksLoop() { defer base.CheckPanic() var ( - lastNumRankingUsers int - lastNumRankingItems int - lastNumRankingFeedback int - lastNumClickUsers int - lastNumClickItems int - lastNumClickFeedback int - err error + err error + tasks = []Task{ + NewFitClickModelTask(m), + NewFitRankingModelTask(m), + NewFindUserNeighborsTask(m), + NewFindItemNeighborsTask(m), + } + firstLoop = true ) go func() { - m.importedChan <- true + m.importedChan <- struct{}{} for { if m.checkDataImported() { - m.importedChan <- true + m.importedChan <- struct{}{} } time.Sleep(time.Second) } @@ -274,11 +269,6 @@ func (m *Master) RunPrivilegedTasksLoop() { case <-m.fitTicker.C: case <-m.importedChan: } - // pre-lock privileged tasks - tasksNames := []string{TaskLoadDataset, TaskFindItemNeighbors, TaskFindUserNeighbors, TaskFitRankingModel, TaskFitClickModel} - for _, taskName := range tasksNames { - m.taskScheduler.PreLock(taskName) - } // download dataset err = m.runLoadDatasetTask() @@ -286,27 +276,33 @@ func (m *Master) RunPrivilegedTasksLoop() { log.Logger().Error("failed to load ranking dataset", zap.Error(err)) continue } - - // fit ranking model - lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedback, err = - m.runRankingRelatedTasks(lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedback) - if err != nil { - log.Logger().Error("failed to fit ranking model", zap.Error(err)) + if m.rankingTrainSet.UserCount() == 0 && m.rankingTrainSet.ItemCount() == 0 && m.rankingTrainSet.Count() == 0 { + log.Logger().Warn("empty ranking dataset", + zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes)) continue } - // fit click model - lastNumClickUsers, lastNumClickItems, lastNumClickFeedback, err = - m.runFitClickModelTask(lastNumClickUsers, lastNumClickItems, lastNumClickFeedback) - if err != nil { - log.Logger().Error("failed to fit click model", zap.Error(err)) - m.taskMonitor.Fail(TaskFitClickModel, err.Error()) - continue + if firstLoop { + m.loadDataChan <- struct{}{} + firstLoop = false } - // release locks - for _, taskName := range tasksNames { - m.taskScheduler.UnLock(taskName) + var registeredTask []Task + for _, t := range tasks { + if m.jobsScheduler.Register(t.name(), t.priority(), true) { + registeredTask = append(registeredTask, t) + } + } + for _, t := range registeredTask { + go func(task Task) { + j := m.jobsScheduler.GetJobsAllocator(task.name()) + defer m.jobsScheduler.Unregister(task.name()) + j.Init() + if err := task.run(j); err != nil { + log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err)) + return + } + }(t) } } } @@ -315,40 +311,36 @@ func (m *Master) RunPrivilegedTasksLoop() { // rankingModelSearcher, clickSearchedModel and clickSearchedScore. func (m *Master) RunRagtagTasksLoop() { defer base.CheckPanic() + <-m.loadDataChan var ( - lastNumRankingUsers int - lastNumRankingItems int - lastNumRankingFeedbacks int - lastNumClickUsers int - lastNumClickItems int - lastNumClickFeedbacks int - err error + err error + tasks = []Task{ + NewCacheGarbageCollectionTask(m), + NewSearchRankingModelTask(m), + NewSearchClickModelTask(m), + } ) for { - // garbage collection - m.taskScheduler.Lock(TaskCacheGarbageCollection) - if err = m.runCacheGarbageCollectionTask(); err != nil { - log.Logger().Error("failed to collect garbage", zap.Error(err)) - m.taskMonitor.Fail(TaskCacheGarbageCollection, err.Error()) - } - m.taskScheduler.UnLock(TaskCacheGarbageCollection) - // search optimal ranking model - lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedbacks, err = - m.runSearchRankingModelTask(lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedbacks) - if err != nil { - log.Logger().Error("failed to search ranking model", zap.Error(err)) - m.taskMonitor.Fail(TaskSearchRankingModel, err.Error()) - time.Sleep(time.Minute) + if m.rankingTrainSet == nil || m.clickTrainSet == nil { + time.Sleep(time.Second) continue } - // search optimal click model - lastNumClickUsers, lastNumClickItems, lastNumClickFeedbacks, err = - m.runSearchClickModelTask(lastNumClickUsers, lastNumClickItems, lastNumClickFeedbacks) - if err != nil { - log.Logger().Error("failed to search click model", zap.Error(err)) - m.taskMonitor.Fail(TaskSearchClickModel, err.Error()) - time.Sleep(time.Minute) - continue + var registeredTask []Task + for _, t := range tasks { + if m.jobsScheduler.Register(t.name(), t.priority(), false) { + registeredTask = append(registeredTask, t) + } + } + for _, t := range registeredTask { + go func(task Task) { + defer m.jobsScheduler.Unregister(task.name()) + j := m.jobsScheduler.GetJobsAllocator(task.name()) + j.Init() + if err = task.run(j); err != nil { + log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err)) + m.taskMonitor.Fail(task.name(), err.Error()) + } + }(t) } time.Sleep(m.Config.Recommend.Collaborative.ModelSearchPeriod) } diff --git a/master/master_test.go b/master/master_test.go index ae95721f1..3f2e17748 100644 --- a/master/master_test.go +++ b/master/master_test.go @@ -14,13 +14,14 @@ package master import ( + "testing" + "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" - "testing" ) type mockMaster struct { diff --git a/master/metrics.go b/master/metrics.go index bdbad29df..2184e99f7 100644 --- a/master/metrics.go +++ b/master/metrics.go @@ -15,13 +15,14 @@ package master import ( + "time" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/samber/lo" "github.com/scylladb/go-set/i32set" "github.com/zhenghaoz/gorse/server" "github.com/zhenghaoz/gorse/storage/cache" - "time" ) const ( @@ -218,7 +219,7 @@ var ( Subsystem: "master", Name: "negative_feedbacks_total", }) - MemoryInuseBytesVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ + MemoryInUseBytesVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "gorse", Subsystem: "master", Name: "memory_inuse_bytes", diff --git a/master/tasks.go b/master/tasks.go index d3d57ff03..a79061b61 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -16,6 +16,11 @@ package master import ( "fmt" + "math" + "sort" + "strings" + "time" + "github.com/chewxy/math32" "github.com/juju/errors" "github.com/samber/lo" @@ -26,6 +31,7 @@ import ( "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/parallel" "github.com/zhenghaoz/gorse/base/search" + "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model/click" "github.com/zhenghaoz/gorse/model/ranking" @@ -33,11 +39,7 @@ import ( "github.com/zhenghaoz/gorse/storage/data" "go.uber.org/atomic" "go.uber.org/zap" - "math" "modernc.org/sortutil" - "sort" - "strings" - "time" ) const ( @@ -56,6 +58,12 @@ const ( similarityShrink = 100 ) +type Task interface { + name() string + priority() int + run(j *task.JobsAllocator) error +} + // runLoadDatasetTask loads dataset. func (m *Master) runLoadDatasetTask() error { initialStartTime := time.Now() @@ -172,8 +180,8 @@ func (m *Master) runLoadDatasetTask() error { rankingDataset = nil m.rankingModelMutex.Unlock() LoadDatasetStepSecondsVec.WithLabelValues("split_ranking_dataset").Set(time.Since(startTime).Seconds()) - MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_train_set").Set(float64(m.rankingTrainSet.Bytes())) - MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_test_set").Set(float64(m.rankingTestSet.Bytes())) + MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_train_set").Set(float64(m.rankingTrainSet.Bytes())) + MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_test_set").Set(float64(m.rankingTestSet.Bytes())) // split click dataset startTime = time.Now() @@ -182,8 +190,8 @@ func (m *Master) runLoadDatasetTask() error { clickDataset = nil m.clickModelMutex.Unlock() LoadDatasetStepSecondsVec.WithLabelValues("split_click_dataset").Set(time.Since(startTime).Seconds()) - MemoryInuseBytesVec.WithLabelValues("ranking_train_set").Set(float64(m.clickTrainSet.Bytes())) - MemoryInuseBytesVec.WithLabelValues("ranking_test_set").Set(float64(m.clickTestSet.Bytes())) + MemoryInUseBytesVec.WithLabelValues("ranking_train_set").Set(float64(m.clickTrainSet.Bytes())) + MemoryInUseBytesVec.WithLabelValues("ranking_test_set").Set(float64(m.clickTestSet.Bytes())) LoadDatasetTotalSeconds.Set(time.Since(initialStartTime).Seconds()) return nil @@ -205,17 +213,49 @@ func (m *Master) estimateFindItemNeighborsComplexity(dataset *ranking.DataSet) i return complexity } -// runFindItemNeighborsTask updates neighbors of items. -func (m *Master) runFindItemNeighborsTask(dataset *ranking.DataSet) { +// FindItemNeighborsTask updates neighbors of items. +type FindItemNeighborsTask struct { + *Master + lastNumItems int + lastNumFeedback int +} + +func NewFindItemNeighborsTask(m *Master) *FindItemNeighborsTask { + return &FindItemNeighborsTask{Master: m} +} + +func (t *FindItemNeighborsTask) name() string { + return TaskFindItemNeighbors +} + +func (t *FindItemNeighborsTask) priority() int { + return -t.rankingTrainSet.ItemCount() * t.rankingTrainSet.ItemCount() +} + +func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { + t.rankingDataMutex.RLock() + defer t.rankingDataMutex.RUnlock() + dataset := t.rankingTrainSet + numItems := dataset.ItemCount() + numFeedback := dataset.Count() + + if numItems == 0 { + t.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.") + return nil + } else if numItems == t.lastNumItems && numFeedback == t.lastNumFeedback { + log.Logger().Info("No item neighbors need to be updated.") + return nil + } + startTaskTime := time.Now() - m.taskMonitor.Start(TaskFindItemNeighbors, m.estimateFindItemNeighborsComplexity(dataset)) + t.taskMonitor.Start(TaskFindItemNeighbors, t.estimateFindItemNeighborsComplexity(dataset)) log.Logger().Info("start searching neighbors of items", - zap.Int("n_cache", m.Config.Recommend.CacheSize)) + zap.Int("n_cache", t.Config.Recommend.CacheSize)) // create progress tracker completed := make(chan struct{}, 1000) go func() { completedCount, previousCount := 0, 0 - ticker := time.NewTicker(time.Second) + ticker := time.NewTicker(time.Second * 10) for { select { case _, ok := <-completed: @@ -227,74 +267,78 @@ func (m *Master) runFindItemNeighborsTask(dataset *ranking.DataSet) { throughput := completedCount - previousCount previousCount = completedCount if throughput > 0 { - m.taskMonitor.Add(TaskFindItemNeighbors, throughput*dataset.ItemCount()) + t.taskMonitor.Add(TaskFindItemNeighbors, throughput*dataset.ItemCount()) log.Logger().Debug("searching neighbors of items", zap.Int("n_complete_items", completedCount), zap.Int("n_items", dataset.ItemCount()), - zap.Int("throughput", throughput)) + zap.Int("throughput", throughput/10)) } } } }() userIDF := make([]float32, dataset.UserCount()) - if m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeRelated || - m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto { + if t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeRelated || + t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto { for _, feedbacks := range dataset.ItemFeedback { sort.Sort(sortutil.Int32Slice(feedbacks)) } - m.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeedback)) + t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeedback)) // inverse document frequency of users for i := range dataset.UserFeedback { userIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(dataset.UserFeedback[i]))) } - m.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.UserFeedback)) + t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.UserFeedback)) } labeledItems := make([][]int32, dataset.NumItemLabels) labelIDF := make([]float32, dataset.NumItemLabels) - if m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeSimilar || - m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto { + if t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeSimilar || + t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto { for i, itemLabels := range dataset.ItemLabels { sort.Sort(sortutil.Int32Slice(itemLabels)) for _, label := range itemLabels { labeledItems[label] = append(labeledItems[label], int32(i)) } } - m.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemLabels)) + t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemLabels)) // inverse document frequency of labels for i := range labeledItems { labelIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(labeledItems[i]))) } - m.taskMonitor.Add(TaskFindItemNeighbors, len(labeledItems)) + t.taskMonitor.Add(TaskFindItemNeighbors, len(labeledItems)) } start := time.Now() var err error - if m.Config.Recommend.ItemNeighbors.EnableIndex { - err = m.findItemNeighborsIVF(dataset, labelIDF, userIDF, completed) + if t.Config.Recommend.ItemNeighbors.EnableIndex { + err = t.findItemNeighborsIVF(dataset, labelIDF, userIDF, completed, j) } else { - err = m.findItemNeighborsBruteForce(dataset, labeledItems, labelIDF, userIDF, completed) + err = t.findItemNeighborsBruteForce(dataset, labeledItems, labelIDF, userIDF, completed, j) } searchTime := time.Since(start) close(completed) if err != nil { log.Logger().Error("failed to searching neighbors of items", zap.Error(err)) - m.taskMonitor.Fail(TaskFindItemNeighbors, err.Error()) + t.taskMonitor.Fail(TaskFindItemNeighbors, err.Error()) FindItemNeighborsTotalSeconds.Set(0) } else { - if err := m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateItemNeighborsTime), time.Now())); err != nil { + if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateItemNeighborsTime), time.Now())); err != nil { log.Logger().Error("failed to set neighbors of items update time", zap.Error(err)) } log.Logger().Info("complete searching neighbors of items", zap.String("search_time", searchTime.String())) - m.taskMonitor.Finish(TaskFindItemNeighbors) + t.taskMonitor.Finish(TaskFindItemNeighbors) FindItemNeighborsTotalSeconds.Set(time.Since(startTaskTime).Seconds()) } + + t.lastNumItems = numItems + t.lastNumFeedback = numFeedback + return nil } func (m *Master) findItemNeighborsBruteForce(dataset *ranking.DataSet, labeledItems [][]int32, - labelIDF, userIDF []float32, completed chan struct{}) error { + labelIDF, userIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { var ( updateItemCount atomic.Float64 findNeighborSeconds atomic.Float64 @@ -314,7 +358,7 @@ func (m *Master) findItemNeighborsBruteForce(dataset *ranking.DataSet, labeledIt return errors.NotImplementedf("item neighbor type `%v`", m.Config.Recommend.ItemNeighbors.NeighborType) } - err := parallel.Parallel(dataset.ItemCount(), m.Config.Master.NumJobs, func(workerId, itemIndex int) error { + err := parallel.DynamicParallel(dataset.ItemCount(), j, func(workerId, itemIndex int) error { defer func() { completed <- struct{}{} }() @@ -372,7 +416,7 @@ func (m *Master) findItemNeighborsBruteForce(dataset *ranking.DataSet, labeledIt return nil } -func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userIDF []float32, completed chan struct{}) error { +func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { var ( updateItemCount atomic.Float64 findNeighborSeconds atomic.Float64 @@ -400,7 +444,8 @@ func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userID return errors.NotImplementedf("item neighbor type `%v`", m.Config.Recommend.ItemNeighbors.NeighborType) } - builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize, search.SetIVFNumJobs(m.Config.Master.NumJobs)) + builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize, + search.SetIVFJobsAllocator(j)) var recall float32 index, recall = builder.Build(m.Config.Recommend.ItemNeighbors.IndexRecall, m.Config.Recommend.ItemNeighbors.IndexFitEpoch, @@ -412,7 +457,7 @@ func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userID } buildIndexSeconds.Add(time.Since(buildStart).Seconds()) - err := parallel.Parallel(dataset.ItemCount(), m.Config.Master.NumJobs, func(workerId, itemIndex int) error { + err := parallel.DynamicParallel(dataset.ItemCount(), j, func(workerId, itemIndex int) error { defer func() { completed <- struct{}{} }() @@ -479,12 +524,44 @@ func (m *Master) estimateFindUserNeighborsComplexity(dataset *ranking.DataSet) i return complexity } -// runFindUserNeighborsTask updates neighbors of users. -func (m *Master) runFindUserNeighborsTask(dataset *ranking.DataSet) { +// FindUserNeighborsTask updates neighbors of users. +type FindUserNeighborsTask struct { + *Master + lastNumUsers int + lastNumFeedback int +} + +func NewFindUserNeighborsTask(m *Master) *FindUserNeighborsTask { + return &FindUserNeighborsTask{Master: m} +} + +func (t *FindUserNeighborsTask) name() string { + return TaskFindUserNeighbors +} + +func (t *FindUserNeighborsTask) priority() int { + return -t.rankingTrainSet.UserCount() * t.rankingTrainSet.UserCount() +} + +func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { + t.rankingDataMutex.RLock() + defer t.rankingDataMutex.RUnlock() + dataset := t.rankingTrainSet + numUsers := dataset.UserCount() + numFeedback := dataset.Count() + + if numUsers == 0 { + t.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.") + return nil + } else if numUsers == t.lastNumUsers && numFeedback == t.lastNumFeedback { + log.Logger().Info("No update of user neighbors needed.") + return nil + } + startTaskTime := time.Now() - m.taskMonitor.Start(TaskFindUserNeighbors, m.estimateFindUserNeighborsComplexity(dataset)) + t.taskMonitor.Start(TaskFindUserNeighbors, t.estimateFindUserNeighborsComplexity(dataset)) log.Logger().Info("start searching neighbors of users", - zap.Int("n_cache", m.Config.Recommend.CacheSize)) + zap.Int("n_cache", t.Config.Recommend.CacheSize)) // create progress tracker completed := make(chan struct{}, 1000) go func() { @@ -501,7 +578,7 @@ func (m *Master) runFindUserNeighborsTask(dataset *ranking.DataSet) { throughput := completedCount - previousCount previousCount = completedCount if throughput > 0 { - m.taskMonitor.Add(TaskFindUserNeighbors, throughput*dataset.UserCount()) + t.taskMonitor.Add(TaskFindUserNeighbors, throughput*dataset.UserCount()) log.Logger().Debug("searching neighbors of users", zap.Int("n_complete_users", completedCount), zap.Int("n_users", dataset.UserCount()), @@ -512,62 +589,66 @@ func (m *Master) runFindUserNeighborsTask(dataset *ranking.DataSet) { }() itemIDF := make([]float32, dataset.ItemCount()) - if m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeRelated || - m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto { + if t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeRelated || + t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto { for _, feedbacks := range dataset.UserFeedback { sort.Sort(sortutil.Int32Slice(feedbacks)) } - m.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeedback)) + t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeedback)) // inverse document frequency of items for i := range dataset.ItemFeedback { itemIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(dataset.ItemFeedback[i]))) } - m.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.ItemFeedback)) + t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.ItemFeedback)) } labeledUsers := make([][]int32, dataset.NumUserLabels) labelIDF := make([]float32, dataset.NumUserLabels) - if m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeSimilar || - m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto { + if t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeSimilar || + t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto { for i, userLabels := range dataset.UserLabels { sort.Sort(sortutil.Int32Slice(userLabels)) for _, label := range userLabels { labeledUsers[label] = append(labeledUsers[label], int32(i)) } } - m.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserLabels)) + t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserLabels)) // inverse document frequency of labels for i := range labeledUsers { labelIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(labeledUsers[i]))) } - m.taskMonitor.Add(TaskFindUserNeighbors, len(labeledUsers)) + t.taskMonitor.Add(TaskFindUserNeighbors, len(labeledUsers)) } start := time.Now() var err error - if m.Config.Recommend.UserNeighbors.EnableIndex { - err = m.findUserNeighborsIVF(dataset, labelIDF, itemIDF, completed) + if t.Config.Recommend.UserNeighbors.EnableIndex { + err = t.findUserNeighborsIVF(dataset, labelIDF, itemIDF, completed, j) } else { - err = m.findUserNeighborsBruteForce(dataset, labeledUsers, labelIDF, itemIDF, completed) + err = t.findUserNeighborsBruteForce(dataset, labeledUsers, labelIDF, itemIDF, completed, j) } searchTime := time.Since(start) close(completed) if err != nil { log.Logger().Error("failed to searching neighbors of users", zap.Error(err)) - m.taskMonitor.Fail(TaskFindUserNeighbors, err.Error()) + t.taskMonitor.Fail(TaskFindUserNeighbors, err.Error()) FindUserNeighborsTotalSeconds.Set(0) } else { - if err := m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateUserNeighborsTime), time.Now())); err != nil { + if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateUserNeighborsTime), time.Now())); err != nil { log.Logger().Error("failed to set neighbors of users update time", zap.Error(err)) } log.Logger().Info("complete searching neighbors of users", zap.String("search_time", searchTime.String())) - m.taskMonitor.Finish(TaskFindUserNeighbors) + t.taskMonitor.Finish(TaskFindUserNeighbors) FindUserNeighborsTotalSeconds.Set(time.Since(startTaskTime).Seconds()) } + + t.lastNumUsers = numUsers + t.lastNumFeedback = numFeedback + return nil } -func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUsers [][]int32, labelIDF, itemIDF []float32, completed chan struct{}) error { +func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUsers [][]int32, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { var ( updateUserCount atomic.Float64 findNeighborSeconds atomic.Float64 @@ -587,7 +668,7 @@ func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUs return errors.NotImplementedf("user neighbor type `%v`", m.Config.Recommend.UserNeighbors.NeighborType) } - err := parallel.Parallel(dataset.UserCount(), m.Config.Master.NumJobs, func(workerId, userIndex int) error { + err := parallel.DynamicParallel(dataset.UserCount(), j, func(workerId, userIndex int) error { defer func() { completed <- struct{}{} }() @@ -636,7 +717,7 @@ func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUs return nil } -func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemIDF []float32, completed chan struct{}) error { +func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { var ( updateUserCount atomic.Float64 buildIndexSeconds atomic.Float64 @@ -665,7 +746,8 @@ func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemID return errors.NotImplementedf("user neighbor type `%v`", m.Config.Recommend.UserNeighbors.NeighborType) } - builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize, search.SetIVFNumJobs(m.Config.Master.NumJobs)) + builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize, + search.SetIVFJobsAllocator(j)) var recall float32 index, recall = builder.Build( m.Config.Recommend.UserNeighbors.IndexRecall, @@ -678,7 +760,7 @@ func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemID } buildIndexSeconds.Add(time.Since(buildStart).Seconds()) - err := parallel.Parallel(dataset.UserCount(), m.Config.Master.NumJobs, func(workerId, userIndex int) error { + err := parallel.DynamicParallel(dataset.UserCount(), j, func(workerId, userIndex int) error { defer func() { completed <- struct{}{} }() @@ -848,319 +930,387 @@ func (m *Master) checkItemNeighborCacheTimeout(itemId string, categories []strin return updateTime.Unix() <= modifiedTime.Unix() } -// fitRankingModel fits ranking model using passed dataset. After model fitted, following states are changed: -// 1. Ranking model version are increased. -// 2. Ranking model score are updated. -// 3. Ranking model, version and score are persisted to local cache. -func (m *Master) runRankingRelatedTasks( - lastNumUsers, lastNumItems, lastNumFeedback int, -) (numUsers, numItems, numFeedback int, err error) { - log.Logger().Info("start fitting ranking model", zap.Int("n_jobs", m.Config.Master.NumJobs)) - m.rankingDataMutex.RLock() - defer m.rankingDataMutex.RUnlock() - numUsers = m.rankingTrainSet.UserCount() - numItems = m.rankingTrainSet.ItemCount() - numFeedback = m.rankingTrainSet.Count() - - if numUsers == 0 && numItems == 0 && numFeedback == 0 { - log.Logger().Warn("empty ranking dataset", - zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes)) - return - } - numFeedbackChanged := numFeedback != lastNumFeedback - numUsersChanged := numUsers != lastNumUsers - numItemsChanged := numItems != lastNumItems +type FitRankingModelTask struct { + *Master + lastNumFeedback int +} + +func NewFitRankingModelTask(m *Master) *FitRankingModelTask { + return &FitRankingModelTask{Master: m} +} + +func (t *FitRankingModelTask) name() string { + return TaskFitRankingModel +} + +func (t *FitRankingModelTask) priority() int { + return -t.rankingTrainSet.Count() +} + +func (t *FitRankingModelTask) run(j *task.JobsAllocator) error { + t.rankingDataMutex.RLock() + defer t.rankingDataMutex.RUnlock() + dataset := t.rankingTrainSet + numFeedback := dataset.Count() var modelChanged bool - bestRankingName, bestRankingModel, bestRankingScore := m.rankingModelSearcher.GetBestModel() - m.rankingModelMutex.Lock() + bestRankingName, bestRankingModel, bestRankingScore := t.rankingModelSearcher.GetBestModel() + t.rankingModelMutex.Lock() if bestRankingModel != nil && !bestRankingModel.Invalid() && - (bestRankingName != m.rankingModelName || bestRankingModel.GetParams().ToString() != m.RankingModel.GetParams().ToString()) && - (bestRankingScore.NDCG > m.rankingScore.NDCG) { + (bestRankingName != t.rankingModelName || bestRankingModel.GetParams().ToString() != t.RankingModel.GetParams().ToString()) && + (bestRankingScore.NDCG > t.rankingScore.NDCG) { // 1. best ranking model must have been found. // 2. best ranking model must be different from current model // 3. best ranking model must perform better than current model - m.RankingModel = bestRankingModel - m.rankingModelName = bestRankingName - m.rankingScore = bestRankingScore + t.RankingModel = bestRankingModel + t.rankingModelName = bestRankingName + t.rankingScore = bestRankingScore modelChanged = true log.Logger().Info("find better ranking model", zap.Any("score", bestRankingScore), zap.String("name", bestRankingName), - zap.Any("params", m.RankingModel.GetParams())) + zap.Any("params", t.RankingModel.GetParams())) } - rankingModel := m.RankingModel - m.rankingModelMutex.Unlock() + rankingModel := ranking.Clone(t.RankingModel) + t.rankingModelMutex.Unlock() - // collect neighbors of items - if numItems == 0 { - m.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.") - } else if numItemsChanged || numFeedbackChanged { - m.runFindItemNeighborsTask(m.rankingTrainSet) - } - // collect neighbors of users - if numUsers == 0 { - m.taskMonitor.Fail(TaskFindUserNeighbors, "No user found.") - } else if numUsersChanged || numFeedbackChanged { - m.runFindUserNeighborsTask(m.rankingTrainSet) - } - - // training model if numFeedback == 0 { - m.taskMonitor.Fail(TaskFitRankingModel, "No feedback found.") - return - } else if !numFeedbackChanged && !modelChanged { + t.taskMonitor.Fail(TaskFitRankingModel, "No feedback found.") + return nil + } else if numFeedback == t.lastNumFeedback && !modelChanged { log.Logger().Info("nothing changed") - return + return nil } - m.runFitRankingModelTask(rankingModel) - return -} -func (m *Master) runFitRankingModelTask(rankingModel ranking.MatrixFactorization) { startFitTime := time.Now() - score := rankingModel.Fit(m.rankingTrainSet, m.rankingTestSet, ranking.NewFitConfig(). - SetJobs(m.Config.Master.NumJobs). - SetTask(m.taskMonitor.Start(TaskFitRankingModel, rankingModel.Complexity()))) + score := rankingModel.Fit(t.rankingTrainSet, t.rankingTestSet, ranking.NewFitConfig(). + SetJobsAllocator(j). + SetTask(t.taskMonitor.Start(TaskFitRankingModel, rankingModel.Complexity()))) CollaborativeFilteringFitSeconds.Set(time.Since(startFitTime).Seconds()) // update ranking model - m.rankingModelMutex.Lock() - m.RankingModel = rankingModel - m.RankingModelVersion++ - m.rankingScore = score - m.rankingModelMutex.Unlock() + t.rankingModelMutex.Lock() + t.RankingModel = rankingModel + t.RankingModelVersion++ + t.rankingScore = score + t.rankingModelMutex.Unlock() log.Logger().Info("fit ranking model complete", - zap.String("version", fmt.Sprintf("%x", m.RankingModelVersion))) + zap.String("version", fmt.Sprintf("%x", t.RankingModelVersion))) CollaborativeFilteringNDCG10.Set(float64(score.NDCG)) CollaborativeFilteringRecall10.Set(float64(score.Recall)) CollaborativeFilteringPrecision10.Set(float64(score.Precision)) - MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes())) - if err := m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitMatchingModelTime), time.Now())); err != nil { + MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(t.RankingModel.Bytes())) + if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitMatchingModelTime), time.Now())); err != nil { log.Logger().Error("failed to write meta", zap.Error(err)) } // caching model - m.rankingModelMutex.RLock() - m.localCache.RankingModelName = m.rankingModelName - m.localCache.RankingModelVersion = m.RankingModelVersion - m.localCache.RankingModel = rankingModel - m.localCache.RankingModelScore = score - m.rankingModelMutex.RUnlock() - if m.localCache.ClickModel == nil || m.localCache.ClickModel.Invalid() { + t.rankingModelMutex.RLock() + t.localCache.RankingModelName = t.rankingModelName + t.localCache.RankingModelVersion = t.RankingModelVersion + t.localCache.RankingModel = rankingModel + t.localCache.RankingModelScore = score + t.rankingModelMutex.RUnlock() + if t.localCache.ClickModel == nil || t.localCache.ClickModel.Invalid() { log.Logger().Info("wait click model") - } else if err := m.localCache.WriteLocalCache(); err != nil { + } else if err := t.localCache.WriteLocalCache(); err != nil { log.Logger().Error("failed to write local cache", zap.Error(err)) } else { log.Logger().Info("write model to local cache", - zap.String("ranking_model_name", m.localCache.RankingModelName), - zap.String("ranking_model_version", encoding.Hex(m.localCache.RankingModelVersion)), - zap.Float32("ranking_model_score", m.localCache.RankingModelScore.NDCG), - zap.Any("ranking_model_params", m.localCache.RankingModel.GetParams())) + zap.String("ranking_model_name", t.localCache.RankingModelName), + zap.String("ranking_model_version", encoding.Hex(t.localCache.RankingModelVersion)), + zap.Float32("ranking_model_score", t.localCache.RankingModelScore.NDCG), + zap.Any("ranking_model_params", t.localCache.RankingModel.GetParams())) } - m.taskMonitor.Finish(TaskFitRankingModel) + t.taskMonitor.Finish(TaskFitRankingModel) + t.lastNumFeedback = numFeedback + return nil } -// runFitClickModelTask fits click model using latest data. After model fitted, following states are changed: +// FitClickModelTask fits click model using latest data. After model fitted, following states are changed: // 1. Click model version are increased. // 2. Click model score are updated. // 3. Click model, version and score are persisted to local cache. -func (m *Master) runFitClickModelTask( - lastNumUsers, lastNumItems, lastNumFeedback int, -) (numUsers, numItems, numFeedback int, err error) { - log.Logger().Info("prepare to fit click model", zap.Int("n_jobs", m.Config.Master.NumJobs)) - m.clickDataMutex.RLock() - defer m.clickDataMutex.RUnlock() - numUsers = m.clickTrainSet.UserCount() - numItems = m.clickTrainSet.ItemCount() - numFeedback = m.clickTrainSet.Count() +type FitClickModelTask struct { + *Master + lastNumUsers int + lastNumItems int + lastNumFeedback int +} + +func NewFitClickModelTask(m *Master) *FitClickModelTask { + return &FitClickModelTask{Master: m} +} + +func (t *FitClickModelTask) name() string { + return TaskFitClickModel +} + +func (t *FitClickModelTask) priority() int { + return -t.clickTrainSet.Count() +} + +func (t *FitClickModelTask) run(j *task.JobsAllocator) error { + log.Logger().Info("prepare to fit click model", zap.Int("n_jobs", t.Config.Master.NumJobs)) + t.clickDataMutex.RLock() + defer t.clickDataMutex.RUnlock() + numUsers := t.clickTrainSet.UserCount() + numItems := t.clickTrainSet.ItemCount() + numFeedback := t.clickTrainSet.Count() var shouldFit bool - if numUsers == 0 || numItems == 0 || numFeedback == 0 { + if t.clickTrainSet == nil || numUsers == 0 || numItems == 0 || numFeedback == 0 { log.Logger().Warn("empty ranking dataset", - zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes)) - m.taskMonitor.Fail(TaskFitClickModel, "No feedback found.") - return - } else if numUsers != lastNumUsers || - numItems != lastNumItems || - numFeedback != lastNumFeedback { + zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes)) + t.taskMonitor.Fail(TaskFitClickModel, "No feedback found.") + return nil + } else if numUsers != t.lastNumUsers || + numItems != t.lastNumItems || + numFeedback != t.lastNumFeedback { shouldFit = true } - bestClickModel, bestClickScore := m.clickModelSearcher.GetBestModel() - m.clickModelMutex.Lock() + bestClickModel, bestClickScore := t.clickModelSearcher.GetBestModel() + t.clickModelMutex.Lock() if bestClickModel != nil && !bestClickModel.Invalid() && - bestClickModel.GetParams().ToString() != m.ClickModel.GetParams().ToString() && - bestClickScore.Precision > m.clickScore.Precision { + bestClickModel.GetParams().ToString() != t.ClickModel.GetParams().ToString() && + bestClickScore.Precision > t.clickScore.Precision { // 1. best click model must have been found. // 2. best click model must be different from current model // 3. best click model must perform better than current model - m.ClickModel = bestClickModel - m.clickScore = bestClickScore + t.ClickModel = bestClickModel + t.clickScore = bestClickScore shouldFit = true log.Logger().Info("find better click model", zap.Float32("Precision", bestClickScore.Precision), zap.Float32("Recall", bestClickScore.Recall), - zap.Any("params", m.ClickModel.GetParams())) + zap.Any("params", t.ClickModel.GetParams())) } - clickModel := m.ClickModel - m.clickModelMutex.Unlock() + clickModel := click.Clone(t.ClickModel) + t.clickModelMutex.Unlock() // training model if !shouldFit { log.Logger().Info("nothing changed") - return + return nil } startFitTime := time.Now() - score := clickModel.Fit(m.clickTrainSet, m.clickTestSet, click.NewFitConfig(). - SetJobs(m.Config.Master.NumJobs). - SetTask(m.taskMonitor.Start(TaskFitClickModel, clickModel.Complexity()))) + score := clickModel.Fit(t.clickTrainSet, t.clickTestSet, click.NewFitConfig(). + SetJobsAllocator(j). + SetTask(t.taskMonitor.Start(TaskFitClickModel, clickModel.Complexity()))) RankingFitSeconds.Set(time.Since(startFitTime).Seconds()) // update match model - m.clickModelMutex.Lock() - m.ClickModel = clickModel - m.clickScore = score - m.ClickModelVersion++ - m.clickModelMutex.Unlock() + t.clickModelMutex.Lock() + t.ClickModel = clickModel + t.clickScore = score + t.ClickModelVersion++ + t.clickModelMutex.Unlock() log.Logger().Info("fit click model complete", - zap.String("version", fmt.Sprintf("%x", m.ClickModelVersion))) + zap.String("version", fmt.Sprintf("%x", t.ClickModelVersion))) RankingPrecision.Set(float64(score.Precision)) RankingRecall.Set(float64(score.Recall)) RankingAUC.Set(float64(score.AUC)) - MemoryInuseBytesVec.WithLabelValues("ranking_model").Set(float64(m.ClickModel.Bytes())) - if err = m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitRankingModelTime), time.Now())); err != nil { + MemoryInUseBytesVec.WithLabelValues("ranking_model").Set(float64(t.ClickModel.Bytes())) + if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitRankingModelTime), time.Now())); err != nil { log.Logger().Error("failed to write meta", zap.Error(err)) } // caching model - m.clickModelMutex.RLock() - m.localCache.ClickModelScore = m.clickScore - m.localCache.ClickModelVersion = m.ClickModelVersion - m.localCache.ClickModel = m.ClickModel - m.clickModelMutex.RUnlock() - if m.localCache.RankingModel == nil || m.localCache.RankingModel.Invalid() { + t.clickModelMutex.RLock() + t.localCache.ClickModelScore = t.clickScore + t.localCache.ClickModelVersion = t.ClickModelVersion + t.localCache.ClickModel = t.ClickModel + t.clickModelMutex.RUnlock() + if t.localCache.RankingModel == nil || t.localCache.RankingModel.Invalid() { log.Logger().Info("wait ranking model") - } else if err = m.localCache.WriteLocalCache(); err != nil { + } else if err := t.localCache.WriteLocalCache(); err != nil { log.Logger().Error("failed to write local cache", zap.Error(err)) } else { log.Logger().Info("write model to local cache", - zap.String("click_model_version", encoding.Hex(m.localCache.ClickModelVersion)), + zap.String("click_model_version", encoding.Hex(t.localCache.ClickModelVersion)), zap.Float32("click_model_score", score.Precision), - zap.Any("click_model_params", m.localCache.ClickModel.GetParams())) + zap.Any("click_model_params", t.localCache.ClickModel.GetParams())) } - m.taskMonitor.Finish(TaskFitClickModel) - return + t.taskMonitor.Finish(TaskFitClickModel) + t.lastNumItems = numItems + t.lastNumUsers = numUsers + t.lastNumFeedback = numFeedback + return nil } -// runSearchRankingModelTask searches best hyper-parameters for ranking models. +// SearchRankingModelTask searches best hyper-parameters for ranking models. // It requires read lock on the ranking dataset. -func (m *Master) runSearchRankingModelTask( - lastNumUsers, lastNumItems, lastNumFeedback int, -) (numUsers, numItems, numFeedback int, err error) { +type SearchRankingModelTask struct { + *Master + lastNumUsers int + lastNumItems int + lastNumFeedback int +} + +func NewSearchRankingModelTask(m *Master) *SearchRankingModelTask { + return &SearchRankingModelTask{Master: m} +} + +func (t *SearchRankingModelTask) name() string { + return TaskSearchRankingModel +} + +func (t *SearchRankingModelTask) priority() int { + return -t.rankingTrainSet.Count() +} + +func (t *SearchRankingModelTask) run(j *task.JobsAllocator) error { log.Logger().Info("start searching ranking model") - m.rankingDataMutex.RLock() - defer m.rankingDataMutex.RUnlock() - numUsers = m.rankingTrainSet.UserCount() - numItems = m.rankingTrainSet.ItemCount() - numFeedback = m.rankingTrainSet.Count() + t.rankingDataMutex.RLock() + defer t.rankingDataMutex.RUnlock() + if t.rankingTrainSet == nil { + log.Logger().Debug("dataset has not been loaded") + return nil + } + numUsers := t.rankingTrainSet.UserCount() + numItems := t.rankingTrainSet.ItemCount() + numFeedback := t.rankingTrainSet.Count() if numUsers == 0 || numItems == 0 || numFeedback == 0 { log.Logger().Warn("empty ranking dataset", - zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes)) - m.taskMonitor.Fail(TaskSearchRankingModel, "No feedback found.") - return - } else if numUsers == lastNumUsers && - numItems == lastNumItems && - numFeedback == lastNumFeedback { + zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes)) + t.taskMonitor.Fail(TaskSearchRankingModel, "No feedback found.") + return nil + } else if numUsers == t.lastNumUsers && + numItems == t.lastNumItems && + numFeedback == t.lastNumFeedback { log.Logger().Info("ranking dataset not changed") - return + return nil } startTime := time.Now() - err = m.rankingModelSearcher.Fit(m.rankingTrainSet, m.rankingTestSet, - m.taskMonitor.Start(TaskSearchRankingModel, m.rankingModelSearcher.Complexity()), - m.taskScheduler.NewRunner(TaskSearchRankingModel)) + err := t.rankingModelSearcher.Fit(t.rankingTrainSet, t.rankingTestSet, + t.taskMonitor.Start(TaskSearchRankingModel, t.rankingModelSearcher.Complexity()), j) if err != nil { log.Logger().Error("failed to search collaborative filtering model", zap.Error(err)) - return + return nil } CollaborativeFilteringSearchSeconds.Set(time.Since(startTime).Seconds()) - _, _, bestScore := m.rankingModelSearcher.GetBestModel() + _, _, bestScore := t.rankingModelSearcher.GetBestModel() CollaborativeFilteringSearchPrecision10.Set(float64(bestScore.Precision)) - m.taskMonitor.Finish(TaskSearchRankingModel) - return + t.taskMonitor.Finish(TaskSearchRankingModel) + t.lastNumItems = numItems + t.lastNumUsers = numUsers + t.lastNumFeedback = numFeedback + return nil } -// runSearchClickModelTask searches best hyper-parameters for factorization machines. +// SearchClickModelTask searches best hyper-parameters for factorization machines. // It requires read lock on the click dataset. -func (m *Master) runSearchClickModelTask( - lastNumUsers, lastNumItems, lastNumFeedback int, -) (numUsers, numItems, numFeedback int, err error) { +type SearchClickModelTask struct { + *Master + lastNumUsers int + lastNumItems int + lastNumFeedback int +} + +func NewSearchClickModelTask(m *Master) *SearchClickModelTask { + return &SearchClickModelTask{Master: m} +} + +func (t *SearchClickModelTask) name() string { + return TaskSearchClickModel +} + +func (t *SearchClickModelTask) priority() int { + return -t.clickTrainSet.Count() +} + +func (t *SearchClickModelTask) run(j *task.JobsAllocator) error { log.Logger().Info("start searching click model") - m.clickDataMutex.RLock() - defer m.clickDataMutex.RUnlock() - numUsers = m.clickTrainSet.UserCount() - numItems = m.clickTrainSet.ItemCount() - numFeedback = m.clickTrainSet.Count() + t.clickDataMutex.RLock() + defer t.clickDataMutex.RUnlock() + if t.clickTrainSet == nil { + log.Logger().Debug("dataset has not been loaded") + return nil + } + numUsers := t.clickTrainSet.UserCount() + numItems := t.clickTrainSet.ItemCount() + numFeedback := t.clickTrainSet.Count() if numUsers == 0 || numItems == 0 || numFeedback == 0 { log.Logger().Warn("empty click dataset", - zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes)) - m.taskMonitor.Fail(TaskSearchClickModel, "No feedback found.") - return - } else if numUsers == lastNumUsers && - numItems == lastNumItems && - numFeedback == lastNumFeedback { + zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes)) + t.taskMonitor.Fail(TaskSearchClickModel, "No feedback found.") + return nil + } else if numUsers == t.lastNumUsers && + numItems == t.lastNumItems && + numFeedback == t.lastNumFeedback { log.Logger().Info("click dataset not changed") - return + return nil } startTime := time.Now() - err = m.clickModelSearcher.Fit(m.clickTrainSet, m.clickTestSet, - m.taskMonitor.Start(TaskSearchClickModel, m.clickModelSearcher.Complexity()), - m.taskScheduler.NewRunner(TaskSearchClickModel)) + err := t.clickModelSearcher.Fit(t.clickTrainSet, t.clickTestSet, + t.taskMonitor.Start(TaskSearchClickModel, t.clickModelSearcher.Complexity()), j) if err != nil { log.Logger().Error("failed to search ranking model", zap.Error(err)) - return + return nil } RankingSearchSeconds.Set(time.Since(startTime).Seconds()) - _, bestScore := m.clickModelSearcher.GetBestModel() + _, bestScore := t.clickModelSearcher.GetBestModel() RankingSearchPrecision.Set(float64(bestScore.Precision)) - m.taskMonitor.Finish(TaskSearchClickModel) - return + t.taskMonitor.Finish(TaskSearchClickModel) + t.lastNumItems = numItems + t.lastNumUsers = numUsers + t.lastNumFeedback = numFeedback + return nil +} + +type CacheGarbageCollectionTask struct { + *Master +} + +func NewCacheGarbageCollectionTask(m *Master) *CacheGarbageCollectionTask { + return &CacheGarbageCollectionTask{m} +} + +func (t *CacheGarbageCollectionTask) name() string { + return TaskCacheGarbageCollection +} + +func (t *CacheGarbageCollectionTask) priority() int { + return -t.rankingTrainSet.UserCount() - t.rankingTrainSet.ItemCount() } -func (m *Master) runCacheGarbageCollectionTask() error { - if m.rankingTrainSet == nil { +func (t *CacheGarbageCollectionTask) run(j *task.JobsAllocator) error { + if t.rankingTrainSet == nil { + log.Logger().Debug("dataset has not been loaded") return nil } + log.Logger().Info("start cache garbage collection") - m.taskMonitor.Start(TaskCacheGarbageCollection, m.rankingTrainSet.UserCount()*9+m.rankingTrainSet.ItemCount()*4) + t.taskMonitor.Start(TaskCacheGarbageCollection, t.rankingTrainSet.UserCount()*9+t.rankingTrainSet.ItemCount()*4) var scanCount, reclaimCount int start := time.Now() - err := m.CacheClient.Scan(func(s string) error { + err := t.CacheClient.Scan(func(s string) error { splits := strings.Split(s, "/") if len(splits) <= 1 { return nil } scanCount++ - m.taskMonitor.Update(TaskCacheGarbageCollection, scanCount) + t.taskMonitor.Update(TaskCacheGarbageCollection, scanCount) switch splits[0] { case cache.UserNeighbors, cache.UserNeighborsDigest, cache.IgnoreItems, cache.OfflineRecommend, cache.OfflineRecommendDigest, cache.CollaborativeRecommend, cache.LastModifyUserTime, cache.LastUpdateUserNeighborsTime, cache.LastUpdateUserRecommendTime: userId := splits[1] // check user in dataset - if m.rankingTrainSet != nil && m.rankingTrainSet.UserIndex.ToNumber(userId) != base.NotId { + if t.rankingTrainSet != nil && t.rankingTrainSet.UserIndex.ToNumber(userId) != base.NotId { return nil } // check user in database - _, err := m.DataClient.GetUser(userId) + _, err := t.DataClient.GetUser(userId) if !errors.Is(err, errors.NotFound) { if err != nil { log.Logger().Error("failed to load user", zap.String("user_id", userId), zap.Error(err)) @@ -1170,10 +1320,10 @@ func (m *Master) runCacheGarbageCollectionTask() error { // delete user cache switch splits[0] { case cache.UserNeighbors, cache.IgnoreItems, cache.CollaborativeRecommend, cache.OfflineRecommend: - err = m.CacheClient.SetSorted(s, nil) + err = t.CacheClient.SetSorted(s, nil) case cache.UserNeighborsDigest, cache.OfflineRecommendDigest, cache.LastModifyUserTime, cache.LastUpdateUserNeighborsTime, cache.LastUpdateUserRecommendTime: - err = m.CacheClient.Delete(s) + err = t.CacheClient.Delete(s) } if err != nil { return errors.Trace(err) @@ -1182,11 +1332,11 @@ func (m *Master) runCacheGarbageCollectionTask() error { case cache.ItemNeighbors, cache.ItemNeighborsDigest, cache.LastModifyItemTime, cache.LastUpdateItemNeighborsTime: itemId := splits[1] // check item in dataset - if m.rankingTrainSet != nil && m.rankingTrainSet.ItemIndex.ToNumber(itemId) != base.NotId { + if t.rankingTrainSet != nil && t.rankingTrainSet.ItemIndex.ToNumber(itemId) != base.NotId { return nil } // check item in database - _, err := m.DataClient.GetItem(itemId) + _, err := t.DataClient.GetItem(itemId) if !errors.Is(err, errors.NotFound) { if err != nil { log.Logger().Error("failed to load item", zap.String("item_id", itemId), zap.Error(err)) @@ -1196,9 +1346,9 @@ func (m *Master) runCacheGarbageCollectionTask() error { // delete item cache switch splits[0] { case cache.ItemNeighbors: - err = m.CacheClient.SetSorted(s, nil) + err = t.CacheClient.SetSorted(s, nil) case cache.ItemNeighborsDigest, cache.LastModifyItemTime, cache.LastUpdateItemNeighborsTime: - err = m.CacheClient.Delete(s) + err = t.CacheClient.Delete(s) } if err != nil { return errors.Trace(err) @@ -1208,10 +1358,10 @@ func (m *Master) runCacheGarbageCollectionTask() error { return nil }) // remove stale hidden items - if err := m.CacheClient.RemSortedByScore(cache.HiddenItemsV2, math.Inf(-1), float64(time.Now().Add(-m.Config.Recommend.CacheExpire).Unix())); err != nil { + if err := t.CacheClient.RemSortedByScore(cache.HiddenItemsV2, math.Inf(-1), float64(time.Now().Add(-t.Config.Recommend.CacheExpire).Unix())); err != nil { return errors.Trace(err) } - m.taskMonitor.Finish(TaskCacheGarbageCollection) + t.taskMonitor.Finish(TaskCacheGarbageCollection) CacheScannedTotal.Set(float64(scanCount)) CacheReclaimedTotal.Set(float64(reclaimCount)) CacheScannedSeconds.Set(time.Since(start).Seconds()) diff --git a/master/tasks_test.go b/master/tasks_test.go index 86a9cf095..5951500a6 100644 --- a/master/tasks_test.go +++ b/master/tasks_test.go @@ -15,15 +15,16 @@ package master import ( + "strconv" + "testing" + "time" + "github.com/juju/errors" "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" - "strconv" - "testing" - "time" ) func TestMaster_FindItemNeighborsBruteForce(t *testing.T) { @@ -85,10 +86,12 @@ func TestMaster_FindItemNeighborsBruteForce(t *testing.T) { // load mock dataset dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) assert.NoError(t, err) + m.rankingTrainSet = dataset // similar items (common users) m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated - m.runFindItemNeighborsTask(dataset) + neighborTask := NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err := m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "9"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar)) @@ -103,7 +106,8 @@ func TestMaster_FindItemNeighborsBruteForce(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "8"), time.Now())) assert.NoError(t, err) m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar - m.runFindItemNeighborsTask(dataset) + neighborTask = NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -120,7 +124,8 @@ func TestMaster_FindItemNeighborsBruteForce(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "9"), time.Now())) assert.NoError(t, err) m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeAuto - m.runFindItemNeighborsTask(dataset) + neighborTask = NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -193,10 +198,12 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) { // load mock dataset dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) assert.NoError(t, err) + m.rankingTrainSet = dataset // similar items (common users) m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated - m.runFindItemNeighborsTask(dataset) + neighborTask := NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err := m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "9"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar)) @@ -211,7 +218,8 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "8"), time.Now())) assert.NoError(t, err) m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar - m.runFindItemNeighborsTask(dataset) + neighborTask = NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -228,7 +236,8 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "9"), time.Now())) assert.NoError(t, err) m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeAuto - m.runFindItemNeighborsTask(dataset) + neighborTask = NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -282,10 +291,12 @@ func TestMaster_FindUserNeighborsBruteForce(t *testing.T) { assert.NoError(t, err) dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) assert.NoError(t, err) + m.rankingTrainSet = dataset // similar items (common users) m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated - m.runFindUserNeighborsTask(dataset) + neighborTask := NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err := m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "9"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar)) @@ -296,7 +307,8 @@ func TestMaster_FindUserNeighborsBruteForce(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) assert.NoError(t, err) m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar - m.runFindUserNeighborsTask(dataset) + neighborTask = NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -309,7 +321,8 @@ func TestMaster_FindUserNeighborsBruteForce(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "9"), time.Now())) assert.NoError(t, err) m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeAuto - m.runFindUserNeighborsTask(dataset) + neighborTask = NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -366,10 +379,12 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) { assert.NoError(t, err) dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) assert.NoError(t, err) + m.rankingTrainSet = dataset // similar items (common users) m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated - m.runFindUserNeighborsTask(dataset) + neighborTask := NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err := m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "9"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar)) @@ -380,7 +395,8 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) assert.NoError(t, err) m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar - m.runFindUserNeighborsTask(dataset) + neighborTask = NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -393,7 +409,8 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) { err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "9"), time.Now())) assert.NoError(t, err) m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeAuto - m.runFindUserNeighborsTask(dataset) + neighborTask = NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100) assert.NoError(t, err) assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar)) @@ -673,7 +690,8 @@ func TestRunCacheGarbageCollectionTask(t *testing.T) { // remove cache assert.NotNil(t, m.rankingTrainSet) - err = m.runCacheGarbageCollectionTask() + gcTask := NewCacheGarbageCollectionTask(&m.Master) + err = gcTask.run(nil) assert.NoError(t, err) var s string diff --git a/misc/csv_test/feedback.csv b/misc/csv_test/feedback.csv deleted file mode 100644 index c0ee0fab3..000000000 --- a/misc/csv_test/feedback.csv +++ /dev/null @@ -1,6 +0,0 @@ -UserId,ItemId -0,0,0 -1,2,3 -2,4,6 -3,6,9 -4,8,12 \ No newline at end of file diff --git a/misc/csv_test/item_date.csv b/misc/csv_test/item_date.csv deleted file mode 100644 index 0df23062a..000000000 --- a/misc/csv_test/item_date.csv +++ /dev/null @@ -1,5 +0,0 @@ -1,2020-1-1,a|b|c -2,2020-2-1,e|f|g -3,2020-3-1,a|b|c -4,2020-4-1,e|f|g -5,2020-5-1,a|b|c \ No newline at end of file diff --git a/misc/csv_test/items.csv b/misc/csv_test/items.csv deleted file mode 100644 index ee45216a5..000000000 --- a/misc/csv_test/items.csv +++ /dev/null @@ -1,5 +0,0 @@ -1::Toy Story (1995)::Animation|Children's|Comedy -2::Jumanji (1995)::Adventure|Children's|Fantasy -3::Grumpier Old Men (1995)::Comedy|Romance -4::Waiting to Exhale (1995)::Comedy|Drama -5::Father of the Bride Part II (1995)::Comedy \ No newline at end of file diff --git a/misc/csv_test/items_header.csv b/misc/csv_test/items_header.csv deleted file mode 100644 index 7817fa0c0..000000000 --- a/misc/csv_test/items_header.csv +++ /dev/null @@ -1,6 +0,0 @@ -ItemId::Title::Genres -1::Toy Story (1995)::Animation|Children's|Comedy -2::Jumanji (1995)::Adventure|Children's|Fantasy -3::Grumpier Old Men (1995)::Comedy|Romance -4::Waiting to Exhale (1995)::Comedy|Drama -5::Father of the Bride Part II (1995)::Comedy \ No newline at end of file diff --git a/misc/csv_test/items_id_only.csv b/misc/csv_test/items_id_only.csv deleted file mode 100644 index 8a1218a10..000000000 --- a/misc/csv_test/items_id_only.csv +++ /dev/null @@ -1,5 +0,0 @@ -1 -2 -3 -4 -5 diff --git a/model/click/model.go b/model/click/model.go index dc93b7c3c..8e4a8943f 100644 --- a/model/click/model.go +++ b/model/click/model.go @@ -17,6 +17,10 @@ package click import ( "encoding/binary" "fmt" + "io" + "reflect" + "time" + "github.com/chewxy/math32" "github.com/juju/errors" "github.com/samber/lo" @@ -30,9 +34,6 @@ import ( "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" - "io" - "reflect" - "time" ) type Score struct { @@ -91,14 +92,13 @@ func (score Score) BetterThan(s Score) bool { } type FitConfig struct { - Jobs int + *task.JobsAllocator Verbose int Task *task.Task } func NewFitConfig() *FitConfig { return &FitConfig{ - Jobs: 1, Verbose: 10, } } @@ -108,8 +108,8 @@ func (config *FitConfig) SetVerbose(verbose int) *FitConfig { return config } -func (config *FitConfig) SetJobs(nJobs int) *FitConfig { - config.Jobs = nJobs +func (config *FitConfig) SetJobsAllocator(allocator *task.JobsAllocator) *FitConfig { + config.JobsAllocator = allocator return config } @@ -280,8 +280,9 @@ func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score { zap.Any("params", fm.GetParams()), zap.Any("config", config)) fm.Init(trainSet) - temp := base.NewMatrix32(config.Jobs, fm.nFactors) - vGrad := base.NewMatrix32(config.Jobs, fm.nFactors) + maxJobs := config.MaxJobs() + temp := base.NewMatrix32(maxJobs, fm.nFactors) + vGrad := base.NewMatrix32(maxJobs, fm.nFactors) snapshots := SnapshotManger{} evalStart := time.Now() @@ -306,7 +307,7 @@ func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score { } fitStart := time.Now() cost := float32(0) - _ = parallel.BatchParallel(trainSet.Count(), config.Jobs, 128, func(workerId, beginJobId, endJobId int) error { + _ = parallel.BatchParallel(trainSet.Count(), config.AvailableJobs(config.Task), 128, func(workerId, beginJobId, endJobId int) error { for i := beginJobId; i < endJobId; i++ { features, values, target := trainSet.Get(i) prediction := fm.internalPredictImpl(features, values) diff --git a/model/click/model_test.go b/model/click/model_test.go index e84599bda..c1d965ec0 100644 --- a/model/click/model_test.go +++ b/model/click/model_test.go @@ -30,7 +30,7 @@ func newFitConfigWithTestTracker(numEpoch int) *FitConfig { t := task.NewTask("test", numEpoch) cfg := NewFitConfig(). SetVerbose(1). - SetJobs(1). + SetJobsAllocator(task.NewConstantJobsAllocator(1)). SetTask(t) return cfg } diff --git a/model/click/search.go b/model/click/search.go index 8de6a379a..da4a673bc 100644 --- a/model/click/search.go +++ b/model/click/search.go @@ -16,13 +16,14 @@ package click import ( "fmt" + "sync" + "time" + "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" - "sync" - "time" ) // ParamsSearchResult contains the return of grid search. @@ -37,7 +38,7 @@ type ParamsSearchResult struct { // GridSearchCV finds the best parameters for a model. func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid, - _ int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult { + _ int64, fitConfig *FitConfig) ParamsSearchResult { // Retrieve parameter names and length paramNames := make([]model.ParamName, 0, len(paramGrid)) count := 1 @@ -60,11 +61,7 @@ func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Da // Cross validate estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - fitConfig.Task.Suspend(true) - runner.Lock() - fitConfig.Task.Suspend(false) score := estimator.Fit(trainSet, testSet, fitConfig) - runner.UnLock() // Create GridSearch result results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) @@ -90,10 +87,10 @@ func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Da // RandomSearchCV searches hyper-parameters by random. func RandomSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid, - numTrials int, seed int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult { + numTrials int, seed int64, fitConfig *FitConfig) ParamsSearchResult { // if the number of combination is less than number of trials, use grid search if paramGrid.NumCombinations() <= numTrials { - return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig, runner) + return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig) } rng := base.NewRandomGenerator(seed) results := ParamsSearchResult{ @@ -112,11 +109,7 @@ func RandomSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet * zap.Any("params", params)) estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - fitConfig.Task.Suspend(true) - runner.Lock() - fitConfig.Task.Suspend(false) score := estimator.Fit(trainSet, testSet, fitConfig) - runner.UnLock() results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) if len(results.Scores) == 0 || score.BetterThan(results.BestScore) { @@ -135,7 +128,6 @@ type ModelSearcher struct { // arguments numEpochs int numTrials int - numJobs int searchSize bool // results bestMutex sync.Mutex @@ -144,12 +136,11 @@ type ModelSearcher struct { } // NewModelSearcher creates a thread-safe personal ranking model searcher. -func NewModelSearcher(nEpoch, nTrials, nJobs int, searchSize bool) *ModelSearcher { +func NewModelSearcher(nEpoch, nTrials int, searchSize bool) *ModelSearcher { return &ModelSearcher{ model: NewFM(FMClassification, model.Params{model.NEpochs: nEpoch}), numTrials: nTrials, numEpochs: nEpoch, - numJobs: nJobs, searchSize: searchSize, } } @@ -165,7 +156,7 @@ func (searcher *ModelSearcher) Complexity() int { return searcher.numTrials * searcher.numEpochs } -func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, runner model.Runner) error { +func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, j *task.JobsAllocator) error { log.Logger().Info("click model search", zap.Int("n_users", trainSet.UserCount()), zap.Int("n_items", trainSet.ItemCount()), @@ -176,8 +167,8 @@ func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, runn // Random search grid := searcher.model.GetParamsGrid(searcher.searchSize) r := RandomSearchCV(searcher.model, trainSet, valSet, grid, searcher.numTrials, 0, NewFitConfig(). - SetJobs(searcher.numJobs). - SetTask(t), runner) + SetJobsAllocator(j). + SetTask(t)) searcher.bestMutex.Lock() defer searcher.bestMutex.Unlock() searcher.bestModel = r.BestModel diff --git a/model/click/search_test.go b/model/click/search_test.go index 00ac61922..3092d825e 100644 --- a/model/click/search_test.go +++ b/model/click/search_test.go @@ -15,7 +15,6 @@ package click import ( "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" @@ -87,22 +86,9 @@ func (m *mockFactorizationMachineForSearch) GetParamsGrid(_ bool) model.ParamsGr } } -type mockRunner struct { - mock.Mock -} - -func (r *mockRunner) Lock() { - r.Called() -} - -func (r *mockRunner) UnLock() { - r.Called() -} - func newFitConfigForSearch() *FitConfig { t := task.NewTask("test", 0) return &FitConfig{ - Jobs: 1, Verbose: 1, Task: t, } @@ -111,13 +97,8 @@ func newFitConfigForSearch() *FitConfig { func TestGridSearchCV(t *testing.T) { m := &mockFactorizationMachineForSearch{} fitConfig := newFitConfigForSearch() - runner := new(mockRunner) - runner.On("Lock") - runner.On("UnLock") - r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig, runner) + r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.AUC) - runner.AssertCalled(t, "Lock") - runner.AssertCalled(t, "UnLock") assert.Equal(t, model.Params{ model.NFactors: 4, model.InitMean: 4, @@ -128,12 +109,7 @@ func TestGridSearchCV(t *testing.T) { func TestRandomSearchCV(t *testing.T) { m := &mockFactorizationMachineForSearch{} fitConfig := newFitConfigForSearch() - runner := new(mockRunner) - runner.On("Lock") - runner.On("UnLock") - r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig, runner) - runner.AssertCalled(t, "Lock") - runner.AssertCalled(t, "UnLock") + r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.AUC) assert.Equal(t, model.Params{ model.NFactors: 4, @@ -143,13 +119,10 @@ func TestRandomSearchCV(t *testing.T) { } func TestModelSearcher_RandomSearch(t *testing.T) { - runner := new(mockRunner) - runner.On("Lock") - runner.On("UnLock") - searcher := NewModelSearcher(2, 63, 1, false) + searcher := NewModelSearcher(2, 63, false) searcher.model = &mockFactorizationMachineForSearch{model.BaseModel{Params: model.Params{model.NEpochs: 2}}} tk := task.NewTask("test", searcher.Complexity()) - err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, runner) + err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1)) assert.NoError(t, err) m, score := searcher.GetBestModel() assert.Equal(t, float32(12), score.AUC) @@ -163,13 +136,10 @@ func TestModelSearcher_RandomSearch(t *testing.T) { } func TestModelSearcher_GridSearch(t *testing.T) { - runner := new(mockRunner) - runner.On("Lock") - runner.On("UnLock") - searcher := NewModelSearcher(2, 64, 1, false) + searcher := NewModelSearcher(2, 64, false) searcher.model = &mockFactorizationMachineForSearch{model.BaseModel{Params: model.Params{model.NEpochs: 2}}} tk := task.NewTask("test", searcher.Complexity()) - err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, runner) + err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1)) assert.NoError(t, err) m, score := searcher.GetBestModel() assert.Equal(t, float32(12), score.AUC) diff --git a/model/model.go b/model/model.go index 3bb1b0371..a1f731a4d 100644 --- a/model/model.go +++ b/model/model.go @@ -51,8 +51,3 @@ func (model *BaseModel) GetParams() Params { func (model *BaseModel) GetRandomGenerator() base.RandomGenerator { return model.rng } - -type Runner interface { - Lock() - UnLock() -} diff --git a/model/ranking/data.go b/model/ranking/data.go index 0f3f140a1..fe0097765 100644 --- a/model/ranking/data.go +++ b/model/ranking/data.go @@ -267,49 +267,6 @@ func (dataset *DataSet) GetIndex(i int) (int32, int32) { return dataset.FeedbackUsers.Get(i), dataset.FeedbackItems.Get(i) } -// LoadDataFromCSV loads Data from a CSV file. The CSV file should be: -// [optional header] -// -// -// -// ... -// For example, the `u.Data` from MovieLens 100K is: -// 196\t242\t3\t881250949 -// 186\t302\t3\t891717742 -// 22\t377\t1\t878887116 -func LoadDataFromCSV(fileName, sep string, hasHeader bool) *DataSet { - dataset := NewMapIndexDataset() - // Open file - file, err := os.Open(fileName) - if err != nil { - log.Logger().Fatal("failed to open csv file", zap.Error(err), - zap.String("csv_file", fileName)) - } - defer func(file *os.File) { - err = file.Close() - if err != nil { - log.Logger().Error("failed to close file", zap.Error(err)) - } - }(file) - // Read CSV file - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - // Ignore header - if hasHeader { - hasHeader = false - continue - } - fields := strings.Split(line, sep) - // Ignore empty line - if len(fields) < 2 { - continue - } - dataset.AddFeedback(fields[0], fields[1], true) - } - return dataset -} - func loadTest(dataset *DataSet, path string) error { // Open file, err := os.Open(path) diff --git a/model/ranking/data_test.go b/model/ranking/data_test.go index 9ea8279e2..bd9220ff2 100644 --- a/model/ranking/data_test.go +++ b/model/ranking/data_test.go @@ -36,16 +36,6 @@ func TestNewMapIndexDataset(t *testing.T) { assert.Equal(t, 6, dataSet.ItemCount()) } -func TestLoadDataFromCSV(t *testing.T) { - dataset := LoadDataFromCSV("../../misc/csv_test/feedback.csv", ",", true) - assert.Equal(t, 5, dataset.Count()) - for i := 0; i < dataset.Count(); i++ { - userIndex, itemIndex := dataset.GetIndex(i) - assert.Equal(t, int32(i), userIndex) - assert.Equal(t, int32(i), itemIndex) - } -} - func TestDataSet_Split(t *testing.T) { numUsers, numItems := 3, 5 // create dataset diff --git a/model/ranking/model.go b/model/ranking/model.go index 7e9207399..6d2344ec7 100644 --- a/model/ranking/model.go +++ b/model/ranking/model.go @@ -16,6 +16,10 @@ package ranking import ( "fmt" + "io" + "reflect" + "time" + "github.com/bits-and-blooms/bitset" "github.com/chewxy/math32" "github.com/juju/errors" @@ -30,9 +34,6 @@ import ( "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" - "io" - "reflect" - "time" ) type Score struct { @@ -42,7 +43,7 @@ type Score struct { } type FitConfig struct { - Jobs int + *task.JobsAllocator Verbose int Candidates int TopK int @@ -51,7 +52,6 @@ type FitConfig struct { func NewFitConfig() *FitConfig { return &FitConfig{ - Jobs: 1, Verbose: 10, Candidates: 100, TopK: 10, @@ -63,8 +63,8 @@ func (config *FitConfig) SetVerbose(verbose int) *FitConfig { return config } -func (config *FitConfig) SetJobs(nJobs int) *FitConfig { - config.Jobs = nJobs +func (config *FitConfig) SetJobsAllocator(allocator *task.JobsAllocator) *FitConfig { + config.JobsAllocator = allocator return config } @@ -302,11 +302,12 @@ func UnmarshalModel(r io.Reader) (MatrixFactorization, error) { // model with implicit feedback. The pairwise ranking between item i and j for user u is estimated // by: // -// p(i >_u j) = \sigma( p_u^T (q_i - q_j) ) +// p(i >_u j) = \sigma( p_u^T (q_i - q_j) ) // // Hyper-parameters: +// // Reg - The regularization parameter of the cost function that is -// optimized. Default is 0.01. +// optimized. Default is 0.01. // Lr - The learning rate of SGD. Default is 0.05. // nFactors - The number of latent factors. Default is 10. // NEpochs - The number of iteration of the SGD procedure. Default is 100. @@ -401,12 +402,13 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Any("config", config)) bpr.Init(trainSet) // Create buffers - temp := base.NewMatrix32(config.Jobs, bpr.nFactors) - userFactor := base.NewMatrix32(config.Jobs, bpr.nFactors) - positiveItemFactor := base.NewMatrix32(config.Jobs, bpr.nFactors) - negativeItemFactor := base.NewMatrix32(config.Jobs, bpr.nFactors) - rng := make([]base.RandomGenerator, config.Jobs) - for i := 0; i < config.Jobs; i++ { + maxJobs := config.MaxJobs() + temp := base.NewMatrix32(maxJobs, bpr.nFactors) + userFactor := base.NewMatrix32(maxJobs, bpr.nFactors) + positiveItemFactor := base.NewMatrix32(maxJobs, bpr.nFactors) + negativeItemFactor := base.NewMatrix32(maxJobs, bpr.nFactors) + rng := make([]base.RandomGenerator, maxJobs) + for i := 0; i < maxJobs; i++ { rng[i] = base.NewRandomGenerator(bpr.GetRandomGenerator().Int63()) } // Convert array to hashmap @@ -419,7 +421,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { } snapshots := SnapshotManger{} evalStart := time.Now() - scores := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) + scores := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) evalTime := time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", 0, bpr.nEpochs), zap.String("eval_time", evalTime.String()), @@ -431,8 +433,9 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { for epoch := 1; epoch <= bpr.nEpochs; epoch++ { fitStart := time.Now() // Training epoch - cost := make([]float32, config.Jobs) - _ = parallel.Parallel(trainSet.Count(), config.Jobs, func(workerId, _ int) error { + numJobs := config.AvailableJobs(config.Task) + cost := make([]float32, numJobs) + _ = parallel.Parallel(trainSet.Count(), numJobs, func(workerId, _ int) error { // Select a user var userIndex int32 var ratingCount int @@ -479,7 +482,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { // Cross validation if epoch%config.Verbose == 0 || epoch == bpr.nEpochs { evalStart = time.Now() - scores = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) + scores = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) evalTime = time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", epoch, bpr.nEpochs), zap.String("fit_time", fitTime.String()), @@ -721,12 +724,13 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Any("config", config)) ccd.Init(trainSet) // Create temporary matrix + maxJobs := config.MaxJobs() s := base.NewMatrix32(ccd.nFactors, ccd.nFactors) - userPredictions := make([][]float32, config.Jobs) - itemPredictions := make([][]float32, config.Jobs) - userRes := make([][]float32, config.Jobs) - itemRes := make([][]float32, config.Jobs) - for i := 0; i < config.Jobs; i++ { + userPredictions := make([][]float32, maxJobs) + itemPredictions := make([][]float32, maxJobs) + userRes := make([][]float32, maxJobs) + itemRes := make([][]float32, maxJobs) + for i := 0; i < maxJobs; i++ { userPredictions[i] = make([]float32, trainSet.ItemCount()) itemPredictions[i] = make([]float32, trainSet.UserCount()) userRes[i] = make([]float32, trainSet.ItemCount()) @@ -735,7 +739,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { // evaluate initial model snapshots := SnapshotManger{} evalStart := time.Now() - scores := Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) + scores := Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) evalTime := time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", 0, ccd.nEpochs), zap.String("eval_time", evalTime.String()), @@ -757,7 +761,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { } } } - _ = parallel.Parallel(trainSet.UserCount(), config.Jobs, func(workerId, userIndex int) error { + _ = parallel.Parallel(trainSet.UserCount(), config.AvailableJobs(config.Task), func(workerId, userIndex int) error { userFeedback := trainSet.UserFeedback[userIndex] for _, i := range userFeedback { userPredictions[workerId][i] = ccd.InternalPredict(int32(userIndex), i) @@ -798,7 +802,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { } } } - _ = parallel.Parallel(trainSet.ItemCount(), config.Jobs, func(workerId, itemIndex int) error { + _ = parallel.Parallel(trainSet.ItemCount(), config.AvailableJobs(config.Task), func(workerId, itemIndex int) error { itemFeedback := trainSet.ItemFeedback[itemIndex] for _, u := range itemFeedback { itemPredictions[workerId][u] = ccd.InternalPredict(u, int32(itemIndex)) @@ -831,7 +835,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { // Cross validation if ep%config.Verbose == 0 || ep == ccd.nEpochs { evalStart = time.Now() - scores = Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall) + scores = Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) evalTime = time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", ep, ccd.nEpochs), zap.String("fit_time", fitTime.String()), diff --git a/model/ranking/model_test.go b/model/ranking/model_test.go index 1fa686c8f..0508429d8 100644 --- a/model/ranking/model_test.go +++ b/model/ranking/model_test.go @@ -31,7 +31,7 @@ const ( func newFitConfig(numEpoch int) *FitConfig { t := task.NewTask("test", numEpoch) - cfg := NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU()).SetTask(t) + cfg := NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU())).SetTask(t) return cfg } diff --git a/model/ranking/search.go b/model/ranking/search.go index fdc2cc903..e1f65c2c7 100644 --- a/model/ranking/search.go +++ b/model/ranking/search.go @@ -16,13 +16,14 @@ package ranking import ( "fmt" + "sync" + "time" + "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" - "sync" - "time" ) // ParamsSearchResult contains the return of grid search. @@ -47,7 +48,7 @@ func (r *ParamsSearchResult) AddScore(params model.Params, score Score) { // GridSearchCV finds the best parameters for a model. func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid, - _ int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult { + _ int64, fitConfig *FitConfig) ParamsSearchResult { // Retrieve parameter names and length paramNames := make([]model.ParamName, 0, len(paramGrid)) count := 1 @@ -70,11 +71,7 @@ func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *Dat // Cross validate estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - fitConfig.Task.Suspend(true) - runner.Lock() - fitConfig.Task.Suspend(false) score := estimator.Fit(trainSet, testSet, fitConfig) - runner.UnLock() // Create GridSearch result results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) @@ -100,10 +97,10 @@ func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *Dat // RandomSearchCV searches hyper-parameters by random. func RandomSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid, - numTrials int, seed int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult { + numTrials int, seed int64, fitConfig *FitConfig) ParamsSearchResult { // if the number of combination is less than number of trials, use grid search if paramGrid.NumCombinations() < numTrials { - return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig, runner) + return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig) } rng := base.NewRandomGenerator(seed) results := ParamsSearchResult{ @@ -122,11 +119,7 @@ func RandomSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *D zap.Any("params", params)) estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - fitConfig.Task.Suspend(true) - runner.Lock() - fitConfig.Task.Suspend(false) score := estimator.Fit(trainSet, testSet, fitConfig) - runner.UnLock() results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) if len(results.Scores) == 0 || score.NDCG > results.BestScore.NDCG { @@ -145,7 +138,6 @@ type ModelSearcher struct { // arguments numEpochs int numTrials int - numJobs int searchSize bool // results bestMutex sync.Mutex @@ -155,11 +147,10 @@ type ModelSearcher struct { } // NewModelSearcher creates a thread-safe personal ranking model searcher. -func NewModelSearcher(nEpoch, nTrials, nJobs int, searchSize bool) *ModelSearcher { +func NewModelSearcher(nEpoch, nTrials int, searchSize bool) *ModelSearcher { searcher := &ModelSearcher{ numTrials: nTrials, numEpochs: nEpoch, - numJobs: nJobs, searchSize: searchSize, } searcher.models = append(searcher.models, NewBPR(model.Params{model.NEpochs: searcher.numEpochs})) @@ -178,7 +169,7 @@ func (searcher *ModelSearcher) Complexity() int { return len(searcher.models) * searcher.numEpochs * searcher.numTrials } -func (searcher *ModelSearcher) Fit(trainSet, valSet *DataSet, t *task.Task, runner model.Runner) error { +func (searcher *ModelSearcher) Fit(trainSet, valSet *DataSet, t *task.Task, j *task.JobsAllocator) error { log.Logger().Info("ranking model search", zap.Int("n_users", trainSet.UserCount()), zap.Int("n_items", trainSet.ItemCount())) @@ -186,8 +177,8 @@ func (searcher *ModelSearcher) Fit(trainSet, valSet *DataSet, t *task.Task, runn for _, m := range searcher.models { r := RandomSearchCV(m, trainSet, valSet, m.GetParamsGrid(searcher.searchSize), searcher.numTrials, 0, NewFitConfig(). - SetJobs(searcher.numJobs). - SetTask(t), runner) + SetJobsAllocator(j). + SetTask(t)) searcher.bestMutex.Lock() if searcher.bestModel == nil || r.BestScore.NDCG > searcher.bestScore.NDCG { searcher.bestModel = r.BestModel diff --git a/model/ranking/search_test.go b/model/ranking/search_test.go index f91b500ce..71141b4e8 100644 --- a/model/ranking/search_test.go +++ b/model/ranking/search_test.go @@ -15,7 +15,6 @@ package ranking import ( "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" @@ -104,22 +103,9 @@ func (m *mockMatrixFactorizationForSearch) GetParamsGrid(_ bool) model.ParamsGri } } -type mockRunner struct { - mock.Mock -} - -func (r *mockRunner) Lock() { - r.Called() -} - -func (r *mockRunner) UnLock() { - r.Called() -} - func newFitConfigForSearch() *FitConfig { t := task.NewTask("test", 100) return &FitConfig{ - Jobs: 1, Verbose: 1, Task: t, } @@ -128,12 +114,7 @@ func newFitConfigForSearch() *FitConfig { func TestGridSearchCV(t *testing.T) { m := &mockMatrixFactorizationForSearch{} fitConfig := newFitConfigForSearch() - runner := new(mockRunner) - runner.On("Lock") - runner.On("UnLock") - r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig, runner) - runner.AssertCalled(t, "Lock") - runner.AssertCalled(t, "UnLock") + r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.NDCG) assert.Equal(t, model.Params{ model.NFactors: 4, @@ -145,12 +126,7 @@ func TestGridSearchCV(t *testing.T) { func TestRandomSearchCV(t *testing.T) { m := &mockMatrixFactorizationForSearch{} fitConfig := newFitConfigForSearch() - runner := new(mockRunner) - runner.On("Lock") - runner.On("UnLock") - r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig, runner) - runner.AssertCalled(t, "Lock") - runner.AssertCalled(t, "UnLock") + r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.NDCG) assert.Equal(t, model.Params{ model.NFactors: 4, @@ -160,13 +136,10 @@ func TestRandomSearchCV(t *testing.T) { } func TestModelSearcher(t *testing.T) { - runner := new(mockRunner) - runner.On("Lock") - runner.On("UnLock") - searcher := NewModelSearcher(2, 63, 1, false) + searcher := NewModelSearcher(2, 63, false) searcher.models = []MatrixFactorization{newMockMatrixFactorizationForSearch(2)} tk := task.NewTask("test", searcher.Complexity()) - err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, runner) + err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1)) assert.NoError(t, err) _, m, score := searcher.GetBestModel() assert.Equal(t, float32(12), score.NDCG) diff --git a/misc/database_test/docker-compose.yml b/storage/docker-compose.yml similarity index 100% rename from misc/database_test/docker-compose.yml rename to storage/docker-compose.yml