diff --git a/.travis.yml b/.travis.yml index 6a426e7..02e9d8e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,3 +16,4 @@ install: - ./download_samples.sh script: - go install && ./test.sh + - go test diff --git a/pgfutter.go b/pgfutter.go index e96d798..3cfc7b7 100644 --- a/pgfutter.go +++ b/pgfutter.go @@ -106,7 +106,7 @@ func main() { { Name: "json", Usage: "Import newline-delimited JSON objects into database", - Action: func(c *cli.Context) { + Action: func(c *cli.Context) error { cli.CommandHelpTemplate = strings.Replace(cli.CommandHelpTemplate, "[arguments...]", "", -1) filename := c.Args().First() @@ -118,13 +118,13 @@ func main() { connStr := parseConnStr(c) err := importJSON(filename, connStr, schema, tableName, ignoreErrors, dataType) - exitOnError(err) + return err }, }, { Name: "jsonobj", Usage: "Import single JSON object into database", - Action: func(c *cli.Context) { + Action: func(c *cli.Context) error { cli.CommandHelpTemplate = strings.Replace(cli.CommandHelpTemplate, "[arguments...]", "", -1) filename := c.Args().First() @@ -135,7 +135,7 @@ func main() { connStr := parseConnStr(c) err := importJSONObject(filename, connStr, schema, tableName, dataType) - exitOnError(err) + return err }, }, { @@ -160,7 +160,7 @@ func main() { Usage: "skip parsing escape sequences in the given delimiter", }, }, - Action: func(c *cli.Context) { + Action: func(c *cli.Context) error { cli.CommandHelpTemplate = strings.Replace(cli.CommandHelpTemplate, "[arguments...]", "", -1) filename := c.Args().First() @@ -176,7 +176,7 @@ func main() { fmt.Println(delimiter) connStr := parseConnStr(c) err := importCSV(filename, connStr, schema, tableName, ignoreErrors, skipHeader, fields, delimiter) - exitOnError(err) + return err }, }, } diff --git a/postgres.go b/postgres.go index 755255d..4b53b1b 100644 --- a/postgres.go +++ b/postgres.go @@ -3,7 +3,9 @@ package main import ( "database/sql" "fmt" + "log" "math/rand" + "regexp" "strconv" "strings" @@ -53,24 +55,17 @@ func postgresify(identifier string) string { "-": "_", ",": "_", "#": "_", - - "[": "", - "]": "", - "{": "", - "}": "", - "(": "", - ")": "", - "?": "", - "!": "", - "$": "", - "%": "", - "*": "", - "\"": "", } for oldString, newString := range replacements { str = strings.Replace(str, oldString, newString, -1) } + reg, err := regexp.Compile("[^A-Za-z0-9_]+") + if err != nil { + log.Fatal(err) + } + str = reg.ReplaceAllString(str, "") + if len(str) == 0 { str = fmt.Sprintf("_col%d", rand.Intn(10000)) } else { diff --git a/postgres_test.go b/postgres_test.go new file mode 100644 index 0000000..cb3ca41 --- /dev/null +++ b/postgres_test.go @@ -0,0 +1,24 @@ +package main + +import "testing" + +type testpair struct { + columnName string + sanitizedName string +} + +var tests = []testpair{ + {"Starting Date & Time", "starting_date__time"}, + {"[$MYCOLUMN]", "mycolumn"}, + {"({colname?!})", "colname"}, + {"m4 * 4 / 3", "m4__4___3"}, +} + +func TestPostgresify(t *testing.T) { + for _, pair := range tests { + str := postgresify(pair.columnName) + if str != pair.sanitizedName { + t.Error("Invalid PostgreSQL identifier expected ", pair.sanitizedName, "got ", str) + } + } +}