diff --git a/Cargo.toml b/Cargo.toml index 9e2df5e..5347031 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,5 +11,12 @@ nom = "7.1.3" nom_locate = "4.2.0" rand = "0.8.5" ratatui = "0.29.0" +regex = "1.11.1" rust_decimal = "1.36.0" rust_decimal_macros = "1.36.0" + +[profile.release] +debug = 1 + +[rust] +debuginfo-level = 1 \ No newline at end of file diff --git a/src/core/amounts.rs b/src/core/amounts.rs index 124e3bf..5037fc3 100644 --- a/src/core/amounts.rs +++ b/src/core/amounts.rs @@ -90,4 +90,14 @@ pub fn combine_amounts(amounts: impl Iterator) -> Vec { } output_amounts.iter().map(|(&unit_id, &value)| Amount {value, unit_id}).collect() +} + +impl PartialOrd for Amount { + fn partial_cmp(&self, other: &Self) -> Option { + if self.unit_id != other.unit_id { + None + } else { + self.value.partial_cmp(&other.value) + } + } } \ No newline at end of file diff --git a/src/core/errors.rs b/src/core/errors.rs index 7b4a23a..29d1834 100644 --- a/src/core/errors.rs +++ b/src/core/errors.rs @@ -1,9 +1,11 @@ use core::fmt; +#[derive(PartialEq)] pub struct CoreError { text: StringData, } +#[derive(PartialEq)] enum StringData { Static(&'static str), Dynamic(String), diff --git a/src/core/ledger.rs b/src/core/ledger.rs index aa5b1b2..b5bff28 100644 --- a/src/core/ledger.rs +++ b/src/core/ledger.rs @@ -1,11 +1,14 @@ +use chrono::NaiveDate; +use rust_decimal::Decimal; use rust_decimal_macros::dec; -use super::{Account, Amount, CoreError, Transaction, Unit}; +use super::{Account, Amount, CoreError, Price, Transaction, Unit}; #[derive(Debug)] pub struct Ledger { accounts: Vec, units: Vec, + prices: Vec, transactions: Vec, } @@ -14,6 +17,7 @@ impl Ledger { Ledger { accounts: Vec::new(), units: Vec::new(), + prices: Vec::new(), transactions: Vec::new(), } } @@ -27,7 +31,9 @@ impl Ledger { } pub fn get_account_by_name(&self, name: &str) -> Option<&Account> { - self.accounts.iter().find(|account| account.get_name() == name) + self.accounts + .iter() + .find(|account| account.get_name() == name) } pub fn get_units(&self) -> &Vec { @@ -39,13 +45,38 @@ impl Ledger { } pub fn get_unit_by_symbol(&self, unit_symbol: &str) -> Option<&Unit> { - self.units.iter().find(|unit| unit.matches_symbol(unit_symbol)) + self.units + .iter() + .find(|unit| unit.matches_symbol(unit_symbol)) } pub fn get_transactions(&self) -> &Vec { &self.transactions } + // Assume prices are sorted by date already + // For now only trivial conversions, not multiple conversions + pub fn get_price_on_date( + &self, + date: NaiveDate, + original_unit_id: u32, + new_unit_id: u32, + ) -> Option { + let max_pos = self + .prices + .iter() + .position(|p| p.date > date) + .unwrap_or(self.prices.len()); + let valid_prices = &self.prices[..max_pos]; + + let price = valid_prices + .iter() + .rev() + .find(|p| p.unit_id == original_unit_id && p.amount.unit_id == new_unit_id); + + price.map(|p| p.amount.value) + } + pub fn round_amount(&self, amount: &Amount) -> Amount { let mut new_amount = *amount; let unit = self.get_unit(amount.unit_id); @@ -60,15 +91,15 @@ impl Ledger { } pub fn round_amounts(&self, amounts: &[Amount]) -> Vec { - amounts.iter().map(|a| self.round_amount(a)).filter(|a| a.value != dec!(0)).collect() + amounts + .iter() + .map(|a| self.round_amount(a)) + .filter(|a| a.value != dec!(0)) + .collect() } pub fn add_account(&mut self, account: Account) -> Result<(), CoreError> { - if self - .accounts - .iter() - .any(|existing_account| existing_account.get_name() == account.get_name()) - { + if self.get_account_by_name(&account.get_name()).is_some() { return Err("Account with the same name already exists".into()); } @@ -100,10 +131,37 @@ impl Ledger { println!("{:?}", balances); return Err("Transaction is not balanced".into()); } + + for posting in transaction.get_postings() { + if let Some(price_amount) = posting.get_price() { + self.add_price(Price { + amount: *price_amount, + date: transaction.get_date(), + unit_id: posting.get_amount().unit_id, + })?; + } else if let Some(cost_amount) = posting.get_cost() { + self.add_price(Price { + amount: *cost_amount, + date: transaction.get_date(), + unit_id: posting.get_amount().unit_id, + })?; + } + } + self.transactions.push(transaction); Ok(()) } + + pub fn add_price(&mut self, price: Price) -> Result<(), CoreError> { + self.prices.push(price); + + Ok(()) + } + + pub fn sort_prices(&mut self) { + self.prices.sort_by(|a, b| a.date.cmp(&b.date)); + } } #[cfg(test)] diff --git a/src/core/mod.rs b/src/core/mod.rs index 8e5e359..382fc23 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,12 +1,14 @@ mod account; mod amounts; +mod common; mod errors; mod ledger; +mod price; mod transaction; -mod common; pub use account::*; pub use amounts::*; pub use errors::*; pub use ledger::*; +pub use price::*; pub use transaction::*; diff --git a/src/core/price.rs b/src/core/price.rs new file mode 100644 index 0000000..dad2443 --- /dev/null +++ b/src/core/price.rs @@ -0,0 +1,10 @@ +use chrono::NaiveDate; + +use super::Amount; + +#[derive(Debug, Clone, Copy)] +pub struct Price { + pub unit_id: u32, + pub date: NaiveDate, + pub amount: Amount, +} diff --git a/src/document/directives.rs b/src/document/directives.rs index 36ee9e0..3e176d1 100644 --- a/src/document/directives.rs +++ b/src/document/directives.rs @@ -37,6 +37,7 @@ pub struct TransactionDirective { pub payee: Option, pub narration: Option, pub postings: Vec, + pub metadata: Vec<(String, String)>, } #[derive(Debug, Clone)] @@ -52,6 +53,7 @@ pub struct BalanceDirective { #[derive(Debug, PartialEq, Clone)] pub struct DirectivePosting { + pub date: Option, pub account: String, pub amount: Option, pub cost: Option, diff --git a/src/document/ledger.rs b/src/document/ledger.rs index f654689..801928a 100644 --- a/src/document/ledger.rs +++ b/src/document/ledger.rs @@ -6,7 +6,16 @@ use crate::{ core::{ Account, Amount, CoreError, Ledger, Posting, Transaction, TransactionFlag, Unit, UnitSymbol, }, - queries::{self, Query}, + queries::{ + self, + base::{self, DataValue}, + functions::{ + ComparisonFunction, LogicalFunction, RegexFunction, StringComparisonFunction, + SubAccountFunction, + }, + transaction::{AccountField, PostingField, TransactionField}, + Query, + }, }; use super::{BalanceDirective, DirectiveAmount, TransactionDirective}; @@ -63,17 +72,48 @@ pub fn add_transaction( Ok(()) } -pub fn check_balance(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), CoreError> { - let accounts = queries::balance(&ledger, &[Query::EndDate(balance.date)]); +pub fn check_balance2(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), CoreError> { + let date_query = ComparisonFunction::new( + "<=", + base::Query::from_field(PostingField::Transaction(TransactionField::Date)), + base::Query::from(DataValue::from(balance.date)), + ) + .unwrap(); + // let account_fn = |str: &str| { + // str.strip_prefix("abc") + // .map(|n| n.is_empty() || n.starts_with(":")) + // .unwrap_or(false) + // }; + // let account_query = StringComparisonFunction::new_func( + // base::Query::from_field(PostingField::Account(AccountField::Name)), + // &account_fn, + // )?; + // let account_regex = format!("^{}($|:.+)", balance.account); + // let account_query = RegexFunction::new( + // base::Query::from_field(PostingField::Account(AccountField::Name)), + // &account_regex, + // )?; + let account_query = SubAccountFunction::new( + balance.account.clone().into(), + base::Query::from_field(PostingField::Account(AccountField::Name)), + ); - let accounts = accounts.iter().filter(|(&account_id, val)| { - let account = ledger.get_account(account_id).unwrap(); - account.is_under_account(&balance.account) - }); + let start = Instant::now(); - if accounts.clone().count() == 0 { + let total_query = LogicalFunction::new( + "and", + base::Query::from_fn(date_query), + base::Query::from_fn(account_query), + ) + .unwrap(); - } + let t2 = Instant::now(); + + let accounts = queries::balance3(&ledger, &base::Query::from_fn(total_query)); + + let t3 = Instant::now(); + + // println!("{:?} {:?}", t2-start, t3-t2); let mut total_amounts = HashMap::new(); let mut account_count = 0; @@ -108,7 +148,81 @@ pub fn check_balance(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), let unit = ledger .get_unit_by_symbol(&balance_amount.unit_symbol) .ok_or("Unit not found")?; - let value = total_amounts.get(&unit.get_id()).map(|v| *v).unwrap_or(dec!(0)); + let value = total_amounts + .get(&unit.get_id()) + .map(|v| *v) + .unwrap_or(dec!(0)); + + // let value = amounts + // .iter() + // .find(|a| a.unit_id == unit.get_id()) + // .map(|a| a.value) + // .unwrap_or(dec!(0)); + let max_scale = max(value.scale(), balance_amount.value.scale()); + + let value = value.round_dp(max_scale); + let balance_value = balance_amount.value.round_dp(max_scale); + + if value != balance_value { + return Err(format!( + "Balance amount for \"{}\" on {} does not match. Expected {} but got {}", + balance.account, balance.date, balance_value, value + ) + .into()); + } + } + + Ok(()) +} + +pub fn check_balance(ledger: &Ledger, balance: &BalanceDirective) -> Result<(), CoreError> { + let accounts = queries::balance(&ledger, &[Query::EndDate(balance.date)]); + // let accounts = queries::balance2(&ledger, balance.date); + + let accounts = accounts.iter().filter(|(&account_id, val)| { + let account = ledger.get_account(account_id).unwrap(); + account.is_under_account(&balance.account) + }); + + if accounts.clone().count() == 0 {} + + let mut total_amounts = HashMap::new(); + let mut account_count = 0; + + for (_, amounts) in accounts { + account_count += 1; + for amount in amounts { + *total_amounts.entry(amount.unit_id).or_insert(dec!(0)) += amount.value; + } + } + + if account_count == 0 { + return Err("No accounts match balance account".into()); + } + + // let balance_account = ledger + // .get_account_by_name(&balance.account) + // .ok_or("Account not found")?; + + // let amounts = accounts + // .get(&balance_account.get_id()) + // .map(|v| v.as_slice()) + // .unwrap_or(&[]); + + // if amounts.len() > balance.amounts.len() { + // return Err("".into()); + // } else if amounts.len() < balance.amounts.len() { + // return Err("".into()); + // } + + for balance_amount in &balance.amounts { + let unit = ledger + .get_unit_by_symbol(&balance_amount.unit_symbol) + .ok_or("Unit not found")?; + let value = total_amounts + .get(&unit.get_id()) + .map(|v| *v) + .unwrap_or(dec!(0)); // let value = amounts // .iter() diff --git a/src/document/mod.rs b/src/document/mod.rs index 779b55d..77c598b 100644 --- a/src/document/mod.rs +++ b/src/document/mod.rs @@ -1,13 +1,13 @@ mod directives; -mod parser; mod ledger; +mod parser; pub use directives::*; -use ledger::{add_transaction, check_balance}; +use ledger::{add_transaction, check_balance2}; use parser::parse_directives; use crate::core::{CoreError, Ledger, Unit}; -use std::path::Path; +use std::{path::Path, time::Instant}; #[derive(Debug)] pub struct Document { @@ -40,9 +40,14 @@ impl Document { add_transaction(&mut ledger, transaction)?; } + ledger.sort_prices(); + + let start = Instant::now(); for balance in &self.directives.balances { - check_balance(&ledger, &balance)?; + check_balance2(&ledger, &balance)?; } + let end = Instant::now(); + println!("time to calculate balance: {:?}", end - start); Ok(ledger) // for balance in self.directives.balances { diff --git a/src/document/parser/base_directive.rs b/src/document/parser/base_directive.rs index cb18bb7..f56fa6c 100644 --- a/src/document/parser/base_directive.rs +++ b/src/document/parser/base_directive.rs @@ -58,6 +58,22 @@ pub fn empty_lines(input: &str) -> IResult<&str, ()> { .parse(input) } +pub fn parse_iso_date(input: &str) -> IResult<&str, NaiveDate> { + let (new_input, (year, _, month, _, day)) = tuple(( + date_year, + opt(tag("-")), + date_month, + opt(tag("-")), + date_day, + )) + .parse(input)?; + + match NaiveDate::from_ymd_opt(year, month, day) { + Some(date) => Ok((new_input, date)), + None => Err(nom::Err::Error(Error::new(input, ErrorKind::Eof))), + } +} + /////////////// // Private // /////////////// @@ -82,22 +98,6 @@ fn date_day(input: &str) -> IResult<&str, u32> { take_n_digits(input, 2) } -fn parse_iso_date(input: &str) -> IResult<&str, NaiveDate> { - let (new_input, (year, _, month, _, day)) = tuple(( - date_year, - opt(tag("-")), - date_month, - opt(tag("-")), - date_day, - )) - .parse(input)?; - - match NaiveDate::from_ymd_opt(year, month, day) { - Some(date) => Ok((new_input, date)), - None => Err(nom::Err::Error(Error::new(input, ErrorKind::Eof))), - } -} - const COMMENT_CHARS: &str = ";#"; fn parse_comment(input: &str) -> IResult<&str, &str> { diff --git a/src/document/parser/mod.rs b/src/document/parser/mod.rs index 4a6d13b..d16def4 100644 --- a/src/document/parser/mod.rs +++ b/src/document/parser/mod.rs @@ -2,6 +2,7 @@ mod amounts; mod base_directive; mod directives; mod transaction; +mod shared; use base_directive::{base_directive, empty_lines}; use directives::{specific_directive, Directive}; diff --git a/src/document/parser/shared.rs b/src/document/parser/shared.rs new file mode 100644 index 0000000..82252c3 --- /dev/null +++ b/src/document/parser/shared.rs @@ -0,0 +1,48 @@ +use nom::{ + bytes::complete::tag, + character::complete::space0, + error::ErrorKind, + sequence::{delimited, tuple}, + IResult, InputTakeAtPosition, Parser, +}; + +pub fn metadatum(input: &str) -> IResult<&str, (&str, &str)> { + tuple(( + delimited(tag("-"), delimited(space0, key, space0), tag(":")), + value, + )) + .map(|v| (v.0, v.1.trim())) + .parse(input) +} + +pub fn key(input: &str) -> IResult<&str, &str> { + input.split_at_position1_complete(|item| item == ':', ErrorKind::AlphaNumeric) +} + +pub fn value(input: &str) -> IResult<&str, &str> { + input.split_at_position1_complete(|_| false, ErrorKind::AlphaNumeric) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_metadatum() { + assert_eq!( + metadatum("- key: value").unwrap().1, + ("key".into(), "value".into()) + ); + assert_eq!( + metadatum("- key space: value space").unwrap().1, + ("key space".into(), "value space".into()) + ); + } + + #[test] + fn parse_metadatum_invalid() { + assert!(metadatum("- key:").is_err()); + assert!(metadatum("- : value").is_err()); + assert!(metadatum("- : ").is_err()); + } +} diff --git a/src/document/parser/transaction.rs b/src/document/parser/transaction.rs index d206a04..e1017c0 100644 --- a/src/document/parser/transaction.rs +++ b/src/document/parser/transaction.rs @@ -1,18 +1,19 @@ use nom::{ branch::alt, bytes::complete::{is_not, tag}, - character::complete::space1, + character::complete::{space0, space1}, combinator::{eof, opt, rest}, error::{Error, ErrorKind}, - sequence::{delimited, preceded, tuple}, - Err, IResult, Parser, + sequence::{delimited, preceded, terminated, tuple}, + Err, IResult, InputTakeAtPosition, Parser, }; use crate::document::{DirectiveAmount, DirectivePosting, TransactionDirective}; use super::{ amounts::{account, amount}, - base_directive::BaseDirective, + base_directive::{parse_iso_date, BaseDirective}, + shared::metadatum, }; // use super::{ @@ -52,7 +53,12 @@ pub fn transaction<'a>( }; let mut postings = Vec::with_capacity(directive.lines.len()); + let mut metadata = Vec::new(); for &line in directive.lines.get(1..).unwrap_or(&[]) { + if let Ok(m) = metadatum(line) { + metadata.push((m.1.0.to_string(), m.1.1.to_string())); + continue; + } let posting = if let Ok(v) = posting(line) { v } else { @@ -73,6 +79,7 @@ pub fn transaction<'a>( payee: payee.map(|p| p.to_string()), narration: narration.map(|n| n.to_string()), postings, + metadata, }, )) } @@ -94,6 +101,7 @@ fn payee_narration(input: &str) -> IResult<&str, (Option<&str>, Option<&str>)> { fn posting(input: &str) -> IResult<&str, DirectivePosting> { tuple(( + opt(terminated(parse_iso_date, space1)), account, opt(tuple(( preceded(space1, amount), @@ -102,7 +110,7 @@ fn posting(input: &str) -> IResult<&str, DirectivePosting> { ))), eof, )) - .map(|(account, value, _)| { + .map(|(date, account, value, _)| { let mut amount = None; let mut cost = None; let mut price = None; @@ -131,7 +139,13 @@ fn posting(input: &str) -> IResult<&str, DirectivePosting> { } } } - DirectivePosting { account: account.to_string(), amount, cost, price } + DirectivePosting { + date, + account: account.to_string(), + amount, + cost, + price, + } }) .parse(input) } @@ -194,6 +208,7 @@ mod tests { assert_eq!( posting("Account1 10 SHARE {$100}").unwrap().1, DirectivePosting { + date: None, account: "Account1".into(), amount: Some(DirectiveAmount { value: dec!(10), @@ -211,6 +226,7 @@ mod tests { assert_eq!( posting("Account1 10 SHARE {{1000 USD}}").unwrap().1, DirectivePosting { + date: None, account: "Account1".into(), amount: Some(DirectiveAmount { value: dec!(10), @@ -232,6 +248,7 @@ mod tests { assert_eq!( posting("Account1 10 SHARE @ $100").unwrap().1, DirectivePosting { + date: None, account: "Account1".into(), amount: Some(DirectiveAmount { value: dec!(10), @@ -249,6 +266,7 @@ mod tests { assert_eq!( posting("Account1 10 SHARE @@ 1000 USD").unwrap().1, DirectivePosting { + date: None, account: "Account1".into(), amount: Some(DirectiveAmount { value: dec!(10), @@ -270,6 +288,7 @@ mod tests { assert_eq!( posting("Account1 10 SHARE {$100} @ $110").unwrap().1, DirectivePosting { + date: None, account: "Account1".into(), amount: Some(DirectiveAmount { value: dec!(10), @@ -293,6 +312,7 @@ mod tests { .unwrap() .1, DirectivePosting { + date: None, account: "Account1".into(), amount: Some(DirectiveAmount { value: dec!(10), @@ -316,6 +336,7 @@ mod tests { .unwrap() .1, DirectivePosting { + date: None, account: "Account1".into(), amount: Some(DirectiveAmount { value: dec!(10), @@ -336,6 +357,34 @@ mod tests { ); } + #[test] + fn parse_posting_with_date() { + assert_eq!( + posting("2000-01-01 Account1 10 SHARE {$100} @ $110") + .unwrap() + .1, + DirectivePosting { + date: Some(NaiveDate::from_ymd_opt(2000, 01, 01).unwrap()), + account: "Account1".into(), + amount: Some(DirectiveAmount { + value: dec!(10), + unit_symbol: "SHARE".into(), + is_unit_prefix: false + }), + cost: Some(DirectiveAmount { + value: dec!(100), + unit_symbol: "$".into(), + is_unit_prefix: true + }), + price: Some(DirectiveAmount { + value: dec!(110), + unit_symbol: "$".into(), + is_unit_prefix: true + }), + } + ); + } + #[test] fn parse_transaction_postings() { let directive = BaseDirective { @@ -362,4 +411,25 @@ mod tests { assert_eq!(transaction.postings[1].account, "Account3"); assert_eq!(transaction.postings[1].amount, None); } + + #[test] + fn parse_transaction_postings_metadata() { + let directive = BaseDirective { + date: NaiveDate::from_ymd_opt(2000, 01, 01), + directive_name: "txn", + lines: vec![ + "payee | narration", + "- key: value", + "Account1:Account2 $10.01", + "Account3", + ], + }; + + let transaction = transaction(directive).unwrap().1; + assert_eq!(transaction.metadata.len(), 1); + assert_eq!(transaction.metadata[0], ("key".into(), "value".into())); + assert_eq!(transaction.postings.len(), 2); + assert_eq!(transaction.postings[0].account, "Account1:Account2"); + assert_eq!(transaction.postings[1].account, "Account3"); + } } diff --git a/src/lib.rs b/src/lib.rs index a4dd631..343b291 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ pub mod queries; // pub mod create_ledger; pub mod document; pub mod output; +mod parser; // pub struct Account { // // TODO diff --git a/src/main.rs b/src/main.rs index 7fe425c..0372a4a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -41,30 +41,30 @@ pub fn main() -> Result<(), Box> { // )) // ); - let stdout = io::stdout(); - let backend = CrosstermBackend::new(stdout); - let mut terminal = Terminal::new(backend)?; + // let stdout = io::stdout(); + // let backend = CrosstermBackend::new(stdout); + // let mut terminal = Terminal::new(backend)?; - let line = Line::from(vec![ - Span::raw("Hello "), - Span::styled("Hello ", Style::new().fg(Color::Rgb(100, 200, 150))), - Span::styled("World", Style::new().fg(Color::Green).bg(Color::White)), - ]) - .centered(); - let text = Text::from(line); + // let line = Line::from(vec![ + // Span::raw("Hello "), + // Span::styled("Hello ", Style::new().fg(Color::Rgb(100, 200, 150))), + // Span::styled("World", Style::new().fg(Color::Green).bg(Color::White)), + // ]) + // .centered(); + // let text = Text::from(line); - println!("{}", text_to_ansi(&text)); + // println!("{}", text_to_ansi(&text)); - // println!("{:?}", line.to_string()); + // // println!("{:?}", line.to_string()); - // terminal.dra + // // terminal.dra - // crossterm::terminal::enable_raw_mode()?; + // // crossterm::terminal::enable_raw_mode()?; - terminal.draw(|f| { - let area = f.area(); - f.render_widget(text, area); - })?; + // terminal.draw(|f| { + // let area = f.area(); + // f.render_widget(text, area); + // })?; // PrintStyledContent @@ -88,18 +88,36 @@ pub fn main() -> Result<(), Box> { // println!("{}", text.render(area, buf);); - return Ok(()); + // return Ok(()); - // let document = Document::new(Path::new("data/full/main.ledger")).unwrap(); + let t1 = Instant::now(); - // let ledger = document.generate_ledger().unwrap(); + let document = Document::new(Path::new("data/full/main.ledger")).unwrap(); + + let t2 = Instant::now(); + + let ledger = document.generate_ledger().unwrap(); + + let t3 = Instant::now(); // let balance = queries::balance( // &ledger, // &[], // ); - // format_balance(&ledger, &balance); + let balance = queries::balance2( + &ledger, + NaiveDate::from_ymd_opt(2100, 01, 01).unwrap(), + Some("$") + ); + + let t4 = Instant::now(); + + format_balance(&ledger, &balance); + + let t5 = Instant::now(); + + println!("{:?} - {:?} - {:?} - {:?}", t2-t1, t3-t2, t4-t3, t5-t4); // return; @@ -127,4 +145,6 @@ pub fn main() -> Result<(), Box> { // let ledger = create_ledger(&file_data).unwrap(); // println!("{:?}", val); + + return Ok(()); } diff --git a/src/output/amounts.rs b/src/output/amounts.rs index f4d811b..83657e8 100644 --- a/src/output/amounts.rs +++ b/src/output/amounts.rs @@ -1,18 +1,45 @@ use crate::core::{Amount, Ledger}; impl Ledger { - pub fn format_amount(&self, amount: &Amount) -> String { + /// Returns formatted string and position of decimal point in string + pub fn format_amount(&self, amount: &Amount) -> (String, usize) { let unit = self.get_unit(amount.unit_id).unwrap(); let default_symbol = unit.default_symbol(); let amount = self.round_amount(&amount); - if default_symbol.is_prefix { - format!("{}{}", default_symbol.symbol, amount.value) + let sign = if amount.value.is_sign_negative() { + "-" } else { + "" + }; + + let value = amount.value.abs().to_string(); + let mut split = value.split("."); + + let mut value = split.next().unwrap() + .as_bytes() + .rchunks(3) + .rev() + .map(std::str::from_utf8) + .collect::, _>>() + .unwrap() + .join(","); + let value_decimal_pos = value.len(); + + if let Some(decimal) = split.next() { + value += "."; + value += decimal; + } + + if default_symbol.is_prefix { + let decimal_pos = sign.len() + default_symbol.symbol.len() + value_decimal_pos; + (format!("{}{}{}", sign, default_symbol.symbol, value), decimal_pos) + } else { + let decimal_pos = sign.len() + value_decimal_pos; if default_symbol.symbol.len() == 1 { - format!("{}{}", amount.value, default_symbol.symbol) + (format!("{}{}{}", sign, value, default_symbol.symbol), decimal_pos) } else { - format!("{} {}", amount.value, default_symbol.symbol) + (format!("{}{} {}", sign, value, default_symbol.symbol), decimal_pos) } } } diff --git a/src/output/cli/balance.rs b/src/output/cli/balance.rs index 5e0c50a..c3063e8 100644 --- a/src/output/cli/balance.rs +++ b/src/output/cli/balance.rs @@ -3,7 +3,9 @@ use std::{ collections::{HashMap, HashSet}, }; -use crate::core::{combine_amounts, Account, Amount, Ledger}; +use ratatui::{style::{Color, Style}, text::{Line, Span, Text}}; + +use crate::{core::{combine_amounts, Account, Amount, Ledger}, output::cli::tui_to_ansi::text_to_ansi}; #[derive(Debug)] struct BalanceTree { @@ -12,6 +14,13 @@ struct BalanceTree { amounts: Option>, } +#[derive(Debug)] +struct BalanceTreeStr { + name: String, + children: Vec, + amounts: Option>, +} + struct AccountInfo<'a> { account_path: Vec<&'a str>, amounts: Vec, @@ -105,43 +114,130 @@ fn set_tree_totals(tree: &mut BalanceTree) { tree.amounts = Some(total_amounts); } -fn print_tree(tree: &BalanceTree, ledger: &Ledger, level: usize, amount_pos: usize) { - let relative_amount_pos = amount_pos - (level*2 + tree.name.len()); - let main_line = format!("{}{} {}", " ".repeat(level), tree.name, "─".repeat(relative_amount_pos)); - let tree_amounts = tree.amounts.as_ref().unwrap().iter().filter(|v| !ledger.round_amount(v).value.is_zero()); +const STYLE_LINE: Style = Style::new().fg(Color::LightBlue); +const STYLE_AMOUNT_LINE: Style = Style::new().fg(Color::DarkGray); +const STYLE_ACCOUNT: Style = Style::new().fg(Color::LightBlue); + +fn tree_to_text(tree: &BalanceTreeStr, ledger: &Ledger, base_amount_pos: usize, max_decimal_pos: usize) -> Text<'static> { + let mut text = Text::default(); + + // let tree_amounts = tree.amounts.as_ref().unwrap().iter().filter(|v| !ledger.round_amount(v).value.is_zero()); + let tree_amounts = tree.amounts.as_ref().unwrap().iter(); let tree_amounts_count = tree_amounts.clone().count(); - for (i, amount) in tree_amounts.enumerate() { - let mut line = String::new(); + for (i, (amount, decimal_pos)) in tree_amounts.enumerate() { + let mut line = Line::default(); + let amount_padding_count = max_decimal_pos - decimal_pos; if i == 0 { - line += &main_line; + let amount_pos = base_amount_pos - tree.name.chars().count(); + + line.push_span(Span::styled(format!("{} ", tree.name), STYLE_ACCOUNT)); + + let mut line_str = "─".repeat(amount_pos); if tree_amounts_count > 1 { - line += "┬" + line_str += "┬" } else { - line += "─" + line_str += "─" } + line_str += &"─".repeat(amount_padding_count); + line.push_span(Span::styled(line_str, STYLE_AMOUNT_LINE)); } else { - line += &" ".repeat(amount_pos); - if i == tree_amounts_count - 1 { - line += " └" + let line_str = if tree.children.len() > 0 { + " │" } else { - line += " │" + " " + }; + line.push_span(Span::styled(line_str, STYLE_LINE)); + + let mut line_str = String::new(); + line_str += &" ".repeat(base_amount_pos - 2); + if i == tree_amounts_count - 1 { + line_str += " └"; + line_str += &"─".repeat(amount_padding_count); + } else { + line_str += " │"; + line_str += &" ".repeat(amount_padding_count); } + line.push_span(Span::styled(line_str, STYLE_AMOUNT_LINE)); } + line.push_span(Span::raw(format!(" {}", amount))); - line += &format!(" {}", ledger.format_amount(amount)); - - println!("{}", line); + text.push_line(line); } - // println!("{}{} {} {:?}", " ".repeat(level), tree.name, "-".repeat(relative_amount_pos), tree.amounts); - let mut children: Vec<&BalanceTree> = tree.children.iter().collect(); + let mut children: Vec<&BalanceTreeStr> = tree.children.iter().collect(); + let children_len = children.len(); children.sort_by(|a, b| a.name.cmp(&b.name)); - for child in children { - print_tree(&child, ledger, level + 1, amount_pos); + for (i_c, child) in children.into_iter().enumerate() { + let mut child_text = tree_to_text(&child, ledger, base_amount_pos - 4, max_decimal_pos); + for (i, line) in child_text.lines.into_iter().enumerate() { + let mut whole_line = Line::default(); + + if i_c == children_len - 1 { + if i == 0 { + whole_line.push_span(Span::styled(" └─ ", STYLE_LINE)); + } else { + whole_line.push_span(Span::styled(" ", STYLE_LINE)); + } + } else { + if i == 0 { + whole_line.push_span(Span::styled(" ├─ ", STYLE_LINE)); + } else { + whole_line.push_span(Span::styled(" │ ", STYLE_LINE)); + } + } + + whole_line.extend(line); + text.push_line(whole_line); + } } + + text } -fn calculate_max_account_len(tree: &BalanceTree, indent_amount: usize, indent_level: usize) -> usize { +// fn print_tree(tree: &BalanceTree, ledger: &Ledger, level: usize, amount_pos: usize) { +// let relative_amount_pos = amount_pos - (level*2 + tree.name.len()); +// let main_line = format!("{}{} {}", " ".repeat(level), tree.name, "─".repeat(relative_amount_pos)); +// let tree_amounts = tree.amounts.as_ref().unwrap().iter().filter(|v| !ledger.round_amount(v).value.is_zero()); +// let tree_amounts_count = tree_amounts.clone().count(); +// for (i, amount) in tree_amounts.enumerate() { +// let mut line = String::new(); +// if i == 0 { +// line += &main_line; +// if tree_amounts_count > 1 { +// line += "┬" +// } else { +// line += "─" +// } +// } else { +// line += &" ".repeat(amount_pos); +// if i == tree_amounts_count - 1 { +// line += " └" +// } else { +// line += " │" +// } +// } + +// line += &format!(" {}", ledger.format_amount(amount)); + +// println!("{}", line); +// } + +// // println!("{}{} {} {:?}", " ".repeat(level), tree.name, "-".repeat(relative_amount_pos), tree.amounts); +// let mut children: Vec<&BalanceTree> = tree.children.iter().collect(); +// children.sort_by(|a, b| a.name.cmp(&b.name)); +// for child in children { +// print_tree(&child, ledger, level + 1, amount_pos); +// } +// } + +fn balance_tree_to_str_tree(tree: BalanceTree, ledger: &Ledger) -> BalanceTreeStr { + let amounts = tree.amounts.map(|v| v.iter().map(|a| ledger.format_amount(a)).collect()); + let children = tree.children.into_iter().map(|c| balance_tree_to_str_tree(c, ledger)).collect(); + + BalanceTreeStr{amounts, name: tree.name, children} +} + +fn calculate_max_account_len(tree: &BalanceTreeStr, indent_amount: usize, indent_level: usize) -> usize { let current_len = tree.name.len() + indent_amount * indent_level; let mut max_length = current_len; @@ -153,16 +249,39 @@ fn calculate_max_account_len(tree: &BalanceTree, indent_amount: usize, indent_le max_length } +fn calculate_max_decimal_pos(tree: &BalanceTreeStr) -> usize { + let mut max_decimal_pos = 0; + if let Some(amounts) = &tree.amounts { + for (_, decimal_pos) in amounts { + max_decimal_pos = max(max_decimal_pos, *decimal_pos); + } + }; + for child in &tree.children { + let child_max = calculate_max_decimal_pos(child); + max_decimal_pos = max(max_decimal_pos, child_max); + } + max_decimal_pos +} + pub fn format_balance(ledger: &Ledger, account_balances: &HashMap>) -> String { let mut output = String::new(); let mut tree = construct_tree(ledger, account_balances); set_tree_totals(&mut tree); - let max_account_len = calculate_max_account_len(&tree, 2, 0); + + let str_tree = balance_tree_to_str_tree(tree, &ledger); + + let max_account_len = calculate_max_account_len(&str_tree, 4, 0); + let max_decimal_pos = calculate_max_decimal_pos(&str_tree); - println!("{}", max_account_len); - print_tree(&tree, &ledger, 0, max_account_len + 5); + + let text = tree_to_text(&str_tree, &ledger, max_account_len, max_decimal_pos); + + println!("{}", text_to_ansi(&text)); + + // println!("{}", max_account_len); + // print_tree(&tree, &ledger, 0, max_account_len + 5); // println!("{:?}", tree); // let base_account_info: Vec = account_balances @@ -275,3 +394,228 @@ pub fn format_balance(ledger: &Ledger, account_balances: &HashMap String { .join("") }) .collect::>() - .join("/n") + .join("\n") } diff --git a/src/parser/mod.rs b/src/parser/mod.rs new file mode 100644 index 0000000..beb28e7 --- /dev/null +++ b/src/parser/mod.rs @@ -0,0 +1,99 @@ +use chrono::NaiveDate; +use nom::{ + bytes::complete::{escaped, tag, take_while_m_n}, + character::complete::{char, none_of, one_of}, + combinator::{opt, recognize}, + error::{Error, ErrorKind}, + multi::{many0, many1}, + sequence::{delimited, terminated, tuple}, + AsChar, Err, IResult, Parser, +}; +use rust_decimal::Decimal; + +pub fn decimal(input: &str) -> IResult<&str, Decimal> { + let (new_input, decimal_str) = recognize(tuple(( + opt(one_of("+-")), + opt(number_int), + opt(char('.')), + opt(number_int), + ))) + .parse(input)?; + + if decimal_str.contains(',') { + match Decimal::from_str_exact(&decimal_str.replace(",", "")) { + Ok(decimal) => Ok((new_input, decimal)), + Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), + } + } else { + match Decimal::from_str_exact(decimal_str) { + Ok(decimal) => Ok((new_input, decimal)), + Err(_) => Err(Err::Error(Error::new(input, ErrorKind::Eof))), + } + } +} + +pub fn number_int(input: &str) -> IResult<&str, &str> { + recognize(many1(terminated(one_of("0123456789"), many0(one_of("_,")))))(input) +} + +pub fn parse_iso_date(input: &str) -> IResult<&str, NaiveDate> { + let (new_input, (year, _, month, _, day)) = tuple(( + date_year, + opt(tag("-")), + date_month, + opt(tag("-")), + date_day, + )) + .parse(input)?; + + match NaiveDate::from_ymd_opt(year, month, day) { + Some(date) => Ok((new_input, date)), + None => Err(nom::Err::Error(Error::new(input, ErrorKind::Eof))), + } +} + +pub fn quoted_string(input: &str) -> IResult<&str, &str> { + delimited( + tag("\""), + escaped(none_of("\\\""), '\\', tag("\"")), + tag("\""), + ) + .parse(input) +} + +/////////////// +// Private // +/////////////// + +fn take_n_digits(i: &str, n: usize) -> IResult<&str, u32> { + let (i, digits) = take_while_m_n(n, n, AsChar::is_dec_digit)(i)?; + + let res = digits.parse().expect("Invalid ASCII number"); + + Ok((i, res)) +} + +fn date_year(input: &str) -> IResult<&str, i32> { + take_n_digits(input, 4).map(|(str, year)| (str, i32::try_from(year).unwrap())) +} + +fn date_month(input: &str) -> IResult<&str, u32> { + take_n_digits(input, 2) +} + +fn date_day(input: &str) -> IResult<&str, u32> { + take_n_digits(input, 2) +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_string() { + assert_eq!(quoted_string("\"test\"").unwrap().1, "test"); + assert_eq!(quoted_string("\"te\\\"st\"").unwrap().1, "te\\\"st"); + } + +} \ No newline at end of file diff --git a/src/queries/balance.rs b/src/queries/balance.rs index b5295ca..9e544bd 100644 --- a/src/queries/balance.rs +++ b/src/queries/balance.rs @@ -1,14 +1,139 @@ -use std::collections::HashMap; +use std::{collections::HashMap, time::Instant}; use crate::core::{Amount, Ledger}; use chrono::NaiveDate; use rust_decimal_macros::dec; +use super::{ + base::{self, DataValue, Function}, + functions::ComparisonFunction, + transaction::{PostingData, PostingField, TransactionField}, +}; + pub enum Query { StartDate(NaiveDate), EndDate(NaiveDate), } +pub fn balance2( + ledger: &Ledger, + end_date: NaiveDate, + convert_to_unit: Option<&str>, +) -> HashMap> { + let q = ComparisonFunction::new( + "<=", + base::Query::from_field(PostingField::Transaction(TransactionField::Date)), + base::Query::from(DataValue::from(end_date)), + ) + .unwrap(); + + let convert_to_unit = convert_to_unit.map(|u| ledger.get_unit_by_symbol(u).unwrap()); + + let postings = ledger + .get_transactions() + .iter() + .map(|t| { + t.get_postings().iter().map(|p| PostingData { + ledger, + posting: p, + parent_transaction: t, + }) + }) + .flatten(); + + let filtered_postings = + postings.filter(|data| q.evaluate(data).map(|v| bool::from(v)).unwrap_or(false)); + + let mut accounts = HashMap::new(); + for posting_data in filtered_postings { + let posting = posting_data.posting; + let mut amount = *posting.get_amount(); + if let Some(new_unit) = convert_to_unit { + if amount.unit_id != new_unit.get_id() { + let price = ledger.get_price_on_date(end_date, amount.unit_id, new_unit.get_id()); + if let Some(price) = price { + amount = Amount { + value: amount.value * price, + unit_id: new_unit.get_id(), + }; + } + } + } + let account_vals = accounts + .entry(posting.get_account_id()) + .or_insert(HashMap::new()); + let a = account_vals.entry(amount.unit_id).or_insert(dec!(0)); + *a += amount.value; + } + + accounts + .iter() + .map(|(&k, v)| { + ( + k, + v.into_iter() + .map(|(&unit_id, &value)| Amount { value, unit_id }) + .collect(), + ) + }) + .collect() +} + +pub fn balance3(ledger: &Ledger, query: &base::Query) -> HashMap> { + // let q = ComparisonFunction::new( + // "<=", + // base::Query::from_field(PostingField::Transaction(TransactionField::Date)), + // base::Query::from(DataValue::from(end_date)), + // ) + // .unwrap(); + + let t0 = Instant::now(); + + let postings = ledger + .get_transactions() + .iter() + .map(|t| { + t.get_postings().iter().map(|p| PostingData { + ledger, + posting: p, + parent_transaction: t, + }) + }) + .flatten(); + + let t1 = Instant::now(); + + let filtered_postings = + postings.filter(|data| query.evaluate(data).map(|v| bool::from(v)).unwrap_or(false)); + + let t2 = Instant::now(); + + // println!("{:?} {:?}", t1-t0, t2-t1); + + let mut accounts = HashMap::new(); + for posting_data in filtered_postings { + let posting = posting_data.posting; + let amount = posting.get_amount(); + let account_vals = accounts + .entry(posting.get_account_id()) + .or_insert(HashMap::new()); + let a = account_vals.entry(amount.unit_id).or_insert(dec!(0)); + *a += amount.value; + } + + accounts + .iter() + .map(|(&k, v)| { + ( + k, + v.into_iter() + .map(|(&unit_id, &value)| Amount { value, unit_id }) + .collect(), + ) + }) + .collect() +} + pub fn balance(ledger: &Ledger, query: &[Query]) -> HashMap> { let relevant_transactions = ledger.get_transactions().iter().filter(|txn| { query.iter().all(|q| match q { diff --git a/src/queries/base.rs b/src/queries/base.rs new file mode 100644 index 0000000..809972a --- /dev/null +++ b/src/queries/base.rs @@ -0,0 +1,218 @@ +use std::collections::HashMap; + +use chrono::NaiveDate; +use rust_decimal::{prelude::Zero, Decimal}; + +use crate::core::{Amount, CoreError}; + +#[derive(Debug, Clone)] +pub enum StringData<'a> { + Owned(String), + Reference(&'a str) +} + +#[derive(Debug, Clone, PartialEq)] +pub enum DataValue<'a> { + Null, + Integer(u32), + Decimal(Decimal), + Boolean(bool), + String(StringData<'a>), + Date(NaiveDate), + Amount(Amount), + List(Vec>), + Map(HashMap<&'static str, DataValue<'a>>), +} + +pub enum Query<'a, T> { + Field(T), + Value(DataValue<'a>), + Function(Box>), +} + +pub trait Data { + fn get_field(&self, field: &T) -> Result; +} + +pub trait Function { + fn evaluate(&self, context: &dyn Data) -> Result; +} + +// impl ConstantValue { +// pub fn to_bool(&self) -> bool { +// match self { +// ConstantValue::Integer(val) => !val.is_zero(), +// ConstantValue::Decimal(val) => !val.is_zero(), +// ConstantValue::Boolean(val) => *val, +// ConstantValue::String(val) => val.is_empty(), +// ConstantValue::Date(_) => true, +// ConstantValue::Amount(val) => !val.value.is_zero(), +// ConstantValue::List(list) => !list.is_empty(), +// ConstantValue::Map(map) => !map.is_empty(), +// } +// } +// } + +impl<'a> StringData<'a> { + pub fn as_ref(&'a self) -> &'a str { + match self { + StringData::Owned(val) => val.as_str(), + StringData::Reference(val) => val, + } + } +} + +impl<'a> PartialEq for StringData<'a> { + fn eq(&self, other: &Self) -> bool { + let str_self = match self { + StringData::Owned(val) => val.as_str(), + StringData::Reference(val) => val, + }; + let str_other = match other { + StringData::Owned(val) => val.as_str(), + StringData::Reference(val) => val, + }; + str_self.eq(str_other) + } +} + +impl<'a> PartialOrd for StringData<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + let str_self = match self { + StringData::Owned(val) => val.as_str(), + StringData::Reference(val) => val, + }; + let str_other = match other { + StringData::Owned(val) => val.as_str(), + StringData::Reference(val) => val, + }; + str_self.partial_cmp(str_other) + } +} + +impl<'a> PartialOrd for DataValue<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (DataValue::Null, DataValue::Null) => Some(std::cmp::Ordering::Equal), + (DataValue::Integer(val1), DataValue::Integer(val2)) => val1.partial_cmp(val2), + (DataValue::Decimal(val1), DataValue::Decimal(val2)) => val1.partial_cmp(val2), + (DataValue::Boolean(val1), DataValue::Boolean(val2)) => val1.partial_cmp(val2), + (DataValue::String(val1), DataValue::String(val2)) => val1.partial_cmp(val2), + (DataValue::Date(val1), DataValue::Date(val2)) => val1.partial_cmp(val2), + (DataValue::Amount(val1), DataValue::Amount(val2)) => val1.partial_cmp(val2), + (DataValue::List(val1), DataValue::List(val2)) => val1.partial_cmp(val2), + _ => None, + } + } +} + +impl<'a> From<&'a str> for StringData<'a> { + fn from(value: &'a str) -> Self { + StringData::Reference(value) + } +} + +impl<'a> From for StringData<'a> { + fn from(value: String) -> Self { + StringData::Owned(value) + } +} + +impl<'a> From for DataValue<'a> { + fn from(value: u32) -> Self { + DataValue::Integer(value) + } +} + +impl<'a> From for DataValue<'a> { + fn from(value: Decimal) -> Self { + DataValue::Decimal(value) + } +} + +impl<'a> From for DataValue<'a> { + fn from(value: bool) -> Self { + DataValue::Boolean(value) + } +} + +impl<'a> From<&'a str> for DataValue<'a> { + fn from(value: &'a str) -> Self { + DataValue::String(value.into()) + } +} + + +impl<'a> From for DataValue<'a> { + fn from(value: String) -> Self { + DataValue::String(value.into()) + } +} + +impl<'a> From for DataValue<'a> { + fn from(value: NaiveDate) -> Self { + DataValue::Date(value) + } +} + +impl<'a> From for DataValue<'a> { + fn from(value: Amount) -> Self { + DataValue::Amount(value) + } +} + +impl<'a> From> for bool { + fn from(value: DataValue) -> Self { + match value { + DataValue::Null => false, + DataValue::Integer(val) => !val.is_zero(), + DataValue::Decimal(val) => !val.is_zero(), + DataValue::Boolean(val) => val, + DataValue::String(val) => val.as_ref().is_empty(), + DataValue::Date(_) => true, + DataValue::Amount(val) => !val.value.is_zero(), + DataValue::List(list) => !list.is_empty(), + DataValue::Map(map) => !map.is_empty(), + } + } +} + +impl<'a, T> Query<'a, T> { + pub fn evaluate(&self, context: &'a dyn Data) -> Result { + match self { + Query::Field(field) => context.get_field(field), + Query::Value(constant) => Ok(constant.clone()), + Query::Function(function) => function.evaluate(context), + } + } + + pub fn from_field(field: T) -> Self { + Query::Field(field) + } + + pub fn from_fn + Sized + 'static>(function: F) -> Self { + Query::Function(Box::new(function)) + } +} + +impl<'a, T> From> for Query<'a, T> { + fn from(constant: DataValue<'a>) -> Self { + Query::Value(constant) + } +} + +// impl Function for T { +// fn to_value(self) -> Value { +// Value::Function(Box::new(self)) +// } +// } + +// impl> T { + +// } + +// impl + Sized + 'static> From for Value { +// fn from(function: T) -> Self { +// Value::Function(Box::new(function)) +// } +// } diff --git a/src/queries/functions.rs b/src/queries/functions.rs new file mode 100644 index 0000000..138d1ba --- /dev/null +++ b/src/queries/functions.rs @@ -0,0 +1,294 @@ +use std::time::{Duration, Instant}; + +use crate::core::CoreError; +use regex::Regex; + +use super::base::{Data, DataValue, Function, Query, StringData}; + +#[derive(PartialEq, Debug)] +pub enum ComparisonOperator { + EQ, + NEQ, + GT, + LT, + GTE, + LTE, +} + +pub struct ComparisonFunction<'a, Field> { + op: ComparisonOperator, + left: Query<'a, Field>, + right: Query<'a, Field>, +} + +pub enum StringComparisonOperator<'a> { + Func(&'a (dyn Fn(&str) -> bool + 'a)), + Regex(Regex), +} + +pub struct StringComparisonFunction<'a, Field> { + op: StringComparisonOperator<'a>, + val: Query<'a, Field>, +} + +pub struct SubAccountFunction<'a, Field> { + account_name: StringData<'a>, + val: Query<'a, Field>, +} + +pub struct RegexFunction<'a, Field> { + left: Query<'a, Field>, + regex: Regex, +} + +#[derive(PartialEq, Debug)] +pub enum LogicalOperator { + AND, + OR, +} + +pub struct LogicalFunction<'a, Field> { + op: LogicalOperator, + left: Query<'a, Field>, + right: Query<'a, Field>, +} + +pub struct NotFunction<'a, Field> { + value: Query<'a, Field>, +} + +impl<'a, Field> ComparisonFunction<'a, Field> { + pub fn new( + op: &str, + left: Query<'a, Field>, + right: Query<'a, Field>, + ) -> Result { + let op = match op { + "==" => ComparisonOperator::EQ, + "!=" => ComparisonOperator::NEQ, + ">" => ComparisonOperator::GT, + "<" => ComparisonOperator::LT, + ">=" => ComparisonOperator::GTE, + "<=" => ComparisonOperator::LTE, + _ => return Err("Invalid Operator".into()), + }; + Ok(ComparisonFunction { op, left, right }) + } + + pub fn new_op( + op: ComparisonOperator, + left: Query<'a, Field>, + right: Query<'a, Field>, + ) -> Self { + ComparisonFunction { op, left, right } + } +} + +impl<'a, Field> SubAccountFunction<'a, Field> { + pub fn new(account: StringData<'a>, val: Query<'a, Field>) -> Self { + SubAccountFunction { account_name: account, val } + } +} + +impl<'a, Field> StringComparisonFunction<'a, Field> { + pub fn new_func( + val: Query<'a, Field>, + func: &'a (impl Fn(&str) -> bool + 'a), + ) -> Result { + Ok(StringComparisonFunction { val, op: StringComparisonOperator::Func(func) }) + } + + pub fn new_regex(val: Query<'a, Field>, regex: &str) -> Result { + let regex = Regex::new(regex).map_err(|_| CoreError::from("Unable to parse regex"))?; + Ok(StringComparisonFunction { val, op: StringComparisonOperator::Regex(regex) }) + } +} + +impl<'a, Field> RegexFunction<'a, Field> { + pub fn new(left: Query<'a, Field>, regex: &str) -> Result { + let regex = Regex::new(regex).map_err(|_| CoreError::from("Unable to parse regex"))?; + Ok(RegexFunction { left, regex }) + } +} + +impl<'a, Field> Function for ComparisonFunction<'a, Field> { + fn evaluate(&self, context: &dyn Data) -> Result { + let left = self.left.evaluate(context)?; + let right = self.right.evaluate(context)?; + + match self.op { + ComparisonOperator::EQ => Ok(DataValue::Boolean(left == right)), + ComparisonOperator::NEQ => Ok(DataValue::Boolean(left != right)), + ComparisonOperator::GT => Ok(DataValue::Boolean(left > right)), + ComparisonOperator::LT => Ok(DataValue::Boolean(left < right)), + ComparisonOperator::GTE => Ok(DataValue::Boolean(left >= right)), + ComparisonOperator::LTE => Ok(DataValue::Boolean(left <= right)), + } + } +} + +impl<'a, Field> Function for StringComparisonFunction<'a, Field> { + fn evaluate(&self, context: &dyn Data) -> Result { + let val = self.val.evaluate(context)?; + + if let DataValue::String(val) = val { + match &self.op { + StringComparisonOperator::Func(func) => Ok(DataValue::Boolean(func(val.as_ref()))), + StringComparisonOperator::Regex(regex) => { + Ok(DataValue::Boolean(regex.is_match(val.as_ref()))) + } + } + // Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) + } else { + Err("Cannot use REGEX operation on non string types".into()) + } + + // let left = self.left.evaluate(context)?; + + // if let DataValue::String(left) = left { + // Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) + // } else { + // Err("Cannot use REGEX operation on non string types".into()) + // } + } +} + +impl<'a, Field> Function for SubAccountFunction<'a, Field> { + fn evaluate(&self, context: &dyn Data) -> Result { + let val = self.val.evaluate(context)?; + + if let DataValue::String(val) = val { + Ok(DataValue::Boolean( + val.as_ref() + .strip_prefix(self.account_name.as_ref()) + .map(|n| n.is_empty() || n.starts_with(":")) + .unwrap_or(false), + )) + } else { + Err("Cannot compare account name on non string types".into()) + } + } +} + +impl<'a, Field> Function for RegexFunction<'a, Field> { + fn evaluate(&self, context: &dyn Data) -> Result { + let left = self.left.evaluate(context)?; + + if let DataValue::String(left) = left { + Ok(DataValue::Boolean(self.regex.is_match(left.as_ref()))) + } else { + Err("Cannot use REGEX operation on non string types".into()) + } + } +} + +impl<'a, Field> LogicalFunction<'a, Field> { + pub fn new( + op: &str, + left: Query<'a, Field>, + right: Query<'a, Field>, + ) -> Result { + if op.eq_ignore_ascii_case("and") { + Ok(LogicalFunction { op: LogicalOperator::AND, left, right }) + } else if op.eq_ignore_ascii_case("or") { + Ok(LogicalFunction { op: LogicalOperator::OR, left, right }) + } else { + Err("Invalid logical operator".into()) + } + } +} + +impl<'a, Field> NotFunction<'a, Field> { + pub fn new(value: Query<'a, Field>) -> Self { + NotFunction { value } + } +} + +impl<'a, Field> Function for LogicalFunction<'a, Field> { + fn evaluate(&self, context: &dyn Data) -> Result { + // More verbose to try and avoid doing right side computation + let value: bool = match self.op { + LogicalOperator::AND => { + if !bool::from(self.left.evaluate(context)?) { + false + } else { + self.right.evaluate(context)?.into() + } + } + LogicalOperator::OR => { + if bool::from(self.left.evaluate(context)?) { + true + } else { + self.right.evaluate(context)?.into() + } + } + }; + + Ok(DataValue::Boolean(value)) + + // let left = self.left.evaluate(context)?.into(); + // let right = self.right.evaluate(context)?.into(); + + // match self.op { + // LogicalOperator::AND => Ok(DataValue::Boolean(left && right)), + // LogicalOperator::OR => Ok(DataValue::Boolean(left || right)), + // } + } +} + +impl<'a, Field> Function for NotFunction<'a, Field> { + fn evaluate(&self, context: &dyn Data) -> Result { + let value: bool = self.value.evaluate(context)?.into(); + + Ok(DataValue::Boolean(!value)) + } +} + +// Tests section +#[cfg(test)] +mod tests { + use super::*; + + enum TestField {} + + struct TestData {} + + impl Data for TestData { + fn get_field(&self, _: &TestField) -> Result { + Err("".into()) + } + } + + #[test] + fn comparison_function_evaluate() { + let value1: DataValue = 5.into(); + let value2: DataValue = 10.into(); + let context = TestData {}; + + let comparison_function = + ComparisonFunction::::new(">", value1.into(), value2.into()).unwrap(); + assert_eq!(comparison_function.op, ComparisonOperator::GT); + assert_eq!(comparison_function.evaluate(&context), Ok(false.into())); + + let value = Query::from_fn(comparison_function); + assert_eq!(value.evaluate(&context), Ok(false.into())); + } + + #[test] + fn logical_function_evaluate() { + let value1: DataValue = 1.into(); + let value2: DataValue = false.into(); + let context = TestData {}; + + let logical_function_and = + LogicalFunction::::new("and", value1.clone().into(), value2.clone().into()) + .unwrap(); + assert_eq!(logical_function_and.op, LogicalOperator::AND); + assert_eq!(logical_function_and.evaluate(&context), Ok(false.into())); + + let logical_function_or = + LogicalFunction::::new("or", value1.into(), value2.into()).unwrap(); + assert_eq!(logical_function_or.op, LogicalOperator::OR); + assert_eq!(logical_function_or.evaluate(&context), Ok(true.into())); + } +} diff --git a/src/queries/mod.rs b/src/queries/mod.rs index ae8db71..42069e5 100644 --- a/src/queries/mod.rs +++ b/src/queries/mod.rs @@ -1,3 +1,9 @@ mod balance; +mod postings; +pub mod base; +pub mod functions; +pub mod parser; +pub mod transaction; pub use balance::*; +pub use postings::*; diff --git a/src/queries/parser/functions.rs b/src/queries/parser/functions.rs new file mode 100644 index 0000000..51bf1e0 --- /dev/null +++ b/src/queries/parser/functions.rs @@ -0,0 +1,142 @@ +use nom::{ + branch::alt, bytes::complete::{tag, tag_no_case, take_until}, character::complete::space0, error::{Error, ErrorKind}, multi::fold_many0, sequence::{delimited, preceded, tuple}, AsChar, IResult, InputTakeAtPosition, Parser +}; + +use crate::{ + parser::{decimal, parse_iso_date, quoted_string}, + queries::{ + base::{DataValue, Query}, + functions::{ComparisonFunction, ComparisonOperator, LogicalFunction}, + }, +}; + +pub trait ParseField: Sized { + fn parse(input: &str) -> Option; +} + +// fn query<'a, Field: ParseField + 'static>( +// input: &'a str, +// ) -> IResult<&'a str, Query<'a, Field>> { +// // TODO: uncommented, is this right? +// delimited( +// space0, +// alt(( +// parenthesis, +// value.map(|v| v.into()), +// field.map(|v| Query::from_field(v)), +// comparison_function::.map(|v| Query::from_fn(v)), +// logical_function::.map(|v| Query::from_fn(v)), +// )), +// space0, +// ) +// .parse(input) +// } + +fn value<'a>(input: &'a str) -> IResult<&'a str, DataValue<'a>> { + alt(( + tag_no_case("null").map(|_| DataValue::Null), + tag_no_case("true").map(|_| DataValue::Boolean(true)), + tag_no_case("false").map(|_| DataValue::Boolean(false)), + parse_iso_date.map(|v| v.into()), + decimal.map(|v| v.into()), + quoted_string.map(|v| v.into()), + )) + .parse(input) +} + +fn field<'a, Field: ParseField>(input: &str) -> IResult<&str, Field> { + input + .split_at_position1_complete( + |item| !item.is_alphanum() || item != '.', + ErrorKind::AlphaNumeric, + ) + .and_then(|v| { + Field::parse(v.1) + .map(|f| (v.0, f)) + .ok_or(nom::Err::Error(Error::new(input, ErrorKind::Eof))) + }) +} + +// fn parenthesis<'a, Field: ParseField + 'static>( +// input: &'a str, +// ) -> IResult<&'a str, Query<'a, Field>> { +// delimited(tag("("), query, tag(")")).parse(input) +// } + +// fn query_factor<'a, Field: ParseField + 'static>(input: &'a str) -> IResult<&'a str, Query<'a, Field>> { +// delimited( +// space0, +// alt(( +// value.map(|v| v.into()), +// field.map(|v| Query::from_field(v)), +// parenthesis, +// )), +// space0, +// ) +// .parse(input) +// } + +// fn query_comparison<'a, Field: ParseField + 'static>( +// input: &'a str, +// ) -> IResult<&'a str, Query<'a, Field>> { +// let (input, initial) = query_factor(input)?; + +// let op = alt(( +// tag("=").map(|_| ComparisonOperator::EQ), +// tag("!=").map(|_| ComparisonOperator::NEQ), +// tag(">").map(|_| ComparisonOperator::GT), +// tag("<").map(|_| ComparisonOperator::LT), +// tag(">=").map(|_| ComparisonOperator::GTE), +// tag("<=").map(|_| ComparisonOperator::LTE), +// )); + +// loop { + +// } + + +// fold_many0(tuple((op, query_factor)), move || initial, |acc, (op, val)| { +// Query::from_fn(ComparisonFunction::new_op(op, acc, val)) +// }).parse(input) +// } + +// fn query_logical_and<'a, Field: ParseField>( +// input: &str, +// ) -> IResult<&'a str, Query<'a, Field>> { +// let (input_next, lhs) = query_factor(input)?; + +// fold_many0(preceded(tag_no_case("and"), query_factor), init, g) + +// // loop { +// // let rhs_result = tuple(( +// // alt((tag_no_case("and"), tag_no_case("or"))), +// // query_factor, +// // )).parse(input_next); +// // if let Ok(rhs_result) = rhs_result { +// // input_next = rhs_result.0; +// // } else { +// // break; +// // } +// // }; + + +// // tuple(take_until(alt((tag(""))))) +// } + +// fn logical_function<'a, Field: ParseField>( +// input: &str, +// ) -> IResult<&str, LogicalFunction<'a, Field>> { +// let lhs = + +// } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_string() { + assert_eq!(quoted_string("\"test\"").unwrap().1, "test"); + assert_eq!(quoted_string("\"te\\\"st\"").unwrap().1, "te\\\"st"); + } +} diff --git a/src/queries/parser/mod.rs b/src/queries/parser/mod.rs new file mode 100644 index 0000000..87ccb64 --- /dev/null +++ b/src/queries/parser/mod.rs @@ -0,0 +1 @@ +mod functions; \ No newline at end of file diff --git a/src/queries/postings.rs b/src/queries/postings.rs new file mode 100644 index 0000000..ef14537 --- /dev/null +++ b/src/queries/postings.rs @@ -0,0 +1,11 @@ +use crate::core::Ledger; + + +pub struct PostingQuery { + +} + + +pub fn query_postings(ledger: &Ledger, query: PostingQuery) { + +} \ No newline at end of file diff --git a/src/queries/transaction.rs b/src/queries/transaction.rs new file mode 100644 index 0000000..6225001 --- /dev/null +++ b/src/queries/transaction.rs @@ -0,0 +1,78 @@ +use crate::core::{CoreError, Ledger, Posting, Transaction}; + +use super::base::{Data, DataValue}; + +pub enum AccountField { + Name, + OpenDate, + CloseDate, +} + +pub enum PostingField { + Transaction(TransactionField), + Account(AccountField), + Amount, + Cost, + Price, +} + +pub enum TransactionField { + Date, + Flag, + Payee, + Narration, +} + +pub struct PostingData<'a> { + pub posting: &'a Posting, + pub parent_transaction: &'a Transaction, + pub ledger: &'a Ledger, +} + +impl<'a> Data for PostingData<'a> { + fn get_field(&self, field: &PostingField) -> Result { + match field { + PostingField::Transaction(transaction_field) => get_transaction_value(transaction_field, &self.parent_transaction), + PostingField::Account(account_field) => { + let account = self + .ledger + .get_account(self.posting.get_account_id()) + .ok_or_else(|| CoreError::from("Unable to find account"))?; + match account_field { + AccountField::Name => Ok(account.get_name().as_str().into()), + AccountField::OpenDate => Ok(account + .get_open_date() + .map(|v| v.into()) + .unwrap_or(DataValue::Null)), + AccountField::CloseDate => Ok(account + .get_close_date() + .map(|v| v.into()) + .unwrap_or(DataValue::Null)), + } + } + PostingField::Amount => todo!(), + PostingField::Cost => todo!(), + PostingField::Price => todo!(), + } + } +} + +fn get_transaction_value<'a>( + field: &TransactionField, + transaction: &Transaction, +) -> Result, CoreError> { + match field { + TransactionField::Date => Ok(transaction.get_date().into()), + TransactionField::Flag => Ok(char::from(transaction.get_flag()).to_string().into()), + TransactionField::Payee => Ok(transaction + .get_payee() + .clone() + .map(|v| v.into()) + .unwrap_or(DataValue::Null)), + TransactionField::Narration => Ok(transaction + .get_narration() + .clone() + .map(|v| v.into()) + .unwrap_or(DataValue::Null)), + } +}