Skip to content

Commit

Permalink
Merge pull request #1 from NumexaHQ/feat-key-mgmt
Browse files Browse the repository at this point in the history
feat: keymgmt
  • Loading branch information
pandyamarut authored Sep 21, 2023
2 parents f1ad754 + 1e8ecf9 commit 2d6e648
Show file tree
Hide file tree
Showing 26 changed files with 1,581 additions and 55 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ up:
docker-compose up -d

all: auth monger vibe
.PHONY: auth monger vibe up
.PHONY: all auth monger vibe up

110 changes: 98 additions & 12 deletions auth/handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,8 @@ func generateAPIKey() string {
return fmt.Sprintf("sk-%s", string(b))
}

// func hashPassword(password string) string {
// hashPassword := utils.HashPassword(password)
// }

func (h *Handler) CreateApiKey(c *fiber.Ctx) error {
type RequestBody postgresql_db.NxaApiKey
var reqBody RequestBody
var reqBody model.GenerateNXTokenRequest
if err := c.BodyParser(&reqBody); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"message": "Invalid request body",
Expand Down Expand Up @@ -113,16 +108,61 @@ func (h *Handler) CreateApiKey(c *fiber.Ctx) error {
})
}

_, err = h.DB.CreateApiKey(c.Context(), postgresql_db.NxaApiKey{
nxaAPIKey := postgresql_db.CreateApiKeyParams{
Name: reqBody.Name,
ApiKey: apiKey,
UserID: user.ID,
ProjectID: reqBody.ProjectID,
ExpiresAt: time.Now().Add(time.Hour * 24 * 365),
Revoked: false,
Disabled: false,
ExpiresAt: time.Now().Add(time.Hour * 24 * 365), // this might not be respected, since expiry is set in the key property
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
}

if reqBody.NxaProviderKeyID != 0 {
_, err := h.DB.GetProviderKeyById(c.Context(), reqBody.NxaProviderKeyID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
log.Errorf("error getting provider key by id: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Error getting provider key by id",
})
} else {
log.Errorf("provider key not found: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Something went wrong. Please contact the administrator",
})
}
}
nxaAPIKey.ProviderKeyID = sql.NullInt32{Int32: reqBody.NxaProviderKeyID, Valid: true}
}

