Skip to content

Commit

Permalink
feat: initial watchdog implementation (#1341)
Browse files Browse the repository at this point in the history
* feat: initial watchdog implementation

Signed-off-by: Ettore Di Giacinto <[email protected]>

* fiuxups

* Add more output

* wip: idletime checker

* wire idle watchdog checks

* enlarge watchdog time window

* small fixes

* Use stopmodel

* Always delete process

Signed-off-by: Ettore Di Giacinto <[email protected]>

---------

Signed-off-by: Ettore Di Giacinto <[email protected]>
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Nov 26, 2023
1 parent 9482acf commit 824612f
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 13 deletions.
16 changes: 15 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,18 @@ MODELS_PATH=/models
# LLAMACPP_PARALLEL=1

### Enable to run parallel requests
# PARALLEL_REQUESTS=true
# PARALLEL_REQUESTS=true

### Watchdog settings
###
# Enables watchdog to kill backends that are inactive for too much time
# WATCHDOG_IDLE=true
#
# Enables watchdog to kill backends that are busy for too much time
# WATCHDOG_BUSY=true
#
# Time in duration format (e.g. 1h30m) after which a backend is considered idle
# WATCHDOG_IDLE_TIMEOUT=5m
#
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
# WATCHDOG_BUSY_TIMEOUT=5m
17 changes: 17 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
Expand Down Expand Up @@ -79,6 +80,22 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
options.Loader.StopAllGRPC()
}()

if options.WatchDog {
wd := model.NewWatchDog(
options.Loader,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
options.Loader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}

return options, cl, nil
}

Expand Down
2 changes: 1 addition & 1 deletion api/localai/backend_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}

status, rpcErr := model.GRPC(false).Status(context.TODO())
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
Expand Down
32 changes: 32 additions & 0 deletions api/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"embed"
"encoding/json"
"time"

"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery"
Expand Down Expand Up @@ -38,6 +39,11 @@ type Option struct {

SingleBackend bool
ParallelBackendRequests bool

WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
}

type AppOption func(*Option)
Expand All @@ -63,6 +69,32 @@ func WithCors(b bool) AppOption {
}
}

var EnableWatchDog = func(o *Option) {
o.WatchDog = true
}

var EnableWatchDogIdleCheck = func(o *Option) {
o.WatchDog = true
o.WatchDogIdle = true
}

var EnableWatchDogBusyCheck = func(o *Option) {
o.WatchDog = true
o.WatchDogBusy = true
}

func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *Option) {
o.WatchDogBusyTimeout = t
}
}

func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *Option) {
o.WatchDogIdleTimeout = t
}
}

var EnableSingleBackend = func(o *Option) {
o.SingleBackend = true
}
Expand Down
47 changes: 47 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"path/filepath"
"strings"
"syscall"
"time"

api "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/api/backend"
Expand Down Expand Up @@ -154,6 +155,30 @@ func main() {
Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.",
EnvVars: []string{"API_KEY"},
},
&cli.BoolFlag{
Name: "enable-watchdog-idle",
Usage: "Enable watchdog for stopping idle backends. This will stop the backends if are in idle state for too long.",
EnvVars: []string{"WATCHDOG_IDLE"},
Value: false,
},
&cli.BoolFlag{
Name: "enable-watchdog-busy",
Usage: "Enable watchdog for stopping busy backends that exceed a defined threshold.",
EnvVars: []string{"WATCHDOG_BUSY"},
Value: false,
},
&cli.StringFlag{
Name: "watchdog-busy-timeout",
Usage: "Watchdog timeout. This will restart the backend if it crashes.",
EnvVars: []string{"WATCHDOG_BUSY_TIMEOUT"},
Value: "5m",
},
&cli.StringFlag{
Name: "watchdog-idle-timeout",
Usage: "Watchdog idle timeout. This will restart the backend if it crashes.",
EnvVars: []string{"WATCHDOG_IDLE_TIMEOUT"},
Value: "15m",
},
&cli.BoolFlag{
Name: "preload-backend-only",
Usage: "If set, the api is NOT launched, and only the preloaded models / backends are started. This is intended for multi-node setups.",
Expand Down Expand Up @@ -198,6 +223,28 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
options.WithUploadLimitMB(ctx.Int("upload-limit")),
options.WithApiKeys(ctx.StringSlice("api-keys")),
}

