Skip to content

Commit

Permalink
feat: add support for keymgmt in proxy(monger)
Browse files Browse the repository at this point in the history
  • Loading branch information
l0calh0st8080 committed Sep 20, 2023
1 parent 5b12df6 commit 1e8ecf9
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 0 deletions.
1 change: 1 addition & 0 deletions auth/pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@ type DB interface {
CreateSetting(ctx context.Context, setting postgresql_db.CreateSettingParams) (postgresql_db.Setting, error)
GetSetting(ctx context.Context, key string) (postgresql_db.Setting, error)
GetProviderKeyById(ctx context.Context, id int32) (postgresql_db.ProviderKey, error)
CheckProviderAndNXAKeyPropertyFromNXAKey(ctx context.Context, nxaKey string, providerName string) (bool, postgresql_db.ProviderKey, postgresql_db.NxaApiKeyProperty, []postgresql_db.ProviderSecret, error)
CreateNXAKeyProperty(ctx context.Context, nxaKeyProperty postgresql_db.CreateNXAKeyPropertyParams) (postgresql_db.NxaApiKeyProperty, error)
}
66 changes: 66 additions & 0 deletions auth/pkg/db/postgres/providerkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package postgres

import (
"context"
"database/sql"
"time"

postgresql_db "github.com/NumexaHQ/captainCache/numexa-common/postgresql/postgresql-db"
"github.com/sirupsen/logrus"
)

func (p *Postgres) AddProviderKeys(ctx context.Context, pk postgresql_db.CreateProviderKeyParams) (postgresql_db.ProviderKey, error) {
Expand Down Expand Up @@ -35,3 +38,66 @@ func (p *Postgres) GetProviderKeyById(ctx context.Context, id int32) (postgresql
queries := getPostgresQueries(p.db)
return queries.GetProviderKeyByID(ctx, id)
}

func (p *Postgres) CheckProviderAndNXAKeyPropertyFromNXAKey(ctx context.Context, nxaKey string, providerName string) (bool, postgresql_db.ProviderKey, postgresql_db.NxaApiKeyProperty, []postgresql_db.ProviderSecret, error) {
notValid := false
var keyProperty postgresql_db.NxaApiKeyProperty
queries := getPostgresQueries(p.db)

nxaKeyRow, err := queries.GetAPIkeyByApiKey(ctx, nxaKey)
if err != nil {
if err == sql.ErrNoRows {
logrus.Errorf("api key not found: %s", nxaKey)
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil
}
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err
}

if nxaKeyRow.ExpiresAt.Before(time.Now()) {
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil
}

if nxaKeyRow.Revoked {
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil
}

if nxaKeyRow.Disabled {
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil
}

if nxaKeyRow.NxaApiKeyPropertyID.Valid {
// api key might not have a provider key associated with it
// and can still be valid
keyProperty, err = queries.GetNXAKeyPropertyByID(ctx, nxaKeyRow.NxaApiKeyPropertyID.Int32)
if err != nil {
if err == sql.ErrNoRows {
return true, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil
}
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err
}
}

// get provider key
providerKey, err := queries.GetProviderKeyByID(ctx, nxaKeyRow.ProviderKeyID.Int32)
if err != nil {
if err == sql.ErrNoRows {
logrus.Errorf("provider key not found for api key: %s", nxaKey)
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil
} else {
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err
}
}

// get provider secrets
providerSecret, err := queries.GetProviderSecretsByProviderKeyID(ctx, nxaKeyRow.ProviderKeyID.Int32)
if err != nil {
if err == sql.ErrNoRows {
logrus.Errorf("provider secrets not found for provider key: %s", providerKey.Name)
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, nil
} else {
return notValid, postgresql_db.ProviderKey{}, postgresql_db.NxaApiKeyProperty{}, []postgresql_db.ProviderSecret{}, err
}
}

return true, providerKey, keyProperty, providerSecret, nil
}
73 changes: 73 additions & 0 deletions monger/handlers/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import (
"strings"
"time"

authModel "github.com/NumexaHQ/captainCache/model"
commonConstants "github.com/NumexaHQ/captainCache/numexa-common/constants"
authConstants "github.com/NumexaHQ/captainCache/pkg/constants"
"github.com/NumexaHQ/captainCache/pkg/providerkeys"
"github.com/NumexaHQ/monger/model"
nxopenaiModel "github.com/NumexaHQ/monger/model/openai"
gptcache "github.com/NumexaHQ/monger/pkg/cache"
Expand All @@ -26,6 +29,76 @@ func (h *Handler) OpenAIProxy(w http.ResponseWriter, r *http.Request) {

apiKey := r.Header.Get("X-Numexa-Api-Key")

// check if openai key is present
// todo: pprof profile this!!!
if r.Header.Get("Organization") == "" || r.Header.Get("Authorization") == "" {
// check if api key has associated openai key
// if not, return error
// else, set the openai key in the header
//
isValid, providerKey, keyProperty, providerSecrets, err := h.AuthDB.CheckProviderAndNXAKeyPropertyFromNXAKey(r.Context(), apiKey, authConstants.PROVIDER_OPENAI)
if err != nil {
logrus.Errorf("Error checking provider and nxa key property from nxa key: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}

if !isValid {
http.Error(w, "Invalid/Expired API key", http.StatusUnauthorized)
return
}

logrus.Infof("provider key: %+v", providerKey)
logrus.Infof("providerSecrets: %+v", providerSecrets)

if providerKey.Provider != authConstants.PROVIDER_OPENAI {
http.Error(w, "Invalid provider", http.StatusUnauthorized)
return
}

// todo: use keyproperty, to enforce rules
logrus.Debugf("key property: %v", keyProperty)

keys := make(map[string]string)

for _, secrets := range providerSecrets {
keys[secrets.Type] = secrets.Key
}

// here kp.Keys is encrypted keys
kp := authModel.ProviderKeys{
Name: providerKey.Name,
Provider: providerKey.Provider,
Keys: keys,
ProjectId: providerKey.ProjectID,
}

kpB, err := json.Marshal(kp)
if err != nil {
logrus.Errorf("Error marshalling provider keys: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}

provider, err := providerkeys.GetProvider(providerKey.Provider, kpB, true)
if err != nil {
logrus.Errorf("Error getting provider: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}

decryptedKeys, err := provider.GetDecryptedKeys(r.Context(), h.AuthDB)
if err != nil {
logrus.Errorf("Error getting decrypted keys: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}

// set the openai key in the header
r.Header.Set("Organization", decryptedKeys["openai_org"])
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", decryptedKeys["openai_key"]))
}

index := strings.Index(originalURL, "/v1/openai/")
if index == -1 {
http.Error(w, "Invalid URL", http.StatusNotFound)
Expand Down
7 changes: 7 additions & 0 deletions monger/routes/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ func Middleware(h http.Handler) http.Handler {

r = r.WithContext(model.AssignRequestID(ctx))

// validate api key
apiKey := r.Header.Get("X-Numexa-Api-Key")
if apiKey == "" {
http.Error(w, "Missing API key", http.StatusUnauthorized)
return
}

h.ServeHTTP(w, r)
})
}
22 changes: 22 additions & 0 deletions numexa-common/postgresql/postgresql-db/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions numexa-common/postgresql/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ INSERT INTO nxa_api_key_property (rate_limit, rate_limit_period, enforce_caching
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING *;

-- name: GetNXAKeyPropertyByID :one
SELECT * FROM nxa_api_key_property WHERE id = $1;

-- name: GetSetting :one
SELECT * FROM setting WHERE key = $1;

Expand Down

0 comments on commit 1e8ecf9

Please sign in to comment.