From 579fdf2688c5393073caa335df62410a6c70b02d Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Fri, 14 Apr 2023 16:10:56 -0400 Subject: [PATCH] fix(output): report terminal status when writer is not a file The underlying writer doesn't have to be a *os.File for it to be a TTY. For example, a PTY ssh session is a TTY. However, the std library returns a io.ReadWriter for the ssh session. Combined with the WithUnsafe() option, we can query the terminal of an ssh session using Termenv. --- termenv_unix.go | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/termenv_unix.go b/termenv_unix.go index 24d519a..ce8f8a8 100644 --- a/termenv_unix.go +++ b/termenv_unix.go @@ -6,6 +6,7 @@ package termenv import ( "fmt" "io" + "os" "strconv" "strings" "time" @@ -113,8 +114,8 @@ func (o Output) backgroundColor() Color { return ANSIColor(0) } -func (o *Output) waitForData(timeout time.Duration) error { - fd := o.TTY().Fd() +func (o *Output) waitForData(f *os.File, timeout time.Duration) error { + fd := f.Fd() tv := unix.NsecToTimeval(int64(timeout)) var readfds unix.FdSet readfds.Set(int(fd)) @@ -137,15 +138,15 @@ func (o *Output) waitForData(timeout time.Duration) error { return nil } -func (o *Output) readNextByte() (byte, error) { - if !o.unsafe { - if err := o.waitForData(OSCTimeout); err != nil { +func (o *Output) readNextByte(rw io.ReadWriter) (byte, error) { + if f, ok := rw.(*os.File); ok && !o.unsafe { + if err := o.waitForData(f, OSCTimeout); err != nil { return 0, err } } var b [1]byte - n, err := o.TTY().Read(b[:]) + n, err := rw.Read(b[:]) if err != nil { return 0, err } @@ -160,15 +161,15 @@ func (o *Output) readNextByte() (byte, error) { // readNextResponse reads either an OSC response or a cursor position response: // - OSC response: "\x1b]11;rgb:1111/1111/1111\x1b\\" // - cursor position response: "\x1b[42;1R" -func (o *Output) readNextResponse() (response string, isOSC bool, err error) { - start, err := o.readNextByte() +func (o *Output) readNextResponse(rw io.ReadWriter) (response string, isOSC bool, err error) { + start, err := o.readNextByte(rw) if err != nil { return "", false, err } // first byte must be ESC for start != ESC { - start, err = o.readNextByte() + start, err = o.readNextByte(rw) if err != nil { return "", false, err } @@ -177,7 +178,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) { response += string(start) // next byte is either '[' (cursor position response) or ']' (OSC response) - tpe, err := o.readNextByte() + tpe, err := o.readNextByte(rw) if err != nil { return "", false, err } @@ -195,7 +196,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) { } for { - b, err := o.readNextByte() + b, err := o.readNextByte(rw) if err != nil { return "", false, err } @@ -231,13 +232,17 @@ func (o Output) termStatusReport(sequence int) (string, error) { return "", ErrStatusReport } - tty := o.TTY() - if tty == nil { + tty, ok := o.Writer().(io.ReadWriter) + if tty == nil || !ok { return "", ErrStatusReport } if !o.unsafe { - fd := int(tty.Fd()) + f, ok := tty.(*os.File) + if !ok { + return "", ErrStatusReport + } + fd := int(f.Fd()) // if in background, we can't control the terminal if !isForeground(fd) { return "", ErrStatusReport @@ -264,7 +269,7 @@ func (o Output) termStatusReport(sequence int) (string, error) { fmt.Fprintf(tty, CSI+"6n") // read the next response - res, isOSC, err := o.readNextResponse() + res, isOSC, err := o.readNextResponse(tty) if err != nil { return "", fmt.Errorf("%s: %s", ErrStatusReport, err) } @@ -275,7 +280,7 @@ func (o Output) termStatusReport(sequence int) (string, error) { } // read the cursor query response next and discard the result - _, _, err = o.readNextResponse() + _, _, err = o.readNextResponse(tty) if err != nil { return "", err }