// check if property exists
if reqBody.Property != (model.NXTokenPropertyRequest{}) {
expiry := time.Now().Add(time.Hour * 24 * 365) // todo: default expiry is 1 year, need to change this to never expire
if reqBody.Property.ExpiresAt.IsZero() {
reqBody.Property.ExpiresAt = expiry
}
nxkp, err := h.DB.CreateNXAKeyProperty(c.Context(), postgresql_db.CreateNXAKeyPropertyParams{
RateLimit: reqBody.Property.RateLimit,
RateLimitPeriod: reqBody.Property.RateLimitPeriod,
EnforceCaching: reqBody.Property.EnforceCaching,
OverallCostLimit: reqBody.Property.OverallCostLimit,
AlertOnThreshold: reqBody.Property.AlertOnThreshold,
ExpiresAt: expiry,
})
if err != nil {
log.Errorf("error creating nxa key property: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

nxaAPIKey.NxaApiKeyPropertyID = sql.NullInt32{Int32: nxkp.ID, Valid: true}
}

_, err = h.DB.CreateApiKey(c.Context(), nxaAPIKey)
if err != nil {
log.Errorf("error generating api key: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
Expand Down Expand Up @@ -177,8 +217,6 @@ func (h *Handler) RegisterHandler(c *fiber.Ctx) error {

// create organization
organization.Name = utils.GenerateOrganizationName()

log.Infof("organization: %+v", organization)
organization, err = h.DB.CreateOrganization(c.Context(), organization)
if err != nil {
log.Errorf("error creating organization: %v", err)
Expand Down Expand Up @@ -325,6 +363,54 @@ func generateJWTToken(user postgresql_db.User, jwtSigningKey string) (string, er
return tokenString, err
}

// SHOULD ONLY BE USED FOR TESTING
// DONOT USE IN PRODUCTION
func (h *Handler) DummyAuthMiddleware(c *fiber.Ctx) error {
tokenString := c.Get("Authorization")

// Remove the "Bearer " prefix from the token string
tokenString = strings.TrimPrefix(tokenString, "Bearer ")

// Parse the token
token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Check the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(h.JWTSigningKey), nil
})

// if err != nil {
// return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
// "message": "Unauthorized",
// })
// }

// Get the user ID from the token's claims
var userID float64
if claims, ok := token.Claims.(jwt.MapClaims); ok {
if id, exists := claims["user_id"]; exists {
userID = id.(float64)
c.Locals("user_id", userID) // Set the user ID in locals for other handlers to access
}
if email, exists := claims["email"]; exists {
c.Locals("user_email", email) // Set the user ID in locals for other handlers to access
}
if name, exists := claims["name"]; exists {
c.Locals("name", name) // Set the user ID in locals for other handlers to access
}
if organizationID, exists := claims["organization_id"]; exists {
c.Locals("organization_id", organizationID) // Set the user ID in locals for other handlers to access
}

}

// Check if the token is still valid (not invalidated by logout)

// Token is valid, proceed to the next handler
return c.Next()
}

func (h *Handler) AuthMiddleware(c *fiber.Ctx) error {
tokenString := c.Get("Authorization")

Expand Down Expand Up @@ -357,7 +443,7 @@ func (h *Handler) AuthMiddleware(c *fiber.Ctx) error {
c.Locals("user_email", email) // Set the user ID in locals for other handlers to access
}
if name, exists := claims["name"]; exists {
c.Locals("user_name", name) // Set the user ID in locals for other handlers to access
c.Locals("name", name) // Set the user ID in locals for other handlers to access
}
if organizationID, exists := claims["organization_id"]; exists {
c.Locals("organization_id", organizationID) // Set the user ID in locals for other handlers to access
Expand Down
160 changes: 160 additions & 0 deletions auth/handlers/keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package handlers

import (
"encoding/json"
"strconv"

"github.com/NumexaHQ/captainCache/model"
"github.com/NumexaHQ/captainCache/pkg/providerkeys"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)

func (h *Handler) AddProviderKeys(c *fiber.Ctx) error {
var reqBody model.ProviderKeys
if err := c.BodyParser(&reqBody); err != nil {
logrus.WithError(err).Error("Error parsing request body")
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"message": "Invalid request body",
})
}

userId := c.Locals("user_id").(float64)
orgId := c.Locals("organization_id").(float64)

rByte, err := json.Marshal(reqBody)
if err != nil {
logrus.WithError(err).Error("Error marshalling request body")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

keyProvider, err := providerkeys.GetProvider(reqBody.Provider, rByte, false)
if err != nil {
logrus.WithError(err).Error("Error getting provider")
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"message": "Invalid provider",
})
}

if keyProvider.KeyExists(c.Context(), h.DB, reqBody.Name) {
logrus.WithError(err).Error("Key already exists")
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"message": "Key already exists with this name",
})
}

// generate uuid for key
keyuuid, err := generateUUID()
if err != nil {
logrus.WithError(err).Error("Error generating uuid")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}
err = keyProvider.PushKeysToDB(c.Context(), h.DB, reqBody.Name, keyuuid, int32(userId), reqBody.ProjectId, int32(orgId))
if err != nil {
logrus.WithError(err).Error("Error pushing keys to db")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

return c.Status(fiber.StatusOK).JSON(fiber.Map{
"message": "Keys added successfully",
})
}

func (h *Handler) GetProviderKeys(c *fiber.Ctx) error {
// userId := c.Locals("user_id").(float64)
// orgId := c.Locals("organization_id").(float64)
projectId := c.Params("project_id")
projectIdInt, err := strconv.Atoi(projectId)
if err != nil {
logrus.WithError(err).Error("Error converting project id to int")
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"message": "Invalid project id",
})
}

keys, err := h.DB.GetProviderKeysByProjectId(c.Context(), int32(projectIdInt))
if err != nil {
logrus.WithError(err).Error("Error getting provider keys")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

resp := []model.ProviderKeys{}

for _, key := range keys {
secrets, err := h.DB.GetProviderSecretByProviderId(c.Context(), key.ID)
if err != nil {
if err.Error() == "no rows in result set" {
logrus.WithError(err).Error("Error getting provider secret")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}
logrus.WithError(err).Error("Error getting provider secret")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

secretsMap := make(map[string]string)
for _, sv := range secrets {
secretsMap[sv.Type] = sv.Key
}

// here kp.Keys is encrypted keys
kp := model.ProviderKeys{
Name: key.Name,
Provider: key.Provider,
Keys: secretsMap,
ProjectId: key.ProjectID,
}

kpB, err := json.Marshal(kp)
if err != nil {
logrus.WithError(err).Error("Error marshalling provider keys")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

provider, err := providerkeys.GetProvider(key.Provider, kpB, true)
if err != nil {
logrus.WithError(err).Error("Error getting provider")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

decryptedKeys, err := provider.GetDecryptedKeys(c.Context(), h.DB)
if err != nil {
logrus.WithError(err).Error("Error getting decrypted keys")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Internal server error",
})
}

// updating kp.Keys with decrypted keys
kp.Keys = decryptedKeys

resp = append(resp, kp)
}

return c.Status(fiber.StatusOK).JSON(resp)
}

func generateUUID() (string, error) {
uuid, err := uuid.NewRandom()
if err != nil {
logrus.WithError(err).Error("Error generating uuid")
return "", err
}
return uuid.String(), nil
}
10 changes: 9 additions & 1 deletion auth/main.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package main

import (
"context"
"os"

"github.com/NumexaHQ/captainCache/model"
commonConstants "github.com/NumexaHQ/captainCache/numexa-common/constants"
nxdb "github.com/NumexaHQ/captainCache/pkg/db"
"github.com/NumexaHQ/captainCache/routes"
Expand Down Expand Up @@ -36,7 +38,13 @@ func main() {

err := db.Init()
if err != nil {
log.Fatal("Failed to initialize database")
log.Fatal("Failed to initialize database: ", err)
}

// init AES setting
err = model.InitializeAESSetting(context.Background(), db)
if err != nil {
log.Fatal("Failed to initialize AES setting: ", err)
}

// Create a new Fiber app
Expand Down
8 changes: 8 additions & 0 deletions auth/model/keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package model

type ProviderKeys struct {
Name string `json:"name"`
Provider string `json:"provider" validate:"required" enum:"openai"`
Keys map[string]string `json:"keys"`
ProjectId int32 `json:"project_id"`
}
Loading

0 comments on commit 2d6e648

Please sign in to comment.