From 94bffebd8f3802a49654ebf34a495271bcd8809e Mon Sep 17 00:00:00 2001 From: Evan Peterson <77evan@gmail.com> Date: Fri, 5 Dec 2025 23:44:46 -0500 Subject: [PATCH] significant cleanup, restructuring, query parsing --- src/core/amounts.rs | 19 +- src/core/mod.rs | 2 + src/{queries/base.rs => core/value.rs} | 128 ++------ src/document/directives.rs | 27 +- src/document/ledger.rs | 172 +++++----- src/document/mod.rs | 4 +- src/document/parser/amounts.rs | 156 ---------- src/lib.rs | 6 +- src/main.rs | 51 ++- src/parser/amount.rs | 100 ++++++ src/parser/core.rs | 143 +++++++++ .../document}/base_directive.rs | 12 +- .../parser => parser/document}/directives.rs | 37 +-- .../parser => parser/document}/mod.rs | 5 +- .../parser => parser/document}/shared.rs | 4 + .../parser => parser/document}/transaction.rs | 69 ++-- src/parser/fields.rs | 48 +++ src/parser/mod.rs | 109 +------ src/parser/query.rs | 246 +++++++++++++++ src/parser/value.rs | 47 +++ src/queries/balance.rs | 169 ---------- src/queries/functions.rs | 294 ------------------ src/queries/mod.rs | 9 - src/queries/parser/functions.rs | 142 --------- src/queries/parser/mod.rs | 1 - src/queries/postings.rs | 11 - src/query/balance.rs | 75 +++++ src/query/functions_comparison.rs | 202 ++++++++++++ src/query/functions_logical.rs | 125 ++++++++ src/query/mod.rs | 11 + src/query/query.rs | 68 ++++ src/{queries => query}/transaction.rs | 14 +- 32 files changed, 1347 insertions(+), 1159 deletions(-) rename src/{queries/base.rs => core/value.rs} (54%) delete mode 100644 src/document/parser/amounts.rs create mode 100644 src/parser/amount.rs create mode 100644 src/parser/core.rs rename src/{document/parser => parser/document}/base_directive.rs (97%) rename src/{document/parser => parser/document}/directives.rs (84%) rename src/{document/parser => parser/document}/mod.rs (94%) rename src/{document/parser => parser/document}/shared.rs (88%) rename src/{document/parser => parser/document}/transaction.rs (87%) create mode 100644 src/parser/fields.rs create mode 100644 src/parser/query.rs create mode 100644 src/parser/value.rs delete mode 100644 src/queries/balance.rs delete mode 100644 src/queries/functions.rs delete mode 100644 src/queries/mod.rs delete mode 100644 src/queries/parser/functions.rs delete mode 100644 src/queries/parser/mod.rs delete mode 100644 src/queries/postings.rs create mode 100644 src/query/balance.rs create mode 100644 src/query/functions_comparison.rs create mode 100644 src/query/functions_logical.rs create mode 100644 src/query/mod.rs create mode 100644 src/query/query.rs rename src/{queries => query}/transaction.rs (85%) diff --git a/src/core/amounts.rs b/src/core/amounts.rs index 5037fc3..eae91fb 100644 --- a/src/core/amounts.rs +++ b/src/core/amounts.rs @@ -5,6 +5,13 @@ use rust_decimal_macros::dec; use super::{common::generate_id, CoreError}; +#[derive(Debug, PartialEq, Clone)] +pub struct RawAmount { + pub value: Decimal, + pub unit_symbol: String, + pub is_unit_prefix: bool, +} + #[derive(Debug, PartialEq, Clone, Copy)] pub struct Amount { pub value: Decimal, @@ -35,10 +42,7 @@ pub struct Unit { impl Amount { pub fn at_opt_price(&self, price: Option) -> Amount { if let Some(p) = price { - Amount { - value: self.value * p.value, - unit_id: p.unit_id, - } + Amount { value: self.value * p.value, unit_id: p.unit_id } } else { *self } @@ -89,7 +93,10 @@ pub fn combine_amounts(amounts: impl Iterator) -> Vec { *output_amounts.entry(amount.unit_id).or_insert(dec!(0)) += amount.value; } - output_amounts.iter().map(|(&unit_id, &value)| Amount {value, unit_id}).collect() + output_amounts + .iter() + .map(|(&unit_id, &value)| Amount { value, unit_id }) + .collect() } impl PartialOrd for Amount { @@ -100,4 +107,4 @@ impl PartialOrd for Amount { self.value.partial_cmp(&other.value) } } -} \ No newline at end of file +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 382fc23..8f742a8 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -5,6 +5,7 @@ mod errors; mod ledger; mod price; mod transaction; +mod value; pub use account::*; pub use amounts::*; @@ -12,3 +13,4 @@ pub use errors::*; pub use ledger::*; pub use price::*; pub use transaction::*; +pub use value::*; \ No newline at end of file diff --git a/src/queries/base.rs b/src/core/value.rs similarity index 54% rename from src/queries/base.rs rename to src/core/value.rs index 809972a..717a667 100644 --- a/src/queries/base.rs +++ b/src/core/value.rs @@ -1,57 +1,29 @@ -use std::collections::HashMap; - use chrono::NaiveDate; use rust_decimal::{prelude::Zero, Decimal}; - -use crate::core::{Amount, CoreError}; +use std::collections::HashMap; #[derive(Debug, Clone)] pub enum StringData<'a> { Owned(String), - Reference(&'a str) + Reference(&'a str), } #[derive(Debug, Clone, PartialEq)] -pub enum DataValue<'a> { +pub enum DataValue { Null, Integer(u32), Decimal(Decimal), Boolean(bool), - String(StringData<'a>), + String(StringData<'static>), Date(NaiveDate), - Amount(Amount), - List(Vec>), - Map(HashMap<&'static str, DataValue<'a>>), + List(Vec), + Map(HashMap<&'static str, DataValue>), + // Amount(Amount), } -pub enum Query<'a, T> { - Field(T), - Value(DataValue<'a>), - Function(Box>), -} - -pub trait Data { - fn get_field(&self, field: &T) -> Result; -} - -pub trait Function { - fn evaluate(&self, context: &dyn Data) -> Result; -} - -// impl ConstantValue { -// pub fn to_bool(&self) -> bool { -// match self { -// ConstantValue::Integer(val) => !val.is_zero(), -// ConstantValue::Decimal(val) => !val.is_zero(), -// ConstantValue::Boolean(val) => *val, -// ConstantValue::String(val) => val.is_empty(), -// ConstantValue::Date(_) => true, -// ConstantValue::Amount(val) => !val.value.is_zero(), -// ConstantValue::List(list) => !list.is_empty(), -// ConstantValue::Map(map) => !map.is_empty(), -// } -// } -// } +///////////////////// +// Implementations // +///////////////////// impl<'a> StringData<'a> { pub fn as_ref(&'a self) -> &'a str { @@ -60,6 +32,13 @@ impl<'a> StringData<'a> { StringData::Reference(val) => val, } } + + pub fn into_owned(self) -> StringData<'static> { + match self { + StringData::Owned(s) => StringData::Owned(s), + StringData::Reference(s) => StringData::Owned(s.to_string()), + } + } } impl<'a> PartialEq for StringData<'a> { @@ -90,7 +69,7 @@ impl<'a> PartialOrd for StringData<'a> { } } -impl<'a> PartialOrd for DataValue<'a> { +impl PartialOrd for DataValue { fn partial_cmp(&self, other: &Self) -> Option { match (self, other) { (DataValue::Null, DataValue::Null) => Some(std::cmp::Ordering::Equal), @@ -99,7 +78,7 @@ impl<'a> PartialOrd for DataValue<'a> { (DataValue::Boolean(val1), DataValue::Boolean(val2)) => val1.partial_cmp(val2), (DataValue::String(val1), DataValue::String(val2)) => val1.partial_cmp(val2), (DataValue::Date(val1), DataValue::Date(val2)) => val1.partial_cmp(val2), - (DataValue::Amount(val1), DataValue::Amount(val2)) => val1.partial_cmp(val2), + // (DataValue::Amount(val1), DataValue::Amount(val2)) => val1.partial_cmp(val2), (DataValue::List(val1), DataValue::List(val2)) => val1.partial_cmp(val2), _ => None, } @@ -118,50 +97,49 @@ impl<'a> From for StringData<'a> { } } -impl<'a> From for DataValue<'a> { +impl From for DataValue { fn from(value: u32) -> Self { DataValue::Integer(value) } } -impl<'a> From for DataValue<'a> { +impl From for DataValue { fn from(value: Decimal) -> Self { DataValue::Decimal(value) } } -impl<'a> From for DataValue<'a> { +impl From for DataValue { fn from(value: bool) -> Self { DataValue::Boolean(value) } } -impl<'a> From<&'a str> for DataValue<'a> { +impl<'a> From<&'a str> for DataValue { fn from(value: &'a str) -> Self { - DataValue::String(value.into()) + DataValue::String(StringData::from(value).into_owned()) } } - -impl<'a> From for DataValue<'a> { +impl From for DataValue { fn from(value: String) -> Self { DataValue::String(value.into()) } } -impl<'a> From for DataValue<'a> { +impl From for DataValue { fn from(value: NaiveDate) -> Self { DataValue::Date(value) } } -impl<'a> From for DataValue<'a> { - fn from(value: Amount) -> Self { - DataValue::Amount(value) - } -} +// impl<'a> From for DataValue<'a> { +// fn from(value: Amount) -> Self { +// DataValue::Amount(value) +// } +// } -impl<'a> From> for bool { +impl From for bool { fn from(value: DataValue) -> Self { match value { DataValue::Null => false, @@ -170,49 +148,9 @@ impl<'a> From> for bool { DataValue::Boolean(val) => val, DataValue::String(val) => val.as_ref().is_empty(), DataValue::Date(_) => true, - DataValue::Amount(val) => !val.value.is_zero(), + // DataValue::Amount(val) => !val.value.is_zero(), DataValue::List(list) => !list.is_empty(), DataValue::Map(map) => !map.is_empty(), } } } - -impl<'a, T> Query<'a, T> { - pub fn evaluate(&self, context: &'a dyn Data) -> Result { - match self { - Query::Field(field) => context.get_field(field), - Query::Value(constant) => Ok(constant.clone()), - Query::Function(function) => function.evaluate(context), - } - } - - pub fn from_field(field: T) -> Self { - Query::Field(field) - } - - pub fn from_fn + Sized + 'static>(function: F) -> Self { - Query::Function(Box::new(function)) - } -} - -impl<'a, T> From> for Query<'a, T> { - fn from(constant: DataValue<'a>) -> Self { - Query::Value(constant) - } -} - -// impl Function for T { -// fn to_value(self) -> Value { -// Value::Function(Box::new(self)) -// } -// } - -// impl> T { - -// } - -// impl + Sized + 'static> From for Value { -// fn from(function: T) -> Self { -// Value::Function(Box::new(function)) -// } -// } diff --git a/src/document/directives.rs b/src/document/directives.rs index 3e176d1..1872bae 100644 --- a/src/document/directives.rs +++ b/src/document/directives.rs @@ -1,9 +1,8 @@ use std::path::PathBuf; use chrono::NaiveDate; -use rust_decimal::Decimal; -use crate::core::UnitSymbol; +use crate::core::{RawAmount, UnitSymbol}; #[derive(Debug)] pub struct Directives { @@ -44,7 +43,7 @@ pub struct TransactionDirective { pub struct BalanceDirective { pub date: NaiveDate, pub account: String, - pub amounts: Vec, + pub amounts: Vec, } /////////////// @@ -55,16 +54,9 @@ pub struct BalanceDirective { pub struct DirectivePosting { pub date: Option, pub account: String, - pub amount: Option, - pub cost: Option, - pub price: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct DirectiveAmount { - pub value: Decimal, - pub unit_symbol: String, - pub is_unit_prefix: bool, + pub amount: Option, + pub cost: Option, + pub price: Option, } ///////////////////// @@ -73,7 +65,12 @@ pub struct DirectiveAmount { impl Directives { pub fn new() -> Self { - Directives{includes: Vec::new(), commodities: Vec::new(), transactions: Vec::new(), balances: Vec::new()} + Directives { + includes: Vec::new(), + commodities: Vec::new(), + transactions: Vec::new(), + balances: Vec::new(), + } } pub fn add_directives(&mut self, other: &Directives) { @@ -82,4 +79,4 @@ impl Directives { self.balances.extend(other.balances.clone()); self.commodities.extend(other.commodities.clone()); } -} \ No newline at end of file +} diff --git a/src/document/ledger.rs b/src/document/ledger.rs index 801928a..c643129 100644 --- a/src/document/ledger.rs +++ b/src/document/ledger.rs @@ -4,21 +4,21 @@ use rust_decimal_macros::dec; use crate::{ core::{ - Account, Amount, CoreError, Ledger, Posting, Transaction, TransactionFlag, Unit, UnitSymbol, - }, - queries::{ - self, - base::{self, DataValue}, - functions::{ - ComparisonFunction, LogicalFunction, RegexFunction, StringComparisonFunction, - SubAccountFunction, - }, - transaction::{AccountField, PostingField, TransactionField}, - Query, - }, + Account, Amount, CoreError, DataValue, Ledger, Posting, RawAmount, Transaction, TransactionFlag, Unit, UnitSymbol + }, query::{self, AccountField, ComparisonFunction, LogicalFunction, PostingField, Query, RegexFunction, TransactionField}, + // queries::{ + // self, + // base::{self, DataValue}, + // functions::{ + // ComparisonFunction, LogicalFunction, RegexFunction, StringComparisonFunction, + // SubAccountFunction, + // }, + // transaction::{AccountField, PostingField, TransactionField}, + // Query, + // }, }; -use super::{BalanceDirective, DirectiveAmount, TransactionDirective}; +use super::{BalanceDirective, TransactionDirective}; pub fn add_transaction( ledger: &mut Ledger, @@ -75,8 +75,8 @@ pub fn add_transaction( pub fn check_balance2(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), CoreError> { let date_query = ComparisonFunction::new( "<=", - base::Query::from_field(PostingField::Transaction(TransactionField::Date)), - base::Query::from(DataValue::from(balance.date)), + Query::from_field(PostingField::Transaction(TransactionField::Date)), + Query::from(DataValue::from(balance.date)), ) .unwrap(); // let account_fn = |str: &str| { @@ -93,23 +93,34 @@ pub fn check_balance2(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), // base::Query::from_field(PostingField::Account(AccountField::Name)), // &account_regex, // )?; - let account_query = SubAccountFunction::new( - balance.account.clone().into(), - base::Query::from_field(PostingField::Account(AccountField::Name)), - ); + + + // let account_query = SubAccountFunction::new( + // balance.account.clone().into(), + // base::Query::from_field(PostingField::Account(AccountField::Name)), + // ); + + // TODO: is this efficient enough? + let account_query = RegexFunction::new( + Query::from_field(PostingField::Account(AccountField::Name)), + format!("^{}", balance.account).as_str(), + true + // "^" + balance.account.clone(), + ).unwrap(); + let start = Instant::now(); let total_query = LogicalFunction::new( "and", - base::Query::from_fn(date_query), - base::Query::from_fn(account_query), + Query::from_fn(date_query), + Query::from_fn(account_query), ) .unwrap(); let t2 = Instant::now(); - let accounts = queries::balance3(&ledger, &base::Query::from_fn(total_query)); + let accounts = query::balance(&ledger, Some(&Query::from_fn(total_query)), None); let t3 = Instant::now(); @@ -175,76 +186,77 @@ pub fn check_balance2(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), Ok(()) } -pub fn check_balance(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), CoreError> { - let accounts = queries::balance(&ledger, &[Query::EndDate(balance.date)]); - // let accounts = queries::balance2(&ledger, balance.date); +// pub fn check_balance(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), CoreError> { +// let accounts = query::balance(&ledger, &[Query::EndDate(balance.date)]); - let accounts = accounts.iter().filter(|(&account_id, val)| { - let account = ledger.get_account(account_id).unwrap(); - account.is_under_account(&balance.account) - }); +// // let accounts = queries::balance2(&ledger, balance.date); - if accounts.clone().count() == 0 {} +// let accounts = accounts.iter().filter(|(&account_id, val)| { +// let account = ledger.get_account(account_id).unwrap(); +// account.is_under_account(&balance.account) +// }); - let mut total_amounts = HashMap::new(); - let mut account_count = 0; +// if accounts.clone().count() == 0 {} - for (_, amounts) in accounts { - account_count += 1; - for amount in amounts { - *total_amounts.entry(amount.unit_id).or_insert(dec!(0)) += amount.value; - } - } +// let mut total_amounts = HashMap::new(); +// let mut account_count = 0; - if account_count == 0 { - return Err("No accounts match balance account".into()); - } +// for (_, amounts) in accounts { +// account_count += 1; +// for amount in amounts { +// *total_amounts.entry(amount.unit_id).or_insert(dec!(0)) += amount.value; +// } +// } - // let balance_account = ledger - // .get_account_by_name(&balance.account) - // .ok_or("Account not found")?; +// if account_count == 0 { +// return Err("No accounts match balance account".into()); +// } - // let amounts = accounts - // .get(&balance_account.get_id()) - // .map(|v| v.as_slice()) - // .unwrap_or(&[]); +// // let balance_account = ledger +// // .get_account_by_name(&balance.account) +// // .ok_or("Account not found")?; - // if amounts.len() > balance.amounts.len() { - // return Err("".into()); - // } else if amounts.len() < balance.amounts.len() { - // return Err("".into()); - // } +// // let amounts = accounts +// // .get(&balance_account.get_id()) +// // .map(|v| v.as_slice()) +// // .unwrap_or(&[]); - for balance_amount in &balance.amounts { - let unit = ledger - .get_unit_by_symbol(&balance_amount.unit_symbol) - .ok_or("Unit not found")?; - let value = total_amounts - .get(&unit.get_id()) - .map(|v| *v) - .unwrap_or(dec!(0)); +// // if amounts.len() > balance.amounts.len() { +// // return Err("".into()); +// // } else if amounts.len() < balance.amounts.len() { +// // return Err("".into()); +// // } - // let value = amounts - // .iter() - // .find(|a| a.unit_id == unit.get_id()) - // .map(|a| a.value) - // .unwrap_or(dec!(0)); - let max_scale = max(value.scale(), balance_amount.value.scale()); +// for balance_amount in &balance.amounts { +// let unit = ledger +// .get_unit_by_symbol(&balance_amount.unit_symbol) +// .ok_or("Unit not found")?; +// let value = total_amounts +// .get(&unit.get_id()) +// .map(|v| *v) +// .unwrap_or(dec!(0)); - let value = value.round_dp(max_scale); - let balance_value = balance_amount.value.round_dp(max_scale); +// // let value = amounts +// // .iter() +// // .find(|a| a.unit_id == unit.get_id()) +// // .map(|a| a.value) +// // .unwrap_or(dec!(0)); +// let max_scale = max(value.scale(), balance_amount.value.scale()); - if value != balance_value { - return Err(format!( - "Balance amount for \"{}\" on {} does not match. Expected {} but got {}", - balance.account, balance.date, balance_value, value - ) - .into()); - } - } +// let value = value.round_dp(max_scale); +// let balance_value = balance_amount.value.round_dp(max_scale); - Ok(()) -} +// if value != balance_value { +// return Err(format!( +// "Balance amount for \"{}\" on {} does not match. Expected {} but got {}", +// balance.account, balance.date, balance_value, value +// ) +// .into()); +// } +// } + +// Ok(()) +// } struct IncompletePosting { account_id: u32, @@ -253,7 +265,7 @@ struct IncompletePosting { price: Option, } -fn create_amount(ledger: &mut Ledger, amount: &DirectiveAmount) -> Result { +fn create_amount(ledger: &mut Ledger, amount: &RawAmount) -> Result { let unit_id = get_or_create_unit(ledger, &amount.unit_symbol, amount.is_unit_prefix)?; Ok(Amount { value: amount.value, unit_id }) diff --git a/src/document/mod.rs b/src/document/mod.rs index 77c598b..276b360 100644 --- a/src/document/mod.rs +++ b/src/document/mod.rs @@ -1,12 +1,10 @@ mod directives; mod ledger; -mod parser; pub use directives::*; use ledger::{add_transaction, check_balance2}; -use parser::parse_directives; -use crate::core::{CoreError, Ledger, Unit}; +use crate::{core::{CoreError, Ledger, Unit}, parser::parse_directives}; use std::{path::Path, time::Instant}; #[derive(Debug)] diff --git a/src/document/parser/amounts.rs b/src/document/parser/amounts.rs deleted file mode 100644 index d3030b3..0000000 --- a/src/document/parser/amounts.rs +++ /dev/null @@ -1,156 +0,0 @@ -use nom::{ - branch::alt, - character::complete::{char, none_of, one_of, space0}, - combinator::{opt, recognize}, - error::{Error, ErrorKind}, - multi::{many0, many1}, - sequence::{preceded, terminated, tuple}, - Err, IResult, InputTakeAtPosition, Parser, -}; -use rust_decimal::Decimal; -use rust_decimal_macros::dec; - -use crate::document::DirectiveAmount; - -pub fn account(input: &str) -> IResult<&str, &str> { - input.split_at_position1_complete(|item| item == ' ' || item == '\t', ErrorKind::AlphaNumeric) -} - -pub fn amount(input: &str) -> IResult<&str, DirectiveAmount> { - alt((suffix_amount, prefix_amount)).parse(input) -} - -pub fn decimal(input: &str) -> IResult<&str, Decimal> { - let (new_input, decimal_str) = recognize(tuple(( - opt(one_of("+-")), - opt(number_int), - opt(char('.')), - opt(number_int), - ))) - .parse(input)?; - - if decimal_str.contains(',') { - match Decimal::from_str_exact(&decimal_str.replace(",", "")) { - Ok(decimal) => Ok((new_input, decimal)), - Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), - } - } else { - match Decimal::from_str_exact(decimal_str) { - Ok(decimal) => Ok((new_input, decimal)), - Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), - } - } -} - -/////////////// -// Private // -/////////////// - -fn prefix_amount(input: &str) -> IResult<&str, DirectiveAmount> { - tuple(( - opt(one_of("+-")), - unit, - preceded(space0, decimal), - )) - .map(|(sign, unit_symbol, mut value)| { - if let Some(s) = sign { - if s == '-' { - value = value * dec!(-1); - } - } - DirectiveAmount { - value, - unit_symbol: unit_symbol.to_string(), - is_unit_prefix: true, - } - }) - .parse(input) -} - -fn suffix_amount(input: &str) -> IResult<&str, DirectiveAmount> { - tuple(( - decimal, - preceded(space0, unit), - )) - .map(|(value, unit_symbol)| DirectiveAmount { - value, - unit_symbol: unit_symbol.to_string(), - is_unit_prefix: false, - }) - .parse(input) -} - -fn unit(input: &str) -> IResult<&str, &str> { - recognize(many1(none_of("0123456789,+-_()*/.{} \t"))).parse(input) -} - -fn number_int(input: &str) -> IResult<&str, &str> { - recognize(many1(terminated(one_of("0123456789"), many0(one_of("_,")))))(input) -} - - -#[cfg(test)] -mod tests { - use rust_decimal_macros::dec; - - use super::*; - - #[test] - fn parse_decimal_good() { - assert_eq!(decimal("1").unwrap().1, dec!(1)); - assert_eq!(decimal("+10").unwrap().1, dec!(10)); - assert_eq!(decimal("-10").unwrap().1, dec!(-10)); - assert_eq!(decimal("10.1").unwrap().1, dec!(10.1)); - assert_eq!(decimal("100_000.01").unwrap().1, dec!(100000.01)); - assert_eq!(decimal(".1").unwrap().1, dec!(0.1)); - assert_eq!(decimal("-.1").unwrap().1, dec!(-0.1)); - assert_eq!(decimal("2.").unwrap().1, dec!(2.)); - assert_eq!(decimal("1,000").unwrap().1, dec!(1000)); - } - - - #[test] - fn amount_good() { - assert_eq!( - amount("$10").unwrap().1, - DirectiveAmount { - value: dec!(10), - unit_symbol: "$".into(), - is_unit_prefix: true - } - ); - assert_eq!( - amount("10 USD").unwrap().1, - DirectiveAmount { - value: dec!(10), - unit_symbol: "USD".into(), - is_unit_prefix: false - } - ); - assert_eq!( - amount("-$10.01").unwrap().1, - DirectiveAmount { - value: dec!(-10.01), - unit_symbol: "$".into(), - is_unit_prefix: true - } - ); - assert_eq!( - amount("-10€").unwrap().1, - DirectiveAmount { - value: dec!(-10), - unit_symbol: "€".into(), - is_unit_prefix: false - } - ); - assert_eq!( - amount("-€10").unwrap().1, - DirectiveAmount { - value: dec!(-10), - unit_symbol: "€".into(), - is_unit_prefix: true - } - ); - } - -} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 343b291..8869030 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,12 +38,14 @@ // struct Quantity {} pub mod core; -pub mod queries; +// pub mod queries; // pub mod parser; // pub mod create_ledger; pub mod document; pub mod output; -mod parser; +pub mod parser; + +pub mod query; // pub struct Account { // // TODO diff --git a/src/main.rs b/src/main.rs index 0372a4a..de46942 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,9 +7,11 @@ use std::{ use accounting_rust::{ document::Document, output::cli::{format_balance, tui_to_ansi::text_to_ansi}, - queries::{self, Query}, + parser::{self, query}, + query::{self, PostingField}, + // queries::{self, base::{self, DataValue, Query}, functions::{ComparisonFunction, LogicalFunction, SubAccountFunction}, transaction::{AccountField, PostingField, TransactionField}}, }; -use chrono::NaiveDate; +use chrono::{NaiveDate, Utc}; use ratatui::{ crossterm::{self, style::PrintStyledContent}, layout::Rect, @@ -105,11 +107,46 @@ pub fn main() -> Result<(), Box> { // &[], // ); - let balance = queries::balance2( - &ledger, - NaiveDate::from_ymd_opt(2100, 01, 01).unwrap(), - Some("$") - ); + // let balance = queries::balance2( + // &ledger, + // NaiveDate::from_ymd_opt(2100, 01, 01).unwrap(), + // Some("$") + // ); + + // let balance_query = "transaction.date < 2100-01-01"; + let balance_query = "account.name ~ 'Assets' OR account.name ~ 'Liabilities'"; + + let parsed_query = parser::query::(balance_query).unwrap(); + if parsed_query.0.trim().len() != 0 { + panic!("Full string not consumed") + } + let balance_query = parsed_query.1; + + let current_date = Utc::now().date_naive(); + + let balance = query::balance(&ledger, Some(&balance_query), Some(("$", current_date))); + + // let date_query = ComparisonFunction::new( + // "<=", + // Query::from_field(PostingField::Transaction(TransactionField::Date)), + // Query::from(DataValue::from(NaiveDate::from_ymd_opt(2100, 01, 01).unwrap())), + // ).unwrap(); + // let account_query = SubAccountFunction::new( + // "Assets".into(), + // base::Query::from_field(PostingField::Account(AccountField::Name)), + // ); + + // let total_query = LogicalFunction::new( + // "and", + // base::Query::from_fn(date_query), + // base::Query::from_fn(account_query), + // ).unwrap(); + + // let balance = queries::balance3( + // &ledger, + // &base::Query::from_fn(total_query), + // ); + let t4 = Instant::now(); diff --git a/src/parser/amount.rs b/src/parser/amount.rs new file mode 100644 index 0000000..16494e3 --- /dev/null +++ b/src/parser/amount.rs @@ -0,0 +1,100 @@ +use nom::{ + branch::alt, + character::complete::{none_of, one_of, space0}, + combinator::{opt, recognize}, + multi::many1, + sequence::{preceded, tuple}, IResult, Parser, +}; +use rust_decimal_macros::dec; + +use crate::core::RawAmount; +use super::decimal; + +pub fn amount(input: &str) -> IResult<&str, RawAmount> { + alt((suffix_amount, prefix_amount)).parse(input) +} + +/////////////// +// Private // +/////////////// + +fn prefix_amount(input: &str) -> IResult<&str, RawAmount> { + tuple((opt(one_of("+-")), unit, preceded(space0, decimal))) + .map(|(sign, unit_symbol, mut value)| { + if let Some(s) = sign { + if s == '-' { + value = value * dec!(-1); + } + } + RawAmount { + value, + unit_symbol: unit_symbol.to_string(), + is_unit_prefix: true, + } + }) + .parse(input) +} + +fn suffix_amount(input: &str) -> IResult<&str, RawAmount> { + tuple((decimal, preceded(space0, unit))) + .map(|(value, unit_symbol)| RawAmount { + value, + unit_symbol: unit_symbol.to_string(), + is_unit_prefix: false, + }) + .parse(input) +} + +fn unit(input: &str) -> IResult<&str, &str> { + recognize(many1(none_of("0123456789,+-_()*/.{} \t"))).parse(input) +} + +#[cfg(test)] +mod tests { + use super::*; + use rust_decimal_macros::dec; + + #[test] + fn amount_good() { + assert_eq!( + amount("$10").unwrap().1, + RawAmount { + value: dec!(10), + unit_symbol: "$".into(), + is_unit_prefix: true + } + ); + assert_eq!( + amount("10 USD").unwrap().1, + RawAmount { + value: dec!(10), + unit_symbol: "USD".into(), + is_unit_prefix: false + } + ); + assert_eq!( + amount("-$10.01").unwrap().1, + RawAmount { + value: dec!(-10.01), + unit_symbol: "$".into(), + is_unit_prefix: true + } + ); + assert_eq!( + amount("-10€").unwrap().1, + RawAmount { + value: dec!(-10), + unit_symbol: "€".into(), + is_unit_prefix: false + } + ); + assert_eq!( + amount("-€10").unwrap().1, + RawAmount { + value: dec!(-10), + unit_symbol: "€".into(), + is_unit_prefix: true + } + ); + } +} diff --git a/src/parser/core.rs b/src/parser/core.rs new file mode 100644 index 0000000..3929e92 --- /dev/null +++ b/src/parser/core.rs @@ -0,0 +1,143 @@ +use chrono::NaiveDate; +use nom::{ + branch::alt, + bytes::complete::{escaped, tag, take_while_m_n}, + character::complete::{char, none_of, one_of, space0}, + combinator::{opt, recognize}, + error::{Error, ErrorKind}, + multi::{many0, many1}, + sequence::{delimited, terminated, tuple}, + AsChar, Err, IResult, Parser, +}; +use rust_decimal::Decimal; + +pub fn decimal(input: &str) -> IResult<&str, Decimal> { + let (new_input, decimal_str) = recognize(tuple(( + opt(one_of("+-")), + opt(number_int), + opt(char('.')), + opt(number_int), + ))) + .parse(input)?; + + if decimal_str.contains(',') { + match Decimal::from_str_exact(&decimal_str.replace(",", "")) { + Ok(decimal) => Ok((new_input, decimal)), + Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), + } + } else { + match Decimal::from_str_exact(decimal_str) { + Ok(decimal) => Ok((new_input, decimal)), + Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), + } + } +} + +pub fn number_int(input: &str) -> IResult<&str, &str> { + recognize(many1(terminated(one_of("0123456789"), many0(one_of("_,")))))(input) +} + +pub fn parse_iso_date(input: &str) -> IResult<&str, NaiveDate> { + let (new_input, (year, _, month, _, day)) = + tuple((date_year, tag("-"), date_month, tag("-"), date_day)).parse(input)?; + + match NaiveDate::from_ymd_opt(year, month, day) { + Some(date) => Ok((new_input, date)), + None => Err(nom::Err::Error(Error::new(input, ErrorKind::Eof))), + } +} + +pub fn quoted_string(input: &str) -> IResult<&str, &str> { + alt(( + delimited( + tag("\""), + escaped(none_of("\\\""), '\\', tag("\"")), + tag("\""), + ), + delimited( + tag("'"), + escaped(none_of("\\\'"), '\\', tag("\'")), + tag("'"), + ), + )) + .parse(input) +} + +pub fn ws<'a, F: 'a, O>(inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> +where + F: FnMut(&'a str) -> IResult<&'a str, O>, +{ + delimited(space0, inner, space0) +} + +/////////////// +// Private // +/////////////// + +fn take_n_digits(i: &str, n: usize) -> IResult<&str, u32> { + let (i, digits) = take_while_m_n(n, n, AsChar::is_dec_digit)(i)?; + + let res = digits.parse().expect("Invalid ASCII number"); + + Ok((i, res)) +} + +fn date_year(input: &str) -> IResult<&str, i32> { + take_n_digits(input, 4).map(|(str, year)| (str, i32::try_from(year).unwrap())) +} + +fn date_month(input: &str) -> IResult<&str, u32> { + take_n_digits(input, 2) +} + +fn date_day(input: &str) -> IResult<&str, u32> { + take_n_digits(input, 2) +} + +#[cfg(test)] +mod tests { + use super::*; + use rust_decimal_macros::dec; + + #[test] + fn parse_decimal_good() { + assert_eq!(decimal("1").unwrap().1, dec!(1)); + assert_eq!(decimal("+10").unwrap().1, dec!(10)); + assert_eq!(decimal("-10").unwrap().1, dec!(-10)); + assert_eq!(decimal("10.1").unwrap().1, dec!(10.1)); + assert_eq!(decimal("100_000.01").unwrap().1, dec!(100000.01)); + assert_eq!(decimal(".1").unwrap().1, dec!(0.1)); + assert_eq!(decimal("-.1").unwrap().1, dec!(-0.1)); + assert_eq!(decimal("2.").unwrap().1, dec!(2.)); + assert_eq!(decimal("1,000").unwrap().1, dec!(1000)); + } + + #[test] + fn correct_date() { + assert_eq!( + parse_iso_date("2000-01-01"), + Ok(("", NaiveDate::from_ymd_opt(2000, 01, 01).unwrap())) + ); + // assert_eq!( + // parse_iso_date("20000101"), + // Ok(("", NaiveDate::from_ymd_opt(2000, 01, 01).unwrap())) + // ); + } + + #[test] + fn incomplete_date() { + assert_eq!(parse_iso_date("200-01-01").is_err(), true); + } + + #[test] + fn invalid_date() { + assert_eq!(parse_iso_date("2000-02-30").is_err(), true); + } + + #[test] + fn parse_string() { + assert_eq!(quoted_string("\"test\"").unwrap().1, "test"); + assert_eq!(quoted_string("'test'").unwrap().1, "test"); + assert_eq!(quoted_string("\"te\\\"st\"").unwrap().1, "te\\\"st"); + } +} diff --git a/src/document/parser/base_directive.rs b/src/parser/document/base_directive.rs similarity index 97% rename from src/document/parser/base_directive.rs rename to src/parser/document/base_directive.rs index f56fa6c..d268f68 100644 --- a/src/document/parser/base_directive.rs +++ b/src/parser/document/base_directive.rs @@ -61,9 +61,9 @@ pub fn empty_lines(input: &str) -> IResult<&str, ()> { pub fn parse_iso_date(input: &str) -> IResult<&str, NaiveDate> { let (new_input, (year, _, month, _, day)) = tuple(( date_year, - opt(tag("-")), + tag("-"), date_month, - opt(tag("-")), + tag("-"), date_day, )) .parse(input)?; @@ -136,10 +136,10 @@ mod tests { parse_iso_date("2000-01-01"), Ok(("", NaiveDate::from_ymd_opt(2000, 01, 01).unwrap())) ); - assert_eq!( - parse_iso_date("20000101"), - Ok(("", NaiveDate::from_ymd_opt(2000, 01, 01).unwrap())) - ); + // assert_eq!( + // parse_iso_date("20000101"), + // Ok(("", NaiveDate::from_ymd_opt(2000, 01, 01).unwrap())) + // ); } #[test] diff --git a/src/document/parser/directives.rs b/src/parser/document/directives.rs similarity index 84% rename from src/document/parser/directives.rs rename to src/parser/document/directives.rs index 97d1410..f68839b 100644 --- a/src/document/parser/directives.rs +++ b/src/parser/document/directives.rs @@ -2,29 +2,20 @@ use std::path::PathBuf; use nom::{ bytes::complete::{is_not, tag}, - character::complete::{none_of, space1}, + character::complete::space1, combinator::{opt, rest}, - error::{Error, ErrorKind, ParseError}, + error::{Error, ErrorKind}, sequence::{preceded, terminated, tuple}, IResult, Parser, }; -use crate::{core::UnitSymbol, document::{ - BalanceDirective, CommodityDirective, IncludeDirective, TransactionDirective, -}}; - -use super::{ - amounts::{account, amount}, - base_directive::BaseDirective, - transaction::transaction, +use crate::{ + core::UnitSymbol, + document::{BalanceDirective, CommodityDirective, IncludeDirective, TransactionDirective}, + parser::amount, }; -// use super::{ -// base::ParsedBaseDirective, -// shared::{parse_account, parse_amount}, -// transaction::parse_transaction, -// types::{ParseError, ParsedBalanceDirective, ParsedDirectives, ParsedIncludeDirective}, -// }; +use super::{base_directive::BaseDirective, shared::account, transaction::transaction}; ////////////// // Public // @@ -115,13 +106,15 @@ fn commodity_directive(directive: BaseDirective) -> IResult symbols.push(UnitSymbol {symbol: value.into(), is_prefix: true}), - "symbol" => symbols.push(UnitSymbol {symbol: value.into(), is_prefix: false}), + "symbol_prefix" => symbols.push(UnitSymbol { symbol: value.into(), is_prefix: true }), + "symbol" => symbols.push(UnitSymbol { symbol: value.into(), is_prefix: false }), "precision" => precision = Some(value.trim().parse::().unwrap()), // TODO: unwrap - _ => return Err(nom::Err::Failure(Error { - input: directive, - code: ErrorKind::Fail, - })), + _ => { + return Err(nom::Err::Failure(Error { + input: directive, + code: ErrorKind::Fail, + })) + } } } diff --git a/src/document/parser/mod.rs b/src/parser/document/mod.rs similarity index 94% rename from src/document/parser/mod.rs rename to src/parser/document/mod.rs index d16def4..f7df86c 100644 --- a/src/document/parser/mod.rs +++ b/src/parser/document/mod.rs @@ -1,4 +1,3 @@ -mod amounts; mod base_directive; mod directives; mod transaction; @@ -14,9 +13,7 @@ use nom::{ Parser, }; -use crate::core::CoreError; - -use super::Directives; +use crate::{core::CoreError, document::Directives}; pub fn parse_directives(input: &str) -> Result { let parsed_directives = terminated( diff --git a/src/document/parser/shared.rs b/src/parser/document/shared.rs similarity index 88% rename from src/document/parser/shared.rs rename to src/parser/document/shared.rs index 82252c3..207cdca 100644 --- a/src/document/parser/shared.rs +++ b/src/parser/document/shared.rs @@ -6,6 +6,10 @@ use nom::{ IResult, InputTakeAtPosition, Parser, }; +pub fn account(input: &str) -> IResult<&str, &str> { + input.split_at_position1_complete(|item| item == ' ' || item == '\t', ErrorKind::AlphaNumeric) +} + pub fn metadatum(input: &str) -> IResult<&str, (&str, &str)> { tuple(( delimited(tag("-"), delimited(space0, key, space0), tag(":")), diff --git a/src/document/parser/transaction.rs b/src/parser/document/transaction.rs similarity index 87% rename from src/document/parser/transaction.rs rename to src/parser/document/transaction.rs index e1017c0..b4e228c 100644 --- a/src/document/parser/transaction.rs +++ b/src/parser/document/transaction.rs @@ -1,25 +1,24 @@ use nom::{ branch::alt, bytes::complete::{is_not, tag}, - character::complete::{space0, space1}, + character::complete::space1, combinator::{eof, opt, rest}, error::{Error, ErrorKind}, sequence::{delimited, preceded, terminated, tuple}, - Err, IResult, InputTakeAtPosition, Parser, + Err, IResult, Parser, }; -use crate::document::{DirectiveAmount, DirectivePosting, TransactionDirective}; +use crate::{ + core::RawAmount, + document::{DirectivePosting, TransactionDirective}, + parser::amount, +}; use super::{ - amounts::{account, amount}, base_directive::{parse_iso_date, BaseDirective}, - shared::metadatum, + shared::{account, metadatum}, }; -// use super::{ -// base::ParsedBaseDirective, directives::BaseDirective, shared::{parse_account, amount}, types::{ParseError, DirectiveAmount, DirectivePosting, ParsedTransactionDirective} -// }; - ////////////// // Public // ////////////// @@ -56,7 +55,7 @@ pub fn transaction<'a>( let mut metadata = Vec::new(); for &line in directive.lines.get(1..).unwrap_or(&[]) { if let Ok(m) = metadatum(line) { - metadata.push((m.1.0.to_string(), m.1.1.to_string())); + metadata.push((m.1 .0.to_string(), m.1 .1.to_string())); continue; } let posting = if let Ok(v) = posting(line) { @@ -118,7 +117,7 @@ fn posting(input: &str) -> IResult<&str, DirectivePosting> { amount = Some(v.0); if let Some(c) = v.1 { if c.1 { - cost = Some(DirectiveAmount { + cost = Some(RawAmount { value: c.0.value / amount.as_ref().unwrap().value.abs(), unit_symbol: c.0.unit_symbol, is_unit_prefix: c.0.is_unit_prefix, @@ -129,7 +128,7 @@ fn posting(input: &str) -> IResult<&str, DirectivePosting> { } if let Some(p) = v.2 { if p.1 { - price = Some(DirectiveAmount { + price = Some(RawAmount { value: p.0.value / amount.as_ref().unwrap().value.abs(), unit_symbol: p.0.unit_symbol, is_unit_prefix: p.0.is_unit_prefix, @@ -150,7 +149,7 @@ fn posting(input: &str) -> IResult<&str, DirectivePosting> { .parse(input) } -fn parse_cost(input: &str) -> IResult<&str, (DirectiveAmount, bool)> { +fn parse_cost(input: &str) -> IResult<&str, (RawAmount, bool)> { alt(( delimited(tag("{"), amount, tag("}")).map(|amount| (amount, false)), delimited(tag("{{"), amount, tag("}}")).map(|amount| (amount, true)), @@ -158,7 +157,7 @@ fn parse_cost(input: &str) -> IResult<&str, (DirectiveAmount, bool)> { .parse(input) } -fn parse_price(input: &str) -> IResult<&str, (DirectiveAmount, bool)> { +fn parse_price(input: &str) -> IResult<&str, (RawAmount, bool)> { alt(( preceded(tuple((tag("@"), space1)), amount).map(|amount| (amount, false)), preceded(tuple((tag("@@"), space1)), amount).map(|amount| (amount, true)), @@ -210,12 +209,12 @@ mod tests { DirectivePosting { date: None, account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), - cost: Some(DirectiveAmount { + cost: Some(RawAmount { value: dec!(100), unit_symbol: "$".into(), is_unit_prefix: true @@ -228,12 +227,12 @@ mod tests { DirectivePosting { date: None, account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), - cost: Some(DirectiveAmount { + cost: Some(RawAmount { value: dec!(100), unit_symbol: "USD".into(), is_unit_prefix: false @@ -250,13 +249,13 @@ mod tests { DirectivePosting { date: None, account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), cost: None, - price: Some(DirectiveAmount { + price: Some(RawAmount { value: dec!(100), unit_symbol: "$".into(), is_unit_prefix: true @@ -268,13 +267,13 @@ mod tests { DirectivePosting { date: None, account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), cost: None, - price: Some(DirectiveAmount { + price: Some(RawAmount { value: dec!(100), unit_symbol: "USD".into(), is_unit_prefix: false @@ -290,17 +289,17 @@ mod tests { DirectivePosting { date: None, account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), - cost: Some(DirectiveAmount { + cost: Some(RawAmount { value: dec!(100), unit_symbol: "$".into(), is_unit_prefix: true }), - price: Some(DirectiveAmount { + price: Some(RawAmount { value: dec!(110), unit_symbol: "$".into(), is_unit_prefix: true @@ -314,17 +313,17 @@ mod tests { DirectivePosting { date: None, account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), - cost: Some(DirectiveAmount { + cost: Some(RawAmount { value: dec!(100), unit_symbol: "USD".into(), is_unit_prefix: false }), - price: Some(DirectiveAmount { + price: Some(RawAmount { value: dec!(110), unit_symbol: "USD".into(), is_unit_prefix: false @@ -338,17 +337,17 @@ mod tests { DirectivePosting { date: None, account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), - cost: Some(DirectiveAmount { + cost: Some(RawAmount { value: dec!(100), unit_symbol: "USD".into(), is_unit_prefix: false }), - price: Some(DirectiveAmount { + price: Some(RawAmount { value: dec!(110), unit_symbol: "USD".into(), is_unit_prefix: false @@ -366,17 +365,17 @@ mod tests { DirectivePosting { date: Some(NaiveDate::from_ymd_opt(2000, 01, 01).unwrap()), account: "Account1".into(), - amount: Some(DirectiveAmount { + amount: Some(RawAmount { value: dec!(10), unit_symbol: "SHARE".into(), is_unit_prefix: false }), - cost: Some(DirectiveAmount { + cost: Some(RawAmount { value: dec!(100), unit_symbol: "$".into(), is_unit_prefix: true }), - price: Some(DirectiveAmount { + price: Some(RawAmount { value: dec!(110), unit_symbol: "$".into(), is_unit_prefix: true @@ -402,7 +401,7 @@ mod tests { assert_eq!(transaction.postings[0].account, "Account1:Account2"); assert_eq!( transaction.postings[0].amount, - Some(DirectiveAmount { + Some(RawAmount { value: dec!(10.01), unit_symbol: "$".into(), is_unit_prefix: true diff --git a/src/parser/fields.rs b/src/parser/fields.rs new file mode 100644 index 0000000..0266800 --- /dev/null +++ b/src/parser/fields.rs @@ -0,0 +1,48 @@ +use std::fmt::Debug; + +use crate::query::{AccountField, PostingField, TransactionField}; + +pub trait ParseField: Debug + Sized + Clone { + fn parse(input: &str) -> Option; +} + +impl ParseField for TransactionField { + fn parse(input: &str) -> Option { + match input.to_lowercase().as_str() { + "date" => Some(TransactionField::Date), + "flag" => Some(TransactionField::Flag), + "payee" => Some(TransactionField::Payee), + "narration" => Some(TransactionField::Narration), + _ => None, + } + } +} + +impl ParseField for AccountField { + fn parse(input: &str) -> Option { + match input.to_lowercase().as_str() { + "name" => Some(AccountField::Name), + "open" => Some(AccountField::OpenDate), + "close" => Some(AccountField::CloseDate), + _ => None, + } + } +} + +impl ParseField for PostingField { + fn parse(input: &str) -> Option { + match input + .to_lowercase() + .split('.') + .collect::>() + .as_slice() + { + ["amount"] => Some(PostingField::Amount), + ["cost"] => Some(PostingField::Cost), + ["price"] => Some(PostingField::Price), + ["transaction", t] => TransactionField::parse(t).map(|v| PostingField::Transaction(v)), + ["account", t] => AccountField::parse(t).map(|v| PostingField::Account(v)), + _ => None, + } + } +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index beb28e7..87dbeb6 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,99 +1,14 @@ -use chrono::NaiveDate; -use nom::{ - bytes::complete::{escaped, tag, take_while_m_n}, - character::complete::{char, none_of, one_of}, - combinator::{opt, recognize}, - error::{Error, ErrorKind}, - multi::{many0, many1}, - sequence::{delimited, terminated, tuple}, - AsChar, Err, IResult, Parser, -}; -use rust_decimal::Decimal; -pub fn decimal(input: &str) -> IResult<&str, Decimal> { - let (new_input, decimal_str) = recognize(tuple(( - opt(one_of("+-")), - opt(number_int), - opt(char('.')), - opt(number_int), - ))) - .parse(input)?; +mod core; +mod amount; +mod value; +mod query; +mod document; +mod fields; - if decimal_str.contains(',') { - match Decimal::from_str_exact(&decimal_str.replace(",", "")) { - Ok(decimal) => Ok((new_input, decimal)), - Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), - } - } else { - match Decimal::from_str_exact(decimal_str) { - Ok(decimal) => Ok((new_input, decimal)), - Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), - } - } -} - -pub fn number_int(input: &str) -> IResult<&str, &str> { - recognize(many1(terminated(one_of("0123456789"), many0(one_of("_,")))))(input) -} - -pub fn parse_iso_date(input: &str) -> IResult<&str, NaiveDate> { - let (new_input, (year, _, month, _, day)) = tuple(( - date_year, - opt(tag("-")), - date_month, - opt(tag("-")), - date_day, - )) - .parse(input)?; - - match NaiveDate::from_ymd_opt(year, month, day) { - Some(date) => Ok((new_input, date)), - None => Err(nom::Err::Error(Error::new(input, ErrorKind::Eof))), - } -} - -pub fn quoted_string(input: &str) -> IResult<&str, &str> { - delimited( - tag("\""), - escaped(none_of("\\\""), '\\', tag("\"")), - tag("\""), - ) - .parse(input) -} - -/////////////// -// Private // -/////////////// - -fn take_n_digits(i: &str, n: usize) -> IResult<&str, u32> { - let (i, digits) = take_while_m_n(n, n, AsChar::is_dec_digit)(i)?; - - let res = digits.parse().expect("Invalid ASCII number"); - - Ok((i, res)) -} - -fn date_year(input: &str) -> IResult<&str, i32> { - take_n_digits(input, 4).map(|(str, year)| (str, i32::try_from(year).unwrap())) -} - -fn date_month(input: &str) -> IResult<&str, u32> { - take_n_digits(input, 2) -} - -fn date_day(input: &str) -> IResult<&str, u32> { - take_n_digits(input, 2) -} - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_string() { - assert_eq!(quoted_string("\"test\"").unwrap().1, "test"); - assert_eq!(quoted_string("\"te\\\"st\"").unwrap().1, "te\\\"st"); - } - -} \ No newline at end of file +pub use core::*; +pub use amount::*; +pub use value::*; +pub use query::*; +pub use document::*; +pub use fields::*; \ No newline at end of file diff --git a/src/parser/query.rs b/src/parser/query.rs new file mode 100644 index 0000000..b10ee73 --- /dev/null +++ b/src/parser/query.rs @@ -0,0 +1,246 @@ +use nom::{ + branch::alt, + bytes::complete::{tag, tag_no_case}, + error::{Error, ErrorKind}, + multi::fold_many0, + sequence::{delimited, preceded, tuple}, + AsChar, Err, IResult, InputTakeAtPosition, Parser, +}; + +use super::{quoted_string, value, ws, ParseField}; +use crate::query::{ + ComparisonFunction, ComparisonOperator, LogicalFunction, LogicalOperator, NotFunction, Query, + RegexFunction, +}; + +pub fn query<'a, Field: ParseField>(input: &'a str) -> IResult<&'a str, Query> { + term_or.parse(input) +} + +///////////// +// Private // +///////////// + +fn expression<'a, Field: ParseField>(input: &'a str) -> IResult<&'a str, Query> { + term_or.parse(input) +} + +fn term_or<'a, Field: ParseField>(input: &'a str) -> IResult<&'a str, Query> { + let (input, init) = term_and(input)?; + + fold_many0( + tuple((tag_no_case("OR"), term_and)), + move || init.clone(), + |acc, (_, val)| Query::from_fn(LogicalFunction::new_op(LogicalOperator::OR, acc, val)), + ) + .parse(input) +} + +fn term_and<'a, Field: ParseField>(input: &'a str) -> IResult<&'a str, Query> { + let (input, init) = term(input)?; + + fold_many0( + tuple((tag_no_case("AND"), term)), + move || init.clone(), + |acc, (_, val)| Query::from_fn(LogicalFunction::new_op(LogicalOperator::AND, acc, val)), + ) + .parse(input) +} + +fn term<'a, Field: ParseField + 'static>(input: &'a str) -> IResult<&'a str, Query> { + alt((function_regex, function_comparison, factor)).parse(input) +} + +fn function_regex<'a, Field: ParseField>(input: &'a str) -> IResult<&'a str, Query> { + let (new_input, result) = tuple((factor, tag("~"), ws(quoted_string))) + .map(|(left, _, right)| RegexFunction::new(left, right, true)) // TODO: case sensitive? + .parse(input)?; + match result { + Ok(regex_function) => Ok((new_input, Query::from_fn(regex_function))), + Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), + } +} + +fn function_comparison<'a, Field: ParseField + 'static>( + input: &'a str, +) -> IResult<&'a str, Query> { + let op = alt(( + tag("=").map(|_| ComparisonOperator::EQ), + tag("!=").map(|_| ComparisonOperator::NEQ), + tag(">").map(|_| ComparisonOperator::GT), + tag("<").map(|_| ComparisonOperator::LT), + tag(">=").map(|_| ComparisonOperator::GTE), + tag("<=").map(|_| ComparisonOperator::LTE), + )); + + tuple((factor, op, factor)) + .map(|(left, op, right)| Query::from_fn(ComparisonFunction::new_op(op, left, right))) + .parse(input) +} + +fn factor<'a, Field: ParseField + 'static>(input: &'a str) -> IResult<&'a str, Query> { + ws(alt(( + function_unary, + value.map(|v| Query::Value(v)), + field.map(|f| Query::Field(f)), + parenthesis, + ))) + .parse(input) +} + +fn function_unary<'a, Field: ParseField + 'static>( + input: &'a str, +) -> IResult<&'a str, Query> { + preceded(ws(tag("!")), factor) + .map(|f| Query::from_fn(NotFunction::new(f))) + .parse(input) +} + +fn field<'a, Field: ParseField>(input: &str) -> IResult<&str, Field> { + input + .split_at_position1_complete( + |item| !item.is_alphanum() && item != '.', + ErrorKind::AlphaNumeric, + ) + .and_then(|v| { + Field::parse(v.1) + .map(|f| (v.0, f)) + .ok_or(nom::Err::Error(Error::new(input, ErrorKind::Eof))) + }) +} + +fn parenthesis<'a, Field: ParseField + 'static>(input: &'a str) -> IResult<&'a str, Query> { + delimited(ws(tag("(")), expression, ws(tag(")"))).parse(input) +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use chrono::NaiveDate; + use rust_decimal::Decimal; + + use crate::{ + core::{CoreError, DataValue}, + query::Data, + }; + + use super::*; + + #[derive(Clone, Debug, PartialEq)] + pub enum Field1 { + A, + AB, + } + + impl ParseField for Field1 { + fn parse(input: &str) -> Option { + match input { + "A" => Some(Field1::A), + "A.B" => Some(Field1::AB), + _ => None, + } + } + } + + pub struct FieldData {} + + impl Data for FieldData { + fn get_field(&self, field: &Field1) -> Result { + if *field == Field1::A { + return Ok(Decimal::new(30, 0).into()); + } + return Ok(DataValue::Null); + } + } + + #[test] + fn parse_field() { + assert_eq!(field::("A").unwrap().1, Field1::A); + assert_eq!(field::("A.B").unwrap().1, Field1::AB); + assert!(field::("C").is_err()); + } + + fn evaluate_query(input: &str) -> Result { + let parsed = query::(input).unwrap(); + if parsed.0.trim().len() != 0 { + panic!("Full string not consumed") + } + + print!("{:?}", parsed.1); + let context = FieldData {}; + parsed.1.evaluate(&context) + } + + #[test] + fn query_value() { + assert_eq!(evaluate_query("10").unwrap(), Decimal::new(10, 0).into()); + assert_eq!(evaluate_query("1.2").unwrap(), Decimal::new(12, 1).into()); + assert_eq!(evaluate_query("null").unwrap(), DataValue::Null); + assert_eq!(evaluate_query("true").unwrap(), true.into()); + assert_eq!(evaluate_query("\"abc\"").unwrap(), "abc".into()); + assert_eq!( + evaluate_query("2000-01-01").unwrap(), + NaiveDate::from_ymd_opt(2000, 01, 01).unwrap().into() + ); + } + + #[test] + fn query_field() { + assert_eq!(evaluate_query("A").unwrap(), Decimal::new(30, 0).into()); + assert_eq!(evaluate_query("A.B").unwrap(), DataValue::Null); + } + + #[test] + fn query_comparison() { + assert_eq!(evaluate_query("10 > 4").unwrap(), true.into()); + assert_eq!(evaluate_query("10 < 4").unwrap(), false.into()); + assert_eq!(evaluate_query("\"ab\" < \"cd\"").unwrap(), true.into()); + assert_eq!( + evaluate_query("2000-02-01 > 2000-01-01").unwrap(), + true.into() + ); + } + + #[test] + fn query_regex() { + assert_eq!(evaluate_query("'abc' ~ 'abc'").unwrap(), true.into()); + assert_eq!(evaluate_query("'abcd' ~ 'abc'").unwrap(), true.into()); + assert_eq!(evaluate_query("'abcd' ~ '^abc$'").unwrap(), false.into()); + assert_eq!( + evaluate_query("'Account:Sub' ~ 'Account'").unwrap(), + true.into() + ); + assert_eq!( + evaluate_query("'Account:Sub' ~ 'Account:Sub'").unwrap(), + true.into() + ); + } + + #[test] + fn query_logical() { + assert_eq!(evaluate_query("true anD false").unwrap(), false.into()); + assert_eq!(evaluate_query("true Or false").unwrap(), true.into()); + assert_eq!(evaluate_query(" ! true").unwrap(), false.into()); + assert_eq!(evaluate_query("!!true").unwrap(), true.into()); + assert_eq!( + evaluate_query("true and true and false").unwrap(), + false.into() + ); + assert_eq!( + evaluate_query("true and false or true and true").unwrap(), + true.into() + ); + } + + #[test] + fn query_combined() { + assert_eq!(evaluate_query("10 > 4 and 4 > 10").unwrap(), false.into()); + assert_eq!( + evaluate_query("true and (false or true)").unwrap(), + true.into() + ); + } +} diff --git a/src/parser/value.rs b/src/parser/value.rs new file mode 100644 index 0000000..b41f761 --- /dev/null +++ b/src/parser/value.rs @@ -0,0 +1,47 @@ +use nom::{ + branch::alt, + bytes::complete::tag_no_case, + IResult, Parser, +}; +use super::{decimal, parse_iso_date, quoted_string}; +use crate::core::DataValue; + +pub fn value<'a>(input: &'a str) -> IResult<&'a str, DataValue> { + alt(( + tag_no_case("null").map(|_| DataValue::Null), + tag_no_case("true").map(|_| DataValue::Boolean(true)), + tag_no_case("false").map(|_| DataValue::Boolean(false)), + parse_iso_date.map(|v| DataValue::Date(v)), + decimal.map(|v| DataValue::Decimal(v)), + quoted_string.map(|v| v.into()), + // TODO: list, map + )) + .parse(input) +} + +#[cfg(test)] +mod tests { + use chrono::NaiveDate; + + use super::*; + + #[test] + fn parse_value() { + assert_eq!(value("nuLl").unwrap().1, DataValue::Null); + assert_eq!(value("TruE").unwrap().1, DataValue::Boolean(true)); + assert_eq!(value("falSe").unwrap().1, DataValue::Boolean(false)); + assert_eq!( + value("2000-01-01").unwrap().1, + DataValue::Date(NaiveDate::from_ymd_opt(2000, 01, 01).unwrap()) + ); + assert_eq!(value("10").unwrap().1, DataValue::Decimal(10.into())); + assert_eq!( + value("20000101").unwrap().1, + DataValue::Decimal(20000101.into()) + ); + assert_eq!( + value("\"abc\"").unwrap().1, + DataValue::String("abc".into()) + ); + } +} diff --git a/src/queries/balance.rs b/src/queries/balance.rs deleted file mode 100644 index 9e544bd..0000000 --- a/src/queries/balance.rs +++ /dev/null @@ -1,169 +0,0 @@ -use std::{collections::HashMap, time::Instant}; - -use crate::core::{Amount, Ledger}; -use chrono::NaiveDate; -use rust_decimal_macros::dec; - -use super::{ - base::{self, DataValue, Function}, - functions::ComparisonFunction, - transaction::{PostingData, PostingField, TransactionField}, -}; - -pub enum Query { - StartDate(NaiveDate), - EndDate(NaiveDate), -} - -pub fn balance2( - ledger: &Ledger, - end_date: NaiveDate, - convert_to_unit: Option<&str>, -) -> HashMap> { - let q = ComparisonFunction::new( - "<=", - base::Query::from_field(PostingField::Transaction(TransactionField::Date)), - base::Query::from(DataValue::from(end_date)), - ) - .unwrap(); - - let convert_to_unit = convert_to_unit.map(|u| ledger.get_unit_by_symbol(u).unwrap()); - - let postings = ledger - .get_transactions() - .iter() - .map(|t| { - t.get_postings().iter().map(|p| PostingData { - ledger, - posting: p, - parent_transaction: t, - }) - }) - .flatten(); - - let filtered_postings = - postings.filter(|data| q.evaluate(data).map(|v| bool::from(v)).unwrap_or(false)); - - let mut accounts = HashMap::new(); - for posting_data in filtered_postings { - let posting = posting_data.posting; - let mut amount = *posting.get_amount(); - if let Some(new_unit) = convert_to_unit { - if amount.unit_id != new_unit.get_id() { - let price = ledger.get_price_on_date(end_date, amount.unit_id, new_unit.get_id()); - if let Some(price) = price { - amount = Amount { - value: amount.value * price, - unit_id: new_unit.get_id(), - }; - } - } - } - let account_vals = accounts - .entry(posting.get_account_id()) - .or_insert(HashMap::new()); - let a = account_vals.entry(amount.unit_id).or_insert(dec!(0)); - *a += amount.value; - } - - accounts - .iter() - .map(|(&k, v)| { - ( - k, - v.into_iter() - .map(|(&unit_id, &value)| Amount { value, unit_id }) - .collect(), - ) - }) - .collect() -} - -pub fn balance3(ledger: &Ledger, query: &base::Query) -> HashMap> { - // let q = ComparisonFunction::new( - // "<=", - // base::Query::from_field(PostingField::Transaction(TransactionField::Date)), - // base::Query::from(DataValue::from(end_date)), - // ) - // .unwrap(); - - let t0 = Instant::now(); - - let postings = ledger - .get_transactions() - .iter() - .map(|t| { - t.get_postings().iter().map(|p| PostingData { - ledger, - posting: p, - parent_transaction: t, - }) - }) - .flatten(); - - let t1 = Instant::now(); - - let filtered_postings = - postings.filter(|data| query.evaluate(data).map(|v| bool::from(v)).unwrap_or(false)); - - let t2 = Instant::now(); - - // println!("{:?} {:?}", t1-t0, t2-t1); - - let mut accounts = HashMap::new(); - for posting_data in filtered_postings { - let posting = posting_data.posting; - let amount = posting.get_amount(); - let account_vals = accounts - .entry(posting.get_account_id()) - .or_insert(HashMap::new()); - let a = account_vals.entry(amount.unit_id).or_insert(dec!(0)); - *a += amount.value; - } - - accounts - .iter() - .map(|(&k, v)| { - ( - k, - v.into_iter() - .map(|(&unit_id, &value)| Amount { value, unit_id }) - .collect(), - ) - }) - .collect() -} - -pub fn balance(ledger: &Ledger, query: &[Query]) -> HashMap> { - let relevant_transactions = ledger.get_transactions().iter().filter(|txn| { - query.iter().all(|q| match q { - Query::StartDate(date) => txn.get_date() >= *date, - Query::EndDate(date) => txn.get_date() <= *date, - }) - }); - - let mut accounts = HashMap::new(); - - for txn in relevant_transactions.clone() { - for posting in txn.get_postings() { - let amount = posting.get_amount(); - let account_vals = accounts - .entry(posting.get_account_id()) - .or_insert(HashMap::new()); - let a = account_vals.entry(amount.unit_id).or_insert(dec!(0)); - *a += amount.value; - } - } - - accounts - .iter() - .map(|(&k, v)| { - ( - k, - v.into_iter() - .map(|(&unit_id, &value)| Amount { value, unit_id }) - .collect(), - ) - }) - .collect() -} diff --git a/src/queries/functions.rs b/src/queries/functions.rs deleted file mode 100644 index 138d1ba..0000000 --- a/src/queries/functions.rs +++ /dev/null @@ -1,294 +0,0 @@ -use std::time::{Duration, Instant}; - -use crate::core::CoreError; -use regex::Regex; - -use super::base::{Data, DataValue, Function, Query, StringData}; - -#[derive(PartialEq, Debug)] -pub enum ComparisonOperator { - EQ, - NEQ, - GT, - LT, - GTE, - LTE, -} - -pub struct ComparisonFunction<'a, Field> { - op: ComparisonOperator, - left: Query<'a, Field>, - right: Query<'a, Field>, -} - -pub enum StringComparisonOperator<'a> { - Func(&'a (dyn Fn(&str) -> bool + 'a)), - Regex(Regex), -} - -pub struct StringComparisonFunction<'a, Field> { - op: StringComparisonOperator<'a>, - val: Query<'a, Field>, -} - -pub struct SubAccountFunction<'a, Field> { - account_name: StringData<'a>, - val: Query<'a, Field>, -} - -pub struct RegexFunction<'a, Field> { - left: Query<'a, Field>, - regex: Regex, -} - -#[derive(PartialEq, Debug)] -pub enum LogicalOperator { - AND, - OR, -} - -pub struct LogicalFunction<'a, Field> { - op: LogicalOperator, - left: Query<'a, Field>, - right: Query<'a, Field>, -} - -pub struct NotFunction<'a, Field> { - value: Query<'a, Field>, -} - -impl<'a, Field> ComparisonFunction<'a, Field> { - pub fn new( - op: &str, - left: Query<'a, Field>, - right: Query<'a, Field>, - ) -> Result { - let op = match op { - "==" => ComparisonOperator::EQ, - "!=" => ComparisonOperator::NEQ, - ">" => ComparisonOperator::GT, - "<" => ComparisonOperator::LT, - ">=" => ComparisonOperator::GTE, - "<=" => ComparisonOperator::LTE, - _ => return Err("Invalid Operator".into()), - }; - Ok(ComparisonFunction { op, left, right }) - } - - pub fn new_op( - op: ComparisonOperator, - left: Query<'a, Field>, - right: Query<'a, Field>, - ) -> Self { - ComparisonFunction { op, left, right } - } -} - -impl<'a, Field> SubAccountFunction<'a, Field> { - pub fn new(account: StringData<'a>, val: Query<'a, Field>) -> Self { - SubAccountFunction { account_name: account, val } - } -} - -impl<'a, Field> StringComparisonFunction<'a, Field> { - pub fn new_func( - val: Query<'a, Field>, - func: &'a (impl Fn(&str) -> bool + 'a), - ) -> Result { - Ok(StringComparisonFunction { val, op: StringComparisonOperator::Func(func) }) - } - - pub fn new_regex(val: Query<'a, Field>, regex: &str) -> Result { - let regex = Regex::new(regex).map_err(|_| CoreError::from("Unable to parse regex"))?; - Ok(StringComparisonFunction { val, op: StringComparisonOperator::Regex(regex) }) - } -} - -impl<'a, Field> RegexFunction<'a, Field> { - pub fn new(left: Query<'a, Field>, regex: &str) -> Result { - let regex = Regex::new(regex).map_err(|_| CoreError::from("Unable to parse regex"))?; - Ok(RegexFunction { left, regex }) - } -} - -impl<'a, Field> Function for ComparisonFunction<'a, Field> { - fn evaluate(&self, context: &dyn Data) -> Result { - let left = self.left.evaluate(context)?; - let right = self.right.evaluate(context)?; - - match self.op { - ComparisonOperator::EQ => Ok(DataValue::Boolean(left == right)), - ComparisonOperator::NEQ => Ok(DataValue::Boolean(left != right)), - ComparisonOperator::GT => Ok(DataValue::Boolean(left > right)), - ComparisonOperator::LT => Ok(DataValue::Boolean(left < right)), - ComparisonOperator::GTE => Ok(DataValue::Boolean(left >= right)), - ComparisonOperator::LTE => Ok(DataValue::Boolean(left <= right)), - } - } -} - -impl<'a, Field> Function for StringComparisonFunction<'a, Field> { - fn evaluate(&self, context: &dyn Data) -> Result { - let val = self.val.evaluate(context)?; - - if let DataValue::String(val) = val { - match &self.op { - StringComparisonOperator::Func(func) => Ok(DataValue::Boolean(func(val.as_ref()))), - StringComparisonOperator::Regex(regex) => { - Ok(DataValue::Boolean(regex.is_match(val.as_ref()))) - } - } - // Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) - } else { - Err("Cannot use REGEX operation on non string types".into()) - } - - // let left = self.left.evaluate(context)?; - - // if let DataValue::String(left) = left { - // Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) - // } else { - // Err("Cannot use REGEX operation on non string types".into()) - // } - } -} - -impl<'a, Field> Function for SubAccountFunction<'a, Field> { - fn evaluate(&self, context: &dyn Data) -> Result { - let val = self.val.evaluate(context)?; - - if let DataValue::String(val) = val { - Ok(DataValue::Boolean( - val.as_ref() - .strip_prefix(self.account_name.as_ref()) - .map(|n| n.is_empty() || n.starts_with(":")) - .unwrap_or(false), - )) - } else { - Err("Cannot compare account name on non string types".into()) - } - } -} - -impl<'a, Field> Function for RegexFunction<'a, Field> { - fn evaluate(&self, context: &dyn Data) -> Result { - let left = self.left.evaluate(context)?; - - if let DataValue::String(left) = left { - Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) - } else { - Err("Cannot use REGEX operation on non string types".into()) - } - } -} - -impl<'a, Field> LogicalFunction<'a, Field> { - pub fn new( - op: &str, - left: Query<'a, Field>, - right: Query<'a, Field>, - ) -> Result { - if op.eq_ignore_ascii_case("and") { - Ok(LogicalFunction { op: LogicalOperator::AND, left, right }) - } else if op.eq_ignore_ascii_case("or") { - Ok(LogicalFunction { op: LogicalOperator::OR, left, right }) - } else { - Err("Invalid logical operator".into()) - } - } -} - -impl<'a, Field> NotFunction<'a, Field> { - pub fn new(value: Query<'a, Field>) -> Self { - NotFunction { value } - } -} - -impl<'a, Field> Function for LogicalFunction<'a, Field> { - fn evaluate(&self, context: &dyn Data) -> Result { - // More verbose to try and avoid doing right side computation - let value: bool = match self.op { - LogicalOperator::AND => { - if !bool::from(self.left.evaluate(context)?) { - false - } else { - self.right.evaluate(context)?.into() - } - } - LogicalOperator::OR => { - if bool::from(self.left.evaluate(context)?) { - true - } else { - self.right.evaluate(context)?.into() - } - } - }; - - Ok(DataValue::Boolean(value)) - - // let left = self.left.evaluate(context)?.into(); - // let right = self.right.evaluate(context)?.into(); - - // match self.op { - // LogicalOperator::AND => Ok(DataValue::Boolean(left && right)), - // LogicalOperator::OR => Ok(DataValue::Boolean(left || right)), - // } - } -} - -impl<'a, Field> Function for NotFunction<'a, Field> { - fn evaluate(&self, context: &dyn Data) -> Result { - let value: bool = self.value.evaluate(context)?.into(); - - Ok(DataValue::Boolean(!value)) - } -} - -// Tests section -#[cfg(test)] -mod tests { - use super::*; - - enum TestField {} - - struct TestData {} - - impl Data for TestData { - fn get_field(&self, _: &TestField) -> Result { - Err("".into()) - } - } - - #[test] - fn comparison_function_evaluate() { - let value1: DataValue = 5.into(); - let value2: DataValue = 10.into(); - let context = TestData {}; - - let comparison_function = - ComparisonFunction::::new(">", value1.into(), value2.into()).unwrap(); - assert_eq!(comparison_function.op, ComparisonOperator::GT); - assert_eq!(comparison_function.evaluate(&context), Ok(false.into())); - - let value = Query::from_fn(comparison_function); - assert_eq!(value.evaluate(&context), Ok(false.into())); - } - - #[test] - fn logical_function_evaluate() { - let value1: DataValue = 1.into(); - let value2: DataValue = false.into(); - let context = TestData {}; - - let logical_function_and = - LogicalFunction::::new("and", value1.clone().into(), value2.clone().into()) - .unwrap(); - assert_eq!(logical_function_and.op, LogicalOperator::AND); - assert_eq!(logical_function_and.evaluate(&context), Ok(false.into())); - - let logical_function_or = - LogicalFunction::::new("or", value1.into(), value2.into()).unwrap(); - assert_eq!(logical_function_or.op, LogicalOperator::OR); - assert_eq!(logical_function_or.evaluate(&context), Ok(true.into())); - } -} diff --git a/src/queries/mod.rs b/src/queries/mod.rs deleted file mode 100644 index 42069e5..0000000 --- a/src/queries/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod balance; -mod postings; -pub mod base; -pub mod functions; -pub mod parser; -pub mod transaction; - -pub use balance::*; -pub use postings::*; diff --git a/src/queries/parser/functions.rs b/src/queries/parser/functions.rs deleted file mode 100644 index 51bf1e0..0000000 --- a/src/queries/parser/functions.rs +++ /dev/null @@ -1,142 +0,0 @@ -use nom::{ - branch::alt, bytes::complete::{tag, tag_no_case, take_until}, character::complete::space0, error::{Error, ErrorKind}, multi::fold_many0, sequence::{delimited, preceded, tuple}, AsChar, IResult, InputTakeAtPosition, Parser -}; - -use crate::{ - parser::{decimal, parse_iso_date, quoted_string}, - queries::{ - base::{DataValue, Query}, - functions::{ComparisonFunction, ComparisonOperator, LogicalFunction}, - }, -}; - -pub trait ParseField: Sized { - fn parse(input: &str) -> Option; -} - -// fn query<'a, Field: ParseField + 'static>( -// input: &'a str, -// ) -> IResult<&'a str, Query<'a, Field>> { -// // TODO: uncommented, is this right? -// delimited( -// space0, -// alt(( -// parenthesis, -// value.map(|v| v.into()), -// field.map(|v| Query::from_field(v)), -// comparison_function::.map(|v| Query::from_fn(v)), -// logical_function::.map(|v| Query::from_fn(v)), -// )), -// space0, -// ) -// .parse(input) -// } - -fn value<'a>(input: &'a str) -> IResult<&'a str, DataValue<'a>> { - alt(( - tag_no_case("null").map(|_| DataValue::Null), - tag_no_case("true").map(|_| DataValue::Boolean(true)), - tag_no_case("false").map(|_| DataValue::Boolean(false)), - parse_iso_date.map(|v| v.into()), - decimal.map(|v| v.into()), - quoted_string.map(|v| v.into()), - )) - .parse(input) -} - -fn field<'a, Field: ParseField>(input: &str) -> IResult<&str, Field> { - input - .split_at_position1_complete( - |item| !item.is_alphanum() || item != '.', - ErrorKind::AlphaNumeric, - ) - .and_then(|v| { - Field::parse(v.1) - .map(|f| (v.0, f)) - .ok_or(nom::Err::Error(Error::new(input, ErrorKind::Eof))) - }) -} - -// fn parenthesis<'a, Field: ParseField + 'static>( -// input: &'a str, -// ) -> IResult<&'a str, Query<'a, Field>> { -// delimited(tag("("), query, tag(")")).parse(input) -// } - -// fn query_factor<'a, Field: ParseField + 'static>(input: &'a str) -> IResult<&'a str, Query<'a, Field>> { -// delimited( -// space0, -// alt(( -// value.map(|v| v.into()), -// field.map(|v| Query::from_field(v)), -// parenthesis, -// )), -// space0, -// ) -// .parse(input) -// } - -// fn query_comparison<'a, Field: ParseField + 'static>( -// input: &'a str, -// ) -> IResult<&'a str, Query<'a, Field>> { -// let (input, initial) = query_factor(input)?; - -// let op = alt(( -// tag("=").map(|_| ComparisonOperator::EQ), -// tag("!=").map(|_| ComparisonOperator::NEQ), -// tag(">").map(|_| ComparisonOperator::GT), -// tag("<").map(|_| ComparisonOperator::LT), -// tag(">=").map(|_| ComparisonOperator::GTE), -// tag("<=").map(|_| ComparisonOperator::LTE), -// )); - -// loop { - -// } - - -// fold_many0(tuple((op, query_factor)), move || initial, |acc, (op, val)| { -// Query::from_fn(ComparisonFunction::new_op(op, acc, val)) -// }).parse(input) -// } - -// fn query_logical_and<'a, Field: ParseField>( -// input: &str, -// ) -> IResult<&'a str, Query<'a, Field>> { -// let (input_next, lhs) = query_factor(input)?; - -// fold_many0(preceded(tag_no_case("and"), query_factor), init, g) - -// // loop { -// // let rhs_result = tuple(( -// // alt((tag_no_case("and"), tag_no_case("or"))), -// // query_factor, -// // )).parse(input_next); -// // if let Ok(rhs_result) = rhs_result { -// // input_next = rhs_result.0; -// // } else { -// // break; -// // } -// // }; - - -// // tuple(take_until(alt((tag(""))))) -// } - -// fn logical_function<'a, Field: ParseField>( -// input: &str, -// ) -> IResult<&str, LogicalFunction<'a, Field>> { -// let lhs = - -// } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_string() { - assert_eq!(quoted_string("\"test\"").unwrap().1, "test"); - assert_eq!(quoted_string("\"te\\\"st\"").unwrap().1, "te\\\"st"); - } -} diff --git a/src/queries/parser/mod.rs b/src/queries/parser/mod.rs deleted file mode 100644 index 87ccb64..0000000 --- a/src/queries/parser/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod functions; \ No newline at end of file diff --git a/src/queries/postings.rs b/src/queries/postings.rs deleted file mode 100644 index ef14537..0000000 --- a/src/queries/postings.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::core::Ledger; - - -pub struct PostingQuery { - -} - - -pub fn query_postings(ledger: &Ledger, query: PostingQuery) { - -} \ No newline at end of file diff --git a/src/query/balance.rs b/src/query/balance.rs new file mode 100644 index 0000000..e292be3 --- /dev/null +++ b/src/query/balance.rs @@ -0,0 +1,75 @@ +use chrono::NaiveDate; +use rust_decimal_macros::dec; +use std::collections::HashMap; + +use super::{ + transaction::{PostingData, PostingField}, + Query, +}; +use crate::core::{Amount, Ledger}; + +pub fn balance( + ledger: &Ledger, + filter: Option<&Query>, + convert_to_unit: Option<(&str, NaiveDate)>, +) -> HashMap> { + let convert_to_unit = convert_to_unit.map(|u| (ledger.get_unit_by_symbol(u.0).unwrap(), u.1)); + + let postings = ledger + .get_transactions() + .iter() + .map(|t| { + t.get_postings().iter().map(|p| PostingData { + ledger, + posting: p, + parent_transaction: t, + }) + }) + .flatten(); + + let filter = match filter { + Some(filter) => filter, + None => &Query::Value(true.into()), + }; + + let filtered_postings = postings.filter(|data| { + filter + .evaluate(data) + .map(|v| bool::from(v)) + .unwrap_or(false) + }); + + let mut accounts = HashMap::new(); + for posting_data in filtered_postings { + let posting = posting_data.posting; + let mut amount = *posting.get_amount(); + if let Some((new_unit, unit_date)) = convert_to_unit { + if amount.unit_id != new_unit.get_id() { + let price = ledger.get_price_on_date(unit_date, amount.unit_id, new_unit.get_id()); + if let Some(price) = price { + amount = Amount { + value: amount.value * price, + unit_id: new_unit.get_id(), + }; + } + } + } + let account_vals = accounts + .entry(posting.get_account_id()) + .or_insert(HashMap::new()); + let a = account_vals.entry(amount.unit_id).or_insert(dec!(0)); + *a += amount.value; + } + + accounts + .iter() + .map(|(&k, v)| { + ( + k, + v.into_iter() + .map(|(&unit_id, &value)| Amount { value, unit_id }) + .collect(), + ) + }) + .collect() +} diff --git a/src/query/functions_comparison.rs b/src/query/functions_comparison.rs new file mode 100644 index 0000000..08d6414 --- /dev/null +++ b/src/query/functions_comparison.rs @@ -0,0 +1,202 @@ +use regex::{Regex, RegexBuilder}; +use std::fmt::Debug; + +use super::{Data, Function, Query}; +use crate::core::{CoreError, DataValue, StringData}; + +#[derive(PartialEq, Debug, Clone)] +pub enum ComparisonOperator { + EQ, + NEQ, + GT, + LT, + GTE, + LTE, +} + +#[derive(Clone, Debug)] +pub struct ComparisonFunction { + op: ComparisonOperator, + left: Query, + right: Query, +} + +// #[derive(Clone, Debug)] +// pub enum StringComparisonOperator<'a> { +// Func(&'a (dyn Fn(&str) -> bool + 'a)), +// Regex(Regex), +// } + +// #[derive(Clone)] +// pub struct StringComparisonFunction<'a, Field: Clone + 'static> { +// op: StringComparisonOperator<'a>, +// val: Query, +// } + +#[derive(Clone, Debug)] +pub struct SubAccountFunction { + account_name: StringData<'static>, + val: Query, +} + +#[derive(Clone, Debug)] +pub struct RegexFunction { + left: Query, + regex: Regex, +} + +///////////////////// +// Implementations // +///////////////////// + +impl<'a, Field: Clone + Debug> ComparisonFunction { + pub fn new(op: &str, left: Query, right: Query) -> Result { + let op = match op { + "==" => ComparisonOperator::EQ, + "!=" => ComparisonOperator::NEQ, + ">" => ComparisonOperator::GT, + "<" => ComparisonOperator::LT, + ">=" => ComparisonOperator::GTE, + "<=" => ComparisonOperator::LTE, + _ => return Err("Invalid Operator".into()), + }; + Ok(ComparisonFunction { op, left, right }) + } + + pub fn new_op(op: ComparisonOperator, left: Query, right: Query) -> Self { + ComparisonFunction { op, left, right } + } +} + +impl Function for ComparisonFunction { + fn evaluate(&self, context: &dyn Data) -> Result { + let left = self.left.evaluate(context)?; + let right = self.right.evaluate(context)?; + + match self.op { + ComparisonOperator::EQ => Ok(DataValue::Boolean(left == right)), + ComparisonOperator::NEQ => Ok(DataValue::Boolean(left != right)), + ComparisonOperator::GT => Ok(DataValue::Boolean(left > right)), + ComparisonOperator::LT => Ok(DataValue::Boolean(left < right)), + ComparisonOperator::GTE => Ok(DataValue::Boolean(left >= right)), + ComparisonOperator::LTE => Ok(DataValue::Boolean(left <= right)), + } + } +} + +// impl<'a, Field: Clone> StringComparisonFunction<'a, Field> { +// pub fn new_func( +// val: Query, +// func: &'a (impl Fn(&str) -> bool + 'a), +// ) -> Result { +// Ok(StringComparisonFunction { val, op: StringComparisonOperator::Func(func) }) +// } + +// pub fn new_regex(val: Query, regex: &str) -> Result { +// let regex = Regex::new(regex).map_err(|_| CoreError::from("Unable to parse regex"))?; +// Ok(StringComparisonFunction { val, op: StringComparisonOperator::Regex(regex) }) +// } +// } + +// impl<'a, Field: Clone> Function for StringComparisonFunction<'a, Field> { +// fn evaluate(&self, context: &dyn Data) -> Result { +// let val = self.val.evaluate(context)?; + +// if let DataValue::String(val) = val { +// match &self.op { +// StringComparisonOperator::Func(func) => Ok(DataValue::Boolean(func(val.as_ref()))), +// StringComparisonOperator::Regex(regex) => { +// Ok(DataValue::Boolean(regex.is_match(val.as_ref()))) +// } +// } +// // Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) +// } else { +// Err("Cannot use REGEX operation on non string types".into()) +// } + +// // let left = self.left.evaluate(context)?; + +// // if let DataValue::String(left) = left { +// // Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) +// // } else { +// // Err("Cannot use REGEX operation on non string types".into()) +// // } +// } +// } + +impl<'a, Field: Clone + Debug> SubAccountFunction { + pub fn new(account: StringData<'a>, val: Query) -> Self { + SubAccountFunction { account_name: account.into_owned(), val } + } +} + +impl Function for SubAccountFunction { + fn evaluate(&self, context: &dyn Data) -> Result { + let val = self.val.evaluate(context)?; + + if let DataValue::String(val) = val { + Ok(DataValue::Boolean( + val.as_ref() + .strip_prefix(self.account_name.as_ref()) + .map(|n| n.is_empty() || n.starts_with(":")) + .unwrap_or(false), + )) + } else { + Err("Cannot compare account name on non string types".into()) + } + } +} + +impl RegexFunction { + pub fn new(left: Query, regex: &str, case_insensitive: bool) -> Result { + let regex = RegexBuilder::new(regex).case_insensitive(case_insensitive).build().map_err(|_| CoreError::from("Unable to parse regex"))?; + Ok(RegexFunction { left, regex }) + } +} + +impl Function for RegexFunction { + fn evaluate(&self, context: &dyn Data) -> Result { + let left = self.left.evaluate(context)?; + + if let DataValue::String(left) = left { + Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) + } else { + Err("Cannot use REGEX operation on non string types".into()) + } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone)] + enum TestField {} + + struct TestData {} + + impl Data for TestData { + fn get_field(&self, _: &TestField) -> Result { + Err("".into()) + } + } + + #[test] + fn comparison_function_evaluate() { + let value1: DataValue = 5.into(); + let value2: DataValue = 10.into(); + let context = TestData {}; + + let comparison_function = + ComparisonFunction::::new(">", value1.into(), value2.into()).unwrap(); + assert_eq!(comparison_function.op, ComparisonOperator::GT); + assert_eq!(comparison_function.evaluate(&context), Ok(false.into())); + + let value = Query::from_fn(comparison_function); + assert_eq!(value.evaluate(&context), Ok(false.into())); + } +} diff --git a/src/query/functions_logical.rs b/src/query/functions_logical.rs new file mode 100644 index 0000000..c5552ad --- /dev/null +++ b/src/query/functions_logical.rs @@ -0,0 +1,125 @@ +use super::{Data, Function, Query}; +use crate::core::{CoreError, DataValue}; +use std::fmt::Debug; + +#[derive(PartialEq, Debug, Clone)] +pub enum LogicalOperator { + AND, + OR, +} + +#[derive(Debug, Clone)] +pub struct LogicalFunction { + op: LogicalOperator, + left: Query, + right: Query, +} + +#[derive(Debug, Clone)] +pub struct NotFunction { + value: Query, +} + +///////////////////// +// Implementations // +///////////////////// + +impl LogicalFunction { + pub fn new(op: &str, left: Query, right: Query) -> Result { + if op.eq_ignore_ascii_case("and") { + Ok(LogicalFunction { op: LogicalOperator::AND, left, right }) + } else if op.eq_ignore_ascii_case("or") { + Ok(LogicalFunction { op: LogicalOperator::OR, left, right }) + } else { + Err("Invalid logical operator".into()) + } + } + + pub fn new_op(op: LogicalOperator, left: Query, right: Query) -> Self { + LogicalFunction { op, left, right } + } +} + +impl Function for LogicalFunction { + fn evaluate(&self, context: &dyn Data) -> Result { + // More verbose to try and avoid doing right side computation + let value: bool = match self.op { + LogicalOperator::AND => { + if !bool::from(self.left.evaluate(context)?) { + false + } else { + self.right.evaluate(context)?.into() + } + } + LogicalOperator::OR => { + if bool::from(self.left.evaluate(context)?) { + true + } else { + self.right.evaluate(context)?.into() + } + } + }; + + Ok(DataValue::Boolean(value)) + + // let left = self.left.evaluate(context)?.into(); + // let right = self.right.evaluate(context)?.into(); + + // match self.op { + // LogicalOperator::AND => Ok(DataValue::Boolean(left && right)), + // LogicalOperator::OR => Ok(DataValue::Boolean(left || right)), + // } + } +} + +impl NotFunction { + pub fn new(value: Query) -> Self { + NotFunction { value } + } +} + +impl Function for NotFunction { + fn evaluate(&self, context: &dyn Data) -> Result { + let value: bool = self.value.evaluate(context)?.into(); + + Ok(DataValue::Boolean(!value)) + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone)] + enum TestField {} + + struct TestData {} + + impl Data for TestData { + fn get_field(&self, _: &TestField) -> Result { + Err("".into()) + } + } + + #[test] + fn logical_function_evaluate() { + let value1: DataValue = 1.into(); + let value2: DataValue = false.into(); + let context = TestData {}; + + let logical_function_and = + LogicalFunction::::new("and", value1.clone().into(), value2.clone().into()) + .unwrap(); + assert_eq!(logical_function_and.op, LogicalOperator::AND); + assert_eq!(logical_function_and.evaluate(&context), Ok(false.into())); + + let logical_function_or = + LogicalFunction::::new("or", value1.into(), value2.into()).unwrap(); + assert_eq!(logical_function_or.op, LogicalOperator::OR); + assert_eq!(logical_function_or.evaluate(&context), Ok(true.into())); + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs new file mode 100644 index 0000000..ff46574 --- /dev/null +++ b/src/query/mod.rs @@ -0,0 +1,11 @@ +mod query; +mod functions_comparison; +mod functions_logical; +mod transaction; +mod balance; + +pub use query::*; +pub use functions_comparison::*; +pub use functions_logical::*; +pub use transaction::*; +pub use balance::*; \ No newline at end of file diff --git a/src/query/query.rs b/src/query/query.rs new file mode 100644 index 0000000..fe0430f --- /dev/null +++ b/src/query/query.rs @@ -0,0 +1,68 @@ +use std::fmt::Debug; + +use crate::core::{CoreError, DataValue}; + +#[derive(Debug)] +pub enum Query { + Field(T), + Value(DataValue), + Function(Box>), +} + +pub trait Data { + fn get_field(&self, field: &T) -> Result; +} + +pub trait Function: FunctionClone + Debug { + fn evaluate(&self, context: &dyn Data) -> Result; +} + +pub trait FunctionClone { + fn clone_box(&self) -> Box>; +} + +impl FunctionClone for F +where + F: Function + Clone + 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } +} + +impl Clone for Query +where + T: Clone + Debug, +{ + fn clone(&self) -> Self { + match self { + Query::Field(f) => Query::Field(f.clone()), + Query::Value(v) => Query::Value(v.clone()), + Query::Function(f) => Query::Function(f.clone_box()), + } + } +} + +impl Query { + pub fn evaluate(&self, context: &dyn Data) -> Result { + match self { + Query::Field(field) => context.get_field(field), + Query::Value(constant) => Ok(constant.clone()), + Query::Function(function) => function.evaluate(context), + } + } + + pub fn from_field(field: T) -> Self { + Query::Field(field) + } + + pub fn from_fn + Sized + 'static>(function: F) -> Self { + Query::Function(Box::new(function)) + } +} + +impl<'a, T: Clone + Debug> From for Query { + fn from(constant: DataValue) -> Self { + Query::Value(constant) + } +} diff --git a/src/queries/transaction.rs b/src/query/transaction.rs similarity index 85% rename from src/queries/transaction.rs rename to src/query/transaction.rs index 6225001..b61e9d3 100644 --- a/src/queries/transaction.rs +++ b/src/query/transaction.rs @@ -1,13 +1,14 @@ -use crate::core::{CoreError, Ledger, Posting, Transaction}; - -use super::base::{Data, DataValue}; +use super::Data; +use crate::core::{CoreError, DataValue, Ledger, Posting, Transaction}; +#[derive(Debug, Clone)] pub enum AccountField { Name, OpenDate, CloseDate, } +#[derive(Debug, Clone)] pub enum PostingField { Transaction(TransactionField), Account(AccountField), @@ -16,6 +17,7 @@ pub enum PostingField { Price, } +#[derive(Debug, Clone)] pub enum TransactionField { Date, Flag, @@ -32,7 +34,9 @@ pub struct PostingData<'a> { impl<'a> Data for PostingData<'a> { fn get_field(&self, field: &PostingField) -> Result { match field { - PostingField::Transaction(transaction_field) => get_transaction_value(transaction_field, &self.parent_transaction), + PostingField::Transaction(transaction_field) => { + get_transaction_value(transaction_field, &self.parent_transaction) + } PostingField::Account(account_field) => { let account = self .ledger @@ -60,7 +64,7 @@ impl<'a> Data for PostingData<'a> { fn get_transaction_value<'a>( field: &TransactionField, transaction: &Transaction, -) -> Result, CoreError> { +) -> Result { match field { TransactionField::Date => Ok(transaction.get_date().into()), TransactionField::Flag => Ok(char::from(transaction.get_flag()).to_string().into()),