From e13aeb8c2a77e5f74d10d523bbf70cc323d0d66d Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 12 Dec 2023 10:59:50 -0800 Subject: [PATCH] Add -print flag Add a -print flag which accepts no additional arguments and prints any found SSM env vars to stdout instead of exec-ing a process with the env vars set. ```sh -print Print the decrypted env vars without exporting them and exit ``` The use-case for this is in places like CI jobs where you may want to resolve SSM parameters and then write them to a config file, or persist them elsewhere for subsequent use. ssm-env is already a bit architecturally overloaded, and this strains it further. I'm not inclined to do a major refactor/rewrite at this point, but if we want to continue extending it that may be required at some point. I'd prob start by separating the interfaces for outputs and fallibility to avoid overloading the expandEnviron/setEnviron functions the way they currently are. --- main.go | 99 +++++++++++++++++++++++++++++------ main_test.go | 144 +++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 206 insertions(+), 37 deletions(-) diff --git a/main.go b/main.go index ae9c3d1..871ace7 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,7 @@ func main() { template = flag.String("template", DefaultTemplate, "The template used to determine what the SSM parameter name is for an environment variable. When this template returns an empty string, the env variable is not an SSM parameter") decrypt = flag.Bool("with-decryption", false, "Will attempt to decrypt the parameter, and set the env var as plaintext") nofail = flag.Bool("no-fail", false, "Don't fail if error retrieving parameter") + print = flag.Bool("print", false, "Print the decrypted env vars without exporting them and exit") print_version = flag.Bool("V", false, "Print the version and exit") ) flag.Parse() @@ -61,26 +62,51 @@ func main() { return } - if len(args) <= 0 { + if !*print && len(args) <= 0 { flag.Usage() - os.Exit(1) + fmt.Fprintf(os.Stderr, "\nmissing program to execute\n") + os.Exit(2) } - path, err := exec.LookPath(args[0]) - must(err) + if *print && len(args) > 0 { + flag.Usage() + fmt.Fprintf(os.Stderr, "\n-print is incompatible with arguments\n") + os.Exit(3) + } - var os osEnviron + var osEnv osEnviron + // Construct the template we'll use for extracting the ssm params we need to + // fetch. t, err := parseTemplate(*template) must(err) + + // Construct an expander with the configs for fetching/replacing env vars. e := &expander{ batchSize: defaultBatchSize, t: t, ssm: &lazySSMClient{}, - os: os, + os: osEnv, } - must(e.expandEnviron(*decrypt, *nofail)) - must(syscall.Exec(path, args[0:], os.Environ())) + // Attempt to "expand" ssm vars. + vars, err := e.expandEnviron(*decrypt, *nofail) + must(err) + + // Actually set the env vars for the process. + e.setEnviron(*print, vars) + // If -print was passed, we're done. + if *print { + os.Exit(0) + } + + // Make sure that we're invoking ssm-env with an executable that actually + // exists. + path, err := exec.LookPath(args[0]) + must(err) + + // Exec whatever command was passed, using the current process' env vars + // (which are now expanded). + must(syscall.Exec(path, args[0:], osEnv.Environ())) } // lazySSMClient wraps the AWS SDK SSM client such that the AWS session and @@ -124,6 +150,9 @@ func (c *lazySSMClient) awsSession() (*session.Session, error) { return sess, nil } +// Construct the template we use for parsing out ssm env var strings (by +// default, `DefaultTemplate`, which works with values like +// "ssm://:"). func parseTemplate(templateText string) (*template.Template, error) { return template.New("template").Funcs(TemplateFuncs).Parse(templateText) } @@ -134,7 +163,9 @@ type ssmClient interface { type environ interface { Environ() []string - Setenv(key, vale string) + Setenv(key, val string) + Getenv(key string) string + Write(s string) error } type osEnviron int @@ -147,6 +178,16 @@ func (e osEnviron) Setenv(key, val string) { os.Setenv(key, val) } +func (e osEnviron) Getenv(key string) string { + return os.Getenv(key) +} + +func (e osEnviron) Write(s string) error { + _, err := fmt.Println(s) + + return err +} + type ssmVar struct { envvar string parameter string @@ -172,7 +213,22 @@ func (e *expander) parameter(k, v string) (*string, error) { return nil, nil } -func (e *expander) expandEnviron(decrypt bool, nofail bool) error { +func (e *expander) setEnviron(print bool, vars map[string]string) { + // If -print was passed, just dump the decrypted env vars to stdout and return. + if print { + for k, v := range vars { + e.os.Write(fmt.Sprintf("%s=%s", k, v)) + } + + return + } + + for k, v := range vars { + e.os.Setenv(k, v) + } +} + +func (e *expander) expandEnviron(decrypt bool, nofail bool) (map[string]string, error) { // Environment variables that point to some SSM parameters. var ssmVars []ssmVar @@ -183,7 +239,7 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error { parameter, err := e.parameter(k, v) if err != nil { // TODO: Should this _also_ not error if nofail is passed? - return fmt.Errorf("determining name of parameter: %v", err) + return make(map[string]string), fmt.Errorf("determining name of parameter: %v", err) } if parameter != nil { @@ -194,16 +250,20 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error { if len(uniqNames) == 0 { // Nothing to do, no SSM parameters. - return nil + return make(map[string]string), nil } + // Construct a string slice to hold each ssm value. names := make([]string, len(uniqNames)) + // Go through and extract the values from uniqNames into the string slice. i := 0 for k := range uniqNames { names[i] = k i++ } + // For each chunk of batched ssm params, get the decrypted values. + decryptedVars := make(map[string]string) for i := 0; i < len(names); i += e.batchSize { j := i + e.batchSize if j > len(names) { @@ -212,18 +272,26 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error { values, err := e.getParameters(names[i:j], decrypt, nofail) if err != nil { - return err + return make(map[string]string), err + } + + if nofail && len(values) == 0 { + for _, v := range ssmVars { + decryptedVars[v.envvar] = e.os.Getenv(v.envvar) + } + + return decryptedVars, nil } for _, v := range ssmVars { val, ok := values[v.parameter] if ok { - e.os.Setenv(v.envvar, val) + decryptedVars[v.envvar] = val } } } - return nil + return decryptedVars, nil } func (e *expander) getParameters(names []string, decrypt bool, nofail bool) (map[string]string, error) { @@ -287,6 +355,7 @@ func splitVar(v string) (key, val string) { return parts[0], parts[1] } +// Abort with an error message if err is not nill. func must(err error) { if err != nil { fmt.Fprintf(os.Stderr, "ssm-env: %v\n", err) diff --git a/main_test.go b/main_test.go index bd52d01..48e9782 100644 --- a/main_test.go +++ b/main_test.go @@ -17,14 +17,16 @@ func TestExpandEnviron_NoSSMParameters(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -35,12 +37,38 @@ func TestExpandEnviron_NoSSMParameters(t *testing.T) { c.AssertExpectations(t) } +func TestExpandEnviron_NoSSMParametersPrint(t *testing.T) { + os := newFakeEnviron() + c := new(mockSSM) + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: &os, + ssm: c, + batchSize: defaultBatchSize, + } + + decrypt := false + nofail := false + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) + assert.NoError(t, err) + + assert.Equal(t, []string{ + "SHELL=/bin/bash", + "TERM=screen-256color", + }, os.Environ()) + assert.Equal(t, "", os.Stdout()) + + c.AssertExpectations(t) +} + func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -58,7 +86,9 @@ func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { decrypt := true nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -70,12 +100,50 @@ func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { c.AssertExpectations(t) } +func TestExpandEnviron_SimpleSSMParameterPrint(t *testing.T) { + os := newFakeEnviron() + c := new(mockSSM) + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: &os, + ssm: c, + batchSize: defaultBatchSize, + } + + os.Setenv("SUPER_SECRET", "ssm://secret") + + c.On("GetParameters", &ssm.GetParametersInput{ + Names: []*string{aws.String("secret")}, + WithDecryption: aws.Bool(true), + }).Return(&ssm.GetParametersOutput{ + Parameters: []*ssm.Parameter{ + {Name: aws.String("secret"), Value: aws.String("hehe")}, + }, + }, nil) + + decrypt := true + nofail := false + print := true + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) + assert.NoError(t, err) + + assert.Equal(t, []string{ + "SHELL=/bin/bash", + "SUPER_SECRET=ssm://secret", + "TERM=screen-256color", + }, os.Environ()) + assert.Equal(t, "SUPER_SECRET=hehe", os.Stdout()) + + c.AssertExpectations(t) +} + func TestExpandEnviron_VersionedSSMParameter(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -93,7 +161,9 @@ func TestExpandEnviron_VersionedSSMParameter(t *testing.T) { decrypt := true nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -110,7 +180,7 @@ func TestExpandEnviron_CustomTemplate(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(`{{ if eq .Name "SUPER_SECRET" }}secret{{end}}`)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -128,7 +198,9 @@ func TestExpandEnviron_CustomTemplate(t *testing.T) { decrypt := true nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -145,7 +217,7 @@ func TestExpandEnviron_DuplicateSSMParameter(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -164,7 +236,9 @@ func TestExpandEnviron_DuplicateSSMParameter(t *testing.T) { decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -182,7 +256,7 @@ func TestExpandEnviron_InvalidParameters(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -198,7 +272,9 @@ func TestExpandEnviron_InvalidParameters(t *testing.T) { decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.Equal(t, &invalidParametersError{InvalidParameters: []string{"secret"}}, err) c.AssertExpectations(t) @@ -209,7 +285,7 @@ func TestExpandEnviron_InvalidParametersNoFail(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -225,9 +301,11 @@ func TestExpandEnviron_InvalidParametersNoFail(t *testing.T) { decrypt := false nofail := true - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) - assert.NoError(t, err) + assert.NoError(t, err) assert.Equal(t, []string{ "SHELL=/bin/bash", "SUPER_SECRET=ssm://secret", @@ -242,7 +320,7 @@ func TestExpandEnviron_BatchParameters(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: 1, } @@ -270,7 +348,9 @@ func TestExpandEnviron_BatchParameters(t *testing.T) { decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -283,18 +363,24 @@ func TestExpandEnviron_BatchParameters(t *testing.T) { c.AssertExpectations(t) } -type fakeEnviron map[string]string +type fakeEnviron struct { + env map[string]string + stdout string +} func newFakeEnviron() fakeEnviron { return fakeEnviron{ - "SHELL": "/bin/bash", - "TERM": "screen-256color", + env: map[string]string{ + "SHELL": "/bin/bash", + "TERM": "screen-256color", + }, + stdout: "", } } func (e fakeEnviron) Environ() []string { var env sort.StringSlice - for k, v := range e { + for k, v := range e.env { env = append(env, fmt.Sprintf("%s=%s", k, v)) } env.Sort() @@ -302,7 +388,21 @@ func (e fakeEnviron) Environ() []string { } func (e fakeEnviron) Setenv(key, val string) { - e[key] = val + e.env[key] = val +} + +func (e fakeEnviron) Getenv(key string) string { + return e.env[key] +} + +func (e *fakeEnviron) Write(s string) error { + e.stdout += s + + return nil +} + +func (e fakeEnviron) Stdout() string { + return e.stdout } type mockSSM struct {