Skip to content

Commit

Permalink
Merge pull request #23 from srinidhis94/fk-support
Browse files Browse the repository at this point in the history
Fk support
  • Loading branch information
PumpkinSeed authored Apr 18, 2021
2 parents a6449b8 + 70400ea commit ad5950a
Show file tree
Hide file tree
Showing 13 changed files with 630 additions and 84 deletions.
9 changes: 7 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
GOFILES := $(shell find . -name "*.go" -type f ! -path "./vendor/*")
GOFMT ?= gofmt -s

.PHONY: all
PACKAGES = $(shell go list ./... | grep -v /vendor/)

.PHONY: all test
all: slqfuzz_darwin_amd64 slqfuzz_windows_amd64 slqfuzz_linux_amd64 slqfuzz_linux_arm64

slqfuzz_darwin_amd64:
Expand All @@ -24,4 +26,7 @@ clean:
rm -rf sqlfuzz*

fmt:
@$(GOFMT) -w ${GOFILES}
@$(GOFMT) -w ${GOFILES}

test:
@go test -v -coverprofile cover.out ${PACKAGES}
42 changes: 30 additions & 12 deletions drivers/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,33 @@ type Field struct {
Enum []string
}

type FKDescriptor struct {
ConstraintName string
TableName string
ColumnName string
ForeignTableName string
ForeignColumnName string
}

//FieldDescriptor represents a field described by the table in the SQL database
type FieldDescriptor struct {
Field string
Type string
Null string
Key string
Length null.Int
Default null.String
Extra string
Precision null.Int
Scale null.Int
HasDefaultValue bool
Field string
Type string
Null string
Key string
Length null.Int
Default null.String
Extra string
Precision null.Int
Scale null.Int
HasDefaultValue bool
ForeignKeyDescriptor *FKDescriptor
}

// TestCase has a map of table to its create table query and table creation order
type TestCase struct {
TableToCreateQueryMap map[string]string
TableCreationOrder []string
}

// Driver is the interface should satisfied by a certain driver
Expand All @@ -65,11 +80,14 @@ type Driver interface {
Driver() string
Insert(fields []string, table string) string
MapField(descriptor FieldDescriptor) Field
DescribeFields(table string, db *sql.DB) ([]FieldDescriptor, error)
Describe(table string, db *sql.DB) ([]FieldDescriptor, error)
MultiDescribe(tables []string, db *sql.DB) (map[string][]FieldDescriptor, []string, error)
GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error)
}

type Testable interface {
TestTable(conn *sql.DB, table string) error
GetTestCase(name string) (TestCase, error)
TestTable(conn *sql.DB, testCase, table string) error
}

// New creates a new driver instance based on the flags
Expand Down
107 changes: 87 additions & 20 deletions drivers/mysql.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,40 @@
package drivers

import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
)

const (
MySQLDescribeTableQuery = "SHOW TABLES;"
mysqlFKQuery = "SELECT CONSTRAINT_NAME,TABLE_NAME,COLUMN_NAME,REFERENCED_TABLE_NAME,REFERENCED_COLUMN_NAME from INFORMATION_SCHEMA.KEY_COLUMN_USAGE where REFERENCED_TABLE_NAME <> 'NULL' and REFERENCED_COLUMN_NAME <> 'NULL' and TABLE_NAME = '%s'"
)

var (
mySQLNameToTestCase = map[string]TestCase{
"single": {
TableToCreateQueryMap: map[string]string{DefaultTableCreateQueryKey: `CREATE TABLE %s (
id INT(6) UNSIGNED,
firstname VARCHAR(30),
lastname VARCHAR(30),
email VARCHAR(50),
reg_date TIMESTAMP
)`},
TableCreationOrder: nil,
},
"multi": {
TableToCreateQueryMap: map[string]string{
"t_currency": "CREATE TABLE IF NOT EXISTS t_currency ( id int not null,shortcut char (3) not null,PRIMARY KEY (id));",
"t_location": "CREATE TABLE IF NOT EXISTS t_location ( id int not null,location_name text not null,PRIMARY KEY (id));",
"t_product": "CREATE TABLE IF NOT EXISTS t_product( id int not null,name text not null,currency_id int ,PRIMARY KEY (id), FOREIGN KEY (currency_id) REFERENCES t_currency(id));",
"t_product_desc": "CREATE TABLE IF NOT EXISTS t_product_desc (id int not null,product_id int , description text not null, PRIMARY KEY (id), FOREIGN KEY (product_id) REFERENCES t_currency(id) );",
"t_product_stock": "CREATE TABLE IF NOT EXISTS t_product_stock(product_id int , location_id int ,amount numeric not null, FOREIGN KEY (product_id) REFERENCES t_currency(id),FOREIGN KEY(location_id) REFERENCES t_location(id));",
},
TableCreationOrder: []string{"t_currency", "t_location", "t_product", "t_product_desc", "t_product_stock"},
},
}
)

