From ce036d1b3362a392ea38043423378b3743bde39e Mon Sep 17 00:00:00 2001 From: Konstantin Anisimov Date: Sun, 23 Oct 2022 21:23:45 +0600 Subject: [PATCH] fix listener codegen when labeled alts are involved --- src/tree.rs | 10 ++++++++-- templates/Rust.stg | 6 +++++- tests/gen/labelsparser.rs | 28 ++++++++++++++++++++++++++++ tests/general_tests.rs | 29 +++++++++++++++++++++++++---- 4 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index fce1359..b568067 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -436,11 +436,17 @@ where T: ParseTreeListener<'input, Node> + 'a + ?Sized, Node::Type: Listenable, { + pub fn walk_mut(mut listener: &mut Listener, t: &Ctx) + where + Listener: CoerceTo, + Ctx: CoerceTo, + { + Self::walk_inner(listener.coerce_mut_to(), t.coerce_ref_to()); + } + /// Walks recursively over tree `t` with `listener` pub fn walk(mut listener: Box, t: &Ctx) -> Box where - // for<'x> &'x mut Listener: CoerceUnsized<&'x mut T>, - // for<'x> &'x Ctx: CoerceUnsized<&'x Node::Type>, Listener: CoerceTo, Ctx: CoerceTo, { diff --git a/templates/Rust.stg b/templates/Rust.stg index e945963..cc91b89 100644 --- a/templates/Rust.stg +++ b/templates/Rust.stg @@ -1095,7 +1095,11 @@ impl\<'input> Context\<'input> for \<'input>{} impl\<'input,'a> Listenable\Listener\<'input> + 'a> for \<'input>{ - + + + + + } diff --git a/tests/gen/labelsparser.rs b/tests/gen/labelsparser.rs index 2f93a7a..ccfb26d 100644 --- a/tests/gen/labelsparser.rs +++ b/tests/gen/labelsparser.rs @@ -504,6 +504,10 @@ impl<'input, 'a> Listenable + 'a> for AddContext<'inp listener.enter_every_rule(self); listener.enter_add(self); } + fn exit(&self, listener: &mut (dyn LabelsListener<'input> + 'a)) { + listener.exit_add(self); + listener.exit_every_rule(self); + } } impl<'input> CustomRuleContext<'input> for AddContextExt<'input> { @@ -570,6 +574,10 @@ impl<'input, 'a> Listenable + 'a> for ParensContext<' listener.enter_every_rule(self); listener.enter_parens(self); } + fn exit(&self, listener: &mut (dyn LabelsListener<'input> + 'a)) { + listener.exit_parens(self); + listener.exit_every_rule(self); + } } impl<'input> CustomRuleContext<'input> for ParensContextExt<'input> { @@ -645,6 +653,10 @@ impl<'input, 'a> Listenable + 'a> for MultContext<'in listener.enter_every_rule(self); listener.enter_mult(self); } + fn exit(&self, listener: &mut (dyn LabelsListener<'input> + 'a)) { + listener.exit_mult(self); + listener.exit_every_rule(self); + } } impl<'input> CustomRuleContext<'input> for MultContextExt<'input> { @@ -712,6 +724,10 @@ impl<'input, 'a> Listenable + 'a> for DecContext<'inp listener.enter_every_rule(self); listener.enter_dec(self); } + fn exit(&self, listener: &mut (dyn LabelsListener<'input> + 'a)) { + listener.exit_dec(self); + listener.exit_every_rule(self); + } } impl<'input> CustomRuleContext<'input> for DecContextExt<'input> { @@ -779,6 +795,10 @@ impl<'input, 'a> Listenable + 'a> for AnIDContext<'in listener.enter_every_rule(self); listener.enter_anID(self); } + fn exit(&self, listener: &mut (dyn LabelsListener<'input> + 'a)) { + listener.exit_anID(self); + listener.exit_every_rule(self); + } } impl<'input> CustomRuleContext<'input> for AnIDContextExt<'input> { @@ -846,6 +866,10 @@ impl<'input, 'a> Listenable + 'a> for AnIntContext<'i listener.enter_every_rule(self); listener.enter_anInt(self); } + fn exit(&self, listener: &mut (dyn LabelsListener<'input> + 'a)) { + listener.exit_anInt(self); + listener.exit_every_rule(self); + } } impl<'input> CustomRuleContext<'input> for AnIntContextExt<'input> { @@ -911,6 +935,10 @@ impl<'input, 'a> Listenable + 'a> for IncContext<'inp listener.enter_every_rule(self); listener.enter_inc(self); } + fn exit(&self, listener: &mut (dyn LabelsListener<'input> + 'a)) { + listener.exit_inc(self); + listener.exit_every_rule(self); + } } impl<'input> CustomRuleContext<'input> for IncContextExt<'input> { diff --git a/tests/general_tests.rs b/tests/general_tests.rs index e841b03..4d6a21a 100644 --- a/tests/general_tests.rs +++ b/tests/general_tests.rs @@ -37,7 +37,6 @@ mod gen { }; use crate::gen::csvvisitor::CSVVisitor; use crate::gen::labelslexer::LabelsLexer; - use crate::gen::labelsparser::{EContextAll, LabelsParser}; use crate::gen::referencetoatnparser::{ ReferenceToATNParserContext, ReferenceToATNParserContextType, }; @@ -337,19 +336,41 @@ if (x < x && a > 0) then duh #[test] fn test_complex_convert() { + use labelsparser::*; + + struct Listener(String); + + impl ParseTreeListener<'_, LabelsParserContextType> for Listener {} + impl<'input> labelslistener::LabelsListener<'input> for Listener { + fn enter_add(&mut self, _ctx: &AddContext<'input>) { + self.0.push('('); + } + + fn exit_add(&mut self, ctx: &AddContext<'input>) { + self.0 += ctx.get_v(); + self.0.push(')'); + } + } + let codepoints = "(a+4)*2".chars().map(|x| x as u32).collect::>(); // let codepoints = "(a+4)*2"; let input = InputStream::new(&*codepoints); let lexer = LabelsLexer::new(input); let token_source = CommonTokenStream::new(lexer); let mut parser = LabelsParser::new(token_source); - let result = parser.s().expect("parser error"); - let string = result.q.as_ref().unwrap().get_v(); + + let root = parser.s().expect("parser error"); + let string = root.q.as_ref().unwrap().get_v(); assert_eq!("* + a 4 2", string); - let x = result.q.as_deref().unwrap(); + let x = root.q.as_deref().unwrap(); + match x { EContextAll::MultContext(x) => assert_eq!("(a+4)", x.a.as_ref().unwrap().get_text()), _ => panic!("oops"), } + + let mut listener = Listener(String::new()); + LabelsTreeWalker::walk_mut(&mut listener, &*root); + assert_eq!(listener.0, "(+ a 4)"); } }