Last active
September 14, 2025 05:25
-
-
Save tung/10d23faf34e853bb451c009b66317203 to your computer and use it in GitHub Desktop.
Math expression calculator in Rust, using precedence climbing for operator precedence.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Math expression calculator in Rust, using precedence climbing for operator precedence. | |
| // | |
| // rustc -C opt-level=z -C lto=fat -C strip=symbols -C panic=abort calc.rs | |
| enum Operator { | |
| Add, | |
| Subtract, | |
| Multiply, | |
| Divide, | |
| Modulo, | |
| Power, | |
| } | |
| impl Operator { | |
| fn left_associative(&self) -> bool { | |
| match self { | |
| Self::Add | Self::Subtract | Self::Multiply | Self::Divide | Self::Modulo => true, | |
| Self::Power => false, | |
| } | |
| } | |
| fn precedence(&self) -> u8 { | |
| match self { | |
| Self::Add | Self::Subtract => 1, | |
| Self::Multiply | Self::Divide | Self::Modulo => 2, | |
| Self::Power => 3, | |
| } | |
| } | |
| fn max_precedence() -> u8 { | |
| Self::Power.precedence() | |
| } | |
| fn compute(&self, left: f64, right: f64) -> f64 { | |
| match self { | |
| Self::Add => left + right, | |
| Self::Subtract => left - right, | |
| Self::Multiply => left * right, | |
| Self::Divide => left / right, | |
| Self::Modulo => left % right, | |
| Self::Power => left.powf(right), | |
| } | |
| } | |
| } | |
| #[derive(Debug, Eq, PartialEq)] | |
| enum ParserError { | |
| ExpectedOpenParen, | |
| ExpectedCloseParen, | |
| ExpectedExpression, | |
| TrailingInput, | |
| } | |
| #[derive(Clone, Debug)] | |
| struct Parser<'a> { | |
| bytes_read: usize, | |
| rest: &'a [u8], | |
| } | |
| impl<'a> Parser<'a> { | |
| fn new(input: &'a str) -> Self { | |
| Self { | |
| bytes_read: 0, | |
| rest: input.as_bytes(), | |
| } | |
| } | |
| fn whitespace(&mut self) { | |
| while !self.rest.is_empty() && self.rest[0] == b' ' { | |
| self.bytes_read += 1; | |
| self.rest = &self.rest[1..]; | |
| } | |
| } | |
| fn literal(&mut self, s: &str) -> bool { | |
| self.whitespace(); | |
| if self.rest.len() < s.len() { | |
| false | |
| } else if self.rest.starts_with(s.as_bytes()) { | |
| if self.rest.len() == s.len() { | |
| self.bytes_read += s.len(); | |
| self.rest = &self.rest[s.len()..]; | |
| true | |
| } else { | |
| let s_0 = s.as_bytes()[0]; | |
| let s_starts_with_alpha = | |
| (s_0 >= b'a' && s_0 <= b'z') || (s_0 >= b'A' && s_0 <= b'Z'); | |
| let c = self.rest[s.len()]; | |
| let trailing_alpha = (c >= b'a' && c <= b'z') || (c >= b'A' && c <= b'Z'); | |
| if !s_starts_with_alpha || !trailing_alpha { | |
| self.bytes_read += s.len(); | |
| self.rest = &self.rest[s.len()..]; | |
| true | |
| } else { | |
| false | |
| } | |
| } | |
| } else { | |
| false | |
| } | |
| } | |
| fn number(&mut self) -> Option<f64> { | |
| let mut end: usize = 0; | |
| while let Some(c) = self.rest.get(end) { | |
| if *c < b'0' || *c > b'9' { | |
| break; | |
| } | |
| end += 1; | |
| } | |
| match (self.rest.get(end), self.rest.get(end + 1)) { | |
| (Some(b'.'), Some(c)) if *c >= b'0' && *c <= b'9' => end += 2, | |
| _ => { | |
| let result = std::str::from_utf8(&self.rest[0..end]) | |
| .ok()? | |
| .parse::<f64>() | |
| .ok()?; | |
| self.bytes_read += end; | |
| self.rest = &self.rest[end..]; | |
| return Some(result); | |
| } | |
| } | |
| while let Some(c) = self.rest.get(end) { | |
| if *c < b'0' || *c > b'9' { | |
| break; | |
| } | |
| end += 1; | |
| } | |
| let result = std::str::from_utf8(&self.rest[0..end]) | |
| .ok()? | |
| .parse::<f64>() | |
| .ok()?; | |
| self.bytes_read += end; | |
| self.rest = &self.rest[end..]; | |
| return Some(result); | |
| } | |
| fn operator(&mut self, min_precedence: u8) -> Option<Operator> { | |
| let mut op_parser = self.clone(); | |
| let op = if op_parser.literal("+") { | |
| Operator::Add | |
| } else if op_parser.literal("-") { | |
| Operator::Subtract | |
| } else if op_parser.literal("*") { | |
| Operator::Multiply | |
| } else if op_parser.literal("/") { | |
| Operator::Divide | |
| } else if op_parser.literal("%") { | |
| Operator::Modulo | |
| } else if op_parser.literal("^") { | |
| Operator::Power | |
| } else { | |
| return None; | |
| }; | |
| if op.precedence() < min_precedence { | |
| None | |
| } else { | |
| self.bytes_read = op_parser.bytes_read; | |
| self.rest = op_parser.rest; | |
| Some(op) | |
| } | |
| } | |
| fn sqrt(&mut self) -> Result<f64, ParserError> { | |
| if !self.literal("(") { | |
| return Err(ParserError::ExpectedOpenParen); | |
| } | |
| let x = self.expr(0)?; | |
| if !self.literal(")") { | |
| return Err(ParserError::ExpectedCloseParen); | |
| } | |
| Ok(x.sqrt()) | |
| } | |
| fn atom(&mut self) -> Result<f64, ParserError> { | |
| if self.literal("(") { | |
| let x = self.expr(0)?; | |
| if !self.literal(")") { | |
| return Err(ParserError::ExpectedCloseParen); | |
| } | |
| Ok(x) | |
| } else if self.literal("-") { | |
| self.expr(Operator::max_precedence() + 1).map(|n| -n) | |
| } else if self.literal("sqrt") { | |
| self.sqrt() | |
| } else if let Some(n) = self.number() { | |
| Ok(n) | |
| } else { | |
| Err(ParserError::ExpectedExpression) | |
| } | |
| } | |
| fn expr(&mut self, min_precedence: u8) -> Result<f64, ParserError> { | |
| let mut result = self.atom()?; | |
| loop { | |
| let Some(op) = self.operator(min_precedence) else { | |
| break; | |
| }; | |
| let next_min_precedence = if op.left_associative() { | |
| op.precedence() + 1 | |
| } else { | |
| op.precedence() | |
| }; | |
| let rhs = self.expr(next_min_precedence)?; | |
| result = op.compute(result, rhs); | |
| } | |
| Ok(result) | |
| } | |
| } | |
| fn main() -> Result<(), &'static str> { | |
| let mut args = std::env::args(); | |
| let program = args | |
| .next() | |
| .expect("first argument should be the program name"); | |
| let Some(input) = args.next() else { | |
| eprintln!("Usage: {} [expression]", program); | |
| return Err("missing argument"); | |
| }; | |
| if args.next().is_some() { | |
| eprintln!("Usage: {} [expression]", program); | |
| return Err("too many arguments"); | |
| } | |
| let mut parser = Parser::new(&input); | |
| let mut result = parser.expr(0); | |
| if result.is_ok() { | |
| parser.whitespace(); | |
| if parser.rest != b"" { | |
| result = Err(ParserError::TrailingInput); | |
| } | |
| } | |
| match result { | |
| Ok(v) => println!("{}", v), | |
| Err(e) => { | |
| eprintln!("{}", input); | |
| for _ in 0..parser.bytes_read { | |
| eprint!(" "); | |
| } | |
| eprint!("^ "); | |
| return Err(match e { | |
| ParserError::ExpectedOpenParen => "expected open parenthesis", | |
| ParserError::ExpectedCloseParen => "expected close parenthesis", | |
| ParserError::ExpectedExpression => "expected expression", | |
| ParserError::TrailingInput => "trailing input", | |
| }); | |
| } | |
| } | |
| Ok(()) | |
| } | |
| #[test] | |
| fn parser_whitespace() { | |
| let mut p = Parser::new(""); | |
| p.whitespace(); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("1"); | |
| p.whitespace(); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b"1"); | |
| let mut p = Parser::new(" "); | |
| p.whitespace(); | |
| assert_eq!(p.bytes_read, 1); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new(" 1"); | |
| p.whitespace(); | |
| assert_eq!(p.bytes_read, 1); | |
| assert_eq!(p.rest, b"1"); | |
| let mut p = Parser::new(" "); | |
| p.whitespace(); | |
| assert_eq!(p.bytes_read, 2); | |
| assert_eq!(p.rest, b""); | |
| } | |
| #[test] | |
| fn parser_literal() { | |
| let mut p = Parser::new(""); | |
| assert_eq!(p.literal(""), true); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("("); | |
| assert_eq!(p.literal("("), true); | |
| assert_eq!(p.bytes_read, 1); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("(("); | |
| assert_eq!(p.literal("("), true); | |
| assert_eq!(p.bytes_read, 1); | |
| assert_eq!(p.rest, b"("); | |
| let mut p = Parser::new("aaa"); | |
| assert_eq!(p.literal("aaa"), true); | |
| assert_eq!(p.bytes_read, 3); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("aaaa"); | |
| assert_eq!(p.literal("aaa"), false); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b"aaaa"); | |
| let mut p = Parser::new(" aaa"); | |
| assert_eq!(p.literal("aaa"), true); | |
| assert_eq!(p.bytes_read, 4); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("aaa"); | |
| assert_eq!(p.literal("bbb"), false); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b"aaa"); | |
| let mut p = Parser::new("aaa bbb"); | |
| assert_eq!(p.literal("aaa"), true); | |
| assert_eq!(p.bytes_read, 3); | |
| assert_eq!(p.rest, b" bbb"); | |
| let mut p = Parser::new("aaabbb"); | |
| assert_eq!(p.literal("aaa"), false); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b"aaabbb"); | |
| } | |
| #[test] | |
| fn parser_number() { | |
| let mut p = Parser::new(""); | |
| assert_eq!(p.number(), None); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("123"); | |
| assert_eq!(p.number(), Some(123.0)); | |
| assert_eq!(p.bytes_read, 3); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("123."); | |
| assert_eq!(p.number(), Some(123.0)); | |
| assert_eq!(p.bytes_read, 3); | |
| assert_eq!(p.rest, b"."); | |
| let mut p = Parser::new("123.456"); | |
| assert_eq!(p.number(), Some(123.456)); | |
| assert_eq!(p.bytes_read, 7); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("123.abc"); | |
| assert_eq!(p.number(), Some(123.0)); | |
| assert_eq!(p.bytes_read, 3); | |
| assert_eq!(p.rest, b".abc"); | |
| } | |
| #[test] | |
| fn parser_atom() { | |
| let mut p = Parser::new(""); | |
| assert_eq!(p.atom(), Err(ParserError::ExpectedExpression)); | |
| assert_eq!(p.bytes_read, 0); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("123.456"); | |
| assert_eq!(p.atom(), Ok(123.456)); | |
| assert_eq!(p.bytes_read, 7); | |
| assert_eq!(p.rest, b""); | |
| let mut p = Parser::new("(123.456)"); | |
| assert_eq!(p.atom(), Ok(123.456)); | |
| assert_eq!(p.bytes_read, 9); | |
| assert_eq!(p.rest, b""); | |
| } | |
| #[test] | |
| fn expr_empty() { | |
| assert_eq!( | |
| Parser::new("").expr(0), | |
| Err(ParserError::ExpectedExpression) | |
| ); | |
| } | |
| #[test] | |
| fn expr_solo() { | |
| assert_eq!(Parser::new("1").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("0.2").expr(0), Ok(0.2)); | |
| } | |
| #[test] | |
| fn expr_ops() { | |
| assert_eq!(Parser::new("12+3").expr(0), Ok(15.0)); | |
| assert_eq!(Parser::new("12-3").expr(0), Ok(9.0)); | |
| assert_eq!(Parser::new("12*3").expr(0), Ok(36.0)); | |
| assert_eq!(Parser::new("12/3").expr(0), Ok(4.0)); | |
| assert_eq!(Parser::new("12%5").expr(0), Ok(2.0)); | |
| assert_eq!(Parser::new("12^3").expr(0), Ok(1728.0)); | |
| } | |
| #[test] | |
| fn expr_precedence() { | |
| assert_eq!(Parser::new("5+4*3/2-1").expr(0), Ok(10.0)); | |
| assert_eq!(Parser::new("4+3*6/9-1").expr(0), Ok(5.0)); | |
| assert_eq!(Parser::new("5+2*6/4-7").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("-9^2").expr(0), Ok(81.0)); | |
| } | |
| #[test] | |
| fn expr_paren() { | |
| assert_eq!(Parser::new("(1)").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("(((1)))").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("(2+3)*(5+7)").expr(0), Ok(60.0)); | |
| } | |
| #[test] | |
| fn expr_prefix() { | |
| assert_eq!(Parser::new("-1").expr(0), Ok(-1.0)); | |
| assert_eq!(Parser::new("--1").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("---1").expr(0), Ok(-1.0)); | |
| assert_eq!(Parser::new("- 1").expr(0), Ok(-1.0)); | |
| assert_eq!(Parser::new("(-1)").expr(0), Ok(-1.0)); | |
| assert_eq!(Parser::new("-(1)").expr(0), Ok(-1.0)); | |
| assert_eq!(Parser::new("-(-1)").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("1--1").expr(0), Ok(2.0)); | |
| assert_eq!(Parser::new("1---1").expr(0), Ok(0.0)); | |
| } | |
| #[test] | |
| fn expr_pow() { | |
| assert_eq!(Parser::new("1 ^ 1").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("2^2").expr(0), Ok(4.0)); | |
| assert_eq!(Parser::new("2^3").expr(0), Ok(8.0)); | |
| assert_eq!(Parser::new("3^2").expr(0), Ok(9.0)); | |
| assert_eq!(Parser::new("3^3").expr(0), Ok(27.0)); | |
| assert_eq!(Parser::new("4^0").expr(0), Ok(1.0)); | |
| } | |
| #[test] | |
| fn expr_sqrt() { | |
| assert_eq!(Parser::new("sqrt(1)").expr(0), Ok(1.0)); | |
| assert_eq!(Parser::new("sqrt(4)").expr(0), Ok(2.0)); | |
| assert_eq!(Parser::new("sqrt(9)").expr(0), Ok(3.0)); | |
| assert_eq!(Parser::new("sqrt(16)").expr(0), Ok(4.0)); | |
| assert_eq!(Parser::new("sqrt(25)").expr(0), Ok(5.0)); | |
| assert_eq!(Parser::new("sqrt (1)").expr(0), Ok(1.0)); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment