diff --git a/bin/dune b/bin/dune new file mode 100644 index 00000000..fcbdd84d --- /dev/null +++ b/bin/dune @@ -0,0 +1,5 @@ +(executable + (public_name u2dl) + (name u2dl) + (modules u2dl) + (libraries birds sql)) diff --git a/bin/u2dl.ml b/bin/u2dl.ml new file mode 100644 index 00000000..b8c8b4b8 --- /dev/null +++ b/bin/u2dl.ml @@ -0,0 +1,45 @@ +open Birds + +let check_arguments_count argv = + if Array.length argv < 3 then + Result.Error "Invalid arguments. Both SQL file name and Datalog file name must be passed." + else + Result.Ok () + +let open_sql_ast filename = + let chan = open_in filename in + let lexbuf = Lexing.from_channel chan in + let ast = Sql.Parser.update Sql.Lexer.token lexbuf in + Result.Ok ast + +let open_view_ast filename = + let filename = filename in + let chan = open_in filename in + let lexbuf = Lexing.from_channel chan in + let ast = Parser.main Lexer.token lexbuf in + Result.Ok ast + +let extract_schema expr = + match expr.Expr.view with + | Some (_, cols) -> Result.Ok (List.map fst cols) + | None -> Result.Error "Invalid schema file. A view definition must be." + +let convert_to_dl sql cols = + match Sql2ast.update_to_datalog sql cols with + | Result.Ok _ as succ -> succ + | Result.Error err -> Result.Error (Sql2ast.string_of_error err) + +let main = + let open Utils.ResultMonad in + check_arguments_count Sys.argv >>= fun _ -> + open_sql_ast Sys.argv.(1) >>= fun sql -> + open_view_ast Sys.argv.(2) >>= fun expr -> + extract_schema expr >>= fun cols -> + convert_to_dl sql cols >>= fun rules -> + return @@ print_endline @@ Expr.to_string Expr.{ expr with rules } + +let _ = + match main with + | Result.Ok _ -> () + | Result.Error err -> print_endline err + diff --git a/examples/schema.dl b/examples/schema.dl new file mode 100644 index 00000000..050753b1 --- /dev/null +++ b/examples/schema.dl @@ -0,0 +1 @@ +view ced('EMP_NAME':string, 'DEPT_NAME':string). diff --git a/examples/sql_sample.dl b/examples/sql_sample.dl new file mode 100644 index 00000000..050753b1 --- /dev/null +++ b/examples/sql_sample.dl @@ -0,0 +1 @@ +view ced('EMP_NAME':string, 'DEPT_NAME':string). diff --git a/examples/sql_sample.sql b/examples/sql_sample.sql new file mode 100644 index 00000000..eb92b64d --- /dev/null +++ b/examples/sql_sample.sql @@ -0,0 +1,7 @@ +UPDATE + ced +SET + DEPT_NAME = 'R&D' +WHERE + DEPT_NAME = 'Dev' +; diff --git a/src/dune b/src/dune index d292db99..929a094f 100644 --- a/src/dune +++ b/src/dune @@ -5,7 +5,7 @@ (library (name birds) - (libraries logic ocamlgraph num str)) + (libraries logic sql ocamlgraph num str)) (env (dev diff --git a/src/sql/ast.ml b/src/sql/ast.ml new file mode 100644 index 00000000..8dda59a8 --- /dev/null +++ b/src/sql/ast.ml @@ -0,0 +1,113 @@ + +type binary_operator = + | Plus (* + *) + | Minus (* - *) + | Times (* * *) + | Divides (* / *) + | Lor (* || *) + +type unary_operator = + | Negate (* - *) + +type operator = + | RelEqual + | RelNotEqual + | RelGeneral of string + +type table_name = string + +type column_name = string + +type instance_name = string + +type column = instance_name option * column_name + +type const = + | Int of int + | Real of float + | String of string + | Bool of bool + | Null + +type vterm = + | Const of const + | Column of column + | UnaryOp of unary_operator * vterm + | BinaryOp of binary_operator * vterm * vterm + +type sql_constraint = + | Constraint of vterm * operator * vterm + +type where_clause = + | Where of sql_constraint list + +type update = + | UpdateSet of table_name * (column * vterm) list * where_clause option + +let string_of_binary_operator = function + | Plus -> "+" + | Minus -> "-" + | Times -> "*" + | Divides -> "/" + | Lor -> "||" + +let string_of_unary_operator = function + | Negate -> "-" + +let string_of_operator = function + | RelEqual -> "=" + | RelNotEqual -> "<>" + | RelGeneral op -> op + +let string_of_column (instance_name, column) = + match instance_name with + | Some instance_name -> Printf.sprintf "%s.%s" instance_name column + | None -> column + +let string_of_column_ignore_instance (_, column) = column + +let string_of_const = function + | Int i -> string_of_int i + | Real f -> string_of_float f + | String s -> s + | Bool b -> string_of_bool b + | Null -> "NULL" + +let rec string_of_vterm = function + | Const c -> string_of_const c + | Column c -> string_of_column c + | UnaryOp (op, e) -> string_of_unary_operator op ^ string_of_vterm e + | BinaryOp (op, left, right) -> + Printf.sprintf "%s %s %s" + (string_of_vterm left) + (string_of_binary_operator op) + (string_of_vterm right) + +let string_of_constraint = function + | Constraint (left, op, right) -> + Printf.sprintf "%s %s %s" + (string_of_vterm left) + (string_of_operator op) + (string_of_vterm right) + +let to_string = function + | UpdateSet (table_name, sets, where) -> + let string_of_set (col, vterm) = + Printf.sprintf " %s = %s" (string_of_column col) (string_of_vterm vterm) + in + "UPDATE\n" ^ + " " ^ table_name ^ "\n" ^ + "SET\n" ^ ( + sets + |> List.map string_of_set + |> String.concat "\n" + ) ^ + match where with + | None -> "" + | Some (Where cs) -> + "\nWHERE\n" ^ ( + cs + |> List.map (fun c -> " " ^ string_of_constraint c) + |> String.concat "\n" + ) + ^ "\n;" diff --git a/src/sql/dune b/src/sql/dune new file mode 100644 index 00000000..6f04656c --- /dev/null +++ b/src/sql/dune @@ -0,0 +1,7 @@ +(menhir + (modules parser)) + +(ocamllex lexer) + +(library + (name sql)) diff --git a/src/sql/lexer.mll b/src/sql/lexer.mll new file mode 100644 index 00000000..2f1fa82d --- /dev/null +++ b/src/sql/lexer.mll @@ -0,0 +1,59 @@ +{ + open Parser;; + open Lexing;; + + let spec_error msg start finish = + Printf.sprintf + "File \"%s\", line %d, characters %d-%d: '%s'" + start.pos_fname + start.pos_lnum + (start.pos_cnum - start.pos_bol) + (finish.pos_cnum - finish.pos_bol) + msg + + exception LexErr of string + let spec_lex_error lexbuf = + raise (LexErr (spec_error (lexeme lexbuf) (lexeme_start_p lexbuf) (lexeme_end_p lexbuf))) + + let keywords = [ + "update", UPDATE; + "UPDATE", UPDATE; + "where", WHERE; + "WHERE", WHERE; + "set", SET; + "SET", SET; + "and", AND; + "AND", AND; + ] +} +let digit = ['0'-'9'] +let alpha = ['a'-'z' 'A'-'Z'] +let ident = (alpha) (alpha | digit | '_' )* +let wsp = [' ' '\r' '\t'] + +rule token = parse + | wsp { token lexbuf } + | '\n' | ';' { Lexing.new_line lexbuf; token lexbuf } + | "--" (wsp | alpha | digit) ('\n' | eof) { Lexing.new_line lexbuf; token lexbuf } + | digit+ as lxm { INTEGER (int_of_string lxm) } + | digit* '.'? digit+ (['e' 'E'] ['-' '+']? digit+)? as lxm { FLOAT (float_of_string (lxm)) } + | '\'' (('\'' '\'') | [^'\n''\''])* '\'' as lxm { TEXT lxm } + | ident as lxm { + match List.assoc_opt lxm keywords with + | Some t -> t + | None -> IDENT lxm + } + | '(' { LPAREN } + | ')' { RPAREN } + | ',' { COMMA } + | '.' { DOT } + | "NULL" | "null" { NULL } + | '=' { EQUAL } + | '*' { ASTERISK } + | "||" { CONCAT_OP } + | '/' { NUM_DIV_OP } + | "!=" | "<>" { NUM_NEQ_OP } + | '+' { PLUS } + | '-' { MINUS } + | eof { EOF } + | _ { spec_lex_error lexbuf } diff --git a/src/sql/parser.mly b/src/sql/parser.mly new file mode 100644 index 00000000..ec1656be --- /dev/null +++ b/src/sql/parser.mly @@ -0,0 +1,79 @@ +%token INTEGER +%token IDENT TEXT +%token FLOAT +%token LPAREN RPAREN COMMA EOF DOT NULL +%token UPDATE WHERE EQUAL ASTERISK SET AND CONCAT_OP +%token NUM_DIV_OP NUM_NEQ_OP PLUS MINUS + +%left CONCAT_OP +%left AND +%nonassoc EQUAL NUM_NEQ_OP +%left PLUS MINUS +%left ASTERISK NUM_DIV_OP +%nonassoc UNARY_MINUS + +%start update + +%% + + update: + | update_stmt EOF { $1 } + ; + + update_stmt: + | UPDATE table=IDENT SET ss=commas(set_column) w=where? { Ast.UpdateSet (table, ss, w) } + ; + + set_column: + | c=column EQUAL e=vterm { c, e } + ; + + vterm: + | const { Ast.Const $1 } + | column { Ast.Column $1 } + | unary_op { $1 } + | left=vterm op=binary_op right=vterm { Ast.BinaryOp (op, left, right) } + | LPAREN e=vterm RPAREN { e } + ; + + + const: + | INTEGER { Ast.Int $1 } + | FLOAT { Ast.Real $1 } + | TEXT { Ast.String $1 } + | NULL { Ast.Null } + ; + + column: + | table=IDENT DOT cname=IDENT { (Some table), cname } + | cname=IDENT { None, cname } + ; + + unary_op: + | MINUS e=vterm %prec UNARY_MINUS { Ast.UnaryOp (Ast.Negate, e) } + ; + + binary_op: + | PLUS { Ast.Plus } + | MINUS { Ast.Minus } + | ASTERISK { Ast.Times } + | NUM_DIV_OP { Ast.Divides } + | CONCAT_OP { Ast.Lor } + ; + + where: + | WHERE cs=ands(sql_constraint) { Ast.Where cs } + ; + + sql_constraint: + | left=vterm op=operator right=vterm { Ast.Constraint (left, op, right) } + ; + + operator: + | EQUAL { Ast.RelEqual } + | NUM_NEQ_OP { Ast.RelNotEqual } + | op=IDENT { Ast.RelGeneral op } + ; + +%inline commas(X): l=separated_nonempty_list(COMMA, X) { l } +%inline ands(X): l=separated_nonempty_list(AND, X) { l } diff --git a/src/sql2ast.ml b/src/sql2ast.ml index 2ff5e4d4..5ee52716 100644 --- a/src/sql2ast.ml +++ b/src/sql2ast.ml @@ -1,61 +1,8 @@ open Utils -let ( >>= ) = ResultMonad.( >>= ) - -type sql_binary_operator = - | SqlPlus (* + *) - | SqlMinus (* - *) - | SqlTimes (* * *) - | SqlDivides (* / *) - | SqlLor (* || *) - -type sql_unary_operator = - | SqlNegate (* - *) - -type sql_operator = - | SqlRelEqual - | SqlRelNotEqual - | SqlRelGeneral of string - -type sql_table_name = string - -type sql_column_name = string - -type sql_instance_name = string - -type sql_column = sql_instance_name option * sql_column_name +module Sql = Sql.Ast -type sql_vterm = - | SqlConst of Expr.const - | SqlColumn of sql_column - | SqlUnaryOp of sql_unary_operator * sql_vterm - | SqlBinaryOp of sql_binary_operator * sql_vterm * sql_vterm - -type sql_constraint = - | SqlConstraint of sql_vterm * sql_operator * sql_vterm - -type sql_where_clause = - | SqlWhere of sql_constraint list - -type sql_update = - | SqlUpdateSet of sql_table_name * (sql_column * sql_vterm) list * sql_where_clause option - -let string_of_sql_binary_operator = function - | SqlPlus -> "+" - | SqlMinus -> "-" - | SqlTimes -> "*" - | SqlDivides -> "/" - | SqlLor -> "||" - -let string_of_sql_unary_operator = function - | SqlNegate -> "-" - -let string_of_sql_operator = function - | SqlRelEqual -> "=" - | SqlRelNotEqual -> "<>" - | SqlRelGeneral op -> op - -let string_of_sql_column_ignore_instance (_, column) = column +let ( >>= ) = ResultMonad.( >>= ) type error = | InvalidColumnName of string @@ -67,28 +14,35 @@ let string_of_error = function module ColumnVarMap = Map.Make(String) let rec ast_vterm_of_sql_vterm colvarmap = function - | SqlConst const -> - ResultMonad.return (Expr.Const const) - | SqlColumn column -> - let column_name = string_of_sql_column_ignore_instance column in + | Sql.Const const -> + ResultMonad.return @@ Expr.Const + begin match const with + | Sql.Int n -> Expr.Int n + | Sql.Real f -> Expr.Real f + | Sql.String s -> Expr.String s + | Sql.Bool b -> Expr.Bool b + | Sql.Null -> Expr.Null + end + | Sql.Column column -> + let column_name = Sql.string_of_column_ignore_instance column in ColumnVarMap.find_opt column_name colvarmap |> Option.map (fun var -> Expr.Var var) |> Option.to_result ~none:(InvalidColumnName column_name) - | SqlUnaryOp (op, sql_vterm) -> + | Sql.UnaryOp (op, sql_vterm) -> ast_vterm_of_sql_vterm colvarmap sql_vterm >>= fun vterm -> - let op = string_of_sql_unary_operator op in + let op = Sql.string_of_unary_operator op in ResultMonad.return (Expr.UnaryOp (op, vterm)) - | SqlBinaryOp (op, left, right) -> + | Sql.BinaryOp (op, left, right) -> ast_vterm_of_sql_vterm colvarmap left >>= fun left -> ast_vterm_of_sql_vterm colvarmap right >>= fun right -> - let op = string_of_sql_binary_operator op in + let op = Sql.string_of_binary_operator op in ResultMonad.return (Expr.BinaryOp (op, left, right)) let ast_terms_of_sql_where_clause colvarmap = function - | SqlWhere sql_constraints -> + | Sql.Where sql_constraints -> let ast_term_of_sql_constraint = function - | SqlConstraint (left, op, right) -> - let op = string_of_sql_operator op in + | Sql.Constraint (left, op, right) -> + let op = Sql.string_of_operator op in ast_vterm_of_sql_vterm colvarmap left >>= fun left -> ast_vterm_of_sql_vterm colvarmap right >>= fun right -> ResultMonad.return (Expr.Equat (Expr.Equation (op, left, right))) in @@ -105,7 +59,7 @@ let build_effects colvarmap column_and_vterms = column_and_vterms |> ResultMonad.mapM (fun (sql_col, sql_vterm) -> ast_vterm_of_sql_vterm colvarmap sql_vterm >>= fun vterm -> - let column_name = string_of_sql_column_ignore_instance sql_col in + let column_name = Sql.string_of_column_ignore_instance sql_col in ColumnVarMap.find_opt column_name colvarmap |> Option.to_result ~none:(InvalidColumnName column_name) >>= fun var -> @@ -128,7 +82,7 @@ let build_creation_rule colvarmap colvarmap' column_and_vterms table_name column column_and_vterms |> ResultMonad.mapM (fun (column, vterm) -> ast_vterm_of_sql_vterm colvarmap' vterm >>= fun vterm -> - let column_name = string_of_sql_column_ignore_instance column in + let column_name = Sql.string_of_column_ignore_instance column in ColumnVarMap.find_opt column_name colvarmap |> Option.map (fun var -> Expr.Equat (Expr.Equation ("=", Expr.Var var, vterm))) |> Option.to_result ~none:(InvalidColumnName column_name) @@ -137,7 +91,7 @@ let build_creation_rule colvarmap colvarmap' column_and_vterms table_name column (* Create a rule corresponding to the operation to insert the record to be updated. *) columns |> ResultMonad.mapM (fun column -> - let column_name = string_of_sql_column_ignore_instance (None, column) in + let column_name = Sql.string_of_column_ignore_instance (None, column) in ColumnVarMap.find_opt column_name colvarmap' |> Option.to_result ~none:(InvalidColumnName column_name) ) >>= fun delete_var_list -> @@ -148,8 +102,8 @@ let build_creation_rule colvarmap colvarmap' column_and_vterms table_name column module ColumnSet = Set.Make(String) -let update_to_datalog (update : sql_update) (columns : sql_column_name list) : (Expr.rule list, error) result = - let SqlUpdateSet (table_name, column_and_vterms, where_clause) = update in +let update_to_datalog (update : Sql.update) (columns : Sql.column_name list) : (Expr.rule list, error) result = + let Sql.UpdateSet (table_name, column_and_vterms, where_clause) = update in (* Create (column name as String, Expr.var) list. *) let make_column_var_list make_var = @@ -159,7 +113,7 @@ let update_to_datalog (update : sql_update) (columns : sql_column_name list) : ( ) in let make_colvarmap column_var_list = column_var_list - |> List.map (fun (col, var) -> string_of_sql_column_ignore_instance col, var) + |> List.map (fun (col, var) -> Sql.string_of_column_ignore_instance col, var) |> List.to_seq |> ColumnVarMap.of_seq in @@ -181,7 +135,7 @@ let update_to_datalog (update : sql_update) (columns : sql_column_name list) : ( | None -> (var :: varlist), in_set | Some _ -> - let column_name = string_of_sql_column_ignore_instance column in + let column_name = Sql.string_of_column_ignore_instance column in (var :: varlist), (ColumnSet.add column_name in_set) ) column_var_list ([], ColumnSet.empty) in @@ -193,7 +147,7 @@ let update_to_datalog (update : sql_update) (columns : sql_column_name list) : ( *) let column_var_list' = columns |> make_column_var_list (fun idx column_name -> - let column_name = string_of_sql_column_ignore_instance (None, column_name) in + let column_name = Sql.string_of_column_ignore_instance (None, column_name) in if ColumnSet.exists (fun c -> c = column_name) in_set then Expr.NamedVar (Printf.sprintf "GENV%d_2" (idx + 1)) else diff --git a/src/sql2ast.mli b/src/sql2ast.mli index 33465e9d..c4fc6848 100644 --- a/src/sql2ast.mli +++ b/src/sql2ast.mli @@ -1,41 +1,3 @@ -type sql_binary_operator = - | SqlPlus (* + *) - | SqlMinus (* - *) - | SqlTimes (* * *) - | SqlDivides (* / *) - | SqlLor (* || *) - -type sql_unary_operator = - | SqlNegate (* - *) - -type sql_operator = - | SqlRelEqual - | SqlRelNotEqual - | SqlRelGeneral of string - -type sql_table_name = string - -type sql_column_name = string - -type sql_instance_name = string - -type sql_column = sql_instance_name option * sql_column_name - -type sql_vterm = - | SqlConst of Expr.const - | SqlColumn of sql_column - | SqlUnaryOp of sql_unary_operator * sql_vterm - | SqlBinaryOp of sql_binary_operator * sql_vterm * sql_vterm - -type sql_constraint = - | SqlConstraint of sql_vterm * sql_operator * sql_vterm - -type sql_where_clause = - | SqlWhere of sql_constraint list - -type sql_update = - | SqlUpdateSet of sql_table_name * (sql_column * sql_vterm) list * sql_where_clause option - type error val string_of_error : error -> string @@ -46,4 +8,4 @@ val string_of_error : error -> string * * @return List of rules in datalog language, or failure. *) -val update_to_datalog : sql_update -> sql_column_name list -> (Expr.rule list, error) result +val update_to_datalog : Sql.Ast.update -> Sql.Ast.column_name list -> (Expr.rule list, error) result diff --git a/test/sql2ast_test.ml b/test/sql2ast_test.ml index 2f4eca2f..952e7239 100644 --- a/test/sql2ast_test.ml +++ b/test/sql2ast_test.ml @@ -1,9 +1,11 @@ open Birds open Utils +module Sql = Sql.Ast + type test_case = { title : string; - input : Sql2ast.sql_update * Sql2ast.sql_column_name list; + input : Sql.update * Sql.column_name list; expected : Expr.rule list } @@ -69,14 +71,14 @@ let main () = * *) input = ( - SqlUpdateSet ( + Sql.UpdateSet ( "ced", - [(None, "dname"), SqlConst (String "'R&D'")], - Some (SqlWhere ([ - SqlConstraint ( - SqlColumn (None, "dname"), - SqlRelEqual, - SqlConst (String "'Dev'") + [(None, "dname"), Sql.Const (String "'R&D'")], + Some (Sql.Where ([ + Sql.Constraint ( + Sql.Column (None, "dname"), + Sql.RelEqual, + Sql.Const (String "'Dev'") ) ])) ), @@ -122,23 +124,23 @@ let main () = * *) input = ( - SqlUpdateSet ( + Sql.UpdateSet ( "t", [ - (None, "c1"), SqlConst (String "'v1'"); - (None, "c3"), SqlConst (String "'v3'"); - (None, "c5"), SqlConst (String "'v5'") + (None, "c1"), Sql.Const (String "'v1'"); + (None, "c3"), Sql.Const (String "'v3'"); + (None, "c5"), Sql.Const (String "'v5'") ], - Some (SqlWhere ([ - SqlConstraint ( - SqlColumn (None, "c2"), - SqlRelEqual, - SqlConst (String "'v2'") + Some (Sql.Where ([ + Sql.Constraint ( + Sql.Column (None, "c2"), + Sql.RelEqual, + Sql.Const (String "'v2'") ); - SqlConstraint ( - SqlColumn (None, "c3"), - SqlRelEqual, - SqlConst (String "'v100'") + Sql.Constraint ( + Sql.Column (None, "c3"), + Sql.RelEqual, + Sql.Const (String "'v100'") ) ])) ), @@ -200,13 +202,13 @@ let main () = * *) input = ( - SqlUpdateSet ( + Sql.UpdateSet ( "t", [ - (None, "c1"), SqlColumn (None, "c2"); - (None, "c2"), SqlColumn (None, "c3") + (None, "c1"), Sql.Column (None, "c2"); + (None, "c2"), Sql.Column (None, "c3") ], - Some (SqlWhere ([])) + Some (Sql.Where ([])) ), ["c1"; "c2"; "c3"; "c4"] );