diff --git a/PROJECT b/PROJECT index 8eb19d434..b7cd7f796 100644 --- a/PROJECT +++ b/PROJECT @@ -34,4 +34,7 @@ resources: kind: Prompt path: github.com/kubeagi/arcadia/api/v1alpha1 version: v1alpha1 + webhooks: + validation: true + webhookVersion: v1 version: "3" diff --git a/api/v1alpha1/llm.go b/api/v1alpha1/llm.go index 045b716d3..69f5f4eaf 100644 --- a/api/v1alpha1/llm.go +++ b/api/v1alpha1/llm.go @@ -26,13 +26,6 @@ var ( ErrMissingAPIKey = errors.New("missing apikey in auth info") ) -type LLMType string - -const ( - OpenAI LLMType = "openai" - ZhiPuAI LLMType = "zhipuai" -) - func (o *AuthInfo) FromSecret(secret corev1.Secret) error { o.APIKey = string(secret.Data["apiKey"]) if o.APIKey == "" { diff --git a/api/v1alpha1/llm_types.go b/api/v1alpha1/llm_types.go index d5cf14484..9c109e928 100644 --- a/api/v1alpha1/llm_types.go +++ b/api/v1alpha1/llm_types.go @@ -17,6 +17,7 @@ limitations under the License. package v1alpha1 import ( + "github.com/kubeagi/arcadia/pkg/llms" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -24,8 +25,8 @@ import ( type LLMSpec struct { DisplayName string `json:"displayName,omitempty"` // Type defines the type of llm - Type LLMType `json:"type"` - // URL keeps the URL of the llm service(required) + Type llms.LLMType `json:"type"` + // URL keeps the URL of the llm service(Must required) URL string `json:"url"` // Auth keeps the authentication credentials when access llm // keeps in k8s secret diff --git a/api/v1alpha1/prompt_types.go b/api/v1alpha1/prompt_types.go index 68a04da44..1a294a277 100644 --- a/api/v1alpha1/prompt_types.go +++ b/api/v1alpha1/prompt_types.go @@ -17,19 +17,16 @@ limitations under the License. package v1alpha1 import ( + llmzhipuai "github.com/kubeagi/arcadia/pkg/llms/zhipuai" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN! -// NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized. - // PromptSpec defines the desired state of Prompt type PromptSpec struct { - // INSERT ADDITIONAL SPEC FIELDS - desired state of cluster - // Important: Run "make" to regenerate code after modifying this file - - // Foo is an example field of Prompt. Edit prompt_types.go to remove/update - Foo string `json:"foo,omitempty"` + // LLM serivice name(CRD LLM) + LLM string `json:"llm"` + // ZhiPuAIParams defines the params of ZhiPuAI + ZhiPuAIParams *llmzhipuai.ModelParams `json:"zhiPuAIParams,omitempty"` } // PromptStatus defines the observed state of Prompt diff --git a/api/v1alpha1/prompt_webhook.go b/api/v1alpha1/prompt_webhook.go new file mode 100644 index 000000000..9b3641672 --- /dev/null +++ b/api/v1alpha1/prompt_webhook.go @@ -0,0 +1,93 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + "context" + + llmzhipuai "github.com/kubeagi/arcadia/pkg/llms/zhipuai" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/webhook" +) + +// log is for logging in this package. +var promptlog = logf.Log.WithName("prompt-resource") + +func (p *Prompt) SetupWebhookWithManager(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(p). + WithDefaulter(p). + WithValidator(p). + Complete() +} + +//+kubebuilder:webhook:path=/mutate-arcadia-kubeagi-k8s-com-cn-v1alpha1-prompt,mutating=true,failurePolicy=fail,sideEffects=None,groups=arcadia.kubeagi.k8s.com.cn,resources=portals,verbs=create;update,versions=v1alpha1,name=mprompt.kb.io,admissionReviewVersions=v1 + +var _ webhook.CustomDefaulter = &Prompt{} + +func (p *Prompt) Default(ctx context.Context, obj runtime.Object) error { + promptlog.Info("default", "name", p.Name) + + // Override p.Spec.ZhiPuAIParams with default values if not nil + if p.Spec.ZhiPuAIParams != nil { + merged := llmzhipuai.MergeParams(*p.Spec.ZhiPuAIParams, llmzhipuai.DefaultModelParams()) + p.Spec.ZhiPuAIParams = &merged + } + + return nil +} + +// TODO(user): change verbs to "verbs=create;update;delete" if you want to enable deletion validation. +//+kubebuilder:webhook:path=/validate-arcadia-kubeagi-k8s-com-cn-v1alpha1-prompt,mutating=false,failurePolicy=fail,sideEffects=None,groups=arcadia.kubeagi.k8s.com.cn,resources=prompts,verbs=create;update;delete,versions=v1alpha1,name=vprompt.kb.io,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Prompt{} + +// ValidateCreate implements webhook.Validator so a webhook will be registered for the type +func (r *Prompt) ValidateCreate(ctx context.Context, obj runtime.Object) error { + promptlog.Info("validate create", "name", r.Name) + + if r.Spec.ZhiPuAIParams != nil { + if err := llmzhipuai.ValidateModelParams(*r.Spec.ZhiPuAIParams); err != nil { + promptlog.Error(err, "validate model params") + return err + } + } + + return nil +} + +// ValidateUpdate implements webhook.Validator so a webhook will be registered for the type +func (r *Prompt) ValidateUpdate(ctx context.Context, oldObj runtime.Object, newObj runtime.Object) error { + promptlog.Info("validate update", "name", r.Name) + + if r.Spec.ZhiPuAIParams != nil { + if err := llmzhipuai.ValidateModelParams(*r.Spec.ZhiPuAIParams); err != nil { + promptlog.Error(err, "validate model params") + return err + } + } + + return nil +} + +// ValidateDelete implements webhook.Validator so a webhook will be registered for the type +func (r *Prompt) ValidateDelete(ctx context.Context, obj runtime.Object) error { + promptlog.Info("validate delete", "name", r.Name) + return nil +} diff --git a/api/v1alpha1/webhook_suite_test.go b/api/v1alpha1/webhook_suite_test.go new file mode 100644 index 000000000..a5725cd76 --- /dev/null +++ b/api/v1alpha1/webhook_suite_test.go @@ -0,0 +1,135 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "path/filepath" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + admissionv1beta1 "k8s.io/api/admission/v1beta1" + //+kubebuilder:scaffold:imports + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/rest" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/envtest" + "sigs.k8s.io/controller-runtime/pkg/envtest/printer" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" +) + +// These tests use Ginkgo (BDD-style Go testing framework). Refer to +// http://onsi.github.io/ginkgo/ to learn more about Ginkgo. + +var cfg *rest.Config +var k8sClient client.Client +var testEnv *envtest.Environment +var ctx context.Context +var cancel context.CancelFunc + +func TestAPIs(t *testing.T) { + RegisterFailHandler(Fail) + + RunSpecsWithDefaultAndCustomReporters(t, + "Webhook Suite", + []Reporter{printer.NewlineReporter{}}) +} + +var _ = BeforeSuite(func() { + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + + ctx, cancel = context.WithCancel(context.TODO()) + + By("bootstrapping test environment") + testEnv = &envtest.Environment{ + CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, + ErrorIfCRDPathMissing: false, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "config", "webhook")}, + }, + } + + var err error + // cfg is defined in this file globally. + cfg, err = testEnv.Start() + Expect(err).NotTo(HaveOccurred()) + Expect(cfg).NotTo(BeNil()) + + scheme := runtime.NewScheme() + err = AddToScheme(scheme) + Expect(err).NotTo(HaveOccurred()) + + err = admissionv1beta1.AddToScheme(scheme) + Expect(err).NotTo(HaveOccurred()) + + //+kubebuilder:scaffold:scheme + + k8sClient, err = client.New(cfg, client.Options{Scheme: scheme}) + Expect(err).NotTo(HaveOccurred()) + Expect(k8sClient).NotTo(BeNil()) + + // start webhook server using Manager + webhookInstallOptions := &testEnv.WebhookInstallOptions + mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + Scheme: scheme, + Host: webhookInstallOptions.LocalServingHost, + Port: webhookInstallOptions.LocalServingPort, + CertDir: webhookInstallOptions.LocalServingCertDir, + LeaderElection: false, + MetricsBindAddress: "0", + }) + Expect(err).NotTo(HaveOccurred()) + + err = (&Prompt{}).SetupWebhookWithManager(mgr) + Expect(err).NotTo(HaveOccurred()) + + //+kubebuilder:scaffold:webhook + + go func() { + defer GinkgoRecover() + err = mgr.Start(ctx) + Expect(err).NotTo(HaveOccurred()) + }() + + // wait for the webhook server to get ready + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", webhookInstallOptions.LocalServingHost, webhookInstallOptions.LocalServingPort) + Eventually(func() error { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + return err + } + conn.Close() + return nil + }).Should(Succeed()) + +}, 60) + +var _ = AfterSuite(func() { + cancel() + By("tearing down the test environment") + err := testEnv.Stop() + Expect(err).NotTo(HaveOccurred()) +}) diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index de25254b1..0e1437b9e 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -22,7 +22,8 @@ limitations under the License. package v1alpha1 import ( - runtime "k8s.io/apimachinery/pkg/runtime" + "github.com/kubeagi/arcadia/pkg/llms/zhipuai" + "k8s.io/apimachinery/pkg/runtime" ) // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. @@ -263,7 +264,7 @@ func (in *Prompt) DeepCopyInto(out *Prompt) { *out = *in out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) - out.Spec = in.Spec + in.Spec.DeepCopyInto(&out.Spec) out.Status = in.Status } @@ -320,6 +321,11 @@ func (in *PromptList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *PromptSpec) DeepCopyInto(out *PromptSpec) { *out = *in + if in.ZhiPuAIParams != nil { + in, out := &in.ZhiPuAIParams, &out.ZhiPuAIParams + *out = new(zhipuai.ModelParams) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PromptSpec. diff --git a/assets/arch.drawio b/assets/arch.drawio index b203a5068..6e54f3036 100644 --- a/assets/arch.drawio +++ b/assets/arch.drawio @@ -1 +1 @@ -7V1bc6M4Fv41rtp9CIW4CHiM3d2Z3klvZSe1MztPUzKWbTYYeUHu2P3rVwIJA5JvCRCoGU9NtRFXn+9cPx2RiT3b7B9StF1/IwscTyxzsZ/YnyaWBRwLsn/4yKEY8V2rGFil0UIcdBx4jn5gMWiK0V20wFntQEpITKNtfTAkSYJDWhtDaUpe64ctSVy/6xatsDLwHKJYHf0tWtC1GA0s87jjJxyt1vLWlin2bJA8Wgxka7Qgr5XL2p8n9iwlhBbfNvsZjrn0pGCK876c2Fs+WYoTes0JYYDmy+0fK7x5AdEfm3/+a0qzOysoLvMdxTvxk8XT0oOUQUp2yQLzq4CJPX1dRxQ/b1HI974y1NnYmm5isTt7wTRciw31EcVTf8cpxfumLJgWYbLBND2wQ8RexxXiEwrkAWCIR3w9AgJ9cdS6goXriUEklGBVXv0oJvZFSOoGqbnOLVIzL0ttGcXxjMQkzc+1ly7/j0uTpuQFV/bA/MPPIAmtjBefDuQPHasmf9+ChmMeP0DBwrZ9FQvLcwy/IzSANLnWlHjAcPh1cwCm1h4sqLEHvytzCG4yh9vlv0DYX4Za+Yc+ni87kHPpO6TbcSzDU1UdMuG7Gs/jG25Xsm5b19uWnNV02J5OQYHOYQO3Kw0FviK1n3dzfP/wVREe+5m0LqEUZ9EPNM8P4L4c7SjJikyF70ZxtErY95DJDDO9nHJZRSyFuBc7NtFiwU+ebkmU0PynudOJ+6mh6AlJsKLjYrAeTLizeRZPDOS2eCCnixgcWFKfz0ZhHaZ2Z5A6joIdXrDkTWySlK7JiiQo/nwcnR5HHwnZChH9F1N6EALk4NbhZ3JLD//hojdcufm7QCLf+LSvbR3E1nUoZGSXhvhchiYkT1G6wvSKpIQL4SyqKY4Rjb7XM9wO8FFtzjYAz9VTjCjm92T/5yZoRklGUcLkoEP0Ec1ZVVHD5HqbaxqvaoMXcDqvfIoRlaWIuOdEpOUnjevONDzXqpvXnbCtq5ESV3/iP65yCFkuM6YzTSjLh3g7uvKBK+CiVXTHTEpBMI5ZsYYvByKUbYsKbhntuatTsgAX+wtHlwX41tzOs7DWsy2W7Dag8VwgfWHF8zmlg6z6vuNo69bltJ1ytZ5AOaApujKnqgoOejrJdVa5qeXuQ0TXu/kZ4d1evn2Aprq+Ku5AFTcA/YZojZLCmKdXTOBwxb/k/r8YY/coh8eYlrWNqQdhDdDAvFR6O32i63kKusDgv3mXrUuYUwnohlOD2cSaLRBFeVCahWyo7xSulsAd87mOUzjJXF5K4ZzWU7hGYiBrNKdeo0G7TOflRYofJc5rP3/w/L6RN0wT1tG382LpHP751hNOI/ZzuQ/5EKWQbrxrpXCCRuFuMwMJKp/6BYvnVhTkPk3RoXKYcK6nVdGu39W3XKGKR40rrtmu/qmJgMWdV0xWUaJVzQFWIdKIWilCTC+oh5vhFyHACXr3Ih9CAVztKvxOXMWtNu1Dpx5efNCHTZf0oT4jMW9JQQZo7aW2t2Pukp4RGA3e2Mv8+QjuvzOGQxM9sqNxlLBMXU4IcwQWKFuXebrELcf4idUMNCIcvzmhlGw0wFLuJlT8qwVghYG1GgytqeM2sjXa8p2b/YpPnBvoNXMMpje5K/ga8ufhalR8qx+1y/K7t87bW/VArKvWdbNKsKv6wlEppp92KxaeV0ukYQpHVrB75gALdkdLj8Ro9PSIbw9Q2q6tSPvznmKWmLBBJve/PX55/rsi+REQIc35qfYnwRVa1nU0aFq9oqnazj922wMT+WPL9rP0Qxxqp8Pnvsv7ZTrwVm6DirIvUVGeRvZWV2w4MFX2u9tCoNy4jkjC+4iWp7HvlbPY1vEkvtFy8SBz8cGQT8B0GlMqgcI+tUQuBI2cht3aqVciV5ziggunsKu+/xxown5qJNVN2UYu+jDEWaY1oyEWQ9Li318MmYble3YNDLH1VkU/aE/okAex+mZTG+7Pa4dI/yh+4xq/cS1BfjNX4runbnXKd6inXOFuzDe4qMY5nuf34qI0PcmjdFFWa/SsadguHBs7KxOPCoyriCroja2SdD3DBQ1bsm3Dd89lwzrixOpsYtZ9QzSow/ABqbH8Xpxl9ZMbe1cGHq+b3PjWYAGDpuY58KwPh0AJZLed4HugB48P1VYC23AK6n4eRzl7XzQKoq2mw2yY/l+aYRvu3wV2PUN9p/ff164iQwrw6hfoMDaogL/s5nj+TqKkdaJJsbeLnl7Le3TWYCUrlV7bKLy6g7ehfcnF69ooSr9f0iC/T44ESadu37uWEgHA7cTvK4ZoN7ogAgsYQOXXLtAj7WXgUDHPwhvHKJlU+7ZH5I5LW2ln+tSVLmws86eemo7PyGZLEpw/wDeUvuB2s/Ny2dzJxXEnl9O1PrkJG2mNCw1H8d2g1/7JoC/frXW0Zh+OFsDgSk9r9eRolXV9tmnA0472BMmjXraxLNODXVHadlOTm7zxxTO8JvPTVVKvW/+jCyJjW/wTtBlILF82Vg43dAQqkg8sEDAEZw9kpRJyY6N0TN+Q0bG0kcCwz85vAqBf+mu5ht1RwGDVh4pDihYRaTdsLzHUTy8vvGBudjG97Nv1Omtw08tquZruEsrqmbFLHvoDl7zqesJ1SjbofYJvvT/Cq7PzQ6OEy8amc0ux5ABP12uyhf/bEbnjrugsumcHAGe7L04T++WF+NL66bSyrKu4Yv0uf632uvzeBFBZfeErGtTvcnug9t/SFCXZkqQbTRvu6Pxgg4Bh0h+SHwSWIn2WRq/CNdIsVxmZ7JXmsqHJXm1EPiPyUb9eyGmUx6w4McyzYSzQQAHcDl/2BFRusigj77dbtRT5EwYSxweNt7VcRlFnUN2FEktNRhh4QAFvZHWk04wh1lvE3l0OaKsRnIldZSHHLnbfM3z1FR09y1pl25msz8WNccraHZiKq+Xi1ySiUfG+oadoi/naMAWF7DXasHSK++ZwHcWLR3QgO/6wGUXhi9yarkka/WCRGOkAEb79eMwzP7cadPCTFH1z6BuSLHwx+ogyuXItJHGMtll0DFjakJI/6JSki3x+FZhyiCtTlKzY2F2+MK0a/bThThPV6rlHAfRlZe02yEEXGJI4lVHuInVhyRXOvbxqBthqH2MYE536icWAGcVciNvKTDkfqsycT/P1g8/HdEWFq01YLhjZ1Wg1mopkL0uVTjV1fsLrChpHzUDyWQh+6p8NnMAfGjhqnrLA25gc/nzYAB8ODBxXBSdftJhFJPmr/JqobAbD51Jc6pfIg6rva4NS//jkFHqBMo8HTWMYc3dQJfCKjse72aP6Etp3Sb7/tyX7ttVsig882whuJ/G603p17tTIP2PX+mZJBkwwrJpMsw7kjMxvp08/4u3gzdVUwNa9D9l1deSoCRg8nQlb5aofH7+xgWecfo80PTYjiNBdv07gSjR1THeHUVotHguae8byLBK33Hgw4E5N8+M7NYGveqzO3zMGy5cTFm32ZnnAB7ysEMgF0pdb5mHrS6XeiZ3qEW1D33o7wF7DUvXaWEVkAq8+MSSgGm7nYal31XnveHOHtlHLbYf9JxEedI3mQm6o/bMKgSGTuFrG1mUeoVkt0GrSNqA5b9f3mlQyYGZXVoi1v/eiCTtOZ2HHU13XbI3oA0/o3qP7rRfepkaCgVaC/ZYentor8MtvP/86MOkx+25wsAORntr1+msU7pKh9V7CoWqfOhnZBuXQOnnjuIYzRPnJ1uSK/L6gjP12xJOP+6ev51KAkYckDSTOiZCkITLfUAmxzeOfoizSt+Nf9LQ//x8= \ No newline at end of file  \ No newline at end of file diff --git a/assets/arch.drawio.png b/assets/arch.drawio.png index a8621dfb4..b5367e1a6 100644 Binary files a/assets/arch.drawio.png and b/assets/arch.drawio.png differ diff --git a/config/certmanager/certificate.yaml b/config/certmanager/certificate.yaml new file mode 100644 index 000000000..52d866183 --- /dev/null +++ b/config/certmanager/certificate.yaml @@ -0,0 +1,25 @@ +# The following manifests contain a self-signed issuer CR and a certificate CR. +# More document can be found at https://docs.cert-manager.io +# WARNING: Targets CertManager v1.0. Check https://cert-manager.io/docs/installation/upgrading/ for breaking changes. +apiVersion: cert-manager.io/v1 +kind: Issuer +metadata: + name: selfsigned-issuer + namespace: system +spec: + selfSigned: {} +--- +apiVersion: cert-manager.io/v1 +kind: Certificate +metadata: + name: serving-cert # this name should match the one appeared in kustomizeconfig.yaml + namespace: system +spec: + # $(SERVICE_NAME) and $(SERVICE_NAMESPACE) will be substituted by kustomize + dnsNames: + - $(SERVICE_NAME).$(SERVICE_NAMESPACE).svc + - $(SERVICE_NAME).$(SERVICE_NAMESPACE).svc.cluster.local + issuerRef: + kind: Issuer + name: selfsigned-issuer + secretName: webhook-server-cert # this secret will not be prefixed, since it's not managed by kustomize diff --git a/config/certmanager/kustomization.yaml b/config/certmanager/kustomization.yaml new file mode 100644 index 000000000..bebea5a59 --- /dev/null +++ b/config/certmanager/kustomization.yaml @@ -0,0 +1,5 @@ +resources: +- certificate.yaml + +configurations: +- kustomizeconfig.yaml diff --git a/config/certmanager/kustomizeconfig.yaml b/config/certmanager/kustomizeconfig.yaml new file mode 100644 index 000000000..90d7c313c --- /dev/null +++ b/config/certmanager/kustomizeconfig.yaml @@ -0,0 +1,16 @@ +# This configuration is for teaching kustomize how to update name ref and var substitution +nameReference: +- kind: Issuer + group: cert-manager.io + fieldSpecs: + - kind: Certificate + group: cert-manager.io + path: spec/issuerRef/name + +varReference: +- kind: Certificate + group: cert-manager.io + path: spec/commonName +- kind: Certificate + group: cert-manager.io + path: spec/dnsNames diff --git a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_prompts.yaml b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_prompts.yaml index 103ffa99f..f3369cb8c 100644 --- a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_prompts.yaml +++ b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_prompts.yaml @@ -35,10 +35,44 @@ spec: spec: description: PromptSpec defines the desired state of Prompt properties: - foo: - description: Foo is an example field of Prompt. Edit prompt_types.go - to remove/update + llm: + description: LLM serivice name(CRD LLM) type: string + zhiPuAIParams: + description: ZhiPuAIParams defines the params of ZhiPuAI + properties: + incremental: + description: Incremental is only Used for SSE Invoke + type: boolean + method: + description: Method used for this prompt call + type: string + model: + description: Model used for this prompt call + type: string + prompt: + description: Contents + items: + description: Prompt defines the content of ZhiPuAI Prompt Call + properties: + content: + type: string + role: + type: string + type: object + type: array + task_id: + description: TaskID is used for getting result of AsyncInvoke + type: string + temperature: + description: Temperature is float in zhipuai + top_p: + description: TopP is float in zhipuai + required: + - prompt + type: object + required: + - llm type: object status: description: PromptStatus defines the observed state of Prompt diff --git a/config/default/manager_webhook_patch.yaml b/config/default/manager_webhook_patch.yaml new file mode 100644 index 000000000..738de350b --- /dev/null +++ b/config/default/manager_webhook_patch.yaml @@ -0,0 +1,23 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: controller-manager + namespace: system +spec: + template: + spec: + containers: + - name: manager + ports: + - containerPort: 9443 + name: webhook-server + protocol: TCP + volumeMounts: + - mountPath: /tmp/k8s-webhook-server/serving-certs + name: cert + readOnly: true + volumes: + - name: cert + secret: + defaultMode: 420 + secretName: webhook-server-cert diff --git a/config/default/webhookcainjection_patch.yaml b/config/default/webhookcainjection_patch.yaml new file mode 100644 index 000000000..02ab515d4 --- /dev/null +++ b/config/default/webhookcainjection_patch.yaml @@ -0,0 +1,15 @@ +# This patch add annotation to admission webhook config and +# the variables $(CERTIFICATE_NAMESPACE) and $(CERTIFICATE_NAME) will be substituted by kustomize. +apiVersion: admissionregistration.k8s.io/v1 +kind: MutatingWebhookConfiguration +metadata: + name: mutating-webhook-configuration + annotations: + cert-manager.io/inject-ca-from: $(CERTIFICATE_NAMESPACE)/$(CERTIFICATE_NAME) +--- +apiVersion: admissionregistration.k8s.io/v1 +kind: ValidatingWebhookConfiguration +metadata: + name: validating-webhook-configuration + annotations: + cert-manager.io/inject-ca-from: $(CERTIFICATE_NAMESPACE)/$(CERTIFICATE_NAME) diff --git a/config/webhook/kustomization.yaml b/config/webhook/kustomization.yaml new file mode 100644 index 000000000..9cf26134e --- /dev/null +++ b/config/webhook/kustomization.yaml @@ -0,0 +1,6 @@ +resources: +- manifests.yaml +- service.yaml + +configurations: +- kustomizeconfig.yaml diff --git a/config/webhook/kustomizeconfig.yaml b/config/webhook/kustomizeconfig.yaml new file mode 100644 index 000000000..25e21e3c9 --- /dev/null +++ b/config/webhook/kustomizeconfig.yaml @@ -0,0 +1,25 @@ +# the following config is for teaching kustomize where to look at when substituting vars. +# It requires kustomize v2.1.0 or newer to work properly. +nameReference: +- kind: Service + version: v1 + fieldSpecs: + - kind: MutatingWebhookConfiguration + group: admissionregistration.k8s.io + path: webhooks/clientConfig/service/name + - kind: ValidatingWebhookConfiguration + group: admissionregistration.k8s.io + path: webhooks/clientConfig/service/name + +namespace: +- kind: MutatingWebhookConfiguration + group: admissionregistration.k8s.io + path: webhooks/clientConfig/service/namespace + create: true +- kind: ValidatingWebhookConfiguration + group: admissionregistration.k8s.io + path: webhooks/clientConfig/service/namespace + create: true + +varReference: +- path: metadata/annotations diff --git a/config/webhook/manifests.yaml b/config/webhook/manifests.yaml new file mode 100644 index 000000000..b93950c16 --- /dev/null +++ b/config/webhook/manifests.yaml @@ -0,0 +1,55 @@ +--- +apiVersion: admissionregistration.k8s.io/v1 +kind: MutatingWebhookConfiguration +metadata: + creationTimestamp: null + name: mutating-webhook-configuration +webhooks: +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /mutate-arcadia-kubeagi-k8s-com-cn-v1alpha1-prompt + failurePolicy: Fail + name: mprompt.kb.io + rules: + - apiGroups: + - arcadia.kubeagi.k8s.com.cn + apiVersions: + - v1alpha1 + operations: + - CREATE + - UPDATE + resources: + - portals + sideEffects: None +--- +apiVersion: admissionregistration.k8s.io/v1 +kind: ValidatingWebhookConfiguration +metadata: + creationTimestamp: null + name: validating-webhook-configuration +webhooks: +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-arcadia-kubeagi-k8s-com-cn-v1alpha1-prompt + failurePolicy: Fail + name: vprompt.kb.io + rules: + - apiGroups: + - arcadia.kubeagi.k8s.com.cn + apiVersions: + - v1alpha1 + operations: + - CREATE + - UPDATE + - DELETE + resources: + - prompts + sideEffects: None diff --git a/config/webhook/service.yaml b/config/webhook/service.yaml new file mode 100644 index 000000000..3f638bd9c --- /dev/null +++ b/config/webhook/service.yaml @@ -0,0 +1,13 @@ + +apiVersion: v1 +kind: Service +metadata: + name: webhook-service + namespace: system +spec: + ports: + - port: 443 + protocol: TCP + targetPort: 9443 + selector: + control-plane: controller-manager diff --git a/controllers/prompt_controller.go b/controllers/prompt_controller.go index 1a9562d04..17e92a79a 100644 --- a/controllers/prompt_controller.go +++ b/controllers/prompt_controller.go @@ -49,8 +49,6 @@ type PromptReconciler struct { func (r *PromptReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { _ = log.FromContext(ctx) - // TODO(user): your logic here - return ctrl.Result{}, nil } diff --git a/examples/zhipuai/main.go b/examples/zhipuai/main.go new file mode 100644 index 000000000..a75cb8dee --- /dev/null +++ b/examples/zhipuai/main.go @@ -0,0 +1,46 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "os" +) + +func main() { + if len(os.Args) == 0 { + panic("api key is empty") + } + apiKey := os.Args[1] + resp, err := sampleInvoke(apiKey) + if err != nil { + panic(err) + } + fmt.Printf("SampleInvoke: \n %+v\n", resp) + + resp, err = sampleInvokeAsync(apiKey) + if err != nil { + panic(err) + } + // fmt.Printf("sampleInvokeAsync: \n %+v\n", resp) + // taskID := "76997570932704279317856632766629711813" + // resp, err = getInvokeAsyncResult(apiKey, taskID) + // if err != nil { + // panic(err) + // } + // fmt.Printf("getInvokeAsyncResult: \n %+v\n", resp) +} diff --git a/examples/zhipuai/samples.go b/examples/zhipuai/samples.go new file mode 100644 index 000000000..967bb2b26 --- /dev/null +++ b/examples/zhipuai/samples.go @@ -0,0 +1,46 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import "github.com/kubeagi/arcadia/pkg/llms/zhipuai" + +func sampleInvoke(apiKey string) (map[string]interface{}, error) { + client := zhipuai.NewZhiPuAI(apiKey) + params := zhipuai.DefaultModelParams() + params.Prompt = []zhipuai.Prompt{ + {Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."}, + } + return client.Invoke(params) +} + +func sampleInvokeAsync(apiKey string) (map[string]interface{}, error) { + client := zhipuai.NewZhiPuAI(apiKey) + params := zhipuai.DefaultModelParams() + params.Method = zhipuai.ZhiPuAIAsyncInvoke + params.Prompt = []zhipuai.Prompt{ + {Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."}, + } + return client.AsyncInvoke(params) +} + +func getInvokeAsyncResult(apiKey string, taskID string) (map[string]interface{}, error) { + client := zhipuai.NewZhiPuAI(apiKey) + params := zhipuai.DefaultModelParams() + params.Method = zhipuai.ZhiPuAIAsyncGet + params.TaskID = taskID + return client.Get(params) +} diff --git a/go.mod b/go.mod index 43820737a..9979e9ef7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/kubeagi/arcadia go 1.20 require ( + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.18.1 k8s.io/api v0.24.2 diff --git a/go.sum b/go.sum index c27022bfb..77285e76c 100644 --- a/go.sum +++ b/go.sum @@ -179,6 +179,8 @@ github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zV github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/main.go b/main.go index 37ab1c1d4..6de900844 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,8 @@ package main import ( "flag" "os" + "path/filepath" + "strconv" // Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.) // to ensure that exec-entrypoint and run can make use of them. @@ -72,6 +74,35 @@ func main() { } } + var enableWebhooks bool + // 1. Environment variable has the highest priority + v, ok := os.LookupEnv("ENABLE_WEBHOOKS") + if !ok { + // 2. options.CertDir can be configured through the config file, priority 2 + if options.CertDir != "" { + enableWebhooks = true + } else { + // 3. The default directory has a value of priority 3 + defaultPath := filepath.Join(os.TempDir(), "k8s-webhook-server", "serving-certs") + _, err := os.Stat(defaultPath) + if err == nil { + enableWebhooks = true + } + if err != nil { + if os.IsNotExist(err) { + enableWebhooks = false + } + } + } + } else { + // 4. If the environment variable is configured, but there is a configuration error, exit directly. + enableWebhooks, err = strconv.ParseBool(v) + if err != nil { + setupLog.Error(err, "unable to parse ENABLE_WEBHOOKS") + os.Exit(1) + } + } + mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), options) if err != nil { setupLog.Error(err, "unable to start manager") @@ -99,6 +130,13 @@ func main() { setupLog.Error(err, "unable to create controller", "controller", "Prompt") os.Exit(1) } + if enableWebhooks { + if err = (&arcadiav1alpha1.Prompt{}).SetupWebhookWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create webhook", "webhook", "Prompt") + os.Exit(1) + } + } + //+kubebuilder:scaffold:builder if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { diff --git a/pkg/llms/llms.go b/pkg/llms/llms.go new file mode 100644 index 000000000..86d0d5ea3 --- /dev/null +++ b/pkg/llms/llms.go @@ -0,0 +1,24 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llms + +type LLMType string + +const ( + OpenAI LLMType = "openai" + ZhiPuAI LLMType = "zhipuai" +) diff --git a/pkg/llms/zhipuai/api.go b/pkg/llms/zhipuai/api.go new file mode 100644 index 000000000..cd140071b --- /dev/null +++ b/pkg/llms/zhipuai/api.go @@ -0,0 +1,101 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// NOTE: Reference zhipuai's python sdk: model_api/api.py + +package zhipuai + +import ( + "errors" + "fmt" + "time" +) + +const ( + ZHIPUAI_MODEL_API_URL = "https://open.bigmodel.cn/api/paas/v3/model-api" + ZHIPUAI_MODEL_Default_Timeout = 300 * time.Second +) + +type Model string + +const ( + ZhiPuAILite Model = "chatglm_lite" + ZhiPuAIStd Model = "chatglm_std" + ZhiPuAIPro Model = "chatglm_pro" +) + +type Method string + +const ( + // POST + ZhiPuAIInvoke Method = "invoke" + ZhiPuAIAsyncInvoke Method = "async-invoke" + ZhiPuAISSEInvoke Method = "sse-invoke" + // GET + ZhiPuAIAsyncGet Method = "async-get" +) + +func BuildAPIURL(model Model, method Method) string { + return fmt.Sprintf("%s/%s/%s", ZHIPUAI_MODEL_API_URL, model, method) +} + +type ZhiPuAI struct { + apiKey string +} + +func NewZhiPuAI(apiKey string) *ZhiPuAI { + return &ZhiPuAI{ + apiKey: apiKey, + } +} + +// Invoke calls zhipuai and returns result immediately +func (z *ZhiPuAI) Invoke(params ModelParams) (map[string]interface{}, error) { + url := BuildAPIURL(params.Model, ZhiPuAIInvoke) + token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS) + if err != nil { + return nil, err + } + + return Post(url, token, params, ZHIPUAI_MODEL_Default_Timeout) +} + +// AsyncInvoke only returns a task id which can be used to get result of task later +func (z *ZhiPuAI) AsyncInvoke(params ModelParams) (map[string]interface{}, error) { + url := BuildAPIURL(params.Model, ZhiPuAIAsyncInvoke) + token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS) + if err != nil { + return nil, err + } + + return Post(url, token, params, ZHIPUAI_MODEL_Default_Timeout) +} + +// Get result of task async-invoke +func (z *ZhiPuAI) Get(params ModelParams) (map[string]interface{}, error) { + if params.TaskID == "" { + return nil, errors.New("TaskID is required when running Get with method AsyncInvoke") + } + + // url with task id + url := fmt.Sprintf("%s/%s", BuildAPIURL(params.Model, ZhiPuAIAsyncInvoke), params.TaskID) + token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS) + if err != nil { + return nil, err + } + + return Get(url, token, ZHIPUAI_MODEL_Default_Timeout) +} diff --git a/pkg/llms/zhipuai/http_client.go b/pkg/llms/zhipuai/http_client.go new file mode 100644 index 000000000..9bc3271ff --- /dev/null +++ b/pkg/llms/zhipuai/http_client.go @@ -0,0 +1,95 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// NOTE: Reference zhipuai's python sdk: utils/http_client.py + +package zhipuai + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "time" +) + +func setHeadersWithToken(req *http.Request, token string) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", token) +} + +func parseHTTPResponse(resp *http.Response) (map[string]interface{}, error) { + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("exception: %s", resp.Status) + } + + var data map[string]interface{} + err := json.NewDecoder(resp.Body).Decode(&data) + if err != nil { + return nil, err + } + + return data, nil +} + +func Post(apiURL, token string, params ModelParams, timeout time.Duration) (map[string]interface{}, error) { + jsonParams, err := json.Marshal(params) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonParams)) + if err != nil { + return nil, err + } + + setHeadersWithToken(req, token) + + client := http.Client{ + Timeout: timeout, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return parseHTTPResponse(resp) +} +func Get(apiURL, token string, timeout time.Duration) (map[string]interface{}, error) { + req, err := http.NewRequest("GET", apiURL, nil) + if err != nil { + return nil, err + } + + setHeadersWithToken(req, token) + + client := http.Client{ + Timeout: timeout, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return parseHTTPResponse(resp) +} + +// TODO: impl stream +func Stream(apiURL, token string, params ModelParams, timeout time.Duration) (*http.Response, error) { + return nil, nil +} diff --git a/pkg/llms/zhipuai/jwt_token.go b/pkg/llms/zhipuai/jwt_token.go new file mode 100644 index 000000000..b311dc161 --- /dev/null +++ b/pkg/llms/zhipuai/jwt_token.go @@ -0,0 +1,58 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// NOTE: Reference zhipuai's python sdk: utils/jwt_token.py +package zhipuai + +import ( + "fmt" + "strings" + "time" + + "github.com/golang-jwt/jwt" +) + +const ( + API_TOKEN_TTL_SECONDS = 3 * 60 + // FIXME: impl TLL Cache + CACHE_TTL_SECONDS = (API_TOKEN_TTL_SECONDS - 30) +) + +func GenerateToken(apikey string, expSeconds int64) (string, error) { + parts := strings.Split(apikey, ".") + if len(parts) != 2 { + return "", fmt.Errorf("invalid apikey") + } + + id := parts[0] + secret := parts[1] + + currentTime := time.Now().UnixMilli() + claims := jwt.MapClaims{ + "api_key": id, + "exp": currentTime + expSeconds*1000, + "timestamp": currentTime, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["sign_type"] = "SIGN" + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/pkg/llms/zhipuai/params.go b/pkg/llms/zhipuai/params.go new file mode 100644 index 000000000..9f8524e02 --- /dev/null +++ b/pkg/llms/zhipuai/params.go @@ -0,0 +1,125 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// NOTE: Reference zhipuai's python sdk: model_api/params.py + +package zhipuai + +import "errors" + +type Role string + +const ( + User Role = "user" + Assistant Role = "assistant" +) + +// +kubebuilder:object:generate=true +// ZhiPuAIParams defines the params of ZhiPuAI Prompt Call +type ModelParams struct { + // Method used for this prompt call + Method Method `json:"method,omitempty"` + + // Model used for this prompt call + Model Model `json:"model,omitempty"` + + //Temperature is float in zhipuai + Temperature float32 `json:"temperature,omitempty"` + // TopP is float in zhipuai + TopP float32 `json:"top_p,omitempty"` + // Contents + Prompt []Prompt `json:"prompt"` + + // TaskID is used for getting result of AsyncInvoke + TaskID string `json:"task_id,omitempty"` + + // Incremental is only Used for SSE Invoke + Incremental bool `json:"incremental,omitempty"` +} + +// +kubebuilder:object:generate=true +// Prompt defines the content of ZhiPuAI Prompt Call +type Prompt struct { + Role Role `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +func DefaultModelParams() ModelParams { + return ModelParams{ + Model: ZhiPuAILite, + Method: ZhiPuAIInvoke, + Temperature: 0.95, // zhipuai official + TopP: 0.7, // zhipuai official + Prompt: []Prompt{}, + } +} + +// MergeZhiPuAI merges b to a with this rule +// - if a.x is emtpy and b.x is not, then a.x = b.x +func MergeParams(a, b ModelParams) ModelParams { + if a.Model == "" && b.Model != "" { + a.Model = b.Model + } + if a.Method == "" && b.Method != "" { + a.Method = b.Method + } + if a.Temperature == 0 && b.Temperature != 0 { + a.Temperature = b.Temperature + } + if a.TopP == 0 && b.TopP != 0 { + a.TopP = b.TopP + } + if !a.Incremental && b.Incremental { + a.Incremental = b.Incremental + } + if len(a.Prompt) == 0 && len(b.Prompt) > 0 { + a.Prompt = b.Prompt + } + return a +} + +func ValidateModelParams(params ModelParams) error { + if params.Model == "" || params.Method == "" { + return errors.New("model or method is required") + } + + if params.Temperature < 0 || params.Temperature > 1 { + return errors.New("temperature must be in [0, 1]") + } + + if params.TopP < 0 || params.TopP > 1 { + return errors.New("top_p must be in [0, 1]") + } + + switch params.Method { + case ZhiPuAIInvoke, ZhiPuAIAsyncInvoke: + if len(params.Prompt) == 0 { + return errors.New("prompt is required") + } + if len(params.Prompt) > 1 { + return errors.New("only one prompt is allowed") + } + case ZhiPuAISSEInvoke: + case ZhiPuAIAsyncGet: + if params.TaskID == "" { + return errors.New("task_id is required") + } + default: + return errors.New("method must be one of [invoke, async-invoke, sse-invoke,get]") + } + + return nil +} diff --git a/pkg/llms/zhipuai/zz_generated.deepcopy.go b/pkg/llms/zhipuai/zz_generated.deepcopy.go new file mode 100644 index 000000000..9e29eb645 --- /dev/null +++ b/pkg/llms/zhipuai/zz_generated.deepcopy.go @@ -0,0 +1,59 @@ +//go:build !ignore_autogenerated +// +build !ignore_autogenerated + +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by controller-gen. DO NOT EDIT. + +package zhipuai + +import () + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ModelParams) DeepCopyInto(out *ModelParams) { + *out = *in + if in.Prompt != nil { + in, out := &in.Prompt, &out.Prompt + *out = make([]Prompt, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelParams. +func (in *ModelParams) DeepCopy() *ModelParams { + if in == nil { + return nil + } + out := new(ModelParams) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Prompt) DeepCopyInto(out *Prompt) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Prompt. +func (in *Prompt) DeepCopy() *Prompt { + if in == nil { + return nil + } + out := new(Prompt) + in.DeepCopyInto(out) + return out +}