// MySQL implementation of the Driver
Expand Down Expand Up @@ -163,46 +189,87 @@ func (m MySQL) MapField(descriptor FieldDescriptor) Field {
return Field{Type: Unknown, Length: -1}
}

func (MySQL) DescribeFields(table string, db *sql.DB) ([]FieldDescriptor, error) {
func (MySQL) Describe(table string, db *sql.DB) ([]FieldDescriptor, error) {
describeQuery := fmt.Sprintf("DESCRIBE %s;", table)
results, err := db.Query(describeQuery)
if err != nil {
return nil, err
}
return parseMySQLFields(results)
fkRows, err := db.Query(fmt.Sprintf(mysqlFKQuery, strings.ToLower(table)))
if err != nil {
return nil, err
}
return parseMySQLFields(results, fkRows)
}

// TestTable only for test purposes
func (m MySQL) TestTable(db *sql.DB, table string) error {
query := `CREATE TABLE %s (
id INT(6) UNSIGNED,
firstname VARCHAR(30),
lastname VARCHAR(30),
email VARCHAR(50),
reg_date TIMESTAMP
)`

res, err := db.ExecContext(context.Background(), fmt.Sprintf(query, table))
func (m MySQL) MultiDescribe(tables []string, db *sql.DB) (map[string][]FieldDescriptor, []string, error) {
processedTables := make(map[string]struct{})
tableToDescriptorMap := make(map[string][]FieldDescriptor)
for {
newTableToDescriptorMap, newlyReferencedTables, err := multiDescribeHelper(tables, processedTables, db, m)
if err != nil {
return nil, nil, err
}
for key, val := range newTableToDescriptorMap {
tableToDescriptorMap[key] = val
}
if len(newlyReferencedTables) == 0 {
break
}
tables = newlyReferencedTables
}
insertionOrder, err := getInsertionOrder(tableToDescriptorMap)
if err != nil {
return err
return nil, nil, err
}
return tableToDescriptorMap, insertionOrder, nil
}

_, err = res.RowsAffected()
func (MySQL) GetLatestColumnValue(table, column string, db *sql.DB) (interface{}, error) {
query := fmt.Sprintf("select %v from %v order by %v desc limit 1", column, table, column)
rows, err := db.Query(query)
if err != nil {
return err
return nil, err
}
var val interface{}
for rows.Next() {
rows.Scan(&val)
}
return nil
return val, nil
}

// TestTable only for test purposes
func (m MySQL) TestTable(db *sql.DB, testCase, table string) error {
return testTable(db, testCase, table, m)
}

func parseMySQLFields(results *sql.Rows) ([]FieldDescriptor, error) {
func (MySQL) GetTestCase(name string) (TestCase, error) {
if val, ok := mySQLNameToTestCase[name]; ok {
return val, nil
}
return TestCase{}, errors.New(fmt.Sprintf("postgres: Error getting testcase with name %v", name))
}

func parseMySQLFields(results, fkRows *sql.Rows) ([]FieldDescriptor, error) {
var fields []FieldDescriptor
columnToFKMap := make(map[string]FKDescriptor)
for fkRows.Next() {
var fk FKDescriptor
err := fkRows.Scan(&fk.ConstraintName, &fk.TableName, &fk.ColumnName, &fk.ForeignTableName, &fk.ForeignColumnName)
if err != nil {
return nil, err
}
columnToFKMap[fk.ColumnName] = fk
}
for results.Next() {
var d FieldDescriptor
err := results.Scan(&d.Field, &d.Type, &d.Null, &d.Key, &d.Default, &d.Extra)
if err != nil {
return nil, err
}

if val, ok := columnToFKMap[d.Field]; ok {
d.ForeignKeyDescriptor = &val
}
fields = append(fields, d)
}
return fields, nil
Expand Down
Loading

0 comments on commit ad5950a

Please sign in to comment.