Skip to content

Instantly share code, notes, and snippets.

@tung
Last active September 14, 2025 05:25
Show Gist options
  • Select an option

  • Save tung/10d23faf34e853bb451c009b66317203 to your computer and use it in GitHub Desktop.

Select an option

Save tung/10d23faf34e853bb451c009b66317203 to your computer and use it in GitHub Desktop.
Math expression calculator in Rust, using precedence climbing for operator precedence.
// Math expression calculator in Rust, using precedence climbing for operator precedence.
//
// rustc -C opt-level=z -C lto=fat -C strip=symbols -C panic=abort calc.rs
enum Operator {
Add,
Subtract,
Multiply,
Divide,
Modulo,
Power,
}
impl Operator {
fn left_associative(&self) -> bool {
match self {
Self::Add | Self::Subtract | Self::Multiply | Self::Divide | Self::Modulo => true,
Self::Power => false,
}
}
fn precedence(&self) -> u8 {
match self {
Self::Add | Self::Subtract => 1,
Self::Multiply | Self::Divide | Self::Modulo => 2,
Self::Power => 3,
}
}
fn max_precedence() -> u8 {
Self::Power.precedence()
}
fn compute(&self, left: f64, right: f64) -> f64 {
match self {
Self::Add => left + right,
Self::Subtract => left - right,
Self::Multiply => left * right,
Self::Divide => left / right,
Self::Modulo => left % right,
Self::Power => left.powf(right),
}
}
}
#[derive(Debug, Eq, PartialEq)]
enum ParserError {
ExpectedOpenParen,
ExpectedCloseParen,
ExpectedExpression,
TrailingInput,
}
#[derive(Clone, Debug)]
struct Parser<'a> {
bytes_read: usize,
rest: &'a [u8],
}
impl<'a> Parser<'a> {
fn new(input: &'a str) -> Self {
Self {
bytes_read: 0,
rest: input.as_bytes(),
}
}
fn whitespace(&mut self) {
while !self.rest.is_empty() && self.rest[0] == b' ' {
self.bytes_read += 1;
self.rest = &self.rest[1..];
}
}
fn literal(&mut self, s: &str) -> bool {
self.whitespace();
if self.rest.len() < s.len() {
false
} else if self.rest.starts_with(s.as_bytes()) {
if self.rest.len() == s.len() {
self.bytes_read += s.len();
self.rest = &self.rest[s.len()..];
true
} else {
let s_0 = s.as_bytes()[0];
let s_starts_with_alpha =
(s_0 >= b'a' && s_0 <= b'z') || (s_0 >= b'A' && s_0 <= b'Z');
let c = self.rest[s.len()];
let trailing_alpha = (c >= b'a' && c <= b'z') || (c >= b'A' && c <= b'Z');
if !s_starts_with_alpha || !trailing_alpha {
self.bytes_read += s.len();
self.rest = &self.rest[s.len()..];
true
} else {
false
}
}
} else {
false
}
}
fn number(&mut self) -> Option<f64> {
let mut end: usize = 0;
while let Some(c) = self.rest.get(end) {
if *c < b'0' || *c > b'9' {
break;
}
end += 1;
}
match (self.rest.get(end), self.rest.get(end + 1)) {
(Some(b'.'), Some(c)) if *c >= b'0' && *c <= b'9' => end += 2,
_ => {
let result = std::str::from_utf8(&self.rest[0..end])
.ok()?
.parse::<f64>()
.ok()?;
self.bytes_read += end;
self.rest = &self.rest[end..];
return Some(result);
}
}
while let Some(c) = self.rest.get(end) {
if *c < b'0' || *c > b'9' {
break;
}
end += 1;
}
let result = std::str::from_utf8(&self.rest[0..end])
.ok()?
.parse::<f64>()
.ok()?;
self.bytes_read += end;
self.rest = &self.rest[end..];
return Some(result);
}
fn operator(&mut self, min_precedence: u8) -> Option<Operator> {
let mut op_parser = self.clone();
let op = if op_parser.literal("+") {
Operator::Add
} else if op_parser.literal("-") {
Operator::Subtract
} else if op_parser.literal("*") {
Operator::Multiply
} else if op_parser.literal("/") {
Operator::Divide
} else if op_parser.literal("%") {
Operator::Modulo
} else if op_parser.literal("^") {
Operator::Power
} else {
return None;
};
if op.precedence() < min_precedence {
None
} else {
self.bytes_read = op_parser.bytes_read;
self.rest = op_parser.rest;
Some(op)
}
}
fn sqrt(&mut self) -> Result<f64, ParserError> {
if !self.literal("(") {
return Err(ParserError::ExpectedOpenParen);
}
let x = self.expr(0)?;
if !self.literal(")") {
return Err(ParserError::ExpectedCloseParen);
}
Ok(x.sqrt())
}
fn atom(&mut self) -> Result<f64, ParserError> {
if self.literal("(") {
let x = self.expr(0)?;
if !self.literal(")") {
return Err(ParserError::ExpectedCloseParen);
}
Ok(x)
} else if self.literal("-") {
self.expr(Operator::max_precedence() + 1).map(|n| -n)
} else if self.literal("sqrt") {
self.sqrt()
} else if let Some(n) = self.number() {
Ok(n)
} else {
Err(ParserError::ExpectedExpression)
}
}
fn expr(&mut self, min_precedence: u8) -> Result<f64, ParserError> {
let mut result = self.atom()?;
loop {
let Some(op) = self.operator(min_precedence) else {
break;
};
let next_min_precedence = if op.left_associative() {
op.precedence() + 1
} else {
op.precedence()
};
let rhs = self.expr(next_min_precedence)?;
result = op.compute(result, rhs);
}
Ok(result)
}
}
fn main() -> Result<(), &'static str> {
let mut args = std::env::args();
let program = args
.next()
.expect("first argument should be the program name");
let Some(input) = args.next() else {
eprintln!("Usage: {} [expression]", program);
return Err("missing argument");
};
if args.next().is_some() {
eprintln!("Usage: {} [expression]", program);
return Err("too many arguments");
}
let mut parser = Parser::new(&input);
let mut result = parser.expr(0);
if result.is_ok() {
parser.whitespace();
if parser.rest != b"" {
result = Err(ParserError::TrailingInput);
}
}
match result {
Ok(v) => println!("{}", v),
Err(e) => {
eprintln!("{}", input);
for _ in 0..parser.bytes_read {
eprint!(" ");
}
eprint!("^ ");
return Err(match e {
ParserError::ExpectedOpenParen => "expected open parenthesis",
ParserError::ExpectedCloseParen => "expected close parenthesis",
ParserError::ExpectedExpression => "expected expression",
ParserError::TrailingInput => "trailing input",
});
}
}
Ok(())
}
#[test]
fn parser_whitespace() {
let mut p = Parser::new("");
p.whitespace();
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"");
let mut p = Parser::new("1");
p.whitespace();
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"1");
let mut p = Parser::new(" ");
p.whitespace();
assert_eq!(p.bytes_read, 1);
assert_eq!(p.rest, b"");
let mut p = Parser::new(" 1");
p.whitespace();
assert_eq!(p.bytes_read, 1);
assert_eq!(p.rest, b"1");
let mut p = Parser::new(" ");
p.whitespace();
assert_eq!(p.bytes_read, 2);
assert_eq!(p.rest, b"");
}
#[test]
fn parser_literal() {
let mut p = Parser::new("");
assert_eq!(p.literal(""), true);
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"");
let mut p = Parser::new("(");
assert_eq!(p.literal("("), true);
assert_eq!(p.bytes_read, 1);
assert_eq!(p.rest, b"");
let mut p = Parser::new("((");
assert_eq!(p.literal("("), true);
assert_eq!(p.bytes_read, 1);
assert_eq!(p.rest, b"(");
let mut p = Parser::new("aaa");
assert_eq!(p.literal("aaa"), true);
assert_eq!(p.bytes_read, 3);
assert_eq!(p.rest, b"");
let mut p = Parser::new("aaaa");
assert_eq!(p.literal("aaa"), false);
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"aaaa");
let mut p = Parser::new(" aaa");
assert_eq!(p.literal("aaa"), true);
assert_eq!(p.bytes_read, 4);
assert_eq!(p.rest, b"");
let mut p = Parser::new("aaa");
assert_eq!(p.literal("bbb"), false);
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"aaa");
let mut p = Parser::new("aaa bbb");
assert_eq!(p.literal("aaa"), true);
assert_eq!(p.bytes_read, 3);
assert_eq!(p.rest, b" bbb");
let mut p = Parser::new("aaabbb");
assert_eq!(p.literal("aaa"), false);
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"aaabbb");
}
#[test]
fn parser_number() {
let mut p = Parser::new("");
assert_eq!(p.number(), None);
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"");
let mut p = Parser::new("123");
assert_eq!(p.number(), Some(123.0));
assert_eq!(p.bytes_read, 3);
assert_eq!(p.rest, b"");
let mut p = Parser::new("123.");
assert_eq!(p.number(), Some(123.0));
assert_eq!(p.bytes_read, 3);
assert_eq!(p.rest, b".");
let mut p = Parser::new("123.456");
assert_eq!(p.number(), Some(123.456));
assert_eq!(p.bytes_read, 7);
assert_eq!(p.rest, b"");
let mut p = Parser::new("123.abc");
assert_eq!(p.number(), Some(123.0));
assert_eq!(p.bytes_read, 3);
assert_eq!(p.rest, b".abc");
}
#[test]
fn parser_atom() {
let mut p = Parser::new("");
assert_eq!(p.atom(), Err(ParserError::ExpectedExpression));
assert_eq!(p.bytes_read, 0);
assert_eq!(p.rest, b"");
let mut p = Parser::new("123.456");
assert_eq!(p.atom(), Ok(123.456));
assert_eq!(p.bytes_read, 7);
assert_eq!(p.rest, b"");
let mut p = Parser::new("(123.456)");
assert_eq!(p.atom(), Ok(123.456));
assert_eq!(p.bytes_read, 9);
assert_eq!(p.rest, b"");
}
#[test]
fn expr_empty() {
assert_eq!(
Parser::new("").expr(0),
Err(ParserError::ExpectedExpression)
);
}
#[test]
fn expr_solo() {
assert_eq!(Parser::new("1").expr(0), Ok(1.0));
assert_eq!(Parser::new("0.2").expr(0), Ok(0.2));
}
#[test]
fn expr_ops() {
assert_eq!(Parser::new("12+3").expr(0), Ok(15.0));
assert_eq!(Parser::new("12-3").expr(0), Ok(9.0));
assert_eq!(Parser::new("12*3").expr(0), Ok(36.0));
assert_eq!(Parser::new("12/3").expr(0), Ok(4.0));
assert_eq!(Parser::new("12%5").expr(0), Ok(2.0));
assert_eq!(Parser::new("12^3").expr(0), Ok(1728.0));
}
#[test]
fn expr_precedence() {
assert_eq!(Parser::new("5+4*3/2-1").expr(0), Ok(10.0));
assert_eq!(Parser::new("4+3*6/9-1").expr(0), Ok(5.0));
assert_eq!(Parser::new("5+2*6/4-7").expr(0), Ok(1.0));
assert_eq!(Parser::new("-9^2").expr(0), Ok(81.0));
}
#[test]
fn expr_paren() {
assert_eq!(Parser::new("(1)").expr(0), Ok(1.0));
assert_eq!(Parser::new("(((1)))").expr(0), Ok(1.0));
assert_eq!(Parser::new("(2+3)*(5+7)").expr(0), Ok(60.0));
}
#[test]
fn expr_prefix() {
assert_eq!(Parser::new("-1").expr(0), Ok(-1.0));
assert_eq!(Parser::new("--1").expr(0), Ok(1.0));
assert_eq!(Parser::new("---1").expr(0), Ok(-1.0));
assert_eq!(Parser::new("- 1").expr(0), Ok(-1.0));
assert_eq!(Parser::new("(-1)").expr(0), Ok(-1.0));
assert_eq!(Parser::new("-(1)").expr(0), Ok(-1.0));
assert_eq!(Parser::new("-(-1)").expr(0), Ok(1.0));
assert_eq!(Parser::new("1--1").expr(0), Ok(2.0));
assert_eq!(Parser::new("1---1").expr(0), Ok(0.0));
}
#[test]
fn expr_pow() {
assert_eq!(Parser::new("1 ^ 1").expr(0), Ok(1.0));
assert_eq!(Parser::new("2^2").expr(0), Ok(4.0));
assert_eq!(Parser::new("2^3").expr(0), Ok(8.0));
assert_eq!(Parser::new("3^2").expr(0), Ok(9.0));
assert_eq!(Parser::new("3^3").expr(0), Ok(27.0));
assert_eq!(Parser::new("4^0").expr(0), Ok(1.0));
}
#[test]
fn expr_sqrt() {
assert_eq!(Parser::new("sqrt(1)").expr(0), Ok(1.0));
assert_eq!(Parser::new("sqrt(4)").expr(0), Ok(2.0));
assert_eq!(Parser::new("sqrt(9)").expr(0), Ok(3.0));
assert_eq!(Parser::new("sqrt(16)").expr(0), Ok(4.0));
assert_eq!(Parser::new("sqrt(25)").expr(0), Ok(5.0));
assert_eq!(Parser::new("sqrt (1)").expr(0), Ok(1.0));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment