Skip to content

Commit

Permalink
expect send text supports ssh tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Jul 13, 2024
1 parent 3ac6a1d commit fd05ceb
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 44 deletions.
15 changes: 8 additions & 7 deletions tssh/ctrl_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (c *controlMaster) handleStdout() <-chan error {
return doneCh
}

func (c *controlMaster) fillPassword(args *sshArgs, expectCount int) (cancel context.CancelFunc) {
func (c *controlMaster) fillPassword(args *sshArgs, param *sshParam, expectCount int) (cancel context.CancelFunc) {
var ctx context.Context
expectTimeout := getExpectTimeout(args, "Ctrl")
if expectTimeout > 0 {
Expand All @@ -112,7 +112,8 @@ func (c *controlMaster) fillPassword(args *sshArgs, expectCount int) (cancel con
}

expect := &sshExpect{
alias: args.Destination,
param: param,
args: args,
ctx: ctx,
pre: "Ctrl",
out: make(chan []byte, 100),
Expand Down Expand Up @@ -141,7 +142,7 @@ func (c *controlMaster) checkExit() <-chan struct{} {
return exitCh
}

func (c *controlMaster) start(args *sshArgs) error {
func (c *controlMaster) start(args *sshArgs, param *sshParam) error {
var err error
c.cmd = exec.Command(c.path, c.args...)
expectCount := getExpectCount(args, "Ctrl")
Expand All @@ -157,7 +158,7 @@ func (c *controlMaster) start(args *sshArgs) error {
defer tty.Close()
c.cmd.Stdin = tty
c.ptmx = pty
cancel := c.fillPassword(args, expectCount)
cancel := c.fillPassword(args, param, expectCount)
defer cancel()
}
if c.stdout, err = c.cmd.StdoutPipe(); err != nil {
Expand Down Expand Up @@ -250,7 +251,7 @@ func getOpenSSH() (string, int, int, error) {
return sshPath, majorVersion, minorVersion, nil
}

func startControlMaster(args *sshArgs, sshPath string) error {
func startControlMaster(args *sshArgs, param *sshParam, sshPath string) error {
cmdArgs := []string{"-T", "-oRemoteCommand=none", "-oConnectTimeout=10"}

if args.Debug {
Expand Down Expand Up @@ -311,7 +312,7 @@ func startControlMaster(args *sshArgs, sshPath string) error {
}

ctrlMaster := &controlMaster{path: sshPath, args: cmdArgs}
if err := ctrlMaster.start(args); err != nil {
if err := ctrlMaster.start(args, param); err != nil {
return err
}
debug("start control master success")
Expand Down Expand Up @@ -356,7 +357,7 @@ func connectViaControl(args *sshArgs, param *sshParam) SshClient {
}
fallthrough
case "auto", "autoask":
if err := startControlMaster(args, sshPath); err != nil {
if err := startControlMaster(args, param, sshPath); err != nil {
warning("start control master failed: %v", err)
}
}
Expand Down
61 changes: 41 additions & 20 deletions tssh/expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ const (
)

type sshExpect struct {
alias string
param *sshParam
args *sshArgs
pre string
ctx context.Context
out chan []byte
Expand All @@ -54,6 +55,11 @@ type expectSender struct {
input string
}

type expectSendText struct {
showText string
sendText string
}

type caseSend struct {
pattern string
sender *expectSender
Expand Down Expand Up @@ -81,8 +87,22 @@ func newTextSender(expect *sshExpect, input string) *expectSender {
return &expectSender{expect, false, input}
}

func (s *expectSender) decodeText(text string) [][]string {
var texts [][]string
func (s *expectSender) newSendText(showText, sendText string) *expectSendText {
var err error
showText, err = expandTokens(showText, s.expect.args, s.expect.param, "%hprnLlj")
if err != nil {
warning("expand send text [%s] failed: %v", showText, err)
} else {
sendText, err = expandTokens(sendText, s.expect.args, s.expect.param, "%hprnLlj")
if err != nil {
warning("expand send text %s failed: %v", strconv.QuoteToASCII(sendText), strconv.QuoteToASCII(err.Error()))
}
}
return &expectSendText{showText: showText, sendText: sendText}
}

func (s *expectSender) decodeText(text string) []*expectSendText {
var texts []*expectSendText
var buf strings.Builder
state := byte(0)
idx := 0
Expand All @@ -105,7 +125,7 @@ func (s *expectSender) decodeText(text string) [][]string {
case 'n':
buf.WriteRune('\n')
case '|':
texts = append(texts, []string{text[idx : i-1], buf.String()})
texts = append(texts, s.newSendText(text[idx:i-1], buf.String()))
idx = i + 1
buf.Reset()
default:
Expand All @@ -118,12 +138,12 @@ func (s *expectSender) decodeText(text string) [][]string {
warning("[%s] ends with \\ is invalid", text)
buf.WriteRune('\\')
}
texts = append(texts, []string{text[idx:], buf.String()})
texts = append(texts, s.newSendText(text[idx:], buf.String()))
return texts
}

func (s *expectSender) getExpectPsssSleep() (bool, bool) {
passSleep := getExConfig(s.expect.alias, fmt.Sprintf("%sExpectPassSleep", s.expect.pre))
passSleep := getExConfig(s.expect.args.Destination, fmt.Sprintf("%sExpectPassSleep", s.expect.pre))
switch strings.ToLower(passSleep) {
case "each":
return true, false
Expand All @@ -135,7 +155,7 @@ func (s *expectSender) getExpectPsssSleep() (bool, bool) {
}

func (s *expectSender) getExpectSleepTime() time.Duration {
expectSleepMS := getExConfig(s.expect.alias, fmt.Sprintf("%sExpectSleepMS", s.expect.pre))
expectSleepMS := getExConfig(s.expect.args.Destination, fmt.Sprintf("%sExpectSleepMS", s.expect.pre))
if expectSleepMS == "" {
return kDefaultExpectSleepMS * time.Millisecond
}
Expand Down Expand Up @@ -188,11 +208,11 @@ func (s *expectSender) sendInput(writer io.Writer, id string) bool {
debug("expect %s sleep: %v", id, sleepTime)
time.Sleep(sleepTime)
}
if text[1] == "" {
if text.sendText == "" {
continue
}
debug("expect %s send: %s", id, text[0])
if err := writeAll(writer, []byte(text[1])); err != nil {
debug("expect %s send: %s", id, text.showText)
if err := writeAll(writer, []byte(text.sendText)); err != nil {
warning("expect %s send input failed: %v", id, err)
return false
}
Expand Down Expand Up @@ -377,7 +397,7 @@ func (e *sshExpect) waitForPattern(pattern string, caseSends *caseSendList) erro
}

func (e *sshExpect) getExpectSender(idx int) *expectSender {
if pass := getExConfig(e.alias, fmt.Sprintf("%sExpectSendPass%d", e.pre, idx)); pass != "" {
if pass := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendPass%d", e.pre, idx)); pass != "" {
secret, err := decodeSecret(pass)
if err != nil {
warning("decode %sExpectSendPass%d [%s] failed: %v", e.pre, idx, pass, err)
Expand All @@ -386,11 +406,11 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender {
return newPassSender(e, secret)
}

if text := getExConfig(e.alias, fmt.Sprintf("%sExpectSendText%d", e.pre, idx)); text != "" {
if text := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendText%d", e.pre, idx)); text != "" {
return newTextSender(e, text)
}

if encTotp := getExConfig(e.alias, fmt.Sprintf("%sExpectSendEncTotp%d", e.pre, idx)); encTotp != "" {
if encTotp := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendEncTotp%d", e.pre, idx)); encTotp != "" {
secret, err := decodeSecret(encTotp)
if err != nil {
warning("decode %sExpectSendEncTotp%d [%s] failed: %v", e.pre, idx, encTotp, err)
Expand All @@ -399,7 +419,7 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender {
return newPassSender(e, getTotpCode(secret))
}

if encOtp := getExConfig(e.alias, fmt.Sprintf("%sExpectSendEncOtp%d", e.pre, idx)); encOtp != "" {
if encOtp := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendEncOtp%d", e.pre, idx)); encOtp != "" {
command, err := decodeSecret(encOtp)
if err != nil {
warning("decode %sExpectSendEncOtp%d [%s] failed: %v", e.pre, idx, encOtp, err)
Expand All @@ -408,11 +428,11 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender {
return newPassSender(e, getOtpCommandOutput(command))
}

if secret := getExConfig(e.alias, fmt.Sprintf("%sExpectSendTotp%d", e.pre, idx)); secret != "" {
if secret := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendTotp%d", e.pre, idx)); secret != "" {
return newPassSender(e, getTotpCode(secret))
}

if command := getExConfig(e.alias, fmt.Sprintf("%sExpectSendOtp%d", e.pre, idx)); command != "" {
if command := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendOtp%d", e.pre, idx)); command != "" {
return newPassSender(e, getOtpCommandOutput(command))
}

Expand All @@ -421,19 +441,19 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender {

func (e *sshExpect) execInteractions(writer io.Writer, expectCount int) {
for idx := 1; idx <= expectCount; idx++ {
pattern := getExConfig(e.alias, fmt.Sprintf("%sExpectPattern%d", e.pre, idx))
pattern := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectPattern%d", e.pre, idx))
if pattern != "" {
debug("expect %d pattern: %s", idx, pattern)
} else {
warning("expect %d pattern is empty, no output will be matched", idx)
}
caseSends := &caseSendList{e, writer, nil}
for _, cfg := range getAllExConfig(e.alias, fmt.Sprintf("%sExpectCaseSendPass%d", e.pre, idx)) {
for _, cfg := range getAllExConfig(e.args.Destination, fmt.Sprintf("%sExpectCaseSendPass%d", e.pre, idx)) {
if err := caseSends.addCaseSendPass(cfg); err != nil {
warning("Invalid ExpectCaseSendPass%d: %v", idx, err)
}
}
for _, cfg := range getAllExConfig(e.alias, fmt.Sprintf("%sExpectCaseSendText%d", e.pre, idx)) {
for _, cfg := range getAllExConfig(e.args.Destination, fmt.Sprintf("%sExpectCaseSendText%d", e.pre, idx)) {
if err := caseSends.addCaseSendText(cfg); err != nil {
warning("Invalid ExpectCaseSendText%d: %v", idx, err)
}
Expand Down Expand Up @@ -497,7 +517,8 @@ func execExpectInteractions(args *sshArgs, ss *sshClientSession) {
defer cancel()

expect := &sshExpect{
alias: args.Destination,
param: ss.param,
args: args,
ctx: ctx,
out: make(chan []byte, 10),
err: make(chan []byte, 10),
Expand Down
20 changes: 10 additions & 10 deletions tssh/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -1153,21 +1153,21 @@ func sshAgentForward(args *sshArgs, param *sshParam, client SshClient, session S
debug("request ssh agent forwarding success")
}

func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode int, err error) {
func sshTcpLogin(args *sshArgs) (ss *sshClientSession, udpMode int, err error) {
ss = &sshClientSession{}
defer func() {
if err != nil {
ss.Close()
} else {
sshLoginSuccess.Store(true)
// execute local command if necessary
execLocalCommand(args, param)
execLocalCommand(args, ss.param)
}
}()

// ssh login
var control bool
ss.client, param, control, err = sshConnect(args, nil, "")
ss.client, ss.param, control, err = sshConnect(args, nil, "")
if err != nil {
return
}
Expand All @@ -1176,7 +1176,7 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode
udpMode = getUdpMode(args)

// parse cmd and tty
ss.cmd, ss.tty, err = parseCmdAndTTY(args, param)
ss.cmd, ss.tty, err = parseCmdAndTTY(args, ss.param)
if err != nil {
return
}
Expand All @@ -1194,7 +1194,7 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode

// ssh port forwarding
if !control && udpMode == kUdpModeNo {
if err = sshForward(ss.client, args, param); err != nil {
if err = sshForward(ss.client, args, ss.param); err != nil {
return
}
}
Expand Down Expand Up @@ -1231,7 +1231,7 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode

if !control && udpMode == kUdpModeNo {
// ssh agent forward
sshAgentForward(args, param, ss.client, ss.session)
sshAgentForward(args, ss.param, ss.client, ss.session)
// x11 forward
sshX11Forward(args, ss.client, ss.session)
}
Expand All @@ -1240,20 +1240,20 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode
}

func sshLogin(args *sshArgs) (*sshClientSession, error) {
ss, param, udpMode, err := sshTcpLogin(args)
ss, udpMode, err := sshTcpLogin(args)
if err != nil {
return nil, err
}

if udpMode != kUdpModeNo {
ss, err = sshUdpLogin(args, param, ss, udpMode)
ss, err = sshUdpLogin(args, ss, udpMode)
if err != nil {
return nil, err
}

// ssh port forwarding if not running as a proxy ( aka: not stdio forward ).
if args.StdioForward == "" {
if err := sshForward(ss.client, args, param); err != nil {
if err := sshForward(ss.client, args, ss.param); err != nil {
ss.Close()
return nil, err
}
Expand All @@ -1263,7 +1263,7 @@ func sshLogin(args *sshArgs) (*sshClientSession, error) {
// if not running as a proxy ( aka: not stdio forward ) and executing remote command
if args.StdioForward == "" && !args.NoCommand {
// ssh agent forward
sshAgentForward(args, param, ss.client, ss.session)
sshAgentForward(args, ss.param, ss.client, ss.session)
// x11 forward
sshX11Forward(args, ss.client, ss.session)
}
Expand Down
1 change: 1 addition & 0 deletions tssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ type sshClientSession struct {
serverIn io.WriteCloser
serverOut io.Reader
serverErr io.Reader
param *sshParam
cmd string
tty bool
}
Expand Down
10 changes: 5 additions & 5 deletions tssh/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,21 @@ func expandTokens(str string, args *sshArgs, param *sshParam, tokens string) (st
}
state = 0
if !strings.ContainsRune(tokens, c) {
return "", fmt.Errorf("token [%%%c] in [%s] is not supported", c, str)
return str, fmt.Errorf("token [%%%c] in [%s] is not supported", c, str)
}
switch c {
case '%':
buf.WriteRune('%')
case 'h':
if !isHostValid(param.host) {
return "", fmt.Errorf("hostname contains invalid characters")
return str, fmt.Errorf("hostname contains invalid characters")
}
buf.WriteString(param.host)
case 'p':
buf.WriteString(param.port)
case 'r':
if !isUserValid(param.user) {
return "", fmt.Errorf("remote username contains invalid characters")
return str, fmt.Errorf("remote username contains invalid characters")
}
buf.WriteString(param.user)
case 'n':
Expand All @@ -131,11 +131,11 @@ func expandTokens(str string, args *sshArgs, param *sshParam, tokens string) (st
}
buf.WriteString(fmt.Sprintf("%x", sha1.Sum([]byte(hashStr))))
default:
return "", fmt.Errorf("token [%%%c] in [%s] is not supported yet", c, str)
return str, fmt.Errorf("token [%%%c] in [%s] is not supported yet", c, str)
}
}
if state != 0 {
return "", fmt.Errorf("[%s] ends with %% is invalid", str)
return str, fmt.Errorf("[%s] ends with %% is invalid", str)
}
return buf.String(), nil
}
1 change: 1 addition & 0 deletions tssh/tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func TestExpandTokens(t *testing.T) {
result, err := expandTokens(original, args, param, "%hnpr")
if errMsg != "" {
require.NotNil(err)
assert.Equal(original, result)
assert.Equal(errMsg, err.Error())
return
}
Expand Down
Loading

0 comments on commit fd05ceb

Please sign in to comment.