idleWatchDog := ctx.Bool("enable-watchdog-idle")
busyWatchDog := ctx.Bool("enable-watchdog-busy")
if idleWatchDog || busyWatchDog {
opts = append(opts, options.EnableWatchDog)
if idleWatchDog {
opts = append(opts, options.EnableWatchDogIdleCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
if err != nil {
return err
}
opts = append(opts, options.SetWatchDogIdleTimeout(dur))
}
if busyWatchDog {
opts = append(opts, options.EnableWatchDogBusyCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
if err != nil {
return err
}
opts = append(opts, options.SetWatchDogBusyTimeout(dur))
}
}
if ctx.Bool("parallel-requests") {
opts = append(opts, options.EnableParallelBackendRequests)
}
Expand Down
44 changes: 43 additions & 1 deletion pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,22 @@ type Client struct {
parallel bool
sync.Mutex
opMutex sync.Mutex
wd WatchDog
}

func NewClient(address string, parallel bool) *Client {
type WatchDog interface {
Mark(address string)
UnMark(address string)
}

func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client {
if !enableWatchDog {
wd = nil
}
return &Client{
address: address,
parallel: parallel,
wd: wd,
}
}

Expand Down Expand Up @@ -79,6 +89,10 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -96,6 +110,10 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -113,6 +131,10 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -129,6 +151,10 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
Expand Down Expand Up @@ -164,6 +190,10 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -180,6 +210,10 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -196,6 +230,10 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down Expand Up @@ -232,6 +270,10 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
// Wait for the service to start up
ready := false
for i := 0; i < o.grpcAttempts; i++ {
if client.GRPC(o.parallelRequests).HealthCheck(context.Background()) {
if client.GRPC(o.parallelRequests, ml.wd).HealthCheck(context.Background()) {
log.Debug().Msgf("GRPC Service Ready")
ready = true
break
Expand All @@ -140,7 +140,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

log.Debug().Msgf("GRPC: Loading model with options: %+v", options)

res, err := client.GRPC(o.parallelRequests).LoadModel(o.context, &options)
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
if err != nil {
return "", fmt.Errorf("could not load model: %w", err)
}
Expand All @@ -154,11 +154,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) {
if parallel {
return addr.GRPC(parallel), nil
return addr.GRPC(parallel, ml.wd), nil
}

if _, ok := ml.grpcClients[string(addr)]; !ok {
ml.grpcClients[string(addr)] = addr.GRPC(parallel)
ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd)
}
return ml.grpcClients[string(addr)], nil
}
Expand Down
26 changes: 21 additions & 5 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ type ModelLoader struct {
models map[string]ModelAddress
grpcProcesses map[string]*process.Process
templates map[TemplateType]map[string]*template.Template
wd *WatchDog
}

type ModelAddress string

func (m ModelAddress) GRPC(parallel bool) *grpc.Client {
return grpc.NewClient(string(m), parallel)
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
enableWD := false
if wd != nil {
enableWD = true
}
return grpc.NewClient(string(m), parallel, wd, enableWD)
}

func NewModelLoader(modelPath string) *ModelLoader {
Expand All @@ -79,10 +84,15 @@ func NewModelLoader(modelPath string) *ModelLoader {
templates: make(map[TemplateType]map[string]*template.Template),
grpcProcesses: make(map[string]*process.Process),
}

nml.initializeTemplateMap()
return nml
}

func (ml *ModelLoader) SetWatchDog(wd *WatchDog) {
ml.wd = wd
}

func (ml *ModelLoader) ExistsInModelPath(s string) bool {
return existsInPath(ml.ModelPath, s)
}
Expand Down Expand Up @@ -139,11 +149,17 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
func (ml *ModelLoader) ShutdownModel(modelName string) error {
ml.mu.Lock()
defer ml.mu.Unlock()

return ml.StopModel(modelName)
}

func (ml *ModelLoader) StopModel(modelName string) error {
defer ml.deleteProcess(modelName)
if _, ok := ml.models[modelName]; !ok {
return fmt.Errorf("model %s not found", modelName)
}

return ml.deleteProcess(modelName)
return nil
//return ml.deleteProcess(modelName)
}

func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
Expand All @@ -153,7 +169,7 @@ func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
if c, ok := ml.grpcClients[s]; ok {
client = c
} else {
client = m.GRPC(false)
client = m.GRPC(false, ml.wd)
}

if !client.HealthCheck(context.Background()) {
Expand Down
Loading

0 comments on commit 824612f

Please sign in to comment.