diff --git a/src/ast/ast.rs b/src/ast/ast.rs index 0b57280..8d2b3cf 100644 --- a/src/ast/ast.rs +++ b/src/ast/ast.rs @@ -326,6 +326,9 @@ pub enum Expression { Binary(BinaryExpression), Require(RequireExpression), Function(FunctionExpression), + Table(TableExpression), + Member(MemberExpression), + Index(IndexExpression), } impl Expression { @@ -349,6 +352,18 @@ impl Expression { Expression::Grouped(GroupedExpression::new(expressions, range)) } + pub fn new_table(values: Vec<(Expression, Option)>, range: Range) -> Self { + Expression::Table(TableExpression::new(values, range)) + } + + pub fn new_member(base: Expression, member: Expression) -> Self { + Expression::Member(MemberExpression::new(Box::new(base), Box::new(member))) + } + + pub fn new_index(base: Expression, index: Expression, bracket_range: Range) -> Self { + Expression::Index(IndexExpression::new(Box::new(base), Box::new(index), bracket_range)) + } + pub fn new_function( arguments: Vec<(Token, Option)>, return_type: Option, @@ -369,6 +384,9 @@ impl Expression { Expression::Grouped(grouped) => grouped.get_range(), Expression::Unary(unary) => unary.get_range(), Expression::Function(function) => function.get_range(), + Expression::Table(table) => table.get_range(), + Expression::Member(member) => member.get_range(), + Expression::Index(index) => index.get_range(), } } @@ -468,6 +486,56 @@ impl UnaryExpression { } } +#[derive(Debug, Serialize, Deserialize)] +pub struct MemberExpression { + pub base: Box, + pub member: Box, +} +impl MemberExpression { + pub fn new(base: Box, member: Box) -> Self { + MemberExpression { base, member } + } + + pub fn get_range(&self) -> Range { + let left_range = self.base.get_range(); + let right_range = self.member.get_range(); + create_middle_range(&left_range, &right_range) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexExpression { + pub base: Box, + pub index: Box, + pub bracket_range: Range, +} +impl IndexExpression { + pub fn new(base: Box, index: Box, bracket_range: Range) -> Self { + IndexExpression { base, index, bracket_range } + } + + pub fn get_range(&self) -> Range { + let left_range = self.base.get_range(); + let right_range = self.index.get_range(); + create_middle_range(&left_range, &right_range) + } +} +#[derive(Debug, Serialize, Deserialize)] +pub struct TableExpression { + pub values: Vec<(Expression, Option)>, + pub range: Range, +} + +impl TableExpression { + pub fn new(values: Vec<(Expression, Option)>, range: Range) -> Self { + TableExpression { values, range } + } + + pub fn get_range(&self) -> Range { + self.range.clone() + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct FunctionExpression { pub arguments: Vec<(Token, Option)>, diff --git a/src/checker/check_call_expression.rs b/src/checker/check_call_expression.rs index 68ddd9e..a8678c7 100644 --- a/src/checker/check_call_expression.rs +++ b/src/checker/check_call_expression.rs @@ -3,72 +3,90 @@ use crate::{ ast::ast, diagnostics::{Diagnostic, TypeError}, types::Type, + utils::range::Range, }; impl<'a> Checker<'a> { pub fn check_call_expression(&mut self, call_expr: &ast::CallExpression) -> Result { let name = call_expr.name.lexeme(); - let (defined, scope_pointer) = self.ctx.defined_in_any_scope(name); + if !defined { - let diagnostic = TypeError::UndeclaredVariable(name.to_string(), Some(call_expr.name.range.clone())); - return Err(self.create_diagnostic(diagnostic)); + return Err( + self.create_diagnostic(TypeError::UndeclaredVariable(name.to_string(), Some(call_expr.name.range.clone()))), + ); } - let call_type = self.ctx.get_variable(name, Some(scope_pointer)); - - if call_type.is_none() { - let diagnostic = TypeError::UndeclaredVariable(name.to_string(), Some(call_expr.name.range.clone())); - return Err(self.create_diagnostic(diagnostic)); - } - match call_type.unwrap() { - Type::Function(call_type) => { - let return_t = *call_type.return_type.clone(); - self.check_call_arguments(&call_expr.args, &call_type.params.to_vec())?; - return Ok(return_t); - } - Type::Unknown => { - return Ok(Type::Unknown); + let call_type = match self.ctx.get_variable(name, Some(scope_pointer)) { + Some(call_type) => call_type.clone(), + None => { + return Err( + self.create_diagnostic(TypeError::UndeclaredVariable(name.to_string(), Some(call_expr.name.range.clone()))), + ); } - _ => { - let diagnostic = TypeError::ExpectedFunction(name.to_string(), Some(call_expr.name.range.clone())); - return Err(self.create_diagnostic(diagnostic)); + }; + + self.check_call_type(&call_type, &call_expr.args, call_expr.get_range()) + } + + pub fn check_call_type( + &mut self, + call_type: &Type, + args: &ast::Expression, + range: Range, + ) -> Result { + match call_type { + Type::Function(func_type) => { + self.check_call_arguments(args, &func_type.params)?; + Ok(*func_type.return_type.clone()) } + Type::Unknown => Ok(Type::Unknown), + _ => Err(self.create_diagnostic(TypeError::ExpectedFunction(call_type.to_string(), Some(range)))), } } - pub fn check_call_arguments(&mut self, args_call: &ast::Expression, params_tt: &[Type]) -> Result<(), Diagnostic> { - if let ast::Expression::Grouped(ast::GroupedExpression { expressions, range }) = args_call { - if expressions.len() != params_tt.len() { - let diagnostic = TypeError::FunctionArityMismatch(params_tt.len(), expressions.len(), Some(range.clone())); - return Err(self.create_diagnostic(diagnostic)); + pub fn check_call_arguments(&mut self, args: &ast::Expression, params: &[Type]) -> Result<(), Diagnostic> { + if let ast::Expression::Grouped(ast::GroupedExpression { expressions, range }) = args { + if expressions.len() != params.len() { + return Err(self.create_diagnostic(TypeError::FunctionArityMismatch( + params.len(), + expressions.len(), + Some(range.clone()), + ))); } - for (arg_expr, param_t) in expressions.iter().zip(params_tt.iter()) { - let param_t = self.check_type(param_t.clone())?; - let arg_t = self.check_expression(arg_expr)?; - if !arg_t.check_match(¶m_t) { - let range = arg_expr.get_range(); - let diagnostic = TypeError::MismatchedTypes(param_t.to_string(), arg_t.to_string(), Some(range.clone())); - return Err(self.create_diagnostic(diagnostic)); + for (arg_expr, param_type) in expressions.iter().zip(params.iter()) { + let inferred_type = self.check_expression(arg_expr)?; + let param_type_checked = self.check_type(param_type.clone())?; + + if !inferred_type.check_match(¶m_type_checked) { + return Err(self.create_diagnostic(TypeError::MismatchedTypes( + param_type_checked.to_string(), + inferred_type.to_string(), + Some(arg_expr.get_range()), + ))); } } - return Ok(()); + Ok(()) + } else { + self.check_single_argument(args, params) } + } - let arg_t = self.check_expression(args_call)?; - - if params_tt.len() != 1 { - let range = args_call.get_range(); - let diagnostic = TypeError::FunctionArityMismatch(params_tt.len(), 1, Some(range)); - return Err(self.create_diagnostic(diagnostic)); + fn check_single_argument(&mut self, arg: &ast::Expression, params: &[Type]) -> Result<(), Diagnostic> { + if params.len() != 1 { + return Err(self.create_diagnostic(TypeError::FunctionArityMismatch(params.len(), 1, Some(arg.get_range())))); } - let param_tt = params_tt.first().unwrap(); + let param_type = self.check_type(params.first().unwrap().clone())?; + let arg_type = self.check_expression(arg)?; - if !arg_t.check_match(¶m_tt) { - let diagnostic = TypeError::MismatchedTypes(param_tt.to_string(), arg_t.to_string(), None); - return Err(self.create_diagnostic(diagnostic)); + if !arg_type.check_match(¶m_type) { + return Err(self.create_diagnostic(TypeError::MismatchedTypes( + param_type.to_string(), + arg_type.to_string(), + Some(arg.get_range()), + ))); } Ok(()) diff --git a/src/checker/check_expression.rs b/src/checker/check_expression.rs index a166933..c035bda 100644 --- a/src/checker/check_expression.rs +++ b/src/checker/check_expression.rs @@ -14,6 +14,9 @@ impl<'a> Checker<'a> { ast::Expression::Unary(unary_expr) => self.check_unary_expression(unary_expr), ast::Expression::Grouped(grup_expr) => self.check_grouped_expression(grup_expr), ast::Expression::Function(function) => self.check_function_expression(function), + ast::Expression::Table(table) => self.check_table_expression(table), + ast::Expression::Member(member) => self.check_member_expression(member), + ast::Expression::Index(index) => self.check_index_expression(index), } } } diff --git a/src/checker/check_index_expression.rs b/src/checker/check_index_expression.rs new file mode 100644 index 0000000..a1382c9 --- /dev/null +++ b/src/checker/check_index_expression.rs @@ -0,0 +1,82 @@ +use super::Checker; +use crate::{ + ast::ast, + diagnostics::{Diagnostic, TypeError}, + types::{TableType, Type}, + utils::range::Range, +}; + +pub enum Accessor { + String(String), + Number(usize), +} + +impl<'a> Checker<'a> { + pub fn check_index_expression(&mut self, table_expr: &ast::IndexExpression) -> Result { + let base_type = self.check_expression(&table_expr.base)?; + let base_range = table_expr.base.get_range(); + + match base_type { + Type::Table(ref table_type) => { + let acc = self.extract_accessor(&table_expr.index)?; + self.check_index_access(table_type, acc, table_expr.index.get_range()) + } + _ => Err(self.create_diagnostic(TypeError::ExpectedTable(base_type.to_string(), Some(base_range)))), + } + } + + fn check_index_access(&self, table: &TableType, acc: Option, range: Range) -> Result { + match acc { + Some(Accessor::String(name)) => self.check_index_access_string(table, &name, range), + Some(Accessor::Number(index)) => self.check_index_access_number(table, index, range), + // todo: return union type based on table values + None => Ok(Type::Unknown), + } + } + + fn check_index_access_string(&self, table: &TableType, name: &str, range: Range) -> Result { + if let Some(value_type) = table.get_type(name) { + Ok(value_type.clone()) + } else { + Err(self.create_diagnostic(TypeError::KeyNotFoundInTable(name.to_string(), table.to_string(), Some(range)))) + } + } + + fn check_index_access_number(&self, table: &TableType, index: usize, range: Range) -> Result { + if let Some(element_type) = table.get_array_type(index) { + Ok(element_type.clone()) + } else { + Err(self.create_diagnostic(TypeError::KeyNotFoundInTable(index.to_string(), table.to_string(), Some(range)))) + } + } + + fn extract_accessor(&mut self, index_expr: &ast::Expression) -> Result, Diagnostic> { + match index_expr { + ast::Expression::Literal(literal) => match literal { + ast::LiteralExpression::String(string) => Ok(Some(Accessor::String(string.value.clone()))), + ast::LiteralExpression::Number(number) => { + let number = number.value.parse::(); + // todo: create diagnostic ?? + return if number.is_err() { Ok(None) } else { Ok(Some(Accessor::Number(number.unwrap()))) }; + } + _ => self.handle_non_literal_index(index_expr), + }, + _ => self.handle_non_literal_index(index_expr), + } + } + + fn handle_non_literal_index(&mut self, index_expr: &ast::Expression) -> Result, Diagnostic> { + self.check_expression_index(index_expr)?; + Ok(None) + } + + fn check_expression_index(&mut self, key_expr: &ast::Expression) -> Result<(), Diagnostic> { + let expr_type = self.check_expression(key_expr)?; + match expr_type { + Type::String | Type::Number => Ok(()), + _ => Err( + self.create_diagnostic(TypeError::MismatchedAccessorType(expr_type.to_string(), Some(key_expr.get_range()))), + ), + } + } +} diff --git a/src/checker/check_member_expression.rs b/src/checker/check_member_expression.rs new file mode 100644 index 0000000..8b4f154 --- /dev/null +++ b/src/checker/check_member_expression.rs @@ -0,0 +1,68 @@ +use super::Checker; +use crate::{ + ast::ast, + diagnostics::{Diagnostic, TypeError}, + types::{TableType, Type}, + utils::range::Range, +}; + +impl<'a> Checker<'a> { + pub fn check_member_expression(&mut self, member: &ast::MemberExpression) -> Result { + let base_type = self.check_expression(&member.base)?; + let base_range = member.base.get_range(); + + match base_type { + Type::Table(ref table_type) => self.check_member(table_type, &member.member), + _ => Err(self.create_diagnostic(TypeError::ExpectedTable(base_type.to_string(), Some(base_range)))), + } + } + + fn check_member(&mut self, table: &TableType, member: &ast::Expression) -> Result { + match member { + ast::Expression::Literal(literal) => { + if let ast::LiteralExpression::String(string) = literal { + self.check_identifier_member(&string.value, table, string.range.clone()) + } else { + Err(self.create_member_error(member)) + } + } + ast::Expression::Call(call) => self.check_call_member(call, table), + ast::Expression::Identifier(identifier) => { + self.check_identifier_member(&identifier.name, table, identifier.range.clone()) + } + _ => Err(self.create_member_error(member)), + } + } + + fn check_identifier_member(&mut self, name: &str, table: &TableType, range: Range) -> Result { + if let Some(member_type) = table.get_type(name) { + Ok(member_type.clone()) + } else { + Err(self.create_not_found_key_error(name, table, range)) + } + } + + fn create_member_error(&mut self, member: &ast::Expression) -> Diagnostic { + let range = member.get_range(); + match self.check_expression(member) { + Ok(member_type) => { + let diagnostic = TypeError::MismatchedKeyType(member_type.to_string(), Some(range)); + self.create_diagnostic(diagnostic) + } + Err(diagnostic) => diagnostic, + } + } + + fn create_not_found_key_error(&self, name: &str, table: &TableType, range: Range) -> Diagnostic { + self.create_diagnostic(TypeError::KeyNotFoundInTable(name.to_string(), table.to_string(), Some(range))) + } + + fn check_call_member(&mut self, call: &ast::CallExpression, table: &TableType) -> Result { + let name = call.name.lexeme(); + if let Some(member_type) = table.get_type(name) { + self.check_call_type(member_type, &call.args, call.name.range.clone()) + } else { + Err(self.create_not_found_key_error(name, table, call.get_range())) + } + } +} diff --git a/src/checker/check_table_expression.rs b/src/checker/check_table_expression.rs new file mode 100644 index 0000000..7949706 --- /dev/null +++ b/src/checker/check_table_expression.rs @@ -0,0 +1,51 @@ +use std::collections::BTreeMap; + +use super::Checker; +use crate::{ + ast::ast, + diagnostics::{Diagnostic, TypeError}, + types::Type, +}; + +impl<'a> Checker<'a> { + pub fn check_table_expression(&mut self, table_expr: &ast::TableExpression) -> Result { + let mut array_elements = vec![]; + let mut map_elements = BTreeMap::new(); + + for (key_expr, value_expr) in &table_expr.values { + if let Some(value_expr) = value_expr { + let value_type = self.check_expression(value_expr)?; + let key_str = self.extract_table_key(key_expr)?; + map_elements.insert(key_str, value_type); + } else { + let array_element_type = self.check_expression(key_expr)?; + array_elements.push(array_element_type); + } + } + + let table_type = Type::new_table( + if array_elements.is_empty() { None } else { Some(array_elements) }, + if map_elements.is_empty() { None } else { Some(map_elements) }, + ); + + Ok(table_type) + } + + fn extract_table_key(&mut self, key_expr: &ast::Expression) -> Result { + match key_expr { + ast::Expression::Identifier(identifier) => Ok(identifier.name.clone()), + ast::Expression::Literal(literal) => match literal { + ast::LiteralExpression::String(string) => Ok(string.value.clone()), + _ => self.create_invalid_literal_key_error(key_expr), + }, + _ => self.create_invalid_literal_key_error(key_expr), + } + } + + fn create_invalid_literal_key_error(&mut self, key_expr: &ast::Expression) -> Result { + let range = key_expr.get_range(); + let expr_type = self.check_expression(key_expr)?; + let diagnostic = TypeError::MismatchedKeyType(expr_type.to_string(), Some(range)); + Err(self.create_diagnostic(diagnostic)) + } +} diff --git a/src/checker/check_type.rs b/src/checker/check_type.rs index 9c62a98..5b93670 100644 --- a/src/checker/check_type.rs +++ b/src/checker/check_type.rs @@ -1,11 +1,13 @@ -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use super::Checker; use crate::{ diagnostics::{Diagnostic, TypeError}, - types::{self, Type}, + types::{self, GenericCallType, TableType, Type}, }; +type GenericBinds = HashMap; + impl<'a> Checker<'a> { pub fn check_optional_type(&mut self, ty: &Option, assume_nil: bool) -> Result { match ty { @@ -43,11 +45,37 @@ impl<'a> Checker<'a> { } } - pub fn create_generic_table(&self, types: &[Type], variables: &[String]) -> HashMap { + pub fn create_generic_table(&self, types: &[Type], variables: &[String]) -> GenericBinds { variables.iter().cloned().zip(types.iter().cloned()).collect() } - pub fn apply_generic_binds(&self, generic_value: &Type, binds: &HashMap) -> Result { + pub fn apply_generic_bind_table(&self, table: &TableType, binds: &GenericBinds) -> Result { + let array = table + .array + .as_ref() + .map(|array| array.iter().map(|ty| self.apply_generic_binds(ty, binds)).collect::, _>>()) + .transpose()?; + + let map = table + .map + .as_ref() + .map(|map| { + map + .iter() + .map(|(key, ty)| Ok((key.clone(), self.apply_generic_binds(ty, binds)?))) + .collect::, Diagnostic>>() + }) + .transpose()?; + + Ok(Type::Table(TableType { array, map })) + } + + pub fn apply_generic_bind_call(&self, call: &GenericCallType, binds: &GenericBinds) -> Result { + let types = call.types.iter().map(|ty| self.apply_generic_binds(ty, binds)).collect::, _>>()?; + Ok(Type::GenericCall(types::GenericCallType { name: call.name.clone(), types, range: call.range.clone() })) + } + + pub fn apply_generic_binds(&self, generic_value: &Type, binds: &GenericBinds) -> Result { match generic_value { Type::Identifier(identifier) => { if let Some(bound_type) = binds.get(&identifier.name) { @@ -62,11 +90,7 @@ impl<'a> Checker<'a> { let return_type = self.apply_generic_binds(&function.return_type, binds)?; Ok(Type::Function(types::FunctionType { params, return_type: Box::new(return_type) })) } - Type::Table(table) => { - let key_type = self.apply_generic_binds(&table.key_type, binds)?; - let value_type = self.apply_generic_binds(&table.value_type, binds)?; - Ok(Type::Table(types::TableType { key_type: Box::new(key_type), value_type: Box::new(value_type) })) - } + Type::Table(table) => self.apply_generic_bind_table(table, binds), Type::Union(union) => { let types = union.types.iter().map(|ty| self.apply_generic_binds(ty, binds)).collect::, _>>()?; Ok(Type::Union(types::UnionType { types })) @@ -79,16 +103,7 @@ impl<'a> Checker<'a> { let types = group.types.iter().map(|ty| self.apply_generic_binds(ty, binds)).collect::, _>>()?; Ok(Type::Grup(types::GrupType { types })) } - Type::GenericCall(generic_call) => { - let types = - generic_call.types.iter().map(|ty| self.apply_generic_binds(ty, binds)).collect::, _>>()?; - - Ok(Type::GenericCall(types::GenericCallType { - name: generic_call.name.clone(), - types, - range: generic_call.range.clone(), - })) - } + Type::GenericCall(generic_call) => self.apply_generic_bind_call(generic_call, binds), _ => Ok(generic_value.clone()), } } @@ -108,7 +123,7 @@ impl<'a> Checker<'a> { }) } - pub fn create_generic_table_type(&self, generics: &[String], inferred: &[Type]) -> HashMap { + pub fn create_generic_table_type(&self, generics: &[String], inferred: &[Type]) -> GenericBinds { generics.iter().cloned().zip(inferred.iter().cloned()).collect() } } diff --git a/src/checker/mod.rs b/src/checker/mod.rs index b89c97b..74cfcbe 100644 --- a/src/checker/mod.rs +++ b/src/checker/mod.rs @@ -12,11 +12,14 @@ pub mod check_function_statement; pub mod check_grouped_expression; pub mod check_identifier; pub mod check_if_statement; +pub mod check_index_expression; pub mod check_literal_expression; +pub mod check_member_expression; pub mod check_repeat_statement; pub mod check_require_expression; pub mod check_return_statement; pub mod check_statement; +pub mod check_table_expression; pub mod check_type; pub mod check_type_declaration; pub mod check_unary_expression; diff --git a/src/diagnostics/format.rs b/src/diagnostics/format.rs index 8d39be9..3b1a8ed 100644 --- a/src/diagnostics/format.rs +++ b/src/diagnostics/format.rs @@ -5,6 +5,9 @@ pub fn format_mismatched_types(expected: &str, found: &str) -> String { pub fn format_expected_function(name: &str) -> String { format!("expected function, found `{}`", name) } +pub fn format_expected_table(message: &str) -> String { + format!("expected table, found `{}`", message) +} pub fn format_undeclared_variable(name: &str) -> String { format!("cannot find value `{}` in this scope", name) @@ -14,6 +17,17 @@ pub fn format_undeclared_type(name: &str) -> String { format!("cannot find type `{}` in this scope", name) } +pub fn format_mismatched_key_type(key: &str) -> String { + format!("expected `string` key, but found `{}`", key) +} + +pub fn format_mismatched_accessor_type(index: &str) -> String { + format!("expected `number` or `string`, but found `{}`", index) +} +pub fn format_no_field(base: &str, member: &str) -> String { + format!("no field `{}` on type `{}`", member, base) +} + pub fn format_function_arity_mismatch(expected: usize, found: usize) -> String { if expected > 1 { format!("expected `{}` arguments, found `{}`", expected, found) @@ -22,6 +36,14 @@ pub fn format_function_arity_mismatch(expected: usize, found: usize) -> String { } } +pub fn format_cannot_index_non_array(type_name: &str) -> String { + format!("cannot index into non-array type `{}`", type_name) +} + +pub fn format_key_not_found_in_table(key: &str, table: &str) -> String { + format!("key `{}` not found in table `{}`", key, table) +} + pub fn format_unsupported_operator(left: &str, right: &str, oper: &str) -> String { format!("unsupported operator `{}` for `{}` and `{}`", oper, left, right) } diff --git a/src/diagnostics/mod.rs b/src/diagnostics/mod.rs index d57fe87..f838677 100644 --- a/src/diagnostics/mod.rs +++ b/src/diagnostics/mod.rs @@ -10,8 +10,9 @@ use report::report_error; use std::fmt::{self, Debug}; use format::{ - format_expected_function, format_function_arity_mismatch, format_mismatched_types, - format_missing_variable_declaration, format_module_not_exported, format_module_not_found, + format_cannot_index_non_array, format_expected_function, format_expected_table, format_function_arity_mismatch, + format_key_not_found_in_table, format_mismatched_accessor_type, format_mismatched_key_type, format_mismatched_types, + format_missing_variable_declaration, format_module_not_exported, format_module_not_found, format_no_field, format_redeclared_in_same_scope, format_type_mismatch_assignment, format_undeclared_type, format_undeclared_variable, format_unsupported_operator, format_warning_shadow_warning, format_warning_unused_variable, }; @@ -126,8 +127,13 @@ pub enum TypeError { FunctionArityMismatch(usize, usize, Option), UnsupportedOperator(String, String, String, Option), MissingVariableDeclaration(Option), + ExpectedTable(String, Option), + MismatchedKeyType(String, Option), + NoField(String, String, Option), + CantIndexNonArray(String, Option), + KeyNotFoundInTable(String, String, Option), + MismatchedAccessorType(String, Option), } - impl From for Diagnostic { fn from(error: TypeError) -> Self { let (message, range) = match error { @@ -144,6 +150,12 @@ impl From for Diagnostic { } TypeError::UndeclaredType(name, rg) => (format_undeclared_type(&name), rg), TypeError::ExpectedFunction(name, rg) => (format_expected_function(&name), rg), + TypeError::NoField(base, member, rg) => (format_no_field(&base, &member), rg), + TypeError::MismatchedKeyType(key, rg) => (format_mismatched_key_type(&key), rg), + TypeError::ExpectedTable(message, rg) => (format_expected_table(&message), rg), + TypeError::CantIndexNonArray(type_name, rg) => (format_cannot_index_non_array(&type_name), rg), + TypeError::KeyNotFoundInTable(key, table, rg) => (format_key_not_found_in_table(&key, &table), rg), + TypeError::MismatchedAccessorType(index, rg) => (format_mismatched_accessor_type(&index), rg), }; Diagnostic::new(DiagnosticLevel::Error, message, range) diff --git a/src/parser/parser.rs b/src/parser/parser.rs index 06549f5..da86967 100644 --- a/src/parser/parser.rs +++ b/src/parser/parser.rs @@ -12,8 +12,6 @@ pub struct Parser<'a> { raw: &'a str, } -type ParseExpression = fn(&mut Parser) -> ast::Expression; - impl<'a> Parser<'a> { pub fn new(raw: &'a str, file_name: &'a str) -> Self { Self { lexer: Lexer::new(raw, file_name), raw } @@ -45,13 +43,14 @@ impl<'a> Parser<'a> { TokenKind::Return => self.parse_return_statement(), TokenKind::Function => self.parse_function_statement(false), TokenKind::Type => self.parse_type_declaration(), - _ => self.parse_assign_or_call(), + _ => self.parse_assign_expression_or_call_expression(), }; - self.match_token_and_consume(TokenKind::Semicolon); statement } + // declarations + // fn parse_local_declaration(&mut self) -> ast::Statement { self.consume_expect_token(TokenKind::Local); if self.match_token(&TokenKind::Function) { @@ -60,30 +59,127 @@ impl<'a> Parser<'a> { self.parse_variable_declaration(true) } - fn parse_expression_statement(&mut self) -> ast::Expression { - self.parse_or_expression() + fn parse_type_declaration(&mut self) -> ast::Statement { + let range = self.consume_expect_token(TokenKind::Type).range.clone(); + let name = self.consume_token(); + let generics = self.parse_simple_generic_str(); + self.consume_expect_token(TokenKind::Assign); + // we can accept function type + let initilizer = self.parse_type(false); + ast::Statement::TypeDeclaration(ast::TypeDeclaration::new(name, generics, initilizer, range)) + } + fn parse_variable_declaration(&mut self, local: bool) -> ast::Statement { + let valyes = self.parse_variable_and_type(None); + let mut initializer = None; + if self.match_token_and_consume(TokenKind::Assign).is_some() { + initializer = Some(self.parse_expression_statement()); + } + ast::Statement::VariableDeclaration(ast::VariableDeclaration::new(valyes, local, initializer)) } - fn parse_or_expression(&mut self) -> ast::Expression { - // self.parse_binary_expression(Self::parse_and_expression, vec![BinaryOperator::Or]) - self.parse_binary_expression(Self::parse_and_expression) + fn parse_function_statement(&mut self, local: bool) -> ast::Statement { + let start = self.consume_expect_token(TokenKind::Function).range; + let name = self.consume_token(); + let generics = self.parse_generic_type(); + self.consume_expect_token(TokenKind::LeftParen); + let arguments = self.parse_arguments_with_option_type(); + self.consume_expect_token(TokenKind::RightParen); + + let mut return_type = None; + + let mut range_return_type = None; + if self.match_token(&TokenKind::Colon) { + self.consume_expect_token(TokenKind::Colon); + range_return_type = Some(self.lexer.peek_token().range.clone()); + return_type = Some(self.parse_type(true)) + } + let body = self.parse_block_statement(&[TokenKind::End]); + let end_range = self.consume_expect_token(TokenKind::End).range; + let range = create_middle_range(&start, &end_range); + ast::Statement::Function(ast::FunctionStatement::new( + name, + local, + generics, + arguments, + return_type, + body, + range, + range_return_type, + )) } - fn parse_and_expression(&mut self) -> ast::Expression { - // self.parse_binary_expression(Self::parse_unary_expression, vec![BinaryOperator::And]) - self.parse_binary_expression(Self::parse_unary_expression) + // flow control + // + + fn parse_if_statement(&mut self) -> ast::Statement { + let start_range = self.consume_expect_token(TokenKind::If).range; + let condition = self.parse_expression_statement(); + self.consume_expect_token(TokenKind::Then); + let then_body = self.parse_block_statement(&[TokenKind::Else, TokenKind::End]); + let else_body = if self.match_token_and_consume(TokenKind::Else).is_some() { + Some(self.parse_block_statement(&[TokenKind::End])) + } else { + None + }; + let end_range = self.consume_expect_token(TokenKind::End).range; + let range = create_middle_range(&start_range, &end_range); + ast::Statement::If(ast::IfStatement::new(condition, then_body, else_body, range)) } - fn parse_binary_expression(&mut self, parse_sub_expression: fn(&mut Self) -> ast::Expression) -> ast::Expression { - let mut expression = parse_sub_expression(self); - while let Some((operator, range)) = self.parse_operator() { - let right_expression = parse_sub_expression(self); - let binary_exp = ast::BinaryExpression::new(operator, Box::new(expression), Box::new(right_expression), range); - expression = ast::Expression::Binary(binary_exp); - } - expression + fn parse_while_statement(&mut self) -> ast::Statement { + let while_token = self.consume_expect_token(TokenKind::While); + let condition = self.parse_expression_statement(); + self.consume_expect_token(TokenKind::Do); + let body = self.parse_block_statement(&[TokenKind::End]); + self.consume_expect_token(TokenKind::End); + ast::Statement::While(ast::WhileStatement::new(condition, body, while_token.range)) + } + + fn parse_repeat_statement(&mut self) -> ast::Statement { + let repeat_token = self.consume_expect_token(TokenKind::Repeat); + let body = self.parse_block_statement(&[TokenKind::Until]); + self.consume_expect_token(TokenKind::Until); + let condition = self.parse_expression_statement(); + ast::Statement::Repeat(ast::RepeatStatement::new(body, condition, repeat_token.range)) + } + + fn parse_for_statement(&mut self) -> ast::Statement { + let start_range = self.consume_expect_token(TokenKind::For).range; + let variable = self.parse_identifier(); + self.consume_expect_token(TokenKind::Assign); + let init = self.parse_expression_statement(); + self.consume_expect_token(TokenKind::Comma); + let limit = self.parse_expression_statement(); + let step = if self.match_token_and_consume(TokenKind::Comma).is_some() { + Some(self.parse_expression_statement()) + } else { + None + }; + self.consume_expect_token(TokenKind::Do); + let body = self.parse_block_statement(&[TokenKind::End]); + let end_range = self.consume_expect_token(TokenKind::End).range; + let range = create_middle_range(&start_range, &end_range); + ast::Statement::For(ast::ForStatement::new(variable, init, limit, step, body, range)) + } + + fn parse_break_statement(&mut self) -> ast::Statement { + let break_token = self.consume_expect_token(TokenKind::Break); + ast::Statement::Break(ast::BreakStatement::new(break_token.range)) + } + + fn parse_continue_statement(&mut self) -> ast::Statement { + // Implementação ausente + unimplemented!() + } + + fn parse_return_statement(&mut self) -> ast::Statement { + let return_token = self.consume_expect_token(TokenKind::Return); + let values = self.parse_return_value(); + ast::Statement::Return(ast::ReturnStatement::new(values, return_token.range)) } + // expressions + // fn parse_function_expression(&mut self) -> ast::Expression { let start_range = self.consume_expect_token(TokenKind::Function).range; self.consume_expect_token(TokenKind::LeftParen); @@ -103,28 +199,25 @@ impl<'a> Parser<'a> { ast::Expression::new_function(arguments, return_type, body, range, range_return_type) } - fn parse_operator(&mut self) -> Option<(ast::BinaryOperator, Range)> { - let token = self.lexer.peek_token(); - let operator = match token.kind { - TokenKind::Plus => ast::BinaryOperator::Add, - TokenKind::Minus => ast::BinaryOperator::Subtract, - TokenKind::Star => ast::BinaryOperator::Multiply, - TokenKind::Slash => ast::BinaryOperator::Divide, - TokenKind::Percent => ast::BinaryOperator::Modulus, - TokenKind::And => ast::BinaryOperator::And, - TokenKind::Or => ast::BinaryOperator::Or, - TokenKind::Equal => ast::BinaryOperator::Equal, - TokenKind::NotEqual => ast::BinaryOperator::NotEqual, - TokenKind::Less => ast::BinaryOperator::LessThan, - TokenKind::Greater => ast::BinaryOperator::GreaterThan, - TokenKind::LessEqual => ast::BinaryOperator::LessThanOrEqual, - TokenKind::GreaterEqual => ast::BinaryOperator::GreaterThanOrEqual, - TokenKind::DoubleDot => ast::BinaryOperator::DoubleDot, - TokenKind::DoubleSlash => ast::BinaryOperator::DoubleSlash, - _ => return None, - }; - self.lexer.next_token(); - return Some((operator, token.range)); + + fn parse_or_expression(&mut self) -> ast::Expression { + // self.parse_binary_expression(Self::parse_and_expression, vec![BinaryOperator::Or]) + self.parse_binary_expression(Self::parse_and_expression) + } + + fn parse_and_expression(&mut self) -> ast::Expression { + // self.parse_binary_expression(Self::parse_unary_expression, vec![BinaryOperator::And]) + self.parse_binary_expression(Self::parse_unary_expression) + } + + fn parse_binary_expression(&mut self, parse_sub_expression: fn(&mut Self) -> ast::Expression) -> ast::Expression { + let mut expression = parse_sub_expression(self); + while let Some((operator, range)) = self.parse_operator() { + let right_expression = parse_sub_expression(self); + let binary_exp = ast::BinaryExpression::new(operator, Box::new(expression), Box::new(right_expression), range); + expression = ast::Expression::Binary(binary_exp); + } + expression } fn parse_unary_expression(&mut self) -> ast::Expression { @@ -137,30 +230,69 @@ impl<'a> Parser<'a> { self.parse_primary_expression() } } - - fn parse_unary_operator(&mut self) -> Option { - let token = self.lexer.peek_token(); - let operator = match token.kind { - TokenKind::Minus => ast::UnaryOperator::Negate, - TokenKind::Not => ast::UnaryOperator::Not, - TokenKind::Hash => ast::UnaryOperator::Hash, - _ => return None, - }; - self.lexer.next_token(); - Some(operator) - } - fn parse_primary_expression(&mut self) -> ast::Expression { let token = self.lexer.peek_token(); - match token.kind { + let mut expression = match token.kind { TokenKind::Number(_) | TokenKind::String(_) => self.parse_literal_expression(), TokenKind::True | TokenKind::False => self.parse_literal_expression(), TokenKind::Identifier(_) => self.parse_identifier(), TokenKind::LeftParen => self.parse_grouped_expression(), TokenKind::Require => self.parse_require_expression(), TokenKind::Function => self.parse_function_expression(), + TokenKind::LeftBrace => self.parse_table_expression(), _ => self.report_unexpected_token(token), + }; + + while self.match_token(&TokenKind::Dot) || self.match_token(&TokenKind::LeftBracket) { + if self.match_token(&TokenKind::Dot) { + expression = self.parse_member_expression(expression); + } else if self.match_token(&TokenKind::LeftBracket) { + expression = self.parse_index_expression(expression); + } } + expression + } + + fn parse_member_expression(&mut self, base: ast::Expression) -> ast::Expression { + self.consume_expect_token(TokenKind::Dot); // consume '.' + let member_expression = self.parse_expression_statement(); + ast::Expression::new_member(base, member_expression) + } + + fn parse_index_expression(&mut self, base: ast::Expression) -> ast::Expression { + let start_range = self.consume_expect_token(TokenKind::LeftBracket).range; // consume '[' + let index_expression = self.parse_expression_statement(); + let end_range = self.consume_expect_token(TokenKind::RightBracket).range; // consume ']' + + let bracket_range = create_middle_range(&start_range, &end_range); + + ast::Expression::Index(ast::IndexExpression { + base: Box::new(base), + index: Box::new(index_expression), + bracket_range, + }) + } + + fn parse_table_expression(&mut self) -> ast::Expression { + let left_range = self.consume_expect_token(TokenKind::LeftBrace).range; + let mut values = vec![]; + while !self.match_token(&TokenKind::RightBrace) { + let value_or_key = self.parse_expression_statement(); + if self.match_token_and_consume(TokenKind::Assign).is_some() { + let value = self.parse_expression_statement(); + values.push((value_or_key, Some(value))); + } else { + values.push((value_or_key, None)); + } + if self.match_token(&TokenKind::RightBrace) { + break; + } + // skip comma + self.consume_expect_token(TokenKind::Comma); + } + let right_range = self.consume_expect_token(TokenKind::RightBrace).range; + let range = create_middle_range(&left_range, &right_range); + ast::Expression::new_table(values, range) } fn parse_literal_expression(&mut self) -> ast::Expression { @@ -174,23 +306,13 @@ impl<'a> Parser<'a> { } } - fn parse_identifier(&mut self) -> ast::Expression { - let token = self.lexer.next_token(); - if self.match_token(&TokenKind::LeftParen) { - return self.parse_call_expression(token); - } - match token.kind { - TokenKind::Identifier(name) => ast::Expression::new_identifier(name, token.range), - _ => self.report_unexpected_token(token), - } - } - fn parse_grouped_expression(&mut self) -> ast::Expression { let left_range = self.consume_expect_token(TokenKind::LeftParen).range; let mut expressions = vec![]; - if self.match_token_and_consume(TokenKind::RightParen).is_some() { - return ast::Expression::new_grouped(expressions, left_range); + if self.match_token(&TokenKind::RightParen) { + let right_range = self.consume_expect_token(TokenKind::RightParen).range; + return ast::Expression::new_grouped(expressions, create_middle_range(&left_range, &right_range)); } expressions.push(self.parse_expression_statement()); @@ -212,17 +334,39 @@ impl<'a> Parser<'a> { ast::Expression::new_require(module_name, range) } - fn parse_assign_or_call(&mut self) -> ast::Statement { + fn parse_assign_expression_or_call_expression(&mut self) -> ast::Statement { let token = self.lexer.peek_token(); match token.kind { - TokenKind::Identifier(_) => self.parse_possible_assign_or_call(token), + TokenKind::Identifier(_) => self.parse_assign_statement_or_call_statement(token), + TokenKind::LeftBrace => { + let table_expression = self.parse_table_expression(); + ast::Statement::Expression(table_expression) + } _ => self.report_unexpected_token(token), } } - fn parse_possible_assign_or_call(&mut self, ident_token: Token) -> ast::Statement { + fn parse_call_expression(&mut self, ident: Token) -> ast::Expression { + let args = self.parse_grouped_expression(); + ast::Expression::new_call(ident, args) + } + + // statements + // + fn parse_block_statement(&mut self, end_tokens: &[TokenKind]) -> ast::Statement { + let mut statements = Vec::new(); + // skip block comments + self.skip_comments(); + while !self.contains_token(end_tokens) { + statements.push(self.parse_statement()); + // skip block comments + self.skip_comments(); + } + ast::Statement::Block(ast::BlockStatement::new(statements)) + } + + fn parse_assign_statement_or_call_statement(&mut self, ident_token: Token) -> ast::Statement { self.lexer.next_token(); // consume identifier - // if next is `=` or `:` then it's an assign statement if self.match_token(&TokenKind::Assign) || self.match_token(&TokenKind::Colon) { self.parse_assign_statement(ident_token) } else if self.match_token(&TokenKind::LeftParen) { @@ -250,92 +394,14 @@ impl<'a> Parser<'a> { ast::Statement::Assign(ast::AssignStatement::new(values, init)) } - fn parse_variable_and_type(&mut self, token: Option) -> Vec<(Token, Option)> { - let mut names: Vec<(Token, Option)> = Vec::new(); - - let token = if token.is_some() { token.unwrap() } else { self.consume_token() }; - - let mut name = (token, None); - if self.match_token_and_consume(TokenKind::Colon).is_some() { - name.1 = Some(self.parse_type(false)); - }; - names.push(name); - while self.match_token_and_consume(TokenKind::Comma).is_some() { - name = (self.consume_token(), None); - if self.match_token_and_consume(TokenKind::Colon).is_some() { - name.1 = Some(self.parse_type(false)); - }; - names.push(name); - } - names - } - - fn parse_call_expression(&mut self, ident: Token) -> ast::Expression { - let args = self.parse_grouped_expression(); - ast::Expression::new_call(ident, args) - } - - fn parse_function_statement(&mut self, local: bool) -> ast::Statement { - let start = self.consume_expect_token(TokenKind::Function).range; - let name = self.consume_token(); - let generics = self.parse_generic_type(); - self.consume_expect_token(TokenKind::LeftParen); - let arguments = self.parse_arguments_with_option_type(); - self.consume_expect_token(TokenKind::RightParen); - - let mut return_type = None; - - let mut range_return_type = None; - if self.match_token(&TokenKind::Colon) { - self.consume_expect_token(TokenKind::Colon); - range_return_type = Some(self.lexer.peek_token().range.clone()); - return_type = Some(self.parse_type(true)) - } - let body = self.parse_block_statement(&[TokenKind::End]); - let end_range = self.consume_expect_token(TokenKind::End).range; - let range = create_middle_range(&start, &end_range); - ast::Statement::Function(ast::FunctionStatement::new( - name, - local, - generics, - arguments, - return_type, - body, - range, - range_return_type, - )) + fn parse_expression_statement(&mut self) -> ast::Expression { + self.parse_or_expression() } - fn parse_type(&mut self, excepted_paren: bool) -> Type { - let token = self.lexer.peek_token(); - if !excepted_paren && token.kind == TokenKind::LeftParen { - self.report_unexpected_token(token) - } - match token.kind { - TokenKind::Identifier(name) => { - self.lexer.next_token(); - if self.match_token(&TokenKind::Less) { - self.consume_expect_token(TokenKind::Less); - let mut types = Vec::new(); - while !self.match_token(&TokenKind::Greater) { - let ty = self.parse_type(false); - types.push(ty); - self.match_token_and_consume(TokenKind::Comma); - } - let right_range = self.consume_expect_token(TokenKind::Greater).range; - let range = create_middle_range(&token.range, &right_range); - return Type::new_generic_call(name, types, range); - } + // utils functions + // - Type::new_type(&name, token.range) - } - TokenKind::LeftParen => self.parse_grup_return_type(), - TokenKind::Function => self.parse_function_type(), - _ => self.report_unexpected_token(token), - } - } - - pub fn parse_arguments_with_function_type(&mut self) -> Vec { + pub fn parse_arguments_with_type(&mut self) -> Vec { let mut arguments = Vec::new(); while !self.match_token(&TokenKind::RightParen) { // todo: what I should do here? @@ -351,7 +417,7 @@ impl<'a> Parser<'a> { pub fn parse_function_type(&mut self) -> Type { self.consume_expect_token(TokenKind::Function); self.consume_expect_token(TokenKind::LeftParen); - let params = self.parse_arguments_with_function_type(); + let params = self.parse_arguments_with_type(); self.consume_expect_token(TokenKind::RightParen); self.consume_expect_token(TokenKind::Colon); let return_type = self.parse_type(true); @@ -415,113 +481,114 @@ impl<'a> Parser<'a> { generics } - fn parse_type_declaration(&mut self) -> ast::Statement { - let range = self.consume_expect_token(TokenKind::Type).range.clone(); - let name = self.consume_token(); - let generics = self.parse_simple_generic_str(); - self.consume_expect_token(TokenKind::Assign); - // we can accept function type - let initilizer = self.parse_type(false); - ast::Statement::TypeDeclaration(ast::TypeDeclaration::new(name, generics, initilizer, range)) - } - fn parse_variable_declaration(&mut self, local: bool) -> ast::Statement { - let valyes = self.parse_variable_and_type(None); - let mut initializer = None; - if self.match_token_and_consume(TokenKind::Assign).is_some() { - initializer = Some(self.parse_expression_statement()); + fn parse_return_value(&mut self) -> Vec { + let mut values = Vec::new(); + if self.match_token(&TokenKind::Semicolon) { + return values; } - ast::Statement::VariableDeclaration(ast::VariableDeclaration::new(valyes, local, initializer)) - } - - fn parse_block_statement(&mut self, end_tokens: &[TokenKind]) -> ast::Statement { - let mut statements = Vec::new(); - // skip block comments - self.skip_comments(); - while !self.contains_token(end_tokens) { - statements.push(self.parse_statement()); - // skip block comments - self.skip_comments(); + values.push(self.parse_expression_statement()); + while self.match_token_and_consume(TokenKind::Comma).is_some() { + values.push(self.parse_expression_statement()); } - ast::Statement::Block(ast::BlockStatement::new(statements)) + values } - fn parse_if_statement(&mut self) -> ast::Statement { - let start_range = self.consume_expect_token(TokenKind::If).range; - let condition = self.parse_expression_statement(); - self.consume_expect_token(TokenKind::Then); - let then_body = self.parse_block_statement(&[TokenKind::Else, TokenKind::End]); - let else_body = if self.match_token_and_consume(TokenKind::Else).is_some() { - Some(self.parse_block_statement(&[TokenKind::End])) - } else { - None + fn parse_operator(&mut self) -> Option<(ast::BinaryOperator, Range)> { + let token = self.lexer.peek_token(); + let operator = match token.kind { + TokenKind::Plus => ast::BinaryOperator::Add, + TokenKind::Minus => ast::BinaryOperator::Subtract, + TokenKind::Star => ast::BinaryOperator::Multiply, + TokenKind::Slash => ast::BinaryOperator::Divide, + TokenKind::Percent => ast::BinaryOperator::Modulus, + TokenKind::And => ast::BinaryOperator::And, + TokenKind::Or => ast::BinaryOperator::Or, + TokenKind::Equal => ast::BinaryOperator::Equal, + TokenKind::NotEqual => ast::BinaryOperator::NotEqual, + TokenKind::Less => ast::BinaryOperator::LessThan, + TokenKind::Greater => ast::BinaryOperator::GreaterThan, + TokenKind::LessEqual => ast::BinaryOperator::LessThanOrEqual, + TokenKind::GreaterEqual => ast::BinaryOperator::GreaterThanOrEqual, + TokenKind::DoubleDot => ast::BinaryOperator::DoubleDot, + TokenKind::DoubleSlash => ast::BinaryOperator::DoubleSlash, + _ => return None, }; - let end_range = self.consume_expect_token(TokenKind::End).range; - let range = create_middle_range(&start_range, &end_range); - ast::Statement::If(ast::IfStatement::new(condition, then_body, else_body, range)) + self.lexer.next_token(); + return Some((operator, token.range)); } - fn parse_while_statement(&mut self) -> ast::Statement { - let while_token = self.consume_expect_token(TokenKind::While); - let condition = self.parse_expression_statement(); - self.consume_expect_token(TokenKind::Do); - let body = self.parse_block_statement(&[TokenKind::End]); - self.consume_expect_token(TokenKind::End); - ast::Statement::While(ast::WhileStatement::new(condition, body, while_token.range)) + fn parse_unary_operator(&mut self) -> Option { + let token = self.lexer.peek_token(); + let operator = match token.kind { + TokenKind::Minus => ast::UnaryOperator::Negate, + TokenKind::Not => ast::UnaryOperator::Not, + TokenKind::Hash => ast::UnaryOperator::Hash, + _ => return None, + }; + self.lexer.next_token(); + Some(operator) } - fn parse_repeat_statement(&mut self) -> ast::Statement { - let repeat_token = self.consume_expect_token(TokenKind::Repeat); - let body = self.parse_block_statement(&[TokenKind::Until]); - self.consume_expect_token(TokenKind::Until); - let condition = self.parse_expression_statement(); - ast::Statement::Repeat(ast::RepeatStatement::new(body, condition, repeat_token.range)) - } + fn parse_identifier(&mut self) -> ast::Expression { + let token = self.lexer.next_token(); - fn parse_for_statement(&mut self) -> ast::Statement { - let start_range = self.consume_expect_token(TokenKind::For).range; - let variable = self.parse_identifier(); - self.consume_expect_token(TokenKind::Assign); - let init = self.parse_expression_statement(); - self.consume_expect_token(TokenKind::Comma); - let limit = self.parse_expression_statement(); - let step = if self.match_token_and_consume(TokenKind::Comma).is_some() { - Some(self.parse_expression_statement()) - } else { - None - }; - self.consume_expect_token(TokenKind::Do); - let body = self.parse_block_statement(&[TokenKind::End]); - let end_range = self.consume_expect_token(TokenKind::End).range; - let range = create_middle_range(&start_range, &end_range); - ast::Statement::For(ast::ForStatement::new(variable, init, limit, step, body, range)) - } + if self.match_token(&TokenKind::LeftParen) { + return self.parse_call_expression(token); + } - fn parse_break_statement(&mut self) -> ast::Statement { - let break_token = self.consume_expect_token(TokenKind::Break); - ast::Statement::Break(ast::BreakStatement::new(break_token.range)) + match token.kind { + TokenKind::Identifier(name) => ast::Expression::new_identifier(name, token.range), + _ => self.report_unexpected_token(token), + } } - fn parse_continue_statement(&mut self) -> ast::Statement { - // Implementação ausente - unimplemented!() - } + fn parse_variable_and_type(&mut self, token: Option) -> Vec<(Token, Option)> { + let mut names: Vec<(Token, Option)> = Vec::new(); - fn parse_return_statement(&mut self) -> ast::Statement { - let return_token = self.consume_expect_token(TokenKind::Return); - let values = self.parse_return_value(); - ast::Statement::Return(ast::ReturnStatement::new(values, return_token.range)) + let token = if token.is_some() { token.unwrap() } else { self.consume_token() }; + + let mut name = (token, None); + if self.match_token_and_consume(TokenKind::Colon).is_some() { + name.1 = Some(self.parse_type(false)); + }; + names.push(name); + while self.match_token_and_consume(TokenKind::Comma).is_some() { + name = (self.consume_token(), None); + if self.match_token_and_consume(TokenKind::Colon).is_some() { + name.1 = Some(self.parse_type(false)); + }; + names.push(name); + } + names } - fn parse_return_value(&mut self) -> Vec { - let mut values = Vec::new(); - if self.match_token(&TokenKind::Semicolon) { - return values; + fn parse_type(&mut self, excepted_paren: bool) -> Type { + let token = self.lexer.peek_token(); + if !excepted_paren && token.kind == TokenKind::LeftParen { + self.report_unexpected_token(token) } - values.push(self.parse_expression_statement()); - while self.match_token_and_consume(TokenKind::Comma).is_some() { - values.push(self.parse_expression_statement()); + match token.kind { + TokenKind::Identifier(name) => { + self.lexer.next_token(); + if self.match_token(&TokenKind::Less) { + self.consume_expect_token(TokenKind::Less); + let mut types = Vec::new(); + while !self.match_token(&TokenKind::Greater) { + let ty = self.parse_type(false); + types.push(ty); + self.match_token_and_consume(TokenKind::Comma); + } + let right_range = self.consume_expect_token(TokenKind::Greater).range; + let range = create_middle_range(&token.range, &right_range); + return Type::new_generic_call(name, types, range); + } + + Type::new_type(&name, token.range) + } + TokenKind::LeftParen => self.parse_grup_return_type(), + TokenKind::Function => self.parse_function_type(), + _ => self.report_unexpected_token(token), } - values } fn consume_expect_token(&mut self, kind: TokenKind) -> Token { diff --git a/src/types/mod.rs b/src/types/mod.rs index c3171c4..ae5b7e3 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,5 +1,7 @@ #![allow(dead_code, unused_variables)] +use std::collections::BTreeMap; + use crate::{ast::ast::BinaryOperator, utils::range::Range}; use serde::{Deserialize, Serialize}; @@ -27,6 +29,18 @@ impl Type { pub fn is_nil(&self) -> bool { matches!(self, Type::Nil) } + + pub fn is_string(&self) -> bool { + matches!(self, Type::String) + } + + pub fn is_number(&self) -> bool { + matches!(self, Type::Number) + } + + pub fn is_boolean(&self) -> bool { + matches!(self, Type::Boolean) + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -59,8 +73,80 @@ pub struct OptionalType { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TableType { - pub key_type: Box, - pub value_type: Box, + pub array: Option>, + pub map: Option>, +} + +impl TableType { + pub fn new(array: Option>, map: Option>) -> Self { + TableType { array, map } + } + + pub fn new_array(array: Vec) -> Self { + TableType { array: Some(array), map: None } + } + + pub fn new_map(map: BTreeMap) -> Self { + TableType { array: None, map: Some(map) } + } + + pub fn is_array(&self) -> bool { + self.array.is_some() + } + + pub fn is_map(&self) -> bool { + self.map.is_some() + } + + pub fn get_type(&self, key: &str) -> Option<&Type> { + if let Some(map) = &self.map { + map.get(key) + } else { + None + } + } + pub fn get_array_len(&self) -> Option { + self.array.as_ref().map(|array| array.len()) + } + + pub fn get_array_type(&self, index: usize) -> Option<&Type> { + self.array.as_ref().and_then(|array| array.get(index)) + } + + pub fn get_array(&self) -> Option<&Vec> { + self.array.as_ref() + } + + pub fn get_map(&self) -> Option<&BTreeMap> { + self.map.as_ref() + } + + pub fn to_string(&self) -> String { + let mut map_str = String::new(); + if let Some(map) = &self.map { + map_str = format!( + "<{}>", + map.iter().map(|(k, v)| format!("{}: {}", k, v.to_string())).collect::>().join(", ") + ); + } + + if let Some(array) = &self.array { + map_str = format!("<{}>", array.iter().map(Type::to_string).collect::>().join(", ")); + } + + format!( + "table{}", + format!( + "{}{}", + map_str, + self + .array + .as_ref() + .map(|array| format!("<{}>", array.iter().map(Type::to_string).collect::>().join(", "))) + .unwrap_or(String::new()) + ) + ) + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -243,8 +329,8 @@ impl Type { Type::Boolean } - pub fn new_table(key_type: Type, value_type: Type) -> Self { - Type::Table(TableType { key_type: Box::new(key_type), value_type: Box::new(value_type) }) + pub fn new_table(array: Option>, map: Option>) -> Self { + Type::Table(TableType { array, map }) } pub fn new_function(params: Vec, return_type: Type) -> Type { @@ -273,7 +359,13 @@ fn check_match_union(left: &Vec, right: &Vec) -> bool { } fn check_match_table(left: &TableType, right: &TableType) -> bool { - left.key_type.check_match(&right.key_type) && left.value_type.check_match(&right.value_type) + if let (Some(left_array), Some(right_array)) = (&left.array, &right.array) { + left_array.len() == right_array.len() && left_array.iter().zip(right_array).all(|(l, r)| l.check_match(r)) + } else if let (Some(left_map), Some(right_map)) = (&left.map, &right.map) { + left_map.len() == right_map.len() && left_map.iter().zip(right_map).all(|(l, r)| l.1.check_match(&r.1)) + } else { + false + } } fn check_match_function(left: &FunctionType, right: &FunctionType) -> bool { @@ -312,7 +404,19 @@ fn format_function_type(function: &FunctionType) -> String { } fn format_table_type(table: &TableType) -> String { - format!("table<{}, {}>", table.key_type.to_string(), table.value_type.to_string()) + let mut array_str = String::new(); + let mut map_str = String::new(); + + if let Some(array) = &table.array { + array_str = format!("<{}>", array.iter().map(Type::to_string).collect::>().join(", ")); + } + + if let Some(map) = &table.map { + map_str = + format!("<{}>", map.iter().map(|(k, v)| format!("{}: {}", k, v.to_string())).collect::>().join(", ")); + } + + format!("table{}", format!("{}{}", array_str, map_str)) } fn format_union_type(union: &UnionType) -> String {