diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 000000000..9965205d0
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,7 @@
+.github
+assets
+
+LICENSE
+
+*.yml
+*.md
diff --git a/.github/workflows/build_release.yml b/.github/workflows/build_release.yml
index 2a8948ab0..c1576d25d 100644
--- a/.github/workflows/build_release.yml
+++ b/.github/workflows/build_release.yml
@@ -76,7 +76,7 @@ jobs:
with:
context: .
platforms: linux/amd64,linux/arm64
- file: docker/${{ matrix.image }}/Dockerfile
+ file: cmd/${{ matrix.image }}/Dockerfile
push: true
tags: |
zhenghaoz/${{ matrix.image }}:latest
diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml
index 68842efc0..bf4a5eafd 100644
--- a/.github/workflows/build_test.yml
+++ b/.github/workflows/build_test.yml
@@ -134,7 +134,6 @@ jobs:
- uses: actions/checkout@v1
- name: Build the stack
- working-directory: ./docker
run: docker-compose up -d
- name: Check the deployed service URL
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 4eedb1c3d..d149e2cfb 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -24,7 +24,7 @@ These following installations are required:
- **Docker Compose**: Multiple databases are required for unit tests. It's convenient to manage databases on Docker Compose.
```bash
-cd misc/database_test
+cd storage
docker-compose up -d
```
@@ -73,7 +73,7 @@ Most logics in Gorse are covered by unit tests. Run unit tests by the following
go test -v ./...
```
-The default database URLs are directed to these databases in `misc/database_test/docker-compose.yml`. Test databases could be overrode by setting following environment variables:
+The default database URLs are directed to these databases in `storage/docker-compose.yml`. Test databases could be overrode by setting following environment variables:
| Environment Value | Default Value |
|-------------------|----------------------------------------------|
diff --git a/README.md b/README.md
index e31671630..718628ed2 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,7 @@
+![](https://img.shields.io/github/go-mod/go-version/zhenghaoz/gorse)
[![build](https://github.com/zhenghaoz/gorse/workflows/build/badge.svg)](https://github.com/zhenghaoz/gorse/actions?query=workflow%3Abuild)
[![codecov](https://codecov.io/gh/gorse-io/gorse/branch/master/graph/badge.svg)](https://codecov.io/gh/gorse-io/gorse)
[![Go Report Card](https://goreportcard.com/badge/github.com/zhenghaoz/gorse)](https://goreportcard.com/report/github.com/zhenghaoz/gorse)
diff --git a/base/copier/copier.go b/base/copier/copier.go
index 03f2a3f66..621af599e 100644
--- a/base/copier/copier.go
+++ b/base/copier/copier.go
@@ -86,10 +86,10 @@ func copyValue(dst, src reflect.Value) error {
dstPointer := reflect.New(dst.Type())
srcPointer := reflect.New(src.Type())
srcPointer.Elem().Set(src)
- srcMarshaller, hasSrcMarshaler := srcPointer.Interface().(encoding.BinaryMarshaler)
+ srcMarshaller, hasSrcMarshaller := srcPointer.Interface().(encoding.BinaryMarshaler)
dstUnmarshaler, hasDstUnmarshaler := dstPointer.Interface().(encoding.BinaryUnmarshaler)
- if hasDstUnmarshaler && hasSrcMarshaler {
+ if hasDstUnmarshaler && hasSrcMarshaller {
dstByte, err := srcMarshaller.MarshalBinary()
if err != nil {
return err
@@ -114,6 +114,11 @@ func copyValue(dst, src reflect.Value) error {
}
}
case reflect.Ptr:
+ if src.IsNil() {
+ // If source is nil, set dst to nil.
+ dst.Set(reflect.Zero(dst.Type()))
+ return nil
+ }
if dst.IsNil() {
dst.Set(reflect.New(src.Elem().Type()))
}
@@ -124,6 +129,11 @@ func copyValue(dst, src reflect.Value) error {
return err
}
case reflect.Interface:
+ if src.IsNil() {
+ // If source is nil, set dst to nil.
+ dst.Set(reflect.Zero(dst.Type()))
+ return nil
+ }
if !dst.IsNil() {
switch dst.Elem().Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint,
diff --git a/base/copier/copier_test.go b/base/copier/copier_test.go
index 06ddf34ba..fca089e67 100644
--- a/base/copier/copier_test.go
+++ b/base/copier/copier_test.go
@@ -170,3 +170,22 @@ func TestPrivate(t *testing.T) {
*a.text = "world"
assert.Equal(t, "hello", *b.text)
}
+
+type NilInterface interface{}
+
+type NilStruct struct {
+ Interface NilInterface
+ Pointer *Foo
+}
+
+func TestNil(t *testing.T) {
+ var d = NilStruct{
+ Interface: 100,
+ Pointer: &Foo{A: 100},
+ }
+ var e NilStruct
+ err := Copy(&d, e)
+ assert.NoError(t, err)
+ assert.Nil(t, d.Interface)
+ assert.Nil(t, d.Pointer)
+}
diff --git a/base/parallel/parallel.go b/base/parallel/parallel.go
index 77aee96fa..600cae641 100644
--- a/base/parallel/parallel.go
+++ b/base/parallel/parallel.go
@@ -15,9 +15,17 @@
package parallel
import (
+ "sync"
+
"github.com/juju/errors"
+ "github.com/zhenghaoz/gorse/base/task"
+ "go.uber.org/atomic"
"modernc.org/mathutil"
- "sync"
+)
+
+const (
+ chanSize = 1024
+ allocPeriod = 128
)
/* Parallel Schedulers */
@@ -33,19 +41,13 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error
}
}
} else {
- const chanSize = 64
- const chanEOF = -1
c := make(chan int, chanSize)
// producer
go func() {
- // send jobs
for i := 0; i < nJobs; i++ {
c <- i
}
- // send EOF
- for i := 0; i < nWorkers; i++ {
- c <- chanEOF
- }
+ close(c)
}()
// consumer
var wg sync.WaitGroup
@@ -57,8 +59,8 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error
defer wg.Done()
for {
// read job
- jobId := <-c
- if jobId == chanEOF {
+ jobId, ok := <-c
+ if !ok {
return
}
// run job
@@ -80,6 +82,55 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error
return nil
}
+func DynamicParallel(nJobs int, jobsAlloc *task.JobsAllocator, worker func(workerId, jobId int) error) error {
+ c := make(chan int, chanSize)
+ // producer
+ go func() {
+ for i := 0; i < nJobs; i++ {
+ c <- i
+ }
+ close(c)
+ }()
+ // consumer
+ for {
+ exit := atomic.NewBool(true)
+ numJobs := jobsAlloc.AvailableJobs(nil)
+ var wg sync.WaitGroup
+ wg.Add(numJobs)
+ errs := make([]error, nJobs)
+ for j := 0; j < numJobs; j++ {
+ // start workers
+ go func(workerId int) {
+ defer wg.Done()
+ for i := 0; i < allocPeriod; i++ {
+ // read job
+ jobId, ok := <-c
+ if !ok {
+ return
+ }
+ exit.Store(false)
+ // run job
+ if err := worker(workerId, jobId); err != nil {
+ errs[jobId] = err
+ return
+ }
+ }
+ }(j)
+ }
+ wg.Wait()
+ // check errors
+ for _, err := range errs {
+ if err != nil {
+ return errors.Trace(err)
+ }
+ }
+ // exit if finished
+ if exit.Load() {
+ return nil
+ }
+ }
+}
+
type batchJob struct {
beginId int
endId int
@@ -90,19 +141,13 @@ func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJo
if nWorkers == 1 {
return worker(0, 0, nJobs)
}
- const chanSize = 64
- const chanEOF = -1
c := make(chan batchJob, chanSize)
// producer
go func() {
- // send jobs
for i := 0; i < nJobs; i += batchSize {
c <- batchJob{beginId: i, endId: mathutil.Min(i+batchSize, nJobs)}
}
- // send EOF
- for i := 0; i < nWorkers; i++ {
- c <- batchJob{beginId: chanEOF, endId: chanEOF}
- }
+ close(c)
}()
// consumer
var wg sync.WaitGroup
@@ -114,8 +159,8 @@ func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJo
defer wg.Done()
for {
// read job
- job := <-c
- if job.beginId == chanEOF {
+ job, ok := <-c
+ if !ok {
return
}
// run job
diff --git a/base/parallel/parallel_test.go b/base/parallel/parallel_test.go
index ffd3ab0f8..aa21b4124 100644
--- a/base/parallel/parallel_test.go
+++ b/base/parallel/parallel_test.go
@@ -18,6 +18,7 @@ import (
"github.com/scylladb/go-set"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/base"
+ "github.com/zhenghaoz/gorse/base/task"
"testing"
"time"
)
@@ -48,6 +49,22 @@ func TestParallel(t *testing.T) {
assert.Equal(t, 1, workersSet.Size())
}
+func TestDynamicParallel(t *testing.T) {
+ a := base.RangeInt(10000)
+ b := make([]int, len(a))
+ workerIds := make([]int, len(a))
+ _ = DynamicParallel(len(a), task.NewConstantJobsAllocator(3), func(workerId, jobId int) error {
+ b[jobId] = a[jobId]
+ workerIds[jobId] = workerId
+ time.Sleep(time.Microsecond)
+ return nil
+ })
+ workersSet := set.NewIntSet(workerIds...)
+ assert.Equal(t, a, b)
+ assert.GreaterOrEqual(t, 4, workersSet.Size())
+ assert.Less(t, 1, workersSet.Size())
+}
+
func TestBatchParallel(t *testing.T) {
a := base.RangeInt(10000)
b := make([]int, len(a))
diff --git a/base/random.go b/base/random.go
index eb02aa5a1..10fd00c10 100644
--- a/base/random.go
+++ b/base/random.go
@@ -15,9 +15,10 @@
package base
import (
+ "math/rand"
+
"github.com/scylladb/go-set/i32set"
"github.com/scylladb/go-set/iset"
- "math/rand"
)
// RandomGenerator is the random generator for gorse.
@@ -76,15 +77,6 @@ func (rng RandomGenerator) NormalVector64(size int, mean, stdDev float64) []floa
return ret
}
-// NormalMatrix64 makes a matrix filled with normal random floats.
-func (rng RandomGenerator) NormalMatrix64(row, col int, mean, stdDev float64) [][]float64 {
- ret := make([][]float64, row)
- for i := range ret {
- ret[i] = rng.NormalVector64(col, mean, stdDev)
- }
- return ret
-}
-
// Sample n values between low and high, but not in exclude.
func (rng RandomGenerator) Sample(low, high, n int, exclude ...*iset.Set) []int {
intervalLength := high - low
diff --git a/base/random_test.go b/base/random_test.go
index 1b110f023..caa2a1e79 100644
--- a/base/random_test.go
+++ b/base/random_test.go
@@ -14,13 +14,12 @@
package base
import (
+ "testing"
+
"github.com/chewxy/math32"
"github.com/scylladb/go-set"
"github.com/stretchr/testify/assert"
"github.com/thoas/go-funk"
- "gonum.org/v1/gonum/stat"
- "math"
- "testing"
)
const randomEpsilon = 0.1
@@ -39,13 +38,6 @@ func TestRandomGenerator_MakeUniformMatrix(t *testing.T) {
assert.False(t, funk.MaxFloat32(vec) > 2)
}
-func TestRandomGenerator_MakeNormalMatrix64(t *testing.T) {
- rng := NewRandomGenerator(0)
- vec := rng.NormalMatrix64(1, 1000, 1, 2)[0]
- assert.False(t, math.Abs(stat.Mean(vec, nil)-1) > randomEpsilon)
- assert.False(t, math.Abs(stat.StdDev(vec, nil)-2) > randomEpsilon)
-}
-
func TestRandomGenerator_Sample(t *testing.T) {
excludeSet := set.NewIntSet(0, 1, 2, 3, 4)
rng := NewRandomGenerator(0)
@@ -80,8 +72,10 @@ func stdDev(x []float32) float32 {
}
// meanVariance computes the sample mean and unbiased variance, where the mean and variance are
-// \sum_i w_i * x_i / (sum_i w_i)
-// \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1)
+//
+// \sum_i w_i * x_i / (sum_i w_i)
+// \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1)
+//
// respectively.
// If weights is nil then all of the weights are 1. If weights is not nil, then
// len(x) must equal len(weights).
diff --git a/base/search/index_test.go b/base/search/index_test.go
index 65b03183d..9a070c290 100644
--- a/base/search/index_test.go
+++ b/base/search/index_test.go
@@ -36,7 +36,7 @@ func TestHNSW_InnerProduct(t *testing.T) {
model.InitMean: 0,
model.InitStdDev: 0.001,
})
- fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU())
+ fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU()))
m.Fit(trainSet, testSet, fitConfig)
var vectors []Vector
for i, itemFactor := range m.ItemFactor {
diff --git a/base/search/ivf.go b/base/search/ivf.go
index a8435ab10..12bb20355 100644
--- a/base/search/ivf.go
+++ b/base/search/ivf.go
@@ -15,6 +15,11 @@
package search
import (
+ "math"
+ "math/rand"
+ "sync"
+ "time"
+
"github.com/chewxy/math32"
"github.com/zhenghaoz/gorse/base"
"github.com/zhenghaoz/gorse/base/heap"
@@ -23,12 +28,7 @@ import (
"github.com/zhenghaoz/gorse/base/task"
"go.uber.org/atomic"
"go.uber.org/zap"
- "math"
- "math/rand"
"modernc.org/mathutil"
- "runtime"
- "sync"
- "time"
)
const (
@@ -46,7 +46,7 @@ type IVF struct {
errorRate float32
maxIter int
numProbe int
- numJobs int
+ jobsAlloc *task.JobsAllocator
task *task.SubTask
}
@@ -65,9 +65,9 @@ func SetClusterErrorRate(errorRate float32) IVFConfig {
}
}
-func SetIVFNumJobs(numJobs int) IVFConfig {
+func SetIVFJobsAllocator(jobsAlloc *task.JobsAllocator) IVFConfig {
return func(ivf *IVF) {
- ivf.numJobs = numJobs
+ ivf.jobsAlloc = jobsAlloc
}
}
@@ -90,7 +90,6 @@ func NewIVF(vectors []Vector, configs ...IVFConfig) *IVF {
errorRate: 0.05,
maxIter: DefaultMaxIter,
numProbe: 1,
- numJobs: runtime.NumCPU(),
}
for _, config := range configs {
config(idx)
@@ -207,7 +206,7 @@ func (idx *IVF) Build() {
// reassign clusters
nextClusters := make([]ivfCluster, idx.k)
- _ = parallel.Parallel(len(idx.data), idx.numJobs, func(_, i int) error {
+ _ = parallel.Parallel(len(idx.data), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error {
if !idx.data[i].IsHidden() {
nextCluster, nextDistance := -1, float32(math32.MaxFloat32)
for c := range clusters {
@@ -270,7 +269,7 @@ func (b *IVFBuilder) evaluate(idx *IVF, prune0 bool) float32 {
samples := b.rng.Sample(0, len(b.data), testSize)
var result, count float32
var mu sync.Mutex
- _ = parallel.Parallel(len(samples), idx.numJobs, func(_, i int) error {
+ _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error {
sample := samples[i]
expected, _ := b.bruteForce.Search(b.data[sample], b.k, prune0)
if len(expected) > 0 {
@@ -321,7 +320,7 @@ func (b *IVFBuilder) evaluateTermSearch(idx *IVF, prune0 bool, term string) floa
samples := b.rng.Sample(0, len(b.data), testSize)
var result, count float32
var mu sync.Mutex
- _ = parallel.Parallel(len(samples), idx.numJobs, func(_, i int) error {
+ _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error {
sample := samples[i]
expected, _ := b.bruteForce.MultiSearch(b.data[sample], []string{term}, b.k, prune0)
if len(expected) > 0 {
diff --git a/base/search/ivf_test.go b/base/search/ivf_test.go
index 8f9ba1ffb..da875d4e8 100644
--- a/base/search/ivf_test.go
+++ b/base/search/ivf_test.go
@@ -15,8 +15,10 @@
package search
import (
- "github.com/stretchr/testify/assert"
"testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/zhenghaoz/gorse/base/task"
)
func TestIVFConfig(t *testing.T) {
@@ -28,8 +30,8 @@ func TestIVFConfig(t *testing.T) {
SetClusterErrorRate(0.123)(ivf)
assert.Equal(t, float32(0.123), ivf.errorRate)
- SetIVFNumJobs(234)(ivf)
- assert.Equal(t, 234, ivf.numJobs)
+ SetIVFJobsAllocator(task.NewConstantJobsAllocator(234))(ivf)
+ assert.Equal(t, 234, ivf.jobsAlloc.AvailableJobs(nil))
SetMaxIteration(345)(ivf)
assert.Equal(t, 345, ivf.maxIter)
diff --git a/base/task/schedule.go b/base/task/schedule.go
index b5308e09c..bc0860258 100644
--- a/base/task/schedule.go
+++ b/base/task/schedule.go
@@ -15,70 +15,181 @@
package task
import (
- "github.com/scylladb/go-set/strset"
+ "runtime"
+ "sort"
"sync"
+
+ "github.com/samber/lo"
+ "github.com/zhenghaoz/gorse/base/log"
+ "go.uber.org/zap"
+ "modernc.org/mathutil"
)
-// Scheduler schedules that pre-locked tasks are executed first.
-type Scheduler struct {
- *sync.Cond
- Privileged *strset.Set
- Running bool
+type JobsAllocator struct {
+ numJobs int // the max number of jobs
+ taskName string // its task name in scheduler
+ scheduler *JobsScheduler
}
-// NewTaskScheduler creates a Scheduler.
-func NewTaskScheduler() *Scheduler {
- return &Scheduler{
- Cond: sync.NewCond(&sync.Mutex{}),
- Privileged: strset.New(),
+func NewConstantJobsAllocator(num int) *JobsAllocator {
+ return &JobsAllocator{
+ numJobs: num,
}
}
-// PreLock a task, the task has the privilege to run first than un-pre-clocked tasks.
-func (t *Scheduler) PreLock(name string) {
- t.L.Lock()
- defer t.L.Unlock()
- t.Privileged.Add(name)
+func (allocator *JobsAllocator) MaxJobs() int {
+ if allocator == nil || allocator.numJobs < 1 {
+ // Return 1 for invalid allocator
+ return 1
+ }
+ return allocator.numJobs
}
-// Lock gets the permission to run task.
-func (t *Scheduler) Lock(name string) {
- t.L.Lock()
- defer t.L.Unlock()
- for t.Running || (!t.Privileged.IsEmpty() && !t.Privileged.Has(name)) {
- t.Wait()
+func (allocator *JobsAllocator) AvailableJobs(tracker *Task) int {
+ if allocator == nil || allocator.numJobs < 1 {
+ // Return 1 for invalid allocator
+ return 1
+ } else if allocator.scheduler != nil {
+ // Use jobs scheduler
+ return allocator.scheduler.allocateJobsForTask(allocator.taskName, true, tracker)
}
- t.Running = true
+ return allocator.numJobs
}
-// UnLock returns the permission to run task.
-func (t *Scheduler) UnLock(name string) {
- t.L.Lock()
- defer t.L.Unlock()
- t.Running = false
- t.Privileged.Remove(name)
- t.Broadcast()
+// Init jobs allocation. This method is used to request allocation of jobs for the first time.
+func (allocator *JobsAllocator) Init() {
+ if allocator.scheduler != nil {
+ allocator.scheduler.allocateJobsForTask(allocator.taskName, true, nil)
+ }
}
-func (t *Scheduler) NewRunner(name string) *Runner {
- return &Runner{
- Scheduler: t,
- Name: name,
+// taskInfo represents a task in JobsScheduler.
+type taskInfo struct {
+ name string // name of the task
+ priority int // high priority tasks are allocated first
+ privileged bool // privileged tasks are allocated first
+ jobs int // number of jobs allocated to the task
+ previous int // previous number of jobs allocated to the task
+}
+
+// JobsScheduler allocates jobs to multiple tasks.
+type JobsScheduler struct {
+ *sync.Cond
+ numJobs int // number of jobs
+ freeJobs int // number of free jobs
+ tasks map[string]*taskInfo
+}
+
+// NewJobsScheduler creates a JobsScheduler with num jobs.
+func NewJobsScheduler(num int) *JobsScheduler {
+ if num <= 0 {
+ // Use all cores if num is less than 1.
+ num = runtime.NumCPU()
+ }
+ return &JobsScheduler{
+ Cond: sync.NewCond(&sync.Mutex{}),
+ numJobs: num,
+ freeJobs: num,
+ tasks: make(map[string]*taskInfo),
}
}
-// Runner is a Scheduler bounded with a task.
-type Runner struct {
- *Scheduler
- Name string
+// Register a task in the JobsScheduler. Registered tasks will be ignored and return false.
+func (s *JobsScheduler) Register(taskName string, priority int, privileged bool) bool {
+ s.L.Lock()
+ defer s.L.Unlock()
+ if _, exits := s.tasks[taskName]; !exits {
+ s.tasks[taskName] = &taskInfo{name: taskName, priority: priority, privileged: privileged}
+ return true
+ } else {
+ return false
+ }
}
-// Lock gets the permission to run task.
-func (locker *Runner) Lock() {
- locker.Scheduler.Lock(locker.Name)
+// Unregister a task from the JobsScheduler.
+func (s *JobsScheduler) Unregister(taskName string) {
+ s.L.Lock()
+ defer s.L.Unlock()
+ if task, exits := s.tasks[taskName]; exits {
+ // Return allocated jobs.
+ s.freeJobs += task.jobs
+ delete(s.tasks, taskName)
+ s.Broadcast()
+ }
}
-// UnLock returns the permission to run task.
-func (locker *Runner) UnLock() {
- locker.Scheduler.UnLock(locker.Name)
+func (s *JobsScheduler) GetJobsAllocator(taskName string) *JobsAllocator {
+ return &JobsAllocator{
+ numJobs: s.numJobs,
+ taskName: taskName,
+ scheduler: s,
+ }
+}
+
+func (s *JobsScheduler) allocateJobsForTask(taskName string, block bool, tracker *Task) int {
+ // Find current task and return the jobs temporarily.
+ s.L.Lock()
+ currentTask, exist := s.tasks[taskName]
+ if !exist {
+ panic("task not found")
+ }
+ s.freeJobs += currentTask.jobs
+ currentTask.jobs = 0
+ s.L.Unlock()
+
+ s.L.Lock()
+ defer s.L.Unlock()
+ for {
+ s.allocateJobsForAll()
+ if currentTask.jobs == 0 && block {
+ tracker.Suspend(true)
+ s.Wait()
+ } else {
+ tracker.Suspend(false)
+ return currentTask.jobs
+ }
+ }
+}
+
+func (s *JobsScheduler) allocateJobsForAll() {
+ // Separate privileged tasks and normal tasks
+ privileged := make([]*taskInfo, 0)
+ normal := make([]*taskInfo, 0)
+ for _, task := range s.tasks {
+ if task.privileged {
+ privileged = append(privileged, task)
+ } else {
+ normal = append(normal, task)
+ }
+ }
+
+ var tasks []*taskInfo
+ if len(privileged) > 0 {
+ tasks = privileged
+ } else {
+ tasks = normal
+ }
+
+ // allocate jobs
+ sort.Slice(tasks, func(i, j int) bool {
+ return tasks[i].priority > tasks[j].priority
+ })
+ for i, task := range tasks {
+ if s.freeJobs == 0 {
+ return
+ }
+ targetJobs := s.numJobs/len(tasks) + lo.If(i < s.numJobs%len(tasks), 1).Else(0)
+ targetJobs = mathutil.Min(targetJobs, s.freeJobs)
+ if task.jobs < targetJobs {
+ if task.previous != targetJobs {
+ log.Logger().Debug("reallocate jobs for task",
+ zap.String("task", task.name),
+ zap.Int("previous_jobs", task.previous),
+ zap.Int("target_jobs", targetJobs))
+ }
+ s.freeJobs -= targetJobs - task.jobs
+ task.jobs = targetJobs
+ task.previous = task.jobs
+ }
+ }
}
diff --git a/base/task/schedule_test.go b/base/task/schedule_test.go
index 614c5e70f..c3541bcfb 100644
--- a/base/task/schedule_test.go
+++ b/base/task/schedule_test.go
@@ -1,51 +1,109 @@
+// Copyright 2022 gorse Project Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package task
import (
- "fmt"
- "github.com/stretchr/testify/assert"
"sync"
"testing"
+
+ "github.com/stretchr/testify/assert"
)
-func TestTaskScheduler(t *testing.T) {
- taskScheduler := NewTaskScheduler()
+func TestConstantJobsAllocator(t *testing.T) {
+ allocator := NewConstantJobsAllocator(314)
+ assert.Equal(t, 314, allocator.MaxJobs())
+ assert.Equal(t, 314, allocator.AvailableJobs(nil))
+
+ allocator = NewConstantJobsAllocator(-1)
+ assert.Equal(t, 1, allocator.MaxJobs())
+ assert.Equal(t, 1, allocator.AvailableJobs(nil))
+
+ allocator = nil
+ assert.Equal(t, 1, allocator.MaxJobs())
+ assert.Equal(t, 1, allocator.AvailableJobs(nil))
+}
+
+func TestDynamicJobsAllocator(t *testing.T) {
+ s := NewJobsScheduler(8)
+ s.Register("a", 1, true)
+ s.Register("b", 2, true)
+ s.Register("c", 3, true)
+ s.Register("d", 4, false)
+ s.Register("e", 4, false)
+ c := s.GetJobsAllocator("c")
+ assert.Equal(t, 8, c.MaxJobs())
+ assert.Equal(t, 3, c.AvailableJobs(nil))
+ b := s.GetJobsAllocator("b")
+ assert.Equal(t, 3, b.AvailableJobs(nil))
+ a := s.GetJobsAllocator("a")
+ assert.Equal(t, 2, a.AvailableJobs(nil))
+
+ barrier := make(chan struct{})
var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ barrier <- struct{}{}
+ d := s.GetJobsAllocator("d")
+ assert.Equal(t, 4, d.AvailableJobs(nil))
+ }()
+ go func() {
+ defer wg.Done()
+ barrier <- struct{}{}
+ e := s.GetJobsAllocator("e")
+ e.Init()
+ assert.Equal(t, 4, s.allocateJobsForTask("e", false, nil))
+ }()
- // pre-lock for privileged tasks
- for i := 0; i < 50; i++ {
- taskScheduler.PreLock(fmt.Sprintf("privileged_%d", i))
- }
-
- // start ragtag tasks
- result := make([]string, 0, 1000)
- for i := 0; i < 50; i++ {
- wg.Add(1)
- go func(name string) {
- taskScheduler.Lock(name)
- result = append(result, name)
- taskScheduler.UnLock(name)
- wg.Done()
- }(fmt.Sprintf("ragtag_%d", i))
- }
-
- // start privileged tasks
- for i := 0; i < 50; i++ {
- wg.Add(1)
- go func(locker *Runner) {
- locker.Lock()
- result = append(result, locker.Name)
- locker.UnLock()
- wg.Done()
- }(taskScheduler.NewRunner(fmt.Sprintf("privileged_%d", i)))
- }
-
- // check result
+ <-barrier
+ <-barrier
+ s.Unregister("a")
+ s.Unregister("b")
+ s.Unregister("c")
wg.Wait()
- for i := 0; i < 100; i++ {
- if i < 50 {
- assert.Contains(t, result[i], "privileged_")
- } else {
- assert.Contains(t, result[i], "ragtag_")
- }
- }
+}
+
+func TestJobsScheduler(t *testing.T) {
+ s := NewJobsScheduler(8)
+ assert.True(t, s.Register("a", 1, true))
+ assert.True(t, s.Register("b", 2, true))
+ assert.True(t, s.Register("c", 3, true))
+ assert.True(t, s.Register("d", 4, false))
+ assert.True(t, s.Register("e", 4, false))
+ assert.False(t, s.Register("c", 1, true))
+ assert.Equal(t, 3, s.allocateJobsForTask("c", false, nil))
+ assert.Equal(t, 3, s.allocateJobsForTask("b", false, nil))
+ assert.Equal(t, 2, s.allocateJobsForTask("a", false, nil))
+ assert.Equal(t, 0, s.allocateJobsForTask("d", false, nil))
+ assert.Equal(t, 0, s.allocateJobsForTask("e", false, nil))
+
+ // several tasks complete
+ s.Unregister("b")
+ s.Unregister("c")
+ assert.Equal(t, 8, s.allocateJobsForTask("a", false, nil))
+
+ // privileged tasks complete
+ s.Unregister("a")
+ assert.Equal(t, 4, s.allocateJobsForTask("d", false, nil))
+ assert.Equal(t, 4, s.allocateJobsForTask("e", false, nil))
+
+ // block privileged tasks if normal tasks are running
+ s.Register("a", 1, true)
+ s.Register("b", 2, true)
+ s.Register("c", 3, true)
+ assert.Equal(t, 0, s.allocateJobsForTask("c", false, nil))
+ assert.Equal(t, 0, s.allocateJobsForTask("b", false, nil))
+ assert.Equal(t, 0, s.allocateJobsForTask("a", false, nil))
}
diff --git a/config/config.toml.template b/config/config.toml.template
index f80e73ff5..f07dea27d 100644
--- a/config/config.toml.template
+++ b/config/config.toml.template
@@ -2,7 +2,9 @@
# The database for caching, support Redis, MySQL, Postgres and MongoDB:
# redis://:@:/
+# rediss://:@:/
# postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full
+# postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full
# mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]]
# mongodb+srv://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]]
cache_store = "redis://localhost:6379/0"
@@ -10,7 +12,10 @@ cache_store = "redis://localhost:6379/0"
# The database for persist data, support MySQL, Postgres, ClickHouse and MongoDB:
# mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]
# postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full
+# postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full
# clickhouse://user:password@host[:port]/database?param1=value1&...¶mN=valueN
+# chhttp://user:password@host[:port]/database?param1=value1&...¶mN=valueN
+# chhttps://user:password@host[:port]/database?param1=value1&...¶mN=valueN
# mongodb://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]]
# mongodb+srv://[username:password@]host1[:port1][,...hostN[:portN]][/[defaultauthdb][?options]]
data_store = "mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse"
diff --git a/docker/docker-compose.yml b/docker-compose.yml
similarity index 85%
rename from docker/docker-compose.yml
rename to docker-compose.yml
index 8bfb7b91f..7ccdd4050 100644
--- a/docker/docker-compose.yml
+++ b/docker-compose.yml
@@ -20,7 +20,9 @@ services:
- mysql_data:/var/lib/mysql
worker:
- image: zhenghaoz/gorse-worker:nightly
+ build:
+ context: .
+ dockerfile: cmd/gorse-worker/Dockerfile
restart: unless-stopped
ports:
- 8089:8089
@@ -36,7 +38,9 @@ services:
- master
server:
- image: zhenghaoz/gorse-server:nightly
+ build:
+ context: .
+ dockerfile: cmd/gorse-server/Dockerfile
restart: unless-stopped
ports:
- 8087:8087
@@ -52,7 +56,9 @@ services:
- master
master:
- image: zhenghaoz/gorse-master:nightly
+ build:
+ context: .
+ dockerfile: cmd/gorse-master/Dockerfile
restart: unless-stopped
ports:
- 8086:8086
@@ -65,7 +71,7 @@ services:
--log-path /var/log/gorse/master.log
--cache-path /var/lib/gorse/master_cache.data
volumes:
- - ./config.toml:/etc/gorse/config.toml
+ - ./config/config.toml.template:/etc/gorse/config.toml
- gorse_log:/var/log/gorse
- master_data:/var/lib/gorse
depends_on:
diff --git a/docker/config.toml b/docker/config.toml
deleted file mode 100644
index a8631d298..000000000
--- a/docker/config.toml
+++ /dev/null
@@ -1,192 +0,0 @@
-[master]
-
-# GRPC port of the master node. The default value is 8086.
-port = 8086
-
-# gRPC host of the master node. The default values is "0.0.0.0".
-host = "0.0.0.0"
-
-# HTTP port of the master node. The default values is 8088.
-http_port = 8088
-
-# HTTP host of the master node. The default values is "0.0.0.0".
-http_host = "0.0.0.0"
-
-# Number of working jobs in the master node. The default value is 1.
-n_jobs = 1
-
-# Meta information timeout. The default value is 10s.
-meta_timeout = "10s"
-
-# Username for the master node dashboard.
-dashboard_user_name = ""
-
-# Password for the master node dashboard.
-dashboard_password = ""
-
-[server]
-
-# Default number of returned items. The default value is 10.
-default_n = 10
-
-# Secret key for RESTful APIs (SSL required).
-api_key = ""
-
-# Clock error in the cluster. The default value is 5s.
-clock_error = "5s"
-
-# Insert new users while inserting feedback. The default value is true.
-auto_insert_user = true
-
-# Insert new items while inserting feedback. The default value is true.
-auto_insert_item = false
-
-# Server-side cache expire time. The default value is 10s.
-cache_expire = "10s"
-
-[recommend]
-
-# The cache size for recommended/popular/latest items. The default value is 10.
-cache_size = 100
-
-# Recommended cache expire time. The default value is 72h.
-cache_expire = "72h"
-
-[recommend.data_source]
-
-# The feedback types for positive events.
-positive_feedback_types = ["star","like"]
-
-# The feedback types for read events.
-read_feedback_types = ["read"]
-
-# The time-to-live (days) of positive feedback, 0 means disabled. The default value is 0.
-positive_feedback_ttl = 0
-
-# The time-to-live (days) of items, 0 means disabled. The default value is 0.
-item_ttl = 0
-
-[recommend.popular]
-
-# The time window of popular items. The default values is 4320h.
-popular_window = "720h"
-
-[recommend.user_neighbors]
-
-# The type of neighbors for users. There are three types:
-# similar: Neighbors are found by number of common labels.
-# related: Neighbors are found by number of common liked items.
-# auto: If a user have labels, neighbors are found by number of common labels.
-# If this user have no labels, neighbors are found by number of common liked items.
-# The default value is "auto".
-neighbor_type = "similar"
-
-# Enable approximate user neighbor searching using vector index. The default value is true.
-enable_index = true
-
-# Minimal recall for approximate user neighbor searching. The default value is 0.8.
-index_recall = 0.8
-
-# Maximal number of fit epochs for approximate user neighbor searching vector index. The default value is 3.
-index_fit_epoch = 3
-
-[recommend.item_neighbors]
-
-# The type of neighbors for items. There are three types:
-# similar: Neighbors are found by number of common labels.
-# related: Neighbors are found by number of common users.
-# auto: If a item have labels, neighbors are found by number of common labels.
-# If this item have no labels, neighbors are found by number of common users.
-# The default value is "auto".
-neighbor_type = "similar"
-
-# Enable approximate item neighbor searching using vector index. The default value is true.
-enable_index = true
-
-# Minimal recall for approximate item neighbor searching. The default value is 0.8.
-index_recall = 0.8
-
-# Maximal number of fit epochs for approximate item neighbor searching vector index. The default value is 3.
-index_fit_epoch = 3
-
-[recommend.collaborative]
-
-# Enable approximate collaborative filtering recommend using vector index. The default value is true.
-enable_index = true
-
-# Minimal recall for approximate collaborative filtering recommend. The default value is 0.9.
-index_recall = 0.9
-
-# Maximal number of fit epochs for approximate collaborative filtering recommend vector index. The default value is 3.
-index_fit_epoch = 3
-
-# The time period for model fitting. The default value is "60m".
-model_fit_period = "60m"
-
-# The time period for model searching. The default value is "360m".
-model_search_period = "360m"
-
-# The number of epochs for model searching. The default value is 100.
-model_search_epoch = 100
-
-# The number of trials for model searching. The default value is 10.
-model_search_trials = 10
-
-# Enable searching models of different sizes, which consume more memory. The default value is false.
-enable_model_size_search = false
-
-[recommend.replacement]
-
-# Replace historical items back to recommendations. The default value is false.
-enable_replacement = false
-
-# Decay the weights of replaced items from positive feedbacks. The default value is 0.8.
-positive_replacement_decay = 0.8
-
-# Decay the weights of replaced items from read feedbacks. The default value is 0.6.
-read_replacement_decay = 0.6
-
-[recommend.offline]
-
-# The time period to check recommendation for users. The default values is 1m.
-check_recommend_period = "1m"
-
-# The time period to refresh recommendation for inactive users. The default values is 120h.
-refresh_recommend_period = "24h"
-
-# Enable latest recommendation during offline recommendation. The default value is false.
-enable_latest_recommend = true
-
-# Enable popular recommendation during offline recommendation. The default value is false.
-enable_popular_recommend = false
-
-# Enable user-based similarity recommendation during offline recommendation. The default value is false.
-enable_user_based_recommend = true
-
-# Enable item-based similarity recommendation during offline recommendation. The default value is false.
-enable_item_based_recommend = false
-
-# Enable collaborative filtering recommendation during offline recommendation. The default value is true.
-enable_collaborative_recommend = true
-
-# Enable click-though rate prediction during offline recommendation. Otherwise, results from multi-way recommendation
-# would be merged randomly. The default value is false.
-enable_click_through_prediction = true
-
-# The explore recommendation method is used to inject popular items or latest items into recommended result:
-# popular: Recommend popular items to cold-start users.
-# latest: Recommend latest items to cold-start users.
-# The default values is { popular = 0.0, latest = 0.0 }.
-explore_recommend = { popular = 0.1, latest = 0.2 }
-
-[recommend.online]
-
-# The fallback recommendation method is used when cached recommendation drained out:
-# item_based: Recommend similar items to cold-start users.
-# popular: Recommend popular items to cold-start users.
-# latest: Recommend latest items to cold-start users.
-# Recommenders are used in order. The default values is ["latest"].
-fallback_recommend = ["item_based", "latest"]
-
-# The number of feedback used in fallback item-based similar recommendation. The default values is 10.
-num_feedback_fallback_item_based = 10
diff --git a/go.mod b/go.mod
index a8abb7436..57ddda4a3 100644
--- a/go.mod
+++ b/go.mod
@@ -4,13 +4,13 @@ go 1.18
require (
github.com/ReneKroon/ttlcache/v2 v2.11.0
- github.com/alicebob/miniredis/v2 v2.16.1
+ github.com/alicebob/miniredis/v2 v2.23.0
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de
github.com/benhoyt/goawk v1.20.0
github.com/bits-and-blooms/bitset v1.2.1
github.com/chewxy/math32 v1.10.1
github.com/emicklei/go-restful-openapi/v2 v2.9.0
- github.com/emicklei/go-restful/v3 v3.8.0
+ github.com/emicklei/go-restful/v3 v3.9.0
github.com/go-playground/locales v0.14.0
github.com/go-playground/universal-translator v0.18.0
github.com/go-playground/validator/v10 v10.11.0
@@ -29,24 +29,23 @@ require (
github.com/mailru/go-clickhouse/v2 v2.0.0
github.com/mitchellh/mapstructure v1.5.0
github.com/orcaman/concurrent-map v1.0.0
- github.com/prometheus/client_golang v1.12.2
+ github.com/prometheus/client_golang v1.13.0
github.com/rakyll/statik v0.1.7
- github.com/samber/lo v1.25.0
- github.com/schollz/progressbar/v3 v3.8.7
+ github.com/samber/lo v1.27.0
+ github.com/schollz/progressbar/v3 v3.9.0
github.com/scylladb/go-set v1.0.2
github.com/sijms/go-ora/v2 v2.4.27
github.com/spf13/cobra v1.5.0
github.com/spf13/viper v1.12.0
- github.com/steinfletcher/apitest v1.5.11
- github.com/stretchr/testify v1.7.5
+ github.com/steinfletcher/apitest v1.5.12
+ github.com/stretchr/testify v1.8.0
github.com/thoas/go-funk v0.9.2
- go.mongodb.org/mongo-driver v1.10.0
- go.uber.org/atomic v1.9.0
- go.uber.org/zap v1.21.0
- golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75
- gonum.org/v1/gonum v0.0.0-20190409070159-6e46824336d2
+ go.mongodb.org/mongo-driver v1.10.1
+ go.uber.org/atomic v1.10.0
+ go.uber.org/zap v1.22.0
+ golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e
google.golang.org/grpc v1.48.0
- google.golang.org/protobuf v1.28.0
+ google.golang.org/protobuf v1.28.1
gopkg.in/yaml.v2 v2.4.0
gorm.io/driver/clickhouse v0.4.2
gorm.io/driver/mysql v1.3.4
@@ -107,28 +106,27 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
- github.com/prometheus/procfs v0.7.3 // indirect
+ github.com/prometheus/procfs v0.8.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect
- github.com/rivo/uniseg v0.2.0 // indirect
+ github.com/rivo/uniseg v0.3.4 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
- github.com/stretchr/objx v0.4.0 // indirect
github.com/subosito/gotenv v1.4.0 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.1 // indirect
github.com/xdg-go/stringprep v1.0.3 // indirect
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect
- github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da // indirect
+ github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 // indirect
go.uber.org/multierr v1.8.0 // indirect
- golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
+ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/net v0.0.0-20220708220712-1185a9018129 // indirect
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
- golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
- golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 // indirect
+ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab // indirect
+ golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.11 // indirect
google.golang.org/genproto v0.0.0-20220719170305-83ca9fad585f // indirect
diff --git a/go.sum b/go.sum
index 5638e717c..3c05f7012 100644
--- a/go.sum
+++ b/go.sum
@@ -51,8 +51,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
-github.com/alicebob/miniredis/v2 v2.16.1 h1:ikfCfUHWlfiVCVVaaDO60SBgPWS4UNIi1A7p7QmUVyw=
-github.com/alicebob/miniredis/v2 v2.16.1/go.mod h1:gquAfGbzn92jvtrSC69+6zZnwSODVXVpYDRaGhWaL6I=
+github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
+github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de h1:FxWPpzIjnTlhPwqqXc4/vE0f7GvRjuAsbW+HOIe8KnA=
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoUhZrYW9p1lxo/cm8EmUOOzAPSEZNGF2DK1dJgw=
@@ -100,8 +100,8 @@ github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25Kn
github.com/emicklei/go-restful-openapi/v2 v2.9.0 h1:djsWqjhI0EVYfkLCCX6jZxUkLmYUq2q9tt09ZbixfyE=
github.com/emicklei/go-restful-openapi/v2 v2.9.0/go.mod h1:VKNgZyYviM1hnyrjD9RDzP2RuE94xTXxV+u6MGN4v4k=
github.com/emicklei/go-restful/v3 v3.7.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
-github.com/emicklei/go-restful/v3 v3.8.0 h1:eCZ8ulSerjdAiaNpF7GxXIE7ZCMo1moN1qX+S609eVw=
-github.com/emicklei/go-restful/v3 v3.8.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
+github.com/emicklei/go-restful/v3 v3.9.0 h1:XwGDlfxEnQZzuopoqxwSEllNcCOM9DhhFyhFIIGKwxE=
+github.com/emicklei/go-restful/v3 v3.9.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@@ -415,8 +415,8 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
-github.com/prometheus/client_golang v1.12.2 h1:51L9cDoUHVrXx4zWYlcLQIZ+d+VXHgqnYKkIuq4g/34=
-github.com/prometheus/client_golang v1.12.2/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
+github.com/prometheus/client_golang v1.13.0 h1:b71QUfeo5M8gq2+evJdTPfZhYMAU0uKPkyPJ7TPsloU=
+github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -432,16 +432,18 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
+github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
+github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/rakyll/statik v0.1.7 h1:OF3QCZUuyPxuGEP7B4ypUa7sB/iHtqOTDYZXGM8KOdQ=
github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Unghqrcc=
github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
-github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
+github.com/rivo/uniseg v0.3.4 h1:3Z3Eu6FGHZWSfNKJTOUiPatWwfc7DzJRU04jFUqJODw=
+github.com/rivo/uniseg v0.3.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
@@ -451,11 +453,11 @@ github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
-github.com/samber/lo v1.25.0 h1:H8F6cB0RotRdgcRCivTByAQePaYhGMdOTJIj2QFS2I0=
-github.com/samber/lo v1.25.0/go.mod h1:2I7tgIv8Q1SG2xEIkRq0F2i2zgxVpnyPOP0d3Gj2r+A=
+github.com/samber/lo v1.27.0 h1:GOyDWxsblvqYobqsmUuMddPa2/mMzkKyojlXol4+LaQ=
+github.com/samber/lo v1.27.0/go.mod h1:it33p9UtPMS7z72fP4gw/EIfQB2eI8ke7GR2wc6+Rhg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
-github.com/schollz/progressbar/v3 v3.8.7 h1:rtje4lnXVD1Dy/RtPpGd2ijLCmQ7Su3G2ia8dJcRKIo=
-github.com/schollz/progressbar/v3 v3.8.7/go.mod h1:W5IEwbJecncFGBvuEh4A7HT1nZZ6WNIL2i3qbnI0WKY=
+github.com/schollz/progressbar/v3 v3.9.0 h1:k9SRNQ8KZyibz1UZOaKxnkUE3iGtmGSDt1YY9KlCYQk=
+github.com/schollz/progressbar/v3 v3.9.0/go.mod h1:W5IEwbJecncFGBvuEh4A7HT1nZZ6WNIL2i3qbnI0WKY=
github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE=
github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs=
github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=
@@ -481,12 +483,11 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.12.0 h1:CZ7eSOd3kZoaYDLbXnmzgQI5RlciuXBMA+18HwHRfZQ=
github.com/spf13/viper v1.12.0/go.mod h1:b6COn30jlNxbm/V2IqWiNWkJ+vZNiMNksliPCiuKtSI=
-github.com/steinfletcher/apitest v1.5.11 h1:bG3hq3sA4+oPHln3O/xQ6LzsQgN0J2WJl+6EpydQZ8Q=
-github.com/steinfletcher/apitest v1.5.11/go.mod h1:cf7Bneo52IIAgpqhP8xaLlzWgAiQ9fHtsDMjeDnZ3so=
+github.com/steinfletcher/apitest v1.5.12 h1:zv+UiSXxDspQ8R7DILKiQVIlamDH7ufCsyQuzWsn3H4=
+github.com/steinfletcher/apitest v1.5.12/go.mod h1:cf7Bneo52IIAgpqhP8xaLlzWgAiQ9fHtsDMjeDnZ3so=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
-github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -496,8 +497,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
-github.com/stretchr/testify v1.7.5 h1:s5PTfem8p8EbKQOctVV53k6jCJt3UX4IEJzwh+C324Q=
-github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/subosito/gotenv v1.4.0 h1:yAzM1+SmVcz5R4tXGsNMu1jUl2aOJXoiWUCEwwnGrvs=
github.com/subosito/gotenv v1.4.0/go.mod h1:mZd6rFysKEcUhUHXJk0C/08wAgyDBFuwEYL7vWWGaGo=
github.com/thoas/go-funk v0.9.2 h1:oKlNYv0AY5nyf9g+/GhMgS/UO2ces0QRdPKwkhY3VCk=
@@ -518,11 +519,12 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
-github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da h1:NimzV1aGyq29m5ukMK0AMWEhFaL/lrEOaephfuoiARg=
-github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
+github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
+github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 h1:5mLPGnFdSsevFRFc9q3yYbBkB6tsm4aCwwQV/j1JQAQ=
+github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
-go.mongodb.org/mongo-driver v1.10.0 h1:UtV6N5k14upNp4LTduX0QCufG124fSu25Wz9tu94GLg=
-go.mongodb.org/mongo-driver v1.10.0/go.mod h1:wsihk0Kdgv8Kqu1Anit4sfK+22vSFbUrAVEYRhCXrA8=
+go.mongodb.org/mongo-driver v1.10.1 h1:NujsPveKwHaWuKUer/ceo9DzEe7HIj1SlJ6uvXZG0S4=
+go.mongodb.org/mongo-driver v1.10.1/go.mod h1:z4XpeoU6w+9Vht+jAFyLgVrD+jGSQQe0+CBWFHNiHt8=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
@@ -535,8 +537,8 @@ go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
-go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
-go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
+go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
+go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
@@ -551,8 +553,9 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
-go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8=
go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw=
+go.uber.org/zap v1.22.0 h1:Zcye5DUgBloQ9BaT4qc9BnjOFog5TvBSAGkJ3Nf70c0=
+go.uber.org/zap v1.22.0/go.mod h1:H4siCOZOrAolnUPJEkfaSjDqyP+BDS0DdDWzwcgt3+U=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
@@ -570,10 +573,10 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
+golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
-golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
@@ -583,8 +586,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
-golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75 h1:x03zeu7B2B11ySp+daztnwM5oBJ/8wGUSqrwcw9L0RA=
-golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA=
+golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA=
+golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -738,13 +741,13 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
-golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab h1:2QkjZIsXupsJbJIdSjjUOgWK3aEtzyuh2mPt3l/CkeU=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
-golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
-golang.org/x/term v0.0.0-20220526004731-065cf7ba2467/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc=
+golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -759,7 +762,6 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
@@ -822,10 +824,6 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-gonum.org/v1/gonum v0.0.0-20190409070159-6e46824336d2 h1:IZ343hUHFQ/v9IrSpcT1ih//e8VIctin7w2WyVl4kcc=
-gonum.org/v1/gonum v0.0.0-20190409070159-6e46824336d2/go.mod h1:2ltnJ7xHfj0zHS40VVPYEAAMTa3ZGguvHGBSJeRWqE0=
-gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
-gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
@@ -924,8 +922,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
-google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
-google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
+google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/master/master.go b/master/master.go
index b27c1bce7..1d9c274a1 100644
--- a/master/master.go
+++ b/master/master.go
@@ -17,15 +17,6 @@ package master
import (
"context"
"fmt"
- "github.com/emicklei/go-restful/v3"
- "github.com/juju/errors"
- "github.com/zhenghaoz/gorse/base/encoding"
- "github.com/zhenghaoz/gorse/base/log"
- "github.com/zhenghaoz/gorse/base/task"
- "github.com/zhenghaoz/gorse/model/click"
- "github.com/zhenghaoz/gorse/server"
- "go.uber.org/zap"
- "google.golang.org/grpc"
"math"
"math/rand"
"net"
@@ -33,12 +24,21 @@ import (
"time"
"github.com/ReneKroon/ttlcache/v2"
+ "github.com/emicklei/go-restful/v3"
+ "github.com/juju/errors"
"github.com/zhenghaoz/gorse/base"
+ "github.com/zhenghaoz/gorse/base/encoding"
+ "github.com/zhenghaoz/gorse/base/log"
+ "github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/config"
+ "github.com/zhenghaoz/gorse/model/click"
"github.com/zhenghaoz/gorse/model/ranking"
"github.com/zhenghaoz/gorse/protocol"
+ "github.com/zhenghaoz/gorse/server"
"github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
+ "go.uber.org/zap"
+ "google.golang.org/grpc"
)
// Master is the master node.
@@ -48,7 +48,7 @@ type Master struct {
grpcServer *grpc.Server
taskMonitor *task.Monitor
- taskScheduler *task.Scheduler
+ jobsScheduler *task.JobsScheduler
cacheFile string
// cluster meta cache
@@ -81,7 +81,8 @@ type Master struct {
// events
fitTicker *time.Ticker
- importedChan chan bool // feedback inserted events
+ importedChan chan struct{} // feedback inserted events
+ loadDataChan chan struct{} // dataset loaded events
}
// NewMaster creates a master node.
@@ -99,20 +100,18 @@ func NewMaster(cfg *config.Config, cacheFile string) *Master {
// create task monitor
cacheFile: cacheFile,
taskMonitor: taskMonitor,
- taskScheduler: task.NewTaskScheduler(),
+ jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs),
// default ranking model
rankingModelName: "bpr",
rankingModelSearcher: ranking.NewModelSearcher(
cfg.Recommend.Collaborative.ModelSearchEpoch,
cfg.Recommend.Collaborative.ModelSearchTrials,
- cfg.Master.NumJobs,
cfg.Recommend.Collaborative.EnableModelSizeSearch,
),
// default click model
clickModelSearcher: click.NewModelSearcher(
cfg.Recommend.Collaborative.ModelSearchEpoch,
cfg.Recommend.Collaborative.ModelSearchTrials,
- cfg.Master.NumJobs,
cfg.Recommend.Collaborative.EnableModelSizeSearch,
),
RestServer: server.RestServer{
@@ -131,7 +130,8 @@ func NewMaster(cfg *config.Config, cacheFile string) *Master {
WebService: new(restful.WebService),
},
fitTicker: time.NewTicker(cfg.Recommend.Collaborative.ModelFitPeriod),
- importedChan: make(chan bool),
+ importedChan: make(chan struct{}),
+ loadDataChan: make(chan struct{}),
}
}
@@ -161,7 +161,7 @@ func (m *Master) Serve() {
CollaborativeFilteringPrecision10.Set(float64(m.rankingScore.Precision))
CollaborativeFilteringRecall10.Set(float64(m.rankingScore.Recall))
CollaborativeFilteringNDCG10.Set(float64(m.rankingScore.NDCG))
- MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes()))
+ MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes()))
}
if m.localCache.ClickModel != nil {
log.Logger().Info("load cached click model",
@@ -174,7 +174,7 @@ func (m *Master) Serve() {
RankingPrecision.Set(float64(m.clickScore.Precision))
RankingRecall.Set(float64(m.clickScore.Recall))
RankingAUC.Set(float64(m.clickScore.AUC))
- MemoryInuseBytesVec.WithLabelValues("ranking_model").Set(float64(m.ClickModel.Bytes()))
+ MemoryInUseBytesVec.WithLabelValues("ranking_model").Set(float64(m.ClickModel.Bytes()))
}
// create cluster meta cache
@@ -208,12 +208,6 @@ func (m *Master) Serve() {
m.RestServer.HiddenItemsManager = server.NewHiddenItemsManager(&m.RestServer)
m.RestServer.PopularItemsCache = server.NewPopularItemsCache(&m.RestServer)
- // pre-lock privileged tasks
- tasksNames := []string{TaskLoadDataset, TaskFindItemNeighbors, TaskFindUserNeighbors, TaskFitRankingModel, TaskFitClickModel}
- for _, taskName := range tasksNames {
- m.taskScheduler.PreLock(taskName)
- }
-
go m.RunPrivilegedTasksLoop()
log.Logger().Info("start model fit", zap.Duration("period", m.Config.Recommend.Collaborative.ModelFitPeriod))
go m.RunRagtagTasksLoop()
@@ -252,19 +246,20 @@ func (m *Master) Shutdown() {
func (m *Master) RunPrivilegedTasksLoop() {
defer base.CheckPanic()
var (
- lastNumRankingUsers int
- lastNumRankingItems int
- lastNumRankingFeedback int
- lastNumClickUsers int
- lastNumClickItems int
- lastNumClickFeedback int
- err error
+ err error
+ tasks = []Task{
+ NewFitClickModelTask(m),
+ NewFitRankingModelTask(m),
+ NewFindUserNeighborsTask(m),
+ NewFindItemNeighborsTask(m),
+ }
+ firstLoop = true
)
go func() {
- m.importedChan <- true
+ m.importedChan <- struct{}{}
for {
if m.checkDataImported() {
- m.importedChan <- true
+ m.importedChan <- struct{}{}
}
time.Sleep(time.Second)
}
@@ -274,11 +269,6 @@ func (m *Master) RunPrivilegedTasksLoop() {
case <-m.fitTicker.C:
case <-m.importedChan:
}
- // pre-lock privileged tasks
- tasksNames := []string{TaskLoadDataset, TaskFindItemNeighbors, TaskFindUserNeighbors, TaskFitRankingModel, TaskFitClickModel}
- for _, taskName := range tasksNames {
- m.taskScheduler.PreLock(taskName)
- }
// download dataset
err = m.runLoadDatasetTask()
@@ -286,27 +276,33 @@ func (m *Master) RunPrivilegedTasksLoop() {
log.Logger().Error("failed to load ranking dataset", zap.Error(err))
continue
}
-
- // fit ranking model
- lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedback, err =
- m.runRankingRelatedTasks(lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedback)
- if err != nil {
- log.Logger().Error("failed to fit ranking model", zap.Error(err))
+ if m.rankingTrainSet.UserCount() == 0 && m.rankingTrainSet.ItemCount() == 0 && m.rankingTrainSet.Count() == 0 {
+ log.Logger().Warn("empty ranking dataset",
+ zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
continue
}
- // fit click model
- lastNumClickUsers, lastNumClickItems, lastNumClickFeedback, err =
- m.runFitClickModelTask(lastNumClickUsers, lastNumClickItems, lastNumClickFeedback)
- if err != nil {
- log.Logger().Error("failed to fit click model", zap.Error(err))
- m.taskMonitor.Fail(TaskFitClickModel, err.Error())
- continue
+ if firstLoop {
+ m.loadDataChan <- struct{}{}
+ firstLoop = false
}
- // release locks
- for _, taskName := range tasksNames {
- m.taskScheduler.UnLock(taskName)
+ var registeredTask []Task
+ for _, t := range tasks {
+ if m.jobsScheduler.Register(t.name(), t.priority(), true) {
+ registeredTask = append(registeredTask, t)
+ }
+ }
+ for _, t := range registeredTask {
+ go func(task Task) {
+ j := m.jobsScheduler.GetJobsAllocator(task.name())
+ defer m.jobsScheduler.Unregister(task.name())
+ j.Init()
+ if err := task.run(j); err != nil {
+ log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
+ return
+ }
+ }(t)
}
}
}
@@ -315,40 +311,36 @@ func (m *Master) RunPrivilegedTasksLoop() {
// rankingModelSearcher, clickSearchedModel and clickSearchedScore.
func (m *Master) RunRagtagTasksLoop() {
defer base.CheckPanic()
+ <-m.loadDataChan
var (
- lastNumRankingUsers int
- lastNumRankingItems int
- lastNumRankingFeedbacks int
- lastNumClickUsers int
- lastNumClickItems int
- lastNumClickFeedbacks int
- err error
+ err error
+ tasks = []Task{
+ NewCacheGarbageCollectionTask(m),
+ NewSearchRankingModelTask(m),
+ NewSearchClickModelTask(m),
+ }
)
for {
- // garbage collection
- m.taskScheduler.Lock(TaskCacheGarbageCollection)
- if err = m.runCacheGarbageCollectionTask(); err != nil {
- log.Logger().Error("failed to collect garbage", zap.Error(err))
- m.taskMonitor.Fail(TaskCacheGarbageCollection, err.Error())
- }
- m.taskScheduler.UnLock(TaskCacheGarbageCollection)
- // search optimal ranking model
- lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedbacks, err =
- m.runSearchRankingModelTask(lastNumRankingUsers, lastNumRankingItems, lastNumRankingFeedbacks)
- if err != nil {
- log.Logger().Error("failed to search ranking model", zap.Error(err))
- m.taskMonitor.Fail(TaskSearchRankingModel, err.Error())
- time.Sleep(time.Minute)
+ if m.rankingTrainSet == nil || m.clickTrainSet == nil {
+ time.Sleep(time.Second)
continue
}
- // search optimal click model
- lastNumClickUsers, lastNumClickItems, lastNumClickFeedbacks, err =
- m.runSearchClickModelTask(lastNumClickUsers, lastNumClickItems, lastNumClickFeedbacks)
- if err != nil {
- log.Logger().Error("failed to search click model", zap.Error(err))
- m.taskMonitor.Fail(TaskSearchClickModel, err.Error())
- time.Sleep(time.Minute)
- continue
+ var registeredTask []Task
+ for _, t := range tasks {
+ if m.jobsScheduler.Register(t.name(), t.priority(), false) {
+ registeredTask = append(registeredTask, t)
+ }
+ }
+ for _, t := range registeredTask {
+ go func(task Task) {
+ defer m.jobsScheduler.Unregister(task.name())
+ j := m.jobsScheduler.GetJobsAllocator(task.name())
+ j.Init()
+ if err = task.run(j); err != nil {
+ log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
+ m.taskMonitor.Fail(task.name(), err.Error())
+ }
+ }(t)
}
time.Sleep(m.Config.Recommend.Collaborative.ModelSearchPeriod)
}
diff --git a/master/master_test.go b/master/master_test.go
index ae95721f1..3f2e17748 100644
--- a/master/master_test.go
+++ b/master/master_test.go
@@ -14,13 +14,14 @@
package master
import (
+ "testing"
+
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/config"
"github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
- "testing"
)
type mockMaster struct {
diff --git a/master/metrics.go b/master/metrics.go
index bdbad29df..2184e99f7 100644
--- a/master/metrics.go
+++ b/master/metrics.go
@@ -15,13 +15,14 @@
package master
import (
+ "time"
+
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/samber/lo"
"github.com/scylladb/go-set/i32set"
"github.com/zhenghaoz/gorse/server"
"github.com/zhenghaoz/gorse/storage/cache"
- "time"
)
const (
@@ -218,7 +219,7 @@ var (
Subsystem: "master",
Name: "negative_feedbacks_total",
})
- MemoryInuseBytesVec = promauto.NewGaugeVec(prometheus.GaugeOpts{
+ MemoryInUseBytesVec = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "gorse",
Subsystem: "master",
Name: "memory_inuse_bytes",
diff --git a/master/tasks.go b/master/tasks.go
index d3d57ff03..a79061b61 100644
--- a/master/tasks.go
+++ b/master/tasks.go
@@ -16,6 +16,11 @@ package master
import (
"fmt"
+ "math"
+ "sort"
+ "strings"
+ "time"
+
"github.com/chewxy/math32"
"github.com/juju/errors"
"github.com/samber/lo"
@@ -26,6 +31,7 @@ import (
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/base/parallel"
"github.com/zhenghaoz/gorse/base/search"
+ "github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/config"
"github.com/zhenghaoz/gorse/model/click"
"github.com/zhenghaoz/gorse/model/ranking"
@@ -33,11 +39,7 @@ import (
"github.com/zhenghaoz/gorse/storage/data"
"go.uber.org/atomic"
"go.uber.org/zap"
- "math"
"modernc.org/sortutil"
- "sort"
- "strings"
- "time"
)
const (
@@ -56,6 +58,12 @@ const (
similarityShrink = 100
)
+type Task interface {
+ name() string
+ priority() int
+ run(j *task.JobsAllocator) error
+}
+
// runLoadDatasetTask loads dataset.
func (m *Master) runLoadDatasetTask() error {
initialStartTime := time.Now()
@@ -172,8 +180,8 @@ func (m *Master) runLoadDatasetTask() error {
rankingDataset = nil
m.rankingModelMutex.Unlock()
LoadDatasetStepSecondsVec.WithLabelValues("split_ranking_dataset").Set(time.Since(startTime).Seconds())
- MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_train_set").Set(float64(m.rankingTrainSet.Bytes()))
- MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_test_set").Set(float64(m.rankingTestSet.Bytes()))
+ MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_train_set").Set(float64(m.rankingTrainSet.Bytes()))
+ MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_test_set").Set(float64(m.rankingTestSet.Bytes()))
// split click dataset
startTime = time.Now()
@@ -182,8 +190,8 @@ func (m *Master) runLoadDatasetTask() error {
clickDataset = nil
m.clickModelMutex.Unlock()
LoadDatasetStepSecondsVec.WithLabelValues("split_click_dataset").Set(time.Since(startTime).Seconds())
- MemoryInuseBytesVec.WithLabelValues("ranking_train_set").Set(float64(m.clickTrainSet.Bytes()))
- MemoryInuseBytesVec.WithLabelValues("ranking_test_set").Set(float64(m.clickTestSet.Bytes()))
+ MemoryInUseBytesVec.WithLabelValues("ranking_train_set").Set(float64(m.clickTrainSet.Bytes()))
+ MemoryInUseBytesVec.WithLabelValues("ranking_test_set").Set(float64(m.clickTestSet.Bytes()))
LoadDatasetTotalSeconds.Set(time.Since(initialStartTime).Seconds())
return nil
@@ -205,17 +213,49 @@ func (m *Master) estimateFindItemNeighborsComplexity(dataset *ranking.DataSet) i
return complexity
}
-// runFindItemNeighborsTask updates neighbors of items.
-func (m *Master) runFindItemNeighborsTask(dataset *ranking.DataSet) {
+// FindItemNeighborsTask updates neighbors of items.
+type FindItemNeighborsTask struct {
+ *Master
+ lastNumItems int
+ lastNumFeedback int
+}
+
+func NewFindItemNeighborsTask(m *Master) *FindItemNeighborsTask {
+ return &FindItemNeighborsTask{Master: m}
+}
+
+func (t *FindItemNeighborsTask) name() string {
+ return TaskFindItemNeighbors
+}
+
+func (t *FindItemNeighborsTask) priority() int {
+ return -t.rankingTrainSet.ItemCount() * t.rankingTrainSet.ItemCount()
+}
+
+func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error {
+ t.rankingDataMutex.RLock()
+ defer t.rankingDataMutex.RUnlock()
+ dataset := t.rankingTrainSet
+ numItems := dataset.ItemCount()
+ numFeedback := dataset.Count()
+
+ if numItems == 0 {
+ t.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.")
+ return nil
+ } else if numItems == t.lastNumItems && numFeedback == t.lastNumFeedback {
+ log.Logger().Info("No item neighbors need to be updated.")
+ return nil
+ }
+
startTaskTime := time.Now()
- m.taskMonitor.Start(TaskFindItemNeighbors, m.estimateFindItemNeighborsComplexity(dataset))
+ t.taskMonitor.Start(TaskFindItemNeighbors, t.estimateFindItemNeighborsComplexity(dataset))
log.Logger().Info("start searching neighbors of items",
- zap.Int("n_cache", m.Config.Recommend.CacheSize))
+ zap.Int("n_cache", t.Config.Recommend.CacheSize))
// create progress tracker
completed := make(chan struct{}, 1000)
go func() {
completedCount, previousCount := 0, 0
- ticker := time.NewTicker(time.Second)
+ ticker := time.NewTicker(time.Second * 10)
for {
select {
case _, ok := <-completed:
@@ -227,74 +267,78 @@ func (m *Master) runFindItemNeighborsTask(dataset *ranking.DataSet) {
throughput := completedCount - previousCount
previousCount = completedCount
if throughput > 0 {
- m.taskMonitor.Add(TaskFindItemNeighbors, throughput*dataset.ItemCount())
+ t.taskMonitor.Add(TaskFindItemNeighbors, throughput*dataset.ItemCount())
log.Logger().Debug("searching neighbors of items",
zap.Int("n_complete_items", completedCount),
zap.Int("n_items", dataset.ItemCount()),
- zap.Int("throughput", throughput))
+ zap.Int("throughput", throughput/10))
}
}
}
}()
userIDF := make([]float32, dataset.UserCount())
- if m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeRelated ||
- m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto {
+ if t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeRelated ||
+ t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto {
for _, feedbacks := range dataset.ItemFeedback {
sort.Sort(sortutil.Int32Slice(feedbacks))
}
- m.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeedback))
+ t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeedback))
// inverse document frequency of users
for i := range dataset.UserFeedback {
userIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(dataset.UserFeedback[i])))
}
- m.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.UserFeedback))
+ t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.UserFeedback))
}
labeledItems := make([][]int32, dataset.NumItemLabels)
labelIDF := make([]float32, dataset.NumItemLabels)
- if m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeSimilar ||
- m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto {
+ if t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeSimilar ||
+ t.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto {
for i, itemLabels := range dataset.ItemLabels {
sort.Sort(sortutil.Int32Slice(itemLabels))
for _, label := range itemLabels {
labeledItems[label] = append(labeledItems[label], int32(i))
}
}
- m.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemLabels))
+ t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemLabels))
// inverse document frequency of labels
for i := range labeledItems {
labelIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(labeledItems[i])))
}
- m.taskMonitor.Add(TaskFindItemNeighbors, len(labeledItems))
+ t.taskMonitor.Add(TaskFindItemNeighbors, len(labeledItems))
}
start := time.Now()
var err error
- if m.Config.Recommend.ItemNeighbors.EnableIndex {
- err = m.findItemNeighborsIVF(dataset, labelIDF, userIDF, completed)
+ if t.Config.Recommend.ItemNeighbors.EnableIndex {
+ err = t.findItemNeighborsIVF(dataset, labelIDF, userIDF, completed, j)
} else {
- err = m.findItemNeighborsBruteForce(dataset, labeledItems, labelIDF, userIDF, completed)
+ err = t.findItemNeighborsBruteForce(dataset, labeledItems, labelIDF, userIDF, completed, j)
}
searchTime := time.Since(start)
close(completed)
if err != nil {
log.Logger().Error("failed to searching neighbors of items", zap.Error(err))
- m.taskMonitor.Fail(TaskFindItemNeighbors, err.Error())
+ t.taskMonitor.Fail(TaskFindItemNeighbors, err.Error())
FindItemNeighborsTotalSeconds.Set(0)
} else {
- if err := m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateItemNeighborsTime), time.Now())); err != nil {
+ if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateItemNeighborsTime), time.Now())); err != nil {
log.Logger().Error("failed to set neighbors of items update time", zap.Error(err))
}
log.Logger().Info("complete searching neighbors of items",
zap.String("search_time", searchTime.String()))
- m.taskMonitor.Finish(TaskFindItemNeighbors)
+ t.taskMonitor.Finish(TaskFindItemNeighbors)
FindItemNeighborsTotalSeconds.Set(time.Since(startTaskTime).Seconds())
}
+
+ t.lastNumItems = numItems
+ t.lastNumFeedback = numFeedback
+ return nil
}
func (m *Master) findItemNeighborsBruteForce(dataset *ranking.DataSet, labeledItems [][]int32,
- labelIDF, userIDF []float32, completed chan struct{}) error {
+ labelIDF, userIDF []float32, completed chan struct{}, j *task.JobsAllocator) error {
var (
updateItemCount atomic.Float64
findNeighborSeconds atomic.Float64
@@ -314,7 +358,7 @@ func (m *Master) findItemNeighborsBruteForce(dataset *ranking.DataSet, labeledIt
return errors.NotImplementedf("item neighbor type `%v`", m.Config.Recommend.ItemNeighbors.NeighborType)
}
- err := parallel.Parallel(dataset.ItemCount(), m.Config.Master.NumJobs, func(workerId, itemIndex int) error {
+ err := parallel.DynamicParallel(dataset.ItemCount(), j, func(workerId, itemIndex int) error {
defer func() {
completed <- struct{}{}
}()
@@ -372,7 +416,7 @@ func (m *Master) findItemNeighborsBruteForce(dataset *ranking.DataSet, labeledIt
return nil
}
-func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userIDF []float32, completed chan struct{}) error {
+func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userIDF []float32, completed chan struct{}, j *task.JobsAllocator) error {
var (
updateItemCount atomic.Float64
findNeighborSeconds atomic.Float64
@@ -400,7 +444,8 @@ func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userID
return errors.NotImplementedf("item neighbor type `%v`", m.Config.Recommend.ItemNeighbors.NeighborType)
}
- builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize, search.SetIVFNumJobs(m.Config.Master.NumJobs))
+ builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize,
+ search.SetIVFJobsAllocator(j))
var recall float32
index, recall = builder.Build(m.Config.Recommend.ItemNeighbors.IndexRecall,
m.Config.Recommend.ItemNeighbors.IndexFitEpoch,
@@ -412,7 +457,7 @@ func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userID
}
buildIndexSeconds.Add(time.Since(buildStart).Seconds())
- err := parallel.Parallel(dataset.ItemCount(), m.Config.Master.NumJobs, func(workerId, itemIndex int) error {
+ err := parallel.DynamicParallel(dataset.ItemCount(), j, func(workerId, itemIndex int) error {
defer func() {
completed <- struct{}{}
}()
@@ -479,12 +524,44 @@ func (m *Master) estimateFindUserNeighborsComplexity(dataset *ranking.DataSet) i
return complexity
}
-// runFindUserNeighborsTask updates neighbors of users.
-func (m *Master) runFindUserNeighborsTask(dataset *ranking.DataSet) {
+// FindUserNeighborsTask updates neighbors of users.
+type FindUserNeighborsTask struct {
+ *Master
+ lastNumUsers int
+ lastNumFeedback int
+}
+
+func NewFindUserNeighborsTask(m *Master) *FindUserNeighborsTask {
+ return &FindUserNeighborsTask{Master: m}
+}
+
+func (t *FindUserNeighborsTask) name() string {
+ return TaskFindUserNeighbors
+}
+
+func (t *FindUserNeighborsTask) priority() int {
+ return -t.rankingTrainSet.UserCount() * t.rankingTrainSet.UserCount()
+}
+
+func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error {
+ t.rankingDataMutex.RLock()
+ defer t.rankingDataMutex.RUnlock()
+ dataset := t.rankingTrainSet
+ numUsers := dataset.UserCount()
+ numFeedback := dataset.Count()
+
+ if numUsers == 0 {
+ t.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.")
+ return nil
+ } else if numUsers == t.lastNumUsers && numFeedback == t.lastNumFeedback {
+ log.Logger().Info("No update of user neighbors needed.")
+ return nil
+ }
+
startTaskTime := time.Now()
- m.taskMonitor.Start(TaskFindUserNeighbors, m.estimateFindUserNeighborsComplexity(dataset))
+ t.taskMonitor.Start(TaskFindUserNeighbors, t.estimateFindUserNeighborsComplexity(dataset))
log.Logger().Info("start searching neighbors of users",
- zap.Int("n_cache", m.Config.Recommend.CacheSize))
+ zap.Int("n_cache", t.Config.Recommend.CacheSize))
// create progress tracker
completed := make(chan struct{}, 1000)
go func() {
@@ -501,7 +578,7 @@ func (m *Master) runFindUserNeighborsTask(dataset *ranking.DataSet) {
throughput := completedCount - previousCount
previousCount = completedCount
if throughput > 0 {
- m.taskMonitor.Add(TaskFindUserNeighbors, throughput*dataset.UserCount())
+ t.taskMonitor.Add(TaskFindUserNeighbors, throughput*dataset.UserCount())
log.Logger().Debug("searching neighbors of users",
zap.Int("n_complete_users", completedCount),
zap.Int("n_users", dataset.UserCount()),
@@ -512,62 +589,66 @@ func (m *Master) runFindUserNeighborsTask(dataset *ranking.DataSet) {
}()
itemIDF := make([]float32, dataset.ItemCount())
- if m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeRelated ||
- m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto {
+ if t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeRelated ||
+ t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto {
for _, feedbacks := range dataset.UserFeedback {
sort.Sort(sortutil.Int32Slice(feedbacks))
}
- m.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeedback))
+ t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeedback))
// inverse document frequency of items
for i := range dataset.ItemFeedback {
itemIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(dataset.ItemFeedback[i])))
}
- m.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.ItemFeedback))
+ t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.ItemFeedback))
}
labeledUsers := make([][]int32, dataset.NumUserLabels)
labelIDF := make([]float32, dataset.NumUserLabels)
- if m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeSimilar ||
- m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto {
+ if t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeSimilar ||
+ t.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto {
for i, userLabels := range dataset.UserLabels {
sort.Sort(sortutil.Int32Slice(userLabels))
for _, label := range userLabels {
labeledUsers[label] = append(labeledUsers[label], int32(i))
}
}
- m.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserLabels))
+ t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserLabels))
// inverse document frequency of labels
for i := range labeledUsers {
labelIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(labeledUsers[i])))
}
- m.taskMonitor.Add(TaskFindUserNeighbors, len(labeledUsers))
+ t.taskMonitor.Add(TaskFindUserNeighbors, len(labeledUsers))
}
start := time.Now()
var err error
- if m.Config.Recommend.UserNeighbors.EnableIndex {
- err = m.findUserNeighborsIVF(dataset, labelIDF, itemIDF, completed)
+ if t.Config.Recommend.UserNeighbors.EnableIndex {
+ err = t.findUserNeighborsIVF(dataset, labelIDF, itemIDF, completed, j)
} else {
- err = m.findUserNeighborsBruteForce(dataset, labeledUsers, labelIDF, itemIDF, completed)
+ err = t.findUserNeighborsBruteForce(dataset, labeledUsers, labelIDF, itemIDF, completed, j)
}
searchTime := time.Since(start)
close(completed)
if err != nil {
log.Logger().Error("failed to searching neighbors of users", zap.Error(err))
- m.taskMonitor.Fail(TaskFindUserNeighbors, err.Error())
+ t.taskMonitor.Fail(TaskFindUserNeighbors, err.Error())
FindUserNeighborsTotalSeconds.Set(0)
} else {
- if err := m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateUserNeighborsTime), time.Now())); err != nil {
+ if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateUserNeighborsTime), time.Now())); err != nil {
log.Logger().Error("failed to set neighbors of users update time", zap.Error(err))
}
log.Logger().Info("complete searching neighbors of users",
zap.String("search_time", searchTime.String()))
- m.taskMonitor.Finish(TaskFindUserNeighbors)
+ t.taskMonitor.Finish(TaskFindUserNeighbors)
FindUserNeighborsTotalSeconds.Set(time.Since(startTaskTime).Seconds())
}
+
+ t.lastNumUsers = numUsers
+ t.lastNumFeedback = numFeedback
+ return nil
}
-func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUsers [][]int32, labelIDF, itemIDF []float32, completed chan struct{}) error {
+func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUsers [][]int32, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error {
var (
updateUserCount atomic.Float64
findNeighborSeconds atomic.Float64
@@ -587,7 +668,7 @@ func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUs
return errors.NotImplementedf("user neighbor type `%v`", m.Config.Recommend.UserNeighbors.NeighborType)
}
- err := parallel.Parallel(dataset.UserCount(), m.Config.Master.NumJobs, func(workerId, userIndex int) error {
+ err := parallel.DynamicParallel(dataset.UserCount(), j, func(workerId, userIndex int) error {
defer func() {
completed <- struct{}{}
}()
@@ -636,7 +717,7 @@ func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUs
return nil
}
-func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemIDF []float32, completed chan struct{}) error {
+func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error {
var (
updateUserCount atomic.Float64
buildIndexSeconds atomic.Float64
@@ -665,7 +746,8 @@ func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemID
return errors.NotImplementedf("user neighbor type `%v`", m.Config.Recommend.UserNeighbors.NeighborType)
}
- builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize, search.SetIVFNumJobs(m.Config.Master.NumJobs))
+ builder := search.NewIVFBuilder(vectors, m.Config.Recommend.CacheSize,
+ search.SetIVFJobsAllocator(j))
var recall float32
index, recall = builder.Build(
m.Config.Recommend.UserNeighbors.IndexRecall,
@@ -678,7 +760,7 @@ func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemID
}
buildIndexSeconds.Add(time.Since(buildStart).Seconds())
- err := parallel.Parallel(dataset.UserCount(), m.Config.Master.NumJobs, func(workerId, userIndex int) error {
+ err := parallel.DynamicParallel(dataset.UserCount(), j, func(workerId, userIndex int) error {
defer func() {
completed <- struct{}{}
}()
@@ -848,319 +930,387 @@ func (m *Master) checkItemNeighborCacheTimeout(itemId string, categories []strin
return updateTime.Unix() <= modifiedTime.Unix()
}
-// fitRankingModel fits ranking model using passed dataset. After model fitted, following states are changed:
-// 1. Ranking model version are increased.
-// 2. Ranking model score are updated.
-// 3. Ranking model, version and score are persisted to local cache.
-func (m *Master) runRankingRelatedTasks(
- lastNumUsers, lastNumItems, lastNumFeedback int,
-) (numUsers, numItems, numFeedback int, err error) {
- log.Logger().Info("start fitting ranking model", zap.Int("n_jobs", m.Config.Master.NumJobs))
- m.rankingDataMutex.RLock()
- defer m.rankingDataMutex.RUnlock()
- numUsers = m.rankingTrainSet.UserCount()
- numItems = m.rankingTrainSet.ItemCount()
- numFeedback = m.rankingTrainSet.Count()
-
- if numUsers == 0 && numItems == 0 && numFeedback == 0 {
- log.Logger().Warn("empty ranking dataset",
- zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
- return
- }
- numFeedbackChanged := numFeedback != lastNumFeedback
- numUsersChanged := numUsers != lastNumUsers
- numItemsChanged := numItems != lastNumItems
+type FitRankingModelTask struct {
+ *Master
+ lastNumFeedback int
+}
+
+func NewFitRankingModelTask(m *Master) *FitRankingModelTask {
+ return &FitRankingModelTask{Master: m}
+}
+
+func (t *FitRankingModelTask) name() string {
+ return TaskFitRankingModel
+}
+
+func (t *FitRankingModelTask) priority() int {
+ return -t.rankingTrainSet.Count()
+}
+
+func (t *FitRankingModelTask) run(j *task.JobsAllocator) error {
+ t.rankingDataMutex.RLock()
+ defer t.rankingDataMutex.RUnlock()
+ dataset := t.rankingTrainSet
+ numFeedback := dataset.Count()
var modelChanged bool
- bestRankingName, bestRankingModel, bestRankingScore := m.rankingModelSearcher.GetBestModel()
- m.rankingModelMutex.Lock()
+ bestRankingName, bestRankingModel, bestRankingScore := t.rankingModelSearcher.GetBestModel()
+ t.rankingModelMutex.Lock()
if bestRankingModel != nil && !bestRankingModel.Invalid() &&
- (bestRankingName != m.rankingModelName || bestRankingModel.GetParams().ToString() != m.RankingModel.GetParams().ToString()) &&
- (bestRankingScore.NDCG > m.rankingScore.NDCG) {
+ (bestRankingName != t.rankingModelName || bestRankingModel.GetParams().ToString() != t.RankingModel.GetParams().ToString()) &&
+ (bestRankingScore.NDCG > t.rankingScore.NDCG) {
// 1. best ranking model must have been found.
// 2. best ranking model must be different from current model
// 3. best ranking model must perform better than current model
- m.RankingModel = bestRankingModel
- m.rankingModelName = bestRankingName
- m.rankingScore = bestRankingScore
+ t.RankingModel = bestRankingModel
+ t.rankingModelName = bestRankingName
+ t.rankingScore = bestRankingScore
modelChanged = true
log.Logger().Info("find better ranking model",
zap.Any("score", bestRankingScore),
zap.String("name", bestRankingName),
- zap.Any("params", m.RankingModel.GetParams()))
+ zap.Any("params", t.RankingModel.GetParams()))
}
- rankingModel := m.RankingModel
- m.rankingModelMutex.Unlock()
+ rankingModel := ranking.Clone(t.RankingModel)
+ t.rankingModelMutex.Unlock()
- // collect neighbors of items
- if numItems == 0 {
- m.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.")
- } else if numItemsChanged || numFeedbackChanged {
- m.runFindItemNeighborsTask(m.rankingTrainSet)
- }
- // collect neighbors of users
- if numUsers == 0 {
- m.taskMonitor.Fail(TaskFindUserNeighbors, "No user found.")
- } else if numUsersChanged || numFeedbackChanged {
- m.runFindUserNeighborsTask(m.rankingTrainSet)
- }
-
- // training model
if numFeedback == 0 {
- m.taskMonitor.Fail(TaskFitRankingModel, "No feedback found.")
- return
- } else if !numFeedbackChanged && !modelChanged {
+ t.taskMonitor.Fail(TaskFitRankingModel, "No feedback found.")
+ return nil
+ } else if numFeedback == t.lastNumFeedback && !modelChanged {
log.Logger().Info("nothing changed")
- return
+ return nil
}
- m.runFitRankingModelTask(rankingModel)
- return
-}
-func (m *Master) runFitRankingModelTask(rankingModel ranking.MatrixFactorization) {
startFitTime := time.Now()
- score := rankingModel.Fit(m.rankingTrainSet, m.rankingTestSet, ranking.NewFitConfig().
- SetJobs(m.Config.Master.NumJobs).
- SetTask(m.taskMonitor.Start(TaskFitRankingModel, rankingModel.Complexity())))
+ score := rankingModel.Fit(t.rankingTrainSet, t.rankingTestSet, ranking.NewFitConfig().
+ SetJobsAllocator(j).
+ SetTask(t.taskMonitor.Start(TaskFitRankingModel, rankingModel.Complexity())))
CollaborativeFilteringFitSeconds.Set(time.Since(startFitTime).Seconds())
// update ranking model
- m.rankingModelMutex.Lock()
- m.RankingModel = rankingModel
- m.RankingModelVersion++
- m.rankingScore = score
- m.rankingModelMutex.Unlock()
+ t.rankingModelMutex.Lock()
+ t.RankingModel = rankingModel
+ t.RankingModelVersion++
+ t.rankingScore = score
+ t.rankingModelMutex.Unlock()
log.Logger().Info("fit ranking model complete",
- zap.String("version", fmt.Sprintf("%x", m.RankingModelVersion)))
+ zap.String("version", fmt.Sprintf("%x", t.RankingModelVersion)))
CollaborativeFilteringNDCG10.Set(float64(score.NDCG))
CollaborativeFilteringRecall10.Set(float64(score.Recall))
CollaborativeFilteringPrecision10.Set(float64(score.Precision))
- MemoryInuseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes()))
- if err := m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitMatchingModelTime), time.Now())); err != nil {
+ MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(t.RankingModel.Bytes()))
+ if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitMatchingModelTime), time.Now())); err != nil {
log.Logger().Error("failed to write meta", zap.Error(err))
}
// caching model
- m.rankingModelMutex.RLock()
- m.localCache.RankingModelName = m.rankingModelName
- m.localCache.RankingModelVersion = m.RankingModelVersion
- m.localCache.RankingModel = rankingModel
- m.localCache.RankingModelScore = score
- m.rankingModelMutex.RUnlock()
- if m.localCache.ClickModel == nil || m.localCache.ClickModel.Invalid() {
+ t.rankingModelMutex.RLock()
+ t.localCache.RankingModelName = t.rankingModelName
+ t.localCache.RankingModelVersion = t.RankingModelVersion
+ t.localCache.RankingModel = rankingModel
+ t.localCache.RankingModelScore = score
+ t.rankingModelMutex.RUnlock()
+ if t.localCache.ClickModel == nil || t.localCache.ClickModel.Invalid() {
log.Logger().Info("wait click model")
- } else if err := m.localCache.WriteLocalCache(); err != nil {
+ } else if err := t.localCache.WriteLocalCache(); err != nil {
log.Logger().Error("failed to write local cache", zap.Error(err))
} else {
log.Logger().Info("write model to local cache",
- zap.String("ranking_model_name", m.localCache.RankingModelName),
- zap.String("ranking_model_version", encoding.Hex(m.localCache.RankingModelVersion)),
- zap.Float32("ranking_model_score", m.localCache.RankingModelScore.NDCG),
- zap.Any("ranking_model_params", m.localCache.RankingModel.GetParams()))
+ zap.String("ranking_model_name", t.localCache.RankingModelName),
+ zap.String("ranking_model_version", encoding.Hex(t.localCache.RankingModelVersion)),
+ zap.Float32("ranking_model_score", t.localCache.RankingModelScore.NDCG),
+ zap.Any("ranking_model_params", t.localCache.RankingModel.GetParams()))
}
- m.taskMonitor.Finish(TaskFitRankingModel)
+ t.taskMonitor.Finish(TaskFitRankingModel)
+ t.lastNumFeedback = numFeedback
+ return nil
}
-// runFitClickModelTask fits click model using latest data. After model fitted, following states are changed:
+// FitClickModelTask fits click model using latest data. After model fitted, following states are changed:
// 1. Click model version are increased.
// 2. Click model score are updated.
// 3. Click model, version and score are persisted to local cache.
-func (m *Master) runFitClickModelTask(
- lastNumUsers, lastNumItems, lastNumFeedback int,
-) (numUsers, numItems, numFeedback int, err error) {
- log.Logger().Info("prepare to fit click model", zap.Int("n_jobs", m.Config.Master.NumJobs))
- m.clickDataMutex.RLock()
- defer m.clickDataMutex.RUnlock()
- numUsers = m.clickTrainSet.UserCount()
- numItems = m.clickTrainSet.ItemCount()
- numFeedback = m.clickTrainSet.Count()
+type FitClickModelTask struct {
+ *Master
+ lastNumUsers int
+ lastNumItems int
+ lastNumFeedback int
+}
+
+func NewFitClickModelTask(m *Master) *FitClickModelTask {
+ return &FitClickModelTask{Master: m}
+}
+
+func (t *FitClickModelTask) name() string {
+ return TaskFitClickModel
+}
+
+func (t *FitClickModelTask) priority() int {
+ return -t.clickTrainSet.Count()
+}
+
+func (t *FitClickModelTask) run(j *task.JobsAllocator) error {
+ log.Logger().Info("prepare to fit click model", zap.Int("n_jobs", t.Config.Master.NumJobs))
+ t.clickDataMutex.RLock()
+ defer t.clickDataMutex.RUnlock()
+ numUsers := t.clickTrainSet.UserCount()
+ numItems := t.clickTrainSet.ItemCount()
+ numFeedback := t.clickTrainSet.Count()
var shouldFit bool
- if numUsers == 0 || numItems == 0 || numFeedback == 0 {
+ if t.clickTrainSet == nil || numUsers == 0 || numItems == 0 || numFeedback == 0 {
log.Logger().Warn("empty ranking dataset",
- zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
- m.taskMonitor.Fail(TaskFitClickModel, "No feedback found.")
- return
- } else if numUsers != lastNumUsers ||
- numItems != lastNumItems ||
- numFeedback != lastNumFeedback {
+ zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes))
+ t.taskMonitor.Fail(TaskFitClickModel, "No feedback found.")
+ return nil
+ } else if numUsers != t.lastNumUsers ||
+ numItems != t.lastNumItems ||
+ numFeedback != t.lastNumFeedback {
shouldFit = true
}
- bestClickModel, bestClickScore := m.clickModelSearcher.GetBestModel()
- m.clickModelMutex.Lock()
+ bestClickModel, bestClickScore := t.clickModelSearcher.GetBestModel()
+ t.clickModelMutex.Lock()
if bestClickModel != nil && !bestClickModel.Invalid() &&
- bestClickModel.GetParams().ToString() != m.ClickModel.GetParams().ToString() &&
- bestClickScore.Precision > m.clickScore.Precision {
+ bestClickModel.GetParams().ToString() != t.ClickModel.GetParams().ToString() &&
+ bestClickScore.Precision > t.clickScore.Precision {
// 1. best click model must have been found.
// 2. best click model must be different from current model
// 3. best click model must perform better than current model
- m.ClickModel = bestClickModel
- m.clickScore = bestClickScore
+ t.ClickModel = bestClickModel
+ t.clickScore = bestClickScore
shouldFit = true
log.Logger().Info("find better click model",
zap.Float32("Precision", bestClickScore.Precision),
zap.Float32("Recall", bestClickScore.Recall),
- zap.Any("params", m.ClickModel.GetParams()))
+ zap.Any("params", t.ClickModel.GetParams()))
}
- clickModel := m.ClickModel
- m.clickModelMutex.Unlock()
+ clickModel := click.Clone(t.ClickModel)
+ t.clickModelMutex.Unlock()
// training model
if !shouldFit {
log.Logger().Info("nothing changed")
- return
+ return nil
}
startFitTime := time.Now()
- score := clickModel.Fit(m.clickTrainSet, m.clickTestSet, click.NewFitConfig().
- SetJobs(m.Config.Master.NumJobs).
- SetTask(m.taskMonitor.Start(TaskFitClickModel, clickModel.Complexity())))
+ score := clickModel.Fit(t.clickTrainSet, t.clickTestSet, click.NewFitConfig().
+ SetJobsAllocator(j).
+ SetTask(t.taskMonitor.Start(TaskFitClickModel, clickModel.Complexity())))
RankingFitSeconds.Set(time.Since(startFitTime).Seconds())
// update match model
- m.clickModelMutex.Lock()
- m.ClickModel = clickModel
- m.clickScore = score
- m.ClickModelVersion++
- m.clickModelMutex.Unlock()
+ t.clickModelMutex.Lock()
+ t.ClickModel = clickModel
+ t.clickScore = score
+ t.ClickModelVersion++
+ t.clickModelMutex.Unlock()
log.Logger().Info("fit click model complete",
- zap.String("version", fmt.Sprintf("%x", m.ClickModelVersion)))
+ zap.String("version", fmt.Sprintf("%x", t.ClickModelVersion)))
RankingPrecision.Set(float64(score.Precision))
RankingRecall.Set(float64(score.Recall))
RankingAUC.Set(float64(score.AUC))
- MemoryInuseBytesVec.WithLabelValues("ranking_model").Set(float64(m.ClickModel.Bytes()))
- if err = m.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitRankingModelTime), time.Now())); err != nil {
+ MemoryInUseBytesVec.WithLabelValues("ranking_model").Set(float64(t.ClickModel.Bytes()))
+ if err := t.CacheClient.Set(cache.Time(cache.Key(cache.GlobalMeta, cache.LastFitRankingModelTime), time.Now())); err != nil {
log.Logger().Error("failed to write meta", zap.Error(err))
}
// caching model
- m.clickModelMutex.RLock()
- m.localCache.ClickModelScore = m.clickScore
- m.localCache.ClickModelVersion = m.ClickModelVersion
- m.localCache.ClickModel = m.ClickModel
- m.clickModelMutex.RUnlock()
- if m.localCache.RankingModel == nil || m.localCache.RankingModel.Invalid() {
+ t.clickModelMutex.RLock()
+ t.localCache.ClickModelScore = t.clickScore
+ t.localCache.ClickModelVersion = t.ClickModelVersion
+ t.localCache.ClickModel = t.ClickModel
+ t.clickModelMutex.RUnlock()
+ if t.localCache.RankingModel == nil || t.localCache.RankingModel.Invalid() {
log.Logger().Info("wait ranking model")
- } else if err = m.localCache.WriteLocalCache(); err != nil {
+ } else if err := t.localCache.WriteLocalCache(); err != nil {
log.Logger().Error("failed to write local cache", zap.Error(err))
} else {
log.Logger().Info("write model to local cache",
- zap.String("click_model_version", encoding.Hex(m.localCache.ClickModelVersion)),
+ zap.String("click_model_version", encoding.Hex(t.localCache.ClickModelVersion)),
zap.Float32("click_model_score", score.Precision),
- zap.Any("click_model_params", m.localCache.ClickModel.GetParams()))
+ zap.Any("click_model_params", t.localCache.ClickModel.GetParams()))
}
- m.taskMonitor.Finish(TaskFitClickModel)
- return
+ t.taskMonitor.Finish(TaskFitClickModel)
+ t.lastNumItems = numItems
+ t.lastNumUsers = numUsers
+ t.lastNumFeedback = numFeedback
+ return nil
}
-// runSearchRankingModelTask searches best hyper-parameters for ranking models.
+// SearchRankingModelTask searches best hyper-parameters for ranking models.
// It requires read lock on the ranking dataset.
-func (m *Master) runSearchRankingModelTask(
- lastNumUsers, lastNumItems, lastNumFeedback int,
-) (numUsers, numItems, numFeedback int, err error) {
+type SearchRankingModelTask struct {
+ *Master
+ lastNumUsers int
+ lastNumItems int
+ lastNumFeedback int
+}
+
+func NewSearchRankingModelTask(m *Master) *SearchRankingModelTask {
+ return &SearchRankingModelTask{Master: m}
+}
+
+func (t *SearchRankingModelTask) name() string {
+ return TaskSearchRankingModel
+}
+
+func (t *SearchRankingModelTask) priority() int {
+ return -t.rankingTrainSet.Count()
+}
+
+func (t *SearchRankingModelTask) run(j *task.JobsAllocator) error {
log.Logger().Info("start searching ranking model")
- m.rankingDataMutex.RLock()
- defer m.rankingDataMutex.RUnlock()
- numUsers = m.rankingTrainSet.UserCount()
- numItems = m.rankingTrainSet.ItemCount()
- numFeedback = m.rankingTrainSet.Count()
+ t.rankingDataMutex.RLock()
+ defer t.rankingDataMutex.RUnlock()
+ if t.rankingTrainSet == nil {
+ log.Logger().Debug("dataset has not been loaded")
+ return nil
+ }
+ numUsers := t.rankingTrainSet.UserCount()
+ numItems := t.rankingTrainSet.ItemCount()
+ numFeedback := t.rankingTrainSet.Count()
if numUsers == 0 || numItems == 0 || numFeedback == 0 {
log.Logger().Warn("empty ranking dataset",
- zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
- m.taskMonitor.Fail(TaskSearchRankingModel, "No feedback found.")
- return
- } else if numUsers == lastNumUsers &&
- numItems == lastNumItems &&
- numFeedback == lastNumFeedback {
+ zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes))
+ t.taskMonitor.Fail(TaskSearchRankingModel, "No feedback found.")
+ return nil
+ } else if numUsers == t.lastNumUsers &&
+ numItems == t.lastNumItems &&
+ numFeedback == t.lastNumFeedback {
log.Logger().Info("ranking dataset not changed")
- return
+ return nil
}
startTime := time.Now()
- err = m.rankingModelSearcher.Fit(m.rankingTrainSet, m.rankingTestSet,
- m.taskMonitor.Start(TaskSearchRankingModel, m.rankingModelSearcher.Complexity()),
- m.taskScheduler.NewRunner(TaskSearchRankingModel))
+ err := t.rankingModelSearcher.Fit(t.rankingTrainSet, t.rankingTestSet,
+ t.taskMonitor.Start(TaskSearchRankingModel, t.rankingModelSearcher.Complexity()), j)
if err != nil {
log.Logger().Error("failed to search collaborative filtering model", zap.Error(err))
- return
+ return nil
}
CollaborativeFilteringSearchSeconds.Set(time.Since(startTime).Seconds())
- _, _, bestScore := m.rankingModelSearcher.GetBestModel()
+ _, _, bestScore := t.rankingModelSearcher.GetBestModel()
CollaborativeFilteringSearchPrecision10.Set(float64(bestScore.Precision))
- m.taskMonitor.Finish(TaskSearchRankingModel)
- return
+ t.taskMonitor.Finish(TaskSearchRankingModel)
+ t.lastNumItems = numItems
+ t.lastNumUsers = numUsers
+ t.lastNumFeedback = numFeedback
+ return nil
}
-// runSearchClickModelTask searches best hyper-parameters for factorization machines.
+// SearchClickModelTask searches best hyper-parameters for factorization machines.
// It requires read lock on the click dataset.
-func (m *Master) runSearchClickModelTask(
- lastNumUsers, lastNumItems, lastNumFeedback int,
-) (numUsers, numItems, numFeedback int, err error) {
+type SearchClickModelTask struct {
+ *Master
+ lastNumUsers int
+ lastNumItems int
+ lastNumFeedback int
+}
+
+func NewSearchClickModelTask(m *Master) *SearchClickModelTask {
+ return &SearchClickModelTask{Master: m}
+}
+
+func (t *SearchClickModelTask) name() string {
+ return TaskSearchClickModel
+}
+
+func (t *SearchClickModelTask) priority() int {
+ return -t.clickTrainSet.Count()
+}
+
+func (t *SearchClickModelTask) run(j *task.JobsAllocator) error {
log.Logger().Info("start searching click model")
- m.clickDataMutex.RLock()
- defer m.clickDataMutex.RUnlock()
- numUsers = m.clickTrainSet.UserCount()
- numItems = m.clickTrainSet.ItemCount()
- numFeedback = m.clickTrainSet.Count()
+ t.clickDataMutex.RLock()
+ defer t.clickDataMutex.RUnlock()
+ if t.clickTrainSet == nil {
+ log.Logger().Debug("dataset has not been loaded")
+ return nil
+ }
+ numUsers := t.clickTrainSet.UserCount()
+ numItems := t.clickTrainSet.ItemCount()
+ numFeedback := t.clickTrainSet.Count()
if numUsers == 0 || numItems == 0 || numFeedback == 0 {
log.Logger().Warn("empty click dataset",
- zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
- m.taskMonitor.Fail(TaskSearchClickModel, "No feedback found.")
- return
- } else if numUsers == lastNumUsers &&
- numItems == lastNumItems &&
- numFeedback == lastNumFeedback {
+ zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes))
+ t.taskMonitor.Fail(TaskSearchClickModel, "No feedback found.")
+ return nil
+ } else if numUsers == t.lastNumUsers &&
+ numItems == t.lastNumItems &&
+ numFeedback == t.lastNumFeedback {
log.Logger().Info("click dataset not changed")
- return
+ return nil
}
startTime := time.Now()
- err = m.clickModelSearcher.Fit(m.clickTrainSet, m.clickTestSet,
- m.taskMonitor.Start(TaskSearchClickModel, m.clickModelSearcher.Complexity()),
- m.taskScheduler.NewRunner(TaskSearchClickModel))
+ err := t.clickModelSearcher.Fit(t.clickTrainSet, t.clickTestSet,
+ t.taskMonitor.Start(TaskSearchClickModel, t.clickModelSearcher.Complexity()), j)
if err != nil {
log.Logger().Error("failed to search ranking model", zap.Error(err))
- return
+ return nil
}
RankingSearchSeconds.Set(time.Since(startTime).Seconds())
- _, bestScore := m.clickModelSearcher.GetBestModel()
+ _, bestScore := t.clickModelSearcher.GetBestModel()
RankingSearchPrecision.Set(float64(bestScore.Precision))
- m.taskMonitor.Finish(TaskSearchClickModel)
- return
+ t.taskMonitor.Finish(TaskSearchClickModel)
+ t.lastNumItems = numItems
+ t.lastNumUsers = numUsers
+ t.lastNumFeedback = numFeedback
+ return nil
+}
+
+type CacheGarbageCollectionTask struct {
+ *Master
+}
+
+func NewCacheGarbageCollectionTask(m *Master) *CacheGarbageCollectionTask {
+ return &CacheGarbageCollectionTask{m}
+}
+
+func (t *CacheGarbageCollectionTask) name() string {
+ return TaskCacheGarbageCollection
+}
+
+func (t *CacheGarbageCollectionTask) priority() int {
+ return -t.rankingTrainSet.UserCount() - t.rankingTrainSet.ItemCount()
}
-func (m *Master) runCacheGarbageCollectionTask() error {
- if m.rankingTrainSet == nil {
+func (t *CacheGarbageCollectionTask) run(j *task.JobsAllocator) error {
+ if t.rankingTrainSet == nil {
+ log.Logger().Debug("dataset has not been loaded")
return nil
}
+
log.Logger().Info("start cache garbage collection")
- m.taskMonitor.Start(TaskCacheGarbageCollection, m.rankingTrainSet.UserCount()*9+m.rankingTrainSet.ItemCount()*4)
+ t.taskMonitor.Start(TaskCacheGarbageCollection, t.rankingTrainSet.UserCount()*9+t.rankingTrainSet.ItemCount()*4)
var scanCount, reclaimCount int
start := time.Now()
- err := m.CacheClient.Scan(func(s string) error {
+ err := t.CacheClient.Scan(func(s string) error {
splits := strings.Split(s, "/")
if len(splits) <= 1 {
return nil
}
scanCount++
- m.taskMonitor.Update(TaskCacheGarbageCollection, scanCount)
+ t.taskMonitor.Update(TaskCacheGarbageCollection, scanCount)
switch splits[0] {
case cache.UserNeighbors, cache.UserNeighborsDigest, cache.IgnoreItems,
cache.OfflineRecommend, cache.OfflineRecommendDigest, cache.CollaborativeRecommend,
cache.LastModifyUserTime, cache.LastUpdateUserNeighborsTime, cache.LastUpdateUserRecommendTime:
userId := splits[1]
// check user in dataset
- if m.rankingTrainSet != nil && m.rankingTrainSet.UserIndex.ToNumber(userId) != base.NotId {
+ if t.rankingTrainSet != nil && t.rankingTrainSet.UserIndex.ToNumber(userId) != base.NotId {
return nil
}
// check user in database
- _, err := m.DataClient.GetUser(userId)
+ _, err := t.DataClient.GetUser(userId)
if !errors.Is(err, errors.NotFound) {
if err != nil {
log.Logger().Error("failed to load user", zap.String("user_id", userId), zap.Error(err))
@@ -1170,10 +1320,10 @@ func (m *Master) runCacheGarbageCollectionTask() error {
// delete user cache
switch splits[0] {
case cache.UserNeighbors, cache.IgnoreItems, cache.CollaborativeRecommend, cache.OfflineRecommend:
- err = m.CacheClient.SetSorted(s, nil)
+ err = t.CacheClient.SetSorted(s, nil)
case cache.UserNeighborsDigest, cache.OfflineRecommendDigest,
cache.LastModifyUserTime, cache.LastUpdateUserNeighborsTime, cache.LastUpdateUserRecommendTime:
- err = m.CacheClient.Delete(s)
+ err = t.CacheClient.Delete(s)
}
if err != nil {
return errors.Trace(err)
@@ -1182,11 +1332,11 @@ func (m *Master) runCacheGarbageCollectionTask() error {
case cache.ItemNeighbors, cache.ItemNeighborsDigest, cache.LastModifyItemTime, cache.LastUpdateItemNeighborsTime:
itemId := splits[1]
// check item in dataset
- if m.rankingTrainSet != nil && m.rankingTrainSet.ItemIndex.ToNumber(itemId) != base.NotId {
+ if t.rankingTrainSet != nil && t.rankingTrainSet.ItemIndex.ToNumber(itemId) != base.NotId {
return nil
}
// check item in database
- _, err := m.DataClient.GetItem(itemId)
+ _, err := t.DataClient.GetItem(itemId)
if !errors.Is(err, errors.NotFound) {
if err != nil {
log.Logger().Error("failed to load item", zap.String("item_id", itemId), zap.Error(err))
@@ -1196,9 +1346,9 @@ func (m *Master) runCacheGarbageCollectionTask() error {
// delete item cache
switch splits[0] {
case cache.ItemNeighbors:
- err = m.CacheClient.SetSorted(s, nil)
+ err = t.CacheClient.SetSorted(s, nil)
case cache.ItemNeighborsDigest, cache.LastModifyItemTime, cache.LastUpdateItemNeighborsTime:
- err = m.CacheClient.Delete(s)
+ err = t.CacheClient.Delete(s)
}
if err != nil {
return errors.Trace(err)
@@ -1208,10 +1358,10 @@ func (m *Master) runCacheGarbageCollectionTask() error {
return nil
})
// remove stale hidden items
- if err := m.CacheClient.RemSortedByScore(cache.HiddenItemsV2, math.Inf(-1), float64(time.Now().Add(-m.Config.Recommend.CacheExpire).Unix())); err != nil {
+ if err := t.CacheClient.RemSortedByScore(cache.HiddenItemsV2, math.Inf(-1), float64(time.Now().Add(-t.Config.Recommend.CacheExpire).Unix())); err != nil {
return errors.Trace(err)
}
- m.taskMonitor.Finish(TaskCacheGarbageCollection)
+ t.taskMonitor.Finish(TaskCacheGarbageCollection)
CacheScannedTotal.Set(float64(scanCount))
CacheReclaimedTotal.Set(float64(reclaimCount))
CacheScannedSeconds.Set(time.Since(start).Seconds())
diff --git a/master/tasks_test.go b/master/tasks_test.go
index 86a9cf095..5951500a6 100644
--- a/master/tasks_test.go
+++ b/master/tasks_test.go
@@ -15,15 +15,16 @@
package master
import (
+ "strconv"
+ "testing"
+ "time"
+
"github.com/juju/errors"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/config"
"github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
- "strconv"
- "testing"
- "time"
)
func TestMaster_FindItemNeighborsBruteForce(t *testing.T) {
@@ -85,10 +86,12 @@ func TestMaster_FindItemNeighborsBruteForce(t *testing.T) {
// load mock dataset
dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator())
assert.NoError(t, err)
+ m.rankingTrainSet = dataset
// similar items (common users)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated
- m.runFindItemNeighborsTask(dataset)
+ neighborTask := NewFindItemNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err := m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "9"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar))
@@ -103,7 +106,8 @@ func TestMaster_FindItemNeighborsBruteForce(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "8"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar
- m.runFindItemNeighborsTask(dataset)
+ neighborTask = NewFindItemNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -120,7 +124,8 @@ func TestMaster_FindItemNeighborsBruteForce(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "9"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeAuto
- m.runFindItemNeighborsTask(dataset)
+ neighborTask = NewFindItemNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -193,10 +198,12 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) {
// load mock dataset
dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator())
assert.NoError(t, err)
+ m.rankingTrainSet = dataset
// similar items (common users)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated
- m.runFindItemNeighborsTask(dataset)
+ neighborTask := NewFindItemNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err := m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "9"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar))
@@ -211,7 +218,8 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "8"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar
- m.runFindItemNeighborsTask(dataset)
+ neighborTask = NewFindItemNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -228,7 +236,8 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyItemTime, "9"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeAuto
- m.runFindItemNeighborsTask(dataset)
+ neighborTask = NewFindItemNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -282,10 +291,12 @@ func TestMaster_FindUserNeighborsBruteForce(t *testing.T) {
assert.NoError(t, err)
dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator())
assert.NoError(t, err)
+ m.rankingTrainSet = dataset
// similar items (common users)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated
- m.runFindUserNeighborsTask(dataset)
+ neighborTask := NewFindUserNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err := m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "9"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar))
@@ -296,7 +307,8 @@ func TestMaster_FindUserNeighborsBruteForce(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar
- m.runFindUserNeighborsTask(dataset)
+ neighborTask = NewFindUserNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -309,7 +321,8 @@ func TestMaster_FindUserNeighborsBruteForce(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "9"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeAuto
- m.runFindUserNeighborsTask(dataset)
+ neighborTask = NewFindUserNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -366,10 +379,12 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) {
assert.NoError(t, err)
dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator())
assert.NoError(t, err)
+ m.rankingTrainSet = dataset
// similar items (common users)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated
- m.runFindUserNeighborsTask(dataset)
+ neighborTask := NewFindUserNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err := m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "9"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"7", "5", "3"}, cache.RemoveScores(similar))
@@ -380,7 +395,8 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar
- m.runFindUserNeighborsTask(dataset)
+ neighborTask = NewFindUserNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -393,7 +409,8 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) {
err = m.CacheClient.Set(cache.Time(cache.Key(cache.LastModifyUserTime, "9"), time.Now()))
assert.NoError(t, err)
m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeAuto
- m.runFindUserNeighborsTask(dataset)
+ neighborTask = NewFindUserNeighborsTask(&m.Master)
+ assert.NoError(t, neighborTask.run(nil))
similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "8"), 0, 100)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "2", "4"}, cache.RemoveScores(similar))
@@ -673,7 +690,8 @@ func TestRunCacheGarbageCollectionTask(t *testing.T) {
// remove cache
assert.NotNil(t, m.rankingTrainSet)
- err = m.runCacheGarbageCollectionTask()
+ gcTask := NewCacheGarbageCollectionTask(&m.Master)
+ err = gcTask.run(nil)
assert.NoError(t, err)
var s string
diff --git a/misc/csv_test/feedback.csv b/misc/csv_test/feedback.csv
deleted file mode 100644
index c0ee0fab3..000000000
--- a/misc/csv_test/feedback.csv
+++ /dev/null
@@ -1,6 +0,0 @@
-UserId,ItemId
-0,0,0
-1,2,3
-2,4,6
-3,6,9
-4,8,12
\ No newline at end of file
diff --git a/misc/csv_test/item_date.csv b/misc/csv_test/item_date.csv
deleted file mode 100644
index 0df23062a..000000000
--- a/misc/csv_test/item_date.csv
+++ /dev/null
@@ -1,5 +0,0 @@
-1,2020-1-1,a|b|c
-2,2020-2-1,e|f|g
-3,2020-3-1,a|b|c
-4,2020-4-1,e|f|g
-5,2020-5-1,a|b|c
\ No newline at end of file
diff --git a/misc/csv_test/items.csv b/misc/csv_test/items.csv
deleted file mode 100644
index ee45216a5..000000000
--- a/misc/csv_test/items.csv
+++ /dev/null
@@ -1,5 +0,0 @@
-1::Toy Story (1995)::Animation|Children's|Comedy
-2::Jumanji (1995)::Adventure|Children's|Fantasy
-3::Grumpier Old Men (1995)::Comedy|Romance
-4::Waiting to Exhale (1995)::Comedy|Drama
-5::Father of the Bride Part II (1995)::Comedy
\ No newline at end of file
diff --git a/misc/csv_test/items_header.csv b/misc/csv_test/items_header.csv
deleted file mode 100644
index 7817fa0c0..000000000
--- a/misc/csv_test/items_header.csv
+++ /dev/null
@@ -1,6 +0,0 @@
-ItemId::Title::Genres
-1::Toy Story (1995)::Animation|Children's|Comedy
-2::Jumanji (1995)::Adventure|Children's|Fantasy
-3::Grumpier Old Men (1995)::Comedy|Romance
-4::Waiting to Exhale (1995)::Comedy|Drama
-5::Father of the Bride Part II (1995)::Comedy
\ No newline at end of file
diff --git a/misc/csv_test/items_id_only.csv b/misc/csv_test/items_id_only.csv
deleted file mode 100644
index 8a1218a10..000000000
--- a/misc/csv_test/items_id_only.csv
+++ /dev/null
@@ -1,5 +0,0 @@
-1
-2
-3
-4
-5
diff --git a/model/click/model.go b/model/click/model.go
index dc93b7c3c..8e4a8943f 100644
--- a/model/click/model.go
+++ b/model/click/model.go
@@ -17,6 +17,10 @@ package click
import (
"encoding/binary"
"fmt"
+ "io"
+ "reflect"
+ "time"
+
"github.com/chewxy/math32"
"github.com/juju/errors"
"github.com/samber/lo"
@@ -30,9 +34,6 @@ import (
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/model"
"go.uber.org/zap"
- "io"
- "reflect"
- "time"
)
type Score struct {
@@ -91,14 +92,13 @@ func (score Score) BetterThan(s Score) bool {
}
type FitConfig struct {
- Jobs int
+ *task.JobsAllocator
Verbose int
Task *task.Task
}
func NewFitConfig() *FitConfig {
return &FitConfig{
- Jobs: 1,
Verbose: 10,
}
}
@@ -108,8 +108,8 @@ func (config *FitConfig) SetVerbose(verbose int) *FitConfig {
return config
}
-func (config *FitConfig) SetJobs(nJobs int) *FitConfig {
- config.Jobs = nJobs
+func (config *FitConfig) SetJobsAllocator(allocator *task.JobsAllocator) *FitConfig {
+ config.JobsAllocator = allocator
return config
}
@@ -280,8 +280,9 @@ func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score {
zap.Any("params", fm.GetParams()),
zap.Any("config", config))
fm.Init(trainSet)
- temp := base.NewMatrix32(config.Jobs, fm.nFactors)
- vGrad := base.NewMatrix32(config.Jobs, fm.nFactors)
+ maxJobs := config.MaxJobs()
+ temp := base.NewMatrix32(maxJobs, fm.nFactors)
+ vGrad := base.NewMatrix32(maxJobs, fm.nFactors)
snapshots := SnapshotManger{}
evalStart := time.Now()
@@ -306,7 +307,7 @@ func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score {
}
fitStart := time.Now()
cost := float32(0)
- _ = parallel.BatchParallel(trainSet.Count(), config.Jobs, 128, func(workerId, beginJobId, endJobId int) error {
+ _ = parallel.BatchParallel(trainSet.Count(), config.AvailableJobs(config.Task), 128, func(workerId, beginJobId, endJobId int) error {
for i := beginJobId; i < endJobId; i++ {
features, values, target := trainSet.Get(i)
prediction := fm.internalPredictImpl(features, values)
diff --git a/model/click/model_test.go b/model/click/model_test.go
index e84599bda..c1d965ec0 100644
--- a/model/click/model_test.go
+++ b/model/click/model_test.go
@@ -30,7 +30,7 @@ func newFitConfigWithTestTracker(numEpoch int) *FitConfig {
t := task.NewTask("test", numEpoch)
cfg := NewFitConfig().
SetVerbose(1).
- SetJobs(1).
+ SetJobsAllocator(task.NewConstantJobsAllocator(1)).
SetTask(t)
return cfg
}
diff --git a/model/click/search.go b/model/click/search.go
index 8de6a379a..da4a673bc 100644
--- a/model/click/search.go
+++ b/model/click/search.go
@@ -16,13 +16,14 @@ package click
import (
"fmt"
+ "sync"
+ "time"
+
"github.com/zhenghaoz/gorse/base"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/model"
"go.uber.org/zap"
- "sync"
- "time"
)
// ParamsSearchResult contains the return of grid search.
@@ -37,7 +38,7 @@ type ParamsSearchResult struct {
// GridSearchCV finds the best parameters for a model.
func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid,
- _ int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult {
+ _ int64, fitConfig *FitConfig) ParamsSearchResult {
// Retrieve parameter names and length
paramNames := make([]model.ParamName, 0, len(paramGrid))
count := 1
@@ -60,11 +61,7 @@ func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Da
// Cross validate
estimator.Clear()
estimator.SetParams(estimator.GetParams().Overwrite(params))
- fitConfig.Task.Suspend(true)
- runner.Lock()
- fitConfig.Task.Suspend(false)
score := estimator.Fit(trainSet, testSet, fitConfig)
- runner.UnLock()
// Create GridSearch result
results.Scores = append(results.Scores, score)
results.Params = append(results.Params, params.Copy())
@@ -90,10 +87,10 @@ func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Da
// RandomSearchCV searches hyper-parameters by random.
func RandomSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid,
- numTrials int, seed int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult {
+ numTrials int, seed int64, fitConfig *FitConfig) ParamsSearchResult {
// if the number of combination is less than number of trials, use grid search
if paramGrid.NumCombinations() <= numTrials {
- return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig, runner)
+ return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig)
}
rng := base.NewRandomGenerator(seed)
results := ParamsSearchResult{
@@ -112,11 +109,7 @@ func RandomSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *
zap.Any("params", params))
estimator.Clear()
estimator.SetParams(estimator.GetParams().Overwrite(params))
- fitConfig.Task.Suspend(true)
- runner.Lock()
- fitConfig.Task.Suspend(false)
score := estimator.Fit(trainSet, testSet, fitConfig)
- runner.UnLock()
results.Scores = append(results.Scores, score)
results.Params = append(results.Params, params.Copy())
if len(results.Scores) == 0 || score.BetterThan(results.BestScore) {
@@ -135,7 +128,6 @@ type ModelSearcher struct {
// arguments
numEpochs int
numTrials int
- numJobs int
searchSize bool
// results
bestMutex sync.Mutex
@@ -144,12 +136,11 @@ type ModelSearcher struct {
}
// NewModelSearcher creates a thread-safe personal ranking model searcher.
-func NewModelSearcher(nEpoch, nTrials, nJobs int, searchSize bool) *ModelSearcher {
+func NewModelSearcher(nEpoch, nTrials int, searchSize bool) *ModelSearcher {
return &ModelSearcher{
model: NewFM(FMClassification, model.Params{model.NEpochs: nEpoch}),
numTrials: nTrials,
numEpochs: nEpoch,
- numJobs: nJobs,
searchSize: searchSize,
}
}
@@ -165,7 +156,7 @@ func (searcher *ModelSearcher) Complexity() int {
return searcher.numTrials * searcher.numEpochs
}
-func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, runner model.Runner) error {
+func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, j *task.JobsAllocator) error {
log.Logger().Info("click model search",
zap.Int("n_users", trainSet.UserCount()),
zap.Int("n_items", trainSet.ItemCount()),
@@ -176,8 +167,8 @@ func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, runn
// Random search
grid := searcher.model.GetParamsGrid(searcher.searchSize)
r := RandomSearchCV(searcher.model, trainSet, valSet, grid, searcher.numTrials, 0, NewFitConfig().
- SetJobs(searcher.numJobs).
- SetTask(t), runner)
+ SetJobsAllocator(j).
+ SetTask(t))
searcher.bestMutex.Lock()
defer searcher.bestMutex.Unlock()
searcher.bestModel = r.BestModel
diff --git a/model/click/search_test.go b/model/click/search_test.go
index 00ac61922..3092d825e 100644
--- a/model/click/search_test.go
+++ b/model/click/search_test.go
@@ -15,7 +15,6 @@ package click
import (
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/mock"
"github.com/zhenghaoz/gorse/base"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/model"
@@ -87,22 +86,9 @@ func (m *mockFactorizationMachineForSearch) GetParamsGrid(_ bool) model.ParamsGr
}
}
-type mockRunner struct {
- mock.Mock
-}
-
-func (r *mockRunner) Lock() {
- r.Called()
-}
-
-func (r *mockRunner) UnLock() {
- r.Called()
-}
-
func newFitConfigForSearch() *FitConfig {
t := task.NewTask("test", 0)
return &FitConfig{
- Jobs: 1,
Verbose: 1,
Task: t,
}
@@ -111,13 +97,8 @@ func newFitConfigForSearch() *FitConfig {
func TestGridSearchCV(t *testing.T) {
m := &mockFactorizationMachineForSearch{}
fitConfig := newFitConfigForSearch()
- runner := new(mockRunner)
- runner.On("Lock")
- runner.On("UnLock")
- r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig, runner)
+ r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig)
assert.Equal(t, float32(12), r.BestScore.AUC)
- runner.AssertCalled(t, "Lock")
- runner.AssertCalled(t, "UnLock")
assert.Equal(t, model.Params{
model.NFactors: 4,
model.InitMean: 4,
@@ -128,12 +109,7 @@ func TestGridSearchCV(t *testing.T) {
func TestRandomSearchCV(t *testing.T) {
m := &mockFactorizationMachineForSearch{}
fitConfig := newFitConfigForSearch()
- runner := new(mockRunner)
- runner.On("Lock")
- runner.On("UnLock")
- r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig, runner)
- runner.AssertCalled(t, "Lock")
- runner.AssertCalled(t, "UnLock")
+ r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig)
assert.Equal(t, float32(12), r.BestScore.AUC)
assert.Equal(t, model.Params{
model.NFactors: 4,
@@ -143,13 +119,10 @@ func TestRandomSearchCV(t *testing.T) {
}
func TestModelSearcher_RandomSearch(t *testing.T) {
- runner := new(mockRunner)
- runner.On("Lock")
- runner.On("UnLock")
- searcher := NewModelSearcher(2, 63, 1, false)
+ searcher := NewModelSearcher(2, 63, false)
searcher.model = &mockFactorizationMachineForSearch{model.BaseModel{Params: model.Params{model.NEpochs: 2}}}
tk := task.NewTask("test", searcher.Complexity())
- err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, runner)
+ err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1))
assert.NoError(t, err)
m, score := searcher.GetBestModel()
assert.Equal(t, float32(12), score.AUC)
@@ -163,13 +136,10 @@ func TestModelSearcher_RandomSearch(t *testing.T) {
}
func TestModelSearcher_GridSearch(t *testing.T) {
- runner := new(mockRunner)
- runner.On("Lock")
- runner.On("UnLock")
- searcher := NewModelSearcher(2, 64, 1, false)
+ searcher := NewModelSearcher(2, 64, false)
searcher.model = &mockFactorizationMachineForSearch{model.BaseModel{Params: model.Params{model.NEpochs: 2}}}
tk := task.NewTask("test", searcher.Complexity())
- err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, runner)
+ err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1))
assert.NoError(t, err)
m, score := searcher.GetBestModel()
assert.Equal(t, float32(12), score.AUC)
diff --git a/model/model.go b/model/model.go
index 3bb1b0371..a1f731a4d 100644
--- a/model/model.go
+++ b/model/model.go
@@ -51,8 +51,3 @@ func (model *BaseModel) GetParams() Params {
func (model *BaseModel) GetRandomGenerator() base.RandomGenerator {
return model.rng
}
-
-type Runner interface {
- Lock()
- UnLock()
-}
diff --git a/model/ranking/data.go b/model/ranking/data.go
index 0f3f140a1..fe0097765 100644
--- a/model/ranking/data.go
+++ b/model/ranking/data.go
@@ -267,49 +267,6 @@ func (dataset *DataSet) GetIndex(i int) (int32, int32) {
return dataset.FeedbackUsers.Get(i), dataset.FeedbackItems.Get(i)
}
-// LoadDataFromCSV loads Data from a CSV file. The CSV file should be:
-// [optional header]
-//
-//
-//
-// ...
-// For example, the `u.Data` from MovieLens 100K is:
-// 196\t242\t3\t881250949
-// 186\t302\t3\t891717742
-// 22\t377\t1\t878887116
-func LoadDataFromCSV(fileName, sep string, hasHeader bool) *DataSet {
- dataset := NewMapIndexDataset()
- // Open file
- file, err := os.Open(fileName)
- if err != nil {
- log.Logger().Fatal("failed to open csv file", zap.Error(err),
- zap.String("csv_file", fileName))
- }
- defer func(file *os.File) {
- err = file.Close()
- if err != nil {
- log.Logger().Error("failed to close file", zap.Error(err))
- }
- }(file)
- // Read CSV file
- scanner := bufio.NewScanner(file)
- for scanner.Scan() {
- line := scanner.Text()
- // Ignore header
- if hasHeader {
- hasHeader = false
- continue
- }
- fields := strings.Split(line, sep)
- // Ignore empty line
- if len(fields) < 2 {
- continue
- }
- dataset.AddFeedback(fields[0], fields[1], true)
- }
- return dataset
-}
-
func loadTest(dataset *DataSet, path string) error {
// Open
file, err := os.Open(path)
diff --git a/model/ranking/data_test.go b/model/ranking/data_test.go
index 9ea8279e2..bd9220ff2 100644
--- a/model/ranking/data_test.go
+++ b/model/ranking/data_test.go
@@ -36,16 +36,6 @@ func TestNewMapIndexDataset(t *testing.T) {
assert.Equal(t, 6, dataSet.ItemCount())
}
-func TestLoadDataFromCSV(t *testing.T) {
- dataset := LoadDataFromCSV("../../misc/csv_test/feedback.csv", ",", true)
- assert.Equal(t, 5, dataset.Count())
- for i := 0; i < dataset.Count(); i++ {
- userIndex, itemIndex := dataset.GetIndex(i)
- assert.Equal(t, int32(i), userIndex)
- assert.Equal(t, int32(i), itemIndex)
- }
-}
-
func TestDataSet_Split(t *testing.T) {
numUsers, numItems := 3, 5
// create dataset
diff --git a/model/ranking/model.go b/model/ranking/model.go
index 7e9207399..6d2344ec7 100644
--- a/model/ranking/model.go
+++ b/model/ranking/model.go
@@ -16,6 +16,10 @@ package ranking
import (
"fmt"
+ "io"
+ "reflect"
+ "time"
+
"github.com/bits-and-blooms/bitset"
"github.com/chewxy/math32"
"github.com/juju/errors"
@@ -30,9 +34,6 @@ import (
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/model"
"go.uber.org/zap"
- "io"
- "reflect"
- "time"
)
type Score struct {
@@ -42,7 +43,7 @@ type Score struct {
}
type FitConfig struct {
- Jobs int
+ *task.JobsAllocator
Verbose int
Candidates int
TopK int
@@ -51,7 +52,6 @@ type FitConfig struct {
func NewFitConfig() *FitConfig {
return &FitConfig{
- Jobs: 1,
Verbose: 10,
Candidates: 100,
TopK: 10,
@@ -63,8 +63,8 @@ func (config *FitConfig) SetVerbose(verbose int) *FitConfig {
return config
}
-func (config *FitConfig) SetJobs(nJobs int) *FitConfig {
- config.Jobs = nJobs
+func (config *FitConfig) SetJobsAllocator(allocator *task.JobsAllocator) *FitConfig {
+ config.JobsAllocator = allocator
return config
}
@@ -302,11 +302,12 @@ func UnmarshalModel(r io.Reader) (MatrixFactorization, error) {
// model with implicit feedback. The pairwise ranking between item i and j for user u is estimated
// by:
//
-// p(i >_u j) = \sigma( p_u^T (q_i - q_j) )
+// p(i >_u j) = \sigma( p_u^T (q_i - q_j) )
//
// Hyper-parameters:
+//
// Reg - The regularization parameter of the cost function that is
-// optimized. Default is 0.01.
+// optimized. Default is 0.01.
// Lr - The learning rate of SGD. Default is 0.05.
// nFactors - The number of latent factors. Default is 10.
// NEpochs - The number of iteration of the SGD procedure. Default is 100.
@@ -401,12 +402,13 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
zap.Any("config", config))
bpr.Init(trainSet)
// Create buffers
- temp := base.NewMatrix32(config.Jobs, bpr.nFactors)
- userFactor := base.NewMatrix32(config.Jobs, bpr.nFactors)
- positiveItemFactor := base.NewMatrix32(config.Jobs, bpr.nFactors)
- negativeItemFactor := base.NewMatrix32(config.Jobs, bpr.nFactors)
- rng := make([]base.RandomGenerator, config.Jobs)
- for i := 0; i < config.Jobs; i++ {
+ maxJobs := config.MaxJobs()
+ temp := base.NewMatrix32(maxJobs, bpr.nFactors)
+ userFactor := base.NewMatrix32(maxJobs, bpr.nFactors)
+ positiveItemFactor := base.NewMatrix32(maxJobs, bpr.nFactors)
+ negativeItemFactor := base.NewMatrix32(maxJobs, bpr.nFactors)
+ rng := make([]base.RandomGenerator, maxJobs)
+ for i := 0; i < maxJobs; i++ {
rng[i] = base.NewRandomGenerator(bpr.GetRandomGenerator().Int63())
}
// Convert array to hashmap
@@ -419,7 +421,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
}
snapshots := SnapshotManger{}
evalStart := time.Now()
- scores := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall)
+ scores := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall)
evalTime := time.Since(evalStart)
log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", 0, bpr.nEpochs),
zap.String("eval_time", evalTime.String()),
@@ -431,8 +433,9 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
for epoch := 1; epoch <= bpr.nEpochs; epoch++ {
fitStart := time.Now()
// Training epoch
- cost := make([]float32, config.Jobs)
- _ = parallel.Parallel(trainSet.Count(), config.Jobs, func(workerId, _ int) error {
+ numJobs := config.AvailableJobs(config.Task)
+ cost := make([]float32, numJobs)
+ _ = parallel.Parallel(trainSet.Count(), numJobs, func(workerId, _ int) error {
// Select a user
var userIndex int32
var ratingCount int
@@ -479,7 +482,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
// Cross validation
if epoch%config.Verbose == 0 || epoch == bpr.nEpochs {
evalStart = time.Now()
- scores = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall)
+ scores = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall)
evalTime = time.Since(evalStart)
log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", epoch, bpr.nEpochs),
zap.String("fit_time", fitTime.String()),
@@ -721,12 +724,13 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
zap.Any("config", config))
ccd.Init(trainSet)
// Create temporary matrix
+ maxJobs := config.MaxJobs()
s := base.NewMatrix32(ccd.nFactors, ccd.nFactors)
- userPredictions := make([][]float32, config.Jobs)
- itemPredictions := make([][]float32, config.Jobs)
- userRes := make([][]float32, config.Jobs)
- itemRes := make([][]float32, config.Jobs)
- for i := 0; i < config.Jobs; i++ {
+ userPredictions := make([][]float32, maxJobs)
+ itemPredictions := make([][]float32, maxJobs)
+ userRes := make([][]float32, maxJobs)
+ itemRes := make([][]float32, maxJobs)
+ for i := 0; i < maxJobs; i++ {
userPredictions[i] = make([]float32, trainSet.ItemCount())
itemPredictions[i] = make([]float32, trainSet.UserCount())
userRes[i] = make([]float32, trainSet.ItemCount())
@@ -735,7 +739,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
// evaluate initial model
snapshots := SnapshotManger{}
evalStart := time.Now()
- scores := Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall)
+ scores := Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall)
evalTime := time.Since(evalStart)
log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", 0, ccd.nEpochs),
zap.String("eval_time", evalTime.String()),
@@ -757,7 +761,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
}
}
}
- _ = parallel.Parallel(trainSet.UserCount(), config.Jobs, func(workerId, userIndex int) error {
+ _ = parallel.Parallel(trainSet.UserCount(), config.AvailableJobs(config.Task), func(workerId, userIndex int) error {
userFeedback := trainSet.UserFeedback[userIndex]
for _, i := range userFeedback {
userPredictions[workerId][i] = ccd.InternalPredict(int32(userIndex), i)
@@ -798,7 +802,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
}
}
}
- _ = parallel.Parallel(trainSet.ItemCount(), config.Jobs, func(workerId, itemIndex int) error {
+ _ = parallel.Parallel(trainSet.ItemCount(), config.AvailableJobs(config.Task), func(workerId, itemIndex int) error {
itemFeedback := trainSet.ItemFeedback[itemIndex]
for _, u := range itemFeedback {
itemPredictions[workerId][u] = ccd.InternalPredict(u, int32(itemIndex))
@@ -831,7 +835,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score {
// Cross validation
if ep%config.Verbose == 0 || ep == ccd.nEpochs {
evalStart = time.Now()
- scores = Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.Jobs, NDCG, Precision, Recall)
+ scores = Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall)
evalTime = time.Since(evalStart)
log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", ep, ccd.nEpochs),
zap.String("fit_time", fitTime.String()),
diff --git a/model/ranking/model_test.go b/model/ranking/model_test.go
index 1fa686c8f..0508429d8 100644
--- a/model/ranking/model_test.go
+++ b/model/ranking/model_test.go
@@ -31,7 +31,7 @@ const (
func newFitConfig(numEpoch int) *FitConfig {
t := task.NewTask("test", numEpoch)
- cfg := NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU()).SetTask(t)
+ cfg := NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU())).SetTask(t)
return cfg
}
diff --git a/model/ranking/search.go b/model/ranking/search.go
index fdc2cc903..e1f65c2c7 100644
--- a/model/ranking/search.go
+++ b/model/ranking/search.go
@@ -16,13 +16,14 @@ package ranking
import (
"fmt"
+ "sync"
+ "time"
+
"github.com/zhenghaoz/gorse/base"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/model"
"go.uber.org/zap"
- "sync"
- "time"
)
// ParamsSearchResult contains the return of grid search.
@@ -47,7 +48,7 @@ func (r *ParamsSearchResult) AddScore(params model.Params, score Score) {
// GridSearchCV finds the best parameters for a model.
func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid,
- _ int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult {
+ _ int64, fitConfig *FitConfig) ParamsSearchResult {
// Retrieve parameter names and length
paramNames := make([]model.ParamName, 0, len(paramGrid))
count := 1
@@ -70,11 +71,7 @@ func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *Dat
// Cross validate
estimator.Clear()
estimator.SetParams(estimator.GetParams().Overwrite(params))
- fitConfig.Task.Suspend(true)
- runner.Lock()
- fitConfig.Task.Suspend(false)
score := estimator.Fit(trainSet, testSet, fitConfig)
- runner.UnLock()
// Create GridSearch result
results.Scores = append(results.Scores, score)
results.Params = append(results.Params, params.Copy())
@@ -100,10 +97,10 @@ func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *Dat
// RandomSearchCV searches hyper-parameters by random.
func RandomSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid,
- numTrials int, seed int64, fitConfig *FitConfig, runner model.Runner) ParamsSearchResult {
+ numTrials int, seed int64, fitConfig *FitConfig) ParamsSearchResult {
// if the number of combination is less than number of trials, use grid search
if paramGrid.NumCombinations() < numTrials {
- return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig, runner)
+ return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig)
}
rng := base.NewRandomGenerator(seed)
results := ParamsSearchResult{
@@ -122,11 +119,7 @@ func RandomSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *D
zap.Any("params", params))
estimator.Clear()
estimator.SetParams(estimator.GetParams().Overwrite(params))
- fitConfig.Task.Suspend(true)
- runner.Lock()
- fitConfig.Task.Suspend(false)
score := estimator.Fit(trainSet, testSet, fitConfig)
- runner.UnLock()
results.Scores = append(results.Scores, score)
results.Params = append(results.Params, params.Copy())
if len(results.Scores) == 0 || score.NDCG > results.BestScore.NDCG {
@@ -145,7 +138,6 @@ type ModelSearcher struct {
// arguments
numEpochs int
numTrials int
- numJobs int
searchSize bool
// results
bestMutex sync.Mutex
@@ -155,11 +147,10 @@ type ModelSearcher struct {
}
// NewModelSearcher creates a thread-safe personal ranking model searcher.
-func NewModelSearcher(nEpoch, nTrials, nJobs int, searchSize bool) *ModelSearcher {
+func NewModelSearcher(nEpoch, nTrials int, searchSize bool) *ModelSearcher {
searcher := &ModelSearcher{
numTrials: nTrials,
numEpochs: nEpoch,
- numJobs: nJobs,
searchSize: searchSize,
}
searcher.models = append(searcher.models, NewBPR(model.Params{model.NEpochs: searcher.numEpochs}))
@@ -178,7 +169,7 @@ func (searcher *ModelSearcher) Complexity() int {
return len(searcher.models) * searcher.numEpochs * searcher.numTrials
}
-func (searcher *ModelSearcher) Fit(trainSet, valSet *DataSet, t *task.Task, runner model.Runner) error {
+func (searcher *ModelSearcher) Fit(trainSet, valSet *DataSet, t *task.Task, j *task.JobsAllocator) error {
log.Logger().Info("ranking model search",
zap.Int("n_users", trainSet.UserCount()),
zap.Int("n_items", trainSet.ItemCount()))
@@ -186,8 +177,8 @@ func (searcher *ModelSearcher) Fit(trainSet, valSet *DataSet, t *task.Task, runn
for _, m := range searcher.models {
r := RandomSearchCV(m, trainSet, valSet, m.GetParamsGrid(searcher.searchSize), searcher.numTrials, 0,
NewFitConfig().
- SetJobs(searcher.numJobs).
- SetTask(t), runner)
+ SetJobsAllocator(j).
+ SetTask(t))
searcher.bestMutex.Lock()
if searcher.bestModel == nil || r.BestScore.NDCG > searcher.bestScore.NDCG {
searcher.bestModel = r.BestModel
diff --git a/model/ranking/search_test.go b/model/ranking/search_test.go
index f91b500ce..71141b4e8 100644
--- a/model/ranking/search_test.go
+++ b/model/ranking/search_test.go
@@ -15,7 +15,6 @@ package ranking
import (
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/mock"
"github.com/zhenghaoz/gorse/base"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/model"
@@ -104,22 +103,9 @@ func (m *mockMatrixFactorizationForSearch) GetParamsGrid(_ bool) model.ParamsGri
}
}
-type mockRunner struct {
- mock.Mock
-}
-
-func (r *mockRunner) Lock() {
- r.Called()
-}
-
-func (r *mockRunner) UnLock() {
- r.Called()
-}
-
func newFitConfigForSearch() *FitConfig {
t := task.NewTask("test", 100)
return &FitConfig{
- Jobs: 1,
Verbose: 1,
Task: t,
}
@@ -128,12 +114,7 @@ func newFitConfigForSearch() *FitConfig {
func TestGridSearchCV(t *testing.T) {
m := &mockMatrixFactorizationForSearch{}
fitConfig := newFitConfigForSearch()
- runner := new(mockRunner)
- runner.On("Lock")
- runner.On("UnLock")
- r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig, runner)
- runner.AssertCalled(t, "Lock")
- runner.AssertCalled(t, "UnLock")
+ r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig)
assert.Equal(t, float32(12), r.BestScore.NDCG)
assert.Equal(t, model.Params{
model.NFactors: 4,
@@ -145,12 +126,7 @@ func TestGridSearchCV(t *testing.T) {
func TestRandomSearchCV(t *testing.T) {
m := &mockMatrixFactorizationForSearch{}
fitConfig := newFitConfigForSearch()
- runner := new(mockRunner)
- runner.On("Lock")
- runner.On("UnLock")
- r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig, runner)
- runner.AssertCalled(t, "Lock")
- runner.AssertCalled(t, "UnLock")
+ r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig)
assert.Equal(t, float32(12), r.BestScore.NDCG)
assert.Equal(t, model.Params{
model.NFactors: 4,
@@ -160,13 +136,10 @@ func TestRandomSearchCV(t *testing.T) {
}
func TestModelSearcher(t *testing.T) {
- runner := new(mockRunner)
- runner.On("Lock")
- runner.On("UnLock")
- searcher := NewModelSearcher(2, 63, 1, false)
+ searcher := NewModelSearcher(2, 63, false)
searcher.models = []MatrixFactorization{newMockMatrixFactorizationForSearch(2)}
tk := task.NewTask("test", searcher.Complexity())
- err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, runner)
+ err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1))
assert.NoError(t, err)
_, m, score := searcher.GetBestModel()
assert.Equal(t, float32(12), score.NDCG)
diff --git a/misc/database_test/docker-compose.yml b/storage/docker-compose.yml
similarity index 100%
rename from misc/database_test/docker-compose.yml
rename to storage/docker-compose.yml