Skip to content

Commit

Permalink
Add -print flag
Browse files Browse the repository at this point in the history
Add a -print flag which accepts no additional arguments and prints any
found SSM env vars to stdout instead of exec-ing a process with the env
vars set.

```sh
-print
      Print the decrypted env vars without exporting them and exit
```

The use-case for this is in places like CI jobs where you may want to
resolve SSM parameters and then write them to a config file, or persist
them elsewhere for subsequent use.

ssm-env is already a bit architecturally overloaded, and this strains it
further. I'm not inclined to do a major refactor/rewrite at this point,
but if we want to continue extending it that may be required at some
point. I'd prob start by separating the interfaces for outputs and
fallibility to avoid overloading the expandEnviron/setEnviron functions
the way they currently are.
  • Loading branch information
aengelas committed Dec 12, 2023
1 parent 70be7a2 commit e13aeb8
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 37 deletions.
99 changes: 84 additions & 15 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func main() {
template = flag.String("template", DefaultTemplate, "The template used to determine what the SSM parameter name is for an environment variable. When this template returns an empty string, the env variable is not an SSM parameter")
decrypt = flag.Bool("with-decryption", false, "Will attempt to decrypt the parameter, and set the env var as plaintext")
nofail = flag.Bool("no-fail", false, "Don't fail if error retrieving parameter")
print = flag.Bool("print", false, "Print the decrypted env vars without exporting them and exit")
print_version = flag.Bool("V", false, "Print the version and exit")
)
flag.Parse()
Expand All @@ -61,26 +62,51 @@ func main() {
return
}

if len(args) <= 0 {
if !*print && len(args) <= 0 {
flag.Usage()
os.Exit(1)
fmt.Fprintf(os.Stderr, "\nmissing program to execute\n")
os.Exit(2)
}

path, err := exec.LookPath(args[0])
must(err)
if *print && len(args) > 0 {
flag.Usage()
fmt.Fprintf(os.Stderr, "\n-print is incompatible with arguments\n")
os.Exit(3)
}

var os osEnviron
var osEnv osEnviron

// Construct the template we'll use for extracting the ssm params we need to
// fetch.
t, err := parseTemplate(*template)
must(err)

// Construct an expander with the configs for fetching/replacing env vars.
e := &expander{
batchSize: defaultBatchSize,
t: t,
ssm: &lazySSMClient{},
os: os,
os: osEnv,
}
must(e.expandEnviron(*decrypt, *nofail))
must(syscall.Exec(path, args[0:], os.Environ()))
// Attempt to "expand" ssm vars.
vars, err := e.expandEnviron(*decrypt, *nofail)
must(err)

// Actually set the env vars for the process.
e.setEnviron(*print, vars)
// If -print was passed, we're done.
if *print {
os.Exit(0)
}

// Make sure that we're invoking ssm-env with an executable that actually
// exists.
path, err := exec.LookPath(args[0])
must(err)

// Exec whatever command was passed, using the current process' env vars
// (which are now expanded).
must(syscall.Exec(path, args[0:], osEnv.Environ()))
}

// lazySSMClient wraps the AWS SDK SSM client such that the AWS session and
Expand Down Expand Up @@ -124,6 +150,9 @@ func (c *lazySSMClient) awsSession() (*session.Session, error) {
return sess, nil
}

// Construct the template we use for parsing out ssm env var strings (by
// default, `DefaultTemplate`, which works with values like
// "ssm://<path>:<version>").
func parseTemplate(templateText string) (*template.Template, error) {
return template.New("template").Funcs(TemplateFuncs).Parse(templateText)
}
Expand All @@ -134,7 +163,9 @@ type ssmClient interface {

type environ interface {
Environ() []string
Setenv(key, vale string)
Setenv(key, val string)
Getenv(key string) string
Write(s string) error
}

type osEnviron int
Expand All @@ -147,6 +178,16 @@ func (e osEnviron) Setenv(key, val string) {
os.Setenv(key, val)
}

func (e osEnviron) Getenv(key string) string {
return os.Getenv(key)
}

func (e osEnviron) Write(s string) error {
_, err := fmt.Println(s)

return err
}

type ssmVar struct {
envvar string
parameter string
Expand All @@ -172,7 +213,22 @@ func (e *expander) parameter(k, v string) (*string, error) {
return nil, nil
}

func (e *expander) expandEnviron(decrypt bool, nofail bool) error {
func (e *expander) setEnviron(print bool, vars map[string]string) {
// If -print was passed, just dump the decrypted env vars to stdout and return.
if print {
for k, v := range vars {
e.os.Write(fmt.Sprintf("%s=%s", k, v))
}

return
}

for k, v := range vars {
e.os.Setenv(k, v)
}
}

func (e *expander) expandEnviron(decrypt bool, nofail bool) (map[string]string, error) {
// Environment variables that point to some SSM parameters.
var ssmVars []ssmVar

Expand All @@ -183,7 +239,7 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error {
parameter, err := e.parameter(k, v)
if err != nil {
// TODO: Should this _also_ not error if nofail is passed?
return fmt.Errorf("determining name of parameter: %v", err)
return make(map[string]string), fmt.Errorf("determining name of parameter: %v", err)
}

if parameter != nil {
Expand All @@ -194,16 +250,20 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error {

if len(uniqNames) == 0 {
// Nothing to do, no SSM parameters.
return nil
return make(map[string]string), nil
}

// Construct a string slice to hold each ssm value.
names := make([]string, len(uniqNames))
// Go through and extract the values from uniqNames into the string slice.
i := 0
for k := range uniqNames {
names[i] = k
i++
}

// For each chunk of batched ssm params, get the decrypted values.
decryptedVars := make(map[string]string)
for i := 0; i < len(names); i += e.batchSize {
j := i + e.batchSize
if j > len(names) {
Expand All @@ -212,18 +272,26 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error {

values, err := e.getParameters(names[i:j], decrypt, nofail)
if err != nil {
return err
return make(map[string]string), err
}

if nofail && len(values) == 0 {
for _, v := range ssmVars {
decryptedVars[v.envvar] = e.os.Getenv(v.envvar)
}

return decryptedVars, nil
}

for _, v := range ssmVars {
val, ok := values[v.parameter]
if ok {
e.os.Setenv(v.envvar, val)
decryptedVars[v.envvar] = val
}
}
}

return nil
return decryptedVars, nil
}

func (e *expander) getParameters(names []string, decrypt bool, nofail bool) (map[string]string, error) {
Expand Down Expand Up @@ -287,6 +355,7 @@ func splitVar(v string) (key, val string) {
return parts[0], parts[1]
}

// Abort with an error message if err is not nill.
func must(err error) {
if err != nil {
fmt.Fprintf(os.Stderr, "ssm-env: %v\n", err)
Expand Down
Loading

0 comments on commit e13aeb8

Please sign in to comment.