diff --git a/cmd/schema.go b/cmd/schema.go index 936cf859..58d05e71 100644 --- a/cmd/schema.go +++ b/cmd/schema.go @@ -3,9 +3,6 @@ package cmd import ( "context" "fmt" - "os" - "path/filepath" - "strings" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/spf13/cobra" @@ -73,21 +70,8 @@ func writeSchema(_ context.Context, dryRun bool, cfg *config.AppConfig) { } if viper.GetBool("mermaid") || viper.GetBool("mermaid-markdown") { - if cfg.SpiceDB.PolicyDir != "" { - files, err := os.ReadDir(cfg.SpiceDB.PolicyDir) - if err != nil { - logger.Fatalw("failed to read policy files from directory", "error", err) - } - - filePaths := make([]string, 0, len(files)) - - for _, file := range files { - if !file.IsDir() && (strings.EqualFold(filepath.Ext(file.Name()), ".yml") || strings.EqualFold(filepath.Ext(file.Name()), ".yaml")) { - filePaths = append(filePaths, cfg.SpiceDB.PolicyDir+"/"+file.Name()) - } - } - - outputPolicyMermaid(filePaths, viper.GetBool("mermaid-markdown")) + if policyDir := cfg.SpiceDB.PolicyDir; policyDir != "" { + outputPolicyMermaid(policyDir, viper.GetBool("mermaid-markdown")) } return diff --git a/cmd/schema_mermaid.go b/cmd/schema_mermaid.go index 46cb2b03..7e41a64e 100644 --- a/cmd/schema_mermaid.go +++ b/cmd/schema_mermaid.go @@ -3,11 +3,8 @@ package cmd import ( "bytes" "fmt" - "os" "text/template" - "gopkg.in/yaml.v3" - "go.infratographer.com/permissions-api/internal/iapl" ) @@ -69,24 +66,16 @@ type mermaidContext struct { RBAC *iapl.RBAC } -func outputPolicyMermaid(filePaths []string, markdown bool) { - policy := iapl.PolicyDocument{} - - if len(filePaths) > 0 { - for _, filePath := range filePaths { - file, err := os.Open(filePath) - if err != nil { - logger.Fatalw("failed to open policy document file", "error", err) - } - defer file.Close() - - var filePolicy iapl.PolicyDocument - - if err := yaml.NewDecoder(file).Decode(&filePolicy); err != nil { - logger.Fatalw("failed to open policy document file", "error", err) - } +func outputPolicyMermaid(dirPath string, markdown bool) { + var ( + policy iapl.PolicyDocument + err error + ) - policy = policy.MergeWithPolicyDocument(filePolicy) + if dirPath != "" { + policy, err = iapl.LoadPolicyDocumentFromDirectory(dirPath) + if err != nil { + logger.Fatalw("failed to load policy documents", "error", err) } } else { policy = iapl.DefaultPolicyDocument() diff --git a/internal/iapl/policy.go b/internal/iapl/policy.go index 17dd27e2..f79146a0 100644 --- a/internal/iapl/policy.go +++ b/internal/iapl/policy.go @@ -1,7 +1,10 @@ package iapl import ( + "errors" "fmt" + "io" + "io/fs" "os" "path/filepath" "strings" @@ -151,65 +154,116 @@ func (p PolicyDocument) MergeWithPolicyDocument(other PolicyDocument) PolicyDocu return p } -// NewPolicyFromFile reads the provided file path and returns a new Policy. -func NewPolicyFromFile(filePath string) (Policy, error) { +func loadPolicyDocumentFromFile(filePath string) (PolicyDocument, error) { file, err := os.Open(filePath) if err != nil { - return nil, err + return PolicyDocument{}, fmt.Errorf("%s: %w", filePath, err) } - var policy PolicyDocument + defer file.Close() - if err := yaml.NewDecoder(file).Decode(&policy); err != nil { - return nil, err + var ( + finalPolicyDocument = PolicyDocument{} + decoder = yaml.NewDecoder(file) + documentIndex int + ) + + for { + var policyDocument PolicyDocument + + if err = decoder.Decode(&policyDocument); err != nil { + if !errors.Is(err, io.EOF) { + return PolicyDocument{}, fmt.Errorf("%s document %d: %w", filePath, documentIndex, err) + } + + break + } + + if finalPolicyDocument.RBAC != nil && policyDocument.RBAC != nil { + return PolicyDocument{}, fmt.Errorf("%s document %d: %w", filePath, documentIndex, ErrorDuplicateRBACDefinition) + } + + finalPolicyDocument = finalPolicyDocument.MergeWithPolicyDocument(policyDocument) + + documentIndex++ } - return NewPolicy(policy), nil + return finalPolicyDocument, nil } -// NewPolicyFromFiles reads the provided file paths, merges them, and returns a new Policy. -func NewPolicyFromFiles(filePaths []string) (Policy, error) { - mergedPolicy := PolicyDocument{} +// LoadPolicyDocumentFromFiles loads all policy documents in the order provided and returns a merged PolicyDocument. +func LoadPolicyDocumentFromFiles(filePaths ...string) (PolicyDocument, error) { + var policyDocument PolicyDocument for _, filePath := range filePaths { - file, err := os.Open(filePath) + filePolicyDocument, err := loadPolicyDocumentFromFile(filePath) if err != nil { - return nil, err + return PolicyDocument{}, err } - defer file.Close() - var filePolicy PolicyDocument + policyDocument = policyDocument.MergeWithPolicyDocument(filePolicyDocument) + } + + return policyDocument, nil +} - if err := yaml.NewDecoder(file).Decode(&filePolicy); err != nil { - return nil, err +// LoadPolicyDocumentFromDirectory reads the provided directory path, reads all files in the directory, merges them, and returns a new merged PolicyDocument. +func LoadPolicyDocumentFromDirectory(directoryPath string) (PolicyDocument, error) { + var filePaths []string + + err := filepath.WalkDir(directoryPath, func(path string, entry fs.DirEntry, err error) error { + if err != nil { + return err } - if mergedPolicy.RBAC != nil && filePolicy.RBAC != nil { - return nil, ErrorDuplicateRBACDefinition + if entry.IsDir() { + return nil } - mergedPolicy = mergedPolicy.MergeWithPolicyDocument(filePolicy) + ext := filepath.Ext(entry.Name()) + + if strings.EqualFold(ext, ".yml") || strings.EqualFold(ext, ".yaml") { + filePaths = append(filePaths, path) + } + + return nil + }) + + if err != nil { + return PolicyDocument{}, err } - return NewPolicy(mergedPolicy), nil + return LoadPolicyDocumentFromFiles(filePaths...) } -// NewPolicyFromDirectory reads the provided directory path, reads all files in the directory, merges them, and returns a new Policy. -func NewPolicyFromDirectory(directoryPath string) (Policy, error) { - files, err := os.ReadDir(directoryPath) +// NewPolicyFromFile reads the provided file path and returns a new Policy. +func NewPolicyFromFile(filePath string) (Policy, error) { + policyDocument, err := LoadPolicyDocumentFromFiles(filePath) if err != nil { return nil, err } - filePaths := make([]string, 0, len(files)) + return NewPolicy(policyDocument), nil +} - for _, file := range files { - if !file.IsDir() && (strings.EqualFold(filepath.Ext(file.Name()), ".yml") || strings.EqualFold(filepath.Ext(file.Name()), ".yaml")) { - filePaths = append(filePaths, directoryPath+"/"+file.Name()) - } +// NewPolicyFromFiles reads the provided file paths, merges them, and returns a new Policy. +func NewPolicyFromFiles(filePaths []string) (Policy, error) { + policyDocument, err := LoadPolicyDocumentFromFiles(filePaths...) + if err != nil { + return nil, err + } + + return NewPolicy(policyDocument), nil +} + +// NewPolicyFromDirectory reads the provided directory path, reads all files in the directory, merges them, and returns a new Policy. +func NewPolicyFromDirectory(directoryPath string) (Policy, error) { + policyDocument, err := LoadPolicyDocumentFromDirectory(directoryPath) + if err != nil { + return nil, err } - return NewPolicyFromFiles(filePaths) + return NewPolicy(policyDocument), nil } func (v *policy) validateUnions() error {