diff --git a/shared/yeast/src/lib.rs b/shared/yeast/src/lib.rs index 0f0ea45a9857..281f44a98b26 100644 --- a/shared/yeast/src/lib.rs +++ b/shared/yeast/src/lib.rs @@ -34,44 +34,48 @@ pub const CHILD_FIELD: u16 = u16::MAX; #[derive(Debug)] pub struct AstCursor<'a> { ast: &'a Ast, - /// A stack of parents, along with iterators for their children - parents: Vec<(&'a Node, ChildrenIter<'a>)>, - node: &'a Node, + /// A stack of parents, along with iterators for their children. + parents: Vec<(Id, ChildrenIter<'a>)>, + node_id: Id, } impl<'a> AstCursor<'a> { pub fn new(ast: &'a Ast) -> Self { - // TODO: handle non-zero root - let node = ast.get_node(ast.root).unwrap(); Self { ast, parents: vec![], - node, + node_id: ast.root, } } + /// The Id of the node currently under the cursor. + pub fn node_id(&self) -> Id { + self.node_id + } + fn goto_next_sibling_opt(&mut self) -> Option<()> { - self.node = self.parents.last_mut()?.1.next()?; + self.node_id = self.parents.last_mut()?.1.next()?; Some(()) } fn goto_first_child_opt(&mut self) -> Option<()> { - let parent = self.node; - let mut children = ChildrenIter::new(self.ast, parent); + let parent_id = self.node_id; + let parent = self.ast.get_node(parent_id)?; + let mut children = ChildrenIter::new(parent); let first_child = children.next()?; - self.node = first_child; - self.parents.push((parent, children)); + self.node_id = first_child; + self.parents.push((parent_id, children)); Some(()) } fn goto_parent_opt(&mut self) -> Option<()> { - self.node = self.parents.pop()?.0; + self.node_id = self.parents.pop()?.0; Some(()) } } impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> { fn node(&self) -> &'a Node { - self.node + &self.ast.nodes[self.node_id] } fn field_id(&self) -> Option { @@ -101,36 +105,30 @@ impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> { } } -/// An iterator over all the child nodes of a node. +/// An iterator over the child Ids of a node. #[derive(Debug)] struct ChildrenIter<'a> { - ast: &'a Ast, current_field: Option, fields: std::collections::btree_map::Iter<'a, FieldId, Vec>, field_children: Option>, } impl<'a> ChildrenIter<'a> { - fn new(ast: &'a Ast, node: &'a Node) -> Self { + fn new(node: &'a Node) -> Self { Self { - ast, current_field: None, fields: node.fields.iter(), field_children: None, } } - fn get_node(&self, id: Id) -> &'a Node { - self.ast.get_node(id).unwrap() - } - fn current_field(&self) -> Option { self.current_field } } -impl<'a> Iterator for ChildrenIter<'a> { - type Item = &'a Node; +impl Iterator for ChildrenIter<'_> { + type Item = Id; fn next(&mut self) -> Option { match self.field_children.as_mut() { @@ -151,7 +149,7 @@ impl<'a> Iterator for ChildrenIter<'a> { self.next() } }, - Some(child_id) => Some(self.get_node(*child_id)), + Some(child_id) => Some(*child_id), }, } } @@ -236,7 +234,6 @@ impl Ast { ) -> Id { let id = self.nodes.len(); self.nodes.push(Node { - id, kind, kind_name: self.schema.node_kind_for_id(kind).unwrap(), fields, @@ -265,7 +262,6 @@ impl Ast { }); let id = self.nodes.len(); self.nodes.push(Node { - id, kind: kind_id, kind_name: kind, is_named: true, @@ -345,7 +341,6 @@ impl Ast { /// A node in our AST #[derive(PartialEq, Eq, Debug, Clone, Serialize)] pub struct Node { - id: Id, kind: KindId, kind_name: &'static str, pub(crate) fields: BTreeMap>, @@ -361,10 +356,6 @@ pub struct Node { } impl Node { - pub fn id(&self) -> Id { - self.id - } - pub fn kind(&self) -> &'static str { self.kind_name } @@ -600,39 +591,41 @@ fn apply_rules_inner( } } - // Collect fields before recursing (avoids borrowing ast immutably during mutation) - let field_entries: Vec<(FieldId, Vec)> = ast.nodes[id] - .fields - .iter() - .map(|(&fid, children)| (fid, children.clone())) - .collect(); - - // recursively descend into all the fields + // Take the parent's fields by ownership: the recursion will rewrite + // each child Id, and we'll write the (possibly mutated) field map back + // when we're done. Avoids cloning the whole BTreeMap and its child + // Vecs on entry. Each child Vec is only re-allocated if a rewrite + // actually changes its contents. + // // Child traversal does not increment rewrite depth and starts fresh // (no rule is skipped on child subtrees). - let mut changed = false; - let mut new_fields = BTreeMap::new(); - for (field_id, children) in field_entries { - let mut new_children = Vec::new(); - for child_id in children { + let mut fields = std::mem::take(&mut ast.nodes[id].fields); + for children in fields.values_mut() { + let mut new_children: Option> = None; + for (i, &child_id) in children.iter().enumerate() { let result = apply_rules_inner(index, ast, child_id, fresh, rewrite_depth, None)?; - if result.len() != 1 || result[0] != child_id { - changed = true; + let unchanged = result.len() == 1 && result[0] == child_id; + match (&mut new_children, unchanged) { + (None, true) => {} // unchanged so far, no allocation needed + (None, false) => { + // First divergence — copy already-processed Ids and + // start collecting the rewritten sequence. + let mut new = Vec::with_capacity(children.len()); + new.extend_from_slice(&children[..i]); + new.extend(result); + new_children = Some(new); + } + (Some(new), _) => { + new.extend(result); + } } - new_children.extend(result); } - new_fields.insert(field_id, new_children); - } - - if !changed { - return Ok(vec![id]); + if let Some(new) = new_children { + *children = new; + } } - - let mut node = ast.nodes[id].clone(); - node.fields = new_fields; - node.id = ast.nodes.len(); - ast.nodes.push(node); - Ok(vec![ast.nodes.len() - 1]) + ast.nodes[id].fields = fields; + Ok(vec![id]) } /// One phase of a desugaring pass: a named bundle of rules that runs to diff --git a/shared/yeast/src/visitor.rs b/shared/yeast/src/visitor.rs index 655aa01e6b3e..1b9c5eab3627 100644 --- a/shared/yeast/src/visitor.rs +++ b/shared/yeast/src/visitor.rs @@ -49,7 +49,7 @@ impl Visitor { pub fn build_with_schema(self, schema: crate::schema::Schema) -> Ast { Ast { - root: self.nodes[0].inner.id, + root: 0, schema, nodes: self.nodes.into_iter().map(|n| n.inner).collect(), } @@ -59,7 +59,6 @@ impl Visitor { let id = self.nodes.len(); self.nodes.push(VisitorNode { inner: Node { - id, kind: self.language.id_for_node_kind(n.kind(), is_named), kind_name: n.kind(), content, @@ -82,11 +81,10 @@ impl Visitor { } fn leave_node(&mut self, field_name: Option<&'static str>, _node: tree_sitter::Node<'_>) { - let node = self.current.map(|i| &self.nodes[i]).unwrap(); - let node_id = node.inner.id; - let node_parent = node.parent; + let node_id = self.current.unwrap(); + let node_parent = self.nodes[node_id].parent; - if let Some(parent_id) = node.parent { + if let Some(parent_id) = node_parent { let parent = self.nodes.get_mut(parent_id).unwrap(); if let Some(field) = field_name { let field_id = self.language.field_id_for_name(field).unwrap().get(); diff --git a/shared/yeast/tests/test.rs b/shared/yeast/tests/test.rs index 6ba2cd3e9f08..ed4202493a46 100644 --- a/shared/yeast/tests/test.rs +++ b/shared/yeast/tests/test.rs @@ -182,7 +182,7 @@ fn test_query_repeated_capture() { // Match against the assignment node (first named child of program) let mut cursor = AstCursor::new(&ast); cursor.goto_first_child(); - let assignment_id = cursor.node().id(); + let assignment_id = cursor.node_id(); let mut captures = yeast::captures::Captures::new(); let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap(); @@ -206,7 +206,7 @@ fn test_capture_unnamed_node_parenthesized() { let mut cursor = AstCursor::new(&ast); cursor.goto_first_child(); - let assignment_id = cursor.node().id(); + let assignment_id = cursor.node_id(); let mut captures = yeast::captures::Captures::new(); let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap(); @@ -233,7 +233,7 @@ fn test_capture_unnamed_node_bare_literal() { let mut cursor = AstCursor::new(&ast); cursor.goto_first_child(); - let assignment_id = cursor.node().id(); + let assignment_id = cursor.node_id(); let mut captures = yeast::captures::Captures::new(); let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap(); @@ -254,7 +254,7 @@ fn test_bare_underscore_matches_unnamed() { let mut cursor = AstCursor::new(&ast); cursor.goto_first_child(); - let assignment_id = cursor.node().id(); + let assignment_id = cursor.node_id(); // `(_)` skips unnamed children, so a query containing a single `(_)` // bare pattern fails to match the assignment (whose only unfielded @@ -293,7 +293,7 @@ fn test_bare_forms_in_field_position() { let mut cursor = AstCursor::new(&ast); cursor.goto_first_child(); - let assignment_id = cursor.node().id(); + let assignment_id = cursor.node_id(); // Bare `_` in field position. Captures the named `identifier "x"` // child of the `left` field — bare `_` admits unnamed too, but the @@ -337,7 +337,7 @@ fn test_forward_scan_finds_unnamed_token_late() { while cursor.node().kind() != "do" || !cursor.node().is_named() { assert!(cursor.goto_next_sibling(), "expected to find named `do`"); } - let do_id = cursor.node().id(); + let do_id = cursor.node_id(); let query = yeast::query!((do ("end") @kw)); let mut captures = yeast::captures::Captures::new(); @@ -363,7 +363,7 @@ fn test_forward_scan_preserves_order() { while cursor.node().kind() != "do" || !cursor.node().is_named() { assert!(cursor.goto_next_sibling(), "expected to find named `do`"); } - let do_id = cursor.node().id(); + let do_id = cursor.node_id(); let query = yeast::query!((do ("end") @first ("do") @second)); let mut captures = yeast::captures::Captures::new();