Skip to content

Commit

Permalink
Add JAX controller tests and webhook validations JAXJob
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <[email protected]>
  • Loading branch information
sandipanpanda committed Aug 16, 2024
1 parent 1b9766b commit cf76911
Show file tree
Hide file tree
Showing 8 changed files with 765 additions and 1 deletion.
8 changes: 8 additions & 0 deletions PROJECT
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,12 @@ resources:
kind: TFJob
path: github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1
version: v1
- api:
crdVersion: v1
namespaced: true
controller: true
group: kubeflow.org
kind: JAXJob
path: github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1
version: v1
version: "3"
2 changes: 1 addition & 1 deletion cmd/training-operator.v1/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func main() {
"Enabling this will ensure there is only one active controller manager.")
flag.StringVar(&leaderElectionID, "leader-election-id", "1ca428e5.training-operator.kubeflow.org", "The ID for leader election.")
flag.Var(&enabledSchemes, "enable-scheme", "Enable scheme(s) as --enable-scheme=tfjob --enable-scheme=pytorchjob, case insensitive."+
" Now supporting TFJob, PyTorchJob, XGBoostJob, PaddleJob. By default, all supported schemes will be enabled.")
" Now supporting TFJob, PyTorchJob, XGBoostJob, PaddleJob, JAXJob. By default, all supported schemes will be enabled.")
flag.StringVar(&gangSchedulerName, "gang-scheduler-name", "", "Now Supporting volcano and scheduler-plugins."+
" Note: If you set another scheduler name, the training-operator assumes it's the scheduler-plugins.")
flag.StringVar(&namespace, "namespace", os.Getenv(EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+
Expand Down
20 changes: 20 additions & 0 deletions manifests/base/webhook/manifests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@ kind: ValidatingWebhookConfiguration
metadata:
name: validating-webhook-configuration
webhooks:
- admissionReviewVersions:
- v1
clientConfig:
service:
name: webhook-service
namespace: system
path: /validate-kubeflow-org-v1-jaxjob
failurePolicy: Fail
name: validator.jaxjob.training-operator.kubeflow.org
rules:
- apiGroups:
- kubeflow.org
apiVersions:
- v1
operations:
- CREATE
- UPDATE
resources:
- jaxjobs
sideEffects: None
- admissionReviewVersions:
- v1
clientConfig:
Expand Down
128 changes: 128 additions & 0 deletions pkg/controller.v1/jax/jaxjob_controller_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2024 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package jax

import (
"context"
"crypto/tls"
"fmt"
"net"
"path/filepath"
"testing"
"time"

kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"github.com/kubeflow/training-operator/pkg/controller.v1/common"
jaxwebhook "github.com/kubeflow/training-operator/pkg/webhooks/jax"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"k8s.io/client-go/kubernetes/scheme"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/envtest"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"
"sigs.k8s.io/controller-runtime/pkg/webhook"
"volcano.sh/apis/pkg/apis/scheduling/v1beta1"
//+kubebuilder:scaffold:imports
)

// These tests use Ginkgo (BDD-style Go testing framework). Refer to
// http://onsi.github.io/ginkgo/ to learn more about Ginkgo.

var (
testK8sClient client.Client
testEnv *envtest.Environment
testCtx context.Context
testCancel context.CancelFunc
)

func TestAPIs(t *testing.T) {
RegisterFailHandler(Fail)

RunSpecs(t, "Controller Suite")
}

var _ = BeforeSuite(func() {
logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)))

testCtx, testCancel = context.WithCancel(context.TODO())

By("bootstrapping test environment")
testEnv = &envtest.Environment{
CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")},
ErrorIfCRDPathMissing: true,
WebhookInstallOptions: envtest.WebhookInstallOptions{
Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook", "manifests.yaml")},
},
}

cfg, err := testEnv.Start()
Expect(err).NotTo(HaveOccurred())
Expect(cfg).NotTo(BeNil())

err = v1beta1.AddToScheme(scheme.Scheme)
Expect(err).NotTo(HaveOccurred())
err = kubeflowv1.AddToScheme(scheme.Scheme)
Expect(err).NotTo(HaveOccurred())

//+kubebuilder:scaffold:scheme

testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})
Expect(err).NotTo(HaveOccurred())
Expect(testK8sClient).NotTo(BeNil())

mgr, err := ctrl.NewManager(cfg, ctrl.Options{
Metrics: metricsserver.Options{
BindAddress: "0",
},
WebhookServer: webhook.NewServer(
webhook.Options{
Host: testEnv.WebhookInstallOptions.LocalServingHost,
Port: testEnv.WebhookInstallOptions.LocalServingPort,
CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir,
}),
})
Expect(err).NotTo(HaveOccurred())

gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc()
r := NewReconciler(mgr, gangSchedulingSetupFunc)

Expect(r.SetupWithManager(mgr, 1)).NotTo(HaveOccurred())
Expect(jaxwebhook.SetupWebhook(mgr)).NotTo(HaveOccurred())

go func() {
defer GinkgoRecover()
err = mgr.Start(testCtx)
Expect(err).ToNot(HaveOccurred(), "failed to run manager")
}()

dialer := &net.Dialer{Timeout: time.Second}
addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort)
Eventually(func(g Gomega) {
conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(conn.Close()).NotTo(HaveOccurred())
}).Should(Succeed())
})

var _ = AfterSuite(func() {
By("tearing down the test environment")
testCancel()
err := testEnv.Stop()
Expect(err).NotTo(HaveOccurred())
})
Loading

0 comments on commit cf76911

Please sign in to comment.