diff --git a/api/base/v1alpha1/common.go b/api/base/v1alpha1/common.go index 876b40c17..62e5056d8 100644 --- a/api/base/v1alpha1/common.go +++ b/api/base/v1alpha1/common.go @@ -52,11 +52,11 @@ type Provider struct { // GetType returnes the type of this provider func (p Provider) GetType() ProviderType { - // if endpoint provided,then 3rd_party + // if endpoint provided, then 3rd_party if p.Enpoint != nil { return ProviderType3rdParty } - // if worker provided,then worker + // if worker provided, then worker if p.Worker != nil { return ProviderTypeWorker } diff --git a/controllers/knowledgebase_controller.go b/controllers/knowledgebase_controller.go index e4281ff30..b89afdf1e 100644 --- a/controllers/knowledgebase_controller.go +++ b/controllers/knowledgebase_controller.go @@ -29,6 +29,7 @@ import ( "github.com/minio/minio-go/v7" "github.com/tmc/langchaingo/documentloaders" langchainembeddings "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/textsplitter" "github.com/tmc/langchaingo/vectorstores/chroma" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -383,15 +384,47 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge return errVectorStoreNotReady } var em langchainembeddings.Embedder - switch embedder.Spec.Type { // nolint: gocritic - case embeddings.ZhiPuAI: - apiKey, err := embedder.AuthAPIKey(ctx, r.Client) + switch embedder.Spec.Provider.GetType() { + case arcadiav1alpha1.ProviderType3rdParty: + switch embedder.Spec.Type { // nolint: gocritic + case embeddings.ZhiPuAI: + apiKey, err := embedder.AuthAPIKey(ctx, r.Client) + if err != nil { + return err + } + em, err = zhipuaiembeddings.NewZhiPuAI( + zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)), + ) + if err != nil { + return err + } + } + case arcadiav1alpha1.ProviderTypeWorker: + gatway, err := config.GetGateway(ctx, r.Client) if err != nil { return err } - em, err = zhipuaiembeddings.NewZhiPuAI( - zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)), - ) + if gatway == nil { + return fmt.Errorf("global config gateway not found") + } + refWorker := embedder.Spec.Worker + if refWorker == nil { + return fmt.Errorf("embedder.spec.worker not defined") + } + worker := &arcadiav1alpha1.Worker{} + if err := r.Client.Get(ctx, types.NamespacedName{Namespace: refWorker.GetNamespace(), Name: refWorker.Name}, worker); err != nil { + return err + } + refModel := worker.Spec.Model + if refModel == nil { + return fmt.Errorf("worker.spec.model not defined") + } + modelName := refModel.Name + llm, err := openai.New(openai.WithModel(modelName), openai.WithBaseURL(gatway.APIServer), openai.WithToken("fake")) + if err != nil { + return err + } + em, err = langchainembeddings.NewEmbedder(llm) if err != nil { return err } @@ -424,7 +457,6 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge textsplitter.WithChunkSize(300), textsplitter.WithChunkOverlap(30), ) - // TODO tags -> qa or fulltext // switch { // case "token": // split = textsplitter.NewTokenSplitter( diff --git a/go.mod b/go.mod index 5142c2656..bd868aa40 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/r3labs/sse/v2 v2.10.0 github.com/spf13/cobra v1.4.0 github.com/stretchr/testify v1.8.4 - github.com/tmc/langchaingo v0.0.0-20231017212009-949349d5ef9c + github.com/tmc/langchaingo v0.0.0-20231209214832-00f364f27fe2 github.com/valyala/fasthttp v1.49.0 github.com/vektah/gqlparser/v2 v2.5.10 k8s.io/api v0.24.2 @@ -45,7 +45,6 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.3 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/klauspost/cpuid/v2 v2.2.5 // indirect - github.com/kr/pretty v0.3.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/minio/md5-simd v1.1.2 // indirect github.com/minio/sha256-simd v1.0.1 // indirect diff --git a/go.sum b/go.sum index b0f86b897..7d9ad89ef 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmW cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= +cloud.google.com/go v0.110.2 h1:sdFPBr6xG9/wkBbfhmUz/JmZC7X6LavQgcrVINrKiVA= +cloud.google.com/go/aiplatform v1.42.0 h1:otuKi5bgONobl5+3bMSrapkTJGL8zNZqtr7M0tfXbt4= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= @@ -31,6 +33,8 @@ cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2Aawl cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= +cloud.google.com/go/iam v1.0.1 h1:lyeCAU6jpnVNrE9zGQkTl3WgNgK/X+uWwaw0kynZJMU= +cloud.google.com/go/longrunning v0.4.2 h1:WDKiiNXFTaQ6qz/G8FCOkuY9kJmOJGY67wPUC1M2RbE= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= @@ -319,12 +323,15 @@ github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/gax-go/v2 v2.11.0 h1:9V9PWXEsWnPpQhu/PeQIkS4eGzMlTLGgt80cUUI8Ki4= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= @@ -405,7 +412,6 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -542,7 +548,6 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= @@ -604,8 +609,8 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/tmc/langchaingo v0.0.0-20231017212009-949349d5ef9c h1:CgTequ9Xl8+3sOsjXeKj2g2jH+P3QLNXtATONeDMziE= -github.com/tmc/langchaingo v0.0.0-20231017212009-949349d5ef9c/go.mod h1:SiwyRS7sBSSi6f3NB4dKENw69X6br/wZ2WRkM+8pZWk= +github.com/tmc/langchaingo v0.0.0-20231209214832-00f364f27fe2 h1:jYxFk98N3864Zq88+6seoUke6IUCUn8s1meYtSJuGdk= +github.com/tmc/langchaingo v0.0.0-20231209214832-00f364f27fe2/go.mod h1:VQf9L5xRny7iSOWD2qn7mAU/N7PJILIXD0RgdD9mV2k= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= @@ -657,6 +662,7 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opentelemetry.io/contrib v0.20.0/go.mod h1:G/EtFaa6qaN7+LxqfIAT3GiZa7Wv5DTBUzl5H4LY0Kc= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.20.0/go.mod h1:oVGt1LRbBOBq1A5BQLlUg9UaU/54aiHw8cgjV3aWZ/E= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.20.0/go.mod h1:2AboqHi0CiIZU0qwhtUfCYD1GeUzvvIXWNkhDt7ZMG4= @@ -826,6 +832,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1016,6 +1023,7 @@ google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34q google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8= google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU= google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94= +google.golang.org/api v0.126.0 h1:q4GJq+cAdMAC7XP7njvQ4tvohGLiSlytuL4BQxbIZ+o= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1071,6 +1079,9 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20210831024726-fe130286e0e2/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc h1:8DyZCyvI8mE1IdLy/60bS+52xfymkE72wv1asokgtao= +google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc h1:kVKPf/IiYSBWEWtkIn6wZXwWGCnLKcC8oWfZvXjsGnM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc h1:XSJ8Vk1SWuNr8S18z1NZSziL0CPIXLCCMDOEFtHBOFc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1093,6 +1104,7 @@ google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= +google.golang.org/grpc v1.57.1 h1:upNTNqv0ES+2ZOOqACwVtS3Il8M12/+Hz41RCPzAjQg= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/pkg/arctl/chat.go b/pkg/arctl/chat.go index 7b7b29e60..7c722faa8 100644 --- a/pkg/arctl/chat.go +++ b/pkg/arctl/chat.go @@ -24,7 +24,7 @@ import ( "github.com/spf13/cobra" "github.com/tmc/langchaingo/embeddings" - openaiEmbeddings "github.com/tmc/langchaingo/embeddings/openai" + "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" "github.com/tmc/langchaingo/vectorstores/chroma" @@ -146,7 +146,11 @@ func SimilaritySearch(ctx context.Context, homePath string) ([]schema.Document, return nil, err } case "openai": - embedder, err = openaiEmbeddings.NewOpenAI() + llm, err := openai.New() + if err != nil { + return nil, err + } + embedder, err = embeddings.NewEmbedder(llm) if err != nil { return nil, err } diff --git a/pkg/arctl/dataset.go b/pkg/arctl/dataset.go index 123c32c32..0e83fa8d7 100644 --- a/pkg/arctl/dataset.go +++ b/pkg/arctl/dataset.go @@ -35,7 +35,7 @@ import ( "github.com/spf13/cobra" "github.com/tmc/langchaingo/documentloaders" "github.com/tmc/langchaingo/embeddings" - openaiEmbeddings "github.com/tmc/langchaingo/embeddings/openai" + "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/textsplitter" "github.com/tmc/langchaingo/vectorstores/chroma" @@ -534,7 +534,11 @@ func (cachedDS *Dataset) embedDocuments(ctx context.Context, documents []schema. return err } case "openai": - embedder, err = openaiEmbeddings.NewOpenAI() + llm, err := openai.New() + if err != nil { + return err + } + embedder, err = embeddings.NewEmbedder(llm) if err != nil { return err }