feat: add receipt OCR pipeline (refs #69)
All checks were successful
Deploy to Staging / Build Images (pull_request) Successful in 32s
Deploy to Staging / Deploy to Staging (pull_request) Successful in 31s
Deploy to Staging / Verify Staging (pull_request) Successful in 2m20s
Deploy to Staging / Notify Staging Ready (pull_request) Successful in 8s
Deploy to Staging / Notify Staging Failure (pull_request) Has been skipped
All checks were successful
Deploy to Staging / Build Images (pull_request) Successful in 32s
Deploy to Staging / Deploy to Staging (pull_request) Successful in 31s
Deploy to Staging / Verify Staging (pull_request) Successful in 2m20s
Deploy to Staging / Notify Staging Ready (pull_request) Successful in 8s
Deploy to Staging / Notify Staging Failure (pull_request) Has been skipped
Implement receipt-specific OCR extraction for fuel receipts: - Pattern matching modules for date, currency, and fuel data extraction - Receipt-optimized image preprocessing for thermal receipts - POST /extract/receipt endpoint with field extraction - Confidence scoring per extracted field - Cross-validation of fuel receipt data - Unit tests for all pattern matchers Extracted fields: merchantName, transactionDate, totalAmount, fuelQuantity, pricePerUnit, fuelGrade Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
198
ocr/tests/test_currency_patterns.py
Normal file
198
ocr/tests/test_currency_patterns.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Tests for currency pattern matching."""
|
||||
import pytest
|
||||
|
||||
from app.patterns.currency_patterns import CurrencyPatternMatcher, currency_matcher
|
||||
|
||||
|
||||
class TestCurrencyPatternMatcher:
|
||||
"""Test currency and amount extraction."""
|
||||
|
||||
def test_total_explicit(self) -> None:
|
||||
"""Test 'TOTAL $XX.XX' pattern."""
|
||||
text = "TOTAL $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
assert result.confidence > 0.9
|
||||
assert result.label == "TOTAL"
|
||||
|
||||
def test_total_with_colon(self) -> None:
|
||||
"""Test 'TOTAL: $XX.XX' pattern."""
|
||||
text = "TOTAL: $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
|
||||
def test_total_without_dollar_sign(self) -> None:
|
||||
"""Test 'TOTAL 45.67' pattern."""
|
||||
text = "TOTAL 45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
|
||||
def test_amount_due(self) -> None:
|
||||
"""Test 'AMOUNT DUE' pattern."""
|
||||
text = "AMOUNT DUE: $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
assert result.label == "AMOUNT DUE"
|
||||
|
||||
def test_sale_pattern(self) -> None:
|
||||
"""Test 'SALE $XX.XX' pattern."""
|
||||
text = "SALE $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
|
||||
def test_grand_total(self) -> None:
|
||||
"""Test 'GRAND TOTAL' pattern."""
|
||||
text = "GRAND TOTAL $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
assert result.label == "GRAND TOTAL"
|
||||
|
||||
def test_total_sale(self) -> None:
|
||||
"""Test 'TOTAL SALE' pattern."""
|
||||
text = "TOTAL SALE: $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
|
||||
def test_balance_due(self) -> None:
|
||||
"""Test 'BALANCE DUE' pattern."""
|
||||
text = "BALANCE DUE $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
|
||||
def test_multiple_amounts_picks_total(self) -> None:
|
||||
"""Test that labeled total is preferred over generic amounts."""
|
||||
text = """
|
||||
REGULAR 87
|
||||
10.500 GAL @ $3.67
|
||||
SUBTOTAL $38.54
|
||||
TAX $0.00
|
||||
TOTAL $38.54
|
||||
"""
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 38.54
|
||||
assert result.pattern_name == "total_explicit"
|
||||
|
||||
def test_all_amounts(self) -> None:
|
||||
"""Test extracting all amounts from receipt."""
|
||||
text = """
|
||||
SUBTOTAL $35.00
|
||||
TAX $3.54
|
||||
TOTAL $38.54
|
||||
"""
|
||||
results = currency_matcher.extract_all_amounts(text)
|
||||
|
||||
# Should find TOTAL and possibly others
|
||||
assert len(results) >= 1
|
||||
assert any(r.value == 38.54 for r in results)
|
||||
|
||||
def test_comma_thousand_separator(self) -> None:
|
||||
"""Test amounts with thousand separators."""
|
||||
text = "TOTAL $1,234.56"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 1234.56
|
||||
|
||||
def test_reasonable_total_range(self) -> None:
|
||||
"""Test that unreasonable totals are filtered."""
|
||||
# Very small amount
|
||||
text = "TOTAL $0.05"
|
||||
result = currency_matcher.extract_total(text)
|
||||
assert result is None # Too small for fuel receipt
|
||||
|
||||
# Reasonable amount
|
||||
text = "TOTAL $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
assert result is not None
|
||||
|
||||
def test_receipt_context_extraction(self) -> None:
|
||||
"""Test extraction from realistic receipt text."""
|
||||
text = """
|
||||
SHELL
|
||||
123 MAIN ST
|
||||
DATE: 01/15/2024
|
||||
|
||||
UNLEADED 87
|
||||
10.500 GAL
|
||||
@ $3.679/GAL
|
||||
|
||||
FUEL TOTAL $38.63
|
||||
TAX $0.00
|
||||
TOTAL $38.63
|
||||
|
||||
DEBIT CARD
|
||||
************1234
|
||||
"""
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 38.63
|
||||
|
||||
def test_no_total_returns_largest(self) -> None:
|
||||
"""Test fallback to largest amount when no labeled total."""
|
||||
text = """
|
||||
$10.50
|
||||
$5.00
|
||||
$45.67
|
||||
"""
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
# Should infer largest reasonable amount as total
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
assert result.confidence < 0.7 # Lower confidence for inferred
|
||||
|
||||
def test_no_amounts_returns_none(self) -> None:
|
||||
"""Test that text without amounts returns None."""
|
||||
text = "SHELL STATION\nPUMP 5"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases in currency parsing."""
|
||||
|
||||
def test_european_format(self) -> None:
|
||||
"""Test European format (comma as decimal)."""
|
||||
# European: 45,67 means 45.67
|
||||
text = "TOTAL 45,67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
|
||||
def test_spaces_in_amount(self) -> None:
|
||||
"""Test handling of spaces around amounts."""
|
||||
text = "TOTAL $ 45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 45.67
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
"""Test case insensitive matching."""
|
||||
for label in ["TOTAL", "Total", "total"]:
|
||||
text = f"{label} $45.67"
|
||||
result = currency_matcher.extract_total(text)
|
||||
|
||||
assert result is not None, f"Failed for {label}"
|
||||
assert result.value == 45.67
|
||||
163
ocr/tests/test_date_patterns.py
Normal file
163
ocr/tests/test_date_patterns.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for date pattern matching."""
|
||||
import pytest
|
||||
|
||||
from app.patterns.date_patterns import DatePatternMatcher, date_matcher
|
||||
|
||||
|
||||
class TestDatePatternMatcher:
|
||||
"""Test date pattern extraction."""
|
||||
|
||||
def test_mm_dd_yyyy_slash(self) -> None:
|
||||
"""Test MM/DD/YYYY format."""
|
||||
text = "DATE: 01/15/2024"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
assert result.confidence > 0.9
|
||||
|
||||
def test_mm_dd_yy_slash(self) -> None:
|
||||
"""Test MM/DD/YY format with 2-digit year."""
|
||||
text = "01/15/24"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
|
||||
def test_mm_dd_yyyy_dash(self) -> None:
|
||||
"""Test MM-DD-YYYY format."""
|
||||
text = "01-15-2024"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
|
||||
def test_iso_format(self) -> None:
|
||||
"""Test ISO YYYY-MM-DD format."""
|
||||
text = "2024-01-15"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
assert result.confidence > 0.95
|
||||
|
||||
def test_month_name_format(self) -> None:
|
||||
"""Test 'Jan 15, 2024' format."""
|
||||
text = "Jan 15, 2024"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
|
||||
def test_month_name_no_comma(self) -> None:
|
||||
"""Test 'Jan 15 2024' format without comma."""
|
||||
text = "Jan 15 2024"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
|
||||
def test_day_month_year_format(self) -> None:
|
||||
"""Test '15 Jan 2024' format."""
|
||||
text = "15 Jan 2024"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
|
||||
def test_full_month_name(self) -> None:
|
||||
"""Test full month name like 'January'."""
|
||||
text = "January 15, 2024"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
|
||||
def test_multiple_dates_returns_best(self) -> None:
|
||||
"""Test that multiple dates returns highest confidence."""
|
||||
text = "Date: 01/15/2024\nExpires: 01/15/2025"
|
||||
results = date_matcher.extract_dates(text)
|
||||
|
||||
assert len(results) == 2
|
||||
# Both should be valid
|
||||
assert all(r.confidence > 0.5 for r in results)
|
||||
|
||||
def test_invalid_date_rejected(self) -> None:
|
||||
"""Test that invalid dates are rejected."""
|
||||
text = "13/45/2024" # Invalid month/day
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_receipt_context_text(self) -> None:
|
||||
"""Test date extraction from realistic receipt text."""
|
||||
text = """
|
||||
SHELL STATION
|
||||
123 MAIN ST
|
||||
DATE: 01/15/2024
|
||||
TIME: 14:32
|
||||
PUMP #5
|
||||
REGULAR 87
|
||||
10.500 GAL
|
||||
TOTAL $38.50
|
||||
"""
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-01-15"
|
||||
|
||||
def test_no_date_returns_none(self) -> None:
|
||||
"""Test that text without dates returns None."""
|
||||
text = "SHELL STATION\nTOTAL $38.50"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_confidence_boost_near_keyword(self) -> None:
|
||||
"""Test confidence boost when date is near DATE keyword."""
|
||||
text_with_keyword = "DATE: 01/15/2024"
|
||||
text_without = "01/15/2024"
|
||||
|
||||
result_with = date_matcher.extract_best_date(text_with_keyword)
|
||||
result_without = date_matcher.extract_best_date(text_without)
|
||||
|
||||
assert result_with is not None
|
||||
assert result_without is not None
|
||||
# Keyword proximity should boost confidence
|
||||
assert result_with.confidence >= result_without.confidence
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases in date parsing."""
|
||||
|
||||
def test_year_2000(self) -> None:
|
||||
"""Test 2-digit year 00 is parsed as 2000."""
|
||||
text = "01/15/00"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2000-01-15"
|
||||
|
||||
def test_leap_year_date(self) -> None:
|
||||
"""Test Feb 29 on leap year."""
|
||||
text = "02/29/2024" # 2024 is a leap year
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "2024-02-29"
|
||||
|
||||
def test_leap_year_invalid(self) -> None:
|
||||
"""Test Feb 29 on non-leap year is rejected."""
|
||||
text = "02/29/2023" # 2023 is not a leap year
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_september_abbrev(self) -> None:
|
||||
"""Test September abbreviation (Sept vs Sep)."""
|
||||
for abbrev in ["Sep", "Sept", "September"]:
|
||||
text = f"{abbrev} 15, 2024"
|
||||
result = date_matcher.extract_best_date(text)
|
||||
|
||||
assert result is not None, f"Failed for {abbrev}"
|
||||
assert result.value == "2024-09-15"
|
||||
327
ocr/tests/test_fuel_patterns.py
Normal file
327
ocr/tests/test_fuel_patterns.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Tests for fuel-specific pattern matching."""
|
||||
import pytest
|
||||
|
||||
from app.patterns.fuel_patterns import FuelPatternMatcher, fuel_matcher
|
||||
|
||||
|
||||
class TestFuelQuantityExtraction:
|
||||
"""Test fuel quantity (gallons/liters) extraction."""
|
||||
|
||||
def test_gallons_suffix(self) -> None:
|
||||
"""Test 'XX.XXX GAL' pattern."""
|
||||
text = "10.500 GAL"
|
||||
result = fuel_matcher.extract_gallons(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 10.5
|
||||
assert result.unit == "GAL"
|
||||
assert result.confidence > 0.9
|
||||
|
||||
def test_gallons_full_word(self) -> None:
|
||||
"""Test 'XX.XXX GALLONS' pattern."""
|
||||
text = "10.500 GALLONS"
|
||||
result = fuel_matcher.extract_gallons(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 10.5
|
||||
|
||||
def test_gallons_prefix(self) -> None:
|
||||
"""Test 'GALLONS: XX.XXX' pattern."""
|
||||
text = "GALLONS: 10.500"
|
||||
result = fuel_matcher.extract_gallons(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 10.5
|
||||
|
||||
def test_gal_prefix(self) -> None:
|
||||
"""Test 'GAL: XX.XXX' pattern."""
|
||||
text = "GAL: 10.500"
|
||||
result = fuel_matcher.extract_gallons(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 10.5
|
||||
|
||||
def test_volume_label(self) -> None:
|
||||
"""Test 'VOLUME XX.XXX' pattern."""
|
||||
text = "VOLUME: 10.500"
|
||||
result = fuel_matcher.extract_gallons(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 10.5
|
||||
|
||||
def test_liters_suffix(self) -> None:
|
||||
"""Test 'XX.XX L' pattern."""
|
||||
text = "40.5 L"
|
||||
result = fuel_matcher.extract_liters(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 40.5
|
||||
assert result.unit == "L"
|
||||
|
||||
def test_liters_full_word(self) -> None:
|
||||
"""Test 'XX.XX LITERS' pattern."""
|
||||
text = "40.5 LITERS"
|
||||
result = fuel_matcher.extract_liters(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 40.5
|
||||
|
||||
def test_quantity_prefers_gallons(self) -> None:
|
||||
"""Test extract_quantity prefers gallons for US receipts."""
|
||||
text = "10.500 GAL"
|
||||
result = fuel_matcher.extract_quantity(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.unit == "GAL"
|
||||
|
||||
def test_reasonable_quantity_filter(self) -> None:
|
||||
"""Test unreasonable quantities are filtered."""
|
||||
# Too small
|
||||
text = "0.001 GAL"
|
||||
result = fuel_matcher.extract_gallons(text)
|
||||
assert result is None
|
||||
|
||||
# Too large
|
||||
text = "100.0 GAL"
|
||||
result = fuel_matcher.extract_gallons(text)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFuelPriceExtraction:
|
||||
"""Test price per unit extraction."""
|
||||
|
||||
def test_price_per_gal_dollar_sign(self) -> None:
|
||||
"""Test '$X.XXX/GAL' pattern."""
|
||||
text = "$3.679/GAL"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 3.679
|
||||
assert result.unit == "GAL"
|
||||
assert result.confidence > 0.95
|
||||
|
||||
def test_price_per_gal_no_dollar(self) -> None:
|
||||
"""Test 'X.XXX/GAL' pattern."""
|
||||
text = "3.679/GAL"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 3.679
|
||||
|
||||
def test_labeled_price_gal(self) -> None:
|
||||
"""Test 'PRICE/GAL $X.XXX' pattern."""
|
||||
text = "PRICE/GAL $3.679"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 3.679
|
||||
|
||||
def test_unit_price(self) -> None:
|
||||
"""Test 'UNIT PRICE $X.XXX' pattern."""
|
||||
text = "UNIT PRICE: $3.679"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 3.679
|
||||
|
||||
def test_at_price(self) -> None:
|
||||
"""Test '@ $X.XXX' pattern."""
|
||||
text = "10.500 GAL @ $3.679"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 3.679
|
||||
|
||||
def test_ppg_pattern(self) -> None:
|
||||
"""Test 'PPG $X.XXX' pattern."""
|
||||
text = "PPG: $3.679"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == 3.679
|
||||
|
||||
def test_reasonable_price_filter(self) -> None:
|
||||
"""Test unreasonable prices are filtered."""
|
||||
# Too low
|
||||
text = "$0.50/GAL"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
assert result is None
|
||||
|
||||
# Too high
|
||||
text = "$15.00/GAL"
|
||||
result = fuel_matcher.extract_price_per_unit(text)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFuelGradeExtraction:
|
||||
"""Test fuel grade/octane extraction."""
|
||||
|
||||
def test_regular_87(self) -> None:
|
||||
"""Test 'REGULAR 87' pattern."""
|
||||
text = "REGULAR 87"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "87"
|
||||
assert "Regular" in result.display_name
|
||||
|
||||
def test_reg_87(self) -> None:
|
||||
"""Test 'REG 87' pattern."""
|
||||
text = "REG 87"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "87"
|
||||
|
||||
def test_unleaded_87(self) -> None:
|
||||
"""Test 'UNLEADED 87' pattern."""
|
||||
text = "UNLEADED 87"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "87"
|
||||
|
||||
def test_plus_89(self) -> None:
|
||||
"""Test 'PLUS 89' pattern."""
|
||||
text = "PLUS 89"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "89"
|
||||
assert "Plus" in result.display_name
|
||||
|
||||
def test_midgrade_89(self) -> None:
|
||||
"""Test 'MIDGRADE 89' pattern."""
|
||||
text = "MIDGRADE 89"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "89"
|
||||
|
||||
def test_premium_93(self) -> None:
|
||||
"""Test 'PREMIUM 93' pattern."""
|
||||
text = "PREMIUM 93"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "93"
|
||||
assert "Premium" in result.display_name
|
||||
|
||||
def test_super_93(self) -> None:
|
||||
"""Test 'SUPER 93' pattern."""
|
||||
text = "SUPER 93"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "93"
|
||||
|
||||
def test_diesel(self) -> None:
|
||||
"""Test 'DIESEL' pattern."""
|
||||
text = "DIESEL #2"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "DIESEL"
|
||||
assert "Diesel" in result.display_name
|
||||
|
||||
def test_e85(self) -> None:
|
||||
"""Test 'E85' ethanol pattern."""
|
||||
text = "E85"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "E85"
|
||||
|
||||
def test_octane_only(self) -> None:
|
||||
"""Test standalone octane number."""
|
||||
text = "87 OCTANE"
|
||||
result = fuel_matcher.extract_grade(text)
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "87"
|
||||
|
||||
|
||||
class TestMerchantExtraction:
|
||||
"""Test gas station name extraction."""
|
||||
|
||||
def test_shell_station(self) -> None:
|
||||
"""Test Shell station detection."""
|
||||
text = "SHELL\n123 MAIN ST"
|
||||
result = fuel_matcher.extract_merchant_name(text)
|
||||
|
||||
assert result is not None
|
||||
merchant, confidence = result
|
||||
assert "SHELL" in merchant.upper()
|
||||
assert confidence > 0.8
|
||||
|
||||
def test_chevron_station(self) -> None:
|
||||
"""Test Chevron station detection."""
|
||||
text = "CHEVRON #12345\n456 OAK AVE"
|
||||
result = fuel_matcher.extract_merchant_name(text)
|
||||
|
||||
assert result is not None
|
||||
merchant, confidence = result
|
||||
assert "CHEVRON" in merchant.upper()
|
||||
|
||||
def test_costco_gas(self) -> None:
|
||||
"""Test Costco gas detection."""
|
||||
text = "COSTCO GASOLINE\n789 WAREHOUSE BLVD"
|
||||
result = fuel_matcher.extract_merchant_name(text)
|
||||
|
||||
assert result is not None
|
||||
merchant, confidence = result
|
||||
assert "COSTCO" in merchant.upper()
|
||||
|
||||
def test_unknown_station_fallback(self) -> None:
|
||||
"""Test fallback to first line for unknown stations."""
|
||||
text = "JOE'S GAS\n123 MAIN ST"
|
||||
result = fuel_matcher.extract_merchant_name(text)
|
||||
|
||||
assert result is not None
|
||||
merchant, confidence = result
|
||||
assert "JOE'S GAS" in merchant
|
||||
assert confidence < 0.7 # Lower confidence for unknown
|
||||
|
||||
|
||||
class TestReceiptContextExtraction:
|
||||
"""Test extraction from realistic receipt text."""
|
||||
|
||||
def test_full_receipt_extraction(self) -> None:
|
||||
"""Test all fields from complete receipt text."""
|
||||
text = """
|
||||
SHELL
|
||||
123 MAIN STREET
|
||||
ANYTOWN, USA 12345
|
||||
|
||||
DATE: 01/15/2024
|
||||
TIME: 14:32
|
||||
PUMP #5
|
||||
|
||||
REGULAR 87
|
||||
10.500 GAL @ $3.679/GAL
|
||||
|
||||
FUEL TOTAL $38.63
|
||||
TAX $0.00
|
||||
TOTAL $38.63
|
||||
|
||||
DEBIT CARD
|
||||
************1234
|
||||
APPROVED
|
||||
"""
|
||||
|
||||
# Test all extractors on this text
|
||||
quantity = fuel_matcher.extract_quantity(text)
|
||||
assert quantity is not None
|
||||
assert quantity.value == 10.5
|
||||
|
||||
price = fuel_matcher.extract_price_per_unit(text)
|
||||
assert price is not None
|
||||
assert price.value == 3.679
|
||||
|
||||
grade = fuel_matcher.extract_grade(text)
|
||||
assert grade is not None
|
||||
assert grade.value == "87"
|
||||
|
||||
merchant = fuel_matcher.extract_merchant_name(text)
|
||||
assert merchant is not None
|
||||
assert "SHELL" in merchant[0].upper()
|
||||
339
ocr/tests/test_receipt_extraction.py
Normal file
339
ocr/tests/test_receipt_extraction.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Tests for receipt extraction pipeline."""
|
||||
import io
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.extractors.receipt_extractor import (
|
||||
ReceiptExtractor,
|
||||
ReceiptExtractionResult,
|
||||
receipt_extractor,
|
||||
)
|
||||
from app.extractors.fuel_receipt import (
|
||||
FuelReceiptExtractor,
|
||||
FuelReceiptValidation,
|
||||
fuel_receipt_extractor,
|
||||
)
|
||||
|
||||
|
||||
class TestReceiptExtractor:
|
||||
"""Test the receipt extraction pipeline."""
|
||||
|
||||
def test_detect_receipt_type_fuel(self) -> None:
|
||||
"""Test fuel receipt type detection."""
|
||||
text = """
|
||||
SHELL STATION
|
||||
REGULAR 87
|
||||
10.500 GAL
|
||||
TOTAL $38.50
|
||||
"""
|
||||
extractor = ReceiptExtractor()
|
||||
receipt_type = extractor._detect_receipt_type(text)
|
||||
|
||||
assert receipt_type == "fuel"
|
||||
|
||||
def test_detect_receipt_type_unknown(self) -> None:
|
||||
"""Test unknown receipt type detection."""
|
||||
text = """
|
||||
WALMART
|
||||
GROCERIES
|
||||
MILK $3.99
|
||||
BREAD $2.50
|
||||
TOTAL $6.49
|
||||
"""
|
||||
extractor = ReceiptExtractor()
|
||||
receipt_type = extractor._detect_receipt_type(text)
|
||||
|
||||
assert receipt_type == "unknown"
|
||||
|
||||
def test_extract_fuel_fields(self) -> None:
|
||||
"""Test fuel field extraction from OCR text."""
|
||||
text = """
|
||||
SHELL
|
||||
123 MAIN ST
|
||||
DATE: 01/15/2024
|
||||
REGULAR 87
|
||||
10.500 GAL @ $3.679
|
||||
TOTAL $38.63
|
||||
"""
|
||||
extractor = ReceiptExtractor()
|
||||
fields = extractor._extract_fuel_fields(text)
|
||||
|
||||
assert "merchantName" in fields
|
||||
assert "transactionDate" in fields
|
||||
assert "totalAmount" in fields
|
||||
assert "fuelQuantity" in fields
|
||||
assert "pricePerUnit" in fields
|
||||
assert "fuelGrade" in fields
|
||||
|
||||
assert fields["totalAmount"].value == 38.63
|
||||
assert fields["fuelQuantity"].value == 10.5
|
||||
assert fields["fuelGrade"].value == "87"
|
||||
|
||||
def test_extract_generic_fields(self) -> None:
|
||||
"""Test generic field extraction."""
|
||||
text = """
|
||||
WALMART
|
||||
01/15/2024
|
||||
TOTAL $25.99
|
||||
"""
|
||||
extractor = ReceiptExtractor()
|
||||
fields = extractor._extract_generic_fields(text)
|
||||
|
||||
assert "transactionDate" in fields
|
||||
assert "totalAmount" in fields
|
||||
assert fields["totalAmount"].value == 25.99
|
||||
|
||||
def test_calculated_price_per_unit(self) -> None:
|
||||
"""Test price per unit calculation when not explicitly stated."""
|
||||
text = """
|
||||
SHELL
|
||||
DATE: 01/15/2024
|
||||
10.000 GAL
|
||||
TOTAL $35.00
|
||||
"""
|
||||
extractor = ReceiptExtractor()
|
||||
fields = extractor._extract_fuel_fields(text)
|
||||
|
||||
assert "pricePerUnit" in fields
|
||||
# 35.00 / 10.000 = 3.50
|
||||
assert abs(fields["pricePerUnit"].value - 3.50) < 0.01
|
||||
# Calculated values should have lower confidence
|
||||
assert fields["pricePerUnit"].confidence < 0.9
|
||||
|
||||
def test_validate_valid_data(self) -> None:
|
||||
"""Test validation of valid receipt data."""
|
||||
extractor = ReceiptExtractor()
|
||||
|
||||
data = {"totalAmount": 38.63, "transactionDate": "2024-01-15"}
|
||||
assert extractor.validate(data) is True
|
||||
|
||||
def test_validate_invalid_data(self) -> None:
|
||||
"""Test validation of invalid receipt data."""
|
||||
extractor = ReceiptExtractor()
|
||||
|
||||
# Empty dict
|
||||
assert extractor.validate({}) is False
|
||||
# Not a dict
|
||||
assert extractor.validate("invalid") is False
|
||||
|
||||
def test_unsupported_file_type(self) -> None:
|
||||
"""Test handling of unsupported file types."""
|
||||
extractor = ReceiptExtractor()
|
||||
|
||||
with patch.object(extractor, "_detect_mime_type", return_value="application/pdf"):
|
||||
result = extractor.extract(b"fake pdf content")
|
||||
|
||||
assert result.success is False
|
||||
assert "Unsupported file type" in result.error
|
||||
|
||||
|
||||
class TestFuelReceiptExtractor:
|
||||
"""Test fuel receipt specialized extractor."""
|
||||
|
||||
def test_validation_success(self) -> None:
|
||||
"""Test validation passes for consistent data."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=38.63, confidence=0.95),
|
||||
"fuelQuantity": ExtractedField(value=10.5, confidence=0.90),
|
||||
"pricePerUnit": ExtractedField(value=3.679, confidence=0.92),
|
||||
"fuelGrade": ExtractedField(value="87", confidence=0.88),
|
||||
}
|
||||
|
||||
validation = extractor._validate_fuel_receipt(fields)
|
||||
|
||||
assert validation.is_valid is True
|
||||
assert len(validation.issues) == 0
|
||||
assert validation.confidence_score == 1.0
|
||||
|
||||
def test_validation_math_mismatch(self) -> None:
|
||||
"""Test validation catches total != quantity * price."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=50.00, confidence=0.95),
|
||||
"fuelQuantity": ExtractedField(value=10.0, confidence=0.90),
|
||||
"pricePerUnit": ExtractedField(value=3.00, confidence=0.92), # Should be $30
|
||||
}
|
||||
|
||||
validation = extractor._validate_fuel_receipt(fields)
|
||||
|
||||
assert validation.is_valid is False
|
||||
assert any("doesn't match" in issue for issue in validation.issues)
|
||||
|
||||
def test_validation_quantity_too_small(self) -> None:
|
||||
"""Test validation catches too-small quantity."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=1.00, confidence=0.95),
|
||||
"fuelQuantity": ExtractedField(value=0.1, confidence=0.90),
|
||||
}
|
||||
|
||||
validation = extractor._validate_fuel_receipt(fields)
|
||||
|
||||
assert any("too small" in issue for issue in validation.issues)
|
||||
|
||||
def test_validation_quantity_too_large(self) -> None:
|
||||
"""Test validation warns on very large quantity."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=150.00, confidence=0.95),
|
||||
"fuelQuantity": ExtractedField(value=45.0, confidence=0.90),
|
||||
}
|
||||
|
||||
validation = extractor._validate_fuel_receipt(fields)
|
||||
|
||||
assert any("unusually large" in issue for issue in validation.issues)
|
||||
|
||||
def test_validation_price_too_low(self) -> None:
|
||||
"""Test validation catches too-low price."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=10.00, confidence=0.95),
|
||||
"pricePerUnit": ExtractedField(value=1.00, confidence=0.90),
|
||||
}
|
||||
|
||||
validation = extractor._validate_fuel_receipt(fields)
|
||||
|
||||
assert any("too low" in issue for issue in validation.issues)
|
||||
|
||||
def test_validation_unknown_grade(self) -> None:
|
||||
"""Test validation catches unknown fuel grade."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=38.00, confidence=0.95),
|
||||
"fuelGrade": ExtractedField(value="95", confidence=0.70), # Not valid US grade
|
||||
}
|
||||
|
||||
validation = extractor._validate_fuel_receipt(fields)
|
||||
|
||||
assert any("Unknown fuel grade" in issue for issue in validation.issues)
|
||||
|
||||
def test_confidence_adjustment_boost(self) -> None:
|
||||
"""Test confidence boost when validation passes."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=38.63, confidence=0.80),
|
||||
}
|
||||
validation = FuelReceiptValidation(
|
||||
is_valid=True, issues=[], confidence_score=1.0
|
||||
)
|
||||
|
||||
adjusted = extractor._adjust_confidences(fields, validation)
|
||||
|
||||
# Should be boosted by 1.1
|
||||
assert adjusted["totalAmount"].confidence == min(1.0, 0.80 * 1.1)
|
||||
|
||||
def test_confidence_adjustment_reduce(self) -> None:
|
||||
"""Test confidence reduction when validation fails."""
|
||||
from app.extractors.receipt_extractor import ExtractedField
|
||||
|
||||
extractor = FuelReceiptExtractor()
|
||||
fields = {
|
||||
"totalAmount": ExtractedField(value=50.00, confidence=0.90),
|
||||
}
|
||||
validation = FuelReceiptValidation(
|
||||
is_valid=False, issues=["Math mismatch"], confidence_score=0.7
|
||||
)
|
||||
|
||||
adjusted = extractor._adjust_confidences(fields, validation)
|
||||
|
||||
# Should be reduced by 0.7
|
||||
assert adjusted["totalAmount"].confidence == 0.90 * 0.7
|
||||
|
||||
|
||||
class TestReceiptPreprocessing:
|
||||
"""Test receipt preprocessing integration."""
|
||||
|
||||
def test_preprocessing_result_structure(self) -> None:
|
||||
"""Test preprocessing returns expected structure."""
|
||||
from app.preprocessors.receipt_preprocessor import (
|
||||
receipt_preprocessor,
|
||||
ReceiptPreprocessingResult,
|
||||
)
|
||||
|
||||
# Create a simple test image (1x1 white pixel PNG)
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (100, 100), color="white")
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
image_bytes = buffer.getvalue()
|
||||
|
||||
result = receipt_preprocessor.preprocess(image_bytes)
|
||||
|
||||
assert isinstance(result, ReceiptPreprocessingResult)
|
||||
assert len(result.image_bytes) > 0
|
||||
assert "loaded" in result.preprocessing_applied
|
||||
assert "grayscale" in result.preprocessing_applied
|
||||
assert result.original_width == 100
|
||||
assert result.original_height == 100
|
||||
|
||||
def test_preprocessing_steps_applied(self) -> None:
|
||||
"""Test all preprocessing steps are applied."""
|
||||
from app.preprocessors.receipt_preprocessor import receipt_preprocessor
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (100, 100), color="white")
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
image_bytes = buffer.getvalue()
|
||||
|
||||
result = receipt_preprocessor.preprocess(
|
||||
image_bytes,
|
||||
apply_contrast=True,
|
||||
apply_deskew=True,
|
||||
apply_denoise=True,
|
||||
apply_threshold=True,
|
||||
apply_sharpen=True,
|
||||
)
|
||||
|
||||
# Check that expected steps are in the applied list
|
||||
assert "contrast" in result.preprocessing_applied
|
||||
assert "denoise" in result.preprocessing_applied
|
||||
assert "threshold" in result.preprocessing_applied
|
||||
|
||||
|
||||
class TestEndpointIntegration:
|
||||
"""Test receipt extraction endpoint integration."""
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(self):
|
||||
"""Create test client for FastAPI app."""
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
def test_receipt_endpoint_exists(self, test_client) -> None:
|
||||
"""Test that /extract/receipt endpoint exists."""
|
||||
# Should get 422 (no file) not 404 (not found)
|
||||
response = test_client.post("/extract/receipt")
|
||||
assert response.status_code == 422 # Unprocessable Entity (missing file)
|
||||
|
||||
def test_receipt_endpoint_no_file(self, test_client) -> None:
|
||||
"""Test endpoint returns error when no file provided."""
|
||||
response = test_client.post("/extract/receipt")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_receipt_endpoint_empty_file(self, test_client) -> None:
|
||||
"""Test endpoint returns error for empty file."""
|
||||
response = test_client.post(
|
||||
"/extract/receipt",
|
||||
files={"file": ("receipt.jpg", b"", "image/jpeg")},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Empty file" in response.json()["detail"]
|
||||
Reference in New Issue
Block a user