Skip to content

Commit

Permalink
fix: resolve tilda paths and validate datadir config
Browse files Browse the repository at this point in the history
- closes #4484
  • Loading branch information
frrist committed Sep 19, 2024
1 parent ce04ff7 commit 0387640
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 32 deletions.
28 changes: 28 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"path/filepath"
"strings"

"github.com/mitchellh/go-homedir"
"github.com/mitchellh/mapstructure"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/spf13/pflag"
"github.com/spf13/viper"

"github.com/bacalhau-project/bacalhau/pkg/config/types"
"github.com/bacalhau-project/bacalhau/pkg/models"
"github.com/bacalhau-project/bacalhau/pkg/util/idgen"
)

Expand Down Expand Up @@ -208,6 +210,29 @@ func New(opts ...Option) (*Config, error) {
absoluteConfigPaths[i] = path
}
}

// allow the users to set datadir to a path like ~/.bacalhau or ~/something/idk/whatever
// and expand the path for them
dataDirPath := c.base.GetString(types.DataDirKey)
if dataDirPath[0] == '~' {
log.Info().Msgf("configuration field 'DataDir' contains '~': (%s). Attempting to expand to the home directory...", dataDirPath)
expanded, err := homedir.Expand(dataDirPath)
if err == nil {
dataDirPath = expanded
c.base.Set(types.DataDirKey, dataDirPath)
log.Info().Msgf("successfully expanded data directory to %s", dataDirPath)
}
}

// validate the config
var cfg types.Bacalhau
if err := c.base.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("config invalid: %w", err)
}

