Skip to content

Commit

Permalink
plugin: add interactivity methods to framework
Browse files Browse the repository at this point in the history
  • Loading branch information
FiloSottile committed Jun 18, 2024
1 parent 0fbe2ac commit 7eedd92
Showing 1 changed file with 143 additions and 37 deletions.
180 changes: 143 additions & 37 deletions plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package plugin

import (
"bufio"
"encoding/base64"
"errors"
"flag"
"fmt"
Expand All @@ -12,13 +13,6 @@ import (
"filippo.io/age/internal/format"
)

// TODO: implement interaction methods.
//
// // Can only be used during a Wrap or Unwrap invoked by Plugin.
// func (*Plugin) DisplayMessage(message string) error
// func (*Plugin) RequestValue(prompt string, secret bool) (string, error)
// func (*Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error)

// TODO: add examples.

// Plugin is a framework for writing age plugins. It allows exposing regular
Expand All @@ -32,6 +26,11 @@ type Plugin struct {
recipient func([]byte) (age.Recipient, error)
idAsRecipient func([]byte) (age.Recipient, error)
identity func([]byte) (age.Identity, error)

sr *format.StanzaReader
// broken is set if the protocol broke down during an interaction function
// called by a Recipient or Identity.
broken bool
}

// New creates a new Plugin with the given name.
Expand Down Expand Up @@ -137,10 +136,10 @@ func (p *Plugin) RecipientV1() int {
var fileKeys [][]byte
var supportsLabels bool

sr := format.NewStanzaReader(bufio.NewReader(os.Stdin))
p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin))
ReadLoop:
for {
s, err := sr.ReadStanza()
s, err := p.sr.ReadStanza()
if err != nil {
return fatalf("failed to read stanza: %v", err)
}
Expand Down Expand Up @@ -187,34 +186,34 @@ ReadLoop:
for i, s := range recipientStrings {
name, data, err := ParseRecipient(s)
if err != nil {
return recipientError(sr, i, err)
return recipientError(p.sr, i, err)
}
if name != p.name {
return recipientError(sr, i, fmt.Errorf("unsupported plugin name: %q", name))
return recipientError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name))
}
if p.recipient == nil {
return recipientError(sr, i, fmt.Errorf("recipient encodings not supported"))
return recipientError(p.sr, i, fmt.Errorf("recipient encodings not supported"))
}
r, err := p.recipient(data)
if err != nil {
return recipientError(sr, i, err)
return recipientError(p.sr, i, err)
}
recipients = append(recipients, r)
}
for i, s := range identityStrings {
name, data, err := ParseIdentity(s)
if err != nil {
return identityError(sr, i, err)
return identityError(p.sr, i, err)
}
if name != p.name {
return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name))
return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name))
}
if p.idAsRecipient == nil {
return identityError(sr, i, fmt.Errorf("identity encodings not supported"))
return identityError(p.sr, i, fmt.Errorf("identity encodings not supported"))
}
r, err := p.idAsRecipient(data)
if err != nil {
return identityError(sr, i, err)
return identityError(p.sr, i, err)
}
identities = append(identities, r)
}
Expand All @@ -227,25 +226,29 @@ ReadLoop:
for i, fk := range fileKeys {
for j, r := range recipients {
ss, ll, err := wrapWithLabels(r, fk)
if err != nil {
return recipientError(sr, j, err)
if p.broken {
return 2
} else if err != nil {
return recipientError(p.sr, j, err)
}
if i == 0 && j == 0 {
labels = ll
} else if err := checkLabels(ll, labels); err != nil {
return recipientError(sr, j, err)
return recipientError(p.sr, j, err)
}
stanzas[i] = append(stanzas[i], ss...)
}
for j, r := range identities {
ss, ll, err := wrapWithLabels(r, fk)
if err != nil {
return identityError(sr, j, err)
if p.broken {
return 2
} else if err != nil {
return identityError(p.sr, j, err)
}
if i == 0 && j == 0 && len(recipients) == 0 {
labels = ll
} else if err := checkLabels(ll, labels); err != nil {
return identityError(sr, j, err)
return identityError(p.sr, j, err)
}
stanzas[i] = append(stanzas[i], ss...)
}
Expand All @@ -254,7 +257,7 @@ ReadLoop:
if sent, err := writeGrease(os.Stdout); err != nil {
return fatalf("failed to write grease: %v", err)
} else if sent {
if err := expectUnsupported(sr); err != nil {
if err := expectUnsupported(p.sr); err != nil {
return fatalf("%v", err)
}
}
Expand All @@ -263,7 +266,7 @@ ReadLoop:
if err := writeStanza(os.Stdout, "labels", labels...); err != nil {
return fatalf("failed to write labels stanza: %v", err)
}
if err := expectOk(sr); err != nil {
if err := expectOk(p.sr); err != nil {
return fatalf("%v", err)
}
}
Expand All @@ -275,14 +278,14 @@ ReadLoop:
Body: s.Body}).Marshal(os.Stdout); err != nil {
return fatalf("failed to write recipient-stanza: %v", err)
}
if err := expectOk(sr); err != nil {
if err := expectOk(p.sr); err != nil {
return fatalf("%v", err)
}
}
if sent, err := writeGrease(os.Stdout); err != nil {
return fatalf("failed to write grease: %v", err)
} else if sent {
if err := expectUnsupported(sr); err != nil {
if err := expectUnsupported(p.sr); err != nil {
return fatalf("%v", err)
}
}
Expand Down Expand Up @@ -321,10 +324,10 @@ func (p *Plugin) IdentityV1() int {
var files [][]*age.Stanza
var identityStrings []string

sr := format.NewStanzaReader(bufio.NewReader(os.Stdin))
p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin))
ReadLoop:
for {
s, err := sr.ReadStanza()
s, err := p.sr.ReadStanza()
if err != nil {
return fatalf("failed to read stanza: %v", err)
}
Expand Down Expand Up @@ -373,17 +376,17 @@ ReadLoop:
for i, s := range identityStrings {
name, data, err := ParseIdentity(s)
if err != nil {
return identityError(sr, i, err)
return identityError(p.sr, i, err)
}
if name != p.name {
return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name))
return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name))
}
if p.identity == nil {
return identityError(sr, i, fmt.Errorf("identity encodings not supported"))
return identityError(p.sr, i, fmt.Errorf("identity encodings not supported"))
}
r, err := p.identity(data)
if err != nil {
return identityError(sr, i, err)
return identityError(p.sr, i, err)
}
identities = append(identities, r)
}
Expand All @@ -392,7 +395,7 @@ ReadLoop:
if sent, err := writeGrease(os.Stdout); err != nil {
return fatalf("failed to write grease: %v", err)
} else if sent {
if err := expectUnsupported(sr); err != nil {
if err := expectUnsupported(p.sr); err != nil {
return fatalf("%v", err)
}
}
Expand All @@ -401,10 +404,12 @@ ReadLoop:
// in which identities are tried.
for _, id := range identities {
fk, err := id.Unwrap(ss)
if errors.Is(err, age.ErrIncorrectIdentity) {
if p.broken {
return 2
} else if errors.Is(err, age.ErrIncorrectIdentity) {
continue
} else if err != nil {
if err := writeError(sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil {
if err := writeError(p.sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil {
return fatalf("%v", err)
}
// Note that we don't exit here, as the protocol allows
Expand All @@ -416,7 +421,7 @@ ReadLoop:
if err := s.Marshal(os.Stdout); err != nil {
return fatalf("failed to write file-key: %v", err)
}
if err := expectOk(sr); err != nil {
if err := expectOk(p.sr); err != nil {
return fatalf("%v", err)
}
break
Expand All @@ -429,6 +434,86 @@ ReadLoop:
return 0
}

// DisplayMessage requests that the client display a message to the user. The
// message should start with a lowercase letter and have no final period.
// DisplayMessage returns an error if the client can't display the message, and
// may return before the message has been displayed to the user.
//
// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main].
func (p *Plugin) DisplayMessage(message string) error {
if err := writeStanzaWithBody(os.Stdout, "msg", []byte(message)); err != nil {
return p.fatalInteractf("failed to write msg stanza: %v", err)
}
s, err := readOkOrFail(p.sr)
if err != nil {
return p.fatalInteractf("%v", err)
}
if s.Type == "fail" {
return fmt.Errorf("client failed to display message")
}
return nil
}

// RequestValue requests a secret or public input from the user through the
// client, with the provided prompt. It returns an error if the client can't
// request the input or if the user dismisses the prompt.
//
// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main].
func (p *Plugin) RequestValue(prompt string, secret bool) (string, error) {
t := "request-public"
if secret {
t = "request-secret"
}
if err := writeStanzaWithBody(os.Stdout, t, []byte(prompt)); err != nil {
return "", p.fatalInteractf("failed to write stanza: %v", err)
}
s, err := readOkOrFail(p.sr)
if err != nil {
return "", p.fatalInteractf("%v", err)
}
if s.Type == "fail" {
return "", fmt.Errorf("client failed to request value")
}
return string(s.Body), nil
}

// Confirm requests a confirmation from the user through the client, with the
// provided prompt. The yes and no value are the choices provided to the user.
// no may be empty. The return value choseYes indicates whether the user
// selected the yes or no option. Confirm returns an error if the client can't
// request the confirmation.
//
// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main].
func (p *Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) {
args := []string{base64.StdEncoding.EncodeToString([]byte(yes))}
if no != "" {
args = append(args, base64.StdEncoding.EncodeToString([]byte(no)))
}
s := &format.Stanza{Type: "confirm", Args: args, Body: []byte(prompt)}
if err := s.Marshal(os.Stdout); err != nil {
return false, p.fatalInteractf("failed to write confirm stanza: %v", err)
}
s, err = readOkOrFail(p.sr)
if err != nil {
return false, p.fatalInteractf("%v", err)
}
if s.Type == "fail" {
return false, fmt.Errorf("client failed to request confirmation")
}
if err := expectStanzaWithNoBody(s, 1); err != nil {
return false, p.fatalInteractf("%v", err)
}
return s.Args[0] == "yes", nil
}

// fatalInteractf prints the error to stderr and sets the broken flag, so the
// Wrap/Unwrap caller can exit with an error.
func (p *Plugin) fatalInteractf(format string, args ...interface{}) error {
p.broken = true
fmt.Fprintf(os.Stderr, format, args...)
return fmt.Errorf(format, args...)
}

func expectStanzaWithNoBody(s *format.Stanza, wantArgs int) error {
if len(s.Args) != wantArgs {
return fmt.Errorf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs)
Expand Down Expand Up @@ -474,6 +559,27 @@ func expectOk(sr *format.StanzaReader) error {
return expectStanzaWithNoBody(ok, 0)
}

func readOkOrFail(sr *format.StanzaReader) (*format.Stanza, error) {
s, err := sr.ReadStanza()
if err != nil {
return nil, fmt.Errorf("failed to read response stanza: %v", err)
}
switch s.Type {
case "fail":
if err := expectStanzaWithNoBody(s, 0); err != nil {
return nil, fmt.Errorf("%v", err)
}
return s, nil
case "ok":
if s.Body != nil {
return nil, fmt.Errorf("ok stanza has %d bytes of body, want 0", len(s.Body))
}
return s, nil
default:
return nil, fmt.Errorf("expected ok or fail stanza, got %q", s.Type)
}
}

func expectUnsupported(sr *format.StanzaReader) error {
unsupported, err := sr.ReadStanza()
if err != nil {
Expand Down

0 comments on commit 7eedd92

Please sign in to comment.