From db1159b6511e8fa09e594f9db0fec6ab4e142468 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 16 Sep 2024 23:29:07 -0400 Subject: [PATCH] feat: auth v2 - supersedes #2894 (#3476) feat: auth v2 - supercedes #2894, metrics to follow later Signed-off-by: Dave Lee --- core/cli/run.go | 56 ++++++++++--------- core/config/application_config.go | 40 +++++++++++-- core/http/app.go | 49 +++++----------- core/http/middleware/auth.go | 93 +++++++++++++++++++++++++++++++ core/http/routes/elevenlabs.go | 7 +-- core/http/routes/jina.go | 3 +- core/http/routes/localai.go | 41 +++++++------- core/http/routes/openai.go | 89 +++++++++++++++-------------- core/http/routes/ui.go | 41 +++++++------- go.mod | 1 + go.sum | 2 + 11 files changed, 264 insertions(+), 158 deletions(-) create mode 100644 core/http/middleware/auth.go diff --git a/core/cli/run.go b/core/cli/run.go index 55ae0fd56b7..afb7204cdbd 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -41,31 +41,34 @@ type RunCMD struct { Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"` ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"` - Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` - CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` - CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` - LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"` - CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"` - UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` - APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` - DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"` - DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` - OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"` - Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"` - Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"` - Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"` - Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"` - Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"` - ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"` - SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"` - PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"` - ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"` - EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"` - WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"` - EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"` - WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"` - Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"` - DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"` + Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` + CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` + CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` + LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"` + CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"` + UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` + APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` + DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"` + DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` + OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"` + UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"` + DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"` + HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"` + Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"` + Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"` + Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"` + Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"` + Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"` + ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"` + SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"` + PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"` + ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"` + EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"` + WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"` + EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"` + WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"` + Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"` + DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"` } func (r *RunCMD) Run(ctx *cliContext.Context) error { @@ -97,6 +100,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithModelsURL(append(r.Models, r.ModelArgs...)...), config.WithOpaqueErrors(r.OpaqueErrors), config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan), + config.WithSubtleKeyComparison(r.UseSubtleKeyComparison), + config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet), + config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints), config.WithP2PNetworkID(r.Peer2PeerNetworkID), } diff --git a/core/config/application_config.go b/core/config/application_config.go index 947c4f136ba..afbf325f271 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -4,6 +4,7 @@ import ( "context" "embed" "encoding/json" + "regexp" "time" "github.com/mudler/LocalAI/pkg/xsysinfo" @@ -16,7 +17,6 @@ type ApplicationConfig struct { ModelPath string LibPath string UploadLimitMB, Threads, ContextSize int - DisableWebUI bool F16 bool Debug bool ImageDir string @@ -31,11 +31,17 @@ type ApplicationConfig struct { PreloadModelsFromPath string CORSAllowOrigins string ApiKeys []string - EnforcePredownloadScans bool - OpaqueErrors bool P2PToken string P2PNetworkID string + DisableWebUI bool + EnforcePredownloadScans bool + OpaqueErrors bool + UseSubtleKeyComparison bool + DisableApiKeyRequirementForHttpGet bool + HttpGetExemptedEndpoints []*regexp.Regexp + DisableGalleryEndpoint bool + ModelLibraryURL string Galleries []Gallery @@ -57,8 +63,6 @@ type ApplicationConfig struct { ModelsURL []string WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration - - DisableGalleryEndpoint bool } type AppOption func(*ApplicationConfig) @@ -327,6 +331,32 @@ func WithOpaqueErrors(opaque bool) AppOption { } } +func WithSubtleKeyComparison(subtle bool) AppOption { + return func(o *ApplicationConfig) { + o.UseSubtleKeyComparison = subtle + } +} + +func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption { + return func(o *ApplicationConfig) { + o.DisableApiKeyRequirementForHttpGet = required + } +} + +func WithHttpGetExemptedEndpoints(endpoints []string) AppOption { + return func(o *ApplicationConfig) { + o.HttpGetExemptedEndpoints = []*regexp.Regexp{} + for _, epr := range endpoints { + r, err := regexp.Compile(epr) + if err == nil && r != nil { + o.HttpGetExemptedEndpoints = append(o.HttpGetExemptedEndpoints, r) + } else { + log.Warn().Err(err).Str("regex", epr).Msg("Error while compiling HTTP Get Exemption regex, skipping this entry.") + } + } + } +} + // ToConfigLoaderOptions returns a slice of ConfigLoader Option. // Some options defined at the application level are going to be passed as defaults for // all the configuration for the models. diff --git a/core/http/app.go b/core/http/app.go index 6eb9c956351..fa9cd866926 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -3,13 +3,15 @@ package http import ( "embed" "errors" + "fmt" "net/http" - "strings" + "github.com/dave-gray101/v2keyauth" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/openai" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/LocalAI/core/config" @@ -137,37 +139,14 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi }) } - // Auth middleware checking if API key is valid. If no API key is set, no auth is required. - auth := func(c *fiber.Ctx) error { - if len(appConfig.ApiKeys) == 0 { - return c.Next() - } - - if len(appConfig.ApiKeys) == 0 { - return c.Next() - } - - authHeader := readAuthHeader(c) - if authHeader == "" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) - } - - // If it's a bearer token - authHeaderParts := strings.Split(authHeader, " ") - if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) - } - - apiKey := authHeaderParts[1] - for _, key := range appConfig.ApiKeys { - if apiKey == key { - return c.Next() - } - } - - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + kaConfig, err := middleware.GetKeyAuthConfig(appConfig) + if err != nil || kaConfig == nil { + return nil, fmt.Errorf("failed to create key auth config: %w", err) } + // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration + app.Use(v2keyauth.New(*kaConfig)) + if appConfig.CORS { var c func(ctx *fiber.Ctx) error if appConfig.CORSAllowOrigins == "" { @@ -192,13 +171,13 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi galleryService := services.NewGalleryService(appConfig) galleryService.Start(appConfig.Context, cl) - routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth) - routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth) - routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth) + routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig) + routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService) + routes.RegisterOpenAIRoutes(app, cl, ml, appConfig) if !appConfig.DisableWebUI { - routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth) + routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService) } - routes.RegisterJINARoutes(app, cl, ml, appConfig, auth) + routes.RegisterJINARoutes(app, cl, ml, appConfig) httpFS := http.FS(embedDirStatic) diff --git a/core/http/middleware/auth.go b/core/http/middleware/auth.go new file mode 100644 index 00000000000..bc8bcf80c20 --- /dev/null +++ b/core/http/middleware/auth.go @@ -0,0 +1,93 @@ +package middleware + +import ( + "crypto/subtle" + "errors" + + "github.com/dave-gray101/v2keyauth" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/keyauth" + "github.com/mudler/LocalAI/core/config" +) + +// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware +// Currently this requires an upstream patch - and feature patches are no longer accepted to v2 +// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate. + +func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) { + customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key"}, keyauth.ConfigDefault.AuthScheme) + if err != nil { + return nil, err + } + + return &v2keyauth.Config{ + CustomKeyLookup: customLookup, + Next: getApiKeyRequiredFilterFunction(applicationConfig), + Validator: getApiKeyValidationFunction(applicationConfig), + ErrorHandler: getApiKeyErrorHandler(applicationConfig), + AuthScheme: "Bearer", + }, nil +} + +func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler { + return func(ctx *fiber.Ctx, err error) error { + if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) { + if len(applicationConfig.ApiKeys) == 0 { + return ctx.Next() // if no keys are set up, any error we get here is not an error. + } + if applicationConfig.OpaqueErrors { + return ctx.SendStatus(403) + } + } + if applicationConfig.OpaqueErrors { + return ctx.SendStatus(500) + } + return err + } +} + +func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) { + + if applicationConfig.UseSubtleKeyComparison { + return func(ctx *fiber.Ctx, apiKey string) (bool, error) { + if len(applicationConfig.ApiKeys) == 0 { + return true, nil // If no keys are setup, accept everything + } + for _, validKey := range applicationConfig.ApiKeys { + if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 { + return true, nil + } + } + return false, v2keyauth.ErrMissingOrMalformedAPIKey + } + } + + return func(ctx *fiber.Ctx, apiKey string) (bool, error) { + if len(applicationConfig.ApiKeys) == 0 { + return true, nil // If no keys are setup, accept everything + } + for _, validKey := range applicationConfig.ApiKeys { + if apiKey == validKey { + return true, nil + } + } + return false, v2keyauth.ErrMissingOrMalformedAPIKey + } +} + +func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool { + if applicationConfig.DisableApiKeyRequirementForHttpGet { + return func(c *fiber.Ctx) bool { + if c.Method() != "GET" { + return false + } + for _, rx := range applicationConfig.HttpGetExemptedEndpoints { + if rx.MatchString(c.Path()) { + return true + } + } + return false + } + } + return func(c *fiber.Ctx) bool { return false } +} \ No newline at end of file diff --git a/core/http/routes/elevenlabs.go b/core/http/routes/elevenlabs.go index b20dec75240..73387c7bb76 100644 --- a/core/http/routes/elevenlabs.go +++ b/core/http/routes/elevenlabs.go @@ -10,12 +10,11 @@ import ( func RegisterElevenLabsRoutes(app *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, - appConfig *config.ApplicationConfig, - auth func(*fiber.Ctx) error) { + appConfig *config.ApplicationConfig) { // Elevenlabs - app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig)) - app.Post("/v1/sound-generation", auth, elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) + app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) } diff --git a/core/http/routes/jina.go b/core/http/routes/jina.go index 92f29224b0e..93125e6cb91 100644 --- a/core/http/routes/jina.go +++ b/core/http/routes/jina.go @@ -11,8 +11,7 @@ import ( func RegisterJINARoutes(app *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, - appConfig *config.ApplicationConfig, - auth func(*fiber.Ctx) error) { + appConfig *config.ApplicationConfig) { // POST endpoint to mimic the reranking app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig)) diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index f85fa8076ee..29fef37876d 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -15,33 +15,32 @@ func RegisterLocalAIRoutes(app *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, - galleryService *services.GalleryService, - auth func(*fiber.Ctx) error) { + galleryService *services.GalleryService) { app.Get("/swagger/*", swagger.HandlerDefault) // default // LocalAI API endpoints if !appConfig.DisableGalleryEndpoint { modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) - app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) - app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint()) + app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) + app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) - app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint()) - app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint()) + app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint()) + app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint()) + app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint()) + app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint()) + app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint()) + app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) } - app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) // Stores sl := model.NewModelLoader("") - app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig)) - app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig)) - app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig)) - app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig)) + app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig)) + app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig)) + app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig)) + app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig)) // Kubernetes health checks ok := func(c *fiber.Ctx) error { @@ -51,20 +50,20 @@ func RegisterLocalAIRoutes(app *fiber.App, app.Get("/healthz", ok) app.Get("/readyz", ok) - app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint()) + app.Get("/metrics", localai.LocalAIMetricsEndpoint()) // Experimental Backend Statistics Module backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now - app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService)) - app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService)) + app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) + app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) // p2p if p2p.IsP2PEnabled() { - app.Get("/api/p2p", auth, localai.ShowP2PNodes(appConfig)) - app.Get("/api/p2p/token", auth, localai.ShowP2PToken(appConfig)) + app.Get("/api/p2p", localai.ShowP2PNodes(appConfig)) + app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig)) } - app.Get("/version", auth, func(c *fiber.Ctx) error { + app.Get("/version", func(c *fiber.Ctx) error { return c.JSON(struct { Version string `json:"version"` }{Version: internal.PrintableVersion()}) diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index e190bc6d352..081daf70d80 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -11,66 +11,65 @@ import ( func RegisterOpenAIRoutes(app *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, - appConfig *config.ApplicationConfig, - auth func(*fiber.Ctx) error) { + appConfig *config.ApplicationConfig) { // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) - app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig)) // assistant - app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) - app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) - app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) // files - app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) - app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) - app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) - app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) - app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) - app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) - app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) - app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig)) + app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig)) + app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/files", openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig)) + app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig)) + app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig)) + app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig)) // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig)) - app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig)) // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig)) + app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig)) if appConfig.ImageDir != "" { app.Static("/generated-images", appConfig.ImageDir) @@ -81,6 +80,6 @@ func RegisterOpenAIRoutes(app *fiber.App, } // List models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) - app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) + app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml)) + app.Get("/models", openai.ListModelsEndpoint(cl, ml)) } diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index 6dfb3f433df..7b2c6ae70bc 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -59,8 +59,7 @@ func RegisterUIRoutes(app *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, - galleryService *services.GalleryService, - auth func(*fiber.Ctx) error) { + galleryService *services.GalleryService) { // keeps the state of models that are being installed from the UI var processingModels = NewModelOpCache() @@ -85,10 +84,10 @@ func RegisterUIRoutes(app *fiber.App, return processingModelsData, taskTypes } - app.Get("/", auth, localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus)) + app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus)) if p2p.IsP2PEnabled() { - app.Get("/p2p", auth, func(c *fiber.Ctx) error { + app.Get("/p2p", func(c *fiber.Ctx) error { summary := fiber.Map{ "Title": "LocalAI - P2P dashboard", "Version": internal.PrintableVersion(), @@ -104,17 +103,17 @@ func RegisterUIRoutes(app *fiber.App, }) /* show nodes live! */ - app.Get("/p2p/ui/workers", auth, func(c *fiber.Ctx) error { + app.Get("/p2p/ui/workers", func(c *fiber.Ctx) error { return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)))) }) - app.Get("/p2p/ui/workers-federation", auth, func(c *fiber.Ctx) error { + app.Get("/p2p/ui/workers-federation", func(c *fiber.Ctx) error { return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)))) }) - app.Get("/p2p/ui/workers-stats", auth, func(c *fiber.Ctx) error { + app.Get("/p2p/ui/workers-stats", func(c *fiber.Ctx) error { return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)))) }) - app.Get("/p2p/ui/workers-federation-stats", auth, func(c *fiber.Ctx) error { + app.Get("/p2p/ui/workers-federation-stats", func(c *fiber.Ctx) error { return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)))) }) } @@ -122,7 +121,7 @@ func RegisterUIRoutes(app *fiber.App, if !appConfig.DisableGalleryEndpoint { // Show the Models page (all models) - app.Get("/browse", auth, func(c *fiber.Ctx) error { + app.Get("/browse", func(c *fiber.Ctx) error { term := c.Query("term") models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath) @@ -167,7 +166,7 @@ func RegisterUIRoutes(app *fiber.App, // Show the models, filtered from the user input // https://htmx.org/examples/active-search/ - app.Post("/browse/search/models", auth, func(c *fiber.Ctx) error { + app.Post("/browse/search/models", func(c *fiber.Ctx) error { form := struct { Search string `form:"search"` }{} @@ -188,7 +187,7 @@ func RegisterUIRoutes(app *fiber.App, // This route is used when the "Install" button is pressed, we submit here a new job to the gallery service // https://htmx.org/examples/progress-bar/ - app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error { + app.Post("/browse/install/model/:id", func(c *fiber.Ctx) error { galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests! log.Debug().Msgf("UI job submitted to install : %+v\n", galleryID) @@ -215,7 +214,7 @@ func RegisterUIRoutes(app *fiber.App, // This route is used when the "Install" button is pressed, we submit here a new job to the gallery service // https://htmx.org/examples/progress-bar/ - app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error { + app.Post("/browse/delete/model/:id", func(c *fiber.Ctx) error { galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests! log.Debug().Msgf("UI job submitted to delete : %+v\n", galleryID) var galleryName = galleryID @@ -255,7 +254,7 @@ func RegisterUIRoutes(app *fiber.App, // Display the job current progress status // If the job is done, we trigger the /browse/job/:uid route // https://htmx.org/examples/progress-bar/ - app.Get("/browse/job/progress/:uid", auth, func(c *fiber.Ctx) error { + app.Get("/browse/job/progress/:uid", func(c *fiber.Ctx) error { jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests! status := galleryService.GetStatus(jobUID) @@ -279,7 +278,7 @@ func RegisterUIRoutes(app *fiber.App, // this route is hit when the job is done, and we display the // final state (for now just displays "Installation completed") - app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error { + app.Get("/browse/job/:uid", func(c *fiber.Ctx) error { jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests! status := galleryService.GetStatus(jobUID) @@ -303,7 +302,7 @@ func RegisterUIRoutes(app *fiber.App, } // Show the Chat page - app.Get("/chat/:model", auth, func(c *fiber.Ctx) error { + app.Get("/chat/:model", func(c *fiber.Ctx) error { backendConfigs, _ := services.ListModels(cl, ml, "", true) summary := fiber.Map{ @@ -318,7 +317,7 @@ func RegisterUIRoutes(app *fiber.App, return c.Render("views/chat", summary) }) - app.Get("/talk/", auth, func(c *fiber.Ctx) error { + app.Get("/talk/", func(c *fiber.Ctx) error { backendConfigs, _ := services.ListModels(cl, ml, "", true) if len(backendConfigs) == 0 { @@ -338,7 +337,7 @@ func RegisterUIRoutes(app *fiber.App, return c.Render("views/talk", summary) }) - app.Get("/chat/", auth, func(c *fiber.Ctx) error { + app.Get("/chat/", func(c *fiber.Ctx) error { backendConfigs, _ := services.ListModels(cl, ml, "", true) @@ -359,7 +358,7 @@ func RegisterUIRoutes(app *fiber.App, return c.Render("views/chat", summary) }) - app.Get("/text2image/:model", auth, func(c *fiber.Ctx) error { + app.Get("/text2image/:model", func(c *fiber.Ctx) error { backendConfigs := cl.GetAllBackendConfigs() summary := fiber.Map{ @@ -374,7 +373,7 @@ func RegisterUIRoutes(app *fiber.App, return c.Render("views/text2image", summary) }) - app.Get("/text2image/", auth, func(c *fiber.Ctx) error { + app.Get("/text2image/", func(c *fiber.Ctx) error { backendConfigs := cl.GetAllBackendConfigs() @@ -395,7 +394,7 @@ func RegisterUIRoutes(app *fiber.App, return c.Render("views/text2image", summary) }) - app.Get("/tts/:model", auth, func(c *fiber.Ctx) error { + app.Get("/tts/:model", func(c *fiber.Ctx) error { backendConfigs := cl.GetAllBackendConfigs() summary := fiber.Map{ @@ -410,7 +409,7 @@ func RegisterUIRoutes(app *fiber.App, return c.Render("views/tts", summary) }) - app.Get("/tts/", auth, func(c *fiber.Ctx) error { + app.Get("/tts/", func(c *fiber.Ctx) error { backendConfigs := cl.GetAllBackendConfigs() diff --git a/go.mod b/go.mod index 57202ad2990..a3359abf661 100644 --- a/go.mod +++ b/go.mod @@ -74,6 +74,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect + github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 // indirect github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect diff --git a/go.sum b/go.sum index ab64b84a441..1dd44a5b2ed 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/creachadair/otp v0.4.2 h1:ngNMaD6Tzd7UUNRFyed7ykZFn/Wr5sSs5ffqZWm9pu8 github.com/creachadair/otp v0.4.2/go.mod h1:DqV9hJyUbcUme0pooYfiFvvMe72Aua5sfhNzwfZvk40= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 h1:flLYmnQFZNo04x2NPehMbf30m7Pli57xwZ0NFqR/hb0= +github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2/go.mod h1:NtWqRzAp/1tw+twkW8uuBenEVVYndEAZACWU3F3xdoQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=