diff --git a/pkg/client/nbd.go b/pkg/client/nbd.go index 48d2872..18509e6 100644 --- a/pkg/client/nbd.go +++ b/pkg/client/nbd.go @@ -7,7 +7,12 @@ import ( "io" "net" "os" + "path" + "path/filepath" + "strconv" + "strings" "syscall" + "time" "github.com/pilebones/go-udev/netlink" "github.com/pojntfx/go-nbd/pkg/ioctl" @@ -24,10 +29,12 @@ var ( ) type Options struct { - ExportName string - BlockSize uint32 - OnConnected func() - Timeout int + ExportName string + BlockSize uint32 + OnConnected func() + ReadyCheckUdev bool + ReadyCheckPollInterval time.Duration + Timeout int } func negotiateNewstyle(conn net.Conn) error { @@ -60,6 +67,10 @@ func Connect(conn net.Conn, device *os.File, options *Options) error { options.ExportName = "default" } + if !options.ReadyCheckUdev && options.ReadyCheckPollInterval <= 0 { + options.ReadyCheckPollInterval = time.Millisecond + } + var cfd uintptr switch c := conn.(type) { case *net.TCPConn: @@ -82,41 +93,83 @@ func Connect(conn net.Conn, device *os.File, options *Options) error { fatal := make(chan error) if options.OnConnected != nil { - udevConn := new(netlink.UEventConn) - if err := udevConn.Connect(netlink.UdevEvent); err != nil { - return err - } - defer udevConn.Close() - - var ( - udevReadyCh = make(chan netlink.UEvent) - udevErrCh = make(chan error) - udevQuit = udevConn.Monitor(udevReadyCh, udevErrCh, &netlink.RuleDefinitions{ - Rules: []netlink.RuleDefinition{ - { - Env: map[string]string{ - "DEVNAME": device.Name(), + if options.ReadyCheckUdev { + udevConn := new(netlink.UEventConn) + if err := udevConn.Connect(netlink.UdevEvent); err != nil { + return err + } + defer udevConn.Close() + + var ( + udevReadyCh = make(chan netlink.UEvent) + udevErrCh = make(chan error) + udevQuit = udevConn.Monitor(udevReadyCh, udevErrCh, &netlink.RuleDefinitions{ + Rules: []netlink.RuleDefinition{ + { + Env: map[string]string{ + "DEVNAME": device.Name(), + }, }, }, - }, - }) - ) - defer close(udevQuit) + }) + ) + defer close(udevQuit) - go func() { - select { - case <-udevReadyCh: - close(udevQuit) + go func() { + select { + case <-udevReadyCh: + close(udevQuit) - options.OnConnected() + options.OnConnected() - return - case err := <-udevErrCh: - fatal <- err + return + case err := <-udevErrCh: + fatal <- err - return - } - }() + return + } + }() + } else { + go func() { + sizeFile, err := os.Open(path.Join("/sys", "block", filepath.Base(device.Name()), "size")) + if err != nil { + fatal <- err + + return + } + defer sizeFile.Close() + + for { + if _, err := sizeFile.Seek(0, io.SeekStart); err != nil { + fatal <- err + + return + } + + rsize, err := io.ReadAll(sizeFile) + if err != nil { + fatal <- err + + return + } + + size, err := strconv.ParseInt(strings.TrimSpace(string(rsize)), 10, 64) + if err != nil { + fatal <- err + + return + } + + if size > 0 { + options.OnConnected() + + return + } + + time.Sleep(options.ReadyCheckPollInterval) + } + }() + } } if _, _, err := syscall.Syscall(