diff --git a/config.yml b/config.yml index 9679a62e7..7051fef1b 100644 --- a/config.yml +++ b/config.yml @@ -1,19 +1,23 @@ nodes: - name: RBS::AST::Annotation + rust_name: AnnotationNode fields: - name: string c_type: rbs_string - name: RBS::AST::Bool + rust_name: BoolNode expose_to_ruby: false expose_location: false fields: - name: value c_type: bool - name: RBS::AST::Comment + rust_name: CommentNode fields: - name: string c_type: rbs_string - name: RBS::AST::Declarations::Class + rust_name: ClassNode fields: - name: name c_type: rbs_type_name @@ -21,19 +25,23 @@ nodes: c_type: rbs_node_list - name: super_class c_type: rbs_ast_declarations_class_super + optional: true # NULL when no superclass (e.g., `class Foo end` vs `class Foo < Bar end`) - name: members c_type: rbs_node_list - name: annotations c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Declarations::Class::Super + rust_name: ClassSuperNode fields: - name: name c_type: rbs_type_name - name: args c_type: rbs_node_list - name: RBS::AST::Declarations::ClassAlias + rust_name: ClassAliasNode fields: - name: new_name c_type: rbs_type_name @@ -41,9 +49,11 @@ nodes: c_type: rbs_type_name - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: annotations c_type: rbs_node_list - name: RBS::AST::Declarations::Constant + rust_name: ConstantNode fields: - name: name c_type: rbs_type_name @@ -51,9 +61,11 @@ nodes: c_type: rbs_node - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: annotations c_type: rbs_node_list - name: RBS::AST::Declarations::Global + rust_name: GlobalNode fields: - name: name c_type: rbs_ast_symbol @@ -61,9 +73,11 @@ nodes: c_type: rbs_node - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: annotations c_type: rbs_node_list - name: RBS::AST::Declarations::Interface + rust_name: InterfaceNode fields: - name: name c_type: rbs_type_name @@ -75,7 +89,9 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Declarations::Module + rust_name: ModuleNode fields: - name: name c_type: rbs_type_name @@ -89,13 +105,16 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Declarations::Module::Self + rust_name: ModuleSelfNode fields: - name: name c_type: rbs_type_name - name: args c_type: rbs_node_list - name: RBS::AST::Declarations::ModuleAlias + rust_name: ModuleAliasNode fields: - name: new_name c_type: rbs_type_name @@ -103,9 +122,11 @@ nodes: c_type: rbs_type_name - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: annotations c_type: rbs_node_list - name: RBS::AST::Declarations::TypeAlias + rust_name: TypeAliasNode fields: - name: name c_type: rbs_type_name @@ -117,22 +138,28 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Directives::Use + rust_name: UseNode fields: - name: clauses c_type: rbs_node_list - name: RBS::AST::Directives::Use::SingleClause + rust_name: UseSingleClauseNode fields: - name: type_name c_type: rbs_type_name - name: new_name c_type: rbs_ast_symbol + optional: true # NULL when no alias (e.g., `use Foo::Bar` vs `use Foo::Bar as Baz`) - name: RBS::AST::Directives::Use::WildcardClause + rust_name: UseWildcardClauseNode fields: - name: namespace c_type: rbs_namespace c_name: rbs_namespace - name: RBS::AST::Members::Alias + rust_name: AliasNode fields: - name: new_name c_type: rbs_ast_symbol @@ -144,7 +171,9 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Members::AttrAccessor + rust_name: AttrAccessorNode fields: - name: name c_type: rbs_ast_symbol @@ -152,15 +181,19 @@ nodes: c_type: rbs_node - name: ivar_name c_type: rbs_node # rbs_ast_symbol_t, NULL or rbs_ast_bool_new(false) + optional: true # NULL when omitted (`attr_accessor foo: T`); Symbol when named (`attr_accessor foo (@bar): T`); Bool(false) when empty parens (`attr_accessor foo (): T`) - name: kind c_type: rbs_keyword - name: annotations c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: visibility c_type: rbs_keyword + optional: true # NULL when no visibility prefix (e.g., `attr_accessor foo: T` vs `private attr_accessor foo: T`) - name: RBS::AST::Members::AttrReader + rust_name: AttrReaderNode fields: - name: name c_type: rbs_ast_symbol @@ -168,15 +201,19 @@ nodes: c_type: rbs_node - name: ivar_name c_type: rbs_node # rbs_ast_symbol_t, NULL or rbs_ast_bool_new(false) + optional: true # NULL when omitted (`attr_reader foo: T`); Symbol when named (`attr_reader foo (@bar): T`); Bool(false) when empty parens (`attr_reader foo (): T`) - name: kind c_type: rbs_keyword - name: annotations c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: visibility c_type: rbs_keyword + optional: true # NULL when no visibility prefix (e.g., `attr_reader foo: T` vs `private attr_reader foo: T`) - name: RBS::AST::Members::AttrWriter + rust_name: AttrWriterNode fields: - name: name c_type: rbs_ast_symbol @@ -184,15 +221,19 @@ nodes: c_type: rbs_node - name: ivar_name c_type: rbs_node # rbs_ast_symbol_t, NULL or rbs_ast_bool_new(false) + optional: true # NULL when omitted (`attr_writer foo: T`); Symbol when named (`attr_writer foo (@bar): T`); Bool(false) when empty parens (`attr_writer foo (): T`) - name: kind c_type: rbs_keyword - name: annotations c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: visibility c_type: rbs_keyword + optional: true # NULL when no visibility prefix (e.g., `attr_writer foo: T` vs `private attr_writer foo: T`) - name: RBS::AST::Members::ClassInstanceVariable + rust_name: ClassInstanceVariableNode fields: - name: name c_type: rbs_ast_symbol @@ -200,7 +241,9 @@ nodes: c_type: rbs_node - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Members::ClassVariable + rust_name: ClassVariableNode fields: - name: name c_type: rbs_ast_symbol @@ -208,7 +251,9 @@ nodes: c_type: rbs_node - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Members::Extend + rust_name: ExtendNode fields: - name: name c_type: rbs_type_name @@ -218,7 +263,9 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Members::Include + rust_name: IncludeNode fields: - name: name c_type: rbs_type_name @@ -228,7 +275,9 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Members::InstanceVariable + rust_name: InstanceVariableNode fields: - name: name c_type: rbs_ast_symbol @@ -236,7 +285,9 @@ nodes: c_type: rbs_node - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Members::MethodDefinition + rust_name: MethodDefinitionNode fields: - name: name c_type: rbs_ast_symbol @@ -248,11 +299,14 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: overloading c_type: bool - name: visibility c_type: rbs_keyword + optional: true # NULL when no visibility prefix (e.g., `def foo: ...` vs `private def foo: ...`) - name: RBS::AST::Members::MethodDefinition::Overload + rust_name: MethodDefinitionOverloadNode expose_location: false fields: - name: annotations @@ -260,6 +314,7 @@ nodes: - name: method_type c_type: rbs_node - name: RBS::AST::Members::Prepend + rust_name: PrependNode fields: - name: name c_type: rbs_type_name @@ -269,9 +324,13 @@ nodes: c_type: rbs_node_list - name: comment c_type: rbs_ast_comment + optional: true # NULL when no comment precedes the declaration - name: RBS::AST::Members::Private + rust_name: PrivateNode - name: RBS::AST::Members::Public + rust_name: PublicNode - name: RBS::AST::TypeParam + rust_name: TypeParamNode fields: - name: name c_type: rbs_ast_symbol @@ -279,25 +338,31 @@ nodes: c_type: rbs_keyword - name: upper_bound c_type: rbs_node + optional: true # NULL when no upper bound (e.g., `[T]` vs `[T < Bound]`) - name: lower_bound c_type: rbs_node + optional: true # NULL when no lower bound (e.g., `[T]` vs `[T > Bound]`) - name: default_type c_type: rbs_node + optional: true # NULL when no default (e.g., `[T]` vs `[T = Default]`) - name: unchecked c_type: bool - name: RBS::AST::Integer + rust_name: IntegerNode expose_to_ruby: false expose_location: false fields: - name: string_representation c_type: rbs_string - name: RBS::AST::String + rust_name: StringNode expose_to_ruby: false expose_location: false fields: - name: string c_type: rbs_string - name: RBS::MethodType + rust_name: MethodTypeNode fields: - name: type_params c_type: rbs_node_list @@ -305,7 +370,9 @@ nodes: c_type: rbs_node - name: block c_type: rbs_types_block + optional: true # NULL when no block (e.g., `() -> void` vs `() { () -> void } -> void`) - name: RBS::Namespace + rust_name: NamespaceNode expose_location: false fields: - name: path @@ -313,6 +380,7 @@ nodes: - name: absolute c_type: bool - name: RBS::Signature + rust_name: SignatureNode expose_to_ruby: false expose_location: false fields: @@ -321,6 +389,7 @@ nodes: - name: declarations c_type: rbs_node_list - name: RBS::TypeName + rust_name: TypeNameNode expose_location: false fields: - name: namespace @@ -329,24 +398,35 @@ nodes: - name: name c_type: rbs_ast_symbol - name: RBS::Types::Alias + rust_name: AliasTypeNode fields: - name: name c_type: rbs_type_name - name: args c_type: rbs_node_list - name: RBS::Types::Bases::Any + rust_name: AnyTypeNode fields: - name: todo c_type: bool - name: RBS::Types::Bases::Bool + rust_name: BoolTypeNode - name: RBS::Types::Bases::Bottom + rust_name: BottomTypeNode - name: RBS::Types::Bases::Class + rust_name: ClassTypeNode - name: RBS::Types::Bases::Instance + rust_name: InstanceTypeNode - name: RBS::Types::Bases::Nil + rust_name: NilTypeNode - name: RBS::Types::Bases::Self + rust_name: SelfTypeNode - name: RBS::Types::Bases::Top + rust_name: TopTypeNode - name: RBS::Types::Bases::Void + rust_name: VoidTypeNode - name: RBS::Types::Block + rust_name: BlockTypeNode expose_location: true fields: - name: type @@ -355,17 +435,21 @@ nodes: c_type: bool - name: self_type c_type: rbs_node + optional: true # NULL when no self binding (e.g., `{ () -> void }` vs `{ () [self: T] -> void }`) - name: RBS::Types::ClassInstance + rust_name: ClassInstanceTypeNode fields: - name: name c_type: rbs_type_name - name: args c_type: rbs_node_list - name: RBS::Types::ClassSingleton + rust_name: ClassSingletonTypeNode fields: - name: name c_type: rbs_type_name - name: RBS::Types::Function + rust_name: FunctionTypeNode expose_location: false fields: - name: required_positionals @@ -374,6 +458,7 @@ nodes: c_type: rbs_node_list - name: rest_positionals c_type: rbs_node + optional: true # NULL when no splat (e.g., `(String) -> void` vs `(*String) -> void`) - name: trailing_positionals c_type: rbs_node_list - name: required_keywords @@ -382,45 +467,57 @@ nodes: c_type: rbs_hash - name: rest_keywords c_type: rbs_node + optional: true # NULL when no double-splat (e.g., `() -> void` vs `(**String) -> void`) - name: return_type c_type: rbs_node - name: RBS::Types::Function::Param + rust_name: FunctionParamNode fields: - name: type c_type: rbs_node - name: name c_type: rbs_ast_symbol + optional: true # NULL when param is unnamed (e.g., `(String) -> void` vs `(String name) -> void`) - name: RBS::Types::Interface + rust_name: InterfaceTypeNode fields: - name: name c_type: rbs_type_name - name: args c_type: rbs_node_list - name: RBS::Types::Intersection + rust_name: IntersectionTypeNode fields: - name: types c_type: rbs_node_list - name: RBS::Types::Literal + rust_name: LiteralTypeNode fields: - name: literal c_type: rbs_node - name: RBS::Types::Optional + rust_name: OptionalTypeNode fields: - name: type c_type: rbs_node - name: RBS::Types::Proc + rust_name: ProcTypeNode fields: - name: type c_type: rbs_node - name: block c_type: rbs_types_block + optional: true # NULL when proc has no block (e.g., `^() -> void` vs `^() { () -> void } -> void`) - name: self_type c_type: rbs_node + optional: true # NULL when no self binding (e.g., `^() -> void` vs `^() [self: T] -> void`) - name: RBS::Types::Record + rust_name: RecordTypeNode fields: - name: all_fields c_type: rbs_hash - name: RBS::Types::Record::FieldType + rust_name: RecordFieldTypeNode expose_to_ruby: false expose_location: false fields: @@ -429,29 +526,35 @@ nodes: - name: required c_type: bool - name: RBS::Types::Tuple + rust_name: TupleTypeNode fields: - name: types c_type: rbs_node_list - name: RBS::Types::Union + rust_name: UnionTypeNode fields: - name: types c_type: rbs_node_list - name: RBS::Types::UntypedFunction + rust_name: UntypedFunctionTypeNode expose_location: false fields: - name: return_type c_type: rbs_node - name: RBS::Types::Variable + rust_name: VariableTypeNode fields: - name: name c_type: rbs_ast_symbol - name: RBS::AST::Ruby::Annotations::NodeTypeAssertion + rust_name: NodeTypeAssertionNode fields: - name: prefix_location c_type: rbs_location - name: type c_type: rbs_node - name: RBS::AST::Ruby::Annotations::ColonMethodTypeAnnotation + rust_name: ColonMethodTypeAnnotationNode fields: - name: prefix_location c_type: rbs_location @@ -460,6 +563,7 @@ nodes: - name: method_type c_type: rbs_node - name: RBS::AST::Ruby::Annotations::MethodTypesAnnotation + rust_name: MethodTypesAnnotationNode fields: - name: prefix_location c_type: rbs_location @@ -468,6 +572,7 @@ nodes: - name: vertical_bar_locations c_type: rbs_location_list - name: RBS::AST::Ruby::Annotations::SkipAnnotation + rust_name: SkipAnnotationNode fields: - name: prefix_location c_type: rbs_location @@ -476,6 +581,7 @@ nodes: - name: comment_location c_type: rbs_location - name: RBS::AST::Ruby::Annotations::ReturnTypeAnnotation + rust_name: ReturnTypeAnnotationNode fields: - name: prefix_location c_type: rbs_location @@ -488,6 +594,7 @@ nodes: - name: comment_location c_type: rbs_location - name: RBS::AST::Ruby::Annotations::TypeApplicationAnnotation + rust_name: TypeApplicationAnnotationNode fields: - name: prefix_location c_type: rbs_location @@ -498,6 +605,7 @@ nodes: - name: comma_locations c_type: rbs_location_list - name: RBS::AST::Ruby::Annotations::InstanceVariableAnnotation + rust_name: InstanceVariableAnnotationNode fields: - name: prefix_location c_type: rbs_location @@ -512,6 +620,7 @@ nodes: - name: comment_location c_type: rbs_location - name: RBS::AST::Ruby::Annotations::ClassAliasAnnotation + rust_name: ClassAliasAnnotationNode fields: - name: prefix_location c_type: rbs_location @@ -522,6 +631,7 @@ nodes: - name: type_name_location c_type: rbs_location - name: RBS::AST::Ruby::Annotations::ModuleAliasAnnotation + rust_name: ModuleAliasAnnotationNode fields: - name: prefix_location c_type: rbs_location diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 906fc5859..747375d18 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -78,12 +78,34 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "glob" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" + +[[package]] +name = "indexmap" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "itertools" version = "0.13.0" @@ -93,6 +115,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + [[package]] name = "libc" version = "0.2.174" @@ -194,6 +222,15 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "ruby-rbs" +version = "0.1.0" +dependencies = [ + "ruby-rbs-sys", + "serde", + "serde_yaml", +] + [[package]] name = "ruby-rbs-sys" version = "0.1.0" @@ -208,6 +245,45 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "shlex" version = "1.3.0" @@ -231,6 +307,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "windows-targets" version = "0.53.2" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 36e83a904..60895567e 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ + "ruby-rbs", "ruby-rbs-sys", ] diff --git a/rust/ruby-rbs/Cargo.toml b/rust/ruby-rbs/Cargo.toml new file mode 100644 index 000000000..9c4731d41 --- /dev/null +++ b/rust/ruby-rbs/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "ruby-rbs" +version = "0.1.0" +edition = "2024" + +[dependencies] +ruby-rbs-sys = { path = "../ruby-rbs-sys" } + +[build-dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_yaml = "0.9" diff --git a/rust/ruby-rbs/build.rs b/rust/ruby-rbs/build.rs new file mode 100644 index 000000000..a6b72cdf9 --- /dev/null +++ b/rust/ruby-rbs/build.rs @@ -0,0 +1,579 @@ +use serde::Deserialize; +use std::{env, error::Error, fs::File, io::Write, path::Path}; + +// This config-driven code generation approach is inspired by Prism's ruby-prism crate. +// See: https://github.com/ruby/prism/blob/main/rust/ruby-prism/build.rs + +#[derive(Debug, Deserialize)] +struct Config { + nodes: Vec, +} + +#[derive(Debug, Deserialize)] +struct NodeField { + name: String, + c_type: String, + c_name: Option, + #[serde(default)] + optional: bool, +} + +impl NodeField { + fn c_name(&self) -> &str { + let name = self.c_name.as_ref().unwrap_or(&self.name); + if name == "type" { "type_" } else { name } + } +} + +#[derive(Debug, Deserialize)] +struct Node { + name: String, + rust_name: String, + fields: Option>, +} + +impl Node { + fn variant_name(&self) -> &str { + self.rust_name + .strip_suffix("Node") + .unwrap_or(&self.rust_name) + } +} + +fn main() -> Result<(), Box> { + let config_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../config.yml") + .canonicalize()?; + + println!("cargo:rerun-if-changed={}", config_path.display()); + + let config_file = File::open(&config_path)?; + let mut config: Config = serde_yaml::from_reader(config_file)?; + + // Keyword and Symbol represent identifiers (interned strings), not traditional AST nodes. + // However, the C parser defines them in `rbs_node_type` (RBS_KEYWORD, RBS_AST_SYMBOL) and + // treats them as nodes (rbs_node_t*) in many contexts (lists, hashes). + // We inject them into the config so they are generated as structs matching the Node pattern, + // allowing them to be wrapped in the Node enum and handled uniformly in Rust. + config.nodes.push(Node { + name: "RBS::Keyword".to_string(), + rust_name: "KeywordNode".to_string(), + fields: None, + }); + config.nodes.push(Node { + name: "RBS::AST::Symbol".to_string(), + rust_name: "SymbolNode".to_string(), + fields: None, + }); + + config.nodes.sort_by(|a, b| a.name.cmp(&b.name)); + generate(&config)?; + + Ok(()) +} + +enum CIdentifier { + Type, // foo_bar_t + Constant, // FOO_BAR + Method, // visit_foo_bar +} + +fn convert_name(name: &str, identifier: CIdentifier) -> String { + let type_name = name.replace("::", "_"); + let lowercase = matches!(identifier, CIdentifier::Type | CIdentifier::Method); + let mut out = String::new(); + let mut prev_is_lower = false; + + for ch in type_name.chars() { + if ch.is_ascii_uppercase() { + if prev_is_lower { + out.push('_'); + } + out.push(if lowercase { + ch.to_ascii_lowercase() + } else { + ch + }); + prev_is_lower = false; + } else if ch == '_' { + out.push(ch); + prev_is_lower = false; + } else { + out.push(if lowercase { + ch + } else { + ch.to_ascii_uppercase() + }); + prev_is_lower = ch.is_ascii_lowercase() || ch.is_ascii_digit(); + } + } + + if matches!(identifier, CIdentifier::Type) { + out.push_str("_t"); + } + out +} + +fn write_node_field_accessor( + file: &mut File, + field: &NodeField, + rust_type: &str, +) -> std::io::Result<()> { + if field.optional { + writeln!(file, " #[must_use]")?; + writeln!( + file, + " pub fn {}(&self) -> Option<{rust_type}<'a>> {{", + field.name, + )?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!(file, " if ptr.is_null() {{")?; + writeln!(file, " None")?; + writeln!(file, " }} else {{")?; + writeln!( + file, + " Some({rust_type} {{ parser: self.parser, pointer: ptr, marker: PhantomData }})" + )?; + writeln!(file, " }}")?; + } else { + writeln!(file, " #[must_use]")?; + writeln!( + file, + " pub fn {}(&self) -> {rust_type}<'a> {{", + field.name + )?; + writeln!( + file, + " {rust_type} {{ parser: self.parser, pointer: unsafe {{ (*self.pointer).{} }}, marker: PhantomData }}", + field.c_name() + )?; + } + writeln!(file, " }}")?; + writeln!(file)?; + Ok(()) +} + +fn write_visit_trait(file: &mut File, config: &Config) -> Result<(), Box> { + writeln!(file, "/// A trait for traversing the AST using a visitor")?; + writeln!(file, "pub trait Visit {{")?; + writeln!( + file, + " /// Visit any node of the AST. Generally used to continue traversal" + )?; + writeln!(file, " fn visit(&mut self, node: &Node) {{")?; + writeln!(file, " match node {{")?; + + for node in &config.nodes { + let node_variant_name = node.variant_name(); + let method_name = convert_name(node_variant_name, CIdentifier::Method); + + writeln!(file, " Node::{node_variant_name}(it) => {{")?; + writeln!(file, " self.visit_{method_name}_node(it);")?; + writeln!(file, " }}")?; + } + + writeln!(file, " }}")?; + writeln!(file, " }}")?; + + for node in &config.nodes { + let node_variant_name = node.variant_name(); + let method_name = convert_name(node_variant_name, CIdentifier::Method); + + writeln!(file)?; + writeln!( + file, + " fn visit_{method_name}_node(&mut self, node: &{node_variant_name}Node) {{" + )?; + writeln!(file, " visit_{method_name}_node(self, node);")?; + writeln!(file, " }}")?; + } + writeln!(file, "}}")?; + writeln!(file)?; + + // Map C field types (e.g. `rbs_type_name`) to the corresponding + // visitor method name (e.g. `type_name` -> `visit_type_name_node`). + let visitor_method_names: std::collections::HashMap = config + .nodes + .iter() + .map(|node| { + let c_type = convert_name(&node.name, CIdentifier::Type); + let c_type = c_type.strip_suffix("_t").unwrap_or(&c_type).to_string(); + let method = convert_name(node.variant_name(), CIdentifier::Method); + (c_type, method) + }) + .collect(); + + let is_visitable = |c_type: &str| -> bool { + matches!(c_type, "rbs_node" | "rbs_node_list" | "rbs_hash") + || visitor_method_names.contains_key(c_type) + }; + + for node in &config.nodes { + let node_variant_name = node.variant_name(); + let method_name = convert_name(node_variant_name, CIdentifier::Method); + + let has_visitable_fields = node + .fields + .iter() + .flatten() + .any(|field| is_visitable(&field.c_type)); + + if !has_visitable_fields { + // If there's nothing to visit in this node, write the empty method with + // underscored parameters, and skip to the next iteration + writeln!( + file, + "pub fn visit_{method_name}_node(_visitor: &mut V, _node: &{node_variant_name}Node) {{}}" + )?; + + continue; + } + + writeln!( + file, + "pub fn visit_{method_name}_node(visitor: &mut V, node: &{node_variant_name}Node) {{" + )?; + + if let Some(fields) = &node.fields { + for field in fields { + let field_method_name = if field.name == "type" { + "type_" + } else { + field.name.as_str() + }; + + match field.c_type.as_str() { + "rbs_node" => { + if field.optional { + writeln!( + file, + " if let Some(item) = node.{field_method_name}() {{" + )?; + writeln!(file, " visitor.visit(&item);")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " visitor.visit(&node.{field_method_name}());")?; + } + } + + "rbs_node_list" => { + if field.optional { + writeln!( + file, + " if let Some(list) = node.{field_method_name}() {{" + )?; + writeln!(file, " for item in list.iter() {{")?; + writeln!(file, " visitor.visit(&item);")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " for item in node.{field_method_name}().iter() {{")?; + writeln!(file, " visitor.visit(&item);")?; + writeln!(file, " }}")?; + } + } + + "rbs_hash" => { + if field.optional { + writeln!( + file, + " if let Some(hash) = node.{field_method_name}() {{" + )?; + writeln!(file, " for (key, value) in hash.iter() {{")?; + writeln!(file, " visitor.visit(&key);")?; + writeln!(file, " visitor.visit(&value);")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!( + file, + " for (key, value) in node.{field_method_name}().iter() {{" + )?; + writeln!(file, " visitor.visit(&key);")?; + writeln!(file, " visitor.visit(&value);")?; + writeln!(file, " }}")?; + } + } + + _ => { + if let Some(visit_method_name) = visitor_method_names.get(&field.c_type) { + if field.optional { + writeln!( + file, + " if let Some(item) = node.{field_method_name}() {{" + )?; + writeln!( + file, + " visitor.visit_{visit_method_name}_node(&item);" + )?; + writeln!(file, " }}")?; + } else { + writeln!( + file, + " visitor.visit_{visit_method_name}_node(&node.{field_method_name}());" + )?; + } + } + } + } + } + } + writeln!(file, "}}")?; + writeln!(file)?; + } + + Ok(()) +} + +fn generate(config: &Config) -> Result<(), Box> { + let out_dir = env::var("OUT_DIR")?; + let dest_path = Path::new(&out_dir).join("bindings.rs"); + + let mut file = File::create(&dest_path)?; + + writeln!(file, "// Generated by build.rs from config.yml")?; + writeln!(file)?; + + for node in &config.nodes { + writeln!(file, "#[derive(Debug)]")?; + writeln!(file, "pub struct {}<'a> {{", node.rust_name)?; + writeln!(file, " #[allow(dead_code)]")?; + writeln!(file, " parser: NonNull,")?; + writeln!( + file, + " pointer: *mut {},", + convert_name(&node.name, CIdentifier::Type) + )?; + writeln!( + file, + " marker: PhantomData<&'a mut {}>", + convert_name(&node.name, CIdentifier::Type) + )?; + writeln!(file, "}}\n")?; + + writeln!(file, "impl<'a> {}<'a> {{", node.rust_name)?; + writeln!(file, " /// Converts this node to a generic node.")?; + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn as_node(self) -> Node<'a> {{")?; + writeln!(file, " Node::{}(self)", node.variant_name())?; + writeln!(file, " }}")?; + writeln!(file)?; + writeln!(file, " /// Returns the location of this node.")?; + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn location(&self) -> RBSLocation {{")?; + writeln!( + file, + " RBSLocation::new(unsafe {{ (*self.pointer).base.location }})" + )?; + writeln!(file, " }}")?; + writeln!(file)?; + + if let Some(fields) = &node.fields { + for field in fields { + match field.c_type.as_str() { + "rbs_string" => { + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn {}(&self) -> RBSString {{", field.name)?; + writeln!( + file, + " RBSString::new(unsafe {{ &(*self.pointer).{} }})", + field.c_name() + )?; + writeln!(file, " }}")?; + writeln!(file)?; + } + "bool" => { + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn {}(&self) -> bool {{", field.name)?; + writeln!(file, " unsafe {{ (*self.pointer).{} }}", field.name)?; + writeln!(file, " }}")?; + writeln!(file)?; + } + "rbs_ast_comment" => { + write_node_field_accessor(&mut file, field, "CommentNode")? + } + "rbs_ast_declarations_class_super" => { + write_node_field_accessor(&mut file, field, "ClassSuperNode")? + } + "rbs_ast_symbol" => write_node_field_accessor(&mut file, field, "SymbolNode")?, + "rbs_hash" => { + write_node_field_accessor(&mut file, field, "RBSHash")?; + } + "rbs_location" => { + if field.optional { + writeln!(file, " #[must_use]")?; + writeln!( + file, + " pub fn {}(&self) -> Option {{", + field.name + )?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!(file, " if ptr.is_null() {{")?; + writeln!(file, " None")?; + writeln!(file, " }} else {{")?; + writeln!(file, " Some(RBSLocation {{ pointer: ptr }})")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn {}(&self) -> RBSLocation {{", field.name)?; + writeln!( + file, + " RBSLocation {{ pointer: unsafe {{ (*self.pointer).{} }} }}", + field.c_name() + )?; + writeln!(file, " }}")?; + } + writeln!(file)?; + } + "rbs_location_list" => { + if field.optional { + writeln!(file, " #[must_use]")?; + writeln!( + file, + " pub fn {}(&self) -> Option {{", + field.name + )?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!(file, " if ptr.is_null() {{")?; + writeln!(file, " None")?; + writeln!(file, " }} else {{")?; + writeln!(file, " Some(RBSLocationList {{ pointer: ptr }})")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " #[must_use]")?; + writeln!( + file, + " pub fn {}(&self) -> RBSLocationList {{", + field.name + )?; + writeln!( + file, + " RBSLocationList {{ pointer: unsafe {{ (*self.pointer).{} }} }}", + field.c_name() + )?; + writeln!(file, " }}")?; + } + writeln!(file)?; + } + "rbs_namespace" => { + write_node_field_accessor(&mut file, field, "NamespaceNode")?; + } + "rbs_node" => { + let name = if field.name == "type" { + "type_" + } else { + field.name.as_str() + }; + if field.optional { + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn {name}(&self) -> Option> {{")?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!( + file, + " if ptr.is_null() {{ None }} else {{ Some(Node::new(self.parser, ptr)) }}" + )?; + } else { + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn {name}(&self) -> Node<'a> {{")?; + writeln!( + file, + " unsafe {{ Node::new(self.parser, (*self.pointer).{}) }}", + field.c_name() + )?; + } + writeln!(file, " }}")?; + writeln!(file)?; + } + "rbs_node_list" => { + write_node_field_accessor(&mut file, field, "NodeList")?; + } + "rbs_keyword" => write_node_field_accessor(&mut file, field, "KeywordNode")?, + "rbs_type_name" => { + write_node_field_accessor(&mut file, field, "TypeNameNode")?; + } + "rbs_types_block" => { + write_node_field_accessor(&mut file, field, "BlockTypeNode")? + } + _ => panic!("Unknown field type: {}", field.c_type), + } + } + } + writeln!(file, "}}\n")?; + } + + // Generate the Node enum to wrap all of the structs + writeln!(file, "#[derive(Debug)]")?; + writeln!(file, "pub enum Node<'a> {{")?; + for node in &config.nodes { + let variant_name = node + .rust_name + .strip_suffix("Node") + .unwrap_or(&node.rust_name); + + writeln!(file, " {variant_name}({}<'a>),", node.rust_name)?; + } + writeln!(file, "}}")?; + + writeln!(file, "impl Node<'_> {{")?; + writeln!(file, " #[allow(clippy::missing_safety_doc)]")?; + writeln!( + file, + " fn new(parser: NonNull, node: *mut rbs_node_t) -> Self {{" + )?; + writeln!(file, " match unsafe {{ (*node).type_ }} {{")?; + for node in &config.nodes { + let enum_name = convert_name(&node.name, CIdentifier::Constant); + let c_type = convert_name(&node.name, CIdentifier::Type); + + writeln!( + file, + " rbs_node_type::{enum_name} => Self::{}({} {{ parser, pointer: node.cast::<{c_type}>(), marker: PhantomData }}),", + node.variant_name(), + node.rust_name, + )?; + } + writeln!( + file, + " _ => panic!(\"Unknown node type: {{}}\", unsafe {{ (*node).type_ }})" + )?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + writeln!(file)?; + writeln!(file, " /// Returns the location of the entire node.")?; + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn location(&self) -> RBSLocation {{")?; + writeln!(file, " match self {{")?; + for node in &config.nodes { + writeln!( + file, + " Node::{}(node) => node.location(),", + node.variant_name() + )?; + } + writeln!(file, " }}")?; + writeln!(file, " }}")?; + writeln!(file, "}}")?; + writeln!(file)?; + + write_visit_trait(&mut file, config)?; + + Ok(()) +} diff --git a/rust/ruby-rbs/src/lib.rs b/rust/ruby-rbs/src/lib.rs new file mode 100644 index 000000000..492bc84b4 --- /dev/null +++ b/rust/ruby-rbs/src/lib.rs @@ -0,0 +1 @@ +pub mod node; diff --git a/rust/ruby-rbs/src/node/mod.rs b/rust/ruby-rbs/src/node/mod.rs new file mode 100644 index 000000000..8d6d379a0 --- /dev/null +++ b/rust/ruby-rbs/src/node/mod.rs @@ -0,0 +1,488 @@ +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); +use rbs_encoding_type_t::RBS_ENCODING_UTF_8; +use ruby_rbs_sys::bindings::*; +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::Once; + +static INIT: Once = Once::new(); + +/// Parse RBS code into an AST. +/// +/// ```rust +/// use ruby_rbs::node::parse; +/// let rbs_code = r#"type foo = "hello""#; +/// let signature = parse(rbs_code.as_bytes()); +/// assert!(signature.is_ok(), "Failed to parse RBS signature"); +/// ``` +pub fn parse(rbs_code: &[u8]) -> Result, String> { + unsafe { + INIT.call_once(|| { + rbs_constant_pool_init(RBS_GLOBAL_CONSTANT_POOL, 26); + }); + + let start_ptr = rbs_code.as_ptr() as *const i8; + let end_ptr = start_ptr.add(rbs_code.len()); + + let raw_rbs_string_value = rbs_string_new(start_ptr, end_ptr); + + let encoding_ptr = &rbs_encodings[RBS_ENCODING_UTF_8 as usize] as *const rbs_encoding_t; + let parser = rbs_parser_new(raw_rbs_string_value, encoding_ptr, 0, rbs_code.len() as i32); + + let mut signature: *mut rbs_signature_t = std::ptr::null_mut(); + let result = rbs_parse_signature(parser, &mut signature); + + let signature_node = SignatureNode { + parser: NonNull::new_unchecked(parser), + pointer: signature, + marker: PhantomData, + }; + + if result { + Ok(signature_node) + } else { + Err(String::from("Failed to parse RBS signature")) + } + } +} + +impl Drop for SignatureNode<'_> { + fn drop(&mut self) { + unsafe { + rbs_parser_free(self.parser.as_ptr()); + } + } +} + +impl KeywordNode<'_> { + #[must_use] + pub fn name(&self) -> &[u8] { + unsafe { + let constant_ptr = rbs_constant_pool_id_to_constant( + &(*self.parser.as_ptr()).constant_pool, + (*self.pointer).constant_id, + ); + if constant_ptr.is_null() { + panic!("Constant ID for keyword is not present in the pool"); + } + + let constant = &*constant_ptr; + std::slice::from_raw_parts(constant.start, constant.length) + } + } +} + +pub struct NodeList<'a> { + parser: NonNull, + pointer: *mut rbs_node_list_t, + marker: PhantomData<&'a mut rbs_node_list_t>, +} + +impl<'a> NodeList<'a> { + #[must_use] + pub fn new(parser: NonNull, pointer: *mut rbs_node_list_t) -> Self { + Self { + parser, + pointer, + marker: PhantomData, + } + } + + /// Returns an iterator over the nodes. + #[must_use] + pub fn iter(&self) -> NodeListIter<'a> { + NodeListIter { + parser: self.parser, + current: unsafe { (*self.pointer).head }, + marker: PhantomData, + } + } +} + +pub struct NodeListIter<'a> { + parser: NonNull, + current: *mut rbs_node_list_node_t, + marker: PhantomData<&'a mut rbs_node_list_node_t>, +} + +impl<'a> Iterator for NodeListIter<'a> { + type Item = Node<'a>; + + fn next(&mut self) -> Option { + if self.current.is_null() { + None + } else { + let pointer_data = unsafe { *self.current }; + let node = Node::new(self.parser, pointer_data.node); + self.current = pointer_data.next; + Some(node) + } + } +} + +pub struct RBSHash<'a> { + parser: NonNull, + pointer: *mut rbs_hash, + marker: PhantomData<&'a mut rbs_hash>, +} + +impl<'a> RBSHash<'a> { + #[must_use] + pub fn new(parser: NonNull, pointer: *mut rbs_hash) -> Self { + Self { + parser, + pointer, + marker: PhantomData, + } + } + + /// Returns an iterator over the key-value pairs. + #[must_use] + pub fn iter(&self) -> RBSHashIter<'a> { + RBSHashIter { + parser: self.parser, + current: unsafe { (*self.pointer).head }, + marker: PhantomData, + } + } +} + +pub struct RBSHashIter<'a> { + parser: NonNull, + current: *mut rbs_hash_node_t, + marker: PhantomData<&'a mut rbs_hash_node_t>, +} + +impl<'a> Iterator for RBSHashIter<'a> { + type Item = (Node<'a>, Node<'a>); + + fn next(&mut self) -> Option { + if self.current.is_null() { + None + } else { + let pointer_data = unsafe { *self.current }; + let key = Node::new(self.parser, pointer_data.key); + let value = Node::new(self.parser, pointer_data.value); + self.current = pointer_data.next; + Some((key, value)) + } + } +} + +pub struct RBSLocation { + pointer: *const rbs_location_t, +} + +impl RBSLocation { + #[must_use] + pub fn new(pointer: *const rbs_location_t) -> Self { + Self { pointer } + } + + #[must_use] + pub fn start(&self) -> i32 { + unsafe { (*self.pointer).rg.start.byte_pos } + } + + #[must_use] + pub fn end(&self) -> i32 { + unsafe { (*self.pointer).rg.end.byte_pos } + } +} + +pub struct RBSLocationList { + pointer: *mut rbs_location_list, +} + +impl RBSLocationList { + #[must_use] + pub fn new(pointer: *mut rbs_location_list) -> Self { + Self { pointer } + } + + /// Returns an iterator over the locations. + #[must_use] + pub fn iter(&self) -> RBSLocationListIter { + RBSLocationListIter { + current: unsafe { (*self.pointer).head }, + } + } +} + +pub struct RBSLocationListIter { + current: *mut rbs_location_list_node_t, +} + +impl Iterator for RBSLocationListIter { + type Item = RBSLocation; + + fn next(&mut self) -> Option { + if self.current.is_null() { + None + } else { + let pointer_data = unsafe { *self.current }; + let loc = RBSLocation::new(pointer_data.loc); + self.current = pointer_data.next; + Some(loc) + } + } +} + +#[derive(Debug)] +pub struct RBSString { + pointer: *const rbs_string_t, +} + +impl RBSString { + #[must_use] + pub fn new(pointer: *const rbs_string_t) -> Self { + Self { pointer } + } + + #[must_use] + pub fn as_bytes(&self) -> &[u8] { + unsafe { + let s = *self.pointer; + std::slice::from_raw_parts(s.start as *const u8, s.end.offset_from(s.start) as usize) + } + } +} + +impl SymbolNode<'_> { + #[must_use] + pub fn name(&self) -> &[u8] { + unsafe { + let constant_ptr = rbs_constant_pool_id_to_constant( + &(*self.parser.as_ptr()).constant_pool, + (*self.pointer).constant_id, + ); + if constant_ptr.is_null() { + panic!("Constant ID for symbol is not present in the pool"); + } + + let constant = &*constant_ptr; + std::slice::from_raw_parts(constant.start, constant.length) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse() { + let rbs_code = r#"type foo = "hello""#; + let signature = parse(rbs_code.as_bytes()); + assert!(signature.is_ok(), "Failed to parse RBS signature"); + + let rbs_code2 = r#"class Foo end"#; + let signature2 = parse(rbs_code2.as_bytes()); + assert!(signature2.is_ok(), "Failed to parse RBS signature"); + } + + #[test] + fn test_parse_integer() { + let rbs_code = r#"type foo = 1"#; + let signature = parse(rbs_code.as_bytes()); + assert!(signature.is_ok(), "Failed to parse RBS signature"); + + let signature_node = signature.unwrap(); + if let Node::TypeAlias(node) = signature_node.declarations().iter().next().unwrap() + && let Node::LiteralType(literal) = node.type_() + && let Node::Integer(integer) = literal.literal() + { + assert_eq!( + "1".to_string(), + String::from_utf8(integer.string_representation().as_bytes().to_vec()).unwrap() + ); + } else { + panic!("No literal type node found"); + } + } + + #[test] + fn test_rbs_hash_via_record_type() { + // RecordType stores its fields in an RBSHash via all_fields() + let rbs_code = r#"type foo = { name: String, age: Integer }"#; + let signature = parse(rbs_code.as_bytes()); + assert!(signature.is_ok(), "Failed to parse RBS signature"); + + let signature_node = signature.unwrap(); + if let Node::TypeAlias(type_alias) = signature_node.declarations().iter().next().unwrap() + && let Node::RecordType(record) = type_alias.type_() + { + let hash = record.all_fields(); + let fields: Vec<_> = hash.iter().collect(); + assert_eq!(fields.len(), 2, "Expected 2 fields in record"); + + // Build a map of field names to type names + let mut field_types: Vec<(String, String)> = Vec::new(); + for (key, value) in &fields { + let Node::Symbol(sym) = key else { + panic!("Expected Symbol key"); + }; + let Node::RecordFieldType(field_type) = value else { + panic!("Expected RecordFieldType value"); + }; + let Node::ClassInstanceType(class_type) = field_type.type_() else { + panic!("Expected ClassInstanceType"); + }; + + let key_name = String::from_utf8(sym.name().to_vec()).unwrap(); + let type_name_node = class_type.name(); + let type_name_sym = type_name_node.name(); + let type_name = String::from_utf8(type_name_sym.name().to_vec()).unwrap(); + field_types.push((key_name, type_name)); + } + + assert!( + field_types.contains(&("name".to_string(), "String".to_string())), + "Expected 'name: String'" + ); + assert!( + field_types.contains(&("age".to_string(), "Integer".to_string())), + "Expected 'age: Integer'" + ); + } else { + panic!("Expected TypeAlias with RecordType"); + } + } + + #[test] + fn visitor_test() { + struct Visitor { + visited: Vec, + } + + impl Visit for Visitor { + fn visit_bool_type_node(&mut self, node: &BoolTypeNode) { + self.visited.push("type:bool".to_string()); + + crate::node::visit_bool_type_node(self, node); + } + + fn visit_class_node(&mut self, node: &ClassNode) { + self.visited.push(format!( + "class:{}", + String::from_utf8(node.name().name().name().to_vec()).unwrap() + )); + + crate::node::visit_class_node(self, node); + } + + fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) { + self.visited.push(format!( + "type:{}", + String::from_utf8(node.name().name().name().to_vec()).unwrap() + )); + + crate::node::visit_class_instance_type_node(self, node); + } + + fn visit_class_super_node(&mut self, node: &ClassSuperNode) { + self.visited.push(format!( + "super:{}", + String::from_utf8(node.name().name().name().to_vec()).unwrap() + )); + + crate::node::visit_class_super_node(self, node); + } + + fn visit_function_type_node(&mut self, node: &FunctionTypeNode) { + let count = node.required_positionals().iter().count(); + self.visited + .push(format!("function:required_positionals:{count}")); + + crate::node::visit_function_type_node(self, node); + } + + fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) { + self.visited.push(format!( + "method:{}", + String::from_utf8(node.name().name().to_vec()).unwrap() + )); + + crate::node::visit_method_definition_node(self, node); + } + + fn visit_record_type_node(&mut self, node: &RecordTypeNode) { + self.visited.push("record".to_string()); + + crate::node::visit_record_type_node(self, node); + } + + fn visit_symbol_node(&mut self, node: &SymbolNode) { + self.visited.push(format!( + "symbol:{}", + String::from_utf8(node.name().to_vec()).unwrap() + )); + + crate::node::visit_symbol_node(self, node); + } + } + + let rbs_code = r#" + class Foo < Bar + def process: ({ name: String, age: Integer }, bool) -> void + end + "#; + + let signature = parse(rbs_code.as_bytes()).unwrap(); + + let mut visitor = Visitor { + visited: Vec::new(), + }; + + visitor.visit(&signature.as_node()); + + assert_eq!( + vec![ + "class:Foo", + "symbol:Foo", + "super:Bar", + "symbol:Bar", + "method:process", + "symbol:process", + "function:required_positionals:2", + "record", + "symbol:name", + "type:String", + "symbol:String", + "symbol:age", + "type:Integer", + "symbol:Integer", + "type:bool", + ], + visitor.visited + ); + } + + #[test] + fn test_node_location_ranges() { + let rbs_code = r#"type foo = 1"#; + let signature = parse(rbs_code.as_bytes()).unwrap(); + + let declaration = signature.declarations().iter().next().unwrap(); + let Node::TypeAlias(type_alias) = declaration else { + panic!("Expected TypeAlias"); + }; + + // TypeAlias spans the entire declaration + let loc = type_alias.location(); + assert_eq!(0, loc.start()); + assert_eq!(12, loc.end()); + + // The literal "1" is at position 11-12 + let Node::LiteralType(literal) = type_alias.type_() else { + panic!("Expected LiteralType"); + }; + let Node::Integer(integer) = literal.literal() else { + panic!("Expected Integer"); + }; + + let int_loc = integer.location(); + assert_eq!(11, int_loc.start()); + assert_eq!(12, int_loc.end()); + } +}