From 1e6121f213cc025123828997ba693402b5b1bade Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sat, 20 Apr 2024 11:08:57 +0800 Subject: [PATCH] support token %j(ProxyJump) --- tssh/agent.go | 2 +- tssh/ctrl_unix.go | 42 +++++++++++++++++++++++++++++----------- tssh/forward.go | 4 ++-- tssh/login.go | 8 ++++---- tssh/tokens.go | 7 +++++++ tssh/tokens_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++--- 6 files changed, 89 insertions(+), 21 deletions(-) diff --git a/tssh/agent.go b/tssh/agent.go index 757b7a4..8db859a 100644 --- a/tssh/agent.go +++ b/tssh/agent.go @@ -46,7 +46,7 @@ func getAgentAddr(args *sshArgs, param *sshParam) (string, error) { if strings.ToLower(addr) == "none" { return "", nil } - expandedAddr, err := expandTokens(addr, args, param, "%CdhikLlnpru") + expandedAddr, err := expandTokens(addr, args, param, "%CdhijkLlnpru") if err != nil { return "", fmt.Errorf("expand IdentityAgent [%s] failed: %v", addr, err) } diff --git a/tssh/ctrl_unix.go b/tssh/ctrl_unix.go index f004efb..72f372d 100644 --- a/tssh/ctrl_unix.go +++ b/tssh/ctrl_unix.go @@ -37,6 +37,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "regexp" "strconv" "strings" "sync/atomic" @@ -226,24 +227,29 @@ func getRealPath(path string) string { return realPath } -func getOpenSSH() (string, error) { +func getOpenSSH() (string, int, error) { sshPath := "/usr/bin/ssh" tsshPath, err := os.Executable() if err != nil { - return "", err + return "", 0, err } if getRealPath(tsshPath) == getRealPath(sshPath) { - return "", fmt.Errorf("%s is the current program", sshPath) + return "", 0, fmt.Errorf("%s is the current program", sshPath) } - return sshPath, nil -} - -func startControlMaster(args *sshArgs) error { - sshPath, err := getOpenSSH() + out, err := exec.Command(sshPath, "-V").CombinedOutput() if err != nil { - return fmt.Errorf("can't find openssh program: %v", err) + return "", 0, err + } + re := regexp.MustCompile(`OpenSSH_(\d+)\.(\d+)`) + matches := re.FindStringSubmatch(string(out)) + majorVersion := -1 + if len(matches) >= 3 { + majorVersion, _ = strconv.Atoi(matches[1]) } + return sshPath, majorVersion, nil +} +func startControlMaster(args *sshArgs, sshPath string) error { cmdArgs := []string{"-T", "-oRemoteCommand=none", "-oConnectTimeout=10"} if args.Debug { @@ -320,7 +326,21 @@ func connectViaControl(args *sshArgs, param *sshParam) *ssh.Client { return nil } - socket, err := expandTokens(ctrlPath, args, param, "%CdhikLlnpru") + sshPath, sshVersion, err := getOpenSSH() + if err != nil { + warning("can't find openssh program: %v", err) + return nil + } + if sshVersion < 0 { + warning("can't get openssh version of %s", sshPath) + return nil + } + + tokens := "%CdhijkLlnpru" + if sshVersion < 9 { + tokens = "%CdhikLlnpru" + } + socket, err := expandTokens(ctrlPath, args, param, tokens) if err != nil { warning("expand ControlPath [%s] failed: %v", socket, err) return nil @@ -335,7 +355,7 @@ func connectViaControl(args *sshArgs, param *sshParam) *ssh.Client { } fallthrough case "auto", "autoask": - if err := startControlMaster(args); err != nil { + if err := startControlMaster(args, sshPath); err != nil { warning("start control master failed: %v", err) } } diff --git a/tssh/forward.go b/tssh/forward.go index baf39b9..3cf0fa3 100644 --- a/tssh/forward.go +++ b/tssh/forward.go @@ -406,7 +406,7 @@ func sshForward(client *ssh.Client, args *sshArgs, param *sshParam) error { localForward(client, f, args) } for _, s := range getAllOptionConfig(args, "LocalForward") { - es, err := expandTokens(s, args, param, "%CdhikLlnpru") + es, err := expandTokens(s, args, param, "%CdhijkLlnpru") if err != nil { warning("expand LocalForward [%s] failed: %v", s, err) continue @@ -424,7 +424,7 @@ func sshForward(client *ssh.Client, args *sshArgs, param *sshParam) error { remoteForward(client, f, args) } for _, s := range getAllOptionConfig(args, "RemoteForward") { - es, err := expandTokens(s, args, param, "%CdhikLlnpru") + es, err := expandTokens(s, args, param, "%CdhijkLlnpru") if err != nil { warning("expand RemoteForward [%s] failed: %v", s, err) continue diff --git a/tssh/login.go b/tssh/login.go index bf4c168..b88b644 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -326,7 +326,7 @@ func getHostKeyCallback(args *sshArgs, param *sshParam) (ssh.HostKeyCallback, kn for _, path := range strings.Fields(knownHostsFiles) { var resolvedPath string if user { - expandedPath, err := expandTokens(path, args, param, "%CdhikLlnpru") + expandedPath, err := expandTokens(path, args, param, "%CdhijkLlnpru") if err != nil { return fmt.Errorf("expand UserKnownHostsFile [%s] failed: %v", path, err) } @@ -709,7 +709,7 @@ func getPublicKeysAuthMethod(args *sshArgs, param *sshParam) ssh.AuthMethod { identities := args.Identity.values for _, identity := range getAllOptionConfig(args, "IdentityFile") { - expandedIdentity, err := expandTokens(identity, args, param, "%CdhikLlnpru") + expandedIdentity, err := expandTokens(identity, args, param, "%CdhijkLlnpru") if err != nil { warning("expand IdentityFile [%s] failed: %v", identity, err) continue @@ -847,7 +847,7 @@ func execLocalCommand(args *sshArgs, param *sshParam) { if localCmd == "" { return } - expandedCmd, err := expandTokens(localCmd, args, param, "%CdfHhIiKkLlnprTtu") + expandedCmd, err := expandTokens(localCmd, args, param, "%CdfHhIijKkLlnprTtu") if err != nil { warning("expand LocalCommand [%s] failed: %v", localCmd, err) return @@ -891,7 +891,7 @@ func parseRemoteCommand(args *sshArgs, param *sshParam) (string, error) { if command == "" { command = getConfig(args.Destination, "RemoteCommand") } - expandedCmd, err := expandTokens(command, args, param, "%CdhikLlnpru") + expandedCmd, err := expandTokens(command, args, param, "%CdhijkLlnpru") if err != nil { return "", fmt.Errorf("expand RemoteCommand [%s] failed: %v", command, err) } diff --git a/tssh/tokens.go b/tssh/tokens.go index a837154..dc94941 100644 --- a/tssh/tokens.go +++ b/tssh/tokens.go @@ -120,8 +120,15 @@ func expandTokens(str string, args *sshArgs, param *sshParam, tokens string) (st hostname = hostname[:idx] } buf.WriteString(hostname) + case 'j': + if len(param.proxy) > 0 { + buf.WriteString(param.proxy[len(param.proxy)-1]) + } case 'C': hashStr := fmt.Sprintf("%s%s%s%s", getHostname(), param.host, param.port, param.user) + if len(param.proxy) > 0 && strings.ContainsRune(tokens, 'j') { + hashStr += param.proxy[len(param.proxy)-1] + } buf.WriteString(fmt.Sprintf("%x", sha1.Sum([]byte(hashStr)))) default: return "", fmt.Errorf("token [%%%c] in [%s] is not supported yet", c, str) diff --git a/tssh/tokens_test.go b/tssh/tokens_test.go index 451b77e..9dde4ea 100644 --- a/tssh/tokens_test.go +++ b/tssh/tokens_test.go @@ -44,9 +44,10 @@ func TestExpandTokens(t *testing.T) { Destination: "dest", } param := &sshParam{ - host: "127.0.0.1", - port: "1337", - user: "penny", + host: "127.0.0.1", + port: "1337", + user: "penny", + proxy: []string{"jump"}, } assertProxyCommand := func(original, expanded, errMsg string) { t.Helper() @@ -94,6 +95,46 @@ func TestExpandTokens(t *testing.T) { assertControlPath("h%", "h%", "[h%] ends with % is invalid") } +func TestProxyJumpToken(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + originalGetHostname := getHostname + defer func() { + getHostname = originalGetHostname + }() + getHostname = func() string { return "myhostname.mydomain.com" } + + args := &sshArgs{ + Destination: "dest", + } + param := &sshParam{ + host: "127.0.0.1", + port: "1337", + user: "penny", + } + + assertProxyJumpToken := func(original, expanded string) { + t.Helper() + result, err := expandTokens(original, args, param, "%CdhijkLlnpru") + require.Nil(err) + assert.Equal(expanded, result) + } + + assertProxyJumpToken("%j", "") + assertProxyJumpToken("_%j_", "__") + assertProxyJumpToken("%C", "07f25c03a322b120bcaa54d2dd0a618f2673cb1c") + + param.proxy = []string{"jump"} + assertProxyJumpToken("%j", "jump") + assertProxyJumpToken("_%j_", "_jump_") + assertProxyJumpToken("%C", "5fa1bcd29f7fd4f17b669ffb83deb4243d52b1fa") + + param.proxy = []string{"jump", "server"} + assertProxyJumpToken("%j", "server") + assertProxyJumpToken("_%j_", "_server_") + assertProxyJumpToken("/%C/", "/dc78bc912643b984e78d7d80f9912dbc794d2455/") +} + func TestInvalidHost(t *testing.T) { assert := assert.New(t) require := require.New(t)