Skip to content

Commit

Permalink
support dynamic jobs allocation (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Aug 15, 2022
1 parent d2c61aa commit 6f265b4
Show file tree
Hide file tree
Showing 43 changed files with 1,055 additions and 979 deletions.
7 changes: 7 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.github
assets

LICENSE

*.yml
*.md
2 changes: 1 addition & 1 deletion .github/workflows/build_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
with:
context: .
platforms: linux/amd64,linux/arm64
file: docker/${{ matrix.image }}/Dockerfile
file: cmd/${{ matrix.image }}/Dockerfile
push: true
tags: |
zhenghaoz/${{ matrix.image }}:latest
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ jobs:
- uses: actions/checkout@v1

- name: Build the stack
working-directory: ./docker
run: docker-compose up -d

- name: Check the deployed service URL
Expand Down
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ These following installations are required:
- **Docker Compose**: Multiple databases are required for unit tests. It's convenient to manage databases on Docker Compose.

```bash
cd misc/database_test
cd storage

docker-compose up -d
```
Expand Down Expand Up @@ -73,7 +73,7 @@ Most logics in Gorse are covered by unit tests. Run unit tests by the following
go test -v ./...
```

The default database URLs are directed to these databases in `misc/database_test/docker-compose.yml`. Test databases could be overrode by setting following environment variables:
The default database URLs are directed to these databases in `storage/docker-compose.yml`. Test databases could be overrode by setting following environment variables:

| Environment Value | Default Value |
|-------------------|----------------------------------------------|
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

<img width=160 src="assets/gorse.png"/>

![](https://img.shields.io/github/go-mod/go-version/zhenghaoz/gorse)
[![build](https://github.com/zhenghaoz/gorse/workflows/build/badge.svg)](https://github.com/zhenghaoz/gorse/actions?query=workflow%3Abuild)
[![codecov](https://codecov.io/gh/gorse-io/gorse/branch/master/graph/badge.svg)](https://codecov.io/gh/gorse-io/gorse)
[![Go Report Card](https://goreportcard.com/badge/github.com/zhenghaoz/gorse)](https://goreportcard.com/report/github.com/zhenghaoz/gorse)
Expand Down
14 changes: 12 additions & 2 deletions base/copier/copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ func copyValue(dst, src reflect.Value) error {
dstPointer := reflect.New(dst.Type())
srcPointer := reflect.New(src.Type())
srcPointer.Elem().Set(src)
srcMarshaller, hasSrcMarshaler := srcPointer.Interface().(encoding.BinaryMarshaler)
srcMarshaller, hasSrcMarshaller := srcPointer.Interface().(encoding.BinaryMarshaler)
dstUnmarshaler, hasDstUnmarshaler := dstPointer.Interface().(encoding.BinaryUnmarshaler)

if hasDstUnmarshaler && hasSrcMarshaler {
if hasDstUnmarshaler && hasSrcMarshaller {
dstByte, err := srcMarshaller.MarshalBinary()
if err != nil {
return err
Expand All @@ -114,6 +114,11 @@ func copyValue(dst, src reflect.Value) error {
}
}
case reflect.Ptr:
if src.IsNil() {
// If source is nil, set dst to nil.
dst.Set(reflect.Zero(dst.Type()))
return nil
}
if dst.IsNil() {
dst.Set(reflect.New(src.Elem().Type()))
}
Expand All @@ -124,6 +129,11 @@ func copyValue(dst, src reflect.Value) error {
return err
}
case reflect.Interface:
if src.IsNil() {
// If source is nil, set dst to nil.
dst.Set(reflect.Zero(dst.Type()))
return nil
}
if !dst.IsNil() {
switch dst.Elem().Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint,
Expand Down
19 changes: 19 additions & 0 deletions base/copier/copier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,22 @@ func TestPrivate(t *testing.T) {
*a.text = "world"
assert.Equal(t, "hello", *b.text)
}

type NilInterface interface{}

type NilStruct struct {
Interface NilInterface
Pointer *Foo
}

func TestNil(t *testing.T) {
var d = NilStruct{
Interface: 100,
Pointer: &Foo{A: 100},
}
var e NilStruct
err := Copy(&d, e)
assert.NoError(t, err)
assert.Nil(t, d.Interface)
assert.Nil(t, d.Pointer)
}
83 changes: 64 additions & 19 deletions base/parallel/parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@
package parallel

import (
"sync"

"github.com/juju/errors"
"github.com/zhenghaoz/gorse/base/task"
"go.uber.org/atomic"
"modernc.org/mathutil"
"sync"
)

const (
chanSize = 1024
allocPeriod = 128
)

/* Parallel Schedulers */
Expand All @@ -33,19 +41,13 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error
}
}
} else {
const chanSize = 64
const chanEOF = -1
c := make(chan int, chanSize)
// producer
go func() {
// send jobs
for i := 0; i < nJobs; i++ {
c <- i
}
// send EOF
for i := 0; i < nWorkers; i++ {
c <- chanEOF
}
close(c)
}()
// consumer
var wg sync.WaitGroup
Expand All @@ -57,8 +59,8 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error
defer wg.Done()
for {
// read job
jobId := <-c
if jobId == chanEOF {
jobId, ok := <-c
if !ok {
return
}
// run job
Expand All @@ -80,6 +82,55 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error
return nil
}

func DynamicParallel(nJobs int, jobsAlloc *task.JobsAllocator, worker func(workerId, jobId int) error) error {
c := make(chan int, chanSize)
// producer
go func() {
for i := 0; i < nJobs; i++ {
c <- i
}
close(c)
}()
// consumer
for {
exit := atomic.NewBool(true)
numJobs := jobsAlloc.AvailableJobs(nil)
var wg sync.WaitGroup
wg.Add(numJobs)
errs := make([]error, nJobs)
for j := 0; j < numJobs; j++ {
// start workers
go func(workerId int) {
defer wg.Done()
for i := 0; i < allocPeriod; i++ {
// read job
jobId, ok := <-c
if !ok {
return
}
exit.Store(false)
// run job
if err := worker(workerId, jobId); err != nil {
errs[jobId] = err
return
}
}
}(j)
}
wg.Wait()
// check errors
for _, err := range errs {
if err != nil {
return errors.Trace(err)
}
}
// exit if finished
if exit.Load() {
return nil
}
}
}

type batchJob struct {
beginId int
endId int
Expand All @@ -90,19 +141,13 @@ func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJo
if nWorkers == 1 {
return worker(0, 0, nJobs)
}
const chanSize = 64
const chanEOF = -1
c := make(chan batchJob, chanSize)
// producer
go func() {
// send jobs
for i := 0; i < nJobs; i += batchSize {
c <- batchJob{beginId: i, endId: mathutil.Min(i+batchSize, nJobs)}
}
// send EOF
for i := 0; i < nWorkers; i++ {
c <- batchJob{beginId: chanEOF, endId: chanEOF}
}
close(c)
}()
// consumer
var wg sync.WaitGroup
Expand All @@ -114,8 +159,8 @@ func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJo
defer wg.Done()
for {
// read job
job := <-c
if job.beginId == chanEOF {
job, ok := <-c
if !ok {
return
}
// run job
Expand Down
17 changes: 17 additions & 0 deletions base/parallel/parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/scylladb/go-set"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/base"
"github.com/zhenghaoz/gorse/base/task"
"testing"
"time"
)
Expand Down Expand Up @@ -48,6 +49,22 @@ func TestParallel(t *testing.T) {
assert.Equal(t, 1, workersSet.Size())
}

func TestDynamicParallel(t *testing.T) {
a := base.RangeInt(10000)
b := make([]int, len(a))
workerIds := make([]int, len(a))
_ = DynamicParallel(len(a), task.NewConstantJobsAllocator(3), func(workerId, jobId int) error {
b[jobId] = a[jobId]
workerIds[jobId] = workerId
time.Sleep(time.Microsecond)
return nil
})
workersSet := set.NewIntSet(workerIds...)
assert.Equal(t, a, b)
assert.GreaterOrEqual(t, 4, workersSet.Size())
assert.Less(t, 1, workersSet.Size())
}

func TestBatchParallel(t *testing.T) {
a := base.RangeInt(10000)
b := make([]int, len(a))
Expand Down
12 changes: 2 additions & 10 deletions base/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
package base

import (
"math/rand"

"github.com/scylladb/go-set/i32set"
"github.com/scylladb/go-set/iset"
"math/rand"
)

// RandomGenerator is the random generator for gorse.
Expand Down Expand Up @@ -76,15 +77,6 @@ func (rng RandomGenerator) NormalVector64(size int, mean, stdDev float64) []floa
return ret
}

// NormalMatrix64 makes a matrix filled with normal random floats.
func (rng RandomGenerator) NormalMatrix64(row, col int, mean, stdDev float64) [][]float64 {
ret := make([][]float64, row)
for i := range ret {
ret[i] = rng.NormalVector64(col, mean, stdDev)
}
return ret
}

// Sample n values between low and high, but not in exclude.
func (rng RandomGenerator) Sample(low, high, n int, exclude ...*iset.Set) []int {
intervalLength := high - low
Expand Down
18 changes: 6 additions & 12 deletions base/random_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
package base

import (
"testing"

"github.com/chewxy/math32"
"github.com/scylladb/go-set"
"github.com/stretchr/testify/assert"
"github.com/thoas/go-funk"
"gonum.org/v1/gonum/stat"
"math"
"testing"
)

const randomEpsilon = 0.1
Expand All @@ -39,13 +38,6 @@ func TestRandomGenerator_MakeUniformMatrix(t *testing.T) {
assert.False(t, funk.MaxFloat32(vec) > 2)
}

func TestRandomGenerator_MakeNormalMatrix64(t *testing.T) {
rng := NewRandomGenerator(0)
vec := rng.NormalMatrix64(1, 1000, 1, 2)[0]
assert.False(t, math.Abs(stat.Mean(vec, nil)-1) > randomEpsilon)
assert.False(t, math.Abs(stat.StdDev(vec, nil)-2) > randomEpsilon)
}

func TestRandomGenerator_Sample(t *testing.T) {
excludeSet := set.NewIntSet(0, 1, 2, 3, 4)
rng := NewRandomGenerator(0)
Expand Down Expand Up @@ -80,8 +72,10 @@ func stdDev(x []float32) float32 {
}

// meanVariance computes the sample mean and unbiased variance, where the mean and variance are
// \sum_i w_i * x_i / (sum_i w_i)
// \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1)
//
// \sum_i w_i * x_i / (sum_i w_i)
// \sum_i w_i (x_i - mean)^2 / (sum_i w_i - 1)
//
// respectively.
// If weights is nil then all of the weights are 1. If weights is not nil, then
// len(x) must equal len(weights).
Expand Down
2 changes: 1 addition & 1 deletion base/search/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestHNSW_InnerProduct(t *testing.T) {
model.InitMean: 0,
model.InitStdDev: 0.001,
})
fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobs(runtime.NumCPU())
fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU()))
m.Fit(trainSet, testSet, fitConfig)
var vectors []Vector
for i, itemFactor := range m.ItemFactor {
Expand Down
Loading

0 comments on commit 6f265b4

Please sign in to comment.