diff --git a/plugin/plugin.go b/plugin/plugin.go index 55c697b..e042003 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -2,6 +2,7 @@ package plugin import ( "bufio" + "encoding/base64" "errors" "flag" "fmt" @@ -12,13 +13,6 @@ import ( "filippo.io/age/internal/format" ) -// TODO: implement interaction methods. -// -// // Can only be used during a Wrap or Unwrap invoked by Plugin. -// func (*Plugin) DisplayMessage(message string) error -// func (*Plugin) RequestValue(prompt string, secret bool) (string, error) -// func (*Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) - // TODO: add examples. // Plugin is a framework for writing age plugins. It allows exposing regular @@ -32,6 +26,11 @@ type Plugin struct { recipient func([]byte) (age.Recipient, error) idAsRecipient func([]byte) (age.Recipient, error) identity func([]byte) (age.Identity, error) + + sr *format.StanzaReader + // broken is set if the protocol broke down during an interaction function + // called by a Recipient or Identity. + broken bool } // New creates a new Plugin with the given name. @@ -137,10 +136,10 @@ func (p *Plugin) RecipientV1() int { var fileKeys [][]byte var supportsLabels bool - sr := format.NewStanzaReader(bufio.NewReader(os.Stdin)) + p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin)) ReadLoop: for { - s, err := sr.ReadStanza() + s, err := p.sr.ReadStanza() if err != nil { return fatalf("failed to read stanza: %v", err) } @@ -187,34 +186,34 @@ ReadLoop: for i, s := range recipientStrings { name, data, err := ParseRecipient(s) if err != nil { - return recipientError(sr, i, err) + return recipientError(p.sr, i, err) } if name != p.name { - return recipientError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return recipientError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.recipient == nil { - return recipientError(sr, i, fmt.Errorf("recipient encodings not supported")) + return recipientError(p.sr, i, fmt.Errorf("recipient encodings not supported")) } r, err := p.recipient(data) if err != nil { - return recipientError(sr, i, err) + return recipientError(p.sr, i, err) } recipients = append(recipients, r) } for i, s := range identityStrings { name, data, err := ParseIdentity(s) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } if name != p.name { - return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.idAsRecipient == nil { - return identityError(sr, i, fmt.Errorf("identity encodings not supported")) + return identityError(p.sr, i, fmt.Errorf("identity encodings not supported")) } r, err := p.idAsRecipient(data) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } identities = append(identities, r) } @@ -227,25 +226,29 @@ ReadLoop: for i, fk := range fileKeys { for j, r := range recipients { ss, ll, err := wrapWithLabels(r, fk) - if err != nil { - return recipientError(sr, j, err) + if p.broken { + return 2 + } else if err != nil { + return recipientError(p.sr, j, err) } if i == 0 && j == 0 { labels = ll } else if err := checkLabels(ll, labels); err != nil { - return recipientError(sr, j, err) + return recipientError(p.sr, j, err) } stanzas[i] = append(stanzas[i], ss...) } for j, r := range identities { ss, ll, err := wrapWithLabels(r, fk) - if err != nil { - return identityError(sr, j, err) + if p.broken { + return 2 + } else if err != nil { + return identityError(p.sr, j, err) } if i == 0 && j == 0 && len(recipients) == 0 { labels = ll } else if err := checkLabels(ll, labels); err != nil { - return identityError(sr, j, err) + return identityError(p.sr, j, err) } stanzas[i] = append(stanzas[i], ss...) } @@ -254,7 +257,7 @@ ReadLoop: if sent, err := writeGrease(os.Stdout); err != nil { return fatalf("failed to write grease: %v", err) } else if sent { - if err := expectUnsupported(sr); err != nil { + if err := expectUnsupported(p.sr); err != nil { return fatalf("%v", err) } } @@ -263,7 +266,7 @@ ReadLoop: if err := writeStanza(os.Stdout, "labels", labels...); err != nil { return fatalf("failed to write labels stanza: %v", err) } - if err := expectOk(sr); err != nil { + if err := expectOk(p.sr); err != nil { return fatalf("%v", err) } } @@ -275,14 +278,14 @@ ReadLoop: Body: s.Body}).Marshal(os.Stdout); err != nil { return fatalf("failed to write recipient-stanza: %v", err) } - if err := expectOk(sr); err != nil { + if err := expectOk(p.sr); err != nil { return fatalf("%v", err) } } if sent, err := writeGrease(os.Stdout); err != nil { return fatalf("failed to write grease: %v", err) } else if sent { - if err := expectUnsupported(sr); err != nil { + if err := expectUnsupported(p.sr); err != nil { return fatalf("%v", err) } } @@ -321,10 +324,10 @@ func (p *Plugin) IdentityV1() int { var files [][]*age.Stanza var identityStrings []string - sr := format.NewStanzaReader(bufio.NewReader(os.Stdin)) + p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin)) ReadLoop: for { - s, err := sr.ReadStanza() + s, err := p.sr.ReadStanza() if err != nil { return fatalf("failed to read stanza: %v", err) } @@ -373,17 +376,17 @@ ReadLoop: for i, s := range identityStrings { name, data, err := ParseIdentity(s) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } if name != p.name { - return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.identity == nil { - return identityError(sr, i, fmt.Errorf("identity encodings not supported")) + return identityError(p.sr, i, fmt.Errorf("identity encodings not supported")) } r, err := p.identity(data) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } identities = append(identities, r) } @@ -392,7 +395,7 @@ ReadLoop: if sent, err := writeGrease(os.Stdout); err != nil { return fatalf("failed to write grease: %v", err) } else if sent { - if err := expectUnsupported(sr); err != nil { + if err := expectUnsupported(p.sr); err != nil { return fatalf("%v", err) } } @@ -401,10 +404,12 @@ ReadLoop: // in which identities are tried. for _, id := range identities { fk, err := id.Unwrap(ss) - if errors.Is(err, age.ErrIncorrectIdentity) { + if p.broken { + return 2 + } else if errors.Is(err, age.ErrIncorrectIdentity) { continue } else if err != nil { - if err := writeError(sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil { + if err := writeError(p.sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil { return fatalf("%v", err) } // Note that we don't exit here, as the protocol allows @@ -416,7 +421,7 @@ ReadLoop: if err := s.Marshal(os.Stdout); err != nil { return fatalf("failed to write file-key: %v", err) } - if err := expectOk(sr); err != nil { + if err := expectOk(p.sr); err != nil { return fatalf("%v", err) } break @@ -429,6 +434,86 @@ ReadLoop: return 0 } +// DisplayMessage requests that the client display a message to the user. The +// message should start with a lowercase letter and have no final period. +// DisplayMessage returns an error if the client can't display the message, and +// may return before the message has been displayed to the user. +// +// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. +func (p *Plugin) DisplayMessage(message string) error { + if err := writeStanzaWithBody(os.Stdout, "msg", []byte(message)); err != nil { + return p.fatalInteractf("failed to write msg stanza: %v", err) + } + s, err := readOkOrFail(p.sr) + if err != nil { + return p.fatalInteractf("%v", err) + } + if s.Type == "fail" { + return fmt.Errorf("client failed to display message") + } + return nil +} + +// RequestValue requests a secret or public input from the user through the +// client, with the provided prompt. It returns an error if the client can't +// request the input or if the user dismisses the prompt. +// +// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. +func (p *Plugin) RequestValue(prompt string, secret bool) (string, error) { + t := "request-public" + if secret { + t = "request-secret" + } + if err := writeStanzaWithBody(os.Stdout, t, []byte(prompt)); err != nil { + return "", p.fatalInteractf("failed to write stanza: %v", err) + } + s, err := readOkOrFail(p.sr) + if err != nil { + return "", p.fatalInteractf("%v", err) + } + if s.Type == "fail" { + return "", fmt.Errorf("client failed to request value") + } + return string(s.Body), nil +} + +// Confirm requests a confirmation from the user through the client, with the +// provided prompt. The yes and no value are the choices provided to the user. +// no may be empty. The return value choseYes indicates whether the user +// selected the yes or no option. Confirm returns an error if the client can't +// request the confirmation. +// +// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. +func (p *Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) { + args := []string{base64.StdEncoding.EncodeToString([]byte(yes))} + if no != "" { + args = append(args, base64.StdEncoding.EncodeToString([]byte(no))) + } + s := &format.Stanza{Type: "confirm", Args: args, Body: []byte(prompt)} + if err := s.Marshal(os.Stdout); err != nil { + return false, p.fatalInteractf("failed to write confirm stanza: %v", err) + } + s, err = readOkOrFail(p.sr) + if err != nil { + return false, p.fatalInteractf("%v", err) + } + if s.Type == "fail" { + return false, fmt.Errorf("client failed to request confirmation") + } + if err := expectStanzaWithNoBody(s, 1); err != nil { + return false, p.fatalInteractf("%v", err) + } + return s.Args[0] == "yes", nil +} + +// fatalInteractf prints the error to stderr and sets the broken flag, so the +// Wrap/Unwrap caller can exit with an error. +func (p *Plugin) fatalInteractf(format string, args ...interface{}) error { + p.broken = true + fmt.Fprintf(os.Stderr, format, args...) + return fmt.Errorf(format, args...) +} + func expectStanzaWithNoBody(s *format.Stanza, wantArgs int) error { if len(s.Args) != wantArgs { return fmt.Errorf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) @@ -474,6 +559,27 @@ func expectOk(sr *format.StanzaReader) error { return expectStanzaWithNoBody(ok, 0) } +func readOkOrFail(sr *format.StanzaReader) (*format.Stanza, error) { + s, err := sr.ReadStanza() + if err != nil { + return nil, fmt.Errorf("failed to read response stanza: %v", err) + } + switch s.Type { + case "fail": + if err := expectStanzaWithNoBody(s, 0); err != nil { + return nil, fmt.Errorf("%v", err) + } + return s, nil + case "ok": + if s.Body != nil { + return nil, fmt.Errorf("ok stanza has %d bytes of body, want 0", len(s.Body)) + } + return s, nil + default: + return nil, fmt.Errorf("expected ok or fail stanza, got %q", s.Type) + } +} + func expectUnsupported(sr *format.StanzaReader) error { unsupported, err := sr.ReadStanza() if err != nil {