log.Info().Msgf("Config loaded from: %s, and with data-dir %s", absoluteConfigPaths, c.base.Get(types.DataDirKey))
return c, nil
}
Expand Down Expand Up @@ -270,6 +295,9 @@ func (c *Config) Unmarshal(out interface{}) error {
if err := c.base.Unmarshal(&out, DecoderHook); err != nil {
return err
}
if v, ok := out.(models.Validatable); ok {
return v.Validate()
}
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ import (
)

func TestConfigWithValueOverrides(t *testing.T) {
overrideRepo := "overrideRepo"
overrideRepo := "/overrideRepo"
overrideName := "puuid"
overrideClientAddress := "1.1.1.1"
overrideClientPort := 1234

defaultConfig := types.Bacalhau{
DataDir: "defaultRepo",
DataDir: "/defaultRepo",
API: types.API{
Host: "0.0.0.0",
Port: 1234,
Expand Down
15 changes: 15 additions & 0 deletions pkg/config/types/bacalhau.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package types

import (
"errors"
"fmt"
"path/filepath"
)

// NB: Developers, after making changes (comments included) to this struct or any of its children, run go generate.

//go:generate go run gen/generate.go ./
Expand All @@ -26,6 +32,15 @@ type Bacalhau struct {
DisableAnalytics bool `yaml:"DisableAnalytics,omitempty"`
}

func (b Bacalhau) Validate() error {
var err error
if !filepath.IsAbs(b.DataDir) {
err = errors.Join(err, fmt.Errorf("DataDir (%q) must be an absolute path", b.DataDir))
}

return err
}

type API struct {
// Host specifies the hostname or IP address on which the API server listens or the client connects.
Host string `yaml:"Host,omitempty"`
Expand Down
28 changes: 27 additions & 1 deletion pkg/config/types/default_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"runtime"
"time"

"github.com/rs/zerolog"
"github.com/rs/zerolog/log"

"github.com/bacalhau-project/bacalhau/pkg/publisher/local"
)

Expand Down Expand Up @@ -142,13 +145,36 @@ const defaultBacalhauDir = ".bacalhau"
// 2. User's home directory with .bacalhau appended.
// 3. If all above fail, use .bacalhau in the current directory.
func DefaultDataDir() string {
// this method runs before root.go, so we set the level to info for these calls, then return it to previous value
currentLevel := zerolog.GlobalLevel()
zerolog.SetGlobalLevel(zerolog.InfoLevel)
defer zerolog.SetGlobalLevel(currentLevel)

// Check if the BACALHAU_DIR environment variable is set
if repoDir, set := os.LookupEnv("BACALHAU_DIR"); set && repoDir != "" {
return repoDir
} else if set {
log.Warn().Msg("BACALHAU_DIR environment variable is set but empty. Falling back to default directories.")
} else {
log.Debug().Msg("BACALHAU_DIR environment variable is not set. Trying to use $HOME.")
}

if userHome, err := os.UserHomeDir(); err == nil {
// Attempt to get the user's home directory
if userHome, err := os.UserHomeDir(); err == nil && filepath.IsAbs(userHome) {
log.Trace().Str("HomeDirectory", userHome).Msg("Successfully found $HOME. Using it for the data directory.")
return filepath.Join(userHome, defaultBacalhauDir)
} else {
log.Warn().Err(err).Msg("$HOME is unset or inaccessible. Falling back to current working directory.")
}

// Fallback: attempt to use the absolute path of the default directory
path, err := filepath.Abs(defaultBacalhauDir)
if err == nil {
log.Info().Str("Directory", path).Msg("Bacalhau will initialize in current working directory.")
return path
}

// If everything fails, return the default directory string
log.Error().Err(err).Msg("Failed to determine absolute path for the default Bacalhau directory. Using the raw default path.")
return defaultBacalhauDir
}
58 changes: 29 additions & 29 deletions pkg/config/types/paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (

const UserKeyFileName = "user_id.pem"

func (c Bacalhau) UserKeyPath() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) UserKeyPath() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
path := filepath.Join(c.DataDir, UserKeyFileName)
path := filepath.Join(b.DataDir, UserKeyFileName)
if exists, err := fileExists(path); err != nil {
return "", fmt.Errorf("checking if user key exists: %w", err)
} else if exists {
Expand All @@ -25,20 +25,20 @@ func (c Bacalhau) UserKeyPath() (string, error) {

const AuthTokensFileName = "tokens.json"

func (c Bacalhau) AuthTokensPath() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) AuthTokensPath() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
return filepath.Join(c.DataDir, AuthTokensFileName), nil
return filepath.Join(b.DataDir, AuthTokensFileName), nil
}

const OrchestratorDirName = "orchestrator"

func (c Bacalhau) OrchestratorDir() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) OrchestratorDir() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
path := filepath.Join(c.DataDir, OrchestratorDirName)
path := filepath.Join(b.DataDir, OrchestratorDirName)
if err := ensureDir(path); err != nil {
return "", fmt.Errorf("getting orchestrator path: %w", err)
}
Expand All @@ -47,24 +47,24 @@ func (c Bacalhau) OrchestratorDir() (string, error) {

const JobStoreFileName = "state_boltdb.db"

func (c Bacalhau) JobStoreFilePath() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) JobStoreFilePath() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
// make sure the parent dir exists first
if _, err := c.OrchestratorDir(); err != nil {
if _, err := b.OrchestratorDir(); err != nil {
return "", fmt.Errorf("getting job store path: %w", err)
}
return filepath.Join(c.DataDir, OrchestratorDirName, JobStoreFileName), nil
return filepath.Join(b.DataDir, OrchestratorDirName, JobStoreFileName), nil
}

const NetworkTransportDirName = "nats-store"

func (c Bacalhau) NetworkTransportDir() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) NetworkTransportDir() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
path := filepath.Join(c.DataDir, OrchestratorDirName, NetworkTransportDirName)
path := filepath.Join(b.DataDir, OrchestratorDirName, NetworkTransportDirName)
if err := ensureDir(path); err != nil {
return "", fmt.Errorf("getting network transport path: %w", err)
}
Expand All @@ -73,11 +73,11 @@ func (c Bacalhau) NetworkTransportDir() (string, error) {

const ComputeDirName = "compute"

func (c Bacalhau) ComputeDir() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) ComputeDir() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
path := filepath.Join(c.DataDir, ComputeDirName)
path := filepath.Join(b.DataDir, ComputeDirName)
if err := ensureDir(path); err != nil {
return "", fmt.Errorf("getting compute path: %w", err)
}
Expand All @@ -86,11 +86,11 @@ func (c Bacalhau) ComputeDir() (string, error) {

const ExecutionDirName = "executions"

func (c Bacalhau) ExecutionDir() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) ExecutionDir() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
path := filepath.Join(c.DataDir, ComputeDirName, ExecutionDirName)
path := filepath.Join(b.DataDir, ComputeDirName, ExecutionDirName)
if err := ensureDir(path); err != nil {
return "", fmt.Errorf("getting executions path: %w", err)
}
Expand All @@ -99,11 +99,11 @@ func (c Bacalhau) ExecutionDir() (string, error) {

const PluginsDirName = "plugins"

func (c Bacalhau) PluginsDir() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) PluginsDir() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
path := filepath.Join(c.DataDir, PluginsDirName)
path := filepath.Join(b.DataDir, PluginsDirName)
if err := ensureDir(path); err != nil {
return "", fmt.Errorf("getting plugins path: %w", err)
}
Expand All @@ -112,12 +112,12 @@ func (c Bacalhau) PluginsDir() (string, error) {

const ExecutionStoreFileName = "state_boltdb.db"

func (c Bacalhau) ExecutionStoreFilePath() (string, error) {
if c.DataDir == "" {
func (b Bacalhau) ExecutionStoreFilePath() (string, error) {
if b.DataDir == "" {
return "", fmt.Errorf("data dir not set")
}
if _, err := c.ComputeDir(); err != nil {
if _, err := b.ComputeDir(); err != nil {
return "", fmt.Errorf("getting execution store path: %w", err)
}
return filepath.Join(c.DataDir, ComputeDirName, ExecutionStoreFileName), nil
return filepath.Join(b.DataDir, ComputeDirName, ExecutionStoreFileName), nil
}

0 comments on commit 0387640

Please sign in to comment.