diff --git a/cloud/blockstore/tools/csi_driver/internal/driver/node.go b/cloud/blockstore/tools/csi_driver/internal/driver/node.go index 40d52d74c2..9193ab62a1 100644 --- a/cloud/blockstore/tools/csi_driver/internal/driver/node.go +++ b/cloud/blockstore/tools/csi_driver/internal/driver/node.go @@ -2,6 +2,7 @@ package driver import ( "context" + "encoding/json" "errors" "fmt" "io/fs" @@ -159,6 +160,23 @@ func (s *nodeService) NodeStageVolume( var err error if instanceID := req.VolumeContext[instanceIDKey]; instanceID != "" { + stageRecordPath := filepath.Join(req.StagingTargetPath, req.VolumeId+".json") + + // Backend can be empty for old disks, in this case we use NBS + backend := "nbs" + if nfsBackend { + backend = "nfs" + } + + if err = s.writeStageData(stageRecordPath, &StageData{ + Backend: backend, + InstanceId: instanceID, + RealStagePath: s.getEndpointDir(stagingDirName, req.VolumeId), + }); err != nil { + return nil, s.statusErrorf(codes.Internal, + "Failed to wriete stage record: %v", err) + } + if nfsBackend { err = s.nodeStageFileStoreAsVhostSocket(ctx, instanceID, req.VolumeId) } else { @@ -193,10 +211,14 @@ func (s *nodeService) NodeUnstageVolume( } if s.vmMode { - if err := s.nodeUnstageVhostSocket(ctx, req.VolumeId); err != nil { - return nil, s.statusErrorf( - codes.InvalidArgument, - "Failed to unstage volume: %v", err) + stageRecordPath := filepath.Join(req.StagingTargetPath, req.VolumeId+".json") + if stageData, err := s.readStageData(stageRecordPath); err == nil { + if err := s.nodeUnstageVhostSocket(ctx, req.VolumeId, stageData); err != nil { + return nil, s.statusErrorf( + codes.InvalidArgument, + "Failed to unstage volume: %v", err) + } + ignoreError(os.Remove(stageRecordPath)) } } @@ -401,6 +423,47 @@ func (s *nodeService) nodePublishDiskAsVhostSocket( return s.mountSocketDir(endpointDir, req) } +type StageData struct { + Backend string `json:"backend"` + InstanceId string `json:"instanceId"` + RealStagePath string `json:"realStagePath"` +} + +func (s *nodeService) writeStageData(path string, data *StageData) error { + bytes, err := json.Marshal(data) + if err != nil { + return err + } + + err = os.MkdirAll(filepath.Dir(path), 0750) + if err != nil { + return err + } + + err = os.WriteFile(path, bytes, 0600) + if err != nil { + return err + } + + return nil +} + +func (s *nodeService) readStageData(path string) (*StageData, error) { + data := StageData{} + + bytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + err = json.Unmarshal(bytes, &data) + if err != nil { + return nil, err + } + + return &data, nil +} + func (s *nodeService) nodeStageDiskAsVhostSocket( ctx context.Context, instanceId string, @@ -722,35 +785,29 @@ func (s *nodeService) mountSocketDir(sourcePath string, req *csi.NodePublishVolu func (s *nodeService) nodeUnstageVhostSocket( ctx context.Context, - volumeID string) error { + volumeID string, + stageData *StageData) error { - log.Printf("csi.nodeUnstageVhostSocket: %s", volumeID) + log.Printf("csi.nodeUnstageVhostSocket[%s]: %s %s %s", stageData.Backend, stageData.InstanceId, + volumeID, stageData.RealStagePath) - endpointDir := s.getEndpointDir(stagingDirName, volumeID) - - // Trying to stop both NBS and NFS endpoints, - // because the endpoint's backend service is unknown here. - // When we miss we get S_FALSE/S_ALREADY code (err == nil). - - if s.nbsClient != nil { + if stageData.Backend == "nbs" { _, err := s.nbsClient.StopEndpoint(ctx, &nbsapi.TStopEndpointRequest{ - UnixSocketPath: filepath.Join(endpointDir, nbsSocketName), + UnixSocketPath: filepath.Join(stageData.RealStagePath, nbsSocketName), }) if err != nil { return fmt.Errorf("failed to stop nbs endpoint: %w", err) } - } - - if s.nfsClient != nil { + } else if stageData.Backend == "nfs" { _, err := s.nfsClient.StopEndpoint(ctx, &nfsapi.TStopEndpointRequest{ - SocketPath: filepath.Join(endpointDir, nfsSocketName), + SocketPath: filepath.Join(stageData.RealStagePath, nfsSocketName), }) if err != nil { return fmt.Errorf("failed to stop nfs endpoint: %w", err) } } - if err := os.RemoveAll(endpointDir); err != nil { + if err := os.RemoveAll(stageData.RealStagePath); err != nil { return err } diff --git a/cloud/blockstore/tools/csi_driver/internal/driver/node_test.go b/cloud/blockstore/tools/csi_driver/internal/driver/node_test.go index 8f703afc45..b8f55f5a07 100644 --- a/cloud/blockstore/tools/csi_driver/internal/driver/node_test.go +++ b/cloud/blockstore/tools/csi_driver/internal/driver/node_test.go @@ -170,14 +170,6 @@ func doTestPublishUnpublishVolumeForKubevirt(t *testing.T, backend string, devic _, err = os.Stat(filepath.Join(socketsDir, podID)) assert.True(t, os.IsNotExist(err)) - nbsClient.On("StopEndpoint", ctx, &nbs.TStopEndpointRequest{ - UnixSocketPath: filepath.Join(socketsDir, stagingDirName, diskID, nbsSocketName), - }).Return(&nbs.TStopEndpointResponse{}, nil) - - nfsClient.On("StopEndpoint", ctx, &nfs.TStopEndpointRequest{ - SocketPath: filepath.Join(socketsDir, stagingDirName, diskID, nfsSocketName), - }).Return(&nfs.TStopEndpointResponse{}, nil) - _, err = nodeService.NodeUnstageVolume(ctx, &csi.NodeUnstageVolumeRequest{ VolumeId: diskID, StagingTargetPath: stagingTargetPath, @@ -216,7 +208,7 @@ func doTestStagedPublishUnpublishVolumeForKubevirt(t *testing.T, backend string, if deviceNameOpt != nil { deviceName = *deviceNameOpt } - stagingTargetPath := "testStagingTargetPath" + stagingTargetPath := filepath.Join(tempDir, "testStagingTargetPath") socketsDir := filepath.Join(tempDir, "sockets") sourcePath := filepath.Join(socketsDir, stagingDirName, diskID) targetPath := filepath.Join(tempDir, "pods", podID, "volumes", diskID, "mount") @@ -343,13 +335,17 @@ func doTestStagedPublishUnpublishVolumeForKubevirt(t *testing.T, backend string, }) require.NoError(t, err) - nbsClient.On("StopEndpoint", ctx, &nbs.TStopEndpointRequest{ - UnixSocketPath: nbsSocketPath, - }).Return(&nbs.TStopEndpointResponse{}, nil) + if backend == "nbs" { + nbsClient.On("StopEndpoint", ctx, &nbs.TStopEndpointRequest{ + UnixSocketPath: nbsSocketPath, + }).Return(&nbs.TStopEndpointResponse{}, nil) + } - nfsClient.On("StopEndpoint", ctx, &nfs.TStopEndpointRequest{ - SocketPath: nfsSocketPath, - }).Return(&nfs.TStopEndpointResponse{}, nil) + if backend == "nfs" { + nfsClient.On("StopEndpoint", ctx, &nfs.TStopEndpointRequest{ + SocketPath: nfsSocketPath, + }).Return(&nfs.TStopEndpointResponse{}, nil) + } _, err = nodeService.NodeUnstageVolume(ctx, &csi.NodeUnstageVolumeRequest{ VolumeId: diskID,