Skip to content

Commit

Permalink
feat: add exported Parse(...) which returns a WhereClause
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Aug 14, 2023
1 parent ad64cca commit efbc392
Show file tree
Hide file tree
Showing 18 changed files with 939 additions and 70 deletions.
81 changes: 75 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,57 @@

Model Query Language (MQL) is a query language for your database models.

The `mql` Go package that provides a language that end users can use to query your
The `mql` Go package provides a language that end users can use to query your
database models, without them having to learn SQL or exposing your
application to sql injection.
application to SQL injection.

## Examples

### github.com/go-gorm/gorm

```Go
w, err := mql.Parse("name=alice or name=bob)",User{})
if err != nil {
return nil, err
}
err = db.Where(w.Condition, w.Args...).Find(&users).Error
```

### database/sql

```Go
w, err := mql.Parse("name=alice or name=bob)",User{})
if err != nil {
return nil, err
}
q := fmt.Sprintf("select * from users where %s", w.Condition)
rows, err := db.Query(q, w.Args...)
```

### github.com/hashicorp/go-dbw

```Go
w, err := mql.Parse("name=alice or name=bob)",User{})
if err != nil {
return nil, err
}
err := rw.SearchWhere(ctx, &users, w.Condition, w.Args)
```

## Some bits about usage

First, you define a model you wish to query as a Go `struct` and then provide a `mql`
query. The package then uses the query along with a model to generate a
parameterized SQL where clause.

Fields in your model can be compared with the following operators: `=`, `!=`,
`>=`, `<=`, `<`, `>`, `%` .
`>=`, `<=`, `<`, `>`, `%` .

Double quotes `"` can be used to quote strings.

Comparison operators can have optional leading/trailing whitespace.

The `%` operator allows you to do partial string matching using LIKE and and this
The `%` operator allows you to do partial string matching using LIKE "%value%". This
matching is case insensitive.

The `=` equality operator is case insensitive when used with string fields.
Expand All @@ -26,14 +61,48 @@ Comparisons can be combined using: `and`, `or`.

More complex queries can be created using parentheses.

Example query:
`name=alice and age > 11 and (region % Boston or region="south shore")`
Example query:
`name=alice and age > 11 and (region % Boston or region="south shore")`

### Date/Time fields

If your model contains a time.Time field, then we'll append `::date` to the
column name when generating a where clause and the comparison value must be in
an `ISO-8601` format. Currently, this is the only supported way to compare
dates, if you need something different then you'll need to provide your own
custom validator/converter via `WithConverter(...)` when calling
`mql.Parse(...)`.

We provide default validation+conversion of fields in a model when parsing
and generating a `WhereClause`. You can provide optional validation+conversion
functions for fields in your model via `WithConverter(...)`.

### Mapping column names

You can also provide an optional map from query column identifiers to model
field names via `WithColumnMap(...)` if needed.

