diff --git a/README.cn.md b/README.cn.md index d4ed8e4..8055b4f 100644 --- a/README.cn.md +++ b/README.cn.md @@ -3,13 +3,15 @@ [![MIT License](https://img.shields.io/badge/license-MIT-green.svg?style=flat)](https://choosealicense.com/licenses/mit/) [![GitHub Release](https://img.shields.io/github/v/release/trzsz/tsshd)](https://github.com/trzsz/tsshd/releases) -[`tssh --udp`](https://github.com/trzsz/trzsz-ssh) 类似于 [`mosh`](https://github.com/mobile-shell/mosh), 而 `tsshd` 类似于 `mosh-server`. +`tsshd` 类似于 `mosh-server`,而 [`tssh --udp`](https://github.com/trzsz/trzsz-ssh) 类似于 [`mosh`](https://github.com/mobile-shell/mosh)。 ## 优点简介 -- 低延迟( 基于 QUIC / KCP ) +- 降低延迟( 基于 [QUIC](https://github.com/quic-go/quic-go) / [KCP](https://github.com/xtaci/kcp-go) ) -- 端口转发( 与 openssh 相同 ) +- 端口转发( 与 openssh 相同,包括 ssh agent 转发和 X11 转发 ) + +- _[TODO]_ 连接迁移( 支持网络切换和掉线重连,依赖于 [quic-go#234](https://github.com/quic-go/quic-go/issues/234) ) ## 如何使用 diff --git a/README.md b/README.md index b8a2121..540a4f3 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,15 @@ [![GitHub Release](https://img.shields.io/github/v/release/trzsz/tsshd)](https://github.com/trzsz/tsshd/releases) [![中文文档](https://img.shields.io/badge/%E4%B8%AD%E6%96%87-%E6%96%87%E6%A1%A3-blue?style=flat)](https://github.com/trzsz/tsshd/blob/main/README.cn.md) -The [`tssh --udp`](https://github.com/trzsz/trzsz-ssh) works like [`mosh`](https://github.com/mobile-shell/mosh), and the `tsshd` works like `mosh-server`. +The `tsshd` works like `mosh-server`, while the [`tssh --udp`](https://github.com/trzsz/trzsz-ssh) works like [`mosh`](https://github.com/mobile-shell/mosh). -## Advanced Features +## Advantages -- Low latency ( based on QUIC / KCP ) +- Low Latency ( based on [QUIC](https://github.com/quic-go/quic-go) / [KCP](https://github.com/xtaci/kcp-go) ) -- Port forwarding ( same as openssh ) +- Port Forwarding ( same as openssh, includes ssh agent forwarding and X11 forwarding ) + +- _[TODO]_ Connection Migration ( supports network switching and reconnection, depends on [quic-go#234](https://github.com/quic-go/quic-go/issues/234) ) ## How to use diff --git a/go.mod b/go.mod index 77f6c64..6397887 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,13 @@ module github.com/trzsz/tsshd go 1.21 require ( - github.com/UserExistsError/conpty v0.1.3 github.com/creack/pty v1.1.21 github.com/quic-go/quic-go v0.45.1 github.com/trzsz/go-arg v1.5.3 github.com/xtaci/kcp-go/v5 v5.6.8 github.com/xtaci/smux v1.5.24 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 + golang.org/x/crypto v0.25.0 + golang.org/x/sys v0.22.0 ) require ( @@ -18,7 +17,7 @@ require ( github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect - github.com/klauspost/reedsolomon v1.12.1 // indirect + github.com/klauspost/reedsolomon v1.12.2 // indirect github.com/onsi/ginkgo/v2 v2.19.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/templexxx/cpu v0.1.1-0.20240303154708-598a14b050c5 // indirect @@ -26,7 +25,7 @@ require ( github.com/tjfoc/gmsm v1.4.1 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect - golang.org/x/mod v0.18.0 // indirect - golang.org/x/net v0.26.0 // indirect + golang.org/x/mod v0.19.0 // indirect + golang.org/x/net v0.27.0 // indirect golang.org/x/tools v0.22.0 // indirect ) diff --git a/go.sum b/go.sum index f1f038a..32ba523 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/UserExistsError/conpty v0.1.3 h1:YzGQkHAiBBkAihOCO5J2cAnahzb8ePvje2YxG7et1E0= -github.com/UserExistsError/conpty v0.1.3/go.mod h1:PDglKIkX3O/2xVk0MV9a6bCWxRmPVfxqZoTG/5sSd9I= github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -39,8 +37,8 @@ github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 h1:e+8XbKB6IMn8A4OAyZ github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/klauspost/reedsolomon v1.12.1 h1:NhWgum1efX1x58daOBGCFWcxtEhOhXKKl1HAPQUp03Q= -github.com/klauspost/reedsolomon v1.12.1/go.mod h1:nEi5Kjb6QqtbofI6s+cbG/j1da11c96IBYBSnVGtuBs= +github.com/klauspost/reedsolomon v1.12.2 h1:TC0hlL/tTRxiMNnqHCzKsY11E0fIIKGCoZ2vQoPKIEM= +github.com/klauspost/reedsolomon v1.12.2/go.mod h1:nEi5Kjb6QqtbofI6s+cbG/j1da11c96IBYBSnVGtuBs= github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= @@ -75,24 +73,24 @@ go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= -golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= +golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -104,9 +102,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= diff --git a/internal/conpty/conpty.go b/internal/conpty/conpty.go new file mode 100644 index 0000000..2c55ea3 --- /dev/null +++ b/internal/conpty/conpty.go @@ -0,0 +1,363 @@ +//go:build windows +// +build windows + +// Forked From: https://github.com/UserExistsError/conpty + +package conpty + +import ( + "context" + "errors" + "fmt" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + modKernel32 = windows.NewLazySystemDLL("kernel32.dll") + fCreatePseudoConsole = modKernel32.NewProc("CreatePseudoConsole") + fResizePseudoConsole = modKernel32.NewProc("ResizePseudoConsole") + fClosePseudoConsole = modKernel32.NewProc("ClosePseudoConsole") + fInitializeProcThreadAttributeList = modKernel32.NewProc("InitializeProcThreadAttributeList") + fUpdateProcThreadAttribute = modKernel32.NewProc("UpdateProcThreadAttribute") + ErrConPtyUnsupported = errors.New("ConPty is not available on this version of Windows") +) + +func IsConPtyAvailable() bool { + return fCreatePseudoConsole.Find() == nil && + fResizePseudoConsole.Find() == nil && + fClosePseudoConsole.Find() == nil && + fInitializeProcThreadAttributeList.Find() == nil && + fUpdateProcThreadAttribute.Find() == nil +} + +const ( + _STILL_ACTIVE uint32 = 259 + _S_OK uintptr = 0 + _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE uintptr = 0x20016 + defaultConsoleWidth = 80 // in characters + defaultConsoleHeight = 40 // in characters +) + +type _COORD struct { + X, Y int16 +} + +func (c *_COORD) Pack() uintptr { + return uintptr((int32(c.Y) << 16) | int32(c.X)) +} + +type _HPCON windows.Handle + +type handleIO struct { + handle windows.Handle +} + +func (h *handleIO) Read(p []byte) (int, error) { + var numRead uint32 = 0 + err := windows.ReadFile(h.handle, p, &numRead, nil) + return int(numRead), err +} + +func (h *handleIO) Write(p []byte) (int, error) { + var numWritten uint32 = 0 + err := windows.WriteFile(h.handle, p, &numWritten, nil) + return int(numWritten), err +} + +func (h *handleIO) Close() error { + return windows.CloseHandle(h.handle) +} + +type ConPty struct { + hpc _HPCON + pi *windows.ProcessInformation + ptyIn, ptyOut, cmdIn, cmdOut *handleIO +} + +func win32ClosePseudoConsole(hPc _HPCON) { + if fClosePseudoConsole.Find() != nil { + return + } + // this kills the attached process. there is no return value. + fClosePseudoConsole.Call(uintptr(hPc)) +} + +func win32ResizePseudoConsole(hPc _HPCON, coord *_COORD) error { + if fResizePseudoConsole.Find() != nil { + return fmt.Errorf("ResizePseudoConsole not found") + } + ret, _, _ := fResizePseudoConsole.Call(uintptr(hPc), coord.Pack()) + if ret != _S_OK { + return fmt.Errorf("ResizePseudoConsole failed with status 0x%x", ret) + } + return nil +} + +func win32CreatePseudoConsole(c *_COORD, hIn, hOut windows.Handle) (_HPCON, error) { + if fCreatePseudoConsole.Find() != nil { + return 0, fmt.Errorf("CreatePseudoConsole not found") + } + var hPc _HPCON + ret, _, _ := fCreatePseudoConsole.Call( + c.Pack(), + uintptr(hIn), + uintptr(hOut), + 0, + uintptr(unsafe.Pointer(&hPc))) + if ret != _S_OK { + return 0, fmt.Errorf("CreatePseudoConsole() failed with status 0x%x", ret) + } + return hPc, nil +} + +type _StartupInfoEx struct { + startupInfo windows.StartupInfo + attributeList []byte +} + +func getStartupInfoExForPTY(hpc _HPCON) (*_StartupInfoEx, error) { + if fInitializeProcThreadAttributeList.Find() != nil { + return nil, fmt.Errorf("InitializeProcThreadAttributeList not found") + } + if fUpdateProcThreadAttribute.Find() != nil { + return nil, fmt.Errorf("UpdateProcThreadAttribute not found") + } + var siEx _StartupInfoEx + siEx.startupInfo.Cb = uint32(unsafe.Sizeof(windows.StartupInfo{}) + unsafe.Sizeof(&siEx.attributeList[0])) + siEx.startupInfo.Flags |= windows.STARTF_USESTDHANDLES + var size uintptr + + // first call is to get required size. this should return false. + ret, _, _ := fInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&size))) + siEx.attributeList = make([]byte, size, size) + ret, _, err := fInitializeProcThreadAttributeList.Call( + uintptr(unsafe.Pointer(&siEx.attributeList[0])), + 1, + 0, + uintptr(unsafe.Pointer(&size))) + if ret != 1 { + return nil, fmt.Errorf("InitializeProcThreadAttributeList: %v", err) + } + + ret, _, err = fUpdateProcThreadAttribute.Call( + uintptr(unsafe.Pointer(&siEx.attributeList[0])), + 0, + _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE, + uintptr(hpc), + unsafe.Sizeof(hpc), + 0, + 0) + if ret != 1 { + return nil, fmt.Errorf("InitializeProcThreadAttributeList: %v", err) + } + return &siEx, nil +} + +func createConsoleProcessAttachedToPTY(hpc _HPCON, commandLine, workDir string, env []string) (*windows.ProcessInformation, error) { + cmdLine, err := windows.UTF16PtrFromString(commandLine) + if err != nil { + return nil, err + } + var currentDirectory *uint16 + if workDir != "" { + currentDirectory, err = windows.UTF16PtrFromString(workDir) + if err != nil { + return nil, err + } + } + var envBlock *uint16 + flags := uint32(windows.EXTENDED_STARTUPINFO_PRESENT) + if env != nil { + flags |= uint32(windows.CREATE_UNICODE_ENVIRONMENT) + envBlock = createEnvBlock(env) + } + siEx, err := getStartupInfoExForPTY(hpc) + if err != nil { + return nil, err + } + var pi windows.ProcessInformation + err = windows.CreateProcess( + nil, // use this if no args + cmdLine, + nil, + nil, + false, // inheritHandle + flags, + envBlock, + currentDirectory, + &siEx.startupInfo, + &pi) + if err != nil { + return nil, err + } + return &pi, nil +} + +// createEnvBlock refers to syscall.createEnvBlock in go/src/syscall/exec_windows.go +// Sourced From: https://github.com/creack/pty/pull/155 +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length += 1 + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// This will only return the first error. +func closeHandles(handles ...windows.Handle) error { + var err error + for _, h := range handles { + if h != windows.InvalidHandle { + if err == nil { + err = windows.CloseHandle(h) + } else { + windows.CloseHandle(h) + } + } + } + return err +} + +// Close all open handles and terminate the process. +func (cpty *ConPty) Close() error { + // there is no return code + win32ClosePseudoConsole(cpty.hpc) + return closeHandles( + cpty.pi.Process, + cpty.pi.Thread, + cpty.ptyIn.handle, + cpty.ptyOut.handle, + cpty.cmdIn.handle, + cpty.cmdOut.handle) +} + +// Wait for the process to exit and return the exit code. If context is canceled, +// Wait() will return STILL_ACTIVE and an error indicating the context was canceled. +func (cpty *ConPty) Wait(ctx context.Context) (uint32, error) { + var exitCode uint32 = _STILL_ACTIVE + for { + if err := ctx.Err(); err != nil { + return _STILL_ACTIVE, fmt.Errorf("wait canceled: %v", err) + } + ret, _ := windows.WaitForSingleObject(cpty.pi.Process, 1000) + if ret != uint32(windows.WAIT_TIMEOUT) { + err := windows.GetExitCodeProcess(cpty.pi.Process, &exitCode) + return exitCode, err + } + } +} + +func (cpty *ConPty) Resize(width, height int) error { + coords := _COORD{ + int16(width), + int16(height), + } + + return win32ResizePseudoConsole(cpty.hpc, &coords) +} + +func (cpty *ConPty) Read(p []byte) (int, error) { + return cpty.cmdOut.Read(p) +} + +func (cpty *ConPty) Write(p []byte) (int, error) { + return cpty.cmdIn.Write(p) +} + +func (cpty *ConPty) Pid() int { + return int(cpty.pi.ProcessId) +} + +type conPtyArgs struct { + coords _COORD + workDir string + env []string +} + +type ConPtyOption func(args *conPtyArgs) + +func ConPtyDimensions(width, height int) ConPtyOption { + return func(args *conPtyArgs) { + args.coords.X = int16(width) + args.coords.Y = int16(height) + } +} + +func ConPtyWorkDir(workDir string) ConPtyOption { + return func(args *conPtyArgs) { + args.workDir = workDir + } +} + +func ConPtyEnv(env []string) ConPtyOption { + return func(args *conPtyArgs) { + args.env = env + } +} + +// Start a new process specified in `commandLine` and attach a pseudo console using the Windows +// ConPty API. If ConPty is not available, ErrConPtyUnsupported will be returned. +// +// On successful return, an instance of ConPty is returned. You must call Close() on this to release +// any resources associated with the process. To get the exit code of the process, you can call Wait(). +func Start(commandLine string, options ...ConPtyOption) (*ConPty, error) { + if !IsConPtyAvailable() { + return nil, ErrConPtyUnsupported + } + args := &conPtyArgs{ + coords: _COORD{defaultConsoleWidth, defaultConsoleHeight}, + } + for _, opt := range options { + opt(args) + } + + var cmdIn, cmdOut, ptyIn, ptyOut windows.Handle + if err := windows.CreatePipe(&ptyIn, &cmdIn, nil, 0); err != nil { + return nil, fmt.Errorf("CreatePipe: %v", err) + } + if err := windows.CreatePipe(&cmdOut, &ptyOut, nil, 0); err != nil { + closeHandles(ptyIn, cmdIn) + return nil, fmt.Errorf("CreatePipe: %v", err) + } + + hPc, err := win32CreatePseudoConsole(&args.coords, ptyIn, ptyOut) + if err != nil { + closeHandles(ptyIn, ptyOut, cmdIn, cmdOut) + return nil, err + } + + pi, err := createConsoleProcessAttachedToPTY(hPc, commandLine, args.workDir, args.env) + if err != nil { + closeHandles(ptyIn, ptyOut, cmdIn, cmdOut) + win32ClosePseudoConsole(hPc) + return nil, fmt.Errorf("Failed to create console process: %v", err) + } + + cpty := &ConPty{ + hpc: hPc, + pi: pi, + ptyIn: &handleIO{ptyIn}, + ptyOut: &handleIO{ptyOut}, + cmdIn: &handleIO{cmdIn}, + cmdOut: &handleIO{cmdOut}, + } + return cpty, nil +} diff --git a/tsshd/forward.go b/tsshd/forward.go index acc86ea..3adeb72 100644 --- a/tsshd/forward.go +++ b/tsshd/forward.go @@ -30,8 +30,13 @@ import ( "net" "sync" "sync/atomic" + "time" ) +type closeWriter interface { + CloseWrite() error +} + var acceptMutex sync.Mutex var acceptID atomic.Uint64 var acceptMap = make(map[uint64]net.Conn) @@ -94,15 +99,14 @@ func handleListenEvent(stream net.Conn) { trySendErrorMessage("listener %s [%s] accept failed: %v", msg.Network, msg.Addr, err) continue } - acceptMutex.Lock() - id := acceptID.Add(1) - 1 - acceptMap[id] = conn + id := addAcceptConn(conn) if err := SendMessage(stream, AcceptMessage{id}); err != nil { - acceptMutex.Unlock() + if conn := getAcceptConn(id); conn != nil { + conn.Close() + } trySendErrorMessage("send accept message failed: %v", err) return } - acceptMutex.Unlock() } } @@ -113,16 +117,12 @@ func handleAcceptEvent(stream net.Conn) { return } - acceptMutex.Lock() - defer acceptMutex.Unlock() - - conn, ok := acceptMap[msg.ID] - if !ok { + conn := getAcceptConn(msg.ID) + if conn == nil { SendError(stream, fmt.Errorf("invalid accept id: %d", msg.ID)) return } - delete(acceptMap, msg.ID) defer conn.Close() if err := SendSuccess(stream); err != nil { // ack ok @@ -133,15 +133,47 @@ func handleAcceptEvent(stream net.Conn) { forwardConnection(stream, conn) } +func addAcceptConn(conn net.Conn) uint64 { + acceptMutex.Lock() + defer acceptMutex.Unlock() + id := acceptID.Add(1) - 1 + acceptMap[id] = conn + return id +} + +func getAcceptConn(id uint64) net.Conn { + acceptMutex.Lock() + defer acceptMutex.Unlock() + if conn, ok := acceptMap[id]; ok { + delete(acceptMap, id) + return conn + } + return nil +} + func forwardConnection(stream net.Conn, conn net.Conn) { var wg sync.WaitGroup wg.Add(2) go func() { _, _ = io.Copy(conn, stream) + if cw, ok := conn.(closeWriter); ok { + _ = cw.CloseWrite() + } else { + // close the entire stream since there is no half-close + time.Sleep(200 * time.Millisecond) + _ = conn.Close() + } wg.Done() }() go func() { _, _ = io.Copy(stream, conn) + if cw, ok := stream.(closeWriter); ok { + _ = cw.CloseWrite() + } else { + // close the entire stream since there is no half-close + time.Sleep(200 * time.Millisecond) + _ = stream.Close() + } wg.Done() }() wg.Wait() diff --git a/tsshd/main.go b/tsshd/main.go index 2f0cb9d..e41db09 100644 --- a/tsshd/main.go +++ b/tsshd/main.go @@ -68,6 +68,14 @@ func background() (bool, io.ReadCloser, error) { return true, stdout, nil } +var onExitFuncs []func() + +func cleanupOnExit() { + for i := len(onExitFuncs) - 1; i >= 0; i-- { + onExitFuncs[i]() + } +} + // TsshdMain is the main function of `tsshd` binary. func TsshdMain() int { var args tsshdArgs @@ -88,6 +96,9 @@ func TsshdMain() int { return 0 } + // cleanup on exit + defer cleanupOnExit() + kcpListener, quicListener, err := initServer(&args) if err != nil { fmt.Println(err) diff --git a/tsshd/proto.go b/tsshd/proto.go index 75bb93e..64c6540 100644 --- a/tsshd/proto.go +++ b/tsshd/proto.go @@ -65,6 +65,18 @@ type BusMessage struct { Timeout time.Duration } +type X11Request struct { + ChannelType string + SingleConnection bool + AuthProtocol string + AuthCookie string + ScreenNumber uint32 +} + +type AgentRequest struct { + ChannelType string +} + type StartMessage struct { ID uint64 Pty bool @@ -74,6 +86,8 @@ type StartMessage struct { Cols int Rows int Envs map[string]string + X11 *X11Request + Agent *AgentRequest } type ExitMessage struct { @@ -91,6 +105,11 @@ type StderrMessage struct { ID uint64 } +type ChannelMessage struct { + ChannelType string + ID uint64 +} + type DialMessage struct { Network string Addr string diff --git a/tsshd/server.go b/tsshd/server.go index e20545f..58b95af 100644 --- a/tsshd/server.go +++ b/tsshd/server.go @@ -62,9 +62,9 @@ var quicConfig = quic.Config{ func initServer(args *tsshdArgs) (*kcp.Listener, *quic.Listener, error) { portRangeLow := kDefaultPortRangeLow portRangeHigh := kDefaultPortRangeHigh - conn, port := listenOnFreePort(portRangeLow, portRangeHigh) - if conn == nil { - return nil, nil, fmt.Errorf("no free udp port in [%d, %d]", portRangeLow, portRangeHigh) + conn, port, err := listenUdpOnFreePort(portRangeLow, portRangeHigh) + if err != nil { + return nil, nil, err } info := &ServerInfo{ @@ -72,7 +72,6 @@ func initServer(args *tsshdArgs) (*kcp.Listener, *quic.Listener, error) { Port: port, } - var err error var kcpListener *kcp.Listener var quicListener *quic.Listener if args.KCP { @@ -99,34 +98,40 @@ func initServer(args *tsshdArgs) (*kcp.Listener, *quic.Listener, error) { return kcpListener, quicListener, nil } -func listenOnFreePort(low, high int) (*net.UDPConn, int) { +func listenUdpOnFreePort(low, high int) (*net.UDPConn, int, error) { if high < low { - return nil, -1 + return nil, 0, fmt.Errorf("no port in [%d,%d]", low, high) } + var err error + var conn *net.UDPConn size := high - low + 1 port := low + math_rand.Intn(size) for i := 0; i < size; i++ { - if conn := listenOnPort(port); conn != nil { - return conn, port + if conn, err = listenUdpOnPort(port); err == nil { + return conn, port, nil } port++ if port > high { port = low } } - return nil, -1 + if err != nil { + return nil, 0, fmt.Errorf("listen udp on [%d,%d] failed: %v", low, high, err) + } + return nil, 0, fmt.Errorf("listen udp on [%d,%d] failed", low, high) } -func listenOnPort(port int) *net.UDPConn { - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) +func listenUdpOnPort(port int) (*net.UDPConn, error) { + addr := fmt.Sprintf(":%d", port) + udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { - return nil + return nil, fmt.Errorf("resolve udp addr [%s] failed: %v", addr, err) } - conn, err := net.ListenUDP("udp", addr) + conn, err := net.ListenUDP("udp", udpAddr) if err != nil { - return nil + return nil, fmt.Errorf("listen udp on [%s] failed: %v", addr, err) } - return conn + return conn, nil } func listenKCP(conn *net.UDPConn, info *ServerInfo) (*kcp.Listener, error) { diff --git a/tsshd/session.go b/tsshd/session.go index 9475958..52b71ee 100644 --- a/tsshd/session.go +++ b/tsshd/session.go @@ -34,6 +34,7 @@ import ( "runtime" "strings" "sync" + "time" ) type sessionContext struct { @@ -173,6 +174,10 @@ func handleSessionEvent(stream net.Conn) { return } + handleX11Request(&msg) + + handleAgentRequest(&msg) + if errStream := getStderrStream(msg.ID); errStream != nil { defer errStream.Close() } @@ -329,3 +334,116 @@ func handleResizeEvent(stream net.Conn) error { } return fmt.Errorf("invalid session id: %d", msg.ID) } + +func handleX11Request(msg *StartMessage) { + if msg.X11 == nil { + return + } + listener, port, err := listenTcpOnFreePort("localhost", 6020, 6999) + if err != nil { + trySendErrorMessage("X11 forwarding listen failed: %v", err) + return + } + onExitFuncs = append(onExitFuncs, func() { + listener.Close() + }) + displayNumber := port - 6000 + if msg.X11.AuthProtocol != "" && msg.X11.AuthCookie != "" { + authDisplay := fmt.Sprintf("unix:%d.%d", displayNumber, msg.X11.ScreenNumber) + input := fmt.Sprintf("remove %s\nadd %s %s %s\n", authDisplay, authDisplay, msg.X11.AuthProtocol, msg.X11.AuthCookie) + if err := writeXauthData(input); err == nil { + onExitFuncs = append(onExitFuncs, func() { + _ = writeXauthData(fmt.Sprintf("remove %s\n", authDisplay)) + }) + } + } + go handleChannelAccept(listener, msg.X11.ChannelType) + msg.Envs["DISPLAY"] = fmt.Sprintf("localhost:%d.%d", displayNumber, msg.X11.ScreenNumber) +} + +func listenTcpOnFreePort(host string, low, high int) (net.Listener, int, error) { + var err error + var listener net.Listener + for port := low; port <= high; port++ { + listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) + if err == nil { + return listener, port, nil + } + } + if err != nil { + return nil, 0, fmt.Errorf("listen tcp on %s:[%d,%d] failed: %v", host, low, high, err) + } + return nil, 0, fmt.Errorf("listen tcp on %s:[%d,%d] failed", host, low, high) +} + +func writeXauthData(input string) error { + cmd := exec.Command("xauth", "-q", "-") + stdin, err := cmd.StdinPipe() + if err != nil { + return err + } + defer stdin.Close() + if err := cmd.Start(); err != nil { + return err + } + if _, err := stdin.Write([]byte(input)); err != nil { + return err + } + stdin.Close() + done := make(chan struct{}, 1) + go func() { + defer close(done) + _ = cmd.Wait() + done <- struct{}{} + }() + select { + case <-time.After(200 * time.Millisecond): + case <-done: + } + return nil +} + +func handleAgentRequest(msg *StartMessage) { + if msg.Agent == nil { + return + } + tempDir, err := os.MkdirTemp("", "tsshd-") + if err != nil { + trySendErrorMessage("agent forwarding mkdir temp failed: %v", err) + return + } + onExitFuncs = append(onExitFuncs, func() { + _ = os.RemoveAll(tempDir) + }) + agentPath := filepath.Join(tempDir, fmt.Sprintf("agent.%d", os.Getpid())) + listener, err := net.Listen("unix", agentPath) + if err != nil { + trySendErrorMessage("agent forwarding listen on [%s] failed: %v", agentPath, err) + return + } + if err := os.Chmod(agentPath, 0600); err != nil { + trySendErrorMessage("agent forwarding chmod [%s] failed: %v", agentPath, err) + } + onExitFuncs = append(onExitFuncs, func() { + listener.Close() + _ = os.Remove(agentPath) + }) + go handleChannelAccept(listener, msg.Agent.ChannelType) + msg.Envs["SSH_AUTH_SOCK"] = agentPath +} + +func handleChannelAccept(listener net.Listener, channelType string) { + for { + conn, err := listener.Accept() + if err != nil { + trySendErrorMessage("channel accept failed: %v", err) + break + } + go func(conn net.Conn) { + id := addAcceptConn(conn) + if err := sendBusMessage("channel", &ChannelMessage{ChannelType: channelType, ID: id}); err != nil { + trySendErrorMessage("send channel message failed: %v", err) + } + }(conn) + } +} diff --git a/tsshd/utils_windows.go b/tsshd/utils_windows.go index af84057..6430343 100644 --- a/tsshd/utils_windows.go +++ b/tsshd/utils_windows.go @@ -31,7 +31,7 @@ import ( "strings" "syscall" - "github.com/UserExistsError/conpty" + "github.com/trzsz/tsshd/internal/conpty" "golang.org/x/sys/windows" ) @@ -63,7 +63,7 @@ func newTsshdPty(cmd *exec.Cmd, cols, rows int) (*tsshdPty, error) { } cmdLine.WriteString(windows.EscapeArg(arg)) } - cpty, err := conpty.Start(cmdLine.String(), conpty.ConPtyDimensions(cols, rows)) + cpty, err := conpty.Start(cmdLine.String(), conpty.ConPtyDimensions(cols, rows), conpty.ConPtyEnv(cmd.Env)) if err != nil { return nil, err }