Skip to content

Commit

Permalink
Merge pull request kubeagi#869 from 0xff-dev/main
Browse files Browse the repository at this point in the history
chore: get the model files based on the modelsource field
  • Loading branch information
bjwswang authored Mar 15, 2024
2 parents 0515f7f + d867df0 commit 30cbb0c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
12 changes: 7 additions & 5 deletions pkg/worker/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
Expand Down Expand Up @@ -79,11 +78,14 @@ func (loader *LoaderOSS) Build(ctx context.Context, model *arcadiav1alpha1.Typed
Object: fmt.Sprintf("model/%s/", model.Name),
})
if err != nil {
if err == datasource.ErrOSSNoSuchObject {
klog.Info("No object was found, So it could pull the model file from other places.")
return nil, nil
}
return nil, err
/*
if err == datasource.ErrOSSNoSuchObject {
klog.Info("No object was found, So it could pull the model file from other places.")
return nil, nil
}
return nil, err
*/
}

var accessKeyID, secretAccessKey string
Expand Down
14 changes: 8 additions & 6 deletions pkg/worker/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
return nil, err
}
if m.Spec.HuggingFaceRepo != "" {
if m.Spec.Revision != "" {
extraArgs += fmt.Sprintf(" --revision %s ", m.Spec.Revision)
}
if m.Spec.ModelSource == modelSourceFromHugginfFace {
modelFileDir = m.Spec.HuggingFaceRepo
}
if m.Spec.ModelScopeRepo != "" {
if m.Spec.ModelSource == modelSourceFromModelScope {
modelFileDir = m.Spec.ModelScopeRepo
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_USE_MODELSCOPE", Value: "True"})
extraArgs += fmt.Sprintf(" --revision %s ", m.Spec.Revision)
}
}

Expand Down Expand Up @@ -220,12 +222,12 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
if m.Spec.Revision != "" {
extraAgrs += fmt.Sprintf(" --revision %s", m.Spec.Revision)
}
if m.Spec.HuggingFaceRepo != "" {
if m.Spec.ModelSource == modelSourceFromHugginfFace {
modelFileDir = m.Spec.HuggingFaceRepo
}
if m.Spec.ModelScopeRepo != "" {
if m.Spec.ModelSource == modelSourceFromModelScope {
modelFileDir = m.Spec.ModelScopeRepo
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_USE_MODELSCOPE", Value: "True"}, corev1.EnvVar{Name: "VLLM_USE_MODELSCOPE", Value: "True"})
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_USE_MODELSCOPE", Value: "True"})
}
}

Expand Down
18 changes: 14 additions & 4 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ const (
WokerCommonSuffix = "-worker"

RDMANodeLabel = "arcadia.kubeagi.k8s.com.cn/rdma"

modelSourceFromLocal = "local"
modelSourceFromHugginfFace = "huggingface"
modelSourceFromModelScope = "modelscope"
)

var (
Expand Down Expand Up @@ -372,13 +376,19 @@ func (podWorker *PodWorker) BeforeStart(ctx context.Context) error {

// Start will build and create worker pod which will host model service
func (podWorker *PodWorker) Start(ctx context.Context) error {
var err error
var (
err error
loader any
)

// define the way to load model
loader, err := podWorker.l.Build(ctx, &arcadiav1alpha1.TypedObjectReference{Namespace: &podWorker.m.Namespace, Name: podWorker.m.Name})
if err != nil {
return fmt.Errorf("failed to build loader with %w", err)
if podWorker.m.Spec.ModelSource == "" || podWorker.m.Spec.ModelSource == modelSourceFromLocal {
loader, err = podWorker.l.Build(ctx, &arcadiav1alpha1.TypedObjectReference{Namespace: &podWorker.m.Namespace, Name: podWorker.m.Name})
if err != nil {
return fmt.Errorf("failed to build loader with %w", err)
}
}

switch podWorker.w.Type() {
case arcadiav1alpha1.WorkerTypeFastchatVLLM:
r, err := NewRunnerFastchatVLLM(podWorker.c, podWorker.w.DeepCopy(), loader == nil)
Expand Down

0 comments on commit 30cbb0c

Please sign in to comment.