diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go index 395886042c4..88fefa09c49 100644 --- a/core/services/backend_monitor.go +++ b/core/services/backend_monitor.go @@ -107,7 +107,7 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status return nil, err } modelAddr := bms.modelLoader.CheckIsLoaded(backendId) - if modelAddr == "" { + if modelAddr == nil { return nil, fmt.Errorf("backend %s is not currently loaded", backendId) } diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 5abc34abf38..3821678cbad 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -18,10 +18,10 @@ func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) if bc, ok := embeds[address]; ok { return bc } - return NewGrpcClient(address, parallel, wd, enableWatchDog) + return buildClient(address, parallel, wd, enableWatchDog) } -func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend { +func buildClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend { if !enableWatchDog { wd = nil } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 827275cfcbb..b654e9c9792 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -39,6 +39,18 @@ func (c *Client) setBusy(v bool) { c.Unlock() } +func (c *Client) wdMark() { + if c.wd != nil { + c.wd.Mark(c.address) + } +} + +func (c *Client) wdUnMark() { + if c.wd != nil { + c.wd.UnMark(c.address) + } +} + func (c *Client) HealthCheck(ctx context.Context) (bool, error) { if !c.parallel { c.opMutex.Lock() @@ -76,10 +88,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ... } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -97,10 +107,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -118,10 +126,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -138,10 +144,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return err @@ -177,10 +181,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -197,10 +199,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -217,10 +217,8 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -237,10 +235,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -277,10 +273,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts } c.setBusy(true) defer c.setBusy(false) - if c.wd != nil { - c.wd.Mark(c.address) - defer c.wd.UnMark(c.address) - } + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -319,6 +313,8 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts .. } c.setBusy(true) defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -333,6 +329,8 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o c.opMutex.Lock() defer c.opMutex.Unlock() } + c.wdMark() + defer c.wdUnMark() c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -351,6 +349,8 @@ func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts .. } c.setBusy(true) defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -367,6 +367,8 @@ func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts } c.setBusy(true) defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -383,6 +385,8 @@ func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc. } c.setBusy(true) defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index d85da6c10f2..de0662e62c1 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -80,6 +80,9 @@ ENTRY: if e.IsDir() { continue } + if strings.HasSuffix(e.Name(), ".log") { + continue + } // Skip the llama.cpp variants if we are autoDetecting // But we always load the fallback variant if it exists @@ -265,12 +268,12 @@ func selectGRPCProcess(backend, assetDir string, f16 bool) string { // starts the grpcModelProcess for the backend, and returns a grpc client // It also loads the model -func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) { - return func(modelName, modelFile string) (ModelAddress, error) { +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*Model, error) { + return func(modelName, modelFile string) (*Model, error) { log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o) - var client ModelAddress + var client *Model getFreeAddress := func() (string, error) { port, err := freeport.GetFreePort() @@ -298,26 +301,26 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string log.Debug().Msgf("external backend is file: %+v", fi) serverAddress, err := getFreeAddress() if err != nil { - return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) } // Make sure the process is executable if err := ml.startProcess(uri, o.model, serverAddress); err != nil { log.Error().Err(err).Str("path", uri).Msg("failed to launch ") - return "", err + return nil, err } log.Debug().Msgf("GRPC Service Started") - client = ModelAddress(serverAddress) + client = NewModel(serverAddress) } else { log.Debug().Msg("external backend is uri") // address - client = ModelAddress(uri) + client = NewModel(uri) } } else { grpcProcess := backendPath(o.assetDir, backend) if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil { - return "", fmt.Errorf("grpc process not found in assetdir: %s", err.Error()) + return nil, fmt.Errorf("grpc process not found in assetdir: %s", err.Error()) } if autoDetect { @@ -329,12 +332,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string // Check if the file exists if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { - return "", fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) + return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) } serverAddress, err := getFreeAddress() if err != nil { - return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) } args := []string{} @@ -344,12 +347,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string // Make sure the process is executable in any circumstance if err := ml.startProcess(grpcProcess, o.model, serverAddress, args...); err != nil { - return "", err + return nil, err } log.Debug().Msgf("GRPC Service Started") - client = ModelAddress(serverAddress) + client = NewModel(serverAddress) } // Wait for the service to start up @@ -369,7 +372,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string if !ready { log.Debug().Msgf("GRPC Service NOT ready") - return "", fmt.Errorf("grpc service not ready") + return nil, fmt.Errorf("grpc service not ready") } options := *o.gRPCOptions @@ -380,27 +383,16 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options) if err != nil { - return "", fmt.Errorf("could not load model: %w", err) + return nil, fmt.Errorf("could not load model: %w", err) } if !res.Success { - return "", fmt.Errorf("could not load model (no success): %s", res.Message) + return nil, fmt.Errorf("could not load model (no success): %s", res.Message) } return client, nil } } -func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) { - if parallel { - return addr.GRPC(parallel, ml.wd), nil - } - - if _, ok := ml.grpcClients[string(addr)]; !ok { - ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd) - } - return ml.grpcClients[string(addr)], nil -} - func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { o := NewOptions(opts...) @@ -425,7 +417,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel") return nil, err } - } var backendToConsume string @@ -438,26 +429,28 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e backendToConsume = backend } - addr, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o)) + model, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o)) if err != nil { return nil, err } - return ml.resolveAddress(addr, o.parallelRequests) + return model.GRPC(o.parallelRequests, ml.wd), nil } func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { o := NewOptions(opts...) ml.mu.Lock() + // Return earlier if we have a model already loaded // (avoid looping through all the backends) - if m := ml.CheckIsLoaded(o.model); m != "" { + if m := ml.CheckIsLoaded(o.model); m != nil { log.Debug().Msgf("Model '%s' already loaded", o.model) ml.mu.Unlock() - return ml.resolveAddress(m, o.parallelRequests) + return m.GRPC(o.parallelRequests, ml.wd), nil } + // If we can have only one backend active, kill all the others (except external backends) if o.singleActiveBackend { log.Debug().Msgf("Stopping all backends except '%s'", o.model) diff --git a/pkg/model/loader.go b/pkg/model/loader.go index b2570c715d2..c1ed01dc0be 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -10,67 +10,28 @@ import ( "github.com/mudler/LocalAI/pkg/templates" - "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/utils" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) -// Rather than pass an interface{} to the prompt template: -// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file -// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values. -type PromptTemplateData struct { - SystemPrompt string - SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_ - Input string - Instruction string - Functions []functions.Function - MessageIndex int -} - -type ChatMessageTemplateData struct { - SystemPrompt string - Role string - RoleName string - FunctionName string - Content string - MessageIndex int - Function bool - FunctionCall interface{} - LastMessage bool -} - // new idea: what if we declare a struct of these here, and use a loop to check? // TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl type ModelLoader struct { - ModelPath string - mu sync.Mutex - // TODO: this needs generics - grpcClients map[string]grpc.Backend - models map[string]ModelAddress + ModelPath string + mu sync.Mutex + models map[string]*Model grpcProcesses map[string]*process.Process templates *templates.TemplateCache wd *WatchDog } -type ModelAddress string - -func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) grpc.Backend { - enableWD := false - if wd != nil { - enableWD = true - } - return grpc.NewClient(string(m), parallel, wd, enableWD) -} - func NewModelLoader(modelPath string) *ModelLoader { nml := &ModelLoader{ ModelPath: modelPath, - grpcClients: make(map[string]grpc.Backend), - models: make(map[string]ModelAddress), + models: make(map[string]*Model), templates: templates.NewTemplateCache(modelPath), grpcProcesses: make(map[string]*process.Process), } @@ -141,12 +102,12 @@ FILE: return models, nil } -func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (ModelAddress, error)) (ModelAddress, error) { +func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) { ml.mu.Lock() defer ml.mu.Unlock() // Check if we already have a loaded model - if model := ml.CheckIsLoaded(modelName); model != "" { + if model := ml.CheckIsLoaded(modelName); model != nil { return model, nil } @@ -156,17 +117,9 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( model, err := loader(modelName, modelFile) if err != nil { - return "", err + return nil, err } - // TODO: Add a helper method to iterate all prompt templates associated with a config if and only if it's YAML? - // Minor perf loss here until this is fixed, but we initialize on first request - - // // If there is a prompt template, load it - // if err := ml.loadTemplateIfExists(modelName); err != nil { - // return nil, err - // } - ml.models[modelName] = model return model, nil } @@ -184,55 +137,29 @@ func (ml *ModelLoader) stopModel(modelName string) error { return fmt.Errorf("model %s not found", modelName) } return nil - //return ml.deleteProcess(modelName) } -func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress { - var client grpc.Backend - if m, ok := ml.models[s]; ok { - log.Debug().Msgf("Model already loaded in memory: %s", s) - if c, ok := ml.grpcClients[s]; ok { - client = c - } else { - client = m.GRPC(false, ml.wd) - } - alive, err := client.HealthCheck(context.Background()) - if !alive { - log.Warn().Msgf("GRPC Model not responding: %s", err.Error()) - log.Warn().Msgf("Deleting the process in order to recreate it") - if !ml.grpcProcesses[s].IsAlive() { - log.Debug().Msgf("GRPC Process is not responding: %s", s) - // stop and delete the process, this forces to re-load the model and re-create again the service - err := ml.deleteProcess(s) - if err != nil { - log.Error().Err(err).Str("process", s).Msg("error stopping process") - } - return "" - } - } - - return m +func (ml *ModelLoader) CheckIsLoaded(s string) *Model { + m, ok := ml.models[s] + if !ok { + return nil } - return "" -} - -const ( - ChatPromptTemplate templates.TemplateType = iota - ChatMessageTemplate - CompletionPromptTemplate - EditPromptTemplate - FunctionsPromptTemplate -) - -func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) { - // TODO: should this check be improved? - if templateType == ChatMessageTemplate { - return "", fmt.Errorf("invalid templateType: ChatMessage") + log.Debug().Msgf("Model already loaded in memory: %s", s) + alive, err := m.GRPC(false, ml.wd).HealthCheck(context.Background()) + if !alive { + log.Warn().Msgf("GRPC Model not responding: %s", err.Error()) + log.Warn().Msgf("Deleting the process in order to recreate it") + if !ml.grpcProcesses[s].IsAlive() { + log.Debug().Msgf("GRPC Process is not responding: %s", s) + // stop and delete the process, this forces to re-load the model and re-create again the service + err := ml.deleteProcess(s) + if err != nil { + log.Error().Err(err).Str("process", s).Msg("error stopping process") + } + return nil + } } - return ml.templates.EvaluateTemplate(templateType, templateName, in) -} -func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { - return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) + return m } diff --git a/pkg/model/model.go b/pkg/model/model.go new file mode 100644 index 00000000000..26ddb8cc7c4 --- /dev/null +++ b/pkg/model/model.go @@ -0,0 +1,29 @@ +package model + +import grpc "github.com/mudler/LocalAI/pkg/grpc" + +type Model struct { + address string + client grpc.Backend +} + +func NewModel(address string) *Model { + return &Model{ + address: address, + } +} + +func (m *Model) GRPC(parallel bool, wd *WatchDog) grpc.Backend { + if m.client != nil { + return m.client + } + + enableWD := false + if wd != nil { + enableWD = true + } + + client := grpc.NewClient(m.address, parallel, wd, enableWD) + m.client = client + return client +} diff --git a/pkg/model/process.go b/pkg/model/process.go index 6a4fd326cff..5b751de8495 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -33,7 +33,7 @@ func (ml *ModelLoader) StopAllExcept(s string) error { func (ml *ModelLoader) deleteProcess(s string) error { if _, exists := ml.grpcProcesses[s]; exists { if err := ml.grpcProcesses[s].Stop(); err != nil { - return err + log.Error().Err(err).Msgf("(deleteProcess) error while deleting grpc process %s", s) } } delete(ml.grpcProcesses, s) diff --git a/pkg/model/template.go b/pkg/model/template.go new file mode 100644 index 00000000000..3dc850cf2ec --- /dev/null +++ b/pkg/model/template.go @@ -0,0 +1,52 @@ +package model + +import ( + "fmt" + + "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/templates" +) + +// Rather than pass an interface{} to the prompt template: +// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file +// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values. +type PromptTemplateData struct { + SystemPrompt string + SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_ + Input string + Instruction string + Functions []functions.Function + MessageIndex int +} + +type ChatMessageTemplateData struct { + SystemPrompt string + Role string + RoleName string + FunctionName string + Content string + MessageIndex int + Function bool + FunctionCall interface{} + LastMessage bool +} + +const ( + ChatPromptTemplate templates.TemplateType = iota + ChatMessageTemplate + CompletionPromptTemplate + EditPromptTemplate + FunctionsPromptTemplate +) + +func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) { + // TODO: should this check be improved? + if templateType == ChatMessageTemplate { + return "", fmt.Errorf("invalid templateType: ChatMessage") + } + return ml.templates.EvaluateTemplate(templateType, templateName, in) +} + +func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { + return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) +} diff --git a/pkg/model/loader_test.go b/pkg/model/template_test.go similarity index 100% rename from pkg/model/loader_test.go rename to pkg/model/template_test.go diff --git a/pkg/model/watchdog.go b/pkg/model/watchdog.go index b5381832e01..5702dda5443 100644 --- a/pkg/model/watchdog.go +++ b/pkg/model/watchdog.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog/log" ) +// WatchDog tracks all the requests from GRPC clients. // All GRPC Clients created by ModelLoader should have an associated injected // watchdog that will keep track of the state of each backend (busy or not) // and for how much time it has been busy. @@ -15,7 +16,6 @@ import ( // force a reload of the model // The watchdog runs as a separate go routine, // and the GRPC client talks to it via a channel to send status updates - type WatchDog struct { sync.Mutex timetable map[string]time.Time