diff --git a/main.go b/main.go index ae9c3d1..871ace7 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,7 @@ func main() { template = flag.String("template", DefaultTemplate, "The template used to determine what the SSM parameter name is for an environment variable. When this template returns an empty string, the env variable is not an SSM parameter") decrypt = flag.Bool("with-decryption", false, "Will attempt to decrypt the parameter, and set the env var as plaintext") nofail = flag.Bool("no-fail", false, "Don't fail if error retrieving parameter") + print = flag.Bool("print", false, "Print the decrypted env vars without exporting them and exit") print_version = flag.Bool("V", false, "Print the version and exit") ) flag.Parse() @@ -61,26 +62,51 @@ func main() { return } - if len(args) <= 0 { + if !*print && len(args) <= 0 { flag.Usage() - os.Exit(1) + fmt.Fprintf(os.Stderr, "\nmissing program to execute\n") + os.Exit(2) } - path, err := exec.LookPath(args[0]) - must(err) + if *print && len(args) > 0 { + flag.Usage() + fmt.Fprintf(os.Stderr, "\n-print is incompatible with arguments\n") + os.Exit(3) + } - var os osEnviron + var osEnv osEnviron + // Construct the template we'll use for extracting the ssm params we need to + // fetch. t, err := parseTemplate(*template) must(err) + + // Construct an expander with the configs for fetching/replacing env vars. e := &expander{ batchSize: defaultBatchSize, t: t, ssm: &lazySSMClient{}, - os: os, + os: osEnv, } - must(e.expandEnviron(*decrypt, *nofail)) - must(syscall.Exec(path, args[0:], os.Environ())) + // Attempt to "expand" ssm vars. + vars, err := e.expandEnviron(*decrypt, *nofail) + must(err) + + // Actually set the env vars for the process. + e.setEnviron(*print, vars) + // If -print was passed, we're done. + if *print { + os.Exit(0) + } + + // Make sure that we're invoking ssm-env with an executable that actually + // exists. + path, err := exec.LookPath(args[0]) + must(err) + + // Exec whatever command was passed, using the current process' env vars + // (which are now expanded). + must(syscall.Exec(path, args[0:], osEnv.Environ())) } // lazySSMClient wraps the AWS SDK SSM client such that the AWS session and @@ -124,6 +150,9 @@ func (c *lazySSMClient) awsSession() (*session.Session, error) { return sess, nil } +// Construct the template we use for parsing out ssm env var strings (by +// default, `DefaultTemplate`, which works with values like +// "ssm://:"). func parseTemplate(templateText string) (*template.Template, error) { return template.New("template").Funcs(TemplateFuncs).Parse(templateText) } @@ -134,7 +163,9 @@ type ssmClient interface { type environ interface { Environ() []string - Setenv(key, vale string) + Setenv(key, val string) + Getenv(key string) string + Write(s string) error } type osEnviron int @@ -147,6 +178,16 @@ func (e osEnviron) Setenv(key, val string) { os.Setenv(key, val) } +func (e osEnviron) Getenv(key string) string { + return os.Getenv(key) +} + +func (e osEnviron) Write(s string) error { + _, err := fmt.Println(s) + + return err +} + type ssmVar struct { envvar string parameter string @@ -172,7 +213,22 @@ func (e *expander) parameter(k, v string) (*string, error) { return nil, nil } -func (e *expander) expandEnviron(decrypt bool, nofail bool) error { +func (e *expander) setEnviron(print bool, vars map[string]string) { + // If -print was passed, just dump the decrypted env vars to stdout and return. + if print { + for k, v := range vars { + e.os.Write(fmt.Sprintf("%s=%s", k, v)) + } + + return + } + + for k, v := range vars { + e.os.Setenv(k, v) + } +} + +func (e *expander) expandEnviron(decrypt bool, nofail bool) (map[string]string, error) { // Environment variables that point to some SSM parameters. var ssmVars []ssmVar @@ -183,7 +239,7 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error { parameter, err := e.parameter(k, v) if err != nil { // TODO: Should this _also_ not error if nofail is passed? - return fmt.Errorf("determining name of parameter: %v", err) + return make(map[string]string), fmt.Errorf("determining name of parameter: %v", err) } if parameter != nil { @@ -194,16 +250,20 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error { if len(uniqNames) == 0 { // Nothing to do, no SSM parameters. - return nil + return make(map[string]string), nil } + // Construct a string slice to hold each ssm value. names := make([]string, len(uniqNames)) + // Go through and extract the values from uniqNames into the string slice. i := 0 for k := range uniqNames { names[i] = k i++ } + // For each chunk of batched ssm params, get the decrypted values. + decryptedVars := make(map[string]string) for i := 0; i < len(names); i += e.batchSize { j := i + e.batchSize if j > len(names) { @@ -212,18 +272,26 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error { values, err := e.getParameters(names[i:j], decrypt, nofail) if err != nil { - return err + return make(map[string]string), err + } + + if nofail && len(values) == 0 { + for _, v := range ssmVars { + decryptedVars[v.envvar] = e.os.Getenv(v.envvar) + } + + return decryptedVars, nil } for _, v := range ssmVars { val, ok := values[v.parameter] if ok { - e.os.Setenv(v.envvar, val) + decryptedVars[v.envvar] = val } } } - return nil + return decryptedVars, nil } func (e *expander) getParameters(names []string, decrypt bool, nofail bool) (map[string]string, error) { @@ -287,6 +355,7 @@ func splitVar(v string) (key, val string) { return parts[0], parts[1] } +// Abort with an error message if err is not nill. func must(err error) { if err != nil { fmt.Fprintf(os.Stderr, "ssm-env: %v\n", err) diff --git a/main_test.go b/main_test.go index bd52d01..48e9782 100644 --- a/main_test.go +++ b/main_test.go @@ -17,14 +17,16 @@ func TestExpandEnviron_NoSSMParameters(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -35,12 +37,38 @@ func TestExpandEnviron_NoSSMParameters(t *testing.T) { c.AssertExpectations(t) } +func TestExpandEnviron_NoSSMParametersPrint(t *testing.T) { + os := newFakeEnviron() + c := new(mockSSM) + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: &os, + ssm: c, + batchSize: defaultBatchSize, + } + + decrypt := false + nofail := false + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) + assert.NoError(t, err) + + assert.Equal(t, []string{ + "SHELL=/bin/bash", + "TERM=screen-256color", + }, os.Environ()) + assert.Equal(t, "", os.Stdout()) + + c.AssertExpectations(t) +} + func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -58,7 +86,9 @@ func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { decrypt := true nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -70,12 +100,50 @@ func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { c.AssertExpectations(t) } +func TestExpandEnviron_SimpleSSMParameterPrint(t *testing.T) { + os := newFakeEnviron() + c := new(mockSSM) + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: &os, + ssm: c, + batchSize: defaultBatchSize, + } + + os.Setenv("SUPER_SECRET", "ssm://secret") + + c.On("GetParameters", &ssm.GetParametersInput{ + Names: []*string{aws.String("secret")}, + WithDecryption: aws.Bool(true), + }).Return(&ssm.GetParametersOutput{ + Parameters: []*ssm.Parameter{ + {Name: aws.String("secret"), Value: aws.String("hehe")}, + }, + }, nil) + + decrypt := true + nofail := false + print := true + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) + assert.NoError(t, err) + + assert.Equal(t, []string{ + "SHELL=/bin/bash", + "SUPER_SECRET=ssm://secret", + "TERM=screen-256color", + }, os.Environ()) + assert.Equal(t, "SUPER_SECRET=hehe", os.Stdout()) + + c.AssertExpectations(t) +} + func TestExpandEnviron_VersionedSSMParameter(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -93,7 +161,9 @@ func TestExpandEnviron_VersionedSSMParameter(t *testing.T) { decrypt := true nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -110,7 +180,7 @@ func TestExpandEnviron_CustomTemplate(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(`{{ if eq .Name "SUPER_SECRET" }}secret{{end}}`)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -128,7 +198,9 @@ func TestExpandEnviron_CustomTemplate(t *testing.T) { decrypt := true nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -145,7 +217,7 @@ func TestExpandEnviron_DuplicateSSMParameter(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -164,7 +236,9 @@ func TestExpandEnviron_DuplicateSSMParameter(t *testing.T) { decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -182,7 +256,7 @@ func TestExpandEnviron_InvalidParameters(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -198,7 +272,9 @@ func TestExpandEnviron_InvalidParameters(t *testing.T) { decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.Equal(t, &invalidParametersError{InvalidParameters: []string{"secret"}}, err) c.AssertExpectations(t) @@ -209,7 +285,7 @@ func TestExpandEnviron_InvalidParametersNoFail(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: defaultBatchSize, } @@ -225,9 +301,11 @@ func TestExpandEnviron_InvalidParametersNoFail(t *testing.T) { decrypt := false nofail := true - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) - assert.NoError(t, err) + assert.NoError(t, err) assert.Equal(t, []string{ "SHELL=/bin/bash", "SUPER_SECRET=ssm://secret", @@ -242,7 +320,7 @@ func TestExpandEnviron_BatchParameters(t *testing.T) { c := new(mockSSM) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), - os: os, + os: &os, ssm: c, batchSize: 1, } @@ -270,7 +348,9 @@ func TestExpandEnviron_BatchParameters(t *testing.T) { decrypt := false nofail := false - err := e.expandEnviron(decrypt, nofail) + print := false + vars, err := e.expandEnviron(decrypt, nofail) + e.setEnviron(print, vars) assert.NoError(t, err) assert.Equal(t, []string{ @@ -283,18 +363,24 @@ func TestExpandEnviron_BatchParameters(t *testing.T) { c.AssertExpectations(t) } -type fakeEnviron map[string]string +type fakeEnviron struct { + env map[string]string + stdout string +} func newFakeEnviron() fakeEnviron { return fakeEnviron{ - "SHELL": "/bin/bash", - "TERM": "screen-256color", + env: map[string]string{ + "SHELL": "/bin/bash", + "TERM": "screen-256color", + }, + stdout: "", } } func (e fakeEnviron) Environ() []string { var env sort.StringSlice - for k, v := range e { + for k, v := range e.env { env = append(env, fmt.Sprintf("%s=%s", k, v)) } env.Sort() @@ -302,7 +388,21 @@ func (e fakeEnviron) Environ() []string { } func (e fakeEnviron) Setenv(key, val string) { - e[key] = val + e.env[key] = val +} + +func (e fakeEnviron) Getenv(key string) string { + return e.env[key] +} + +func (e *fakeEnviron) Write(s string) error { + e.stdout += s + + return nil +} + +func (e fakeEnviron) Stdout() string { + return e.stdout } type mockSSM struct {