diff --git a/cspp/README.md b/cspp/README.md index 27201d5..8b5a3b4 100644 --- a/cspp/README.md +++ b/cspp/README.md @@ -34,7 +34,9 @@ You need to set the following enviornment variables. | `CSPP_UPLOADS_DIR` | directory to save images | optional | `/var/lib/cspp/uploads` | `./data/uploads` | | `CSPP_CREDENTIALS_DIR`| directory to save API keys as json blobs | optional | `/var/lib/cspp/credentials` | `./data/credentials` | -:warning: If you specifiy `CSPP_BASE_URL` with a port on the string and specify `CSPP_PORT` and they do not match, you may get unpredictable results +:warning: How do ports work? + +If you specify a port via `CSPP_PORT` and `CSPP_BASE_URL` the one found in `CSPP_BASE_URL` will be used. If you don't specify a port in `CSPP_BASE_URL` the one found in `CSPP_PORT` will be used. If neither is specified, the default port `8080` will be used. ## Slack Specifics diff --git a/cspp/main.go b/cspp/main.go index e65476e..db72de5 100644 --- a/cspp/main.go +++ b/cspp/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "net/url" "os" "github.com/fsnotify/fsnotify" @@ -66,6 +67,32 @@ func init() { setupDirectory(viper.GetString("processed_dir")) setupDirectory(viper.GetString("uploads_dir")) setupDirectory(viper.GetString("credentials_dir")) + + validatePortVsBaseURL() + +} + +func validatePortVsBaseURL() { + log.Debugln("validatePortVsBaseURL") + baseurl := viper.GetString("base_url") + port := viper.GetString("port") + if baseurl != "" && port != "" { + parsedURL, err := url.Parse(baseurl) + if err != nil { + log.Errorln("Error parsing base URL:", err) + os.Exit(1) + } + baseport := parsedURL.Port() + if baseport == "" && port != "" { + return + } + if baseport != port { + viper.Set("port", baseport) + if port != "8080" { + log.Infoln("CSPP_PORT overridden by value specified in CSPP_BASE_URL.") + } + } + } } func main() { diff --git a/cspp/main_test.go b/cspp/main_test.go new file mode 100644 index 0000000..152cdb4 --- /dev/null +++ b/cspp/main_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "testing" + + "github.com/spf13/viper" +) + +// MockLogger is a mock logger for testing purposes +type MockLogger struct{} + +func (l *MockLogger) Debugln(args ...interface{}) {} +func (l *MockLogger) Errorln(args ...interface{}) {} + +func TestValidatePortVsBaseURL(t *testing.T) { + // Mock configuration + viper.Set("base_url", "http://example.com:8080") + viper.Set("port", "8081") + + validatePortVsBaseURL() + + if port := viper.GetString("port"); port != "8080" { + t.Errorf("Expected port to be set to 8080, got %s", port) + } +} + +func TestValidatePortVsBaseURL_NoBaseURL(t *testing.T) { + // Mock configuration + viper.Set("base_url", "") + viper.Set("port", "8081") + + validatePortVsBaseURL() + + if port := viper.GetString("port"); port != "8081" { + t.Errorf("Expected port to remain unchanged, got %s", port) + } +} + +func TestValidatePortVsBaseURL_InvalidBaseURL(t *testing.T) { + // Mock configuration + viper.Set("base_url", "invalid-url") + viper.Set("port", "8081") + + validatePortVsBaseURL() + + // Expect the error message to be logged +} + +func TestValidatePortVsBaseURL_BaseURLWithoutPort(t *testing.T) { + // Mock configuration + viper.Set("base_url", "http://example.com") + viper.Set("port", "8081") + + validatePortVsBaseURL() + + if port := viper.GetString("port"); port != "8081" { + t.Errorf("Expected port to remain unchanged, got %s", port) + } +} + +func TestValidatePortVsBaseURL_Port8080(t *testing.T) { + // Mock configuration + viper.Set("base_url", "http://example.com:8080") + viper.Set("port", "8080") + + validatePortVsBaseURL() + + // Expect no message to be logged +} + +func TestValidatePortVsBaseURL_CustomPort(t *testing.T) { + // Mock configuration + viper.Set("base_url", "http://example.com:9000") + viper.Set("port", "8081") + + validatePortVsBaseURL() + + // Expect the overridden message to be logged +}