diff --git a/Makefile b/Makefile index a21519c51..593d2bd54 100644 --- a/Makefile +++ b/Makefile @@ -143,6 +143,9 @@ vet: generate: controller-gen $(CONTROLLER_GEN) object:headerFile="hack/boilerplate.go.txt" paths="./..." +mock-generate: gomock + go generate ./... + CONTROLLER_GEN = $(BIN_DIR)/controller-gen controller-gen: ## Download controller-gen locally if necessary. $(call go-install-tool,$(CONTROLLER_GEN),sigs.k8s.io/controller-tools/cmd/controller-gen@v0.9.0) @@ -155,6 +158,10 @@ ENVTEST = $(BIN_DIR)/setup-envtest envtest: ## Download envtest-setup locally if necessary. $(call go-install-tool,$(ENVTEST),sigs.k8s.io/controller-runtime/tools/setup-envtest@latest) +GOMOCK = $(shell pwd)/bin/mockgen +gomock: + $(call go-get-tool,$(GOMOCK),github.com/golang/mock/mockgen@v1.6.0) + # go-install-tool will 'go install' any package $2 and install it to $1. define go-install-tool @[ -f $(1) ] || { \ diff --git a/api/v1/helper.go b/api/v1/helper.go index 6479b4c0a..2087ca669 100644 --- a/api/v1/helper.go +++ b/api/v1/helper.go @@ -57,6 +57,13 @@ const ( SupportedNicIDConfigmap = "supported-nic-ids" ) +type ConfigurationModeType string + +const ( + DaemonConfigurationMode ConfigurationModeType = "daemon" + SystemdConfigurationMode ConfigurationModeType = "systemd" +) + func (e NetFilterType) String() string { switch e { case OpenstackNetworkID: @@ -66,7 +73,7 @@ func (e NetFilterType) String() string { } } -func InitNicIDMap(client kubernetes.Interface, namespace string) error { +func InitNicIDMapFromConfigMap(client kubernetes.Interface, namespace string) error { cm, err := client.CoreV1().ConfigMaps(namespace).Get( context.Background(), SupportedNicIDConfigmap, @@ -79,9 +86,14 @@ func InitNicIDMap(client kubernetes.Interface, namespace string) error { for _, v := range cm.Data { NicIDMap = append(NicIDMap, v) } + return nil } +func InitNicIDMapFromList(idList []string) { + NicIDMap = append(NicIDMap, idList...) +} + func IsSupportedVendor(vendorID string) bool { for _, n := range NicIDMap { ids := strings.Split(n, " ") diff --git a/api/v1/sriovoperatorconfig_types.go b/api/v1/sriovoperatorconfig_types.go index 220c2bbd3..fc4fe4b64 100644 --- a/api/v1/sriovoperatorconfig_types.go +++ b/api/v1/sriovoperatorconfig_types.go @@ -39,6 +39,10 @@ type SriovOperatorConfigSpec struct { DisableDrain bool `json:"disableDrain,omitempty"` // Flag to enable OVS hardware offload. Set to 'true' to provision switchdev-configuration.service and enable OpenvSwitch hw-offload on nodes. EnableOvsOffload bool `json:"enableOvsOffload,omitempty"` + // Flag to enable the sriov-network-config-daemon to use a systemd service to configure SR-IOV devices on boot + // Default mode: daemon + // +kubebuilder:validation:Enum=daemon;systemd + ConfigurationMode ConfigurationModeType `json:"configurationMode,omitempty"` } // SriovOperatorConfigStatus defines the observed state of SriovOperatorConfig diff --git a/bindata/manifests/daemon/daemonset.yaml b/bindata/manifests/daemon/daemonset.yaml index ed5948900..4853a9811 100644 --- a/bindata/manifests/daemon/daemonset.yaml +++ b/bindata/manifests/daemon/daemonset.yaml @@ -28,7 +28,7 @@ spec: hostPID: true nodeSelector: kubernetes.io/os: linux - node-role.kubernetes.io/worker: + node-role.kubernetes.io/worker: "" tolerations: - operator: Exists serviceAccountName: sriov-network-config-daemon @@ -39,16 +39,66 @@ spec: - name: {{ . }} {{- end }} {{- end }} + initContainers: + - name: sriov-cni + image: {{.SRIOVCNIImage}} + command: + - /bin/sh + - -c + - cp /usr/bin/sriov /host/opt/cni/bin/ + securityContext: + privileged: true + resources: + requests: + cpu: 10m + memory: 10Mi + volumeMounts: + - name: cnibin + mountPath: /host/opt/cni/bin + - name: sriov-infiniband-cni + image: {{.SRIOVInfiniBandCNIImage}} + command: + - /bin/sh + - -c + - cp /usr/bin/ib-sriov /host/opt/cni/bin/ + securityContext: + privileged: true + resources: + requests: + cpu: 10m + memory: 10Mi + volumeMounts: + - name: cnibin + mountPath: /host/opt/cni/bin + {{- if .UsedSystemdMode}} + - name: sriov-service-copy + image: {{.Image}} + command: + - /bin/bash + - -c + - mkdir -p /host/var/lib/sriov/ && cp /usr/bin/sriov-network-config-daemon /host/var/lib/sriov/sriov-network-config-daemon && chcon -t bin_t /host/var/lib/sriov/sriov-network-config-daemon | true # Allow systemd to run the file, use pipe true to not failed if the system doesn't have selinux or apparmor enabled + securityContext: + privileged: true + resources: + requests: + cpu: 10m + memory: 10Mi + volumeMounts: + - name: host + mountPath: /host + {{- end }} containers: - name: sriov-network-config-daemon image: {{.Image}} command: - sriov-network-config-daemon - imagePullPolicy: IfNotPresent securityContext: privileged: true args: - "start" + {{- if .UsedSystemdMode}} + - --use-systemd-service + {{- end }} env: - name: NODE_NAME valueFrom: @@ -73,28 +123,6 @@ spec: preStop: exec: command: ["/bindata/scripts/clean-k8s-services.sh"] - - name: sriov-cni - image: {{.SRIOVCNIImage}} - securityContext: - privileged: true - resources: - requests: - cpu: 10m - memory: 10Mi - volumeMounts: - - name: cnibin - mountPath: /host/opt/cni/bin - - name: sriov-infiniband-cni - image: {{.SRIOVInfiniBandCNIImage}} - securityContext: - privileged: true - resources: - requests: - cpu: 10m - memory: 10Mi - volumeMounts: - - name: cnibin - mountPath: /host/opt/cni/bin volumes: - name: host hostPath: diff --git a/bindata/manifests/sriov-config-service/kubernetes/sriov-config-service.yaml b/bindata/manifests/sriov-config-service/kubernetes/sriov-config-service.yaml new file mode 100644 index 000000000..5c7470831 --- /dev/null +++ b/bindata/manifests/sriov-config-service/kubernetes/sriov-config-service.yaml @@ -0,0 +1,15 @@ +contents: | + [Unit] + Description=Configures SRIOV NIC + Wants=network-pre.target + Before=network-pre.target + + [Service] + Type=oneshot + ExecStart=/var/lib/sriov/sriov-network-config-daemon service + StandardOutput=journal+console + + [Install] + WantedBy=multi-user.target +enabled: true +name: sriov-config.service diff --git a/bindata/manifests/sriov-config-service/openshift/sriov-config-service.yaml b/bindata/manifests/sriov-config-service/openshift/sriov-config-service.yaml new file mode 100644 index 000000000..d91a08707 --- /dev/null +++ b/bindata/manifests/sriov-config-service/openshift/sriov-config-service.yaml @@ -0,0 +1,30 @@ +apiVersion: machineconfiguration.openshift.io/v1 +kind: MachineConfig +metadata: + labels: + machineconfiguration.openshift.io/role: worker + name: sriov-config-service +spec: + config: + ignition: + version: 3.2.0 + systemd: + units: + - contents: | + [Unit] + Description=Configures SRIOV NIC + # Removal of this file signals firstboot completion + ConditionPathExists=!/etc/ignition-machine-config-encapsulated.json + # This service is used to configure the SR-IOV VFs on NICs + Wants=network-pre.target + Before=network-pre.target + + [Service] + Type=oneshot + ExecStart=/var/lib/sriov/sriov-network-config-daemon service -v {{ .LogLevel }} + StandardOutput=journal+console + + [Install] + WantedBy=multi-user.target + enabled: true + name: "sriov-config.service" diff --git a/bindata/scripts/enable-rdma.sh b/bindata/scripts/enable-rdma.sh deleted file mode 100755 index 60fc1915a..000000000 --- a/bindata/scripts/enable-rdma.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash - -REDHAT_RELEASE_FILE="/host/etc/redhat-release" - -RDMA_CONDITION_FILE="" -RDMA_SERVICE_NAME="" -PACKAGE_MANAGER="" - -function kmod_isloaded { - if grep --quiet '\(^ib\|^rdma\)' <(chroot /host/ lsmod); then - echo "RDMA kernel modules loaded" - true - else - echo "RDMA kernel modules not loaded" - false - fi -} - -function trigger_udev_event { - echo "Trigger udev event" - chroot /host/ modprobe -r mlx4_en && chroot /host/ modprobe mlx4_en - chroot /host/ modprobe -r mlx5_core && chroot /host/ modprobe mlx5_core -} - -function enable_rdma { - if [ -f "$RDMA_CONDITION_FILE" ]; then - echo "$RDMA_SERVICE_NAME.service installed" - if kmod_isloaded; then - exit - else - trigger_udev_event - fi - else - chroot /host/ $PACKAGE_MANAGER install -y rdma-core - trigger_udev_event - fi - - if kmod_isloaded; then - exit - else - exit 1 - fi -} - -if ! grep --quiet 'mlx4_en' <(chroot /host/ lsmod) && ! grep --quiet 'mlx5_core' <(chroot /host/ lsmod); then - echo "No RDMA capable device" - exit 1 -fi - -if [ -f "$REDHAT_RELEASE_FILE" ]; then - if grep --quiet CoreOS "$REDHAT_RELEASE_FILE"; then - echo "It's CoreOS, exit" - if kmod_isloaded; then - exit - else - exit 1 - fi - else - RDMA_CONDITION_FILE="/host/usr/libexec/rdma-init-kernel" - RDMA_SERVICE_NAME="rdma" - PACKAGE_MANAGER=yum - - enable_rdma - fi -elif grep -i --quiet 'ubuntu' /host/etc/os-release ; then - RDMA_CONDITION_FILE="/host/usr/sbin/rdma-ndd" - RDMA_SERVICE_NAME="rdma-ndd" - PACKAGE_MANAGER=apt-get - - enable_rdma -else - os=$(cat /etc/os-release | grep PRETTY_NAME | cut -c 13-) - echo "Unsupported OS: $os" - exit 1 -fi diff --git a/bindata/scripts/load-kmod.sh b/bindata/scripts/load-kmod.sh deleted file mode 100755 index f91762d5f..000000000 --- a/bindata/scripts/load-kmod.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/sh -# chroot /host/ modprobe $1 -kmod_name=$(echo $1 | tr "-" "_") -kmod_args="${@:2}" -chroot /host/ lsmod | grep "^$1" >& /dev/null - -if [ $? -eq 0 ] -then - # NOTE: We do not check if the module is loaded with specific options - # so a manual reload is required if the module is loaded with - # new or different options. - echo "Module $kmod_name already loaded; no change will be applied..." - exit 0 -else - chroot /host/ modprobe $kmod_name $kmod_args -fi diff --git a/cmd/sriov-network-config-daemon/main.go b/cmd/sriov-network-config-daemon/main.go index 07719701b..3ceeed3ae 100644 --- a/cmd/sriov-network-config-daemon/main.go +++ b/cmd/sriov-network-config-daemon/main.go @@ -1,3 +1,18 @@ +/* +Copyright 2023. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package main import ( diff --git a/cmd/sriov-network-config-daemon/service.go b/cmd/sriov-network-config-daemon/service.go new file mode 100644 index 000000000..b300254ac --- /dev/null +++ b/cmd/sriov-network-config-daemon/service.go @@ -0,0 +1,182 @@ +/* +Copyright 2023. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package main + +import ( + "errors" + "flag" + "fmt" + "os" + + "github.com/golang/glog" + "github.com/spf13/cobra" + + sriovv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/host" + plugin "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins/generic" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins/virtual" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/systemd" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/version" +) + +var ( + serviceCmd = &cobra.Command{ + Use: "service", + Short: "Starts SR-IOV service Config", + Long: "", + RunE: runServiceCmd, + } +) + +func init() { + rootCmd.AddCommand(serviceCmd) +} + +func runServiceCmd(cmd *cobra.Command, args []string) error { + flag.Set("logtostderr", "true") + flag.Parse() + + // To help debugging, immediately log version + glog.V(2).Infof("Version: %+v", version.Version) + + glog.V(0).Info("Starting sriov-config-service") + supportedNicIds, err := systemd.ReadSriovSupportedNics() + if err != nil { + glog.Errorf("failed to read list of supported nic ids") + sriovResult := &systemd.SriovResult{ + SyncStatus: "Failed", + LastSyncError: fmt.Sprintf("failed to read list of supported nic ids: %v", err), + } + err = systemd.WriteSriovResult(sriovResult) + if err != nil { + glog.Errorf("sriov-config-service failed to write sriov result file with content %v error: %v", *sriovResult, err) + return fmt.Errorf("sriov-config-service failed to write sriov result file with content %v error: %v", *sriovResult, err) + } + return fmt.Errorf("sriov-config-service failed to read list of supported nic ids: %v", err) + } + sriovv1.InitNicIDMapFromList(supportedNicIds) + + nodeStateSpec, err := systemd.ReadConfFile() + if err != nil { + if _, err := os.Stat(systemd.SriovSystemdConfigPath); !errors.Is(err, os.ErrNotExist) { + glog.Errorf("failed to read the sriov configuration file in path %s: %v", systemd.SriovSystemdConfigPath, err) + sriovResult := &systemd.SriovResult{ + SyncStatus: "Failed", + LastSyncError: fmt.Sprintf("failed to read the sriov configuration file in path %s: %v", systemd.SriovSystemdConfigPath, err), + } + err = systemd.WriteSriovResult(sriovResult) + if err != nil { + glog.Errorf("sriov-config-service failed to write sriov result file with content %v error: %v", *sriovResult, err) + return fmt.Errorf("sriov-config-service failed to write sriov result file with content %v error: %v", *sriovResult, err) + } + } + + nodeStateSpec = &systemd.SriovConfig{ + Spec: sriovv1.SriovNetworkNodeStateSpec{}, + UnsupportedNics: false, + PlatformType: utils.Baremetal, + } + } + + glog.V(2).Infof("sriov-config-service read config: %v", nodeStateSpec) + + // Load kernel modules + hostManager := host.NewHostManager(true) + _, err = hostManager.TryEnableRdma() + if err != nil { + glog.Warningf("failed to enable RDMA: %v", err) + } + hostManager.TryEnableTun() + hostManager.TryEnableVhostNet() + + var configPlugin plugin.VendorPlugin + var ifaceStatuses []sriovv1.InterfaceExt + if nodeStateSpec.PlatformType == utils.Baremetal { + // Bare metal support + ifaceStatuses, err = utils.DiscoverSriovDevices(nodeStateSpec.UnsupportedNics) + if err != nil { + glog.Errorf("sriov-config-service: failed to discover sriov devices on the host: %v", err) + return fmt.Errorf("sriov-config-service: failed to discover sriov devices on the host: %v", err) + } + + // Create the generic plugin + configPlugin, err = generic.NewGenericPlugin(true) + if err != nil { + glog.Errorf("sriov-config-service: failed to create generic plugin %v", err) + return fmt.Errorf("sriov-config-service failed to create generic plugin %v", err) + } + } else if nodeStateSpec.PlatformType == utils.VirtualOpenStack { + // Openstack support + metaData, networkData, err := utils.GetOpenstackData(false) + if err != nil { + glog.Errorf("sriov-config-service: failed to read OpenStack data: %v", err) + return fmt.Errorf("sriov-config-service failed to read OpenStack data: %v", err) + } + + openStackDevicesInfo, err := utils.CreateOpenstackDevicesInfo(metaData, networkData) + if err != nil { + glog.Errorf("failed to read OpenStack data: %v", err) + return fmt.Errorf("sriov-config-service failed to read OpenStack data: %v", err) + } + + ifaceStatuses, err = utils.DiscoverSriovDevicesVirtual(openStackDevicesInfo) + if err != nil { + glog.Errorf("sriov-config-service:failed to read OpenStack data: %v", err) + return fmt.Errorf("sriov-config-service: failed to read OpenStack data: %v", err) + } + + // Create the virtual plugin + configPlugin, err = virtual.NewVirtualPlugin(true) + if err != nil { + glog.Errorf("sriov-config-service: failed to create virtual plugin %v", err) + return fmt.Errorf("sriov-config-service: failed to create virtual plugin %v", err) + } + } + + nodeState := &sriovv1.SriovNetworkNodeState{ + Spec: nodeStateSpec.Spec, + Status: sriovv1.SriovNetworkNodeStateStatus{Interfaces: ifaceStatuses}, + } + + _, _, err = configPlugin.OnNodeStateChange(nodeState) + if err != nil { + glog.Errorf("sriov-config-service: failed to run OnNodeStateChange to update the generic plugin status %v", err) + return fmt.Errorf("sriov-config-service: failed to run OnNodeStateChange to update the generic plugin status %v", err) + } + + sriovResult := &systemd.SriovResult{ + SyncStatus: "Succeeded", + LastSyncError: "", + } + + err = configPlugin.Apply() + if err != nil { + glog.Errorf("sriov-config-service failed to run apply node configuration %v", err) + sriovResult.SyncStatus = "Failed" + sriovResult.LastSyncError = err.Error() + } + + err = systemd.WriteSriovResult(sriovResult) + if err != nil { + glog.Errorf("sriov-config-service failed to write sriov result file with content %v error: %v", *sriovResult, err) + return fmt.Errorf("sriov-config-service failed to write sriov result file with content %v error: %v", *sriovResult, err) + } + + glog.V(0).Info("Shutting down sriov-config-service") + return nil +} diff --git a/cmd/sriov-network-config-daemon/start.go b/cmd/sriov-network-config-daemon/start.go index 0ab621970..b9ee61ac5 100644 --- a/cmd/sriov-network-config-daemon/start.go +++ b/cmd/sriov-network-config-daemon/start.go @@ -1,3 +1,18 @@ +/* +Copyright 2023. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package main import ( @@ -39,13 +54,15 @@ var ( startOpts struct { kubeconfig string nodeName string + systemd bool } ) func init() { rootCmd.AddCommand(startCmd) startCmd.PersistentFlags().StringVar(&startOpts.kubeconfig, "kubeconfig", "", "Kubeconfig file to access a remote cluster (testing only)") - startCmd.PersistentFlags().StringVar(&startOpts.nodeName, "node-name", "", "kubernetes node name daemon is managing.") + startCmd.PersistentFlags().StringVar(&startOpts.nodeName, "node-name", "", "kubernetes node name daemon is managing") + startCmd.PersistentFlags().BoolVar(&startOpts.systemd, "use-systemd-service", false, "use config daemon in systemd mode") } func runStartCmd(cmd *cobra.Command, args []string) { @@ -165,7 +182,7 @@ func runStartCmd(cmd *cobra.Command, args []string) { glog.V(0).Infof("Running on platform: %s", platformType.String()) var namespace = os.Getenv("NAMESPACE") - if err := sriovnetworkv1.InitNicIDMap(kubeclient, namespace); err != nil { + if err := sriovnetworkv1.InitNicIDMapFromConfigMap(kubeclient, namespace); err != nil { glog.Errorf("failed to run init NicIdMap: %v", err) panic(err.Error()) } @@ -189,6 +206,8 @@ func runStartCmd(cmd *cobra.Command, args []string) { syncCh, refreshCh, platformType, + startOpts.systemd, + devMode, ).Run(stopCh, exitCh) if err != nil { glog.Errorf("failed to run daemon: %v", err) diff --git a/config/crd/bases/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml b/config/crd/bases/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml index b47354a88..e6864e5d0 100644 --- a/config/crd/bases/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml +++ b/config/crd/bases/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml @@ -41,6 +41,13 @@ spec: type: string description: NodeSelector selects the nodes to be configured type: object + configurationMode: + description: Flag to enable the sriov-network-config-daemon to use + a systemd mode instead of the regular method + enum: + - daemon + - systemd + type: string disableDrain: description: Flag to disable nodes drain during debugging type: boolean diff --git a/controllers/helper.go b/controllers/helper.go index a6d93c2ac..58634116c 100644 --- a/controllers/helper.go +++ b/controllers/helper.go @@ -30,6 +30,14 @@ var webhooks = map[string](string){ constants.OperatorWebHookName: constants.OperatorWebHookPath, } +const ( + clusterRoleResourceName = "ClusterRole" + clusterRoleBindingResourceName = "ClusterRoleBinding" + mutatingWebhookConfigurationCRDName = "MutatingWebhookConfiguration" + validatingWebhookConfigurationCRDName = "ValidatingWebhookConfiguration" + machineConfigCRDName = "MachineConfig" +) + var namespace = os.Getenv("NAMESPACE") func GetImagePullSecrets() []string { diff --git a/controllers/sriovoperatorconfig_controller.go b/controllers/sriovoperatorconfig_controller.go index e7ff7e9b7..4ed76447f 100644 --- a/controllers/sriovoperatorconfig_controller.go +++ b/controllers/sriovoperatorconfig_controller.go @@ -34,6 +34,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" + machinev1 "github.com/openshift/machine-config-operator/pkg/apis/machineconfiguration.openshift.io/v1" + sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" apply "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/apply" constants "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" @@ -89,6 +91,7 @@ func (r *SriovOperatorConfigReconciler) Reconcile(ctx context.Context, req ctrl. ConfigDaemonNodeSelector: map[string]string{}, LogLevel: 2, DisableDrain: singleNode, + ConfigurationMode: sriovnetworkv1.DaemonConfigurationMode, } err = r.Create(ctx, defaultConfig) @@ -121,6 +124,17 @@ func (r *SriovOperatorConfigReconciler) Reconcile(ctx context.Context, req ctrl. return reconcile.Result{}, err } + // For Openshift we need to create the systemd files using a machine config + if utils.ClusterType == utils.ClusterTypeOpenshift { + // TODO: add support for hypershift as today there is no MCO on hypershift clusters + if r.OpenshiftContext.IsHypershift() { + return ctrl.Result{}, fmt.Errorf("systemd mode is not supported on hypershift") + } + + if err = r.syncOpenShiftSystemdService(ctx, defaultConfig); err != nil { + return reconcile.Result{}, err + } + } return reconcile.Result{RequeueAfter: constants.ResyncPeriod}, nil } @@ -176,6 +190,12 @@ func (r *SriovOperatorConfigReconciler) syncConfigDaemonSet(ctx context.Context, data.Data["ClusterType"] = utils.ClusterType data.Data["DevMode"] = os.Getenv("DEV_MODE") data.Data["ImagePullSecrets"] = GetImagePullSecrets() + if dc.Spec.ConfigurationMode == sriovnetworkv1.SystemdConfigurationMode { + data.Data["UsedSystemdMode"] = true + } else { + data.Data["UsedSystemdMode"] = false + } + envCniBinPath := os.Getenv("SRIOV_CNI_BIN_PATH") if envCniBinPath == "" { data.Data["CNIBinPath"] = "/var/lib/cni/bin" @@ -299,7 +319,7 @@ func (r *SriovOperatorConfigReconciler) deleteK8sResource(ctx context.Context, i func (r *SriovOperatorConfigReconciler) syncK8sResource(ctx context.Context, cr *sriovnetworkv1.SriovOperatorConfig, in *uns.Unstructured) error { switch in.GetKind() { - case "ClusterRole", "ClusterRoleBinding", "MutatingWebhookConfiguration", "ValidatingWebhookConfiguration": + case clusterRoleResourceName, clusterRoleBindingResourceName, mutatingWebhookConfigurationCRDName, validatingWebhookConfigurationCRDName, machineConfigCRDName: default: // set owner-reference only for namespaced objects if err := controllerutil.SetControllerReference(cr, in, r.Scheme); err != nil { @@ -311,3 +331,70 @@ func (r *SriovOperatorConfigReconciler) syncK8sResource(ctx context.Context, cr } return nil } + +// syncOpenShiftSystemdService creates the Machine Config to deploy the systemd service on openshift ONLY +func (r *SriovOperatorConfigReconciler) syncOpenShiftSystemdService(ctx context.Context, cr *sriovnetworkv1.SriovOperatorConfig) error { + logger := log.Log.WithName("syncSystemdService") + + if cr.Spec.ConfigurationMode != sriovnetworkv1.SystemdConfigurationMode { + obj := &machinev1.MachineConfig{} + err := r.Get(context.TODO(), types.NamespacedName{Name: constants.SystemdServiceOcpMachineConfigName}, obj) + if err != nil { + if apierrors.IsNotFound(err) { + return nil + } + + logger.Error(err, "failed to get machine config for the sriov-systemd-service") + return err + } + + logger.Info("Systemd service was deployed but the operator is now operating on daemonset mode, removing the machine config") + err = r.Delete(context.TODO(), obj) + if err != nil { + logger.Error(err, "failed to remove the systemd service machine config") + return err + } + + return nil + } + + logger.Info("Start to sync config systemd machine config for openshift") + data := render.MakeRenderData() + data.Data["LogLevel"] = cr.Spec.LogLevel + objs, err := render.RenderDir(constants.SystemdServiceOcpPath, &data) + if err != nil { + logger.Error(err, "Fail to render config daemon manifests") + return err + } + + // Sync machine config + return r.setLabelInsideObject(ctx, cr, objs) +} + +func (r SriovOperatorConfigReconciler) setLabelInsideObject(ctx context.Context, cr *sriovnetworkv1.SriovOperatorConfig, objs []*uns.Unstructured) error { + logger := log.Log.WithName("setLabelInsideObject") + for _, obj := range objs { + if obj.GetKind() == machineConfigCRDName && len(cr.Spec.ConfigDaemonNodeSelector) > 0 { + scheme := kscheme.Scheme + mc := &machinev1.ControllerConfig{} + err := scheme.Convert(obj, mc, nil) + if err != nil { + logger.Error(err, "Fail to convert to MachineConfig") + return err + } + mc.Labels = cr.Spec.ConfigDaemonNodeSelector + err = scheme.Convert(mc, obj, nil) + if err != nil { + logger.Error(err, "Fail to convert to Unstructured") + return err + } + } + err := r.syncK8sResource(ctx, cr, obj) + if err != nil { + logger.Error(err, "Couldn't sync SR-IoV daemons objects") + return err + } + } + + return nil +} diff --git a/deploy/operator.yaml b/deploy/operator.yaml index 220b758b1..c076d4c1a 100644 --- a/deploy/operator.yaml +++ b/deploy/operator.yaml @@ -41,7 +41,6 @@ spec: image: $SRIOV_NETWORK_OPERATOR_IMAGE command: - sriov-network-operator - imagePullPolicy: IfNotPresent resources: requests: cpu: 100m diff --git a/deployment/sriov-network-operator/crds/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml b/deployment/sriov-network-operator/crds/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml index b47354a88..e6864e5d0 100644 --- a/deployment/sriov-network-operator/crds/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml +++ b/deployment/sriov-network-operator/crds/sriovnetwork.openshift.io_sriovoperatorconfigs.yaml @@ -41,6 +41,13 @@ spec: type: string description: NodeSelector selects the nodes to be configured type: object + configurationMode: + description: Flag to enable the sriov-network-config-daemon to use + a systemd mode instead of the regular method + enum: + - daemon + - systemd + type: string disableDrain: description: Flag to disable nodes drain during debugging type: boolean diff --git a/deployment/sriov-network-operator/templates/operator.yaml b/deployment/sriov-network-operator/templates/operator.yaml index 293a233eb..eb75be182 100644 --- a/deployment/sriov-network-operator/templates/operator.yaml +++ b/deployment/sriov-network-operator/templates/operator.yaml @@ -43,7 +43,6 @@ spec: image: {{ .Values.images.operator }} command: - sriov-network-operator - imagePullPolicy: IfNotPresent resources: requests: cpu: 100m diff --git a/go.mod b/go.mod index 29f97e274..35fbac8c4 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/coreos/go-systemd/v22 v22.4.0 github.com/fsnotify/fsnotify v1.6.0 github.com/golang/glog v1.0.0 + github.com/golang/mock v1.4.4 github.com/google/go-cmp v0.5.9 github.com/hashicorp/go-retryablehttp v0.7.0 github.com/jaypipes/ghw v0.9.0 @@ -26,6 +27,7 @@ require ( github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae golang.org/x/time v0.0.0-20220922220347-f3bd1da661af gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.25.2 k8s.io/apiextensions-apiserver v0.25.2 k8s.io/apimachinery v0.25.2 @@ -146,7 +148,6 @@ require ( google.golang.org/grpc v1.50.1 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect howett.net/plist v1.0.0 // indirect k8s.io/cli-runtime v0.25.1 // indirect k8s.io/component-base v0.25.2 // indirect diff --git a/go.sum b/go.sum index e7fddb1c8..d0b001817 100644 --- a/go.sum +++ b/go.sum @@ -213,6 +213,7 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/hack/build-go.sh b/hack/build-go.sh index e7f9ced83..bb1cead41 100755 --- a/hack/build-go.sh +++ b/hack/build-go.sh @@ -32,7 +32,7 @@ fi mkdir -p ${BIN_PATH} -CGO_ENABLED=${CGO_ENABLED:-1} +CGO_ENABLED=${CGO_ENABLED:-0} if [[ ${WHAT} == "manager" ]]; then diff --git a/main.go b/main.go index f14dee49c..4e96fbb74 100644 --- a/main.go +++ b/main.go @@ -69,6 +69,7 @@ func main() { var metricsAddr string var enableLeaderElection bool var probeAddr string + flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.BoolVar(&enableLeaderElection, "leader-elect", false, @@ -82,6 +83,7 @@ func main() { ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts))) restConfig := ctrl.GetConfigOrDie() + kubeClient, err := client.New(restConfig, client.Options{Scheme: scheme}) if err != nil { setupLog.Error(err, "couldn't create client") @@ -210,7 +212,7 @@ func main() { func initNicIDMap() error { namespace := os.Getenv("NAMESPACE") kubeclient := kubernetes.NewForConfigOrDie(ctrl.GetConfigOrDie()) - if err := sriovnetworkv1.InitNicIDMap(kubeclient, namespace); err != nil { + if err := sriovnetworkv1.InitNicIDMapFromConfigMap(kubeclient, namespace); err != nil { return err } diff --git a/pkg/consts/constants.go b/pkg/consts/constants.go index be19c4594..72e847d82 100644 --- a/pkg/consts/constants.go +++ b/pkg/consts/constants.go @@ -3,23 +3,25 @@ package consts import "time" const ( - ResyncPeriod = 5 * time.Minute - DefaultConfigName = "default" - ConfigDaemonPath = "./bindata/manifests/daemon" - InjectorWebHookPath = "./bindata/manifests/webhook" - OperatorWebHookPath = "./bindata/manifests/operator-webhook" - ServiceCAConfigMapAnnotation = "service.beta.openshift.io/inject-cabundle" - InjectorWebHookName = "network-resources-injector-config" - OperatorWebHookName = "sriov-operator-webhook-config" - DeprecatedOperatorWebHookName = "operator-webhook-config" - PluginPath = "./bindata/manifests/plugins" - DaemonPath = "./bindata/manifests/daemon" - DefaultPolicyName = "default" - ConfigMapName = "device-plugin-config" - DaemonSet = "DaemonSet" - ServiceAccount = "ServiceAccount" - DPConfigFileName = "config.json" - OVSHWOLMachineConfigNameSuffix = "ovs-hw-offload" + ResyncPeriod = 5 * time.Minute + DefaultConfigName = "default" + ConfigDaemonPath = "./bindata/manifests/daemon" + InjectorWebHookPath = "./bindata/manifests/webhook" + OperatorWebHookPath = "./bindata/manifests/operator-webhook" + SystemdServiceOcpPath = "./bindata/manifests/sriov-config-service/openshift" + SystemdServiceOcpMachineConfigName = "sriov-config-service" + ServiceCAConfigMapAnnotation = "service.beta.openshift.io/inject-cabundle" + InjectorWebHookName = "network-resources-injector-config" + OperatorWebHookName = "sriov-operator-webhook-config" + DeprecatedOperatorWebHookName = "operator-webhook-config" + PluginPath = "./bindata/manifests/plugins" + DaemonPath = "./bindata/manifests/daemon" + DefaultPolicyName = "default" + ConfigMapName = "device-plugin-config" + DaemonSet = "DaemonSet" + ServiceAccount = "ServiceAccount" + DPConfigFileName = "config.json" + OVSHWOLMachineConfigNameSuffix = "ovs-hw-offload" LinkTypeEthernet = "ether" LinkTypeInfiniband = "infiniband" diff --git a/pkg/daemon/daemon.go b/pkg/daemon/daemon.go index c0455bdbd..07fcb7dd9 100644 --- a/pkg/daemon/daemon.go +++ b/pkg/daemon/daemon.go @@ -41,7 +41,10 @@ import ( sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" snclientset "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/clientset/versioned" sninformer "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/informers/externalversions" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/host" plugin "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/service" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/systemd" "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" ) @@ -65,6 +68,10 @@ type Daemon struct { platform utils.PlatformType + useSystemdService bool + + devMode bool + client snclientset.Interface // kubeClient allows interaction with Kubernetes, including the node we are running on. kubeClient kubernetes.Interface @@ -75,6 +82,8 @@ type Daemon struct { enabledPlugins map[string]plugin.VendorPlugin + serviceManager service.ServiceManager + // channel used by callbacks to signal Run() of an error exitCh chan<- error @@ -103,7 +112,6 @@ type Daemon struct { } const ( - rdmaScriptsPath = "/bindata/scripts/enable-rdma.sh" udevScriptsPath = "/bindata/scripts/load-udev.sh" annoKey = "sriovnetwork.openshift.io/state" annoIdle = "Idle" @@ -140,18 +148,23 @@ func New( syncCh <-chan struct{}, refreshCh chan<- Message, platformType utils.PlatformType, + useSystemdService bool, + devMode bool, ) *Daemon { return &Daemon{ - name: nodeName, - platform: platformType, - client: client, - kubeClient: kubeClient, - openshiftContext: openshiftContext, - exitCh: exitCh, - stopCh: stopCh, - syncCh: syncCh, - refreshCh: refreshCh, - nodeState: &sriovnetworkv1.SriovNetworkNodeState{}, + name: nodeName, + platform: platformType, + useSystemdService: useSystemdService, + devMode: devMode, + client: client, + kubeClient: kubeClient, + openshiftContext: openshiftContext, + serviceManager: service.NewServiceManager("/host"), + exitCh: exitCh, + stopCh: stopCh, + syncCh: syncCh, + refreshCh: refreshCh, + nodeState: &sriovnetworkv1.SriovNetworkNodeState{}, drainer: &drain.Helper{ Client: kubeClient, Force: true, @@ -210,13 +223,24 @@ func (dn *Daemon) Run(stopCh <-chan struct{}, exitCh <-chan error) error { } else { glog.V(0).Infof("Run(): start daemon.") } + + if dn.useSystemdService { + glog.V(0).Info("Run(): daemon running in systemd mode") + } // Only watch own SriovNetworkNodeState CR defer utilruntime.HandleCrash() defer dn.workqueue.ShutDown() - tryEnableRdma() - tryEnableTun() - tryEnableVhostNet() + if !dn.useSystemdService { + hostManager := host.NewHostManager(dn.useSystemdService) + hostManager.TryEnableRdma() + hostManager.TryEnableTun() + hostManager.TryEnableVhostNet() + err := systemd.CleanSriovFilesFromHost(utils.ClusterType == utils.ClusterTypeOpenshift) + if err != nil { + glog.Warningf("failed to remove all the systemd sriov files error: %v", err) + } + } if err := dn.tryCreateUdevRuleWrapper(); err != nil { return err @@ -274,7 +298,7 @@ func (dn *Daemon) Run(stopCh <-chan struct{}, exitCh <-chan error) error { } glog.Info("Starting workers") - // Launch one workers to process + // Launch one worker to process go wait.Until(dn.runWorker, time.Second, stopCh) glog.Info("Started workers") @@ -412,6 +436,7 @@ func (dn *Daemon) nodeStateSyncHandler() error { var err error // Get the latest NodeState var latestState *sriovnetworkv1.SriovNetworkNodeState + var sriovResult = &systemd.SriovResult{SyncStatus: syncStatusSucceeded, LastSyncError: ""} latestState, err = dn.client.SriovnetworkV1().SriovNetworkNodeStates(namespace).Get(context.Background(), dn.name, metav1.GetOptions{}) if err != nil { glog.Warningf("nodeStateSyncHandler(): Failed to fetch node state %s: %v", dn.name, err) @@ -420,7 +445,45 @@ func (dn *Daemon) nodeStateSyncHandler() error { latest := latestState.GetGeneration() glog.V(0).Infof("nodeStateSyncHandler(): new generation is %d", latest) + if utils.ClusterType == utils.ClusterTypeOpenshift && !dn.openshiftContext.IsHypershift() { + if err = dn.getNodeMachinePool(); err != nil { + return err + } + } + if dn.nodeState.GetGeneration() == latest { + if dn.useSystemdService { + serviceExist, err := dn.serviceManager.IsServiceExist(systemd.SriovServicePath) + if err != nil { + glog.Errorf("nodeStateSyncHandler(): failed to check if sriov-config service exist on host: %v", err) + return err + } + + // if the service doesn't exist we should continue to let the k8s plugin to create the service files + // this is only for k8s base environments, for openshift the sriov-operator creates a machine config to will apply + // the system service and reboot the node the config-daemon doesn't need to do anything. + if !serviceExist { + sriovResult = &systemd.SriovResult{SyncStatus: syncStatusFailed, LastSyncError: "sriov-config systemd service doesn't exist on node"} + } else { + sriovResult, err = systemd.ReadSriovResult() + if err != nil { + glog.Errorf("nodeStateSyncHandler(): failed to load sriov result file from host: %v", err) + return err + } + } + if sriovResult.LastSyncError != "" || sriovResult.SyncStatus == syncStatusFailed { + glog.Infof("nodeStateSyncHandler(): sync failed systemd service error: %s", sriovResult.LastSyncError) + + // add the error but don't requeue + dn.refreshCh <- Message{ + syncStatus: syncStatusFailed, + lastSyncError: sriovResult.LastSyncError, + } + <-dn.syncCh + return nil + } + return nil + } glog.V(0).Infof("nodeStateSyncHandler(): Interface not changed") if latestState.Status.LastSyncError != "" || latestState.Status.SyncStatus != syncStatusSucceeded { @@ -453,9 +516,9 @@ func (dn *Daemon) nodeStateSyncHandler() error { lastSyncError: "", } - // load plugins if has not loaded + // load plugins if it has not loaded if len(dn.enabledPlugins) == 0 { - dn.enabledPlugins, err = enablePlugins(dn.platform, latestState) + dn.enabledPlugins, err = enablePlugins(dn.platform, dn.useSystemdService, latestState) if err != nil { glog.Errorf("nodeStateSyncHandler(): failed to enable vendor plugins error: %v", err) return err @@ -464,6 +527,8 @@ func (dn *Daemon) nodeStateSyncHandler() error { reqReboot := false reqDrain := false + + // check if any of the plugins required to drain or reboot the node for k, p := range dn.enabledPlugins { d, r := false, false if dn.nodeState.GetName() == "" { @@ -480,10 +545,31 @@ func (dn *Daemon) nodeStateSyncHandler() error { reqDrain = reqDrain || d reqReboot = reqReboot || r } - glog.V(0).Infof("nodeStateSyncHandler(): reqDrain %v, reqReboot %v disableDrain %v", reqDrain, reqReboot, dn.disableDrain) + + // When running using systemd check if the applied configuration is the latest one + // or there is a new config we need to apply + // When using systemd configuration we write the file + if dn.useSystemdService { + r, err := systemd.WriteConfFile(latestState, dn.devMode, dn.platform) + if err != nil { + glog.Errorf("nodeStateSyncHandler(): failed to write configuration file for systemd mode: %v", err) + return err + } + reqDrain = reqDrain || r + reqReboot = reqReboot || r + glog.V(0).Infof("nodeStateSyncHandler(): systemd mode reqDrain %v, reqReboot %v disableDrain %v", r, r, dn.disableDrain) + + err = systemd.WriteSriovSupportedNics() + if err != nil { + glog.Errorf("nodeStateSyncHandler(): failed to write supported nic ids file for systemd mode: %v", err) + return err + } + } + glog.V(0).Infof("nodeStateSyncHandler(): aggregated daemon reqDrain %v, reqReboot %v disableDrain %v", reqDrain, reqReboot, dn.disableDrain) for k, p := range dn.enabledPlugins { - if k != GenericPluginName { + // Skip both the general and virtual plugin apply them last + if k != GenericPluginName && k != VirtualPluginName { err := p.Apply() if err != nil { glog.Errorf("nodeStateSyncHandler(): plugin %s fail to apply: %v", k, err) @@ -526,7 +612,8 @@ func (dn *Daemon) nodeStateSyncHandler() error { } } - if !reqReboot { + if !reqReboot && !dn.useSystemdService { + // For BareMetal machines apply the generic plugin selectedPlugin, ok := dn.enabledPlugins[GenericPluginName] if ok { // Apply generic_plugin last @@ -536,6 +623,17 @@ func (dn *Daemon) nodeStateSyncHandler() error { return err } } + + // For Virtual machines apply the virtual plugin + selectedPlugin, ok = dn.enabledPlugins[VirtualPluginName] + if ok { + // Apply virtual_plugin last + err = selectedPlugin.Apply() + if err != nil { + glog.Errorf("nodeStateSyncHandler(): generic_plugin fail to apply: %v", err) + return err + } + } } if reqReboot { @@ -565,9 +663,16 @@ func (dn *Daemon) nodeStateSyncHandler() error { } glog.Info("nodeStateSyncHandler(): sync succeeded") dn.nodeState = latestState.DeepCopy() - dn.refreshCh <- Message{ - syncStatus: syncStatusSucceeded, - lastSyncError: "", + if dn.useSystemdService { + dn.refreshCh <- Message{ + syncStatus: sriovResult.SyncStatus, + lastSyncError: sriovResult.LastSyncError, + } + } else { + dn.refreshCh <- Message{ + syncStatus: syncStatusSucceeded, + lastSyncError: "", + } } // wait for writer to refresh the status <-dn.syncCh @@ -932,44 +1037,6 @@ func (dn *Daemon) drainNode() error { return nil } -func tryEnableTun() { - if err := utils.LoadKernelModule("tun"); err != nil { - glog.Errorf("tryEnableTun(): TUN kernel module not loaded: %v", err) - } -} - -func tryEnableVhostNet() { - if err := utils.LoadKernelModule("vhost_net"); err != nil { - glog.Errorf("tryEnableVhostNet(): VHOST_NET kernel module not loaded: %v", err) - } -} - -func tryEnableRdma() (bool, error) { - glog.V(2).Infof("tryEnableRdma()") - var stdout, stderr bytes.Buffer - - cmd := exec.Command("/bin/bash", path.Join(filesystemRoot, rdmaScriptsPath)) - cmd.Stdout = &stdout - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - glog.Errorf("tryEnableRdma(): fail to enable rdma %v: %v", err, cmd.Stderr) - return false, err - } - glog.V(2).Infof("tryEnableRdma(): %v", cmd.Stdout) - - i, err := strconv.Atoi(strings.TrimSpace(stdout.String())) - if err == nil { - if i == 0 { - glog.V(2).Infof("tryEnableRdma(): RDMA kernel modules loaded") - return true, nil - } else { - glog.V(2).Infof("tryEnableRdma(): RDMA kernel modules not loaded") - return false, nil - } - } - return false, err -} - func tryCreateSwitchdevUdevRule(nodeState *sriovnetworkv1.SriovNetworkNodeState) error { glog.V(2).Infof("tryCreateSwitchdevUdevRule()") var newContent string diff --git a/pkg/daemon/daemon_test.go b/pkg/daemon/daemon_test.go index fb70cc14b..c23a7b6dd 100644 --- a/pkg/daemon/daemon_test.go +++ b/pkg/daemon/daemon_test.go @@ -101,7 +101,7 @@ var _ = Describe("Config Daemon", func() { kubeClient := fakek8s.NewSimpleClientset(&FakeSupportedNicIDs, &SriovDevicePluginPod) client := fakesnclientset.NewSimpleClientset() - err = sriovnetworkv1.InitNicIDMap(kubeClient, namespace) + err = sriovnetworkv1.InitNicIDMapFromConfigMap(kubeClient, namespace) Expect(err).ToNot(HaveOccurred()) sut = New("test-node", @@ -113,6 +113,8 @@ var _ = Describe("Config Daemon", func() { syncCh, refreshCh, utils.Baremetal, + false, + false, ) sut.enabledPlugins = map[string]plugin.VendorPlugin{generic.PluginName: &fake.FakePlugin{}} diff --git a/pkg/daemon/plugin.go b/pkg/daemon/plugin.go index 307a35343..9639db88e 100644 --- a/pkg/daemon/plugin.go +++ b/pkg/daemon/plugin.go @@ -24,15 +24,16 @@ var ( GenericPlugin = genericplugin.NewGenericPlugin GenericPluginName = genericplugin.PluginName VirtualPlugin = virtualplugin.NewVirtualPlugin + VirtualPluginName = virtualplugin.PluginName K8sPlugin = k8splugin.NewK8sPlugin ) -func enablePlugins(platform utils.PlatformType, ns *sriovnetworkv1.SriovNetworkNodeState) (map[string]plugin.VendorPlugin, error) { +func enablePlugins(platform utils.PlatformType, useSystemdService bool, ns *sriovnetworkv1.SriovNetworkNodeState) (map[string]plugin.VendorPlugin, error) { glog.Infof("enableVendorPlugins(): enabling plugins") enabledPlugins := map[string]plugin.VendorPlugin{} if platform == utils.VirtualOpenStack { - virtualPlugin, err := VirtualPlugin() + virtualPlugin, err := VirtualPlugin(false) if err != nil { glog.Errorf("enableVendorPlugins(): failed to load the virtual plugin error: %v", err) return nil, err @@ -46,14 +47,14 @@ func enablePlugins(platform utils.PlatformType, ns *sriovnetworkv1.SriovNetworkN enabledPlugins = enabledVendorPlugins if utils.ClusterType != utils.ClusterTypeOpenshift { - k8sPlugin, err := K8sPlugin() + k8sPlugin, err := K8sPlugin(useSystemdService) if err != nil { glog.Errorf("enableVendorPlugins(): failed to load the k8s plugin error: %v", err) return nil, err } enabledPlugins[k8sPlugin.Name()] = k8sPlugin } - genericPlugin, err := GenericPlugin() + genericPlugin, err := GenericPlugin(false) if err != nil { glog.Errorf("enableVendorPlugins(): failed to load the generic plugin error: %v", err) return nil, err diff --git a/pkg/daemon/writer.go b/pkg/daemon/writer.go index d70ea8843..2a6fc6384 100644 --- a/pkg/daemon/writer.go +++ b/pkg/daemon/writer.go @@ -54,7 +54,7 @@ func (w *NodeStateStatusWriter) RunOnce(destDir string, platformType utils.Platf } if ns == nil { - metaData, networkData, err := utils.GetOpenstackData() + metaData, networkData, err := utils.GetOpenstackData(true) if err != nil { glog.Errorf("RunOnce(): failed to read OpenStack data: %v", err) } @@ -109,10 +109,7 @@ func (w *NodeStateStatusWriter) Run(stop <-chan struct{}, refresh <-chan Message if err := w.pollNicStatus(platformType); err != nil { continue } - _, err := w.setNodeStateStatus(msg) - if err != nil { - glog.Errorf("Run() period: writing to node status failed: %v", err) - } + w.setNodeStateStatus(msg) } } } diff --git a/pkg/host/host.go b/pkg/host/host.go new file mode 100644 index 000000000..37af5ddba --- /dev/null +++ b/pkg/host/host.go @@ -0,0 +1,455 @@ +/* +Copyright 2023. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package host + +import ( + "fmt" + "os" + pathlib "path" + "strings" + + "github.com/golang/glog" + + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" +) + +const ( + hostPathFromDaemon = "/host" + redhatReleaseFile = "/etc/redhat-release" + rhelRDMAConditionFile = "/usr/libexec/rdma-init-kernel" + rhelRDMAServiceName = "rdma" + rhelPackageManager = "yum" + + ubuntuRDMAConditionFile = "/usr/sbin/rdma-ndd" + ubuntuRDMAServiceName = "rdma-ndd" + ubuntuPackageManager = "apt-get" + + genericOSReleaseFile = "/etc/os-release" +) + +// Contains all the host manipulation functions +// +//go:generate ../../bin/mockgen -destination mock/mock_host.go -source host.go +type HostManagerInterface interface { + TryEnableTun() + TryEnableVhostNet() + TryEnableRdma() (bool, error) + + // private functions + // part of the interface for the mock generation + LoadKernelModule(name string, args ...string) error + isKernelModuleLoaded(string) (bool, error) + isRHELSystem() (bool, error) + isUbuntuSystem() (bool, error) + isCoreOS() (bool, error) + rdmaIsLoaded() (bool, error) + enableRDMA(string, string, string) (bool, error) + installRDMA(string) error + triggerUdevEvent() error + reloadDriver(string) error + enableRDMAOnRHELMachine() (bool, error) + getOSPrettyName() (string, error) +} + +type HostManager struct { + RunOnHost bool + cmd utils.CommandInterface +} + +func NewHostManager(runOnHost bool) HostManagerInterface { + return &HostManager{ + RunOnHost: runOnHost, + cmd: &utils.Command{}, + } +} + +func (h *HostManager) LoadKernelModule(name string, args ...string) error { + glog.Infof("LoadKernelModule(): try to load kernel module %s with arguments '%s'", name, args) + chrootDefinition := getChrootExtention(h.RunOnHost) + cmdArgs := strings.Join(args, " ") + + // check if the driver is already loaded in to the system + isLoaded, err := h.isKernelModuleLoaded(name) + if err != nil { + glog.Errorf("LoadKernelModule(): failed to check if kernel module %s is already loaded", name) + } + if isLoaded { + glog.Infof("LoadKernelModule(): kernel module %s already loaded", name) + return nil + } + + _, _, err = h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("%s modprobe %s %s", chrootDefinition, name, cmdArgs)) + if err != nil { + glog.Errorf("LoadKernelModule(): failed to load kernel module %s with arguments '%s': %v", name, args, err) + return err + } + return nil +} + +func (h *HostManager) isKernelModuleLoaded(kernelModuleName string) (bool, error) { + glog.Infof("isKernelModuleLoaded(): check if kernel module %s is loaded", kernelModuleName) + chrootDefinition := getChrootExtention(h.RunOnHost) + + stdout, stderr, err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("%s lsmod | grep \"^%s\"", chrootDefinition, kernelModuleName)) + if err != nil && stderr.Len() != 0 { + glog.Errorf("isKernelModuleLoaded(): failed to check if kernel module %s is loaded: error: %v stderr %s", kernelModuleName, err, stderr.String()) + return false, err + } + glog.V(2).Infof("isKernelModuleLoaded(): %v", stdout.String()) + if stderr.Len() != 0 { + glog.Errorf("isKernelModuleLoaded(): failed to check if kernel module %s is loaded: error: %v stderr %s", kernelModuleName, err, stderr.String()) + return false, fmt.Errorf(stderr.String()) + } + + if stdout.Len() != 0 { + glog.Infof("isKernelModuleLoaded(): kernel module %s already loaded", kernelModuleName) + return true, nil + } + + return false, nil +} + +func (h *HostManager) TryEnableTun() { + if err := h.LoadKernelModule("tun"); err != nil { + glog.Errorf("tryEnableTun(): TUN kernel module not loaded: %v", err) + } +} + +func (h *HostManager) TryEnableVhostNet() { + if err := h.LoadKernelModule("vhost_net"); err != nil { + glog.Errorf("tryEnableVhostNet(): VHOST_NET kernel module not loaded: %v", err) + } +} + +func (h *HostManager) TryEnableRdma() (bool, error) { + glog.V(2).Infof("tryEnableRdma()") + chrootDefinition := getChrootExtention(h.RunOnHost) + + // check if the driver is already loaded in to the system + _, stderr, mlx4Err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("grep --quiet 'mlx4_en' <(%s lsmod)", chrootDefinition)) + if mlx4Err != nil && stderr.Len() != 0 { + glog.Errorf("tryEnableRdma(): failed to check for kernel module 'mlx4_en': error: %v stderr %s", mlx4Err, stderr.String()) + return false, fmt.Errorf(stderr.String()) + } + + _, stderr, mlx5Err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("grep --quiet 'mlx5_core' <(%s lsmod)", chrootDefinition)) + if mlx5Err != nil && stderr.Len() != 0 { + glog.Errorf("tryEnableRdma(): failed to check for kernel module 'mlx5_core': error: %v stderr %s", mlx5Err, stderr.String()) + return false, fmt.Errorf(stderr.String()) + } + + if mlx4Err != nil && mlx5Err != nil { + glog.Errorf("tryEnableRdma(): no RDMA capable devices") + return false, nil + } + + isRhelSystem, err := h.isRHELSystem() + if err != nil { + glog.Errorf("tryEnableRdma(): failed to check if the machine is base on RHEL: %v", err) + return false, err + } + + // RHEL check + if isRhelSystem { + return h.enableRDMAOnRHELMachine() + } + + isUbuntuSystem, err := h.isUbuntuSystem() + if err != nil { + glog.Errorf("tryEnableRdma(): failed to check if the machine is base on Ubuntu: %v", err) + return false, err + } + + if isUbuntuSystem { + return h.enableRDMAOnUbuntuMachine() + } + + osName, err := h.getOSPrettyName() + if err != nil { + glog.Errorf("tryEnableRdma(): failed to check OS name: %v", err) + return false, err + } + + glog.Errorf("tryEnableRdma(): Unsupported OS: %s", osName) + return false, fmt.Errorf("unable to load RDMA unsupported OS: %s", osName) +} + +func (h *HostManager) enableRDMAOnRHELMachine() (bool, error) { + glog.Infof("enableRDMAOnRHELMachine()") + isCoreOsSystem, err := h.isCoreOS() + if err != nil { + glog.Errorf("enableRDMAOnRHELMachine(): failed to check if the machine runs CoreOS: %v", err) + return false, err + } + + // CoreOS check + if isCoreOsSystem { + isRDMALoaded, err := h.rdmaIsLoaded() + if err != nil { + glog.Errorf("enableRDMAOnRHELMachine(): failed to check if RDMA kernel modules are loaded: %v", err) + return false, err + } + + return isRDMALoaded, nil + } + + // RHEL + glog.Infof("enableRDMAOnRHELMachine(): enabling RDMA on RHEL machine") + isRDMAEnable, err := h.enableRDMA(rhelRDMAConditionFile, rhelRDMAServiceName, rhelPackageManager) + if err != nil { + glog.Errorf("enableRDMAOnRHELMachine(): failed to enable RDMA on RHEL machine: %v", err) + return false, err + } + + // check if we need to install rdma-core package + if isRDMAEnable { + isRDMALoaded, err := h.rdmaIsLoaded() + if err != nil { + glog.Errorf("enableRDMAOnRHELMachine(): failed to check if RDMA kernel modules are loaded: %v", err) + return false, err + } + + // if ib kernel module is not loaded trigger a loading + if isRDMALoaded { + err = h.triggerUdevEvent() + if err != nil { + glog.Errorf("enableRDMAOnRHELMachine() failed to trigger udev event: %v", err) + return false, err + } + } + } + + return true, nil +} + +func (h *HostManager) enableRDMAOnUbuntuMachine() (bool, error) { + glog.Infof("enableRDMAOnUbuntuMachine(): enabling RDMA on RHEL machine") + isRDMAEnable, err := h.enableRDMA(ubuntuRDMAConditionFile, ubuntuRDMAServiceName, ubuntuPackageManager) + if err != nil { + glog.Errorf("enableRDMAOnUbuntuMachine(): failed to enable RDMA on Ubuntu machine: %v", err) + return false, err + } + + // check if we need to install rdma-core package + if isRDMAEnable { + isRDMALoaded, err := h.rdmaIsLoaded() + if err != nil { + glog.Errorf("enableRDMAOnUbuntuMachine(): failed to check if RDMA kernel modules are loaded: %v", err) + return false, err + } + + // if ib kernel module is not loaded trigger a loading + if isRDMALoaded { + err = h.triggerUdevEvent() + if err != nil { + glog.Errorf("enableRDMAOnUbuntuMachine() failed to trigger udev event: %v", err) + return false, err + } + } + } + + return true, nil +} + +func (h *HostManager) isRHELSystem() (bool, error) { + glog.Infof("isRHELSystem(): checking for RHEL machine") + path := redhatReleaseFile + if !h.RunOnHost { + path = pathlib.Join(hostPathFromDaemon, path) + } + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + glog.V(2).Infof("isRHELSystem() not a RHEL machine") + return false, nil + } + + glog.Errorf("isRHELSystem() failed to check for os release file on path %s: %v", path, err) + return false, err + } + + return true, nil +} + +func (h *HostManager) isCoreOS() (bool, error) { + glog.Infof("isCoreOS(): checking for CoreOS machine") + path := redhatReleaseFile + if !h.RunOnHost { + path = pathlib.Join(hostPathFromDaemon, path) + } + + data, err := os.ReadFile(path) + if err != nil { + glog.Errorf("isCoreOS(): failed to read RHEL release file on path %s: %v", path, err) + return false, err + } + + if strings.Contains(string(data), "CoreOS") { + return true, nil + } + + return false, nil +} + +func (h *HostManager) isUbuntuSystem() (bool, error) { + glog.Infof("isUbuntuSystem(): checking for Ubuntu machine") + path := genericOSReleaseFile + if !h.RunOnHost { + path = pathlib.Join(hostPathFromDaemon, path) + } + + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + glog.Errorf("isUbuntuSystem() os-release on path %s doesn't exist: %v", path, err) + return false, err + } + + glog.Errorf("isUbuntuSystem() failed to check for os release file on path %s: %v", path, err) + return false, err + } + + stdout, stderr, err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("grep -i --quiet 'ubuntu' %s", path)) + if err != nil && stderr.Len() != 0 { + glog.Errorf("isUbuntuSystem(): failed to check for ubuntu operating system name in os-releasae file: error: %v stderr %s", err, stderr.String()) + return false, fmt.Errorf(stderr.String()) + } + + if stdout.Len() > 0 { + return true, nil + } + + return false, nil +} + +func (h *HostManager) rdmaIsLoaded() (bool, error) { + glog.V(2).Infof("rdmaIsLoaded()") + chrootDefinition := getChrootExtention(h.RunOnHost) + + // check if the driver is already loaded in to the system + _, stderr, err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("grep --quiet '\\(^ib\\|^rdma\\)' <(%s lsmod)", chrootDefinition)) + if err != nil && stderr.Len() != 0 { + glog.Errorf("rdmaIsLoaded(): fail to check if ib and rdma kernel modules are loaded: error: %v stderr %s", err, stderr.String()) + return false, fmt.Errorf(stderr.String()) + } + + if err != nil { + return false, nil + } + + return true, nil +} + +func (h *HostManager) enableRDMA(conditionFilePath, serviceName, packageManager string) (bool, error) { + path := conditionFilePath + if !h.RunOnHost { + path = pathlib.Join(hostPathFromDaemon, path) + } + glog.Infof("enableRDMA(): checking for service file on path %s", path) + + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + glog.V(2).Infof("enableRDMA(): RDMA server doesn't exist") + err = h.installRDMA(packageManager) + if err != nil { + glog.Errorf("enableRDMA() failed to install RDMA package: %v", err) + return false, err + } + + err = h.triggerUdevEvent() + if err != nil { + glog.Errorf("enableRDMA() failed to trigger udev event: %v", err) + return false, err + } + + return false, nil + } + + glog.Errorf("enableRDMA() failed to check for os release file on path %s: %v", path, err) + return false, err + } + + glog.Infof("enableRDMA(): service %s.service installed", serviceName) + return true, nil +} + +func (h *HostManager) installRDMA(packageManager string) error { + glog.Infof("installRDMA(): installing RDMA") + chrootDefinition := getChrootExtention(h.RunOnHost) + + stdout, stderr, err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("%s %s install -y rdma-core", chrootDefinition, packageManager)) + if err != nil && stderr.Len() != 0 { + glog.Errorf("installRDMA(): failed to install RDMA package output %s: error %v stderr %s", stdout.String(), err, stderr.String()) + return err + } + + return nil +} + +func (h *HostManager) triggerUdevEvent() error { + glog.Infof("triggerUdevEvent(): installing RDMA") + + err := h.reloadDriver("mlx4_en") + if err != nil { + return err + } + + err = h.reloadDriver("mlx5_core") + if err != nil { + return err + } + + return nil +} + +func (h *HostManager) reloadDriver(driverName string) error { + glog.Infof("reloadDriver(): reload driver %s", driverName) + chrootDefinition := getChrootExtention(h.RunOnHost) + + _, stderr, err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("%s modprobe -r %s && %s modprobe %s", chrootDefinition, driverName, chrootDefinition, driverName)) + if err != nil && stderr.Len() != 0 { + glog.Errorf("installRDMA(): failed to reload %s kernel module: error %v stderr %s", driverName, err, stderr.String()) + return err + } + + return nil +} + +func (h *HostManager) getOSPrettyName() (string, error) { + path := genericOSReleaseFile + if !h.RunOnHost { + path = pathlib.Join(hostPathFromDaemon, path) + } + + glog.Infof("getOSPrettyName(): getting os name from os-release file") + + stdout, stderr, err := h.cmd.Run("/bin/sh", "-c", fmt.Sprintf("cat %s | grep PRETTY_NAME | cut -c 13-", path)) + if err != nil && stderr.Len() != 0 { + glog.Errorf("isUbuntuSystem(): failed to check for ubuntu operating system name in os-releasae file: error: %v stderr %s", err, stderr.String()) + return "", fmt.Errorf(stderr.String()) + } + + if stdout.Len() > 0 { + return stdout.String(), nil + } + + return "", fmt.Errorf("failed to find pretty operating system name") +} + +func getChrootExtention(runOnHost bool) string { + if !runOnHost { + return "chroot /host/" + } + return "" +} diff --git a/pkg/host/mock/mock_host.go b/pkg/host/mock/mock_host.go new file mode 100644 index 000000000..06ed3858a --- /dev/null +++ b/pkg/host/mock/mock_host.go @@ -0,0 +1,254 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: host.go + +// Package mock_host is a generated GoMock package. +package mock_host + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockHostManagerInterface is a mock of HostManagerInterface interface. +type MockHostManagerInterface struct { + ctrl *gomock.Controller + recorder *MockHostManagerInterfaceMockRecorder +} + +// MockHostManagerInterfaceMockRecorder is the mock recorder for MockHostManagerInterface. +type MockHostManagerInterfaceMockRecorder struct { + mock *MockHostManagerInterface +} + +// NewMockHostManagerInterface creates a new mock instance. +func NewMockHostManagerInterface(ctrl *gomock.Controller) *MockHostManagerInterface { + mock := &MockHostManagerInterface{ctrl: ctrl} + mock.recorder = &MockHostManagerInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHostManagerInterface) EXPECT() *MockHostManagerInterfaceMockRecorder { + return m.recorder +} + +// LoadKernelModule mocks base method. +func (m *MockHostManagerInterface) LoadKernelModule(name string, args ...string) error { + m.ctrl.T.Helper() + varargs := []interface{}{name} + for _, a := range args { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "LoadKernelModule", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// LoadKernelModule indicates an expected call of LoadKernelModule. +func (mr *MockHostManagerInterfaceMockRecorder) LoadKernelModule(name interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{name}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadKernelModule", reflect.TypeOf((*MockHostManagerInterface)(nil).LoadKernelModule), varargs...) +} + +// TryEnableRdma mocks base method. +func (m *MockHostManagerInterface) TryEnableRdma() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TryEnableRdma") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TryEnableRdma indicates an expected call of TryEnableRdma. +func (mr *MockHostManagerInterfaceMockRecorder) TryEnableRdma() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryEnableRdma", reflect.TypeOf((*MockHostManagerInterface)(nil).TryEnableRdma)) +} + +// TryEnableTun mocks base method. +func (m *MockHostManagerInterface) TryEnableTun() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "TryEnableTun") +} + +// TryEnableTun indicates an expected call of TryEnableTun. +func (mr *MockHostManagerInterfaceMockRecorder) TryEnableTun() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryEnableTun", reflect.TypeOf((*MockHostManagerInterface)(nil).TryEnableTun)) +} + +// TryEnableVhostNet mocks base method. +func (m *MockHostManagerInterface) TryEnableVhostNet() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "TryEnableVhostNet") +} + +// TryEnableVhostNet indicates an expected call of TryEnableVhostNet. +func (mr *MockHostManagerInterfaceMockRecorder) TryEnableVhostNet() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryEnableVhostNet", reflect.TypeOf((*MockHostManagerInterface)(nil).TryEnableVhostNet)) +} + +// enableRDMA mocks base method. +func (m *MockHostManagerInterface) enableRDMA(arg0, arg1, arg2 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "enableRDMA", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// enableRDMA indicates an expected call of enableRDMA. +func (mr *MockHostManagerInterfaceMockRecorder) enableRDMA(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "enableRDMA", reflect.TypeOf((*MockHostManagerInterface)(nil).enableRDMA), arg0, arg1, arg2) +} + +// enableRDMAOnRHELMachine mocks base method. +func (m *MockHostManagerInterface) enableRDMAOnRHELMachine() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "enableRDMAOnRHELMachine") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// enableRDMAOnRHELMachine indicates an expected call of enableRDMAOnRHELMachine. +func (mr *MockHostManagerInterfaceMockRecorder) enableRDMAOnRHELMachine() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "enableRDMAOnRHELMachine", reflect.TypeOf((*MockHostManagerInterface)(nil).enableRDMAOnRHELMachine)) +} + +// getOSPrettyName mocks base method. +func (m *MockHostManagerInterface) getOSPrettyName() (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getOSPrettyName") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// getOSPrettyName indicates an expected call of getOSPrettyName. +func (mr *MockHostManagerInterfaceMockRecorder) getOSPrettyName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getOSPrettyName", reflect.TypeOf((*MockHostManagerInterface)(nil).getOSPrettyName)) +} + +// installRDMA mocks base method. +func (m *MockHostManagerInterface) installRDMA(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "installRDMA", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// installRDMA indicates an expected call of installRDMA. +func (mr *MockHostManagerInterfaceMockRecorder) installRDMA(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "installRDMA", reflect.TypeOf((*MockHostManagerInterface)(nil).installRDMA), arg0) +} + +// isCoreOS mocks base method. +func (m *MockHostManagerInterface) isCoreOS() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "isCoreOS") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// isCoreOS indicates an expected call of isCoreOS. +func (mr *MockHostManagerInterfaceMockRecorder) isCoreOS() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isCoreOS", reflect.TypeOf((*MockHostManagerInterface)(nil).isCoreOS)) +} + +// isKernelModuleLoaded mocks base method. +func (m *MockHostManagerInterface) isKernelModuleLoaded(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "isKernelModuleLoaded", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// isKernelModuleLoaded indicates an expected call of isKernelModuleLoaded. +func (mr *MockHostManagerInterfaceMockRecorder) isKernelModuleLoaded(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isKernelModuleLoaded", reflect.TypeOf((*MockHostManagerInterface)(nil).isKernelModuleLoaded), arg0) +} + +// isRHELSystem mocks base method. +func (m *MockHostManagerInterface) isRHELSystem() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "isRHELSystem") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// isRHELSystem indicates an expected call of isRHELSystem. +func (mr *MockHostManagerInterfaceMockRecorder) isRHELSystem() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isRHELSystem", reflect.TypeOf((*MockHostManagerInterface)(nil).isRHELSystem)) +} + +// isUbuntuSystem mocks base method. +func (m *MockHostManagerInterface) isUbuntuSystem() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "isUbuntuSystem") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// isUbuntuSystem indicates an expected call of isUbuntuSystem. +func (mr *MockHostManagerInterfaceMockRecorder) isUbuntuSystem() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isUbuntuSystem", reflect.TypeOf((*MockHostManagerInterface)(nil).isUbuntuSystem)) +} + +// rdmaIsLoaded mocks base method. +func (m *MockHostManagerInterface) rdmaIsLoaded() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "rdmaIsLoaded") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// rdmaIsLoaded indicates an expected call of rdmaIsLoaded. +func (mr *MockHostManagerInterfaceMockRecorder) rdmaIsLoaded() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "rdmaIsLoaded", reflect.TypeOf((*MockHostManagerInterface)(nil).rdmaIsLoaded)) +} + +// reloadDriver mocks base method. +func (m *MockHostManagerInterface) reloadDriver(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "reloadDriver", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// reloadDriver indicates an expected call of reloadDriver. +func (mr *MockHostManagerInterfaceMockRecorder) reloadDriver(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "reloadDriver", reflect.TypeOf((*MockHostManagerInterface)(nil).reloadDriver), arg0) +} + +// triggerUdevEvent mocks base method. +func (m *MockHostManagerInterface) triggerUdevEvent() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "triggerUdevEvent") + ret0, _ := ret[0].(error) + return ret0 +} + +// triggerUdevEvent indicates an expected call of triggerUdevEvent. +func (mr *MockHostManagerInterfaceMockRecorder) triggerUdevEvent() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "triggerUdevEvent", reflect.TypeOf((*MockHostManagerInterface)(nil).triggerUdevEvent)) +} diff --git a/pkg/plugins/generic/generic_plugin.go b/pkg/plugins/generic/generic_plugin.go index a2a7e1c00..c5b8c3596 100644 --- a/pkg/plugins/generic/generic_plugin.go +++ b/pkg/plugins/generic/generic_plugin.go @@ -12,6 +12,7 @@ import ( sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" constants "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/host" plugin "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins" "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" ) @@ -25,6 +26,8 @@ type GenericPlugin struct { LastState *sriovnetworkv1.SriovNetworkNodeState LoadVfioDriver uint LoadVirtioVdpaDriver uint + RunningOnHost bool + HostManager host.HostManagerInterface } const scriptsPath = "bindata/scripts/enable-kargs.sh" @@ -36,12 +39,14 @@ const ( ) // Initialize our plugin and set up initial values -func NewGenericPlugin() (plugin.VendorPlugin, error) { +func NewGenericPlugin(runningOnHost bool) (plugin.VendorPlugin, error) { return &GenericPlugin{ PluginName: PluginName, SpecVersion: "1.0", LoadVfioDriver: unloaded, LoadVirtioVdpaDriver: unloaded, + RunningOnHost: runningOnHost, + HostManager: host.NewHostManager(runningOnHost), }, nil } @@ -76,7 +81,7 @@ func (p *GenericPlugin) OnNodeStateChange(new *sriovnetworkv1.SriovNetworkNodeSt func (p *GenericPlugin) Apply() error { glog.Infof("generic-plugin Apply(): desiredState=%v", p.DesireState.Spec) if p.LoadVfioDriver == loading { - if err := utils.LoadKernelModule("vfio_pci"); err != nil { + if err := p.HostManager.LoadKernelModule("vfio_pci"); err != nil { glog.Errorf("generic-plugin Apply(): fail to load vfio_pci kmod: %v", err) return err } @@ -84,7 +89,7 @@ func (p *GenericPlugin) Apply() error { } if p.LoadVirtioVdpaDriver == loading { - if err := utils.LoadKernelModule("virtio_vdpa"); err != nil { + if err := p.HostManager.LoadKernelModule("virtio_vdpa"); err != nil { glog.Errorf("generic-plugin Apply(): fail to load virtio_vdpa kmod: %v", err) return err } @@ -107,11 +112,15 @@ func (p *GenericPlugin) Apply() error { return err } - exit, err := utils.Chroot("/host") - if err != nil { - return err + // When calling from systemd do not try to chroot + if !p.RunningOnHost { + exit, err := utils.Chroot("/host") + if err != nil { + return err + } + defer exit() } - defer exit() + if err := utils.SyncNodeState(p.DesireState, pfsToSkip); err != nil { return err } diff --git a/pkg/plugins/generic/generic_plugin_test.go b/pkg/plugins/generic/generic_plugin_test.go index 69e850154..0d8fc0000 100644 --- a/pkg/plugins/generic/generic_plugin_test.go +++ b/pkg/plugins/generic/generic_plugin_test.go @@ -20,7 +20,7 @@ var _ = Describe("Generic plugin", func() { var genericPlugin plugin.VendorPlugin var err error BeforeEach(func() { - genericPlugin, err = generic.NewGenericPlugin() + genericPlugin, err = generic.NewGenericPlugin(false) Expect(err).ToNot(HaveOccurred()) }) diff --git a/pkg/plugins/k8s/k8s_plugin.go b/pkg/plugins/k8s/k8s_plugin.go index 948e51a5e..bd10408ff 100644 --- a/pkg/plugins/k8s/k8s_plugin.go +++ b/pkg/plugins/k8s/k8s_plugin.go @@ -29,7 +29,9 @@ type K8sPlugin struct { switchdevAfterNMService *service.Service openVSwitchService *service.Service networkManagerService *service.Service + sriovService *service.Service updateTarget *k8sUpdateTarget + useSystemdService bool } type k8sUpdateTarget struct { @@ -38,15 +40,16 @@ type k8sUpdateTarget struct { switchdevBeforeNMRunScript bool switchdevAfterNMRunScript bool switchdevUdevScript bool + sriovScript bool systemServices []*service.Service } func (u *k8sUpdateTarget) needUpdate() bool { - return u.switchdevBeforeNMService || u.switchdevAfterNMService || u.switchdevBeforeNMRunScript || u.switchdevAfterNMRunScript || u.switchdevUdevScript || len(u.systemServices) > 0 + return u.switchdevBeforeNMService || u.switchdevAfterNMService || u.switchdevBeforeNMRunScript || u.switchdevAfterNMRunScript || u.switchdevUdevScript || u.sriovScript || len(u.systemServices) > 0 } func (u *k8sUpdateTarget) needReboot() bool { - return u.switchdevBeforeNMService || u.switchdevAfterNMService || u.switchdevBeforeNMRunScript || u.switchdevAfterNMRunScript || u.switchdevUdevScript + return u.switchdevBeforeNMService || u.switchdevAfterNMService || u.switchdevBeforeNMRunScript || u.switchdevAfterNMRunScript || u.switchdevUdevScript || u.sriovScript } func (u *k8sUpdateTarget) reset() { @@ -54,6 +57,8 @@ func (u *k8sUpdateTarget) reset() { u.switchdevAfterNMService = false u.switchdevBeforeNMRunScript = false u.switchdevAfterNMRunScript = false + u.switchdevUdevScript = false + u.sriovScript = false u.systemServices = []*service.Service{} } @@ -76,8 +81,11 @@ func (u *k8sUpdateTarget) String() string { } const ( - switchdevManifestPath = "bindata/manifests/switchdev-config/" + bindataManifestPath = "bindata/manifests/" + switchdevManifestPath = bindataManifestPath + "switchdev-config/" switchdevUnits = switchdevManifestPath + "switchdev-units/" + sriovUnits = bindataManifestPath + "sriov-config-service/kubernetes/" + sriovUnitFile = sriovUnits + "sriov-config-service.yaml" switchdevBeforeNMUnitFile = switchdevUnits + "switchdev-configuration-before-nm.yaml" switchdevAfterNMUnitFile = switchdevUnits + "switchdev-configuration-after-nm.yaml" networkManagerUnitFile = switchdevUnits + "NetworkManager.service.yaml" @@ -90,12 +98,13 @@ const ( ) // Initialize our plugin and set up initial values -func NewK8sPlugin() (plugins.VendorPlugin, error) { +func NewK8sPlugin(useSystemdService bool) (plugins.VendorPlugin, error) { k8sPluging := &K8sPlugin{ - PluginName: PluginName, - SpecVersion: "1.0", - serviceManager: service.NewServiceManager(chroot), - updateTarget: &k8sUpdateTarget{}, + PluginName: PluginName, + SpecVersion: "1.0", + serviceManager: service.NewServiceManager(chroot), + updateTarget: &k8sUpdateTarget{}, + useSystemdService: useSystemdService, } return k8sPluging, k8sPluging.readManifestFiles() @@ -120,15 +129,26 @@ func (p *K8sPlugin) OnNodeStateChange(new *sriovnetworkv1.SriovNetworkNodeState) p.updateTarget.reset() // TODO add check for enableOvsOffload in OperatorConfig later // Update services if switchdev required - if !utils.IsSwitchdevModeSpec(new.Spec) { + if !p.useSystemdService && !utils.IsSwitchdevModeSpec(new.Spec) { return } - // Check services - err = p.servicesStateUpdate() - if err != nil { - glog.Errorf("k8s-plugin OnNodeStateChange(): failed : %v", err) - return + if utils.IsSwitchdevModeSpec(new.Spec) { + // Check services + err = p.switchDevServicesStateUpdate() + if err != nil { + glog.Errorf("k8s-plugin OnNodeStateChange(): failed : %v", err) + return + } + } + + if p.useSystemdService { + // Check sriov service + err = p.sriovServiceStateUpdate() + if err != nil { + glog.Errorf("k8s-plugin OnNodeStateChange(): failed : %v", err) + return + } } if p.updateTarget.needUpdate() { @@ -151,6 +171,12 @@ func (p *K8sPlugin) Apply() error { return err } + if p.useSystemdService { + if err := p.updateSriovService(); err != nil { + return err + } + } + for _, systemService := range p.updateTarget.systemServices { if err := p.updateSystemService(systemService); err != nil { return err @@ -230,6 +256,16 @@ func (p *K8sPlugin) readOpenVSwitchdManifest() error { return nil } +func (p *K8sPlugin) readSriovServiceManifest() error { + sriovService, err := service.ReadServiceManifestFile(sriovUnitFile) + if err != nil { + return err + } + + p.sriovService = sriovService + return nil +} + func (p *K8sPlugin) readManifestFiles() error { if err := p.readSwitchdevManifest(); err != nil { return err @@ -243,6 +279,10 @@ func (p *K8sPlugin) readManifestFiles() error { return err } + if err := p.readSriovServiceManifest(); err != nil { + return err + } + return nil } @@ -281,7 +321,27 @@ func (p *K8sPlugin) switchdevServiceStateUpdate() error { return nil } -func (p *K8sPlugin) getSystemServices() []*service.Service { +func (p *K8sPlugin) sriovServiceStateUpdate() error { + glog.Info("sriovServiceStateUpdate()") + exist, err := p.serviceManager.IsServiceExist(p.sriovService.Path) + if err != nil { + return err + } + + // create the service if it doesn't exist + if !exist { + p.updateTarget.sriovScript = true + } else { + p.updateTarget.sriovScript = p.isSystemServiceNeedUpdate(p.sriovService) + } + + if p.updateTarget.sriovScript { + p.updateTarget.systemServices = append(p.updateTarget.systemServices, p.sriovService) + } + return nil +} + +func (p *K8sPlugin) getSwitchDevSystemServices() []*service.Service { return []*service.Service{p.networkManagerService, p.openVSwitchService} } @@ -316,16 +376,17 @@ func (p *K8sPlugin) isSwitchdevServiceNeedUpdate(serviceObj *service.Service) (n } func (p *K8sPlugin) isSystemServiceNeedUpdate(serviceObj *service.Service) bool { + glog.Infof("isSystemServiceNeedUpdate()") systemService, err := p.serviceManager.ReadService(serviceObj.Path) if err != nil { - glog.Warningf("k8s-plugin isSystemServiceNeedUpdate(): failed to read switchdev service file %q: %v", + glog.Warningf("k8s-plugin isSystemServiceNeedUpdate(): failed to read sriov-config service file %q: %v", serviceObj.Path, err) return false } if systemService != nil { needChange, err := service.CompareServices(systemService, serviceObj) if err != nil { - glog.Warningf("k8s-plugin isSystemServiceNeedUpdate(): failed to compare switchdev service : %v", err) + glog.Warningf("k8s-plugin isSystemServiceNeedUpdate(): failed to compare sriov-config service: %v", err) return false } return needChange @@ -336,7 +397,7 @@ func (p *K8sPlugin) isSystemServiceNeedUpdate(serviceObj *service.Service) bool func (p *K8sPlugin) systemServicesStateUpdate() error { var services []*service.Service - for _, systemService := range p.getSystemServices() { + for _, systemService := range p.getSwitchDevSystemServices() { exist, err := p.serviceManager.IsServiceExist(systemService.Path) if err != nil { return err @@ -353,7 +414,7 @@ func (p *K8sPlugin) systemServicesStateUpdate() error { return nil } -func (p *K8sPlugin) servicesStateUpdate() error { +func (p *K8sPlugin) switchDevServicesStateUpdate() error { // Check switchdev err := p.switchdevServiceStateUpdate() if err != nil { @@ -369,6 +430,17 @@ func (p *K8sPlugin) servicesStateUpdate() error { return nil } +func (p *K8sPlugin) updateSriovService() error { + if p.updateTarget.sriovScript { + err := p.serviceManager.EnableService(p.sriovService) + if err != nil { + return err + } + } + + return nil +} + func (p *K8sPlugin) updateSwitchdevService() error { if p.updateTarget.switchdevBeforeNMService { err := p.serviceManager.EnableService(p.switchdevBeforeNMService) @@ -418,7 +490,7 @@ func (p *K8sPlugin) updateSystemService(serviceObj *service.Service) error { } if systemService == nil { // Invalid case to reach here - return fmt.Errorf("k8s-plugin Apply(): can't update non-existing service %q", serviceObj.Name) + return fmt.Errorf("k8s-plugin updateSystemService(): can't update non-existing service %q", serviceObj.Name) } serviceOptions, err := unit.Deserialize(strings.NewReader(serviceObj.Content)) if err != nil { @@ -431,3 +503,11 @@ func (p *K8sPlugin) updateSystemService(serviceObj *service.Service) error { return p.serviceManager.EnableService(updatedService) } + +func (p *K8sPlugin) SetSystemdFlag() { + p.useSystemdService = true +} + +func (p *K8sPlugin) IsSystemService() bool { + return p.useSystemdService +} diff --git a/pkg/plugins/virtual/virtual_plugin.go b/pkg/plugins/virtual/virtual_plugin.go index 2d74cbbd3..dbe0d9883 100644 --- a/pkg/plugins/virtual/virtual_plugin.go +++ b/pkg/plugins/virtual/virtual_plugin.go @@ -7,6 +7,7 @@ import ( sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" constants "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/host" plugin "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins" "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" ) @@ -20,6 +21,8 @@ type VirtualPlugin struct { DesireState *sriovnetworkv1.SriovNetworkNodeState LastState *sriovnetworkv1.SriovNetworkNodeState LoadVfioDriver uint + RunningOnHost bool + HostManager host.HostManagerInterface } const ( @@ -29,11 +32,13 @@ const ( ) // Initialize our plugin and set up initial values -func NewVirtualPlugin() (plugin.VendorPlugin, error) { +func NewVirtualPlugin(runningOnHost bool) (plugin.VendorPlugin, error) { return &VirtualPlugin{ PluginName: PluginName, SpecVersion: "1.0", LoadVfioDriver: unloaded, + RunningOnHost: runningOnHost, + HostManager: host.NewHostManager(runningOnHost), }, nil } @@ -74,12 +79,12 @@ func (p *VirtualPlugin) Apply() error { // This is the case for OpenStack deployments where the underlying virtualization platform is KVM. // NOTE: if VFIO was already loaded for some reason, we will not try to load it again with the new options. kernelArgs := "enable_unsafe_noiommu_mode=1" - if err := utils.LoadKernelModule("vfio", kernelArgs); err != nil { + if err := p.HostManager.LoadKernelModule("vfio", kernelArgs); err != nil { glog.Errorf("virtual-plugin Apply(): fail to load vfio kmod: %v", err) return err } - if err := utils.LoadKernelModule("vfio_pci"); err != nil { + if err := p.HostManager.LoadKernelModule("vfio_pci"); err != nil { glog.Errorf("virtual-plugin Apply(): fail to load vfio_pci kmod: %v", err) return err } @@ -107,6 +112,13 @@ func (p *VirtualPlugin) Apply() error { return nil } +func (p *VirtualPlugin) SetSystemdFlag() { +} + +func (p *VirtualPlugin) IsSystemService() bool { + return false +} + func needVfioDriver(state *sriovnetworkv1.SriovNetworkNodeState) bool { for _, iface := range state.Spec.Interfaces { for i := range iface.VfGroups { diff --git a/pkg/service/utils.go b/pkg/service/utils.go index 8b9e82c72..f9dc28f85 100644 --- a/pkg/service/utils.go +++ b/pkg/service/utils.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/coreos/go-systemd/v22/unit" + "github.com/golang/glog" "gopkg.in/yaml.v2" ) @@ -28,7 +29,7 @@ OUTER: continue OUTER } } - + glog.Infof("DEBUG: %+v %v", optsA, *optB) return true, nil } diff --git a/pkg/systemd/systemd.go b/pkg/systemd/systemd.go new file mode 100644 index 000000000..127b897be --- /dev/null +++ b/pkg/systemd/systemd.go @@ -0,0 +1,288 @@ +/* +Copyright 2023. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package systemd + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "strings" + + "github.com/golang/glog" + "gopkg.in/yaml.v3" + + sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" +) + +const ( + SriovSystemdConfigPath = utils.SriovConfBasePath + "/sriov-interface-config.yaml" + SriovSystemdResultPath = utils.SriovConfBasePath + "/sriov-interface-result.yaml" + sriovSystemdSupportedNicPath = utils.SriovConfBasePath + "/sriov-supported-nics-ids.yaml" + sriovSystemdServiceBinaryPath = "/var/lib/sriov/sriov-network-config-daemon" + + SriovHostSystemdConfigPath = "/host" + SriovSystemdConfigPath + SriovHostSystemdResultPath = "/host" + SriovSystemdResultPath + sriovHostSystemdSupportedNicPath = "/host" + sriovSystemdSupportedNicPath + sriovHostSystemdServiceBinaryPath = "/host" + sriovSystemdServiceBinaryPath + + SriovServicePath = "/etc/systemd/system/sriov-config.service" + SriovHostServicePath = "/host" + SriovServicePath +) + +type SriovConfig struct { + Spec sriovnetworkv1.SriovNetworkNodeStateSpec `yaml:"spec"` + UnsupportedNics bool `yaml:"unsupportedNics"` + PlatformType utils.PlatformType `yaml:"platformType"` +} + +type SriovResult struct { + SyncStatus string `yaml:"syncStatus"` + LastSyncError string `yaml:"lastSyncError"` +} + +func ReadConfFile() (spec *SriovConfig, err error) { + rawConfig, err := ioutil.ReadFile(SriovSystemdConfigPath) + if err != nil { + return nil, err + } + + err = yaml.Unmarshal(rawConfig, &spec) + + return spec, err +} + +func WriteConfFile(newState *sriovnetworkv1.SriovNetworkNodeState, unsupportedNics bool, platformType utils.PlatformType) (bool, error) { + newFile := false + // remove the device plugin revision as we don't need it here + newState.Spec.DpConfigVersion = "" + + sriovConfig := &SriovConfig{ + newState.Spec, + unsupportedNics, + platformType, + } + + _, err := os.Stat(SriovHostSystemdConfigPath) + if err != nil { + if os.IsNotExist(err) { + // Create the sriov-operator folder on the host if it doesn't exist + if _, err := os.Stat(utils.HostSriovConfBasePath); os.IsNotExist(err) { + err = os.Mkdir(utils.HostSriovConfBasePath, os.ModeDir) + if err != nil { + glog.Errorf("WriteConfFile(): fail to create sriov-operator folder: %v", err) + return false, err + } + } + + glog.V(2).Infof("WriteConfFile(): file not existed, create it") + _, err = os.Create(SriovHostSystemdConfigPath) + if err != nil { + glog.Errorf("WriteConfFile(): fail to create file: %v", err) + return false, err + } + newFile = true + } else { + return false, err + } + } + + oldContent, err := ioutil.ReadFile(SriovHostSystemdConfigPath) + if err != nil { + glog.Errorf("WriteConfFile(): fail to read file: %v", err) + return false, err + } + + oldContentObj := &SriovConfig{} + err = yaml.Unmarshal(oldContent, oldContentObj) + if err != nil { + glog.Errorf("WriteConfFile(): fail to unmarshal old file: %v", err) + return false, err + } + + var newContent []byte + newContent, err = yaml.Marshal(sriovConfig) + if err != nil { + glog.Errorf("WriteConfFile(): fail to marshal config: %v", err) + return false, err + } + + if bytes.Equal(newContent, oldContent) { + glog.V(2).Info("WriteConfFile(): no update") + return false, nil + } + glog.V(2).Infof("WriteConfFile(): previews configuration is not equal: old config:\n%s\nnew config:\n%s\n", string(oldContent), string(newContent)) + + glog.V(2).Infof("WriteConfFile(): write '%s' to %s", newContent, SriovHostSystemdConfigPath) + err = ioutil.WriteFile(SriovHostSystemdConfigPath, newContent, 0644) + if err != nil { + glog.Errorf("WriteConfFile(): fail to write file: %v", err) + return false, err + } + + // this will be used to mark the first time we create this file. + // this helps to avoid the first reboot after installation + if newFile && len(sriovConfig.Spec.Interfaces) == 0 { + glog.V(2).Info("WriteConfFile(): first file creation and no interfaces to configure returning reboot false") + return false, nil + } + + return true, nil +} + +func WriteSriovResult(result *SriovResult) error { + _, err := os.Stat(SriovSystemdResultPath) + if err != nil { + if os.IsNotExist(err) { + glog.V(2).Infof("WriteSriovResult(): file not existed, create it") + _, err = os.Create(SriovSystemdResultPath) + if err != nil { + glog.Errorf("WriteSriovResult(): failed to create sriov result file on path %s: %v", SriovSystemdResultPath, err) + return err + } + } else { + glog.Errorf("WriteSriovResult(): failed to check sriov result file on path %s: %v", SriovSystemdResultPath, err) + return err + } + } + + out, err := yaml.Marshal(result) + if err != nil { + glog.Errorf("WriteSriovResult(): failed to marshal sriov result file: %v", err) + return err + } + + glog.V(2).Infof("WriteSriovResult(): write '%s' to %s", string(out), SriovSystemdResultPath) + err = ioutil.WriteFile(SriovSystemdResultPath, out, 0644) + if err != nil { + glog.Errorf("WriteSriovResult(): failed to write sriov result file on path %s: %v", SriovSystemdResultPath, err) + return err + } + + return nil +} + +func ReadSriovResult() (*SriovResult, error) { + _, err := os.Stat(SriovHostSystemdResultPath) + if err != nil { + if os.IsNotExist(err) { + glog.V(2).Infof("ReadSriovResult(): file not existed, return empty result") + return &SriovResult{}, nil + } else { + glog.Errorf("ReadSriovResult(): failed to check sriov result file on path %s: %v", SriovHostSystemdResultPath, err) + return nil, err + } + } + + rawConfig, err := ioutil.ReadFile(SriovHostSystemdResultPath) + if err != nil { + glog.Errorf("ReadSriovResult(): failed to read sriov result file on path %s: %v", SriovHostSystemdResultPath, err) + return nil, err + } + + result := &SriovResult{} + err = yaml.Unmarshal(rawConfig, &result) + if err != nil { + glog.Errorf("ReadSriovResult(): failed to unmarshal sriov result file on path %s: %v", SriovHostSystemdResultPath, err) + return nil, err + } + return result, err +} + +func WriteSriovSupportedNics() error { + _, err := os.Stat(sriovHostSystemdSupportedNicPath) + if err != nil { + if os.IsNotExist(err) { + glog.V(2).Infof("WriteSriovSupportedNics(): file not existed, create it") + _, err = os.Create(sriovHostSystemdSupportedNicPath) + if err != nil { + glog.Errorf("WriteSriovSupportedNics(): failed to create sriov supporter nics ids file on path %s: %v", sriovHostSystemdSupportedNicPath, err) + return err + } + } else { + glog.Errorf("WriteSriovSupportedNics(): failed to check sriov supporter nics ids file on path %s: %v", sriovHostSystemdSupportedNicPath, err) + return err + } + } + + rawNicList := []byte{} + for _, line := range sriovnetworkv1.NicIDMap { + rawNicList = append(rawNicList, []byte(fmt.Sprintf("%s\n", line))...) + } + + err = ioutil.WriteFile(sriovHostSystemdSupportedNicPath, rawNicList, 0644) + if err != nil { + glog.Errorf("WriteSriovSupportedNics(): failed to write sriov supporter nics ids file on path %s: %v", sriovHostSystemdSupportedNicPath, err) + return err + } + + return nil +} + +func ReadSriovSupportedNics() ([]string, error) { + _, err := os.Stat(sriovSystemdSupportedNicPath) + if err != nil { + if os.IsNotExist(err) { + glog.V(2).Infof("ReadSriovSupportedNics(): file not existed, return empty result") + return nil, err + } else { + glog.Errorf("ReadSriovSupportedNics(): failed to check sriov supporter nics file on path %s: %v", sriovSystemdSupportedNicPath, err) + return nil, err + } + } + + rawConfig, err := ioutil.ReadFile(sriovSystemdSupportedNicPath) + if err != nil { + glog.Errorf("ReadSriovSupportedNics(): failed to read sriov supporter nics file on path %s: %v", sriovSystemdSupportedNicPath, err) + return nil, err + } + + lines := strings.Split(string(rawConfig), "\n") + return lines, nil +} + +func CleanSriovFilesFromHost(isOpenShift bool) error { + err := os.Remove(SriovHostSystemdConfigPath) + if err != nil && !os.IsNotExist(err) { + return err + } + + err = os.Remove(SriovHostSystemdResultPath) + if err != nil && !os.IsNotExist(err) { + return err + } + + err = os.Remove(sriovHostSystemdSupportedNicPath) + if err != nil && !os.IsNotExist(err) { + return err + } + + err = os.Remove(sriovHostSystemdServiceBinaryPath) + if err != nil && !os.IsNotExist(err) { + return err + } + + // in openshift we should not remove the systemd service it will be done by the machine config operator + if !isOpenShift { + err = os.Remove(SriovHostServicePath) + if err != nil && !os.IsNotExist(err) { + return err + } + } + + return nil +} diff --git a/pkg/utils/command.go b/pkg/utils/command.go new file mode 100644 index 000000000..085cf5881 --- /dev/null +++ b/pkg/utils/command.go @@ -0,0 +1,41 @@ +/* +Copyright 2021. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package utils + +import ( + "bytes" + "os/exec" +) + +// Interface to run commands +// +//go:generate ../../bin/mockgen -destination mock/mock_command.go -source command.go +type CommandInterface interface { + Run(string, ...string) (stdout bytes.Buffer, stderr bytes.Buffer, err error) +} + +type Command struct { +} + +func (c *Command) Run(name string, args ...string) (stdout bytes.Buffer, stderr bytes.Buffer, err error) { + var stdoutbuff, stderrbuff bytes.Buffer + cmd := exec.Command(name, args...) + cmd.Stdout = &stdoutbuff + cmd.Stderr = &stderrbuff + + err = cmd.Run() + return +} diff --git a/pkg/utils/mock/mock_command.go b/pkg/utils/mock/mock_command.go new file mode 100644 index 000000000..d408a48b0 --- /dev/null +++ b/pkg/utils/mock/mock_command.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: command.go + +// Package mock_utils is a generated GoMock package. +package mock_utils + +import ( + bytes "bytes" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockCommandInterface is a mock of CommandInterface interface. +type MockCommandInterface struct { + ctrl *gomock.Controller + recorder *MockCommandInterfaceMockRecorder +} + +// MockCommandInterfaceMockRecorder is the mock recorder for MockCommandInterface. +type MockCommandInterfaceMockRecorder struct { + mock *MockCommandInterface +} + +// NewMockCommandInterface creates a new mock instance. +func NewMockCommandInterface(ctrl *gomock.Controller) *MockCommandInterface { + mock := &MockCommandInterface{ctrl: ctrl} + mock.recorder = &MockCommandInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCommandInterface) EXPECT() *MockCommandInterfaceMockRecorder { + return m.recorder +} + +// Run mocks base method. +func (m *MockCommandInterface) Run(arg0 string, arg1 ...string) (bytes.Buffer, bytes.Buffer, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Run", varargs...) + ret0, _ := ret[0].(bytes.Buffer) + ret1, _ := ret[1].(bytes.Buffer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Run indicates an expected call of Run. +func (mr *MockCommandInterfaceMockRecorder) Run(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockCommandInterface)(nil).Run), varargs...) +} diff --git a/pkg/utils/switchdev.go b/pkg/utils/sriov.go similarity index 55% rename from pkg/utils/switchdev.go rename to pkg/utils/sriov.go index 58eabae22..c18122140 100644 --- a/pkg/utils/switchdev.go +++ b/pkg/utils/sriov.go @@ -1,8 +1,24 @@ +/* +Copyright 2021. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package utils import ( "bytes" "encoding/json" + "fmt" "io/ioutil" "os" @@ -12,7 +28,10 @@ import ( ) const ( - switchDevConfPath = "/host/etc/sriov_config.json" + SriovConfBasePath = "/etc/sriov-operator" + HostSriovConfBasePath = "/host" + SriovConfBasePath + SriovSwitchDevConfPath = SriovConfBasePath + "/sriov_config.json" + SriovHostSwitchDevConfPath = "/host" + SriovSwitchDevConfPath ) type config struct { @@ -28,13 +47,21 @@ func IsSwitchdevModeSpec(spec sriovnetworkv1.SriovNetworkNodeStateSpec) bool { return false } +func findInterface(interfaces sriovnetworkv1.Interfaces, name string) (iface sriovnetworkv1.Interface, err error) { + for _, i := range interfaces { + if i.Name == name { + return i, nil + } + } + return sriovnetworkv1.Interface{}, fmt.Errorf("unable to find interface: %v", name) +} + func WriteSwitchdevConfFile(newState *sriovnetworkv1.SriovNetworkNodeState) (update bool, err error) { // Create a map with all the PFs we will need to SKIP for systemd configuration pfsToSkip, err := GetPfsToSkip(newState) if err != nil { return false, err } - cfg := config{} for _, iface := range newState.Spec.Interfaces { for _, ifaceStatus := range newState.Status.Interfaces { @@ -48,12 +75,20 @@ func WriteSwitchdevConfFile(newState *sriovnetworkv1.SriovNetworkNodeState) (upd i := sriovnetworkv1.Interface{} if iface.NumVfs > 0 { + var vfGroups []sriovnetworkv1.VfGroup = nil + ifc, err := findInterface(newState.Spec.Interfaces, iface.Name) + if err != nil { + glog.Errorf("WriteSwitchdevConfFile(): fail find interface: %v", err) + } else { + vfGroups = ifc.VfGroups + } i = sriovnetworkv1.Interface{ // Not passing all the contents, since only NumVfs and EswitchMode can be configured by configure-switchdev.sh currently. Name: iface.Name, PciAddress: iface.PciAddress, NumVfs: iface.NumVfs, - VfGroups: iface.VfGroups, + Mtu: iface.Mtu, + VfGroups: vfGroups, } if iface.EswitchMode == sriovnetworkv1.ESwithModeSwitchDev { @@ -63,15 +98,25 @@ func WriteSwitchdevConfFile(newState *sriovnetworkv1.SriovNetworkNodeState) (upd } } } - _, err = os.Stat(switchDevConfPath) + _, err = os.Stat(SriovHostSwitchDevConfPath) if err != nil { if os.IsNotExist(err) { if len(cfg.Interfaces) == 0 { err = nil return } + + // Create the sriov-operator folder on the host if it doesn't exist + if _, err := os.Stat("/host" + SriovConfBasePath); os.IsNotExist(err) { + err = os.Mkdir("/host"+SriovConfBasePath, os.ModeDir) + if err != nil { + glog.Errorf("WriteConfFile(): fail to create sriov-operator folder: %v", err) + return false, err + } + } + glog.V(2).Infof("WriteSwitchdevConfFile(): file not existed, create it") - _, err = os.Create(switchDevConfPath) + _, err = os.Create(SriovHostSwitchDevConfPath) if err != nil { glog.Errorf("WriteSwitchdevConfFile(): fail to create file: %v", err) return @@ -80,7 +125,7 @@ func WriteSwitchdevConfFile(newState *sriovnetworkv1.SriovNetworkNodeState) (upd return } } - oldContent, err := ioutil.ReadFile(switchDevConfPath) + oldContent, err := ioutil.ReadFile(SriovHostSwitchDevConfPath) if err != nil { glog.Errorf("WriteSwitchdevConfFile(): fail to read file: %v", err) return @@ -100,7 +145,7 @@ func WriteSwitchdevConfFile(newState *sriovnetworkv1.SriovNetworkNodeState) (upd } update = true glog.V(2).Infof("WriteSwitchdevConfFile(): write '%s' to switchdev.conf", newContent) - err = ioutil.WriteFile(switchDevConfPath, newContent, 0644) + err = ioutil.WriteFile(SriovHostSwitchDevConfPath, newContent, 0644) if err != nil { glog.Errorf("WriteSwitchdevConfFile(): fail to write file: %v", err) return diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 37888e4b1..257f4c0ca 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -34,7 +34,7 @@ const ( sysClassNet = "/sys/class/net" netClass = 0x02 numVfsFile = "sriov_numvfs" - scriptsPath = "bindata/scripts/load-kmod.sh" + ClusterTypeOpenshift = "openshift" ClusterTypeKubernetes = "kubernetes" VendorMellanox = "15b3" @@ -148,14 +148,18 @@ func DiscoverSriovDevices(withUnsupported bool) ([]sriovnetworkv1.InterfaceExt, // SyncNodeState Attempt to update the node state to match the desired state func SyncNodeState(newState *sriovnetworkv1.SriovNetworkNodeState, pfsToConfig map[string]bool) error { - if IsKernelLockdownMode(true) && hasMellanoxInterfacesInSpec(newState) { + return ConfigSriovInterfaces(newState.Spec.Interfaces, newState.Status.Interfaces, pfsToConfig) +} + +func ConfigSriovInterfaces(interfaces []sriovnetworkv1.Interface, ifaceStatuses []sriovnetworkv1.InterfaceExt, pfsToConfig map[string]bool) error { + if IsKernelLockdownMode(true) && hasMellanoxInterfacesInSpec(ifaceStatuses, interfaces) { glog.Warningf("cannot use mellanox devices when in kernel lockdown mode") return fmt.Errorf("cannot use mellanox devices when in kernel lockdown mode") } var err error - for _, ifaceStatus := range newState.Status.Interfaces { + for _, ifaceStatus := range ifaceStatuses { configured := false - for _, iface := range newState.Spec.Interfaces { + for _, iface := range interfaces { if iface.PciAddress == ifaceStatus.PciAddress { configured = true @@ -589,18 +593,6 @@ func getVfInfo(pciAddr string, devices []*ghw.PCIDevice) sriovnetworkv1.VirtualF return vf } -func LoadKernelModule(name string, args ...string) error { - glog.Infof("LoadKernelModule(): try to load kernel module %s with arguments '%s'", name, args) - cmdArgs := strings.Join(args, " ") - cmd := exec.Command("/bin/sh", scriptsPath, name, cmdArgs) - err := cmd.Run() - if err != nil { - glog.Errorf("LoadKernelModule(): fail to load kernel module %s with arguments '%s': %v", name, args, err) - return err - } - return nil -} - func Chroot(path string) (func() error, error) { root, err := os.Open("/") if err != nil { @@ -788,10 +780,10 @@ func RunCommand(command string, args ...string) (string, error) { return stdout.String(), err } -func hasMellanoxInterfacesInSpec(newState *sriovnetworkv1.SriovNetworkNodeState) bool { - for _, ifaceStatus := range newState.Status.Interfaces { +func hasMellanoxInterfacesInSpec(ifaceStatuses sriovnetworkv1.InterfaceExts, ifaceSpecs sriovnetworkv1.Interfaces) bool { + for _, ifaceStatus := range ifaceStatuses { if ifaceStatus.Vendor == VendorMellanox { - for _, iface := range newState.Spec.Interfaces { + for _, iface := range ifaceSpecs { if iface.PciAddress == ifaceStatus.PciAddress { glog.V(2).Infof("hasMellanoxInterfacesInSpec(): Mellanox device %s (pci: %s) specified in SriovNetworkNodeState spec", ifaceStatus.Name, ifaceStatus.PciAddress) return true diff --git a/pkg/utils/utils_virtual.go b/pkg/utils/utils_virtual.go index 9c710aee3..f2e9eb821 100644 --- a/pkg/utils/utils_virtual.go +++ b/pkg/utils/utils_virtual.go @@ -49,12 +49,15 @@ var ( ) const ( - ospMetaDataDir = "/host/var/config/openstack/2018-08-27" - ospMetaDataBaseURL = "http://169.254.169.254/openstack/2018-08-27" - ospNetworkDataFile = ospMetaDataDir + "/network_data.json" - ospMetaDataFile = ospMetaDataDir + "/meta_data.json" - ospNetworkDataURL = ospMetaDataBaseURL + "/network_data.json" - ospMetaDataURL = ospMetaDataBaseURL + "/meta_data.json" + ospHostMetaDataDir = "/host/var/config/openstack/2018-08-27" + ospMetaDataDir = "/var/config/openstack/2018-08-27" + ospMetaDataBaseURL = "http://169.254.169.254/openstack/2018-08-27" + ospHostNetworkDataFile = ospHostMetaDataDir + "/network_data.json" + ospHostMetaDataFile = ospHostMetaDataDir + "/meta_data.json" + ospNetworkDataFile = ospMetaDataDir + "/network_data.json" + ospMetaDataFile = ospMetaDataDir + "/meta_data.json" + ospNetworkDataURL = ospMetaDataBaseURL + "/network_data.json" + ospMetaDataURL = ospMetaDataBaseURL + "/meta_data.json" ) // OSPMetaDataDevice -- Device structure within meta_data.json @@ -111,8 +114,8 @@ type OSPDeviceInfo struct { } // GetOpenstackData gets the metadata and network_data -func GetOpenstackData() (metaData *OSPMetaData, networkData *OSPNetworkData, err error) { - metaData, networkData, err = getOpenstackDataFromConfigDrive() +func GetOpenstackData(useHostPath bool) (metaData *OSPMetaData, networkData *OSPNetworkData, err error) { + metaData, networkData, err = getOpenstackDataFromConfigDrive(useHostPath) if err != nil { metaData, networkData, err = getOpenstackDataFromMetadataService() } @@ -120,37 +123,45 @@ func GetOpenstackData() (metaData *OSPMetaData, networkData *OSPNetworkData, err } // getOpenstackDataFromConfigDrive reads the meta_data and network_data files -func getOpenstackDataFromConfigDrive() (metaData *OSPMetaData, networkData *OSPNetworkData, err error) { +func getOpenstackDataFromConfigDrive(useHostPath bool) (metaData *OSPMetaData, networkData *OSPNetworkData, err error) { metaData = &OSPMetaData{} networkData = &OSPNetworkData{} glog.Infof("reading OpenStack meta_data from config-drive") var metadataf *os.File - metadataf, err = os.Open(ospMetaDataFile) + ospMetaDataFilePath := ospMetaDataFile + if useHostPath { + ospMetaDataFilePath = ospHostMetaDataFile + } + metadataf, err = os.Open(ospMetaDataFilePath) if err != nil { - return metaData, networkData, fmt.Errorf("error opening file %s: %w", ospMetaDataFile, err) + return metaData, networkData, fmt.Errorf("error opening file %s: %w", ospHostMetaDataFile, err) } defer func() { if e := metadataf.Close(); err == nil && e != nil { - err = fmt.Errorf("error closing file %s: %w", ospMetaDataFile, e) + err = fmt.Errorf("error closing file %s: %w", ospHostMetaDataFile, e) } }() if err = json.NewDecoder(metadataf).Decode(&metaData); err != nil { - return metaData, networkData, fmt.Errorf("error unmarshalling metadata from file %s: %w", ospMetaDataFile, err) + return metaData, networkData, fmt.Errorf("error unmarshalling metadata from file %s: %w", ospHostMetaDataFile, err) } glog.Infof("reading OpenStack network_data from config-drive") var networkDataf *os.File - networkDataf, err = os.Open(ospNetworkDataFile) + ospNetworkDataFilePath := ospNetworkDataFile + if useHostPath { + ospNetworkDataFilePath = ospHostNetworkDataFile + } + networkDataf, err = os.Open(ospNetworkDataFilePath) if err != nil { - return metaData, networkData, fmt.Errorf("error opening file %s: %w", ospNetworkDataFile, err) + return metaData, networkData, fmt.Errorf("error opening file %s: %w", ospHostNetworkDataFile, err) } defer func() { if e := networkDataf.Close(); err == nil && e != nil { - err = fmt.Errorf("error closing file %s: %w", ospNetworkDataFile, e) + err = fmt.Errorf("error closing file %s: %w", ospHostNetworkDataFile, e) } }() if err = json.NewDecoder(networkDataf).Decode(&networkData); err != nil { - return metaData, networkData, fmt.Errorf("error unmarshalling metadata from file %s: %w", ospNetworkDataFile, err) + return metaData, networkData, fmt.Errorf("error unmarshalling metadata from file %s: %w", ospHostNetworkDataFile, err) } return metaData, networkData, err } diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 270e0903d..c4d3686af 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -11,7 +11,7 @@ import ( snclientset "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/clientset/versioned" ) -var snclient *snclientset.Clientset +var snclient snclientset.Interface var kubeclient *kubernetes.Clientset func SetupInClusterClient() error { diff --git a/pkg/webhook/validate.go b/pkg/webhook/validate.go index 01f753608..91b3922fa 100644 --- a/pkg/webhook/validate.go +++ b/pkg/webhook/validate.go @@ -17,6 +17,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" ) @@ -35,17 +37,60 @@ func validateSriovOperatorConfig(cr *sriovnetworkv1.SriovOperatorConfig, operati glog.V(2).Infof("validateSriovOperatorConfig: %v", cr) var warnings []string - if cr.GetName() == constants.DefaultConfigName { - if operation == v1.Delete { - return false, warnings, fmt.Errorf("default SriovOperatorConfig shouldn't be deleted") + if cr.GetName() != constants.DefaultConfigName { + return false, warnings, fmt.Errorf("only default SriovOperatorConfig is used") + } + + if operation == v1.Delete { + return false, warnings, fmt.Errorf("default SriovOperatorConfig shouldn't be deleted") + } + + if cr.Spec.DisableDrain { + warnings = append(warnings, "Node draining is disabled for applying SriovNetworkNodePolicy, it may result in workload interruption.") + } + + err := validateSriovOperatorConfigDisableDrain(cr) + if err != nil { + return false, warnings, err + } + + return true, warnings, nil +} + +// validateSriovOperatorConfigDisableDrain checks if the user is setting `.Spec.DisableDrain` from false to true while +// operator is updating one or more nodes. Disabling the drain at this stage would prevent the operator to uncordon a node at +// the end of the update operation, keeping nodes un-schedulable until manual intervention. +func validateSriovOperatorConfigDisableDrain(cr *sriovnetworkv1.SriovOperatorConfig) error { + if !cr.Spec.DisableDrain { + return nil + } + + previousConfig, err := snclient.SriovnetworkV1().SriovOperatorConfigs(cr.Namespace).Get(context.Background(), cr.Name, metav1.GetOptions{}) + if err != nil { + if k8serrors.IsNotFound(err) { + return nil } + return fmt.Errorf("can't validate SriovOperatorConfig[%s] DisableDrain against its previous value: %q", cr.Name, err) + } + + if previousConfig.Spec.DisableDrain == cr.Spec.DisableDrain { + // DisableDrain didn't change + return nil + } + + // DisableDrain has been changed `false -> true`, check if any node is updating + nodeStates, err := snclient.SriovnetworkV1().SriovNetworkNodeStates(namespace).List(context.Background(), metav1.ListOptions{}) + if err != nil { + return fmt.Errorf("can't validate SriovOperatorConfig[%s] DisableDrain transition to true: %q", cr.Name, err) + } - if cr.Spec.DisableDrain { - warnings = append(warnings, "Node draining is disabled for applying SriovNetworkNodePolicy, it may result in workload interruption.") + for _, nodeState := range nodeStates.Items { + if nodeState.Status.SyncStatus == "InProgress" { + return fmt.Errorf("can't set Spec.DisableDrain = true while node[%s] is updating", nodeState.Name) } - return true, warnings, nil } - return false, warnings, fmt.Errorf("only default SriovOperatorConfig is used") + + return nil } func validateSriovNetworkNodePolicy(cr *sriovnetworkv1.SriovNetworkNodePolicy, operation v1.Operation) (bool, []string, error) { diff --git a/pkg/webhook/validate_test.go b/pkg/webhook/validate_test.go index 028d62a5e..a82796312 100644 --- a/pkg/webhook/validate_test.go +++ b/pkg/webhook/validate_test.go @@ -1,6 +1,7 @@ package webhook import ( + "context" "fmt" "os" "testing" @@ -14,6 +15,8 @@ import ( . "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" constants "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" + + fakesnclientset "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/clientset/versioned/fake" ) func TestMain(m *testing.M) { @@ -129,11 +132,8 @@ func NewNode() *corev1.Node { return &corev1.Node{Spec: corev1.NodeSpec{ProviderID: "openstack"}} } -func TestValidateSriovOperatorConfigWithDefaultOperatorConfig(t *testing.T) { - var err error - var ok bool - var w []string - config := &SriovOperatorConfig{ +func newDefaultOperatorConfig() *SriovOperatorConfig { + return &SriovOperatorConfig{ ObjectMeta: metav1.ObjectMeta{ Name: "default", }, @@ -145,8 +145,15 @@ func TestValidateSriovOperatorConfigWithDefaultOperatorConfig(t *testing.T) { LogLevel: 2, }, } +} + +func TestValidateSriovOperatorConfigWithDefaultOperatorConfig(t *testing.T) { g := NewGomegaWithT(t) - ok, _, err = validateSriovOperatorConfig(config, "DELETE") + + config := newDefaultOperatorConfig() + snclient = fakesnclientset.NewSimpleClientset() + + ok, _, err := validateSriovOperatorConfig(config, "DELETE") g.Expect(err).To(HaveOccurred()) g.Expect(ok).To(Equal(false)) @@ -154,7 +161,7 @@ func TestValidateSriovOperatorConfigWithDefaultOperatorConfig(t *testing.T) { g.Expect(err).NotTo(HaveOccurred()) g.Expect(ok).To(Equal(true)) - ok, w, err = validateSriovOperatorConfig(config, "UPDATE") + ok, w, err := validateSriovOperatorConfig(config, "UPDATE") g.Expect(err).NotTo(HaveOccurred()) g.Expect(ok).To(Equal(true)) g.Expect(w[0]).To(ContainSubstring("Node draining is disabled")) @@ -164,6 +171,39 @@ func TestValidateSriovOperatorConfigWithDefaultOperatorConfig(t *testing.T) { g.Expect(ok).To(Equal(true)) } +func TestValidateSriovOperatorConfigDisableDrain(t *testing.T) { + g := NewGomegaWithT(t) + + config := newDefaultOperatorConfig() + config.Spec.DisableDrain = false + + nodeState := &SriovNetworkNodeState{ + ObjectMeta: metav1.ObjectMeta{Name: "worker-1", Namespace: namespace}, + Status: SriovNetworkNodeStateStatus{ + SyncStatus: "InProgress", + }, + } + + snclient = fakesnclientset.NewSimpleClientset( + config, + nodeState, + ) + + config.Spec.DisableDrain = true + ok, _, err := validateSriovOperatorConfig(config, "UPDATE") + g.Expect(err).To(MatchError("can't set Spec.DisableDrain = true while node[worker-1] is updating")) + g.Expect(ok).To(Equal(false)) + + // Simulate node update finished + nodeState.Status.SyncStatus = "Succeeded" + snclient.SriovnetworkV1().SriovNetworkNodeStates(namespace). + Update(context.Background(), nodeState, metav1.UpdateOptions{}) + + ok, _, err = validateSriovOperatorConfig(config, "UPDATE") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(ok).To(Equal(true)) +} + func TestValidateSriovNetworkNodePolicyWithDefaultPolicy(t *testing.T) { var err error var ok bool diff --git a/pkg/webhook/webhook.go b/pkg/webhook/webhook.go index 5dbbe7440..8acd2ac4f 100644 --- a/pkg/webhook/webhook.go +++ b/pkg/webhook/webhook.go @@ -14,7 +14,7 @@ import ( var namespace = os.Getenv("NAMESPACE") func RetriveSupportedNics() error { - if err := sriovnetworkv1.InitNicIDMap(kubeclient, namespace); err != nil { + if err := sriovnetworkv1.InitNicIDMapFromConfigMap(kubeclient, namespace); err != nil { return err } return nil diff --git a/test/conformance/tests/test_sriov_operator.go b/test/conformance/tests/test_sriov_operator.go index d093ba326..91c9a99ce 100644 --- a/test/conformance/tests/test_sriov_operator.go +++ b/test/conformance/tests/test_sriov_operator.go @@ -36,7 +36,7 @@ import ( "github.com/k8snetworkplumbingwg/sriov-network-operator/test/util/pod" ) -var waitingTime time.Duration = 20 * time.Minute +var waitingTime = 20 * time.Minute var sriovNetworkName = "test-sriovnetwork" var snoTimeoutMultiplier time.Duration = 0 diff --git a/test/util/cluster/cluster.go b/test/util/cluster/cluster.go index 430a96077..d5b9895ca 100644 --- a/test/util/cluster/cluster.go +++ b/test/util/cluster/cluster.go @@ -51,7 +51,7 @@ func DiscoverSriov(clients *testclient.ClientSet, operatorNamespace string) (*En return nil, fmt.Errorf("failed to find matching node states %v", err) } - err = sriovv1.InitNicIDMap(kubernetes.NewForConfigOrDie(clients.Config), operatorNamespace) + err = sriovv1.InitNicIDMapFromConfigMap(kubernetes.NewForConfigOrDie(clients.Config), operatorNamespace) if err != nil { return nil, fmt.Errorf("failed to InitNicIdMap %v", err) } diff --git a/vendor/github.com/golang/mock/AUTHORS b/vendor/github.com/golang/mock/AUTHORS new file mode 100644 index 000000000..660b8ccc8 --- /dev/null +++ b/vendor/github.com/golang/mock/AUTHORS @@ -0,0 +1,12 @@ +# This is the official list of GoMock authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS files. +# See the latter for an explanation. + +# Names should be added to this file as +# Name or Organization +# The email address is not required for organizations. + +# Please keep the list sorted. + +Alex Reece +Google Inc. diff --git a/vendor/github.com/golang/mock/CONTRIBUTORS b/vendor/github.com/golang/mock/CONTRIBUTORS new file mode 100644 index 000000000..def849cab --- /dev/null +++ b/vendor/github.com/golang/mock/CONTRIBUTORS @@ -0,0 +1,37 @@ +# This is the official list of people who can contribute (and typically +# have contributed) code to the gomock repository. +# The AUTHORS file lists the copyright holders; this file +# lists people. For example, Google employees are listed here +# but not in AUTHORS, because Google holds the copyright. +# +# The submission process automatically checks to make sure +# that people submitting code are listed in this file (by email address). +# +# Names should be added to this file only after verifying that +# the individual or the individual's organization has agreed to +# the appropriate Contributor License Agreement, found here: +# +# http://code.google.com/legal/individual-cla-v1.0.html +# http://code.google.com/legal/corporate-cla-v1.0.html +# +# The agreement for individuals can be filled out on the web. +# +# When adding J Random Contributor's name to this file, +# either J's name or J's organization's name should be +# added to the AUTHORS file, depending on whether the +# individual or corporate CLA was used. + +# Names should be added to this file like so: +# Name +# +# An entry with two email addresses specifies that the +# first address should be used in the submit logs and +# that the second address should be recognized as the +# same person when interacting with Rietveld. + +# Please keep the list sorted. + +Aaron Jacobs +Alex Reece +David Symonds +Ryan Barrett diff --git a/vendor/github.com/golang/mock/LICENSE b/vendor/github.com/golang/mock/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/vendor/github.com/golang/mock/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/golang/mock/gomock/call.go b/vendor/github.com/golang/mock/gomock/call.go new file mode 100644 index 000000000..7345f6540 --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/call.go @@ -0,0 +1,427 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// Call represents an expected call to a mock. +type Call struct { + t TestHelper // for triggering test failures on invalid call setup + + receiver interface{} // the receiver of the method call + method string // the name of the method + methodType reflect.Type // the type of the method + args []Matcher // the args + origin string // file and line number of call setup + + preReqs []*Call // prerequisite calls + + // Expectations + minCalls, maxCalls int + + numCalls int // actual number made + + // actions are called when this Call is called. Each action gets the args and + // can set the return values by returning a non-nil slice. Actions run in the + // order they are created. + actions []func([]interface{}) []interface{} +} + +// newCall creates a *Call. It requires the method type in order to support +// unexported methods. +func newCall(t TestHelper, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call { + t.Helper() + + // TODO: check arity, types. + margs := make([]Matcher, len(args)) + for i, arg := range args { + if m, ok := arg.(Matcher); ok { + margs[i] = m + } else if arg == nil { + // Handle nil specially so that passing a nil interface value + // will match the typed nils of concrete args. + margs[i] = Nil() + } else { + margs[i] = Eq(arg) + } + } + + origin := callerInfo(3) + actions := []func([]interface{}) []interface{}{func([]interface{}) []interface{} { + // Synthesize the zero value for each of the return args' types. + rets := make([]interface{}, methodType.NumOut()) + for i := 0; i < methodType.NumOut(); i++ { + rets[i] = reflect.Zero(methodType.Out(i)).Interface() + } + return rets + }} + return &Call{t: t, receiver: receiver, method: method, methodType: methodType, + args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} +} + +// AnyTimes allows the expectation to be called 0 or more times +func (c *Call) AnyTimes() *Call { + c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity + return c +} + +// MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called or if MaxTimes +// was previously called with 1, MinTimes also sets the maximum number of calls to infinity. +func (c *Call) MinTimes(n int) *Call { + c.minCalls = n + if c.maxCalls == 1 { + c.maxCalls = 1e8 + } + return c +} + +// MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called or if MinTimes was +// previously called with 1, MaxTimes also sets the minimum number of calls to 0. +func (c *Call) MaxTimes(n int) *Call { + c.maxCalls = n + if c.minCalls == 1 { + c.minCalls = 0 + } + return c +} + +// DoAndReturn declares the action to run when the call is matched. +// The return values from this function are returned by the mocked function. +// It takes an interface{} argument to support n-arity functions. +func (c *Call) DoAndReturn(f interface{}) *Call { + // TODO: Check arity and types here, rather than dying badly elsewhere. + v := reflect.ValueOf(f) + + c.addAction(func(args []interface{}) []interface{} { + vargs := make([]reflect.Value, len(args)) + ft := v.Type() + for i := 0; i < len(args); i++ { + if args[i] != nil { + vargs[i] = reflect.ValueOf(args[i]) + } else { + // Use the zero value for the arg. + vargs[i] = reflect.Zero(ft.In(i)) + } + } + vrets := v.Call(vargs) + rets := make([]interface{}, len(vrets)) + for i, ret := range vrets { + rets[i] = ret.Interface() + } + return rets + }) + return c +} + +// Do declares the action to run when the call is matched. The function's +// return values are ignored to retain backward compatibility. To use the +// return values call DoAndReturn. +// It takes an interface{} argument to support n-arity functions. +func (c *Call) Do(f interface{}) *Call { + // TODO: Check arity and types here, rather than dying badly elsewhere. + v := reflect.ValueOf(f) + + c.addAction(func(args []interface{}) []interface{} { + vargs := make([]reflect.Value, len(args)) + ft := v.Type() + for i := 0; i < len(args); i++ { + if args[i] != nil { + vargs[i] = reflect.ValueOf(args[i]) + } else { + // Use the zero value for the arg. + vargs[i] = reflect.Zero(ft.In(i)) + } + } + v.Call(vargs) + return nil + }) + return c +} + +// Return declares the values to be returned by the mocked function call. +func (c *Call) Return(rets ...interface{}) *Call { + c.t.Helper() + + mt := c.methodType + if len(rets) != mt.NumOut() { + c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]", + c.receiver, c.method, len(rets), mt.NumOut(), c.origin) + } + for i, ret := range rets { + if got, want := reflect.TypeOf(ret), mt.Out(i); got == want { + // Identical types; nothing to do. + } else if got == nil { + // Nil needs special handling. + switch want.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + // ok + default: + c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]", + i, c.receiver, c.method, want, c.origin) + } + } else if got.AssignableTo(want) { + // Assignable type relation. Make the assignment now so that the generated code + // can return the values with a type assertion. + v := reflect.New(want).Elem() + v.Set(reflect.ValueOf(ret)) + rets[i] = v.Interface() + } else { + c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]", + i, c.receiver, c.method, got, want, c.origin) + } + } + + c.addAction(func([]interface{}) []interface{} { + return rets + }) + + return c +} + +// Times declares the exact number of times a function call is expected to be executed. +func (c *Call) Times(n int) *Call { + c.minCalls, c.maxCalls = n, n + return c +} + +// SetArg declares an action that will set the nth argument's value, +// indirected through a pointer. Or, in the case of a slice, SetArg +// will copy value's elements into the nth argument. +func (c *Call) SetArg(n int, value interface{}) *Call { + c.t.Helper() + + mt := c.methodType + // TODO: This will break on variadic methods. + // We will need to check those at invocation time. + if n < 0 || n >= mt.NumIn() { + c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]", + n, mt.NumIn(), c.origin) + } + // Permit setting argument through an interface. + // In the interface case, we don't (nay, can't) check the type here. + at := mt.In(n) + switch at.Kind() { + case reflect.Ptr: + dt := at.Elem() + if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) { + c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]", + n, vt, dt, c.origin) + } + case reflect.Interface: + // nothing to do + case reflect.Slice: + // nothing to do + default: + c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]", + n, at, c.origin) + } + + c.addAction(func(args []interface{}) []interface{} { + v := reflect.ValueOf(value) + switch reflect.TypeOf(args[n]).Kind() { + case reflect.Slice: + setSlice(args[n], v) + default: + reflect.ValueOf(args[n]).Elem().Set(v) + } + return nil + }) + return c +} + +// isPreReq returns true if other is a direct or indirect prerequisite to c. +func (c *Call) isPreReq(other *Call) bool { + for _, preReq := range c.preReqs { + if other == preReq || preReq.isPreReq(other) { + return true + } + } + return false +} + +// After declares that the call may only match after preReq has been exhausted. +func (c *Call) After(preReq *Call) *Call { + c.t.Helper() + + if c == preReq { + c.t.Fatalf("A call isn't allowed to be its own prerequisite") + } + if preReq.isPreReq(c) { + c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq) + } + + c.preReqs = append(c.preReqs, preReq) + return c +} + +// Returns true if the minimum number of calls have been made. +func (c *Call) satisfied() bool { + return c.numCalls >= c.minCalls +} + +// Returns true if the maximum number of calls have been made. +func (c *Call) exhausted() bool { + return c.numCalls >= c.maxCalls +} + +func (c *Call) String() string { + args := make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = arg.String() + } + arguments := strings.Join(args, ", ") + return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin) +} + +// Tests if the given call matches the expected call. +// If yes, returns nil. If no, returns error with message explaining why it does not match. +func (c *Call) matches(args []interface{}) error { + if !c.methodType.IsVariadic() { + if len(args) != len(c.args) { + return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + + for i, m := range c.args { + if !m.Matches(args[i]) { + got := fmt.Sprintf("%v", args[i]) + if gs, ok := m.(GotFormatter); ok { + got = gs.Got(args[i]) + } + + return fmt.Errorf( + "expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v", + c.origin, i, got, m, + ) + } + } + } else { + if len(c.args) < c.methodType.NumIn()-1 { + return fmt.Errorf("expected call at %s has the wrong number of matchers. Got: %d, want: %d", + c.origin, len(c.args), c.methodType.NumIn()-1) + } + if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) { + return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + if len(args) < len(c.args)-1 { + return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d", + c.origin, len(args), len(c.args)-1) + } + + for i, m := range c.args { + if i < c.methodType.NumIn()-1 { + // Non-variadic args + if !m.Matches(args[i]) { + return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i], m) + } + continue + } + // The last arg has a possibility of a variadic argument, so let it branch + + // sample: Foo(a int, b int, c ...int) + if i < len(c.args) && i < len(args) { + if m.Matches(args[i]) { + // Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC) + // Got Foo(a, b) want Foo(matcherA, matcherB) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD) + continue + } + } + + // The number of actual args don't match the number of matchers, + // or the last matcher is a slice and the last arg is not. + // If this function still matches it is because the last matcher + // matches all the remaining arguments or the lack of any. + // Convert the remaining arguments, if any, into a slice of the + // expected type. + vargsType := c.methodType.In(c.methodType.NumIn() - 1) + vargs := reflect.MakeSlice(vargsType, 0, len(args)-i) + for _, arg := range args[i:] { + vargs = reflect.Append(vargs, reflect.ValueOf(arg)) + } + if m.Matches(vargs.Interface()) { + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher) + break + } + // Wrong number of matchers or not match. Fail. + // Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB) + return fmt.Errorf("Expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), args[i:], c.args[i]) + + } + } + + // Check that all prerequisite calls have been satisfied. + for _, preReqCall := range c.preReqs { + if !preReqCall.satisfied() { + return fmt.Errorf("Expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v", + c.origin, preReqCall, c) + } + } + + // Check that the call is not exhausted. + if c.exhausted() { + return fmt.Errorf("expected call at %s has already been called the max number of times", c.origin) + } + + return nil +} + +// dropPrereqs tells the expected Call to not re-check prerequisite calls any +// longer, and to return its current set. +func (c *Call) dropPrereqs() (preReqs []*Call) { + preReqs = c.preReqs + c.preReqs = nil + return +} + +func (c *Call) call() []func([]interface{}) []interface{} { + c.numCalls++ + return c.actions +} + +// InOrder declares that the given calls should occur in order. +func InOrder(calls ...*Call) { + for i := 1; i < len(calls); i++ { + calls[i].After(calls[i-1]) + } +} + +func setSlice(arg interface{}, v reflect.Value) { + va := reflect.ValueOf(arg) + for i := 0; i < v.Len(); i++ { + va.Index(i).Set(v.Index(i)) + } +} + +func (c *Call) addAction(action func([]interface{}) []interface{}) { + c.actions = append(c.actions, action) +} diff --git a/vendor/github.com/golang/mock/gomock/callset.go b/vendor/github.com/golang/mock/gomock/callset.go new file mode 100644 index 000000000..b046b525e --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/callset.go @@ -0,0 +1,108 @@ +// Copyright 2011 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "bytes" + "fmt" +) + +// callSet represents a set of expected calls, indexed by receiver and method +// name. +type callSet struct { + // Calls that are still expected. + expected map[callSetKey][]*Call + // Calls that have been exhausted. + exhausted map[callSetKey][]*Call +} + +// callSetKey is the key in the maps in callSet +type callSetKey struct { + receiver interface{} + fname string +} + +func newCallSet() *callSet { + return &callSet{make(map[callSetKey][]*Call), make(map[callSetKey][]*Call)} +} + +// Add adds a new expected call. +func (cs callSet) Add(call *Call) { + key := callSetKey{call.receiver, call.method} + m := cs.expected + if call.exhausted() { + m = cs.exhausted + } + m[key] = append(m[key], call) +} + +// Remove removes an expected call. +func (cs callSet) Remove(call *Call) { + key := callSetKey{call.receiver, call.method} + calls := cs.expected[key] + for i, c := range calls { + if c == call { + // maintain order for remaining calls + cs.expected[key] = append(calls[:i], calls[i+1:]...) + cs.exhausted[key] = append(cs.exhausted[key], call) + break + } + } +} + +// FindMatch searches for a matching call. Returns error with explanation message if no call matched. +func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) { + key := callSetKey{receiver, method} + + // Search through the expected calls. + expected := cs.expected[key] + var callsErrors bytes.Buffer + for _, call := range expected { + err := call.matches(args) + if err != nil { + _, _ = fmt.Fprintf(&callsErrors, "\n%v", err) + } else { + return call, nil + } + } + + // If we haven't found a match then search through the exhausted calls so we + // get useful error messages. + exhausted := cs.exhausted[key] + for _, call := range exhausted { + if err := call.matches(args); err != nil { + _, _ = fmt.Fprintf(&callsErrors, "\n%v", err) + } + } + + if len(expected)+len(exhausted) == 0 { + _, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method) + } + + return nil, fmt.Errorf(callsErrors.String()) +} + +// Failures returns the calls that are not satisfied. +func (cs callSet) Failures() []*Call { + failures := make([]*Call, 0, len(cs.expected)) + for _, calls := range cs.expected { + for _, call := range calls { + if !call.satisfied() { + failures = append(failures, call) + } + } + } + return failures +} diff --git a/vendor/github.com/golang/mock/gomock/controller.go b/vendor/github.com/golang/mock/gomock/controller.go new file mode 100644 index 000000000..d7c3c656a --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/controller.go @@ -0,0 +1,264 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gomock is a mock framework for Go. +// +// Standard usage: +// (1) Define an interface that you wish to mock. +// type MyInterface interface { +// SomeMethod(x int64, y string) +// } +// (2) Use mockgen to generate a mock from the interface. +// (3) Use the mock in a test: +// func TestMyThing(t *testing.T) { +// mockCtrl := gomock.NewController(t) +// defer mockCtrl.Finish() +// +// mockObj := something.NewMockMyInterface(mockCtrl) +// mockObj.EXPECT().SomeMethod(4, "blah") +// // pass mockObj to a real object and play with it. +// } +// +// By default, expected calls are not enforced to run in any particular order. +// Call order dependency can be enforced by use of InOrder and/or Call.After. +// Call.After can create more varied call order dependencies, but InOrder is +// often more convenient. +// +// The following examples create equivalent call order dependencies. +// +// Example of using Call.After to chain expected call order: +// +// firstCall := mockObj.EXPECT().SomeMethod(1, "first") +// secondCall := mockObj.EXPECT().SomeMethod(2, "second").After(firstCall) +// mockObj.EXPECT().SomeMethod(3, "third").After(secondCall) +// +// Example of using InOrder to declare expected call order: +// +// gomock.InOrder( +// mockObj.EXPECT().SomeMethod(1, "first"), +// mockObj.EXPECT().SomeMethod(2, "second"), +// mockObj.EXPECT().SomeMethod(3, "third"), +// ) +// +// TODO: +// - Handle different argument/return types (e.g. ..., chan, map, interface). +package gomock + +import ( + "context" + "fmt" + "reflect" + "runtime" + "sync" +) + +// A TestReporter is something that can be used to report test failures. It +// is satisfied by the standard library's *testing.T. +type TestReporter interface { + Errorf(format string, args ...interface{}) + Fatalf(format string, args ...interface{}) +} + +// TestHelper is a TestReporter that has the Helper method. It is satisfied +// by the standard library's *testing.T. +type TestHelper interface { + TestReporter + Helper() +} + +// A Controller represents the top-level control of a mock ecosystem. It +// defines the scope and lifetime of mock objects, as well as their +// expectations. It is safe to call Controller's methods from multiple +// goroutines. Each test should create a new Controller and invoke Finish via +// defer. +// +// func TestFoo(t *testing.T) { +// ctrl := gomock.NewController(t) +// defer ctrl.Finish() +// // .. +// } +// +// func TestBar(t *testing.T) { +// t.Run("Sub-Test-1", st) { +// ctrl := gomock.NewController(st) +// defer ctrl.Finish() +// // .. +// }) +// t.Run("Sub-Test-2", st) { +// ctrl := gomock.NewController(st) +// defer ctrl.Finish() +// // .. +// }) +// }) +type Controller struct { + // T should only be called within a generated mock. It is not intended to + // be used in user code and may be changed in future versions. T is the + // TestReporter passed in when creating the Controller via NewController. + // If the TestReporter does not implement a TestHelper it will be wrapped + // with a nopTestHelper. + T TestHelper + mu sync.Mutex + expectedCalls *callSet + finished bool +} + +// NewController returns a new Controller. It is the preferred way to create a +// Controller. +func NewController(t TestReporter) *Controller { + h, ok := t.(TestHelper) + if !ok { + h = nopTestHelper{t} + } + + return &Controller{ + T: h, + expectedCalls: newCallSet(), + } +} + +type cancelReporter struct { + TestHelper + cancel func() +} + +func (r *cancelReporter) Errorf(format string, args ...interface{}) { + r.TestHelper.Errorf(format, args...) +} +func (r *cancelReporter) Fatalf(format string, args ...interface{}) { + defer r.cancel() + r.TestHelper.Fatalf(format, args...) +} + +// WithContext returns a new Controller and a Context, which is cancelled on any +// fatal failure. +func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) { + h, ok := t.(TestHelper) + if !ok { + h = nopTestHelper{t} + } + + ctx, cancel := context.WithCancel(ctx) + return NewController(&cancelReporter{h, cancel}), ctx +} + +type nopTestHelper struct { + TestReporter +} + +func (h nopTestHelper) Helper() {} + +// RecordCall is called by a mock. It should not be called by user code. +func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call { + ctrl.T.Helper() + + recv := reflect.ValueOf(receiver) + for i := 0; i < recv.Type().NumMethod(); i++ { + if recv.Type().Method(i).Name == method { + return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...) + } + } + ctrl.T.Fatalf("gomock: failed finding method %s on %T", method, receiver) + panic("unreachable") +} + +// RecordCallWithMethodType is called by a mock. It should not be called by user code. +func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call { + ctrl.T.Helper() + + call := newCall(ctrl.T, receiver, method, methodType, args...) + + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + ctrl.expectedCalls.Add(call) + + return call +} + +// Call is called by a mock. It should not be called by user code. +func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} { + ctrl.T.Helper() + + // Nest this code so we can use defer to make sure the lock is released. + actions := func() []func([]interface{}) []interface{} { + ctrl.T.Helper() + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + + expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args) + if err != nil { + origin := callerInfo(2) + ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err) + } + + // Two things happen here: + // * the matching call no longer needs to check prerequite calls, + // * and the prerequite calls are no longer expected, so remove them. + preReqCalls := expected.dropPrereqs() + for _, preReqCall := range preReqCalls { + ctrl.expectedCalls.Remove(preReqCall) + } + + actions := expected.call() + if expected.exhausted() { + ctrl.expectedCalls.Remove(expected) + } + return actions + }() + + var rets []interface{} + for _, action := range actions { + if r := action(args); r != nil { + rets = r + } + } + + return rets +} + +// Finish checks to see if all the methods that were expected to be called +// were called. It should be invoked for each Controller. It is not idempotent +// and therefore can only be invoked once. +func (ctrl *Controller) Finish() { + ctrl.T.Helper() + + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + + if ctrl.finished { + ctrl.T.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.") + } + ctrl.finished = true + + // If we're currently panicking, probably because this is a deferred call, + // pass through the panic. + if err := recover(); err != nil { + panic(err) + } + + // Check that all remaining expected calls are satisfied. + failures := ctrl.expectedCalls.Failures() + for _, call := range failures { + ctrl.T.Errorf("missing call(s) to %v", call) + } + if len(failures) != 0 { + ctrl.T.Fatalf("aborting test due to missing call(s)") + } +} + +func callerInfo(skip int) string { + if _, file, line, ok := runtime.Caller(skip + 1); ok { + return fmt.Sprintf("%s:%d", file, line) + } + return "unknown file" +} diff --git a/vendor/github.com/golang/mock/gomock/matchers.go b/vendor/github.com/golang/mock/gomock/matchers.go new file mode 100644 index 000000000..7bfc07be4 --- /dev/null +++ b/vendor/github.com/golang/mock/gomock/matchers.go @@ -0,0 +1,255 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "fmt" + "reflect" + "strings" +) + +// A Matcher is a representation of a class of values. +// It is used to represent the valid or expected arguments to a mocked method. +type Matcher interface { + // Matches returns whether x is a match. + Matches(x interface{}) bool + + // String describes what the matcher matches. + String() string +} + +// WantFormatter modifies the given Matcher's String() method to the given +// Stringer. This allows for control on how the "Want" is formatted when +// printing . +func WantFormatter(s fmt.Stringer, m Matcher) Matcher { + type matcher interface { + Matches(x interface{}) bool + } + + return struct { + matcher + fmt.Stringer + }{ + matcher: m, + Stringer: s, + } +} + +// StringerFunc type is an adapter to allow the use of ordinary functions as +// a Stringer. If f is a function with the appropriate signature, +// StringerFunc(f) is a Stringer that calls f. +type StringerFunc func() string + +// String implements fmt.Stringer. +func (f StringerFunc) String() string { + return f() +} + +// GotFormatter is used to better print failure messages. If a matcher +// implements GotFormatter, it will use the result from Got when printing +// the failure message. +type GotFormatter interface { + // Got is invoked with the received value. The result is used when + // printing the failure message. + Got(got interface{}) string +} + +// GotFormatterFunc type is an adapter to allow the use of ordinary +// functions as a GotFormatter. If f is a function with the appropriate +// signature, GotFormatterFunc(f) is a GotFormatter that calls f. +type GotFormatterFunc func(got interface{}) string + +// Got implements GotFormatter. +func (f GotFormatterFunc) Got(got interface{}) string { + return f(got) +} + +// GotFormatterAdapter attaches a GotFormatter to a Matcher. +func GotFormatterAdapter(s GotFormatter, m Matcher) Matcher { + return struct { + GotFormatter + Matcher + }{ + GotFormatter: s, + Matcher: m, + } +} + +type anyMatcher struct{} + +func (anyMatcher) Matches(interface{}) bool { + return true +} + +func (anyMatcher) String() string { + return "is anything" +} + +type eqMatcher struct { + x interface{} +} + +func (e eqMatcher) Matches(x interface{}) bool { + return reflect.DeepEqual(e.x, x) +} + +func (e eqMatcher) String() string { + return fmt.Sprintf("is equal to %v", e.x) +} + +type nilMatcher struct{} + +func (nilMatcher) Matches(x interface{}) bool { + if x == nil { + return true + } + + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.Ptr, reflect.Slice: + return v.IsNil() + } + + return false +} + +func (nilMatcher) String() string { + return "is nil" +} + +type notMatcher struct { + m Matcher +} + +func (n notMatcher) Matches(x interface{}) bool { + return !n.m.Matches(x) +} + +func (n notMatcher) String() string { + // TODO: Improve this if we add a NotString method to the Matcher interface. + return "not(" + n.m.String() + ")" +} + +type assignableToTypeOfMatcher struct { + targetType reflect.Type +} + +func (m assignableToTypeOfMatcher) Matches(x interface{}) bool { + return reflect.TypeOf(x).AssignableTo(m.targetType) +} + +func (m assignableToTypeOfMatcher) String() string { + return "is assignable to " + m.targetType.Name() +} + +type allMatcher struct { + matchers []Matcher +} + +func (am allMatcher) Matches(x interface{}) bool { + for _, m := range am.matchers { + if !m.Matches(x) { + return false + } + } + return true +} + +func (am allMatcher) String() string { + ss := make([]string, 0, len(am.matchers)) + for _, matcher := range am.matchers { + ss = append(ss, matcher.String()) + } + return strings.Join(ss, "; ") +} + +type lenMatcher struct { + i int +} + +func (m lenMatcher) Matches(x interface{}) bool { + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == m.i + default: + return false + } +} + +func (m lenMatcher) String() string { + return fmt.Sprintf("has length %d", m.i) +} + +// Constructors + +// All returns a composite Matcher that returns true if and only all of the +// matchers return true. +func All(ms ...Matcher) Matcher { return allMatcher{ms} } + +// Any returns a matcher that always matches. +func Any() Matcher { return anyMatcher{} } + +// Eq returns a matcher that matches on equality. +// +// Example usage: +// Eq(5).Matches(5) // returns true +// Eq(5).Matches(4) // returns false +func Eq(x interface{}) Matcher { return eqMatcher{x} } + +// Len returns a matcher that matches on length. This matcher returns false if +// is compared to a type that is not an array, chan, map, slice, or string. +func Len(i int) Matcher { + return lenMatcher{i} +} + +// Nil returns a matcher that matches if the received value is nil. +// +// Example usage: +// var x *bytes.Buffer +// Nil().Matches(x) // returns true +// x = &bytes.Buffer{} +// Nil().Matches(x) // returns false +func Nil() Matcher { return nilMatcher{} } + +// Not reverses the results of its given child matcher. +// +// Example usage: +// Not(Eq(5)).Matches(4) // returns true +// Not(Eq(5)).Matches(5) // returns false +func Not(x interface{}) Matcher { + if m, ok := x.(Matcher); ok { + return notMatcher{m} + } + return notMatcher{Eq(x)} +} + +// AssignableToTypeOf is a Matcher that matches if the parameter to the mock +// function is assignable to the type of the parameter to this function. +// +// Example usage: +// var s fmt.Stringer = &bytes.Buffer{} +// AssignableToTypeOf(s).Matches(time.Second) // returns true +// AssignableToTypeOf(s).Matches(99) // returns false +// +// var ctx = reflect.TypeOf((*context.Context)).Elem() +// AssignableToTypeOf(ctx).Matches(context.Background()) // returns true +func AssignableToTypeOf(x interface{}) Matcher { + if xt, ok := x.(reflect.Type); ok { + return assignableToTypeOfMatcher{xt} + } + return assignableToTypeOfMatcher{reflect.TypeOf(x)} +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 1d9fed62c..dd08de875 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -200,6 +200,9 @@ github.com/golang/glog # github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da ## explicit github.com/golang/groupcache/lru +# github.com/golang/mock v1.4.4 +## explicit; go 1.11 +github.com/golang/mock/gomock # github.com/golang/protobuf v1.5.3 ## explicit; go 1.9 github.com/golang/protobuf/jsonpb