Skip to content

Commit

Permalink
feat: tpool - thread pool for cpu bound tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Jul 18, 2023
1 parent a129727 commit 2a6966c
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 0 deletions.
80 changes: 80 additions & 0 deletions util/tpool/scheduler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package tpool

import "sync"

type task func()

func newScheduler(threads int) *scheduler {
s := &scheduler{
threads: threads,
tasks: make([]task, 0, 1024),
notifier: make([]chan struct{}, 0, 1024),
}
return s
}

type scheduler struct {
mu sync.Mutex
state int // 0: running, -1: closed
threads int // threads number
tasks []task // LIFO: We want to make most tasks have the fastest latency
notifier []chan struct{}
}

func (s *scheduler) Close() {
s.mu.Lock()
s.state = -1
for i := 0; i < len(s.notifier); i++ {
notify := s.notifier[i]
notify <- struct{}{}
}
s.mu.Unlock()
}

func (s *scheduler) Add(t task) {
var notify chan struct{}
s.mu.Lock()
if s.state < 0 { // closed
return
}

waits := len(s.notifier)
s.tasks = append(s.tasks, t)
if waits > 0 {
notify = s.notifier[waits-1]
s.notifier = s.notifier[:waits-1]
}
s.mu.Unlock()
if notify != nil {
notify <- struct{}{}
}
}

func (s *scheduler) Get() (t task) {
var notify chan struct{}
GET:
s.mu.Lock()
if s.state < 0 { // closed
return
}

size := len(s.tasks)
if size > 0 {
t = s.tasks[size-1]
s.tasks = s.tasks[:size-1]
s.mu.Unlock()
if notify != nil {
close(notify)
}
return t
}
if notify == nil {
notify = make(chan struct{}, 1)
}
s.notifier = append(s.notifier, notify)
s.mu.Unlock()

<-notify // thread go to sleep
// thread wakeup
goto GET
}
28 changes: 28 additions & 0 deletions util/tpool/thread.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package tpool

import "runtime"

type Thread struct {
scheduler *scheduler
}

func newThread(scheduler *scheduler) *Thread {
t := &Thread{
scheduler: scheduler,
}
go t.run()
return t
}

func (t *Thread) run() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
for {
tsk := t.scheduler.Get()
if tsk == nil {
// closed
return
}
tsk()
}
}
28 changes: 28 additions & 0 deletions util/tpool/thread_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package tpool

func New(size int) *threadPool {
s := newScheduler(size)
threads := make([]*Thread, size)
for i := 0; i < size; i++ {
threads[i] = newThread(s)
}

pool := &threadPool{
scheduler: s,
threads: threads,
}
return pool
}

type threadPool struct {
scheduler *scheduler
threads []*Thread
}

func (tp *threadPool) Submit(task task) {
tp.scheduler.Add(task)
}

func (tp *threadPool) Close() {
tp.scheduler.Close()
}
80 changes: 80 additions & 0 deletions util/tpool/thread_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package tpool

import (
"runtime"
"runtime/pprof"
"sync"
"testing"
"time"
)

func TestSleep(t *testing.T) {
var threadProfile = pprof.Lookup("threadcreate")
runtime.GOMAXPROCS(2)
threads := 4
p := New(threads)
defer p.Close()

var wg sync.WaitGroup
wg.Add(threads)
for i := 0; i < threads; i++ {
p.Submit(func() {
time.Sleep(time.Millisecond * 10)
wg.Done()
})
}
wg.Wait()
t.Logf("Currnt thread count: %d", threadProfile.Count())
}

func TestCPUBond(t *testing.T) {
var threadProfile = pprof.Lookup("threadcreate")
runtime.GOMAXPROCS(2)
threads := 4
p := New(threads)
defer p.Close()

for round := threads; round <= 128; round *= 2 {
var wg sync.WaitGroup
begin := time.Now()
wg.Add(round)
for i := 0; i < round; i++ {
p.Submit(func() {
var sum int
for x := 0; x <= 100000000; x++ {
sum += x
}
_ = sum
wg.Done()
})
}
wg.Wait()
cost := time.Now().Sub(begin)
t.Logf("Round[%d]: cost %d ms", round, cost.Milliseconds())
}
t.Logf("Currnt thread count: %d", threadProfile.Count())
}

func TestHugeThreads(t *testing.T) {
var threadProfile = pprof.Lookup("threadcreate")
runtime.GOMAXPROCS(2)
threads := 32
p := New(threads)
defer p.Close()

var wg sync.WaitGroup
round := threads * 16
wg.Add(round)
for i := 0; i < round; i++ {
p.Submit(func() {
var sum int
for x := 0; x <= 100000000; x++ {
sum += x
}
_ = sum
wg.Done()
})
}
wg.Wait()
t.Logf("Currnt thread count: %d", threadProfile.Count())
}

0 comments on commit 2a6966c

Please sign in to comment.