Skip to content

Commit

Permalink
fix(model-loading): keep track of open GRPC Clients (#3377)
Browse files Browse the repository at this point in the history
Due to a previous refactor we moved the client constructor tight to the
model address, however that was just a string which we would use to
build the client each time.

With this change we make the loader to return a *Model which carries a
constructor for the client and stores the client on the first
connection.

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Aug 25, 2024
1 parent 771a052 commit 7f06954
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 171 deletions.
2 changes: 1 addition & 1 deletion core/services/backend_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status
return nil, err
}
modelAddr := bms.modelLoader.CheckIsLoaded(backendId)
if modelAddr == "" {
if modelAddr == nil {
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/grpc/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool)
if bc, ok := embeds[address]; ok {
return bc
}
return NewGrpcClient(address, parallel, wd, enableWatchDog)
return buildClient(address, parallel, wd, enableWatchDog)
}

func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
func buildClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
if !enableWatchDog {
wd = nil
}
Expand Down
76 changes: 40 additions & 36 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ func (c *Client) setBusy(v bool) {
c.Unlock()
}

func (c *Client) wdMark() {
if c.wd != nil {
c.wd.Mark(c.address)
}
}

func (c *Client) wdUnMark() {
if c.wd != nil {
c.wd.UnMark(c.address)
}
}

func (c *Client) HealthCheck(ctx context.Context) (bool, error) {
if !c.parallel {
c.opMutex.Lock()
Expand Down Expand Up @@ -76,10 +88,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -97,10 +107,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -118,10 +126,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -138,10 +144,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
Expand Down Expand Up @@ -177,10 +181,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -197,10 +199,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -217,10 +217,8 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -237,10 +235,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down Expand Up @@ -277,10 +273,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down Expand Up @@ -319,6 +313,8 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ..
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -333,6 +329,8 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.wdMark()
defer c.wdUnMark()
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand All @@ -351,6 +349,8 @@ func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ..
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -367,6 +367,8 @@ func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand All @@ -383,6 +385,8 @@ func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
Expand Down
55 changes: 24 additions & 31 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ ENTRY:
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".log") {
continue
}

// Skip the llama.cpp variants if we are autoDetecting
// But we always load the fallback variant if it exists
Expand Down Expand Up @@ -265,12 +268,12 @@ func selectGRPCProcess(backend, assetDir string, f16 bool) string {

// starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) {
return func(modelName, modelFile string) (ModelAddress, error) {
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*Model, error) {
return func(modelName, modelFile string) (*Model, error) {

log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)

var client ModelAddress
var client *Model

getFreeAddress := func() (string, error) {
port, err := freeport.GetFreePort()
Expand Down Expand Up @@ -298,26 +301,26 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
log.Debug().Msgf("external backend is file: %+v", fi)
serverAddress, err := getFreeAddress()
if err != nil {
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
}
// Make sure the process is executable
if err := ml.startProcess(uri, o.model, serverAddress); err != nil {
log.Error().Err(err).Str("path", uri).Msg("failed to launch ")
return "", err
return nil, err
}

log.Debug().Msgf("GRPC Service Started")

client = ModelAddress(serverAddress)
client = NewModel(serverAddress)
} else {
log.Debug().Msg("external backend is uri")
// address
client = ModelAddress(uri)
client = NewModel(uri)
}
} else {
grpcProcess := backendPath(o.assetDir, backend)
if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil {
return "", fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
return nil, fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
}

if autoDetect {
Expand All @@ -329,12 +332,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

// Check if the file exists
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return "", fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
}

serverAddress, err := getFreeAddress()
if err != nil {
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
}

args := []string{}
Expand All @@ -344,12 +347,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

// Make sure the process is executable in any circumstance
if err := ml.startProcess(grpcProcess, o.model, serverAddress, args...); err != nil {
return "", err
return nil, err
}

log.Debug().Msgf("GRPC Service Started")

client = ModelAddress(serverAddress)
client = NewModel(serverAddress)
}

// Wait for the service to start up
Expand All @@ -369,7 +372,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

if !ready {
log.Debug().Msgf("GRPC Service NOT ready")
return "", fmt.Errorf("grpc service not ready")
return nil, fmt.Errorf("grpc service not ready")
}

options := *o.gRPCOptions
Expand All @@ -380,27 +383,16 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
if err != nil {
return "", fmt.Errorf("could not load model: %w", err)
return nil, fmt.Errorf("could not load model: %w", err)
}
if !res.Success {
return "", fmt.Errorf("could not load model (no success): %s", res.Message)
return nil, fmt.Errorf("could not load model (no success): %s", res.Message)
}

return client, nil
}
}

func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) {
if parallel {
return addr.GRPC(parallel, ml.wd), nil
}

if _, ok := ml.grpcClients[string(addr)]; !ok {
ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd)
}
return ml.grpcClients[string(addr)], nil
}

func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...)

Expand All @@ -425,7 +417,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel")
return nil, err
}

}

var backendToConsume string
Expand All @@ -438,26 +429,28 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
backendToConsume = backend
}

addr, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o))
model, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o))
if err != nil {
return nil, err
}

return ml.resolveAddress(addr, o.parallelRequests)
return model.GRPC(o.parallelRequests, ml.wd), nil
}

func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
o := NewOptions(opts...)

ml.mu.Lock()

// Return earlier if we have a model already loaded
// (avoid looping through all the backends)
if m := ml.CheckIsLoaded(o.model); m != "" {
if m := ml.CheckIsLoaded(o.model); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model)
ml.mu.Unlock()

return ml.resolveAddress(m, o.parallelRequests)
return m.GRPC(o.parallelRequests, ml.wd), nil
}

// If we can have only one backend active, kill all the others (except external backends)
if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
Expand Down
Loading

0 comments on commit 7f06954

Please sign in to comment.