From a4ed802c82abb0490b305b6b4647335a8bc229fb Mon Sep 17 00:00:00 2001 From: l0calh0st8080 <142974401+l0calh0st8080@users.noreply.github.com> Date: Wed, 20 Sep 2023 00:15:19 +0530 Subject: [PATCH 1/3] [wip] feat: keymgmt --- Makefile | 2 +- auth/handlers/auth.go | 50 ++- auth/handlers/keys.go | 160 +++++++++ auth/main.go | 10 +- auth/model/keys.go | 8 + auth/model/setting.go | 82 +++++ auth/pkg/constants/aes.go | 5 + auth/pkg/constants/providerkeys.go | 12 + auth/pkg/db/db.go | 7 + auth/pkg/db/postgres/providerkeys.go | 32 ++ auth/pkg/db/postgres/setting.go | 17 + auth/pkg/providerkeys/openai/openai.go | 197 +++++++++++ auth/pkg/providerkeys/openai/types.go | 17 + auth/pkg/providerkeys/provider.go | 28 ++ auth/routes/routes.go | 4 + numexa-common/encryption/aes.go | 74 ++++ .../postgresql/0004_key_tables.down.sql | 11 + .../postgresql/0004_key_tables.up.sql | 49 +++ .../postgresql/postgresql-db/models.go | 60 +++- .../postgresql/postgresql-db/queries.sql.go | 320 +++++++++++++++++- numexa-common/postgresql/queries.sql | 42 +++ 21 files changed, 1171 insertions(+), 16 deletions(-) create mode 100644 auth/handlers/keys.go create mode 100644 auth/model/keys.go create mode 100644 auth/model/setting.go create mode 100644 auth/pkg/constants/aes.go create mode 100644 auth/pkg/constants/providerkeys.go create mode 100644 auth/pkg/db/postgres/providerkeys.go create mode 100644 auth/pkg/db/postgres/setting.go create mode 100644 auth/pkg/providerkeys/openai/openai.go create mode 100644 auth/pkg/providerkeys/openai/types.go create mode 100644 auth/pkg/providerkeys/provider.go create mode 100644 numexa-common/encryption/aes.go create mode 100644 numexa-common/postgresql/0004_key_tables.down.sql create mode 100644 numexa-common/postgresql/0004_key_tables.up.sql diff --git a/Makefile b/Makefile index 1e482ba..0bf6f0c 100644 --- a/Makefile +++ b/Makefile @@ -13,5 +13,5 @@ up: docker-compose up -d all: auth monger vibe -.PHONY: auth monger vibe up +.PHONY: all auth monger vibe up diff --git a/auth/handlers/auth.go b/auth/handlers/auth.go index 33dd128..04adb41 100644 --- a/auth/handlers/auth.go +++ b/auth/handlers/auth.go @@ -325,6 +325,54 @@ func generateJWTToken(user postgresql_db.User, jwtSigningKey string) (string, er return tokenString, err } +// SHOULD ONLY BE USED FOR TESTING +// DONOT USE IN PRODUCTION +func (h *Handler) DummyAuthMiddleware(c *fiber.Ctx) error { + tokenString := c.Get("Authorization") + + // Remove the "Bearer " prefix from the token string + tokenString = strings.TrimPrefix(tokenString, "Bearer ") + + // Parse the token + token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Check the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(h.JWTSigningKey), nil + }) + + // if err != nil { + // return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + // "message": "Unauthorized", + // }) + // } + + // Get the user ID from the token's claims + var userID float64 + if claims, ok := token.Claims.(jwt.MapClaims); ok { + if id, exists := claims["user_id"]; exists { + userID = id.(float64) + c.Locals("user_id", userID) // Set the user ID in locals for other handlers to access + } + if email, exists := claims["email"]; exists { + c.Locals("user_email", email) // Set the user ID in locals for other handlers to access + } + if name, exists := claims["name"]; exists { + c.Locals("name", name) // Set the user ID in locals for other handlers to access + } + if organizationID, exists := claims["organization_id"]; exists { + c.Locals("organization_id", organizationID) // Set the user ID in locals for other handlers to access + } + + } + + // Check if the token is still valid (not invalidated by logout) + + // Token is valid, proceed to the next handler + return c.Next() +} + func (h *Handler) AuthMiddleware(c *fiber.Ctx) error { tokenString := c.Get("Authorization") @@ -357,7 +405,7 @@ func (h *Handler) AuthMiddleware(c *fiber.Ctx) error { c.Locals("user_email", email) // Set the user ID in locals for other handlers to access } if name, exists := claims["name"]; exists { - c.Locals("user_name", name) // Set the user ID in locals for other handlers to access + c.Locals("name", name) // Set the user ID in locals for other handlers to access } if organizationID, exists := claims["organization_id"]; exists { c.Locals("organization_id", organizationID) // Set the user ID in locals for other handlers to access diff --git a/auth/handlers/keys.go b/auth/handlers/keys.go new file mode 100644 index 0000000..ffe4d3d --- /dev/null +++ b/auth/handlers/keys.go @@ -0,0 +1,160 @@ +package handlers + +import ( + "encoding/json" + "strconv" + + "github.com/NumexaHQ/captainCache/model" + "github.com/NumexaHQ/captainCache/pkg/providerkeys" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/sirupsen/logrus" +) + +func (h *Handler) AddProviderKeys(c *fiber.Ctx) error { + var reqBody model.ProviderKeys + if err := c.BodyParser(&reqBody); err != nil { + logrus.WithError(err).Error("Error parsing request body") + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "message": "Invalid request body", + }) + } + + userId := c.Locals("user_id").(float64) + orgId := c.Locals("organization_id").(float64) + + rByte, err := json.Marshal(reqBody) + if err != nil { + logrus.WithError(err).Error("Error marshalling request body") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + keyProvider, err := providerkeys.GetProvider(reqBody.Provider, rByte, false) + if err != nil { + logrus.WithError(err).Error("Error getting provider") + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "message": "Invalid provider", + }) + } + + if keyProvider.KeyExists(c.Context(), h.DB, reqBody.Name) { + logrus.WithError(err).Error("Key already exists") + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "message": "Key already exists with this name", + }) + } + + // generate uuid for key + keyuuid, err := generateUUID() + if err != nil { + logrus.WithError(err).Error("Error generating uuid") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + err = keyProvider.PushKeysToDB(c.Context(), h.DB, reqBody.Name, keyuuid, int32(userId), reqBody.ProjectId, int32(orgId)) + if err != nil { + logrus.WithError(err).Error("Error pushing keys to db") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "message": "Keys added successfully", + }) +} + +func (h *Handler) GetProviderKeys(c *fiber.Ctx) error { + // userId := c.Locals("user_id").(float64) + // orgId := c.Locals("organization_id").(float64) + projectId := c.Params("project_id") + projectIdInt, err := strconv.Atoi(projectId) + if err != nil { + logrus.WithError(err).Error("Error converting project id to int") + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "message": "Invalid project id", + }) + } + + keys, err := h.DB.GetProviderKeysByProjectId(c.Context(), int32(projectIdInt)) + if err != nil { + logrus.WithError(err).Error("Error getting provider keys") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + resp := []model.ProviderKeys{} + + for _, key := range keys { + secrets, err := h.DB.GetProviderSecretByProviderId(c.Context(), key.ID) + if err != nil { + if err.Error() == "no rows in result set" { + logrus.WithError(err).Error("Error getting provider secret") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + logrus.WithError(err).Error("Error getting provider secret") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + secretsMap := make(map[string]string) + for _, sv := range secrets { + secretsMap[sv.Type] = sv.Key + } + + // here kp.Keys is encrypted keys + kp := model.ProviderKeys{ + Name: key.Name, + Provider: key.Provider, + Keys: secretsMap, + ProjectId: key.ProjectID, + } + + kpB, err := json.Marshal(kp) + if err != nil { + logrus.WithError(err).Error("Error marshalling provider keys") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + provider, err := providerkeys.GetProvider(key.Provider, kpB, true) + if err != nil { + logrus.WithError(err).Error("Error getting provider") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + decryptedKeys, err := provider.GetDecryptedKeys(c.Context(), h.DB) + if err != nil { + logrus.WithError(err).Error("Error getting decrypted keys") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + // updating kp.Keys with decrypted keys + kp.Keys = decryptedKeys + + resp = append(resp, kp) + } + + return c.Status(fiber.StatusOK).JSON(resp) +} + +func generateUUID() (string, error) { + uuid, err := uuid.NewRandom() + if err != nil { + logrus.WithError(err).Error("Error generating uuid") + return "", err + } + return uuid.String(), nil +} diff --git a/auth/main.go b/auth/main.go index d4c6938..6a99dbb 100644 --- a/auth/main.go +++ b/auth/main.go @@ -1,8 +1,10 @@ package main import ( + "context" "os" + "github.com/NumexaHQ/captainCache/model" commonConstants "github.com/NumexaHQ/captainCache/numexa-common/constants" nxdb "github.com/NumexaHQ/captainCache/pkg/db" "github.com/NumexaHQ/captainCache/routes" @@ -36,7 +38,13 @@ func main() { err := db.Init() if err != nil { - log.Fatal("Failed to initialize database") + log.Fatal("Failed to initialize database: ", err) + } + + // init AES setting + err = model.InitializeAESSetting(context.Background(), db) + if err != nil { + log.Fatal("Failed to initialize AES setting: ", err) } // Create a new Fiber app diff --git a/auth/model/keys.go b/auth/model/keys.go new file mode 100644 index 0000000..bc651de --- /dev/null +++ b/auth/model/keys.go @@ -0,0 +1,8 @@ +package model + +type ProviderKeys struct { + Name string `json:"name"` + Provider string `json:"provider" validate:"required" enum:"openai"` + Keys map[string]string `json:"keys"` + ProjectId int32 `json:"project_id"` +} diff --git a/auth/model/setting.go b/auth/model/setting.go new file mode 100644 index 0000000..b0af7b8 --- /dev/null +++ b/auth/model/setting.go @@ -0,0 +1,82 @@ +package model + +import ( + "context" + "crypto/aes" + "crypto/rand" + "encoding/hex" + "encoding/json" + + postgresql_db "github.com/NumexaHQ/captainCache/numexa-common/postgresql/postgresql-db" + "github.com/NumexaHQ/captainCache/pkg/constants" + nxDB "github.com/NumexaHQ/captainCache/pkg/db" +) + +type SettingValue struct { + Label string `json:"label"` + Value interface{} `json:"value"` + Description string `json:"description"` +} + +func InitializeAESSetting(ctx context.Context, db nxDB.DB) error { + // set aes_secret in setting table, if !exists + _, err := db.GetSetting(ctx, constants.AES_SECRET) + if err != nil { + key := make([]byte, 32) // 32 bytes for AES-256 + iv := make([]byte, aes.BlockSize) + _, err = rand.Read(key) + if err != nil { + return err + } + + _, err = rand.Read(iv) + if err != nil { + return err + } + + aesValue := &SettingValue{ + Label: "AES Encryption Setting", + Description: "AES Encryption Key-IV pair", + Value: map[string]string{ + "aes_iv": hex.EncodeToString(iv), + "aes_key": hex.EncodeToString(key), + }, + } + + rawAES, err := json.Marshal(aesValue) + if err != nil { + return err + } + rawMessageAES := json.RawMessage(rawAES) + + _, err = db.CreateSetting(ctx, postgresql_db.CreateSettingParams{ + Key: constants.AES_SECRET, + Value: rawMessageAES, + Visible: false, + }) + if err != nil { + return err + } + } + return nil +} + +func GetAESSettingValue(ctx context.Context, db nxDB.DB) (json.RawMessage, error) { + setting, err := db.GetSetting(ctx, constants.AES_SECRET) + if err != nil { + return nil, err + } + + var aesValue SettingValue + err = json.Unmarshal(setting.Value, &aesValue) + if err != nil { + return nil, err + } + + b, err := json.Marshal(aesValue.Value) + if err != nil { + return nil, err + } + + return json.RawMessage(b), nil +} diff --git a/auth/pkg/constants/aes.go b/auth/pkg/constants/aes.go new file mode 100644 index 0000000..cf3c9c4 --- /dev/null +++ b/auth/pkg/constants/aes.go @@ -0,0 +1,5 @@ +package constants + +const ( + AES_SECRET = "aes_secret" +) diff --git a/auth/pkg/constants/providerkeys.go b/auth/pkg/constants/providerkeys.go new file mode 100644 index 0000000..3a6a919 --- /dev/null +++ b/auth/pkg/constants/providerkeys.go @@ -0,0 +1,12 @@ +package constants + +// providers +const ( + PROVIDER_OPENAI = "openai" +) + +// keys +const ( + KEY_OPENAI_ORG = "openai_org" + KEY_OPENAI_KEY = "openai_key" +) diff --git a/auth/pkg/db/db.go b/auth/pkg/db/db.go index 65dffc3..6a3619a 100644 --- a/auth/pkg/db/db.go +++ b/auth/pkg/db/db.go @@ -35,4 +35,11 @@ type DB interface { GetAllApiKeysByUserId(ctx context.Context, userID int32) ([]postgresql_db.GetAllApiKeysByUserIdRow, error) GetProjectsByOrgId(ctx context.Context, orgID int32) ([]postgresql_db.Project, error) UpdateUserLastLogin(ctx context.Context, user postgresql_db.User) error + AddProviderKeys(ctx context.Context, pk postgresql_db.CreateProviderKeyParams) (postgresql_db.ProviderKey, error) + AddProviderSecrets(ctx context.Context, ps postgresql_db.CreateProviderSecretParams) (postgresql_db.ProviderSecret, error) + GetProviderSecretByProviderId(ctx context.Context, id int32) ([]postgresql_db.ProviderSecret, error) + GetProviderKeyByName(ctx context.Context, name string) (postgresql_db.ProviderKey, error) + GetProviderKeysByProjectId(ctx context.Context, projectID int32) ([]postgresql_db.ProviderKey, error) + CreateSetting(ctx context.Context, setting postgresql_db.CreateSettingParams) (postgresql_db.Setting, error) + GetSetting(ctx context.Context, key string) (postgresql_db.Setting, error) } diff --git a/auth/pkg/db/postgres/providerkeys.go b/auth/pkg/db/postgres/providerkeys.go new file mode 100644 index 0000000..6321ba5 --- /dev/null +++ b/auth/pkg/db/postgres/providerkeys.go @@ -0,0 +1,32 @@ +package postgres + +import ( + "context" + + postgresql_db "github.com/NumexaHQ/captainCache/numexa-common/postgresql/postgresql-db" +) + +func (p *Postgres) AddProviderKeys(ctx context.Context, pk postgresql_db.CreateProviderKeyParams) (postgresql_db.ProviderKey, error) { + queries := getPostgresQueries(p.db) + return queries.CreateProviderKey(ctx, pk) +} + +func (p *Postgres) AddProviderSecrets(ctx context.Context, ps postgresql_db.CreateProviderSecretParams) (postgresql_db.ProviderSecret, error) { + queries := getPostgresQueries(p.db) + return queries.CreateProviderSecret(ctx, ps) +} + +func (p *Postgres) GetProviderKeyByName(ctx context.Context, name string) (postgresql_db.ProviderKey, error) { + queries := getPostgresQueries(p.db) + return queries.GetProviderKeyByName(ctx, name) +} + +func (p *Postgres) GetProviderKeysByProjectId(ctx context.Context, projectID int32) ([]postgresql_db.ProviderKey, error) { + queries := getPostgresQueries(p.db) + return queries.GetProviderKeysByProjectID(ctx, projectID) +} + +func (p *Postgres) GetProviderSecretByProviderId(ctx context.Context, id int32) ([]postgresql_db.ProviderSecret, error) { + queries := getPostgresQueries(p.db) + return queries.GetProviderSecretsByProviderKeyID(ctx, id) +} diff --git a/auth/pkg/db/postgres/setting.go b/auth/pkg/db/postgres/setting.go new file mode 100644 index 0000000..3d87341 --- /dev/null +++ b/auth/pkg/db/postgres/setting.go @@ -0,0 +1,17 @@ +package postgres + +import ( + "context" + + postgresql_db "github.com/NumexaHQ/captainCache/numexa-common/postgresql/postgresql-db" +) + +func (p *Postgres) CreateSetting(ctx context.Context, setting postgresql_db.CreateSettingParams) (postgresql_db.Setting, error) { + queries := getPostgresQueries(p.db) + return queries.CreateSetting(ctx, setting) +} + +func (p *Postgres) GetSetting(ctx context.Context, key string) (postgresql_db.Setting, error) { + queries := getPostgresQueries(p.db) + return queries.GetSetting(ctx, key) +} diff --git a/auth/pkg/providerkeys/openai/openai.go b/auth/pkg/providerkeys/openai/openai.go new file mode 100644 index 0000000..8b43ec5 --- /dev/null +++ b/auth/pkg/providerkeys/openai/openai.go @@ -0,0 +1,197 @@ +package openai + +import ( + "context" + "encoding/json" + "time" + + "github.com/NumexaHQ/captainCache/model" + "github.com/NumexaHQ/captainCache/numexa-common/encryption" + postgresql_db "github.com/NumexaHQ/captainCache/numexa-common/postgresql/postgresql-db" + "github.com/NumexaHQ/captainCache/pkg/constants" + nxDB "github.com/NumexaHQ/captainCache/pkg/db" + "github.com/sirupsen/logrus" +) + +// todo: call validate on the keys +// Be careful passing this struct around, it contains sensitive, unencrypted information +func New(b []byte, isEncrypted bool) (*ProviderOpenAI, error) { + var payload Payload + err := json.Unmarshal(b, &payload) + if err != nil { + return nil, err + } + return &ProviderOpenAI{ + Payload: payload, + encrypted: isEncrypted, + }, nil +} + +func (o *ProviderOpenAI) EncryptKeys(ctx context.Context, db nxDB.DB) error { + // get the aes key from the setting table + aesValue, err := model.GetAESSettingValue(ctx, db) + if err != nil { + return err + } + + aes := encryption.AES{} + err = json.Unmarshal(aesValue, &aes) + if err != nil { + return err + } + + // encrypt the keys + // and set the encrypted flag to true + o.Payload.Keys.OpenAIOrg, err = aes.Encrypt(o.Payload.Keys.OpenAIOrg) + if err != nil { + return err + } + o.Payload.Keys.OpenAIKey, err = aes.Encrypt(o.Payload.Keys.OpenAIKey) + if err != nil { + return err + } + o.encrypted = true + + return nil + +} + +func (o *ProviderOpenAI) IsEncrypted() bool { + return o.encrypted +} + +func (o *ProviderOpenAI) GetEncryptedKeys(ctx context.Context, db nxDB.DB) (map[string]string, error) { + if !o.encrypted { + err := o.EncryptKeys(ctx, db) + if err != nil { + logrus.WithError(err).Error("error encrypting keys") + return nil, err + } + } + return map[string]string{ + constants.KEY_OPENAI_ORG: o.Payload.Keys.OpenAIOrg, + constants.KEY_OPENAI_KEY: o.Payload.Keys.OpenAIKey, + }, nil + +} + +func (o *ProviderOpenAI) GetDecryptedKeys(ctx context.Context, db nxDB.DB) (map[string]string, error) { + if o.encrypted { + err := o.DecryptKeys(ctx, db, o.Payload.Name) + if err != nil { + logrus.WithError(err).Error("error decrypting keys") + return nil, err + } + } + return map[string]string{ + constants.KEY_OPENAI_ORG: o.Payload.Keys.OpenAIOrg, + constants.KEY_OPENAI_KEY: o.Payload.Keys.OpenAIKey, + }, nil +} + +func (o *ProviderOpenAI) DecryptKeys(ctx context.Context, db nxDB.DB, name string) error { + // get the aes key from the setting table + // get the aes key from the setting table + aesValue, err := model.GetAESSettingValue(ctx, db) + if err != nil { + return err + } + + aes := encryption.AES{} + err = json.Unmarshal(aesValue, &aes) + if err != nil { + return err + } + + // decrypt the keys + // and set the encrypted flag to false + o.Payload.Keys.OpenAIOrg, err = aes.Decrypt(o.Payload.Keys.OpenAIOrg) + if err != nil { + return err + } + o.Payload.Keys.OpenAIKey, err = aes.Decrypt(o.Payload.Keys.OpenAIKey) + if err != nil { + return err + } + + o.encrypted = false + + return nil +} + +func (o *ProviderOpenAI) GetKeys() map[string]string { + return map[string]string{ + constants.KEY_OPENAI_ORG: o.Payload.Keys.OpenAIOrg, + constants.KEY_OPENAI_KEY: o.Payload.Keys.OpenAIKey, + } +} + +func (o *ProviderOpenAI) PushKeysToDB(ctx context.Context, db nxDB.DB, name, keyuuid string, userId, projectId, orgId int32) error { + // create a entry in the provider_keys table + // with the provider name, user id, project id, and the keys + // return error if any + pk, err := db.AddProviderKeys(ctx, postgresql_db.CreateProviderKeyParams{ + Name: name, + KeyUuid: keyuuid, + Provider: o.GetProviderName(), + CreatorID: userId, + ProjectID: projectId, + OrganizationID: orgId, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + if err != nil { + return err + } + + // create entry in the provider_secrets + // incase of openai, we have two keys + // openai_org and openai_key, so we need to loop through the keys + // and create a entry for each key + keys, err := o.GetEncryptedKeys(ctx, db) + if err != nil { + return err + } + for k, v := range keys { + _, err := db.AddProviderSecrets(ctx, postgresql_db.CreateProviderSecretParams{ + ProviderKeyID: pk.ID, + Key: v, + Type: k, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + if err != nil { + return err + } + } + + return nil +} + +func (o *ProviderOpenAI) GetProviderName() string { + return "openai" +} + +func (o *ProviderOpenAI) KeyExists(ctx context.Context, db nxDB.DB, name string) bool { + _, err := db.GetProviderKeyByName(ctx, name) + if err != nil { + logrus.WithError(err).Error("error getting provider key by name") + return false + } + return true +} + +// example request body + +/* +{ + "name": "openai-prod-key", + "provider": "openai", + "keys": { + "openai_org": "org-5asdf6asdfghasf7as6as8d", + "openai_key": "key-287asfdvuas66hasd6767agsd7aa" + } + +} + +*/ diff --git a/auth/pkg/providerkeys/openai/types.go b/auth/pkg/providerkeys/openai/types.go new file mode 100644 index 0000000..6f25c34 --- /dev/null +++ b/auth/pkg/providerkeys/openai/types.go @@ -0,0 +1,17 @@ +package openai + +type ProviderOpenAI struct { + Payload Payload `json:"payload"` + encrypted bool `json:"encrypted"` +} + +type Payload struct { + Provider string `json:"provider" validate:"required" enum:"openai"` + Keys OpenAIKeys `json:"keys" validate:"required"` + Name string `json:"name" validate:"required"` +} + +type OpenAIKeys struct { + OpenAIOrg string `json:"openai_org" validate:"required"` + OpenAIKey string `json:"openai_key" validate:"required"` +} diff --git a/auth/pkg/providerkeys/provider.go b/auth/pkg/providerkeys/provider.go new file mode 100644 index 0000000..02789f3 --- /dev/null +++ b/auth/pkg/providerkeys/provider.go @@ -0,0 +1,28 @@ +package providerkeys + +import ( + "context" + "fmt" + + "github.com/NumexaHQ/captainCache/pkg/constants" + nxDB "github.com/NumexaHQ/captainCache/pkg/db" + "github.com/NumexaHQ/captainCache/pkg/providerkeys/openai" +) + +func GetProvider(provider string, b []byte, isEncrypted bool) (Provider, error) { + switch provider { + case constants.PROVIDER_OPENAI: + return openai.New(b, isEncrypted) + default: + return nil, fmt.Errorf("invalid provider: %s", provider) + } +} + +type Provider interface { + GetKeys() map[string]string + IsEncrypted() bool + GetEncryptedKeys(ctx context.Context, db nxDB.DB) (map[string]string, error) + GetDecryptedKeys(ctx context.Context, db nxDB.DB) (map[string]string, error) + PushKeysToDB(ctx context.Context, db nxDB.DB, name, keyuuid string, userId, projectId, orgId int32) error + KeyExists(ctx context.Context, db nxDB.DB, name string) bool +} diff --git a/auth/routes/routes.go b/auth/routes/routes.go index c2e9b94..f7028ed 100644 --- a/auth/routes/routes.go +++ b/auth/routes/routes.go @@ -28,4 +28,8 @@ func Setup(app *fiber.App, db db.DB, jwtSigningKey string) { //GenerateApiKey app.Post("/generate_api_key", nxHandler.AuthMiddleware, nxHandler.CreateApiKey) app.Get("/get_api_key", nxHandler.AuthMiddleware, nxHandler.GetAPIkeyByUserId) + + // keymgmt + app.Post("/add_provider_keys", nxHandler.AuthMiddleware, nxHandler.AddProviderKeys) + app.Get("/get_provider_keys/:project_id", nxHandler.AuthMiddleware, nxHandler.GetProviderKeys) } diff --git a/numexa-common/encryption/aes.go b/numexa-common/encryption/aes.go new file mode 100644 index 0000000..2752e4e --- /dev/null +++ b/numexa-common/encryption/aes.go @@ -0,0 +1,74 @@ +package encryption + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/hex" +) + +type AES struct { + IV string `json:"aes_iv"` + Key string `json:"aes_key"` +} + +const blockSize = aes.BlockSize + +func (a *AES) Encrypt(plaintext string) (string, error) { + bKey, err := hex.DecodeString(a.Key) + if err != nil { + return "", err + } + bIV, err := hex.DecodeString(a.IV) + if err != nil { + return "", err + } + bPlaintext := PKCS5Padding([]byte(plaintext), blockSize, len(plaintext)) + block, err := aes.NewCipher(bKey) + if err != nil { + return "", err + } + ciphertext := make([]byte, len(bPlaintext)) + mode := cipher.NewCBCEncrypter(block, bIV) + mode.CryptBlocks(ciphertext, bPlaintext) + return hex.EncodeToString(ciphertext), nil +} + +func (a *AES) Decrypt(cipherText string) (string, error) { + bKey, err := hex.DecodeString(a.Key) + if err != nil { + return "", err + } + bIV, err := hex.DecodeString(a.IV) + if err != nil { + return "", err + } + cipherTextDecoded, err := hex.DecodeString(cipherText) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(bKey) + if err != nil { + return "", err + } + + mode := cipher.NewCBCDecrypter(block, bIV) + mode.CryptBlocks([]byte(cipherTextDecoded), []byte(cipherTextDecoded)) + return string(PKCS5UnPadding(cipherTextDecoded)), nil +} + +func PKCS5Padding(ciphertext []byte, blockSize int, after int) []byte { + padding := (blockSize - len(ciphertext)%blockSize) + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padtext...) +} + +func PKCS5UnPadding(src []byte) []byte { + length := len(src) + unpadding := int(src[length-1]) + if unpadding > length { + return src + } + return src[:(length - unpadding)] +} diff --git a/numexa-common/postgresql/0004_key_tables.down.sql b/numexa-common/postgresql/0004_key_tables.down.sql new file mode 100644 index 0000000..74a5412 --- /dev/null +++ b/numexa-common/postgresql/0004_key_tables.down.sql @@ -0,0 +1,11 @@ +START TRANSACTION; + +DROP TABLE IF EXISTS "public"."provider_keys"; + +DROP TABLE IF EXISTS "public"."provider_secrets"; + +DROP TABLE IF EXISTS "public"."nxa_api_key_property"; + +ALTER TABLE "public"."nxa_api_key" DROP COLUMN "nxa_api_key_property_id"; + +COMMIT; \ No newline at end of file diff --git a/numexa-common/postgresql/0004_key_tables.up.sql b/numexa-common/postgresql/0004_key_tables.up.sql new file mode 100644 index 0000000..3428558 --- /dev/null +++ b/numexa-common/postgresql/0004_key_tables.up.sql @@ -0,0 +1,49 @@ +START TRANSACTION; + +CREATE TABLE IF NOT EXISTS "public"."setting" ( + "id" SERIAL PRIMARY KEY, + "key" VARCHAR(255) NOT NULL UNIQUE, + "value" JSONB NOT NULL, + "visible" BOOLEAN NOT NULL DEFAULT FALSE, + "created_at" TIMESTAMP NOT NULL, + "updated_at" TIMESTAMP NOT NULL +); + + +CREATE TABLE IF NOT EXISTS "public"."provider_keys" ( + "id" SERIAL PRIMARY KEY, + "key_uuid" VARCHAR(255) NOT NULL UNIQUE, + "name" VARCHAR(255) NOT NULL, + "provider" VARCHAR(255) NOT NULL, + "creator_id" INTEGER NOT NULL REFERENCES users(id), + "organization_id" INTEGER NOT NULL REFERENCES organizations(id), + "project_id" INTEGER NOT NULL REFERENCES projects(id), + "created_at" TIMESTAMP NOT NULL, + "updated_at" TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS "public"."provider_secrets" ( + "id" SERIAL PRIMARY KEY, + "provider_key_id" INTEGER NOT NULL REFERENCES provider_keys(id), + "type" VARCHAR(255) NOT NULL, + "key" VARCHAR(255) NOT NULL, + "created_at" TIMESTAMP NOT NULL, + "updated_at" TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS "public"."nxa_api_key_property" ( + "id" SERIAL PRIMARY KEY, + "rate_limit" INTEGER NOT NULL, + "rate_limit_period" VARCHAR(255) NOT NULL, + "enforce_caching" BOOLEAN NOT NULL, + "overall_cost_limit" INTEGER NOT NULL, + "alert_on_threshold" INTEGER NOT NULL, + "provider_key_id" INTEGER NULL REFERENCES provider_keys(id), + "expires_at" TIMESTAMP NOT NULL, + "created_at" TIMESTAMP NOT NULL, + "updated_at" TIMESTAMP NOT NULL +); + +ALTER TABLE "public"."nxa_api_key" ADD COLUMN "nxa_api_key_property_id" INTEGER NULL REFERENCES nxa_api_key_property(id); + +COMMIT; \ No newline at end of file diff --git a/numexa-common/postgresql/postgresql-db/models.go b/numexa-common/postgresql/postgresql-db/models.go index 566a0f5..776aa1c 100644 --- a/numexa-common/postgresql/postgresql-db/models.go +++ b/numexa-common/postgresql/postgresql-db/models.go @@ -6,18 +6,32 @@ package postgresql_db import ( "database/sql" + "encoding/json" "time" ) type NxaApiKey struct { - ID int32 `json:"id"` - Name string `json:"name"` - ApiKey string `json:"api_key"` - UserID int32 `json:"user_id"` - ProjectID int32 `json:"project_id"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - ExpiresAt time.Time `json:"expires_at"` + ID int32 `json:"id"` + Name string `json:"name"` + ApiKey string `json:"api_key"` + UserID int32 `json:"user_id"` + ProjectID int32 `json:"project_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + ExpiresAt time.Time `json:"expires_at"` + NxaApiKeyPropertyID sql.NullInt32 `json:"nxa_api_key_property_id"` +} + +type NxaApiKeyProperty struct { + ID int32 `json:"id"` + RateLimit int32 `json:"rate_limit"` + RateLimitPeriod string `json:"rate_limit_period"` + EnforceCaching bool `json:"enforce_caching"` + OverallCostLimit int32 `json:"overall_cost_limit"` + AlertOnThreshold int32 `json:"alert_on_threshold"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } type Organization struct { @@ -41,6 +55,27 @@ type ProjectUser struct { RoleID int32 `json:"role_id"` } +type ProviderKey struct { + ID int32 `json:"id"` + KeyUuid string `json:"key_uuid"` + Name string `json:"name"` + Provider string `json:"provider"` + CreatorID int32 `json:"creator_id"` + OrganizationID int32 `json:"organization_id"` + ProjectID int32 `json:"project_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type ProviderSecret struct { + ID int32 `json:"id"` + ProviderKeyID int32 `json:"provider_key_id"` + Type string `json:"type"` + Key string `json:"key"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type Role struct { ID int32 `json:"id"` Name string `json:"name"` @@ -48,6 +83,15 @@ type Role struct { UpdatedAt time.Time `json:"updated_at"` } +type Setting struct { + ID int32 `json:"id"` + Key string `json:"key"` + Value json.RawMessage `json:"value"` + Visible bool `json:"visible"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type User struct { ID int32 `json:"id"` OrganizationID int32 `json:"organization_id"` diff --git a/numexa-common/postgresql/postgresql-db/queries.sql.go b/numexa-common/postgresql/postgresql-db/queries.sql.go index 5292562..62846a7 100644 --- a/numexa-common/postgresql/postgresql-db/queries.sql.go +++ b/numexa-common/postgresql/postgresql-db/queries.sql.go @@ -8,13 +8,14 @@ package postgresql_db import ( "context" "database/sql" + "encoding/json" "time" ) const createApiKey = `-- name: CreateApiKey :one INSERT INTO nxa_api_key (name, api_key, user_id, project_id, created_at, updated_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, name, api_key, user_id, project_id, created_at, updated_at, expires_at +RETURNING id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id ` type CreateApiKeyParams struct { @@ -47,6 +48,7 @@ func (q *Queries) CreateApiKey(ctx context.Context, arg CreateApiKeyParams) (Nxa &i.CreatedAt, &i.UpdatedAt, &i.ExpiresAt, + &i.NxaApiKeyPropertyID, ) return i, err } @@ -123,6 +125,117 @@ func (q *Queries) CreateProjectUser(ctx context.Context, arg CreateProjectUserPa return i, err } +const createProviderKey = `-- name: CreateProviderKey :one +INSERT INTO provider_keys (name, key_uuid, provider, creator_id, organization_id, project_id, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +RETURNING id, key_uuid, name, provider, creator_id, organization_id, project_id, created_at, updated_at +` + +type CreateProviderKeyParams struct { + Name string `json:"name"` + KeyUuid string `json:"key_uuid"` + Provider string `json:"provider"` + CreatorID int32 `json:"creator_id"` + OrganizationID int32 `json:"organization_id"` + ProjectID int32 `json:"project_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (q *Queries) CreateProviderKey(ctx context.Context, arg CreateProviderKeyParams) (ProviderKey, error) { + row := q.db.QueryRowContext(ctx, createProviderKey, + arg.Name, + arg.KeyUuid, + arg.Provider, + arg.CreatorID, + arg.OrganizationID, + arg.ProjectID, + arg.CreatedAt, + arg.UpdatedAt, + ) + var i ProviderKey + err := row.Scan( + &i.ID, + &i.KeyUuid, + &i.Name, + &i.Provider, + &i.CreatorID, + &i.OrganizationID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const createProviderSecret = `-- name: CreateProviderSecret :one +INSERT INTO provider_secrets (provider_key_id, type, key, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, provider_key_id, type, key, created_at, updated_at +` + +type CreateProviderSecretParams struct { + ProviderKeyID int32 `json:"provider_key_id"` + Type string `json:"type"` + Key string `json:"key"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (q *Queries) CreateProviderSecret(ctx context.Context, arg CreateProviderSecretParams) (ProviderSecret, error) { + row := q.db.QueryRowContext(ctx, createProviderSecret, + arg.ProviderKeyID, + arg.Type, + arg.Key, + arg.CreatedAt, + arg.UpdatedAt, + ) + var i ProviderSecret + err := row.Scan( + &i.ID, + &i.ProviderKeyID, + &i.Type, + &i.Key, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const createSetting = `-- name: CreateSetting :one +INSERT INTO setting (key, value, visible, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5) +RETURNING id, key, value, visible, created_at, updated_at +` + +type CreateSettingParams struct { + Key string `json:"key"` + Value json.RawMessage `json:"value"` + Visible bool `json:"visible"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (q *Queries) CreateSetting(ctx context.Context, arg CreateSettingParams) (Setting, error) { + row := q.db.QueryRowContext(ctx, createSetting, + arg.Key, + arg.Value, + arg.Visible, + arg.CreatedAt, + arg.UpdatedAt, + ) + var i Setting + err := row.Scan( + &i.ID, + &i.Key, + &i.Value, + &i.Visible, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const createUser = `-- name: CreateUser :one INSERT INTO users (name, organization_id, email, password, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6) @@ -163,7 +276,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e } const getAPIKeyByNameAndProjectId = `-- name: GetAPIKeyByNameAndProjectId :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at FROM nxa_api_key WHERE name = $1 AND project_id = $2 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE name = $1 AND project_id = $2 ` type GetAPIKeyByNameAndProjectIdParams struct { @@ -183,12 +296,13 @@ func (q *Queries) GetAPIKeyByNameAndProjectId(ctx context.Context, arg GetAPIKey &i.CreatedAt, &i.UpdatedAt, &i.ExpiresAt, + &i.NxaApiKeyPropertyID, ) return i, err } const getAPIkeyByApiKey = `-- name: GetAPIkeyByApiKey :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at FROM nxa_api_key WHERE api_key = $1 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE api_key = $1 ` func (q *Queries) GetAPIkeyByApiKey(ctx context.Context, apiKey string) (NxaApiKey, error) { @@ -203,12 +317,13 @@ func (q *Queries) GetAPIkeyByApiKey(ctx context.Context, apiKey string) (NxaApiK &i.CreatedAt, &i.UpdatedAt, &i.ExpiresAt, + &i.NxaApiKeyPropertyID, ) return i, err } const getAPIkeyByUserId = `-- name: GetAPIkeyByUserId :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at FROM nxa_api_key WHERE user_id = $1 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE user_id = $1 ` func (q *Queries) GetAPIkeyByUserId(ctx context.Context, userID int32) (NxaApiKey, error) { @@ -223,6 +338,7 @@ func (q *Queries) GetAPIkeyByUserId(ctx context.Context, userID int32) (NxaApiKe &i.CreatedAt, &i.UpdatedAt, &i.ExpiresAt, + &i.NxaApiKeyPropertyID, ) return i, err } @@ -440,8 +556,201 @@ func (q *Queries) GetProjects(ctx context.Context, organizationID int32) ([]Proj return items, nil } +const getProviderKey = `-- name: GetProviderKey :one +SELECT id, key_uuid, name, provider, creator_id, organization_id, project_id, created_at, updated_at FROM provider_keys WHERE id = $1 +` + +func (q *Queries) GetProviderKey(ctx context.Context, id int32) (ProviderKey, error) { + row := q.db.QueryRowContext(ctx, getProviderKey, id) + var i ProviderKey + err := row.Scan( + &i.ID, + &i.KeyUuid, + &i.Name, + &i.Provider, + &i.CreatorID, + &i.OrganizationID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getProviderKeyByName = `-- name: GetProviderKeyByName :one +SELECT id, key_uuid, name, provider, creator_id, organization_id, project_id, created_at, updated_at FROM provider_keys WHERE name = $1 +` + +func (q *Queries) GetProviderKeyByName(ctx context.Context, name string) (ProviderKey, error) { + row := q.db.QueryRowContext(ctx, getProviderKeyByName, name) + var i ProviderKey + err := row.Scan( + &i.ID, + &i.KeyUuid, + &i.Name, + &i.Provider, + &i.CreatorID, + &i.OrganizationID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getProviderKeyByUUID = `-- name: GetProviderKeyByUUID :one +SELECT id, key_uuid, name, provider, creator_id, organization_id, project_id, created_at, updated_at FROM provider_keys WHERE key_uuid = $1 +` + +func (q *Queries) GetProviderKeyByUUID(ctx context.Context, keyUuid string) (ProviderKey, error) { + row := q.db.QueryRowContext(ctx, getProviderKeyByUUID, keyUuid) + var i ProviderKey + err := row.Scan( + &i.ID, + &i.KeyUuid, + &i.Name, + &i.Provider, + &i.CreatorID, + &i.OrganizationID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getProviderKeysByProjectID = `-- name: GetProviderKeysByProjectID :many +SELECT id, key_uuid, name, provider, creator_id, organization_id, project_id, created_at, updated_at FROM provider_keys WHERE project_id = $1 +` + +func (q *Queries) GetProviderKeysByProjectID(ctx context.Context, projectID int32) ([]ProviderKey, error) { + rows, err := q.db.QueryContext(ctx, getProviderKeysByProjectID, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ProviderKey + for rows.Next() { + var i ProviderKey + if err := rows.Scan( + &i.ID, + &i.KeyUuid, + &i.Name, + &i.Provider, + &i.CreatorID, + &i.OrganizationID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getProviderSecret = `-- name: GetProviderSecret :one +SELECT id, provider_key_id, type, key, created_at, updated_at FROM provider_secrets WHERE id = $1 +` + +func (q *Queries) GetProviderSecret(ctx context.Context, id int32) (ProviderSecret, error) { + row := q.db.QueryRowContext(ctx, getProviderSecret, id) + var i ProviderSecret + err := row.Scan( + &i.ID, + &i.ProviderKeyID, + &i.Type, + &i.Key, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getProviderSecretByProviderKeyIDAndType = `-- name: GetProviderSecretByProviderKeyIDAndType :one +SELECT id, provider_key_id, type, key, created_at, updated_at FROM provider_secrets WHERE provider_key_id = $1 AND type = $2 +` + +type GetProviderSecretByProviderKeyIDAndTypeParams struct { + ProviderKeyID int32 `json:"provider_key_id"` + Type string `json:"type"` +} + +func (q *Queries) GetProviderSecretByProviderKeyIDAndType(ctx context.Context, arg GetProviderSecretByProviderKeyIDAndTypeParams) (ProviderSecret, error) { + row := q.db.QueryRowContext(ctx, getProviderSecretByProviderKeyIDAndType, arg.ProviderKeyID, arg.Type) + var i ProviderSecret + err := row.Scan( + &i.ID, + &i.ProviderKeyID, + &i.Type, + &i.Key, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getProviderSecretsByProviderKeyID = `-- name: GetProviderSecretsByProviderKeyID :many +SELECT id, provider_key_id, type, key, created_at, updated_at FROM provider_secrets WHERE provider_key_id = $1 +` + +func (q *Queries) GetProviderSecretsByProviderKeyID(ctx context.Context, providerKeyID int32) ([]ProviderSecret, error) { + rows, err := q.db.QueryContext(ctx, getProviderSecretsByProviderKeyID, providerKeyID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ProviderSecret + for rows.Next() { + var i ProviderSecret + if err := rows.Scan( + &i.ID, + &i.ProviderKeyID, + &i.Type, + &i.Key, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSetting = `-- name: GetSetting :one +SELECT id, key, value, visible, created_at, updated_at FROM setting WHERE key = $1 +` + +func (q *Queries) GetSetting(ctx context.Context, key string) (Setting, error) { + row := q.db.QueryRowContext(ctx, getSetting, key) + var i Setting + err := row.Scan( + &i.ID, + &i.Key, + &i.Value, + &i.Visible, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getTokenByProjectId = `-- name: GetTokenByProjectId :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at FROM nxa_api_key WHERE project_id = $1 AND api_key = $2 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE project_id = $1 AND api_key = $2 ` type GetTokenByProjectIdParams struct { @@ -461,6 +770,7 @@ func (q *Queries) GetTokenByProjectId(ctx context.Context, arg GetTokenByProject &i.CreatedAt, &i.UpdatedAt, &i.ExpiresAt, + &i.NxaApiKeyPropertyID, ) return i, err } diff --git a/numexa-common/postgresql/queries.sql b/numexa-common/postgresql/queries.sql index 7cbdd32..b1e966b 100644 --- a/numexa-common/postgresql/queries.sql +++ b/numexa-common/postgresql/queries.sql @@ -79,4 +79,46 @@ SELECT created_at, updated_at, expires_at, name, project_id, user_id FROM nxa_a -- name: UpdateUserLastLogin :one UPDATE users SET last_login = $1, total_logins = total_logins + 1 WHERE id = $2 RETURNING *; +-- name: CreateProviderKey :one +INSERT INTO provider_keys (name, key_uuid, provider, creator_id, organization_id, project_id, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +RETURNING *; + +-- name: CreateProviderSecret :one +INSERT INTO provider_secrets (provider_key_id, type, key, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5) +RETURNING *; + +-- name: GetProviderKey :one +SELECT * FROM provider_keys WHERE id = $1; + +-- name: GetProviderKeyByUUID :one +SELECT * FROM provider_keys WHERE key_uuid = $1; + +-- name: GetProviderKeyByName :one +SELECT * FROM provider_keys WHERE name = $1; + +-- name: GetProviderKeysByProjectID :many +SELECT * FROM provider_keys WHERE project_id = $1; + +-- name: GetProviderSecret :one +SELECT * FROM provider_secrets WHERE id = $1; + +-- name: GetProviderSecretsByProviderKeyID :many +SELECT * FROM provider_secrets WHERE provider_key_id = $1; + +-- name: GetProviderSecretsByProviderKeyID :many +SELECT * FROM provider_secrets WHERE provider_key_id = $1; + +-- name: GetProviderSecretByProviderKeyIDAndType :one +SELECT * FROM provider_secrets WHERE provider_key_id = $1 AND type = $2; + +-- name: CreateSetting :one +INSERT INTO setting (key, value, visible, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5) +RETURNING *; + +-- name: GetSetting :one +SELECT * FROM setting WHERE key = $1; + From 5b12df634918e92411c453022913ab068991c6a4 Mon Sep 17 00:00:00 2001 From: l0calh0st8080 <142974401+l0calh0st8080@users.noreply.github.com> Date: Wed, 20 Sep 2023 15:12:21 +0530 Subject: [PATCH 2/3] wip: add option to associate nxa key with provider key --- auth/handlers/auth.go | 60 ++++++-- auth/model/nxtoken.go | 42 +++++- auth/pkg/db/db.go | 4 +- auth/pkg/db/postgres/nxaapi.go | 12 ++ auth/pkg/db/postgres/postgres.go | 12 +- auth/pkg/db/postgres/providerkeys.go | 5 + .../postgresql/0004_key_tables.up.sql | 5 + .../postgresql/postgresql-db/models.go | 24 ++-- .../postgresql/postgresql-db/queries.sql.go | 132 ++++++++++++++++-- numexa-common/postgresql/queries.sql | 15 +- 10 files changed, 255 insertions(+), 56 deletions(-) create mode 100644 auth/pkg/db/postgres/nxaapi.go diff --git a/auth/handlers/auth.go b/auth/handlers/auth.go index 04adb41..bdd0a1e 100644 --- a/auth/handlers/auth.go +++ b/auth/handlers/auth.go @@ -29,13 +29,8 @@ func generateAPIKey() string { return fmt.Sprintf("sk-%s", string(b)) } -// func hashPassword(password string) string { -// hashPassword := utils.HashPassword(password) -// } - func (h *Handler) CreateApiKey(c *fiber.Ctx) error { - type RequestBody postgresql_db.NxaApiKey - var reqBody RequestBody + var reqBody model.GenerateNXTokenRequest if err := c.BodyParser(&reqBody); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ "message": "Invalid request body", @@ -113,16 +108,61 @@ func (h *Handler) CreateApiKey(c *fiber.Ctx) error { }) } - _, err = h.DB.CreateApiKey(c.Context(), postgresql_db.NxaApiKey{ + nxaAPIKey := postgresql_db.CreateApiKeyParams{ Name: reqBody.Name, ApiKey: apiKey, UserID: user.ID, ProjectID: reqBody.ProjectID, - ExpiresAt: time.Now().Add(time.Hour * 24 * 365), + Revoked: false, + Disabled: false, + ExpiresAt: time.Now().Add(time.Hour * 24 * 365), // this might not be respected, since expiry is set in the key property CreatedAt: time.Now(), UpdatedAt: time.Now(), - }) + } + + if reqBody.NxaProviderKeyID != 0 { + _, err := h.DB.GetProviderKeyById(c.Context(), reqBody.NxaProviderKeyID) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + log.Errorf("error getting provider key by id: %v", err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Error getting provider key by id", + }) + } else { + log.Errorf("provider key not found: %v", err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Something went wrong. Please contact the administrator", + }) + } + } + nxaAPIKey.ProviderKeyID = sql.NullInt32{Int32: reqBody.NxaProviderKeyID, Valid: true} + } + + // check if property exists + if reqBody.Property != (model.NXTokenPropertyRequest{}) { + expiry := time.Now().Add(time.Hour * 24 * 365) // todo: default expiry is 1 year, need to change this to never expire + if reqBody.Property.ExpiresAt.IsZero() { + reqBody.Property.ExpiresAt = expiry + } + nxkp, err := h.DB.CreateNXAKeyProperty(c.Context(), postgresql_db.CreateNXAKeyPropertyParams{ + RateLimit: reqBody.Property.RateLimit, + RateLimitPeriod: reqBody.Property.RateLimitPeriod, + EnforceCaching: reqBody.Property.EnforceCaching, + OverallCostLimit: reqBody.Property.OverallCostLimit, + AlertOnThreshold: reqBody.Property.AlertOnThreshold, + ExpiresAt: expiry, + }) + if err != nil { + log.Errorf("error creating nxa key property: %v", err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Internal server error", + }) + } + + nxaAPIKey.NxaApiKeyPropertyID = sql.NullInt32{Int32: nxkp.ID, Valid: true} + } + _, err = h.DB.CreateApiKey(c.Context(), nxaAPIKey) if err != nil { log.Errorf("error generating api key: %v", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ @@ -177,8 +217,6 @@ func (h *Handler) RegisterHandler(c *fiber.Ctx) error { // create organization organization.Name = utils.GenerateOrganizationName() - - log.Infof("organization: %+v", organization) organization, err = h.DB.CreateOrganization(c.Context(), organization) if err != nil { log.Errorf("error creating organization: %v", err) diff --git a/auth/model/nxtoken.go b/auth/model/nxtoken.go index 99860de..40ac2d0 100644 --- a/auth/model/nxtoken.go +++ b/auth/model/nxtoken.go @@ -1,14 +1,44 @@ package model +import "time" + type GenerateNXTokenRequest struct { - Count int `json:"count"` - Name string `json:"name"` - Description string `json:"description"` - ExpiresAt string `json:"expires_at"` - UserID string `json:"user_id"` - Email string `json:"email"` + Name string `json:"name"` + Description string `json:"description"` + UserID int32 `json:"user_id"` + ProjectID int32 `json:"project_id"` + NxaProviderKeyID int32 `json:"nxa_provider_key_id"` + Property NXTokenPropertyRequest `json:"property"` +} + +type NXTokenPropertyRequest struct { + RateLimit int32 `json:"rate_limit"` + RateLimitPeriod string `json:"rate_limit_period"` + EnforceCaching bool `json:"enforce_caching"` + OverallCostLimit int32 `json:"overall_cost_limit"` + AlertOnThreshold int32 `json:"alert_on_threshold"` + ExpiresAt time.Time `json:"expires_at"` } type GenerateNXTokenResponse struct { Token string `json:"token"` } + +/* +{ + "name": "prod-key", + "description": "production key", + "user_id": "1", + "project_id": "1", + "property": { + "rate_limit": "1000", + "rate_limit_period": "day", + "enforce_caching": "true", + "overall_cost_limit": "1000", + "alert_on_threshold": "1000", + "expires_at": "2021-01-01 00:00:00" + } + +} + +*/ diff --git a/auth/pkg/db/db.go b/auth/pkg/db/db.go index 6a3619a..d2924ec 100644 --- a/auth/pkg/db/db.go +++ b/auth/pkg/db/db.go @@ -26,7 +26,7 @@ type DB interface { GetUsersByProjectId(ctx context.Context, projectID int32) ([]postgresql_db.User, error) GetProject(ctx context.Context, projectID int32) (postgresql_db.Project, error) CreateOrganization(ctx context.Context, organization postgresql_db.Organization) (postgresql_db.Organization, error) - CreateApiKey(ctx context.Context, apiKey postgresql_db.NxaApiKey) (postgresql_db.NxaApiKey, error) + CreateApiKey(ctx context.Context, apiKeyParam postgresql_db.CreateApiKeyParams) (postgresql_db.NxaApiKey, error) GetAPIkeyByUserId(ctx context.Context, userID int32) ([]postgresql_db.NxaApiKey, error) GetUserById(ctx context.Context, id int32) (postgresql_db.User, error) GetAPIkeyByApiKey(ctx context.Context, apiKey string) (postgresql_db.NxaApiKey, error) @@ -42,4 +42,6 @@ type DB interface { GetProviderKeysByProjectId(ctx context.Context, projectID int32) ([]postgresql_db.ProviderKey, error) CreateSetting(ctx context.Context, setting postgresql_db.CreateSettingParams) (postgresql_db.Setting, error) GetSetting(ctx context.Context, key string) (postgresql_db.Setting, error) + GetProviderKeyById(ctx context.Context, id int32) (postgresql_db.ProviderKey, error) + CreateNXAKeyProperty(ctx context.Context, nxaKeyProperty postgresql_db.CreateNXAKeyPropertyParams) (postgresql_db.NxaApiKeyProperty, error) } diff --git a/auth/pkg/db/postgres/nxaapi.go b/auth/pkg/db/postgres/nxaapi.go new file mode 100644 index 0000000..c62d8ad --- /dev/null +++ b/auth/pkg/db/postgres/nxaapi.go @@ -0,0 +1,12 @@ +package postgres + +import ( + "context" + + postgresql_db "github.com/NumexaHQ/captainCache/numexa-common/postgresql/postgresql-db" +) + +func (p *Postgres) CreateNXAKeyProperty(ctx context.Context, nxaKeyProperty postgresql_db.CreateNXAKeyPropertyParams) (postgresql_db.NxaApiKeyProperty, error) { + queries := getPostgresQueries(p.db) + return queries.CreateNXAKeyProperty(ctx, nxaKeyProperty) +} diff --git a/auth/pkg/db/postgres/postgres.go b/auth/pkg/db/postgres/postgres.go index 2c0d174..bbfbf9d 100644 --- a/auth/pkg/db/postgres/postgres.go +++ b/auth/pkg/db/postgres/postgres.go @@ -171,18 +171,10 @@ func (p *Postgres) GetOrganization(ctx context.Context, organizationID int32) (p return organization, nil } -func (p *Postgres) CreateApiKey(ctx context.Context, apiKey postgresql_db.NxaApiKey) (postgresql_db.NxaApiKey, error) { +func (p *Postgres) CreateApiKey(ctx context.Context, apiKeyParam postgresql_db.CreateApiKeyParams) (postgresql_db.NxaApiKey, error) { queries := getPostgresQueries(p.db) - apiKey, err := queries.CreateApiKey(ctx, postgresql_db.CreateApiKeyParams{ - Name: apiKey.Name, - ApiKey: apiKey.ApiKey, - UserID: apiKey.UserID, - ProjectID: apiKey.ProjectID, - ExpiresAt: apiKey.ExpiresAt, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) + apiKey, err := queries.CreateApiKey(ctx, apiKeyParam) return apiKey, err } diff --git a/auth/pkg/db/postgres/providerkeys.go b/auth/pkg/db/postgres/providerkeys.go index 6321ba5..de81403 100644 --- a/auth/pkg/db/postgres/providerkeys.go +++ b/auth/pkg/db/postgres/providerkeys.go @@ -30,3 +30,8 @@ func (p *Postgres) GetProviderSecretByProviderId(ctx context.Context, id int32) queries := getPostgresQueries(p.db) return queries.GetProviderSecretsByProviderKeyID(ctx, id) } + +func (p *Postgres) GetProviderKeyById(ctx context.Context, id int32) (postgresql_db.ProviderKey, error) { + queries := getPostgresQueries(p.db) + return queries.GetProviderKeyByID(ctx, id) +} diff --git a/numexa-common/postgresql/0004_key_tables.up.sql b/numexa-common/postgresql/0004_key_tables.up.sql index 3428558..6b1ed03 100644 --- a/numexa-common/postgresql/0004_key_tables.up.sql +++ b/numexa-common/postgresql/0004_key_tables.up.sql @@ -45,5 +45,10 @@ CREATE TABLE IF NOT EXISTS "public"."nxa_api_key_property" ( ); ALTER TABLE "public"."nxa_api_key" ADD COLUMN "nxa_api_key_property_id" INTEGER NULL REFERENCES nxa_api_key_property(id); +ALTER TABLE "public"."nxa_api_key" ADD COLUMN "provider_key_id" INTEGER NULL REFERENCES provider_keys(id); +ALTER TABLE "public"."nxa_api_key" ADD COLUMN "revoked" BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE "public"."nxa_api_key" ADD COLUMN "revoked_at" TIMESTAMP NULL; +ALTER TABLE "public"."nxa_api_key" ADD COLUMN "revoked_by" INTEGER NULL REFERENCES users(id); +ALTER TABLE "public"."nxa_api_key" ADD COLUMN "disabled" BOOLEAN NOT NULL DEFAULT FALSE; COMMIT; \ No newline at end of file diff --git a/numexa-common/postgresql/postgresql-db/models.go b/numexa-common/postgresql/postgresql-db/models.go index 776aa1c..825e15e 100644 --- a/numexa-common/postgresql/postgresql-db/models.go +++ b/numexa-common/postgresql/postgresql-db/models.go @@ -20,18 +20,24 @@ type NxaApiKey struct { UpdatedAt time.Time `json:"updated_at"` ExpiresAt time.Time `json:"expires_at"` NxaApiKeyPropertyID sql.NullInt32 `json:"nxa_api_key_property_id"` + ProviderKeyID sql.NullInt32 `json:"provider_key_id"` + Revoked bool `json:"revoked"` + RevokedAt sql.NullTime `json:"revoked_at"` + RevokedBy sql.NullInt32 `json:"revoked_by"` + Disabled bool `json:"disabled"` } type NxaApiKeyProperty struct { - ID int32 `json:"id"` - RateLimit int32 `json:"rate_limit"` - RateLimitPeriod string `json:"rate_limit_period"` - EnforceCaching bool `json:"enforce_caching"` - OverallCostLimit int32 `json:"overall_cost_limit"` - AlertOnThreshold int32 `json:"alert_on_threshold"` - ExpiresAt time.Time `json:"expires_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int32 `json:"id"` + RateLimit int32 `json:"rate_limit"` + RateLimitPeriod string `json:"rate_limit_period"` + EnforceCaching bool `json:"enforce_caching"` + OverallCostLimit int32 `json:"overall_cost_limit"` + AlertOnThreshold int32 `json:"alert_on_threshold"` + ProviderKeyID sql.NullInt32 `json:"provider_key_id"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } type Organization struct { diff --git a/numexa-common/postgresql/postgresql-db/queries.sql.go b/numexa-common/postgresql/postgresql-db/queries.sql.go index 62846a7..621bd79 100644 --- a/numexa-common/postgresql/postgresql-db/queries.sql.go +++ b/numexa-common/postgresql/postgresql-db/queries.sql.go @@ -13,19 +13,25 @@ import ( ) const createApiKey = `-- name: CreateApiKey :one -INSERT INTO nxa_api_key (name, api_key, user_id, project_id, created_at, updated_at, expires_at) -VALUES ($1, $2, $3, $4, $5, $6, $7) -RETURNING id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id +INSERT INTO nxa_api_key (name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id, provider_key_id, revoked, revoked_at, revoked_by, disabled) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) +RETURNING id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id, provider_key_id, revoked, revoked_at, revoked_by, disabled ` type CreateApiKeyParams struct { - Name string `json:"name"` - ApiKey string `json:"api_key"` - UserID int32 `json:"user_id"` - ProjectID int32 `json:"project_id"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - ExpiresAt time.Time `json:"expires_at"` + Name string `json:"name"` + ApiKey string `json:"api_key"` + UserID int32 `json:"user_id"` + ProjectID int32 `json:"project_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + ExpiresAt time.Time `json:"expires_at"` + NxaApiKeyPropertyID sql.NullInt32 `json:"nxa_api_key_property_id"` + ProviderKeyID sql.NullInt32 `json:"provider_key_id"` + Revoked bool `json:"revoked"` + RevokedAt sql.NullTime `json:"revoked_at"` + RevokedBy sql.NullInt32 `json:"revoked_by"` + Disabled bool `json:"disabled"` } func (q *Queries) CreateApiKey(ctx context.Context, arg CreateApiKeyParams) (NxaApiKey, error) { @@ -37,6 +43,12 @@ func (q *Queries) CreateApiKey(ctx context.Context, arg CreateApiKeyParams) (Nxa arg.CreatedAt, arg.UpdatedAt, arg.ExpiresAt, + arg.NxaApiKeyPropertyID, + arg.ProviderKeyID, + arg.Revoked, + arg.RevokedAt, + arg.RevokedBy, + arg.Disabled, ) var i NxaApiKey err := row.Scan( @@ -49,6 +61,57 @@ func (q *Queries) CreateApiKey(ctx context.Context, arg CreateApiKeyParams) (Nxa &i.UpdatedAt, &i.ExpiresAt, &i.NxaApiKeyPropertyID, + &i.ProviderKeyID, + &i.Revoked, + &i.RevokedAt, + &i.RevokedBy, + &i.Disabled, + ) + return i, err +} + +const createNXAKeyProperty = `-- name: CreateNXAKeyProperty :one +INSERT INTO nxa_api_key_property (rate_limit, rate_limit_period, enforce_caching, overall_cost_limit, alert_on_threshold, provider_key_id, expires_at, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +RETURNING id, rate_limit, rate_limit_period, enforce_caching, overall_cost_limit, alert_on_threshold, provider_key_id, expires_at, created_at, updated_at +` + +type CreateNXAKeyPropertyParams struct { + RateLimit int32 `json:"rate_limit"` + RateLimitPeriod string `json:"rate_limit_period"` + EnforceCaching bool `json:"enforce_caching"` + OverallCostLimit int32 `json:"overall_cost_limit"` + AlertOnThreshold int32 `json:"alert_on_threshold"` + ProviderKeyID sql.NullInt32 `json:"provider_key_id"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (q *Queries) CreateNXAKeyProperty(ctx context.Context, arg CreateNXAKeyPropertyParams) (NxaApiKeyProperty, error) { + row := q.db.QueryRowContext(ctx, createNXAKeyProperty, + arg.RateLimit, + arg.RateLimitPeriod, + arg.EnforceCaching, + arg.OverallCostLimit, + arg.AlertOnThreshold, + arg.ProviderKeyID, + arg.ExpiresAt, + arg.CreatedAt, + arg.UpdatedAt, + ) + var i NxaApiKeyProperty + err := row.Scan( + &i.ID, + &i.RateLimit, + &i.RateLimitPeriod, + &i.EnforceCaching, + &i.OverallCostLimit, + &i.AlertOnThreshold, + &i.ProviderKeyID, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, ) return i, err } @@ -276,7 +339,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e } const getAPIKeyByNameAndProjectId = `-- name: GetAPIKeyByNameAndProjectId :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE name = $1 AND project_id = $2 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id, provider_key_id, revoked, revoked_at, revoked_by, disabled FROM nxa_api_key WHERE name = $1 AND project_id = $2 ` type GetAPIKeyByNameAndProjectIdParams struct { @@ -297,12 +360,17 @@ func (q *Queries) GetAPIKeyByNameAndProjectId(ctx context.Context, arg GetAPIKey &i.UpdatedAt, &i.ExpiresAt, &i.NxaApiKeyPropertyID, + &i.ProviderKeyID, + &i.Revoked, + &i.RevokedAt, + &i.RevokedBy, + &i.Disabled, ) return i, err } const getAPIkeyByApiKey = `-- name: GetAPIkeyByApiKey :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE api_key = $1 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id, provider_key_id, revoked, revoked_at, revoked_by, disabled FROM nxa_api_key WHERE api_key = $1 ` func (q *Queries) GetAPIkeyByApiKey(ctx context.Context, apiKey string) (NxaApiKey, error) { @@ -318,12 +386,17 @@ func (q *Queries) GetAPIkeyByApiKey(ctx context.Context, apiKey string) (NxaApiK &i.UpdatedAt, &i.ExpiresAt, &i.NxaApiKeyPropertyID, + &i.ProviderKeyID, + &i.Revoked, + &i.RevokedAt, + &i.RevokedBy, + &i.Disabled, ) return i, err } const getAPIkeyByUserId = `-- name: GetAPIkeyByUserId :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE user_id = $1 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id, provider_key_id, revoked, revoked_at, revoked_by, disabled FROM nxa_api_key WHERE user_id = $1 ` func (q *Queries) GetAPIkeyByUserId(ctx context.Context, userID int32) (NxaApiKey, error) { @@ -339,6 +412,11 @@ func (q *Queries) GetAPIkeyByUserId(ctx context.Context, userID int32) (NxaApiKe &i.UpdatedAt, &i.ExpiresAt, &i.NxaApiKeyPropertyID, + &i.ProviderKeyID, + &i.Revoked, + &i.RevokedAt, + &i.RevokedBy, + &i.Disabled, ) return i, err } @@ -577,6 +655,27 @@ func (q *Queries) GetProviderKey(ctx context.Context, id int32) (ProviderKey, er return i, err } +const getProviderKeyByID = `-- name: GetProviderKeyByID :one +SELECT id, key_uuid, name, provider, creator_id, organization_id, project_id, created_at, updated_at FROM provider_keys WHERE id = $1 +` + +func (q *Queries) GetProviderKeyByID(ctx context.Context, id int32) (ProviderKey, error) { + row := q.db.QueryRowContext(ctx, getProviderKeyByID, id) + var i ProviderKey + err := row.Scan( + &i.ID, + &i.KeyUuid, + &i.Name, + &i.Provider, + &i.CreatorID, + &i.OrganizationID, + &i.ProjectID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getProviderKeyByName = `-- name: GetProviderKeyByName :one SELECT id, key_uuid, name, provider, creator_id, organization_id, project_id, created_at, updated_at FROM provider_keys WHERE name = $1 ` @@ -750,7 +849,7 @@ func (q *Queries) GetSetting(ctx context.Context, key string) (Setting, error) { } const getTokenByProjectId = `-- name: GetTokenByProjectId :one -SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id FROM nxa_api_key WHERE project_id = $1 AND api_key = $2 +SELECT id, name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id, provider_key_id, revoked, revoked_at, revoked_by, disabled FROM nxa_api_key WHERE project_id = $1 AND api_key = $2 ` type GetTokenByProjectIdParams struct { @@ -771,6 +870,11 @@ func (q *Queries) GetTokenByProjectId(ctx context.Context, arg GetTokenByProject &i.UpdatedAt, &i.ExpiresAt, &i.NxaApiKeyPropertyID, + &i.ProviderKeyID, + &i.Revoked, + &i.RevokedAt, + &i.RevokedBy, + &i.Disabled, ) return i, err } diff --git a/numexa-common/postgresql/queries.sql b/numexa-common/postgresql/queries.sql index b1e966b..6d840bc 100644 --- a/numexa-common/postgresql/queries.sql +++ b/numexa-common/postgresql/queries.sql @@ -57,8 +57,8 @@ SELECT * FROM project_users WHERE project_id = $1 AND user_id = $2; SELECT * FROM nxa_api_key WHERE project_id = $1 AND api_key = $2; -- name: CreateApiKey :one -INSERT INTO nxa_api_key (name, api_key, user_id, project_id, created_at, updated_at, expires_at) -VALUES ($1, $2, $3, $4, $5, $6, $7) +INSERT INTO nxa_api_key (name, api_key, user_id, project_id, created_at, updated_at, expires_at, nxa_api_key_property_id, provider_key_id, revoked, revoked_at, revoked_by, disabled) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING *; -- name: GetAPIkeyByUserId :one @@ -107,17 +107,22 @@ SELECT * FROM provider_secrets WHERE id = $1; -- name: GetProviderSecretsByProviderKeyID :many SELECT * FROM provider_secrets WHERE provider_key_id = $1; --- name: GetProviderSecretsByProviderKeyID :many -SELECT * FROM provider_secrets WHERE provider_key_id = $1; - -- name: GetProviderSecretByProviderKeyIDAndType :one SELECT * FROM provider_secrets WHERE provider_key_id = $1 AND type = $2; +-- name: GetProviderKeyByID :one +SELECT * FROM provider_keys WHERE id = $1; + -- name: CreateSetting :one INSERT INTO setting (key, value, visible, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING *; +-- name: CreateNXAKeyProperty :one +INSERT INTO nxa_api_key_property (rate_limit, rate_limit_period, enforce_caching, overall_cost_limit, alert_on_threshold, provider_key_id, expires_at, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +RETURNING *; + -- name: GetSetting :one SELECT * FROM setting WHERE key = $1; From 1e8ecf92466435c7e81cd1de74a1f51600050404 Mon Sep 17 00:00:00 2001 From: l0calh0st8080 <142974401+l0calh0st8080@users.noreply.github.com> Date: Wed, 20 Sep 2023 22:55:33 +0530 Subject: [PATCH 3/3] feat: add support for keymgmt in proxy(monger) --- auth/pkg/db/db.go | 1 + auth/pkg/db/postgres/providerkeys.go | 66 +++++++++++++++++ monger/handlers/proxy.go | 73 +++++++++++++++++++ monger/routes/middleware.go | 7 ++ .../postgresql/postgresql-db/queries.sql.go | 22 ++++++ numexa-common/postgresql/queries.sql | 3 + 6 files changed, 172 insertions(+) diff --git a/auth/pkg/db/db.go b/auth/pkg/db/db.go index d2924ec..f7df48f 100644 --- a/auth/pkg/db/db.go +++ b/auth/pkg/db/db.go @@ -43,5 +43,6 @@ type DB interface { CreateSetting(ctx context.Context, setting postgresql_db.CreateSettingParams) (postgresql_db.Setting, error) GetSetting(ctx context.Context, key string) (postgresql_db.Setting, error) GetProviderKeyById(ctx context.Context, id int32) (postgresql_db.ProviderKey, error) + CheckProviderAndNXAKeyPropertyFromNXAKey(ctx context.Context, nxaKey string, providerName string) (bool, postgresql_db.ProviderKey, postgresql_db.NxaApiKeyProperty, []postgresql_db.ProviderSecret, error) CreateNXAKeyProperty(ctx context.Context, nxaKeyProperty postgresql_db.CreateNXAKeyPropertyParams) (postgresql_db.NxaApiKeyProperty, error) } diff --git a/auth/pkg/db/postgres/providerkeys.go b/auth/pkg/db/postgres/providerkeys.go index de81403..c485faa 100644 --- a/auth/pkg/db/postgres/providerkeys.go +++ b/auth/pkg/db/postgres/providerkeys.go @@ -2,8 +2,11 @@ package postgres import ( "context" + "database/sql" + "time" postgresql_db "github.com/NumexaHQ/captainCache/numexa-common/postgresql/postgresql-db" + "github.com/sirupsen/logrus" ) func (p *Postgres) AddProviderKeys(ctx context.Context, pk postgresql_db.CreateProviderKeyParams) (postgresql_db.ProviderKey, error) { @@ -35,3 +38,66 @@ func (p *Postgres) GetProviderKeyById(ctx context.Context, id int32) (postgresql queries := getPostgresQueries(p.db) return queries.GetProviderKeyByID(ctx, id) } + +func (p *Postgres) CheckProviderAndNXAKeyPropertyFromNXAKey(ctx context.Context, nxaKey string, providerName string) (bool, postgresql_db.ProviderKey, postgresql_db.NxaApiKeyProperty, []postgresql_db.ProviderSecret, error) { + notValid := false + var keyProperty postgresql_db.NxaApiKeyProperty + queries := getPostgresQueries(p.db) + + nxaKeyRow, err := queries.GetAPIkeyByApiKey(ctx, nxaKey) + if err != nil { + if err == sql.ErrNoRows { + logrus.Errorf("api key not found: %s", nxaKey) + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil + } + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err + } + + if nxaKeyRow.ExpiresAt.Before(time.Now()) { + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil + } + + if nxaKeyRow.Revoked { + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil + } + + if nxaKeyRow.Disabled { + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil + } + + if nxaKeyRow.NxaApiKeyPropertyID.Valid { + // api key might not have a provider key associated with it + // and can still be valid + keyProperty, err = queries.GetNXAKeyPropertyByID(ctx, nxaKeyRow.NxaApiKeyPropertyID.Int32) + if err != nil { + if err == sql.ErrNoRows { + return true, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil + } + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err + } + } + + // get provider key + providerKey, err := queries.GetProviderKeyByID(ctx, nxaKeyRow.ProviderKeyID.Int32) + if err != nil { + if err == sql.ErrNoRows { + logrus.Errorf("provider key not found for api key: %s", nxaKey) + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil + } else { + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err + } + } + + // get provider secrets + providerSecret, err := queries.GetProviderSecretsByProviderKeyID(ctx, nxaKeyRow.ProviderKeyID.Int32) + if err != nil { + if err == sql.ErrNoRows { + logrus.Errorf("provider secrets not found for provider key: %s", providerKey.Name) + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil + } else { + return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err + } + } + + return true, providerKey, keyProperty, providerSecret, nil +} diff --git a/monger/handlers/proxy.go b/monger/handlers/proxy.go index 146327a..4f8feb5 100644 --- a/monger/handlers/proxy.go +++ b/monger/handlers/proxy.go @@ -11,7 +11,10 @@ import ( "strings" "time" + authModel "github.com/NumexaHQ/captainCache/model" commonConstants "github.com/NumexaHQ/captainCache/numexa-common/constants" + authConstants "github.com/NumexaHQ/captainCache/pkg/constants" + "github.com/NumexaHQ/captainCache/pkg/providerkeys" "github.com/NumexaHQ/monger/model" nxopenaiModel "github.com/NumexaHQ/monger/model/openai" gptcache "github.com/NumexaHQ/monger/pkg/cache" @@ -26,6 +29,76 @@ func (h *Handler) OpenAIProxy(w http.ResponseWriter, r *http.Request) { apiKey := r.Header.Get("X-Numexa-Api-Key") + // check if openai key is present + // todo: pprof profile this!!! + if r.Header.Get("Organization") == "" || r.Header.Get("Authorization") == "" { + // check if api key has associated openai key + // if not, return error + // else, set the openai key in the header + // + isValid, providerKey, keyProperty, providerSecrets, err := h.AuthDB.CheckProviderAndNXAKeyPropertyFromNXAKey(r.Context(), apiKey, authConstants.PROVIDER_OPENAI) + if err != nil { + logrus.Errorf("Error checking provider and nxa key property from nxa key: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + if !isValid { + http.Error(w, "Invalid/Expired API key", http.StatusUnauthorized) + return + } + + logrus.Infof("provider key: %+v", providerKey) + logrus.Infof("providerSecrets: %+v", providerSecrets) + + if providerKey.Provider != authConstants.PROVIDER_OPENAI { + http.Error(w, "Invalid provider", http.StatusUnauthorized) + return + } + + // todo: use keyproperty, to enforce rules + logrus.Debugf("key property: %v", keyProperty) + + keys := make(map[string]string) + + for _, secrets := range providerSecrets { + keys[secrets.Type] = secrets.Key + } + + // here kp.Keys is encrypted keys + kp := authModel.ProviderKeys{ + Name: providerKey.Name, + Provider: providerKey.Provider, + Keys: keys, + ProjectId: providerKey.ProjectID, + } + + kpB, err := json.Marshal(kp) + if err != nil { + logrus.Errorf("Error marshalling provider keys: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + provider, err := providerkeys.GetProvider(providerKey.Provider, kpB, true) + if err != nil { + logrus.Errorf("Error getting provider: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + decryptedKeys, err := provider.GetDecryptedKeys(r.Context(), h.AuthDB) + if err != nil { + logrus.Errorf("Error getting decrypted keys: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // set the openai key in the header + r.Header.Set("Organization", decryptedKeys["openai_org"]) + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", decryptedKeys["openai_key"])) + } + index := strings.Index(originalURL, "/v1/openai/") if index == -1 { http.Error(w, "Invalid URL", http.StatusNotFound) diff --git a/monger/routes/middleware.go b/monger/routes/middleware.go index 3fd5f99..bc8fd87 100644 --- a/monger/routes/middleware.go +++ b/monger/routes/middleware.go @@ -12,6 +12,13 @@ func Middleware(h http.Handler) http.Handler { r = r.WithContext(model.AssignRequestID(ctx)) + // validate api key + apiKey := r.Header.Get("X-Numexa-Api-Key") + if apiKey == "" { + http.Error(w, "Missing API key", http.StatusUnauthorized) + return + } + h.ServeHTTP(w, r) }) } diff --git a/numexa-common/postgresql/postgresql-db/queries.sql.go b/numexa-common/postgresql/postgresql-db/queries.sql.go index 621bd79..4a46773 100644 --- a/numexa-common/postgresql/postgresql-db/queries.sql.go +++ b/numexa-common/postgresql/postgresql-db/queries.sql.go @@ -464,6 +464,28 @@ func (q *Queries) GetAllApiKeysByUserId(ctx context.Context, userID int32) ([]Ge return items, nil } +const getNXAKeyPropertyByID = `-- name: GetNXAKeyPropertyByID :one +SELECT id, rate_limit, rate_limit_period, enforce_caching, overall_cost_limit, alert_on_threshold, provider_key_id, expires_at, created_at, updated_at FROM nxa_api_key_property WHERE id = $1 +` + +func (q *Queries) GetNXAKeyPropertyByID(ctx context.Context, id int32) (NxaApiKeyProperty, error) { + row := q.db.QueryRowContext(ctx, getNXAKeyPropertyByID, id) + var i NxaApiKeyProperty + err := row.Scan( + &i.ID, + &i.RateLimit, + &i.RateLimitPeriod, + &i.EnforceCaching, + &i.OverallCostLimit, + &i.AlertOnThreshold, + &i.ProviderKeyID, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getOrganization = `-- name: GetOrganization :one SELECT id, name, created_at, updated_at FROM organizations WHERE id = $1 ` diff --git a/numexa-common/postgresql/queries.sql b/numexa-common/postgresql/queries.sql index 6d840bc..e475b6f 100644 --- a/numexa-common/postgresql/queries.sql +++ b/numexa-common/postgresql/queries.sql @@ -123,6 +123,9 @@ INSERT INTO nxa_api_key_property (rate_limit, rate_limit_period, enforce_caching VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; +-- name: GetNXAKeyPropertyByID :one +SELECT * FROM nxa_api_key_property WHERE id = $1; + -- name: GetSetting :one SELECT * FROM setting WHERE key = $1;