Skip to content

Commit

Permalink
feat: support for downloading model files via rdma
Browse files Browse the repository at this point in the history
  • Loading branch information
0xff-dev committed Dec 29, 2023
1 parent c8a8a3a commit 88a331c
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 25 deletions.
6 changes: 4 additions & 2 deletions api/base/v1alpha1/datasource_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ type DatasourceSpec struct {
}

type RDMA struct {
// Path on a model storage server, the usual storage path is /path/ns/mode-name, and the path field is /path/, which must end in /.
// ServerSavePath on a model storage server, the usual storage path is /path/ns/mode-name, and the path field is /path/, which must end in /.
// example: /opt/kubeagi/, /opt/, /
// +kubebuilder:validation:Pattern=(^\/$)|(^\/[a-zA-Z0-9\_.@-]+(\/[a-zA-Z0-9\_.@-]+)*\/$)
Path string `json:"path"`
ServerSavePath string `json:"path"`

NodePaths map[string]string `json:"nodePaths,omitempty"`
}

// OSS defines info for object storage service as datasource
Expand Down
9 changes: 8 additions & 1 deletion api/base/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions config/crd/bases/arcadia.kubeagi.k8s.com.cn_datasources.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,14 @@ spec:
description: RDMA configure RDMA pulls the model file directly from
the remote service to the host node.
properties:
nodePaths:
additionalProperties:
type: string
type: object
path:
description: 'Path on a model storage server, the usual storage
path is /path/ns/mode-name, and the path field is /path/, which
must end in /. example: /opt/kubeagi/, /opt/, /'
description: 'ServerSavePath on a model storage server, the usual
storage path is /path/ns/mode-name, and the path field is /path/,
which must end in /. example: /opt/kubeagi/, /opt/, /'
pattern: (^\/$)|(^\/[a-zA-Z0-9\_.@-]+(\/[a-zA-Z0-9\_.@-]+)*\/$)
type: string
required:
Expand Down
2 changes: 2 additions & 0 deletions controllers/datasource_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ func (r *DatasourceReconciler) Checkdatasource(ctx context.Context, logger logr.
return r.UpdateStatus(ctx, instance, err)
}
info = instance.Spec.OSS.DeepCopy()
case arcadiav1alpha1.DatasourceTypeRDMA:
return r.UpdateStatus(ctx, instance, nil)
default:
ds, err = datasource.NewUnknown(ctx, r.Client)
if err != nil {
Expand Down
26 changes: 23 additions & 3 deletions controllers/worker_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,29 @@ func (r *WorkerReconciler) initialize(ctx context.Context, logger logr.Logger, i

func (r *WorkerReconciler) reconcile(ctx context.Context, logger logr.Logger, worker *arcadiav1alpha1.Worker) (*arcadiav1alpha1.Worker, error) {
logger.V(5).Info("GetSystemDatasource which hosts the worker's model files")
datasource, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return worker, errors.Wrap(err, "Failed to get system datasource")

m := arcadiav1alpha1.Model{}
ns := worker.Namespace
if worker.Spec.Model.Namespace != nil && *worker.Spec.Model.Namespace != "" {
ns = *worker.Spec.Model.Namespace
}
if err := r.Client.Get(ctx, types.NamespacedName{Name: worker.Spec.Model.Name, Namespace: ns}, &m); err != nil {
return worker, errors.Wrap(err, "failed to get model")
}

var (
datasource = &arcadiav1alpha1.Datasource{}
err error
)
if m.Spec.Source != nil {
if err = r.Client.Get(ctx, types.NamespacedName{Namespace: ns, Name: m.Spec.Source.Name}, datasource); err != nil {
return worker, errors.Wrap(err, "model config datasource, but get it failed.")
}
} else {
datasource, err = config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return worker, errors.Wrap(err, "Failed to get system datasource")
}
}

// Only PodWorker(hosts this worker via a single pod) supported now
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,14 @@ spec:
description: RDMA configure RDMA pulls the model file directly from
the remote service to the host node.
properties:
nodePaths:
additionalProperties:
type: string
type: object
path:
description: 'Path on a model storage server, the usual storage
path is /path/ns/mode-name, and the path field is /path/, which
must end in /. example: /opt/kubeagi/, /opt/, /'
description: 'ServerSavePath on a model storage server, the usual
storage path is /path/ns/mode-name, and the path field is /path/,
which must end in /. example: /opt/kubeagi/, /opt/, /'
pattern: (^\/$)|(^\/[a-zA-Z0-9\_.@-]+(\/[a-zA-Z0-9\_.@-]+)*\/$)
type: string
required:
Expand Down
51 changes: 51 additions & 0 deletions pkg/worker/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,54 @@ type LoaderGit struct{}
func (loader *LoaderGit) Build(ctx context.Context, model *arcadiav1alpha1.TypedObjectReference) (any, error) {
return nil, ErrNotImplementedYet
}

var _ ModelLoader = (*RDMALoader)(nil)

type RDMALoader struct {
c client.Client

modelName string

// UID/worker-name is the local model file storage path
workerUID string

datasource *arcadiav1alpha1.Datasource
}

func NewRDMALoader(c client.Client, modelName, workerUID string, source *arcadiav1alpha1.Datasource) *RDMALoader {
return &RDMALoader{c: c, modelName: modelName, workerUID: workerUID, datasource: source}
}

func (r *RDMALoader) Build(ctx context.Context, _ *arcadiav1alpha1.TypedObjectReference) (any, error) {
rdmaEndpoint := r.datasource.Spec.Endpoint.URL
remoteBaseSavePath := r.datasource.Spec.RDMA.ServerSavePath

container := &corev1.Container{
Name: "rdma-loader",
Image: "wetman2023/floo:23.12",
ImagePullPolicy: "IfNotPresent",
Command: []string{
"/bin/bash",
"-c",
// pulls files from the service's 'rdmaEndpoint:/remoteBaseSavePath/modelName' directory to the local 'UID' directory.
fmt.Sprintf("floo_get --from=%s --to=$TO --srv=%s --dir=%s%s", rdmaEndpoint, r.workerUID, remoteBaseSavePath, r.modelName),
},
Env: []corev1.EnvVar{
{
Name: "TO",
ValueFrom: &corev1.EnvVarSource{
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "status.hostIP",
},
},
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "tmp",
MountPath: "/tmp",
},
},
}
return container, nil
}
83 changes: 70 additions & 13 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ import (

const (
WokerCommonSuffix = "-worker"

// TODO: Currently all nodes write the same savePath
RDMASavePath = "/raid10x12/rtranmodel"

RDMALabel = "arcadia.kubeagi.k8s.com.cn/rdma"
)

var (
Expand Down Expand Up @@ -146,11 +151,24 @@ func NewPodWorker(ctx context.Context, c client.Client, s *runtime.Scheme, w *ar
podWorker.m = m

// default fields in a worker
storage := corev1.Volume{
Name: "models",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
var storage corev1.Volume
if d.Spec.Type() != arcadiav1alpha1.DatasourceTypeRDMA {
storage = corev1.Volume{
Name: "models",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
}
} else {
storage = corev1.Volume{
Name: "models",
VolumeSource: corev1.VolumeSource{
HostPath: &corev1.HostPathVolumeSource{
// /rdma/abc/uid -> /data/models
Path: fmt.Sprintf("%s/%s", d.Spec.RDMA.ServerSavePath, w.GetUID()),
},
},
}
}

service := corev1.Service{
Expand Down Expand Up @@ -194,18 +212,21 @@ func NewPodWorker(ctx context.Context, c client.Client, s *runtime.Scheme, w *ar
podWorker.service = service
podWorker.deployment = deployment

// init loader(Only oss supported yet)
endpoint := d.Spec.Endpoint.DeepCopy()
if endpoint.AuthSecret != nil && endpoint.AuthSecret.Namespace == nil {
endpoint.AuthSecret.WithNameSpace(d.Namespace)
}
switch d.Spec.Type() {
case arcadiav1alpha1.DatasourceTypeOSS:
// init loader(Only oss supported yet)
endpoint := d.Spec.Endpoint.DeepCopy()
if endpoint.AuthSecret != nil && endpoint.AuthSecret.Namespace == nil {
endpoint.AuthSecret.WithNameSpace(d.Namespace)
}
l, err := NewLoaderOSS(ctx, c, endpoint)
if err != nil {
return nil, fmt.Errorf("failed to new a loader with %w", err)
}
podWorker.l = l
case arcadiav1alpha1.DatasourceTypeRDMA:
l := NewRDMALoader(c, w.Spec.Model.Name, string(w.GetUID()), d)
podWorker.l = l
default:
return nil, fmt.Errorf("datasource %s with type %s not supported in worker", d.Name, d.Spec.Type())
}
Expand Down Expand Up @@ -244,9 +265,8 @@ func (podWorker *PodWorker) Model() *arcadiav1alpha1.Model {
// Now we have a pvc(if configured),service,LLM(if a llm model),Embedder(if a embedding model)
func (podWorker *PodWorker) BeforeStart(ctx context.Context) error {
var err error

// prepare pvc
if podWorker.Worker().Spec.Storage != nil {
// If the local directory is mounted, there is no need to create the pvc
if podWorker.Worker().Spec.Storage != nil && podWorker.storage.HostPath == nil {
pvc := &corev1.PersistentVolumeClaim{
ObjectMeta: metav1.ObjectMeta{
Namespace: podWorker.Namespace,
Expand Down Expand Up @@ -372,6 +392,16 @@ func (podWorker *PodWorker) Start(ctx context.Context) error {
}
conRunner, _ := runner.(*corev1.Container)

if podWorker.storage.HostPath != nil {
conRunner.Lifecycle = &corev1.Lifecycle{
PreStop: &corev1.LifecycleHandler{
Exec: &corev1.ExecAction{
Command: []string{"/bin/bash", "-c", fmt.Sprintf("rm -rf /data/models/%s", podWorker.Model().Name)},
},
},
}
}

// initialize deployment
desiredDep := podWorker.deployment.DeepCopy()
// configure pod template
Expand All @@ -389,6 +419,33 @@ func (podWorker *PodWorker) Start(ctx context.Context) error {
Volumes: []corev1.Volume{podWorker.storage},
},
}
if podWorker.storage.HostPath != nil {
podSpecTempalte.Spec.Affinity = &corev1.Affinity{
NodeAffinity: &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
{
Operator: corev1.NodeSelectorOpExists,
Key: RDMALabel,
},
},
},
},
},
},
}
podSpecTempalte.Spec.Volumes = append(podSpecTempalte.Spec.Volumes, corev1.Volume{
Name: "tmp",
VolumeSource: corev1.VolumeSource{
HostPath: &corev1.HostPathVolumeSource{
Path: "/tmp",
},
},
})
}

desiredDep.Spec.Template = podSpecTempalte
err = controllerutil.SetControllerReference(podWorker.Worker(), desiredDep, podWorker.s)
if err != nil {
Expand Down

0 comments on commit 88a331c

Please sign in to comment.