diff --git a/api/base/v1alpha1/datasource.go b/api/base/v1alpha1/datasource.go index 024e6a18a..73452c54e 100644 --- a/api/base/v1alpha1/datasource.go +++ b/api/base/v1alpha1/datasource.go @@ -23,19 +23,21 @@ const ( type DatasourceType string const ( - DatasourceTypeOSS DatasourceType = "oss" - DatasourceTypeRDMA DatasourceType = "RDMA" - DatasourceTypeUnknown DatasourceType = "unknown" + DatasourceTypeOSS DatasourceType = "oss" + DatasourceTypeRDMA DatasourceType = "RDMA" + DatasourceTypePostgreSQL DatasourceType = "postgresql" + DatasourceTypeUnknown DatasourceType = "unknown" ) func (ds DatasourceSpec) Type() DatasourceType { - // Object storage service - if ds.OSS != nil { + switch { + case ds.OSS != nil: return DatasourceTypeOSS - } - if ds.RDMA != nil { + case ds.RDMA != nil: return DatasourceTypeRDMA + case ds.PostgreSQL != nil: + return DatasourceTypePostgreSQL + default: + return DatasourceTypeUnknown } - - return DatasourceTypeUnknown } diff --git a/api/base/v1alpha1/datasource_types.go b/api/base/v1alpha1/datasource_types.go index 41eb7da16..5cf4117fe 100644 --- a/api/base/v1alpha1/datasource_types.go +++ b/api/base/v1alpha1/datasource_types.go @@ -20,9 +20,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN! -// NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized. - // DatasourceSpec defines the desired state of Datasource type DatasourceSpec struct { CommonSpec `json:",inline"` @@ -35,6 +32,9 @@ type DatasourceSpec struct { // RDMA configure RDMA pulls the model file directly from the remote service to the host node. RDMA *RDMA `json:"rdma,omitempty"` + + // PostgreSQL defines info for PostgreSQL + PostgreSQL *PostgreSQL `json:"postgresql,omitempty"` } type RDMA struct { @@ -54,6 +54,37 @@ type OSS struct { Object string `json:"object,omitempty"` } +// PostgreSQL defines info for PostgreSQL +// +// ref: https://github.com/jackc/pgx/blame/v5.5.1/pgconn/config.go#L409 +// they are common standard PostgreSQL environment variables +// For convenience, we use the same name. +// +// The PGUSER/PGPASSWORD/PGPASSFILE/PGSSLPASSWORD parameters have been intentionally excluded +// because they contain sensitive information and are stored in the secret pointed to by `endpoint.authSecret`. +type PostgreSQL struct { + Host string `json:"PGHOST,omitempty"` + Port string `json:"PGPORT,omitempty"` + Database string `json:"PGDATABASE,omitempty"` + AppName string `json:"PGAPPNAME,omitempty"` + ConnectTimeout string `json:"PGCONNECT_TIMEOUT,omitempty"` + SSLMode string `json:"PGSSLMODE,omitempty"` + SSLKey string `json:"PGSSLKEY,omitempty"` + SSLCert string `json:"PGSSLCERT,omitempty"` + SSLSni string `json:"PGSSLSNI,omitempty"` + SSLRootCert string `json:"PGSSLROOTCERT,omitempty"` + TargetSessionAttrs string `json:"PGTARGETSESSIONATTRS,omitempty"` + Service string `json:"PGSERVICE,omitempty"` + ServiceFile string `json:"PGSERVICEFILE,omitempty"` +} + +const ( + PGUSER = "PGUSER" + PGPASSWORD = "PGPASSWORD" + PGPASSFILE = "PGPASSFILE" + PGSSLPASSWORD = "PGSSLPASSWORD" +) + // DatasourceStatus defines the observed state of Datasource type DatasourceStatus struct { // ConditionedStatus is the current status diff --git a/api/base/v1alpha1/zz_generated.deepcopy.go b/api/base/v1alpha1/zz_generated.deepcopy.go index c27531b81..b0a10f176 100644 --- a/api/base/v1alpha1/zz_generated.deepcopy.go +++ b/api/base/v1alpha1/zz_generated.deepcopy.go @@ -359,6 +359,11 @@ func (in *DatasourceSpec) DeepCopyInto(out *DatasourceSpec) { *out = new(RDMA) (*in).DeepCopyInto(*out) } + if in.PostgreSQL != nil { + in, out := &in.PostgreSQL, &out.PostgreSQL + *out = new(PostgreSQL) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DatasourceSpec. @@ -959,6 +964,21 @@ func (in *OSS) DeepCopy() *OSS { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PostgreSQL) DeepCopyInto(out *PostgreSQL) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PostgreSQL. +func (in *PostgreSQL) DeepCopy() *PostgreSQL { + if in == nil { + return nil + } + out := new(PostgreSQL) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Prompt) DeepCopyInto(out *Prompt) { *out = *in diff --git a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_datasources.yaml b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_datasources.yaml index b9d83b7d0..9a3e464ce 100644 --- a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_datasources.yaml +++ b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_datasources.yaml @@ -99,6 +99,36 @@ spec: description: Object must end with a slash "/" if it is a directory type: string type: object + postgresql: + description: PostgreSQL defines info for PostgreSQL + properties: + PGAPPNAME: + type: string + PGCONNECT_TIMEOUT: + type: string + PGDATABASE: + type: string + PGHOST: + type: string + PGPORT: + type: string + PGSERVICE: + type: string + PGSERVICEFILE: + type: string + PGSSLCERT: + type: string + PGSSLKEY: + type: string + PGSSLMODE: + type: string + PGSSLROOTCERT: + type: string + PGSSLSNI: + type: string + PGTARGETSESSIONATTRS: + type: string + type: object rdma: description: RDMA configure RDMA pulls the model file directly from the remote service to the host node. diff --git a/config/samples/arcadia_v1alpha1_datasource_postgresql.yaml b/config/samples/arcadia_v1alpha1_datasource_postgresql.yaml new file mode 100644 index 000000000..6cd69ab53 --- /dev/null +++ b/config/samples/arcadia_v1alpha1_datasource_postgresql.yaml @@ -0,0 +1,24 @@ +## Datasource secret +apiVersion: v1 +kind: Secret +metadata: + name: datasource-postgresql-sample-authsecret + namespace: arcadia +data: + rootUser: YWRtaW4= + rootPassword: UGFzc3cwcmQh +--- +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Datasource +metadata: + name: datasource-postgresql-sample + namespace: arcadia +spec: + displayName: "postgresql 数据源示例" + endpoint: + url: postgres://arcadia-postgresql.arcadia.svc.cluster.local:5432 + authSecret: + kind: Secret + name: datasource-postgresql-sample-authsecret + postgresql: + PGDATABASE: arcadia diff --git a/controllers/datasource_controller.go b/controllers/datasource_controller.go index 236980efb..3a4fc570c 100644 --- a/controllers/datasource_controller.go +++ b/controllers/datasource_controller.go @@ -154,16 +154,16 @@ func (r *DatasourceReconciler) Checkdatasource(ctx context.Context, logger logr. logger.V(5).Info("check datasource") var err error + endpoint := instance.Spec.Endpoint.DeepCopy() + // set auth secret's namespace to the datasource's namespace + if endpoint.AuthSecret != nil { + endpoint.AuthSecret.WithNameSpace(instance.Namespace) + } // create datasource var ds datasource.Datasource var info any switch instance.Spec.Type() { case arcadiav1alpha1.DatasourceTypeOSS: - endpoint := instance.Spec.Endpoint.DeepCopy() - // set auth secret's namespace to the datasource's namespace - if endpoint.AuthSecret != nil { - endpoint.AuthSecret.WithNameSpace(instance.Namespace) - } ds, err = datasource.NewOSS(ctx, r.Client, nil, endpoint) if err != nil { return r.UpdateStatus(ctx, instance, err) @@ -171,6 +171,9 @@ func (r *DatasourceReconciler) Checkdatasource(ctx context.Context, logger logr. info = instance.Spec.OSS.DeepCopy() case arcadiav1alpha1.DatasourceTypeRDMA: return r.UpdateStatus(ctx, instance, nil) + case arcadiav1alpha1.DatasourceTypePostgreSQL: + _, err = datasource.NewPostgreSQL(ctx, r.Client, nil, instance.Spec.PostgreSQL, endpoint) + return r.UpdateStatus(ctx, instance, err) default: ds, err = datasource.NewUnknown(ctx, r.Client) if err != nil { diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index c4a2bce21..e65e358a5 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(KubeBB Component) for KubeAGI Arcadia type: application -version: 0.2.2 +version: 0.2.3 appVersion: "0.1.0" keywords: @@ -15,6 +15,10 @@ sources: maintainers: - name: bjwswang url: https://github.com/bjwswang + - name: Abirdcfly + url: https://github.com/Abirdcfly + - name: 0xff-dev + url: https://github.com/0xff-dev - name: lanture1064 url: https://github.com/lanture1064 diff --git a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_datasources.yaml b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_datasources.yaml index b9d83b7d0..9a3e464ce 100644 --- a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_datasources.yaml +++ b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_datasources.yaml @@ -99,6 +99,36 @@ spec: description: Object must end with a slash "/" if it is a directory type: string type: object + postgresql: + description: PostgreSQL defines info for PostgreSQL + properties: + PGAPPNAME: + type: string + PGCONNECT_TIMEOUT: + type: string + PGDATABASE: + type: string + PGHOST: + type: string + PGPORT: + type: string + PGSERVICE: + type: string + PGSERVICEFILE: + type: string + PGSSLCERT: + type: string + PGSSLKEY: + type: string + PGSSLMODE: + type: string + PGSSLROOTCERT: + type: string + PGSSLSNI: + type: string + PGTARGETSESSIONATTRS: + type: string + type: object rdma: description: RDMA configure RDMA pulls the model file directly from the remote service to the host node. diff --git a/go.mod b/go.mod index 599bb9372..999dc6e2d 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/go-logr/logr v1.2.3 github.com/gofiber/fiber/v2 v2.49.1 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/jackc/pgx/v5 v5.4.1 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.3 github.com/r3labs/sse/v2 v2.10.0 @@ -54,6 +55,8 @@ require ( github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.3 // indirect github.com/huandu/xstrings v1.3.3 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/minio/md5-simd v1.1.2 // indirect diff --git a/go.sum b/go.sum index d29ac8a5c..abbfa71bb 100644 --- a/go.sum +++ b/go.sum @@ -422,6 +422,12 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.1 h1:oKfB/FhuVtit1bBM3zNRRsZ925ZkMN3HXL+LgLUM9lE= +github.com/jackc/pgx/v5 v5.4.1/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= diff --git a/pkg/datasource/oss.go b/pkg/datasource/oss.go index c3beb080d..2d0d3b770 100644 --- a/pkg/datasource/oss.go +++ b/pkg/datasource/oss.go @@ -17,7 +17,6 @@ package datasource import ( "context" "crypto/tls" - "encoding/base64" "errors" "fmt" "io" @@ -75,13 +74,8 @@ func NewOSS(ctx context.Context, c client.Client, dc dynamic.Interface, endpoint return nil, err } data, _, _ := unstructured.NestedStringMap(secret.Object, "data") - - if ds, err := base64.StdEncoding.DecodeString(data["rootUser"]); err == nil { - accessKeyID = string(ds) - } - if ds, err := base64.StdEncoding.DecodeString(data["rootPassword"]); err == nil { - secretAccessKey = string(ds) - } + accessKeyID = utils.DecodeBase64Str(data["rootUser"]) + secretAccessKey = utils.DecodeBase64Str(data["rootPassword"]) } if c != nil { secret := corev1.Secret{} diff --git a/pkg/datasource/postgresql.go b/pkg/datasource/postgresql.go new file mode 100644 index 000000000..acde3d097 --- /dev/null +++ b/pkg/datasource/postgresql.go @@ -0,0 +1,185 @@ +/* +Copyright 2024 KubeAGI. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datasource + +import ( + "context" + "errors" + "io" + "os" + + "github.com/jackc/pgx/v5" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/dynamic" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/pkg/utils" +) + +var ( + _ Datasource = (*PostgreSQL)(nil) +) + +// PostgreSQL is a wrapper to PostgreSQL +type PostgreSQL struct { + *pgx.Conn +} + +func NewPostgreSQL(ctx context.Context, c client.Client, dc dynamic.Interface, config *v1alpha1.PostgreSQL, endpoint *v1alpha1.Endpoint) (*PostgreSQL, error) { + if err := SetPGEnv(ctx, c, dc, config, endpoint); err != nil { + return nil, err + } + conn, err := pgx.Connect(context.Background(), endpoint.URL) + if err != nil { + return nil, err + } + defer conn.Close(ctx) + if err := conn.Ping(ctx); err != nil { + return nil, err + } + return &PostgreSQL{conn}, nil +} +func (p *PostgreSQL) Stat(ctx context.Context, info any) error { + // TODO implement me + panic("implement me") +} + +func (p *PostgreSQL) Remove(ctx context.Context, info any) error { + // TODO implement me + panic("implement me") +} + +func (p *PostgreSQL) ReadFile(ctx context.Context, info any) (io.ReadCloser, error) { + // TODO implement me + panic("implement me") +} + +func (p *PostgreSQL) StatFile(ctx context.Context, info any) (any, error) { + // TODO implement me + panic("implement me") +} + +func (p *PostgreSQL) GetTags(ctx context.Context, info any) (map[string]string, error) { + // TODO implement me + panic("implement me") +} + +func (p *PostgreSQL) ListObjects(ctx context.Context, source string, info any) (any, error) { + // TODO implement me + panic("implement me") +} + +// SetPGEnv will export all pg setting to environment variable +func SetPGEnv(ctx context.Context, c client.Client, dc dynamic.Interface, config *v1alpha1.PostgreSQL, endpoint *v1alpha1.Endpoint) error { + var pgUser, pgPassword, pgPassFile, pgSSLPassword string + if endpoint.AuthSecret != nil { + if endpoint.AuthSecret.Namespace == nil { + return errors.New("no namespace found for endpoint.authsecret") + } + if err := utils.ValidateClient(c, dc); err != nil { + return err + } + if dc != nil { + secret, err := dc.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}). + Namespace(*endpoint.AuthSecret.Namespace).Get(ctx, endpoint.AuthSecret.Name, v1.GetOptions{}) + if err != nil { + return err + } + data, _, _ := unstructured.NestedStringMap(secret.Object, "data") + pgUser = utils.DecodeBase64Str(data[v1alpha1.PGUSER]) + pgPassword = utils.DecodeBase64Str(data[v1alpha1.PGPASSWORD]) + pgPassFile = utils.DecodeBase64Str(data[v1alpha1.PGPASSFILE]) + pgSSLPassword = utils.DecodeBase64Str(data[v1alpha1.PGSSLPASSWORD]) + } + if c != nil { + secret := corev1.Secret{} + if err := c.Get(ctx, types.NamespacedName{ + Namespace: *endpoint.AuthSecret.Namespace, + Name: endpoint.AuthSecret.Name, + }, &secret); err != nil { + return err + } + pgUser = string(secret.Data[v1alpha1.PGUSER]) + pgPassword = string(secret.Data[v1alpha1.PGPASSWORD]) + pgPassFile = string(secret.Data[v1alpha1.PGPASSFILE]) + pgSSLPassword = string(secret.Data[v1alpha1.PGSSLPASSWORD]) + } + } + if err := setenv("PGUSER", pgUser); err != nil { + return err + } + if err := setenv("PGPASSWORD", pgPassword); err != nil { + return err + } + if err := setenv("PGPASSFILE", pgPassFile); err != nil { + return err + } + if err := setenv("PGSSLPASSWORD", pgSSLPassword); err != nil { + return err + } + if err := setenv("PGHOST", config.Host); err != nil { + return err + } + if err := setenv("PGPORT", config.Port); err != nil { + return err + } + if err := setenv("PGDATABASE", config.Database); err != nil { + return err + } + if err := setenv("PGAPPNAME", config.AppName); err != nil { + return err + } + if err := setenv("PGCONNECT_TIMEOUT", config.ConnectTimeout); err != nil { + return err + } + if err := setenv("PGSSLMODE", config.SSLMode); err != nil { + return err + } + if err := setenv("PGSSLKEY", config.SSLKey); err != nil { + return err + } + if err := setenv("PGSSLCERT", config.SSLCert); err != nil { + return err + } + if err := setenv("PGSSLSNI", config.SSLSni); err != nil { + return err + } + if err := setenv("PGSSLROOTCERT", config.SSLRootCert); err != nil { + return err + } + if err := setenv("PGTARGETSESSIONATTRS", config.TargetSessionAttrs); err != nil { + return err + } + if err := setenv("PGSERVICE", config.Service); err != nil { + return err + } + if err := setenv("PGSERVICEFILE", config.ServiceFile); err != nil { + return err + } + return nil +} + +func setenv(key, value string) error { + if len(value) == 0 { + return nil + } + return os.Setenv(key, value) +} diff --git a/pkg/utils/structured.go b/pkg/utils/structured.go index 2dadef192..361192c51 100644 --- a/pkg/utils/structured.go +++ b/pkg/utils/structured.go @@ -17,6 +17,7 @@ limitations under the License. package utils import ( + "encoding/base64" "encoding/json" "fmt" "reflect" @@ -57,3 +58,11 @@ func ValidateClient(c client.Client, cli dynamic.Interface) error { } return nil } + +func DecodeBase64Str(s string) string { + ds, err := base64.StdEncoding.DecodeString(s) + if err == nil { + return string(ds) + } + return "" +} diff --git a/tests/example-test.sh b/tests/example-test.sh index b4f170e58..bbcd765a9 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -224,7 +224,8 @@ helm install -narcadia arcadia deploy/charts/arcadia -f tests/deploy-values.yaml info "4. check system datasource arcadia-minio(system datasource)" waitCRDStatusReady "Datasource" "arcadia" "arcadia-minio" -info "5. create and verify a oss datasource" +info "5. create and verify datasource" +info "5.1 oss datasource" kubectl apply -f config/samples/arcadia_v1alpha1_datasource.yaml waitCRDStatusReady "Datasource" "arcadia" "datasource-sample" datasourceType=$(kubectl get datasource -n arcadia datasource-sample -o=jsonpath='{.metadata.labels.arcadia\.kubeagi\.k8s\.com\.cn/datasource-type}') @@ -232,6 +233,14 @@ if [[ $datasourceType != "oss" ]]; then error "Datasource should be oss but got $datasourceType" exit 1 fi +info "5.2 PostgreSQL datasource" +kubectl apply -f config/samples/arcadia_v1alpha1_datasource_postgresql.yaml +waitCRDStatusReady "Datasource" "arcadia" "datasource-postgresql-sample" +datasourceType=$(kubectl get datasource -n arcadia datasource-postgresql-sample -o=jsonpath='{.metadata.labels.arcadia\.kubeagi\.k8s\.com\.cn/datasource-type}') +if [[ $datasourceType != "postgresql" ]]; then + error "Datasource should be oss but got $datasourceType" + exit 1 +fi info "6. verify default vectorstore" waitCRDStatusReady "VectorStore" "arcadia" "arcadia-vectorstore"