**Please note**: We take security and our users' trust very seriously. If you
believe you have found a security issue, please *[responsibly
disclose](https://www.hashicorp.com/security#vulnerability-reporting)* by
contacting us at [email protected].

### Ignoring fields

If your model (Go struct) has fields you don't want users searching then you can
optionally provide a list of columns to be ignored via `WithIgnoreFields(...)`

### Custom converters/validators

Sometimes the default out-of-the-box bits doesn't fit your needs. If you need to
override how expressions (column name, operator and value) is converted and
validated during the generation of a WhereClause, then you can optionally
provide your own validator/convertor via `WithConverter(...)`

### Grammar

See: [GRAMMAR.md](./GRAMMER.md)

## Contributing

Thank you for your interest in contributing! Please refer to
Expand Down
3 changes: 1 addition & 2 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,5 @@ func panicIfNil(a any, caller, missing string) {
}

func pointer[T any](input T) *T {
ret := input
return &ret
return &input
}
2 changes: 1 addition & 1 deletion docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/*
Package mql provides a language that end users can use to query your
database models, without them having to learn SQL or exposing your
application to sql injection.
application to SQL injection.
You define a model you wish to query as a Go struct and provide a mql query. The
package then uses the query along with a model to generate a parameterized SQL
Expand Down
4 changes: 4 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package mql
import "errors"

var (
ErrInternal = errors.New("internal error")
ErrInvalidParameter = errors.New("invalid parameter")
ErrInvalidNotEqual = errors.New(`invalid "!=" token`)
ErrMissingExpr = errors.New("missing expression")
Expand All @@ -17,7 +18,10 @@ var (
ErrUnexpectedToken = errors.New("unexpected token")
ErrInvalidComparisonOp = errors.New("invalid comparison operator")
ErrMissingComparisonOp = errors.New("missing comparison operator")
ErrMissingColumn = errors.New("missing column")
ErrInvalidLogicalOp = errors.New("invalid logical operator")
ErrMissingLogicalOp = errors.New("missing logical operator")
ErrMissingRightSideExpr = errors.New("logical operator without a right side expr")
ErrMissingComparisonValue = errors.New("missing comparison value")
ErrInvalidColumn = errors.New("invalid column")
)
95 changes: 71 additions & 24 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
type exprType int

const (
comparisonExprType exprType = iota
unknownExprType exprType = iota
comparisonExprType
logicalExprType
)

Expand All @@ -18,38 +19,39 @@ type expr interface {
String() string
}

type comparisonOp string
// ComparisonOp defines a set of comparison operators
type ComparisonOp string

const (
greaterThanOp comparisonOp = ">"
greaterThanOrEqualOp = ">="
lessThanOp = "<"
lessThanOrEqualOp = "<="
equalOp = "="
notEqualOp = "!="
containsOp = "%"
GreaterThanOp ComparisonOp = ">"
GreaterThanOrEqualOp ComparisonOp = ">="
LessThanOp ComparisonOp = "<"
LessThanOrEqualOp ComparisonOp = "<="
EqualOp ComparisonOp = "="
NotEqualOp ComparisonOp = "!="
ContainsOp ComparisonOp = "%"
)

func newComparisonOp(s string) (comparisonOp, error) {
func newComparisonOp(s string) (ComparisonOp, error) {
const op = "newComparisonOp"
switch s {
switch ComparisonOp(s) {
case
string(greaterThanOp),
string(greaterThanOrEqualOp),
string(lessThanOp),
string(lessThanOrEqualOp),
string(equalOp),
string(notEqualOp),
string(containsOp):
return comparisonOp(s), nil
GreaterThanOp,
GreaterThanOrEqualOp,
LessThanOp,
LessThanOrEqualOp,
EqualOp,
NotEqualOp,
ContainsOp:
return ComparisonOp(s), nil
default:
return "", fmt.Errorf("%s: %w %q", op, ErrInvalidComparisonOp, s)
}
}

type comparisonExpr struct {
column string
comparisonOp comparisonOp
comparisonOp ComparisonOp
value *string
}

Expand All @@ -72,19 +74,64 @@ func (e *comparisonExpr) isComplete() bool {
return e.column != "" && e.comparisonOp != "" && e.value != nil
}

// defaultValidateConvert will validate the comparison expr value, and then convert the
// expr to its SQL equivalence.
func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opt ...Option) (*WhereClause, error) {
const op = "mql.(comparisonExpr).convertToSql"
switch {
case columnName == "":
return nil, fmt.Errorf("%s: %w", op, ErrMissingColumn)
case comparisonOp == "":
return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonOp)
case isNil(columnValue):
return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonValue)
case validator.fn == nil:
return nil, fmt.Errorf("%s: missing validator function: %w", op, ErrInvalidParameter)
case validator.typ == "":
return nil, fmt.Errorf("%s: missing validator type: %w", op, ErrInvalidParameter)
}

// everything was validated at the start, so we know this is a valid/complete comparisonExpr
e := &comparisonExpr{
column: columnName,
comparisonOp: comparisonOp,
value: columnValue,
}

v, err := validator.fn(*e.value)
if err != nil {
return nil, fmt.Errorf("%s: %q in %s: %w", op, *e.value, e.String(), ErrInvalidParameter)
}
if validator.typ == "time" {
columnName = fmt.Sprintf("%s::date", columnName)
}
switch e.comparisonOp {
case ContainsOp:
return &WhereClause{
Condition: fmt.Sprintf("%s like ?", columnName),
Args: []any{fmt.Sprintf("%%%s%%", v)},
}, nil
default:
return &WhereClause{
Condition: fmt.Sprintf("%s%s?", columnName, e.comparisonOp),
Args: []any{v},
}, nil
}
}

type logicalOp string

const (
andOp logicalOp = "and"
orOp = "or"
orOp logicalOp = "or"
)

func newLogicalOp(s string) (logicalOp, error) {
const op = "newLogicalOp"
switch s {
switch logicalOp(s) {
case
string(andOp),
string(orOp):
andOp,
orOp:
return logicalOp(s), nil
default:
return "", fmt.Errorf("%s: %w %q", op, ErrInvalidLogicalOp, s)
Expand Down
96 changes: 96 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package mql

import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -31,5 +32,100 @@ func Test_root(t *testing.T) {
assert.ErrorIs(t, err, ErrMissingExpr)
assert.ErrorContains(t, err, "missing expression nil in: \"raw\"")
})
}

// Test_newComparison will focus on error conditions
func Test_newLogicalOp(t *testing.T) {
t.Parallel()
t.Run("invalid-comp-op", func(t *testing.T) {
op, err := newLogicalOp("not-valid")
require.Error(t, err)
assert.Empty(t, op)
assert.ErrorIs(t, err, ErrInvalidLogicalOp)
assert.ErrorContains(t, err, `invalid logical operator "not-valid"`)
})
}

// Test_newComparisonOp will focus on error conditions
func Test_newComparisonOp(t *testing.T) {
t.Parallel()
t.Run("invalid-comp-op", func(t *testing.T) {
op, err := newComparisonOp("not-valid")
require.Error(t, err)
assert.Empty(t, op)
assert.ErrorIs(t, err, ErrInvalidComparisonOp)
assert.ErrorContains(t, err, `invalid comparison operator "not-valid"`)
})
}

func Test_comparisonExprString(t *testing.T) {
t.Run("nil-value", func(t *testing.T) {
e := &comparisonExpr{
column: "name",
comparisonOp: "=",
value: nil,
}
assert.Equal(t, "(comparisonExpr: name = nil)", e.String())
})
}

func Test_logicalExprString(t *testing.T) {
t.Run("String", func(t *testing.T) {
e := &logicalExpr{
leftExpr: &comparisonExpr{
column: "name",
comparisonOp: "=",
value: pointer("alice"),
},
logicalOp: andOp,
rightExpr: &comparisonExpr{
column: "name",
comparisonOp: "=",
value: pointer("alice"),
},
}
assert.Equal(t, "(logicalExpr: (comparisonExpr: name = alice) and (comparisonExpr: name = alice))", e.String())
})
}

// Test_defaultValidateConvert will focus on error conditions
func Test_defaultValidateConvert(t *testing.T) {
t.Parallel()
fValidators, err := fieldValidators(reflect.ValueOf(testModel{}))
require.NoError(t, err)
t.Run("missing-column", func(t *testing.T) {
e, err := defaultValidateConvert("", EqualOp, pointer("alice"), fValidators["name"])
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrMissingColumn)
assert.ErrorContains(t, err, "missing column")
})
t.Run("missing-comparison-op", func(t *testing.T) {
e, err := defaultValidateConvert("name", "", pointer("alice"), fValidators["name"])
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrMissingComparisonOp)
assert.ErrorContains(t, err, "missing comparison operator")
})
t.Run("missing-value", func(t *testing.T) {
e, err := defaultValidateConvert("name", EqualOp, nil, fValidators["name"])
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrMissingComparisonValue)
assert.ErrorContains(t, err, "missing comparison value")
})
t.Run("missing-validator-func", func(t *testing.T) {
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{typ: "string"})
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrInvalidParameter)
assert.ErrorContains(t, err, "missing validator function")
})
t.Run("missing-validator-typ", func(t *testing.T) {
e, err := defaultValidateConvert("name", EqualOp, pointer("alice"), validator{fn: fValidators["name"].fn})
require.Error(t, err)
assert.Empty(t, e)
assert.ErrorIs(t, err, ErrInvalidParameter)
assert.ErrorContains(t, err, "missing validator type")
})
}
Loading

0 comments on commit efbc392

Please sign in to comment.