Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expiration date to vouchers #477

Open
wants to merge 7 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions server/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ func (a *App) startBackgroundWorkers(ctx context.Context) {
go a.deployer.PeriodicRequests(ctx, substrateBlockDiffInSeconds)
go a.deployer.PeriodicDeploy(ctx, substrateBlockDiffInSeconds)

// send notification about vms and k8s expiration
go a.deployer.WarnUsersWithExpiredVMs(ctx)
go a.deployer.WarnUsersWithExpiredK8s(ctx)

// remove expired vms and k8s
go a.deployer.CleanExpiredVMs(ctx)
go a.deployer.CleanExpiredK8S(ctx)

// check pending deployments
a.deployer.ConsumeVMRequest(ctx, true)
a.deployer.ConsumeK8sRequest(ctx, true)
Expand Down
2 changes: 1 addition & 1 deletion server/app/k8s_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (a *App) K8sDeployHandler(req *http.Request) (interface{}, Response) {
return nil, InternalServerError(errors.New(internalServerErrorMsg))
}

_, err = deployer.ValidateK8sQuota(k8sDeployInput, quota.Vms, quota.PublicIPs)
_, _, err = deployer.ValidateK8sQuota(k8sDeployInput, quota.QuotaVMs, quota.PublicIPs)
if err != nil {
log.Error().Err(err).Send()
return nil, BadRequest(errors.New(err.Error()))
Expand Down
1 change: 0 additions & 1 deletion server/app/quota_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ func TestQuotaRouter(t *testing.T) {
err = app.db.CreateQuota(
&models.Quota{
UserID: user.ID.String(),
Vms: 10,
PublicIPs: 1,
},
)
Expand Down
25 changes: 18 additions & 7 deletions server/app/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/http/httptest"
"os"
"path/filepath"

"testing"

c4sDeployer "github.com/codescalers/cloud4students/deployer"
Expand Down Expand Up @@ -73,23 +72,35 @@ func SetUp(t testing.TB) *App {
`, dbPath)

err := os.WriteFile(configPath, []byte(config), 0644)
assert.NoError(t, err)
if !assert.NoError(t, err) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use assert.error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert changes the state of the testing object t, fa if I used assert.Error so every time the err = nil (meaning the test had bassed), the t object would report that the test failed

return &App{}
}

configuration, err := internal.ReadConfFile(configPath)
assert.NoError(t, err)
if !assert.NoError(t, err) {
return &App{}
}

db := models.NewDB()
err = db.Connect(configuration.Database.File)
assert.NoError(t, err)
if !assert.NoError(t, err) {
return &App{}
}

err = db.Migrate()
assert.NoError(t, err)
if !assert.NoError(t, err) {
return &App{}
}

tfPluginClient, err := deployer.NewTFPluginClient(configuration.Account.Mnemonics, "sr25519", configuration.Account.Network, "", "", "", 0, false)
assert.NoError(t, err)
if !assert.NoError(t, err) {
return &App{}
}

newDeployer, err := c4sDeployer.NewDeployer(db, streams.RedisClient{}, tfPluginClient)
assert.NoError(t, err)
if !assert.NoError(t, err) {
return &App{}
}

app := &App{
config: configuration,
Expand Down
49 changes: 36 additions & 13 deletions server/app/user_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,22 @@ type EmailInput struct {

// ApplyForVoucherInput struct for user to apply for voucher
type ApplyForVoucherInput struct {
VMs int `json:"vms" binding:"required" validate:"min=0"`
PublicIPs int `json:"public_ips" binding:"required" validate:"min=0"`
Reason string `json:"reason" binding:"required" validate:"nonzero"`
VMs int `json:"vms" binding:"required" validate:"min=0"`
PublicIPs int `json:"public_ips" binding:"required" validate:"min=0"`
Reason string `json:"reason" binding:"required" validate:"nonzero"`
VoucherDurationInMonth int `json:"voucher_duration_in_month" binding:"required"`
}

// AddVoucherInput struct for voucher applied by user
type AddVoucherInput struct {
Voucher string `json:"voucher" binding:"required"`
Voucher string `json:"voucher" binding:"required"`
VoucherDurationInMonth int `json:"voucher_duration_in_month" binding:"required"`
}

// SignUpHandler creates account for user
func (a *App) SignUpHandler(req *http.Request) (interface{}, Response) {
var signUp SignUpInput
err := json.NewDecoder(req.Body).Decode(&signUp)

if err != nil {
log.Error().Err(err).Send()
return nil, BadRequest(errors.New("failed to read sign up data"))
Expand Down Expand Up @@ -163,7 +164,6 @@ func (a *App) SignUpHandler(req *http.Request) (interface{}, Response) {
// create empty quota
quota := models.Quota{
UserID: u.ID.String(),
Vms: 0,
}
err = a.db.CreateQuota(&quota)
if err != nil {
Expand Down Expand Up @@ -573,14 +573,20 @@ func (a *App) ApplyForVoucherHandler(req *http.Request) (interface{}, Response)
return nil, BadRequest(errors.New("invalid voucher data"))
}

// make sure the requested duration is less that the maximum allowed duration
if input.VoucherDurationInMonth > a.config.VouchersMaxDuration {
return nil, BadRequest(fmt.Errorf("invalid voucher duration, max duration is %d", a.config.VouchersMaxDuration))
}

// generate voucher for user but can't use it until admin approves it
v := internal.GenerateRandomVoucher(5)
voucher := models.Voucher{
Voucher: v,
UserID: userID,
VMs: input.VMs,
Reason: input.Reason,
PublicIPs: input.PublicIPs,
Voucher: v,
UserID: userID,
VMs: input.VMs,
Reason: input.Reason,
PublicIPs: input.PublicIPs,
VoucherDurationInMonth: input.VoucherDurationInMonth,
}

err = a.db.CreateVoucher(&voucher)
Expand All @@ -607,7 +613,7 @@ func (a *App) ActivateVoucherHandler(req *http.Request) (interface{}, Response)
return nil, BadRequest(errors.New("failed to read voucher data"))
}

oldQuota, err := a.db.GetUserQuota(userID)
quota, err := a.db.GetUserQuota(userID)
if err == gorm.ErrRecordNotFound {
return nil, NotFound(errors.New("user quota is not found"))
}
Expand Down Expand Up @@ -643,15 +649,32 @@ func (a *App) ActivateVoucherHandler(req *http.Request) (interface{}, Response)
return nil, InternalServerError(errors.New(internalServerErrorMsg))
}

err = a.db.UpdateUserQuota(userID, oldQuota.Vms+voucherQuota.VMs, oldQuota.PublicIPs+voucherQuota.PublicIPs)
err = a.db.UpdateUserQuota(userID, quota.PublicIPs+voucherQuota.PublicIPs)
if err != nil {
log.Error().Err(err).Send()
return nil, InternalServerError(errors.New(internalServerErrorMsg))
}

vms := getDurationVMs(quota, voucherQuota.VoucherDurationInMonth)
err = a.db.UpdateUserQuotaVMs(quota.ID, voucherQuota.VoucherDurationInMonth, vms+voucherQuota.VMs)
if err != nil {
log.Error().Err(err).Send()
return nil, InternalServerError(errors.New(internalServerErrorMsg))
}

middlewares.VoucherActivated.WithLabelValues(userID, voucherQuota.Voucher, fmt.Sprint(voucherQuota.VMs), fmt.Sprint(voucherQuota.PublicIPs)).Inc()

return ResponseMsg{
Message: "Voucher is applied successfully",
Data: nil,
}, Ok()
}

func getDurationVMs(quota models.Quota, duration int) int {
for _, q := range quota.QuotaVMs {
if duration == q.Duration {
return q.VMs
}
}
return 0
}
2 changes: 1 addition & 1 deletion server/app/vm_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (a *App) DeployVMHandler(req *http.Request) (interface{}, Response) {
return nil, InternalServerError(errors.New(internalServerErrorMsg))
}

_, err = deployer.ValidateVMQuota(input, quota.Vms, quota.PublicIPs)
_, _, err = deployer.ValidateVMQuota(input, quota.QuotaVMs, quota.PublicIPs)
if err != nil {
return nil, BadRequest(errors.New(err.Error()))
}
Expand Down
21 changes: 14 additions & 7 deletions server/app/voucher_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package app
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"

Expand All @@ -17,9 +18,10 @@ import (

// GenerateVoucherInput struct for data needed when user generate vouchers
type GenerateVoucherInput struct {
Length int `json:"length" binding:"required" validate:"min=3,max=20"`
VMs int `json:"vms" binding:"required"`
PublicIPs int `json:"public_ips" binding:"required"`
Length int `json:"length" binding:"required" validate:"min=3,max=20"`
VMs int `json:"vms" binding:"required"`
PublicIPs int `json:"public_ips" binding:"required"`
VoucherDurationInMonth int `json:"voucher_duration_in_month" binding:"required"`
}

// UpdateVoucherInput struct for data needed when user update voucher
Expand All @@ -43,11 +45,16 @@ func (a *App) GenerateVoucherHandler(req *http.Request) (interface{}, Response)
}
voucher := internal.GenerateRandomVoucher(input.Length)

if input.VoucherDurationInMonth > a.config.VouchersMaxDuration {
return nil, BadRequest(fmt.Errorf("invalid voucher duration, max duration is %d", a.config.VouchersMaxDuration))
}

v := models.Voucher{
Voucher: voucher,
VMs: input.VMs,
PublicIPs: input.PublicIPs,
Approved: true,
Voucher: voucher,
VMs: input.VMs,
PublicIPs: input.PublicIPs,
Approved: true,
VoucherDurationInMonth: input.VoucherDurationInMonth,
}

err = a.db.CreateVoucher(&v)
Expand Down
132 changes: 132 additions & 0 deletions server/deployer/deployer.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,138 @@ func (d *Deployer) CancelDeployment(contractID uint64, netContractID uint64, dlT
return nil
}

func (d *Deployer) WarnUsersWithExpiredVMs(ctx context.Context) {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
users, err := d.db.ListAllUsers()
if err != nil {
log.Error().Err(err).Msg("failed to get all users")
return
}

for _, user := range users {
vms, err := d.db.GetAllVms(user.UserID)
if err != nil {
log.Error().Err(err).Msg("failed to get all user vms")
continue
}

for _, vm := range vms {
if time.Now().Before(vm.ExpiresAt) && time.Until(vm.ExpiresAt) < time.Hour*24 {
notification := models.Notification{
UserID: user.UserID,
Msg: fmt.Sprintf("Warning: vm with id %d expires in one day", vm.ID),
Type: models.VMsType,
}

err = d.db.CreateNotification(&notification)
if err != nil {
log.Error().Err(err).Msgf("failed to create notification: %+v", notification)
}
}
}
}
}
}

func (d *Deployer) WarnUsersWithExpiredK8s(ctx context.Context) {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
users, err := d.db.ListAllUsers()
if err != nil {
log.Error().Err(err).Msg("failed to get all users")
return
}

for _, user := range users {
k8s, err := d.db.GetAllK8s(user.UserID)
if err != nil {
log.Error().Err(err).Msg("failed to get all user k8s clusters")
continue
}

for _, k := range k8s {
if time.Now().Before(k.ExpiresAt) && time.Until(k.ExpiresAt) < time.Hour*24 {
notification := models.Notification{
UserID: user.UserID,
Msg: fmt.Sprintf("Warning: k8s cluster with id %d expires in one day", k.ID),
Type: models.K8sType,
}

err = d.db.CreateNotification(&notification)
if err != nil {
log.Error().Err(err).Msgf("failed to create notification: %+v", notification)
}
}
}
}
}
}

func (d *Deployer) CleanExpiredVMs(ctx context.Context) {
Eslam-Nawara marked this conversation as resolved.
Show resolved Hide resolved
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
users, err := d.db.ListAllUsers()
if err != nil {
log.Error().Err(err).Msg("failed to get all users")
return
}

for _, user := range users {
vms, err := d.db.GetAllVms(user.UserID)
if err != nil {
log.Error().Err(err).Msg("failed to get all user vms")
continue
}

for _, vm := range vms {
if vm.ExpiresAt.Before(time.Now()) {
err = d.CancelDeployment(vm.ContractID, vm.NetworkContractID, "vm", vm.Name)
if err != nil {
log.Error().Err(err).Msg("failed to cancel contract of expired vm")
}
err := d.db.DeleteVMByID(vm.ID)
if err != nil {
log.Error().Err(err).Msg("failed to delete expired vm")
}
}
}
}
}
}

func (d *Deployer) CleanExpiredK8S(ctx context.Context) {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
users, err := d.db.ListAllUsers()
if err != nil {
log.Error().Err(err).Msg("failed to get all users")
return
}

for _, user := range users {
k8s, err := d.db.GetAllK8s(user.UserID)
if err != nil {
log.Error().Err(err).Msg("failed to get all user k8s clusters")
continue
}

for _, k := range k8s {
if k.ExpiresAt.Before(time.Now()) {
err = d.CancelDeployment(uint64(k.ClusterContract), uint64(k.NetworkContract), "k8s", k.Master.Name)
if err != nil {
log.Error().Err(err).Msg("failed to cancel contract of expired k8s cluster")
}
err := d.db.DeleteVMByID(k.ID)
if err != nil {
log.Error().Err(err).Msg("failed to delete expired k8s cluster")
}
}
}
}
}
}

func buildNetwork(node uint32, name string) workloads.ZNet {
return workloads.ZNet{
Name: name,
Expand Down
Loading
Loading