feat: Receipt OCR Pipeline (#69) #77

Merged
egullickson merged 1 commits from issue-69-receipt-ocr-pipeline into main 2026-02-02 02:47:52 +00:00
16 changed files with 2845 additions and 2 deletions

View File

@@ -1,10 +1,23 @@
"""Extractors package for domain-specific OCR extraction."""
from app.extractors.base import BaseExtractor, ExtractionResult
from app.extractors.vin_extractor import VinExtractor, vin_extractor
from app.extractors.receipt_extractor import (
ReceiptExtractor,
receipt_extractor,
ReceiptExtractionResult,
ExtractedField,
)
from app.extractors.fuel_receipt import FuelReceiptExtractor, fuel_receipt_extractor
__all__ = [
"BaseExtractor",
"ExtractionResult",
"VinExtractor",
"vin_extractor",
"ReceiptExtractor",
"receipt_extractor",
"ReceiptExtractionResult",
"ExtractedField",
"FuelReceiptExtractor",
"fuel_receipt_extractor",
]

View File

@@ -0,0 +1,193 @@
"""Fuel receipt specialization with validation and cross-checking."""
import logging
from dataclasses import dataclass
from typing import Optional
from app.extractors.receipt_extractor import (
ExtractedField,
ReceiptExtractionResult,
receipt_extractor,
)
logger = logging.getLogger(__name__)
@dataclass
class FuelReceiptValidation:
"""Validation result for fuel receipt extraction."""
is_valid: bool
issues: list[str]
confidence_score: float
class FuelReceiptExtractor:
"""Specialized fuel receipt extractor with cross-validation.
Provides additional validation and confidence scoring specific
to fuel receipts by cross-checking extracted values.
"""
# Expected fields for a complete fuel receipt
REQUIRED_FIELDS = ["totalAmount"]
OPTIONAL_FIELDS = [
"merchantName",
"transactionDate",
"fuelQuantity",
"pricePerUnit",
"fuelGrade",
]
def extract(
self,
image_bytes: bytes,
content_type: Optional[str] = None,
) -> ReceiptExtractionResult:
"""
Extract fuel receipt data with validation.
Args:
image_bytes: Raw image bytes
content_type: MIME type
Returns:
ReceiptExtractionResult with fuel-specific extraction
"""
# Use base receipt extractor with fuel hint
result = receipt_extractor.extract(
image_bytes=image_bytes,
content_type=content_type,
receipt_type="fuel",
)
if not result.success:
return result
# Validate and cross-check fuel fields
validation = self._validate_fuel_receipt(result.extracted_fields)
if validation.issues:
logger.warning(
f"Fuel receipt validation issues: {validation.issues}"
)
# Update overall confidence based on validation
result.extracted_fields = self._adjust_confidences(
result.extracted_fields, validation
)
return result
def _validate_fuel_receipt(
self, fields: dict[str, ExtractedField]
) -> FuelReceiptValidation:
"""
Validate extracted fuel receipt fields.
Cross-checks:
- total = quantity * price per unit (within tolerance)
- quantity is reasonable for a single fill-up
- price per unit is within market range
Args:
fields: Extracted fields
Returns:
FuelReceiptValidation with issues and confidence
"""
issues = []
confidence_score = 1.0
# Check required fields
for field_name in self.REQUIRED_FIELDS:
if field_name not in fields:
issues.append(f"Missing required field: {field_name}")
confidence_score *= 0.5
# Cross-validate total = quantity * price
if all(
f in fields for f in ["totalAmount", "fuelQuantity", "pricePerUnit"]
):
total = fields["totalAmount"].value
quantity = fields["fuelQuantity"].value
price = fields["pricePerUnit"].value
calculated_total = quantity * price
tolerance = 0.10 # Allow 10% tolerance for rounding
if abs(total - calculated_total) > total * tolerance:
issues.append(
f"Total ({total}) doesn't match quantity ({quantity}) * "
f"price ({price}) = {calculated_total:.2f}"
)
confidence_score *= 0.7
# Validate quantity is reasonable
if "fuelQuantity" in fields:
quantity = fields["fuelQuantity"].value
if quantity < 0.5:
issues.append(f"Fuel quantity too small: {quantity}")
confidence_score *= 0.6
elif quantity > 40: # 40 gallons is very large tank
issues.append(f"Fuel quantity unusually large: {quantity}")
confidence_score *= 0.8
# Validate price is reasonable (current US market range)
if "pricePerUnit" in fields:
price = fields["pricePerUnit"].value
if price < 1.50:
issues.append(f"Price per unit too low: ${price}")
confidence_score *= 0.7
elif price > 7.00:
issues.append(f"Price per unit unusually high: ${price}")
confidence_score *= 0.8
# Validate fuel grade
if "fuelGrade" in fields:
grade = fields["fuelGrade"].value
valid_grades = ["87", "89", "91", "93", "DIESEL", "E85"]
if grade not in valid_grades:
issues.append(f"Unknown fuel grade: {grade}")
confidence_score *= 0.9
is_valid = len(issues) == 0
return FuelReceiptValidation(
is_valid=is_valid,
issues=issues,
confidence_score=confidence_score,
)
def _adjust_confidences(
self,
fields: dict[str, ExtractedField],
validation: FuelReceiptValidation,
) -> dict[str, ExtractedField]:
"""
Adjust field confidences based on validation.
Args:
fields: Extracted fields
validation: Validation result
Returns:
Fields with adjusted confidences
"""
if validation.is_valid:
# Boost confidences when cross-validation passes
boost = 1.1
else:
# Reduce confidences when there are issues
boost = validation.confidence_score
adjusted = {}
for name, field in fields.items():
adjusted[name] = ExtractedField(
value=field.value,
confidence=min(1.0, field.confidence * boost),
)
return adjusted
# Singleton instance
fuel_receipt_extractor = FuelReceiptExtractor()

View File

@@ -0,0 +1,345 @@
"""Receipt-specific OCR extractor with field extraction."""
import io
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Optional
import magic
import pytesseract
from PIL import Image
from pillow_heif import register_heif_opener
from app.config import settings
from app.extractors.base import BaseExtractor
from app.preprocessors.receipt_preprocessor import receipt_preprocessor
from app.patterns import currency_matcher, date_matcher, fuel_matcher
# Register HEIF/HEIC opener
register_heif_opener()
logger = logging.getLogger(__name__)
@dataclass
class ExtractedField:
"""A single extracted field with confidence."""
value: Any
confidence: float
@dataclass
class ReceiptExtractionResult:
"""Result of receipt extraction."""
success: bool
receipt_type: str = "unknown"
extracted_fields: dict[str, ExtractedField] = field(default_factory=dict)
raw_text: str = ""
processing_time_ms: int = 0
error: Optional[str] = None
class ReceiptExtractor(BaseExtractor):
"""Receipt-specific OCR extractor for fuel and general receipts."""
# Supported MIME types
SUPPORTED_TYPES = {
"image/jpeg",
"image/png",
"image/heic",
"image/heif",
}
def __init__(self) -> None:
"""Initialize receipt extractor."""
pytesseract.pytesseract.tesseract_cmd = settings.tesseract_cmd
def extract(
self,
image_bytes: bytes,
content_type: Optional[str] = None,
receipt_type: Optional[str] = None,
) -> ReceiptExtractionResult:
"""
Extract data from a receipt image.
Args:
image_bytes: Raw image bytes (HEIC, JPEG, PNG)
content_type: MIME type (auto-detected if not provided)
receipt_type: Hint for receipt type ("fuel" for specialized extraction)
Returns:
ReceiptExtractionResult with extracted fields
"""
start_time = time.time()
# Detect content type if not provided
if not content_type:
content_type = self._detect_mime_type(image_bytes)
# Validate content type
if content_type not in self.SUPPORTED_TYPES:
return ReceiptExtractionResult(
success=False,
error=f"Unsupported file type: {content_type}",
processing_time_ms=int((time.time() - start_time) * 1000),
)
try:
# Apply receipt-optimized preprocessing
preprocessing_result = receipt_preprocessor.preprocess(image_bytes)
preprocessed_bytes = preprocessing_result.image_bytes
# Perform OCR
raw_text = self._perform_ocr(preprocessed_bytes)
if not raw_text.strip():
# Try with less aggressive preprocessing
preprocessing_result = receipt_preprocessor.preprocess(
image_bytes,
apply_threshold=False,
)
preprocessed_bytes = preprocessing_result.image_bytes
raw_text = self._perform_ocr(preprocessed_bytes)
if not raw_text.strip():
return ReceiptExtractionResult(
success=False,
error="No text found in image",
processing_time_ms=int((time.time() - start_time) * 1000),
)
# Detect receipt type if not specified
detected_type = receipt_type or self._detect_receipt_type(raw_text)
# Extract fields based on receipt type
if detected_type == "fuel":
extracted_fields = self._extract_fuel_fields(raw_text)
else:
extracted_fields = self._extract_generic_fields(raw_text)
processing_time_ms = int((time.time() - start_time) * 1000)
logger.info(
f"Receipt extraction: type={detected_type}, "
f"fields={len(extracted_fields)}, "
f"time={processing_time_ms}ms"
)
return ReceiptExtractionResult(
success=True,
receipt_type=detected_type,
extracted_fields=extracted_fields,
raw_text=raw_text,
processing_time_ms=processing_time_ms,
)
except Exception as e:
logger.error(f"Receipt extraction failed: {e}", exc_info=True)
return ReceiptExtractionResult(
success=False,
error=str(e),
processing_time_ms=int((time.time() - start_time) * 1000),
)
def _detect_mime_type(self, file_bytes: bytes) -> str:
"""Detect MIME type using python-magic."""
mime = magic.Magic(mime=True)
detected = mime.from_buffer(file_bytes)
return detected or "application/octet-stream"
def _perform_ocr(self, image_bytes: bytes, psm: int = 6) -> str:
"""
Perform OCR on preprocessed image.
Args:
image_bytes: Preprocessed image bytes
psm: Tesseract page segmentation mode
4 = Assume single column of text
6 = Uniform block of text
Returns:
Raw OCR text
"""
image = Image.open(io.BytesIO(image_bytes))
# Configure Tesseract for receipt OCR
# PSM 4 works well for columnar receipt text
config = f"--psm {psm}"
return pytesseract.image_to_string(image, config=config)
def _detect_receipt_type(self, text: str) -> str:
"""
Detect receipt type based on content.
Args:
text: OCR text
Returns:
Receipt type: "fuel", "retail", or "unknown"
"""
text_upper = text.upper()
# Fuel receipt indicators
fuel_keywords = [
"GALLON", "GAL", "FUEL", "GAS", "DIESEL", "UNLEADED",
"REGULAR", "PREMIUM", "OCTANE", "PPG", "PUMP",
]
fuel_score = sum(1 for kw in fuel_keywords if kw in text_upper)
# Check for known gas stations
if fuel_matcher.extract_merchant_name(text):
merchant, _ = fuel_matcher.extract_merchant_name(text)
if any(
station in merchant.upper()
for station in fuel_matcher.STATION_NAMES
):
fuel_score += 3
if fuel_score >= 2:
return "fuel"
return "unknown"
def _extract_fuel_fields(self, text: str) -> dict[str, ExtractedField]:
"""
Extract fuel-specific fields from receipt text.
Args:
text: OCR text
Returns:
Dictionary of extracted fields
"""
fields: dict[str, ExtractedField] = {}
# Extract merchant name
merchant_result = fuel_matcher.extract_merchant_name(text)
if merchant_result:
merchant_name, confidence = merchant_result
fields["merchantName"] = ExtractedField(
value=merchant_name,
confidence=confidence,
)
# Extract transaction date
date_match = date_matcher.extract_best_date(text)
if date_match:
fields["transactionDate"] = ExtractedField(
value=date_match.value,
confidence=date_match.confidence,
)
# Extract total amount
total_match = currency_matcher.extract_total(text)
if total_match:
fields["totalAmount"] = ExtractedField(
value=total_match.value,
confidence=total_match.confidence,
)
# Extract fuel quantity
quantity_match = fuel_matcher.extract_quantity(text)
if quantity_match:
fields["fuelQuantity"] = ExtractedField(
value=quantity_match.value,
confidence=quantity_match.confidence,
)
# Extract price per unit
price_match = fuel_matcher.extract_price_per_unit(text)
if price_match:
fields["pricePerUnit"] = ExtractedField(
value=price_match.value,
confidence=price_match.confidence,
)
# Extract fuel grade
grade_match = fuel_matcher.extract_grade(text)
if grade_match:
fields["fuelGrade"] = ExtractedField(
value=grade_match.value,
confidence=grade_match.confidence,
)
# Calculate derived values if we have enough data
if "totalAmount" in fields and "fuelQuantity" in fields:
if "pricePerUnit" not in fields:
# Calculate price per unit from total and quantity
calculated_price = (
fields["totalAmount"].value / fields["fuelQuantity"].value
)
# Only use if reasonable
if 1.0 <= calculated_price <= 10.0:
fields["pricePerUnit"] = ExtractedField(
value=round(calculated_price, 3),
confidence=min(
fields["totalAmount"].confidence,
fields["fuelQuantity"].confidence,
)
* 0.8, # Lower confidence for calculated value
)
return fields
def _extract_generic_fields(self, text: str) -> dict[str, ExtractedField]:
"""
Extract generic fields from receipt text.
Args:
text: OCR text
Returns:
Dictionary of extracted fields
"""
fields: dict[str, ExtractedField] = {}
# Extract date
date_match = date_matcher.extract_best_date(text)
if date_match:
fields["transactionDate"] = ExtractedField(
value=date_match.value,
confidence=date_match.confidence,
)
# Extract total amount
total_match = currency_matcher.extract_total(text)
if total_match:
fields["totalAmount"] = ExtractedField(
value=total_match.value,
confidence=total_match.confidence,
)
# Try to get merchant from first line
lines = [l.strip() for l in text.split("\n") if l.strip()]
if lines:
fields["merchantName"] = ExtractedField(
value=lines[0][:50],
confidence=0.40,
)
return fields
def validate(self, data: Any) -> bool:
"""
Validate extracted receipt data.
Args:
data: Extracted data to validate
Returns:
True if data has minimum required fields
"""
if not isinstance(data, dict):
return False
# Minimum: must have at least total amount or date
return "totalAmount" in data or "transactionDate" in data
# Singleton instance
receipt_extractor = ReceiptExtractor()

View File

@@ -7,6 +7,8 @@ from .schemas import (
JobStatus,
JobSubmitRequest,
OcrResponse,
ReceiptExtractedField,
ReceiptExtractionResponse,
VinAlternative,
VinExtractionResponse,
)
@@ -19,6 +21,8 @@ __all__ = [
"JobStatus",
"JobSubmitRequest",
"OcrResponse",
"ReceiptExtractedField",
"ReceiptExtractionResponse",
"VinAlternative",
"VinExtractionResponse",
]

View File

@@ -93,3 +93,25 @@ class JobSubmitRequest(BaseModel):
callback_url: Optional[str] = Field(default=None, alias="callbackUrl")
model_config = {"populate_by_name": True}
class ReceiptExtractedField(BaseModel):
"""A single extracted field from a receipt with confidence."""
value: str | float
confidence: float = Field(ge=0.0, le=1.0)
class ReceiptExtractionResponse(BaseModel):
"""Response from receipt extraction endpoint."""
success: bool
receipt_type: str = Field(alias="receiptType")
extracted_fields: dict[str, ReceiptExtractedField] = Field(
default_factory=dict, alias="extractedFields"
)
raw_text: str = Field(alias="rawText")
processing_time_ms: int = Field(alias="processingTimeMs")
error: Optional[str] = None
model_config = {"populate_by_name": True}

View File

@@ -0,0 +1,13 @@
"""Pattern matching modules for receipt field extraction."""
from app.patterns.date_patterns import DatePatternMatcher, date_matcher
from app.patterns.currency_patterns import CurrencyPatternMatcher, currency_matcher
from app.patterns.fuel_patterns import FuelPatternMatcher, fuel_matcher
__all__ = [
"DatePatternMatcher",
"date_matcher",
"CurrencyPatternMatcher",
"currency_matcher",
"FuelPatternMatcher",
"fuel_matcher",
]

View File

@@ -0,0 +1,227 @@
"""Currency and amount pattern matching for receipt extraction."""
import re
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Optional
@dataclass
class AmountMatch:
"""Result of currency/amount pattern matching."""
value: float
raw_match: str
confidence: float
pattern_name: str
label: Optional[str] = None # e.g., "TOTAL", "SUBTOTAL"
class CurrencyPatternMatcher:
"""Extract and normalize currency amounts from receipt text."""
# Total amount patterns (prioritized)
TOTAL_PATTERNS = [
# TOTAL $XX.XX or TOTAL: $XX.XX
(
r"(?:^|\s)TOTAL[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)",
"total_explicit",
0.98,
),
# AMOUNT DUE $XX.XX
(
r"AMOUNT\s*DUE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
"amount_due",
0.95,
),
# SALE $XX.XX
(
r"(?:^|\s)SALE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)",
"sale_explicit",
0.92,
),
# GRAND TOTAL $XX.XX
(
r"GRAND\s*TOTAL[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
"grand_total",
0.97,
),
# TOTAL SALE $XX.XX
(
r"TOTAL\s*SALE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
"total_sale",
0.96,
),
# BALANCE DUE $XX.XX
(
r"BALANCE\s*DUE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
"balance_due",
0.94,
),
# PURCHASE $XX.XX
(
r"(?:^|\s)PURCHASE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)",
"purchase",
0.88,
),
]
# Generic amount patterns (lower priority)
AMOUNT_PATTERNS = [
# $XX.XX (standalone dollar amount)
(
r"\$\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
"dollar_amount",
0.60,
),
# XX.XX (standalone decimal amount)
(
r"(?<![.$\d])(\d{1,6}[.,]\d{2})(?![.\d])",
"decimal_amount",
0.40,
),
]
def extract_total(self, text: str) -> Optional[AmountMatch]:
"""
Extract the total amount from receipt text.
Prioritizes explicit total patterns over generic amounts.
Args:
text: Receipt text to search
Returns:
AmountMatch for total or None if not found
"""
text_upper = text.upper()
# Try total-specific patterns first
for pattern, name, confidence in self.TOTAL_PATTERNS:
match = re.search(pattern, text_upper, re.MULTILINE)
if match:
amount = self._parse_amount(match.group(1))
if amount is not None and self._is_reasonable_total(amount):
return AmountMatch(
value=amount,
raw_match=match.group(0).strip(),
confidence=confidence,
pattern_name=name,
label=self._extract_label(name),
)
# Fall back to finding the largest reasonable amount
all_amounts = self.extract_all_amounts(text)
reasonable = [a for a in all_amounts if self._is_reasonable_total(a.value)]
if reasonable:
# Assume largest amount is the total
reasonable.sort(key=lambda x: x.value, reverse=True)
best = reasonable[0]
# Lower confidence since we're guessing
return AmountMatch(
value=best.value,
raw_match=best.raw_match,
confidence=min(0.60, best.confidence),
pattern_name="inferred_total",
label="TOTAL",
)
return None
def extract_all_amounts(self, text: str) -> list[AmountMatch]:
"""
Extract all currency amounts from text.
Args:
text: Receipt text to search
Returns:
List of AmountMatch objects
"""
matches = []
text_upper = text.upper()
# Check total patterns
for pattern, name, confidence in self.TOTAL_PATTERNS:
for match in re.finditer(pattern, text_upper, re.MULTILINE):
amount = self._parse_amount(match.group(1))
if amount is not None:
matches.append(
AmountMatch(
value=amount,
raw_match=match.group(0).strip(),
confidence=confidence,
pattern_name=name,
label=self._extract_label(name),
)
)
# Check generic amount patterns
for pattern, name, confidence in self.AMOUNT_PATTERNS:
for match in re.finditer(pattern, text_upper):
amount = self._parse_amount(match.group(1))
if amount is not None:
# Skip if already found by a more specific pattern
if not any(abs(m.value - amount) < 0.01 for m in matches):
matches.append(
AmountMatch(
value=amount,
raw_match=match.group(0).strip(),
confidence=confidence,
pattern_name=name,
)
)
return matches
def _parse_amount(self, amount_str: str) -> Optional[float]:
"""Parse amount string to float, handling various formats."""
# Remove any spaces
cleaned = amount_str.strip().replace(" ", "")
# Handle European format (1.234,56) vs US format (1,234.56)
# For US receipts, assume comma is thousands separator
if "," in cleaned and "." in cleaned:
# Determine which is decimal separator (last one)
if cleaned.rfind(",") > cleaned.rfind("."):
# European format
cleaned = cleaned.replace(".", "").replace(",", ".")
else:
# US format
cleaned = cleaned.replace(",", "")
elif "," in cleaned:
# Could be thousands separator or decimal
parts = cleaned.split(",")
if len(parts) == 2 and len(parts[1]) == 2:
# Likely decimal separator
cleaned = cleaned.replace(",", ".")
else:
# Likely thousands separator
cleaned = cleaned.replace(",", "")
try:
amount = float(Decimal(cleaned))
return amount if amount >= 0 else None
except (InvalidOperation, ValueError):
return None
def _is_reasonable_total(self, amount: float) -> bool:
"""Check if amount is a reasonable total for a fuel receipt."""
# Reasonable range: $1 to $500 for typical fuel purchases
return 1.0 <= amount <= 500.0
def _extract_label(self, pattern_name: str) -> str:
"""Extract display label from pattern name."""
labels = {
"total_explicit": "TOTAL",
"amount_due": "AMOUNT DUE",
"sale_explicit": "SALE",
"grand_total": "GRAND TOTAL",
"total_sale": "TOTAL SALE",
"balance_due": "BALANCE DUE",
"purchase": "PURCHASE",
}
return labels.get(pattern_name, "TOTAL")
# Singleton instance
currency_matcher = CurrencyPatternMatcher()

View File

@@ -0,0 +1,186 @@
"""Date pattern matching for receipt extraction."""
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
@dataclass
class DateMatch:
"""Result of date pattern matching."""
value: str # ISO format YYYY-MM-DD
raw_match: str # Original text matched
confidence: float
pattern_name: str
class DatePatternMatcher:
"""Extract and normalize dates from receipt text."""
# Pattern definitions with named groups and confidence weights
PATTERNS = [
# MM/DD/YYYY or MM/DD/YY (most common US format)
(
r"(?P<month>\d{1,2})/(?P<day>\d{1,2})/(?P<year>\d{2,4})",
"mm_dd_yyyy",
0.95,
),
# MM-DD-YYYY or MM-DD-YY
(
r"(?P<month>\d{1,2})-(?P<day>\d{1,2})-(?P<year>\d{2,4})",
"mm_dd_yyyy_dash",
0.90,
),
# YYYY-MM-DD (ISO format)
(
r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})",
"iso_date",
0.98,
),
# Mon DD, YYYY (e.g., Jan 15, 2024)
(
r"(?P<month_name>[A-Za-z]{3})\s+(?P<day>\d{1,2}),?\s+(?P<year>\d{4})",
"month_name_long",
0.85,
),
# DD Mon YYYY (e.g., 15 Jan 2024)
(
r"(?P<day>\d{1,2})\s+(?P<month_name>[A-Za-z]{3})\s+(?P<year>\d{4})",
"day_month_year",
0.85,
),
# MMDDYYYY or MMDDYY (no separators, common in some POS systems)
(
r"(?<!\d)(?P<month>\d{2})(?P<day>\d{2})(?P<year>\d{2,4})(?!\d)",
"compact_date",
0.70,
),
]
MONTH_NAMES = {
"jan": 1, "january": 1,
"feb": 2, "february": 2,
"mar": 3, "march": 3,
"apr": 4, "april": 4,
"may": 5,
"jun": 6, "june": 6,
"jul": 7, "july": 7,
"aug": 8, "august": 8,
"sep": 9, "sept": 9, "september": 9,
"oct": 10, "october": 10,
"nov": 11, "november": 11,
"dec": 12, "december": 12,
}
def extract_dates(self, text: str) -> list[DateMatch]:
"""
Extract all date patterns from text.
Args:
text: Receipt text to search
Returns:
List of DateMatch objects sorted by confidence
"""
matches = []
text_upper = text.upper()
for pattern, name, base_confidence in self.PATTERNS:
for match in re.finditer(pattern, text, re.IGNORECASE):
parsed = self._parse_match(match, name)
if parsed:
year, month, day = parsed
if self._is_valid_date(year, month, day):
# Adjust confidence based on context
confidence = self._adjust_confidence(
base_confidence, text_upper, match.start()
)
matches.append(
DateMatch(
value=f"{year:04d}-{month:02d}-{day:02d}",
raw_match=match.group(0),
confidence=confidence,
pattern_name=name,
)
)
# Sort by confidence, deduplicate by value
matches.sort(key=lambda x: x.confidence, reverse=True)
seen = set()
unique_matches = []
for match in matches:
if match.value not in seen:
seen.add(match.value)
unique_matches.append(match)
return unique_matches
def extract_best_date(self, text: str) -> Optional[DateMatch]:
"""
Extract the most likely transaction date.
Args:
text: Receipt text to search
Returns:
Best DateMatch or None if no date found
"""
matches = self.extract_dates(text)
return matches[0] if matches else None
def _parse_match(
self, match: re.Match, pattern_name: str
) -> Optional[tuple[int, int, int]]:
"""Parse regex match into year, month, day tuple."""
groups = match.groupdict()
# Handle month name patterns
if "month_name" in groups:
month_str = groups["month_name"].lower()
month = self.MONTH_NAMES.get(month_str)
if not month:
return None
else:
month = int(groups["month"])
day = int(groups["day"])
year = int(groups["year"])
# Normalize 2-digit years
if year < 100:
year = 2000 + year if year < 50 else 1900 + year
return year, month, day
def _is_valid_date(self, year: int, month: int, day: int) -> bool:
"""Check if date components form a valid date."""
try:
datetime(year=year, month=month, day=day)
# Reasonable year range for receipts
return 2000 <= year <= 2100
except ValueError:
return False
def _adjust_confidence(
self, base_confidence: float, text: str, position: int
) -> float:
"""
Adjust confidence based on context clues.
Boost confidence if date appears near date-related keywords.
"""
# Look for nearby date keywords
context_start = max(0, position - 50)
context = text[context_start:position + 50]
date_keywords = ["DATE", "TIME", "TRANS", "SALE"]
for keyword in date_keywords:
if keyword in context:
return min(1.0, base_confidence + 0.05)
return base_confidence
# Singleton instance
date_matcher = DatePatternMatcher()

View File

@@ -0,0 +1,364 @@
"""Fuel-specific pattern matching for receipt extraction."""
import re
from dataclasses import dataclass
from typing import Optional
@dataclass
class FuelQuantityMatch:
"""Result of fuel quantity pattern matching."""
value: float # Gallons or liters
unit: str # "GAL" or "L"
raw_match: str
confidence: float
pattern_name: str
@dataclass
class FuelPriceMatch:
"""Result of fuel price per unit pattern matching."""
value: float
unit: str # "GAL" or "L"
raw_match: str
confidence: float
pattern_name: str
@dataclass
class FuelGradeMatch:
"""Result of fuel grade pattern matching."""
value: str # e.g., "87", "89", "93", "DIESEL"
display_name: str # e.g., "Regular 87", "Premium 93"
raw_match: str
confidence: float
class FuelPatternMatcher:
"""Extract fuel-specific data from receipt text."""
# Gallons patterns
GALLONS_PATTERNS = [
# XX.XXX GAL or XX.XXX GALLONS
(
r"(\d{1,3}\.\d{1,3})\s*(?:GAL(?:LON)?S?)",
"gallons_suffix",
0.95,
),
# GALLONS: XX.XXX or GAL: XX.XXX
(
r"(?:GAL(?:LON)?S?)[:\s]+(\d{1,3}\.\d{1,3})",
"gallons_prefix",
0.93,
),
# VOLUME XX.XXX
(
r"VOLUME[:\s]+(\d{1,3}\.\d{1,3})",
"volume",
0.85,
),
# QTY XX.XXX (near fuel context)
(
r"QTY[:\s]+(\d{1,3}\.\d{1,3})",
"qty",
0.70,
),
]
# Liters patterns (for international receipts)
LITERS_PATTERNS = [
# XX.XX L or XX.XX LITERS
(
r"(\d{1,3}\.\d{1,3})\s*(?:L(?:ITERS?)?)",
"liters_suffix",
0.95,
),
# LITERS: XX.XX
(
r"(?:L(?:ITERS?)?)[:\s]+(\d{1,3}\.\d{1,3})",
"liters_prefix",
0.93,
),
]
# Price per gallon patterns
PRICE_PER_UNIT_PATTERNS = [
# $X.XXX/GAL or $X.XX/GAL
(
r"\$?\s*(\d{1,2}\.\d{2,3})\s*/\s*GAL",
"price_per_gal",
0.98,
),
# PRICE/GAL $X.XXX
(
r"PRICE\s*/\s*GAL[:\s]*\$?\s*(\d{1,2}\.\d{2,3})",
"labeled_price_gal",
0.96,
),
# UNIT PRICE $X.XXX
(
r"UNIT\s*PRICE[:\s]*\$?\s*(\d{1,2}\.\d{2,3})",
"unit_price",
0.90,
),
# @ $X.XXX (per unit implied)
(
r"@\s*\$?\s*(\d{1,2}\.\d{2,3})",
"at_price",
0.85,
),
# PPG $X.XXX (price per gallon)
(
r"PPG[:\s]*\$?\s*(\d{1,2}\.\d{2,3})",
"ppg",
0.92,
),
]
# Fuel grade patterns
GRADE_PATTERNS = [
# REGULAR 87, REG 87
(r"(?:REGULAR|REG)\s*(\d{2})", "regular", 0.95),
# UNLEADED 87
(r"UNLEADED\s*(\d{2})", "unleaded", 0.93),
# PLUS 89, MID 89, MIDGRADE 89
(r"(?:PLUS|MID(?:GRADE)?)\s*(\d{2})", "plus", 0.95),
# PREMIUM 91/93, PREM 91/93, SUPER 91/93
(r"(?:PREMIUM|PREM|SUPER)\s*(\d{2})", "premium", 0.95),
# Just the octane number near fuel context (87, 89, 91, 93)
(r"(?<!\d)\s*(87|89|91|93)\s*(?:OCT(?:ANE)?)?", "octane_only", 0.75),
# DIESEL (no octane)
(r"DIESEL(?:\s*#?\d)?", "diesel", 0.98),
# E85 (ethanol blend)
(r"E\s*85", "e85", 0.95),
]
# Common gas station names
STATION_NAMES = [
"SHELL", "CHEVRON", "EXXON", "MOBIL", "BP", "SUNOCO", "76",
"CIRCLE K", "SPEEDWAY", "WAWA", "SHEETZ", "CASEY", "PILOT",
"FLYING J", "LOVES", "TA", "PETRO", "MARATHON", "CITGO",
"VALERO", "MURPHY", "COSTCO", "SAMS CLUB", "SAM'S CLUB",
"KROGER", "QT", "QUIKTRIP", "RACETRAC", "KUM & GO",
"KWIK TRIP", "HOLIDAY", "SINCLAIR", "CONOCO", "PHILLIPS 66",
"ARCO", "AMPM", "AM/PM", "7-ELEVEN", "7 ELEVEN", "GETTY",
"GULF", "HESS", "TEXACO", "TURKEY HILL", "CUMBERLAND FARMS",
]
def extract_gallons(self, text: str) -> Optional[FuelQuantityMatch]:
"""
Extract fuel quantity in gallons.
Args:
text: Receipt text to search
Returns:
FuelQuantityMatch or None
"""
text_upper = text.upper()
for pattern, name, confidence in self.GALLONS_PATTERNS:
match = re.search(pattern, text_upper)
if match:
quantity = float(match.group(1))
if self._is_reasonable_quantity(quantity):
return FuelQuantityMatch(
value=quantity,
unit="GAL",
raw_match=match.group(0),
confidence=confidence,
pattern_name=name,
)
return None
def extract_liters(self, text: str) -> Optional[FuelQuantityMatch]:
"""
Extract fuel quantity in liters.
Args:
text: Receipt text to search
Returns:
FuelQuantityMatch or None
"""
text_upper = text.upper()
for pattern, name, confidence in self.LITERS_PATTERNS:
match = re.search(pattern, text_upper)
if match:
quantity = float(match.group(1))
if self._is_reasonable_quantity(quantity, is_liters=True):
return FuelQuantityMatch(
value=quantity,
unit="L",
raw_match=match.group(0),
confidence=confidence,
pattern_name=name,
)
return None
def extract_quantity(self, text: str) -> Optional[FuelQuantityMatch]:
"""
Extract fuel quantity (gallons or liters).
Prefers gallons for US receipts.
Args:
text: Receipt text to search
Returns:
FuelQuantityMatch or None
"""
# Try gallons first (more common in US)
gallons = self.extract_gallons(text)
if gallons:
return gallons
# Fall back to liters
return self.extract_liters(text)
def extract_price_per_unit(self, text: str) -> Optional[FuelPriceMatch]:
"""
Extract price per gallon/liter.
Args:
text: Receipt text to search
Returns:
FuelPriceMatch or None
"""
text_upper = text.upper()
for pattern, name, confidence in self.PRICE_PER_UNIT_PATTERNS:
match = re.search(pattern, text_upper)
if match:
price = float(match.group(1))
if self._is_reasonable_price(price):
return FuelPriceMatch(
value=price,
unit="GAL", # Default to gallons for US
raw_match=match.group(0),
confidence=confidence,
pattern_name=name,
)
return None
def extract_grade(self, text: str) -> Optional[FuelGradeMatch]:
"""
Extract fuel grade (octane rating or diesel).
Args:
text: Receipt text to search
Returns:
FuelGradeMatch or None
"""
text_upper = text.upper()
for pattern, name, confidence in self.GRADE_PATTERNS:
match = re.search(pattern, text_upper)
if match:
if name == "diesel":
return FuelGradeMatch(
value="DIESEL",
display_name="Diesel",
raw_match=match.group(0),
confidence=confidence,
)
elif name == "e85":
return FuelGradeMatch(
value="E85",
display_name="E85 Ethanol",
raw_match=match.group(0),
confidence=confidence,
)
else:
octane = match.group(1)
display = self._get_grade_display_name(octane, name)
return FuelGradeMatch(
value=octane,
display_name=display,
raw_match=match.group(0),
confidence=confidence,
)
return None
def extract_merchant_name(self, text: str) -> Optional[tuple[str, float]]:
"""
Extract gas station/merchant name.
Args:
text: Receipt text to search
Returns:
Tuple of (merchant_name, confidence) or None
"""
text_upper = text.upper()
# Check for known station names
for station in self.STATION_NAMES:
if station in text_upper:
# Try to get the full line for context
for line in text.split("\n"):
if station in line.upper():
# Clean up the line
cleaned = line.strip()
if len(cleaned) <= 50: # Reasonable length
return (cleaned, 0.90)
return (station.title(), 0.85)
# Fall back to first non-empty line (often the merchant)
lines = [l.strip() for l in text.split("\n") if l.strip()]
if lines:
first_line = lines[0]
# Skip if it looks like a date or number
if not re.match(r"^\d+[/\-.]", first_line):
return (first_line[:50], 0.50) # Low confidence
return None
def _is_reasonable_quantity(
self, quantity: float, is_liters: bool = False
) -> bool:
"""Check if fuel quantity is reasonable."""
if is_liters:
# Typical fill: 20-100 liters
return 0.5 <= quantity <= 150.0
else:
# Typical fill: 5-30 gallons
return 0.1 <= quantity <= 50.0
def _is_reasonable_price(self, price: float) -> bool:
"""Check if price per unit is reasonable."""
# US gas prices: $1.50 - $8.00 per gallon (allowing for fluctuation)
return 1.00 <= price <= 10.00
def _get_grade_display_name(self, octane: str, pattern_name: str) -> str:
"""Get display name for fuel grade."""
grade_names = {
"87": "Regular 87",
"89": "Plus 89",
"91": "Premium 91",
"93": "Premium 93",
}
if octane in grade_names:
return grade_names[octane]
# Use pattern hint
if pattern_name == "premium":
return f"Premium {octane}"
elif pattern_name == "plus":
return f"Plus {octane}"
else:
return f"Unleaded {octane}"
# Singleton instance
fuel_matcher = FuelPatternMatcher()

View File

@@ -1,10 +1,16 @@
"""Image preprocessors for OCR optimization."""
from app.services.preprocessor import ImagePreprocessor, preprocessor
from app.preprocessors.vin_preprocessor import VinPreprocessor, vin_preprocessor
from app.preprocessors.receipt_preprocessor import (
ReceiptPreprocessor,
receipt_preprocessor,
)
__all__ = [
"ImagePreprocessor",
"preprocessor",
"VinPreprocessor",
"vin_preprocessor",
"ReceiptPreprocessor",
"receipt_preprocessor",
]

View File

@@ -0,0 +1,340 @@
"""Receipt-optimized image preprocessing pipeline."""
import io
import logging
from dataclasses import dataclass
from typing import Optional
import cv2
import numpy as np
from PIL import Image
from pillow_heif import register_heif_opener
# Register HEIF/HEIC opener
register_heif_opener()
logger = logging.getLogger(__name__)
@dataclass
class ReceiptPreprocessingResult:
"""Result of receipt preprocessing."""
image_bytes: bytes
preprocessing_applied: list[str]
original_width: int
original_height: int
class ReceiptPreprocessor:
"""Receipt-optimized image preprocessing for improved OCR accuracy.
Thermal receipts typically have:
- Low contrast (faded ink)
- Uneven illumination
- Paper curl/skew
- Variable font weights
This preprocessor addresses these issues with targeted enhancements.
"""
# Optimal width for receipt OCR (narrow receipts work better)
TARGET_WIDTH = 800
def preprocess(
self,
image_bytes: bytes,
apply_contrast: bool = True,
apply_deskew: bool = True,
apply_denoise: bool = True,
apply_threshold: bool = True,
apply_sharpen: bool = True,
) -> ReceiptPreprocessingResult:
"""
Apply receipt-optimized preprocessing pipeline.
Pipeline optimized for thermal receipts:
1. HEIC conversion (if needed)
2. Grayscale conversion
3. Resize to optimal width
4. Deskew (correct rotation)
5. High contrast enhancement (CLAHE + histogram stretch)
6. Adaptive sharpening
7. Noise reduction
8. Adaptive thresholding (receipt-optimized)
Args:
image_bytes: Raw image bytes (HEIC, JPEG, PNG)
apply_contrast: Apply contrast enhancement
apply_deskew: Apply deskew correction
apply_denoise: Apply noise reduction
apply_threshold: Apply adaptive thresholding
apply_sharpen: Apply sharpening
Returns:
ReceiptPreprocessingResult with processed image bytes
"""
steps_applied = []
# Load image with PIL (handles HEIC via pillow-heif)
pil_image = Image.open(io.BytesIO(image_bytes))
original_width, original_height = pil_image.size
steps_applied.append("loaded")
# Handle EXIF rotation
pil_image = self._fix_orientation(pil_image)
# Convert to RGB if needed
if pil_image.mode not in ("RGB", "L"):
pil_image = pil_image.convert("RGB")
steps_applied.append("convert_rgb")
# Convert to OpenCV format
cv_image = np.array(pil_image)
if len(cv_image.shape) == 3:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
# Convert to grayscale
if len(cv_image.shape) == 3:
gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
else:
gray = cv_image
steps_applied.append("grayscale")
# Resize to optimal width while maintaining aspect ratio
gray = self._resize_optimal(gray)
steps_applied.append("resize")
# Apply deskew
if apply_deskew:
gray = self._deskew(gray)
steps_applied.append("deskew")
# Apply high contrast enhancement (critical for thermal receipts)
if apply_contrast:
gray = self._enhance_contrast(gray)
steps_applied.append("contrast")
# Apply sharpening
if apply_sharpen:
gray = self._sharpen(gray)
steps_applied.append("sharpen")
# Apply denoising
if apply_denoise:
gray = self._denoise(gray)
steps_applied.append("denoise")
# Apply adaptive thresholding (receipt-optimized parameters)
if apply_threshold:
gray = self._adaptive_threshold_receipt(gray)
steps_applied.append("threshold")
# Convert back to PNG bytes
result_image = Image.fromarray(gray)
buffer = io.BytesIO()
result_image.save(buffer, format="PNG")
logger.debug(f"Receipt preprocessing applied: {steps_applied}")
return ReceiptPreprocessingResult(
image_bytes=buffer.getvalue(),
preprocessing_applied=steps_applied,
original_width=original_width,
original_height=original_height,
)
def _fix_orientation(self, image: Image.Image) -> Image.Image:
"""Fix image orientation based on EXIF data."""
try:
exif = image.getexif()
if exif:
orientation = exif.get(274) # Orientation tag
if orientation:
rotate_values = {
3: 180,
6: 270,
8: 90,
}
if orientation in rotate_values:
return image.rotate(
rotate_values[orientation], expand=True
)
except Exception as e:
logger.debug(f"Could not read EXIF orientation: {e}")
return image
def _resize_optimal(self, image: np.ndarray) -> np.ndarray:
"""Resize image to optimal width for OCR."""
height, width = image.shape[:2]
if width <= self.TARGET_WIDTH:
return image
scale = self.TARGET_WIDTH / width
new_height = int(height * scale)
return cv2.resize(
image,
(self.TARGET_WIDTH, new_height),
interpolation=cv2.INTER_AREA,
)
def _deskew(self, image: np.ndarray) -> np.ndarray:
"""
Correct image rotation using projection profile.
Receipts often have slight rotation from scanning/photography.
Uses projection profile method optimized for text documents.
"""
try:
# Create binary image for angle detection
_, binary = cv2.threshold(
image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
# Find all non-zero points
coords = np.column_stack(np.where(binary > 0))
if len(coords) < 100:
return image
# Use minimum area rectangle to find angle
rect = cv2.minAreaRect(coords)
angle = rect[-1]
# Normalize angle
if angle < -45:
angle = 90 + angle
elif angle > 45:
angle = angle - 90
# Only correct if angle is significant but not extreme
if abs(angle) < 0.5 or abs(angle) > 15:
return image
# Rotate to correct skew
height, width = image.shape[:2]
center = (width // 2, height // 2)
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(
image,
rotation_matrix,
(width, height),
borderMode=cv2.BORDER_REPLICATE,
)
logger.debug(f"Receipt deskewed by {angle:.2f} degrees")
return rotated
except Exception as e:
logger.warning(f"Deskew failed: {e}")
return image
def _enhance_contrast(self, image: np.ndarray) -> np.ndarray:
"""
Apply aggressive contrast enhancement for faded receipts.
Combines:
1. Histogram stretching
2. CLAHE (Contrast Limited Adaptive Histogram Equalization)
"""
try:
# First, stretch histogram to use full dynamic range
p2, p98 = np.percentile(image, (2, 98))
stretched = np.clip(
(image - p2) * 255.0 / (p98 - p2), 0, 255
).astype(np.uint8)
# Apply CLAHE with parameters optimized for receipts
# Higher clipLimit for faded thermal receipts
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(stretched)
return enhanced
except Exception as e:
logger.warning(f"Contrast enhancement failed: {e}")
return image
def _sharpen(self, image: np.ndarray) -> np.ndarray:
"""
Apply unsharp masking for clearer text edges.
Light sharpening improves OCR on slightly blurry images.
"""
try:
# Gaussian blur for unsharp mask
blurred = cv2.GaussianBlur(image, (0, 0), 2.0)
# Unsharp mask: original + alpha * (original - blurred)
sharpened = cv2.addWeighted(image, 1.5, blurred, -0.5, 0)
return sharpened
except Exception as e:
logger.warning(f"Sharpening failed: {e}")
return image
def _denoise(self, image: np.ndarray) -> np.ndarray:
"""
Apply light denoising optimized for text.
Uses bilateral filter to preserve edges while reducing noise.
"""
try:
# Bilateral filter preserves edges better than Gaussian
# Light denoising - don't want to blur text
return cv2.bilateralFilter(image, 5, 50, 50)
except Exception as e:
logger.warning(f"Denoising failed: {e}")
return image
def _adaptive_threshold_receipt(self, image: np.ndarray) -> np.ndarray:
"""
Apply adaptive thresholding optimized for receipt text.
Uses parameters tuned for:
- Variable font sizes (small print + headers)
- Faded thermal printing
- Uneven paper illumination
"""
try:
# Use Gaussian adaptive threshold
# Larger block size (31) handles uneven illumination
# Moderate C value (8) for faded receipts
binary = cv2.adaptiveThreshold(
image,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
blockSize=31,
C=8,
)
return binary
except Exception as e:
logger.warning(f"Adaptive threshold failed: {e}")
return image
def preprocess_for_low_quality(
self, image_bytes: bytes
) -> ReceiptPreprocessingResult:
"""
Apply aggressive preprocessing for very low quality receipts.
Use this when standard preprocessing fails to produce readable text.
"""
return self.preprocess(
image_bytes,
apply_contrast=True,
apply_deskew=True,
apply_denoise=True,
apply_threshold=True,
apply_sharpen=True,
)
# Singleton instance
receipt_preprocessor = ReceiptPreprocessor()

View File

@@ -1,10 +1,19 @@
"""OCR extraction endpoints."""
import logging
from typing import Optional
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from app.extractors.vin_extractor import vin_extractor
from app.models import BoundingBox, OcrResponse, VinAlternative, VinExtractionResponse
from app.extractors.receipt_extractor import receipt_extractor
from app.models import (
BoundingBox,
OcrResponse,
ReceiptExtractedField,
ReceiptExtractionResponse,
VinAlternative,
VinExtractionResponse,
)
from app.services import ocr_service
logger = logging.getLogger(__name__)
@@ -154,3 +163,97 @@ async def extract_vin(
processingTimeMs=result.processing_time_ms,
error=result.error,
)
@router.post("/receipt", response_model=ReceiptExtractionResponse)
async def extract_receipt(
file: UploadFile = File(..., description="Receipt image file"),
receipt_type: Optional[str] = Form(
default=None,
description="Receipt type hint: 'fuel' for specialized extraction",
),
) -> ReceiptExtractionResponse:
"""
Extract data from a receipt image using OCR.
Optimized for fuel receipts with pattern-based field extraction:
- HEIC conversion (if needed)
- Grayscale conversion
- High contrast enhancement (for thermal receipts)
- Adaptive thresholding
- Pattern matching for dates, amounts, fuel quantities
Supports HEIC, JPEG, PNG formats.
Processing time target: <3 seconds.
- **file**: Receipt image file (max 10MB)
- **receipt_type**: Optional hint ("fuel" for gas station receipts)
Returns:
- **receiptType**: Detected type ("fuel" or "unknown")
- **extractedFields**: Dictionary of extracted fields with confidence scores
- merchantName: Gas station or store name
- transactionDate: Date in YYYY-MM-DD format
- totalAmount: Total purchase amount
- fuelQuantity: Gallons/liters purchased (fuel receipts)
- pricePerUnit: Price per gallon/liter (fuel receipts)
- fuelGrade: Octane rating or fuel type (fuel receipts)
- **rawText**: Full OCR text
- **processingTimeMs**: Processing time in milliseconds
"""
# Validate file presence
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
# Read file content
content = await file.read()
file_size = len(content)
# Validate file size
if file_size > MAX_SYNC_SIZE:
raise HTTPException(
status_code=413,
detail=f"File too large. Max: {MAX_SYNC_SIZE // (1024*1024)}MB",
)
if file_size == 0:
raise HTTPException(status_code=400, detail="Empty file provided")
logger.info(
f"Receipt extraction: {file.filename}, "
f"size: {file_size} bytes, "
f"content_type: {file.content_type}, "
f"receipt_type: {receipt_type}"
)
# Perform receipt extraction
result = receipt_extractor.extract(
image_bytes=content,
content_type=file.content_type,
receipt_type=receipt_type,
)
if not result.success:
logger.warning(f"Receipt extraction failed for {file.filename}: {result.error}")
raise HTTPException(
status_code=422,
detail=result.error or "Failed to extract data from receipt image",
)
# Convert internal fields to API response format
extracted_fields = {
name: ReceiptExtractedField(
value=field.value,
confidence=field.confidence,
)
for name, field in result.extracted_fields.items()
}
return ReceiptExtractionResponse(
success=result.success,
receiptType=result.receipt_type,
extractedFields=extracted_fields,
rawText=result.raw_text,
processingTimeMs=result.processing_time_ms,
error=result.error,
)

View File

@@ -0,0 +1,198 @@
"""Tests for currency pattern matching."""
import pytest
from app.patterns.currency_patterns import CurrencyPatternMatcher, currency_matcher
class TestCurrencyPatternMatcher:
"""Test currency and amount extraction."""
def test_total_explicit(self) -> None:
"""Test 'TOTAL $XX.XX' pattern."""
text = "TOTAL $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
assert result.confidence > 0.9
assert result.label == "TOTAL"
def test_total_with_colon(self) -> None:
"""Test 'TOTAL: $XX.XX' pattern."""
text = "TOTAL: $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
def test_total_without_dollar_sign(self) -> None:
"""Test 'TOTAL 45.67' pattern."""
text = "TOTAL 45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
def test_amount_due(self) -> None:
"""Test 'AMOUNT DUE' pattern."""
text = "AMOUNT DUE: $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
assert result.label == "AMOUNT DUE"
def test_sale_pattern(self) -> None:
"""Test 'SALE $XX.XX' pattern."""
text = "SALE $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
def test_grand_total(self) -> None:
"""Test 'GRAND TOTAL' pattern."""
text = "GRAND TOTAL $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
assert result.label == "GRAND TOTAL"
def test_total_sale(self) -> None:
"""Test 'TOTAL SALE' pattern."""
text = "TOTAL SALE: $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
def test_balance_due(self) -> None:
"""Test 'BALANCE DUE' pattern."""
text = "BALANCE DUE $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
def test_multiple_amounts_picks_total(self) -> None:
"""Test that labeled total is preferred over generic amounts."""
text = """
REGULAR 87
10.500 GAL @ $3.67
SUBTOTAL $38.54
TAX $0.00
TOTAL $38.54
"""
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 38.54
assert result.pattern_name == "total_explicit"
def test_all_amounts(self) -> None:
"""Test extracting all amounts from receipt."""
text = """
SUBTOTAL $35.00
TAX $3.54
TOTAL $38.54
"""
results = currency_matcher.extract_all_amounts(text)
# Should find TOTAL and possibly others
assert len(results) >= 1
assert any(r.value == 38.54 for r in results)
def test_comma_thousand_separator(self) -> None:
"""Test amounts with thousand separators."""
text = "TOTAL $1,234.56"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 1234.56
def test_reasonable_total_range(self) -> None:
"""Test that unreasonable totals are filtered."""
# Very small amount
text = "TOTAL $0.05"
result = currency_matcher.extract_total(text)
assert result is None # Too small for fuel receipt
# Reasonable amount
text = "TOTAL $45.67"
result = currency_matcher.extract_total(text)
assert result is not None
def test_receipt_context_extraction(self) -> None:
"""Test extraction from realistic receipt text."""
text = """
SHELL
123 MAIN ST
DATE: 01/15/2024
UNLEADED 87
10.500 GAL
@ $3.679/GAL
FUEL TOTAL $38.63
TAX $0.00
TOTAL $38.63
DEBIT CARD
************1234
"""
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 38.63
def test_no_total_returns_largest(self) -> None:
"""Test fallback to largest amount when no labeled total."""
text = """
$10.50
$5.00
$45.67
"""
result = currency_matcher.extract_total(text)
# Should infer largest reasonable amount as total
assert result is not None
assert result.value == 45.67
assert result.confidence < 0.7 # Lower confidence for inferred
def test_no_amounts_returns_none(self) -> None:
"""Test that text without amounts returns None."""
text = "SHELL STATION\nPUMP 5"
result = currency_matcher.extract_total(text)
assert result is None
class TestEdgeCases:
"""Test edge cases in currency parsing."""
def test_european_format(self) -> None:
"""Test European format (comma as decimal)."""
# European: 45,67 means 45.67
text = "TOTAL 45,67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
def test_spaces_in_amount(self) -> None:
"""Test handling of spaces around amounts."""
text = "TOTAL $ 45.67"
result = currency_matcher.extract_total(text)
assert result is not None
assert result.value == 45.67
def test_case_insensitive(self) -> None:
"""Test case insensitive matching."""
for label in ["TOTAL", "Total", "total"]:
text = f"{label} $45.67"
result = currency_matcher.extract_total(text)
assert result is not None, f"Failed for {label}"
assert result.value == 45.67

View File

@@ -0,0 +1,163 @@
"""Tests for date pattern matching."""
import pytest
from app.patterns.date_patterns import DatePatternMatcher, date_matcher
class TestDatePatternMatcher:
"""Test date pattern extraction."""
def test_mm_dd_yyyy_slash(self) -> None:
"""Test MM/DD/YYYY format."""
text = "DATE: 01/15/2024"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
assert result.confidence > 0.9
def test_mm_dd_yy_slash(self) -> None:
"""Test MM/DD/YY format with 2-digit year."""
text = "01/15/24"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
def test_mm_dd_yyyy_dash(self) -> None:
"""Test MM-DD-YYYY format."""
text = "01-15-2024"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
def test_iso_format(self) -> None:
"""Test ISO YYYY-MM-DD format."""
text = "2024-01-15"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
assert result.confidence > 0.95
def test_month_name_format(self) -> None:
"""Test 'Jan 15, 2024' format."""
text = "Jan 15, 2024"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
def test_month_name_no_comma(self) -> None:
"""Test 'Jan 15 2024' format without comma."""
text = "Jan 15 2024"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
def test_day_month_year_format(self) -> None:
"""Test '15 Jan 2024' format."""
text = "15 Jan 2024"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
def test_full_month_name(self) -> None:
"""Test full month name like 'January'."""
text = "January 15, 2024"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
def test_multiple_dates_returns_best(self) -> None:
"""Test that multiple dates returns highest confidence."""
text = "Date: 01/15/2024\nExpires: 01/15/2025"
results = date_matcher.extract_dates(text)
assert len(results) == 2
# Both should be valid
assert all(r.confidence > 0.5 for r in results)
def test_invalid_date_rejected(self) -> None:
"""Test that invalid dates are rejected."""
text = "13/45/2024" # Invalid month/day
result = date_matcher.extract_best_date(text)
assert result is None
def test_receipt_context_text(self) -> None:
"""Test date extraction from realistic receipt text."""
text = """
SHELL STATION
123 MAIN ST
DATE: 01/15/2024
TIME: 14:32
PUMP #5
REGULAR 87
10.500 GAL
TOTAL $38.50
"""
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-01-15"
def test_no_date_returns_none(self) -> None:
"""Test that text without dates returns None."""
text = "SHELL STATION\nTOTAL $38.50"
result = date_matcher.extract_best_date(text)
assert result is None
def test_confidence_boost_near_keyword(self) -> None:
"""Test confidence boost when date is near DATE keyword."""
text_with_keyword = "DATE: 01/15/2024"
text_without = "01/15/2024"
result_with = date_matcher.extract_best_date(text_with_keyword)
result_without = date_matcher.extract_best_date(text_without)
assert result_with is not None
assert result_without is not None
# Keyword proximity should boost confidence
assert result_with.confidence >= result_without.confidence
class TestEdgeCases:
"""Test edge cases in date parsing."""
def test_year_2000(self) -> None:
"""Test 2-digit year 00 is parsed as 2000."""
text = "01/15/00"
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2000-01-15"
def test_leap_year_date(self) -> None:
"""Test Feb 29 on leap year."""
text = "02/29/2024" # 2024 is a leap year
result = date_matcher.extract_best_date(text)
assert result is not None
assert result.value == "2024-02-29"
def test_leap_year_invalid(self) -> None:
"""Test Feb 29 on non-leap year is rejected."""
text = "02/29/2023" # 2023 is not a leap year
result = date_matcher.extract_best_date(text)
assert result is None
def test_september_abbrev(self) -> None:
"""Test September abbreviation (Sept vs Sep)."""
for abbrev in ["Sep", "Sept", "September"]:
text = f"{abbrev} 15, 2024"
result = date_matcher.extract_best_date(text)
assert result is not None, f"Failed for {abbrev}"
assert result.value == "2024-09-15"

View File

@@ -0,0 +1,327 @@
"""Tests for fuel-specific pattern matching."""
import pytest
from app.patterns.fuel_patterns import FuelPatternMatcher, fuel_matcher
class TestFuelQuantityExtraction:
"""Test fuel quantity (gallons/liters) extraction."""
def test_gallons_suffix(self) -> None:
"""Test 'XX.XXX GAL' pattern."""
text = "10.500 GAL"
result = fuel_matcher.extract_gallons(text)
assert result is not None
assert result.value == 10.5
assert result.unit == "GAL"
assert result.confidence > 0.9
def test_gallons_full_word(self) -> None:
"""Test 'XX.XXX GALLONS' pattern."""
text = "10.500 GALLONS"
result = fuel_matcher.extract_gallons(text)
assert result is not None
assert result.value == 10.5
def test_gallons_prefix(self) -> None:
"""Test 'GALLONS: XX.XXX' pattern."""
text = "GALLONS: 10.500"
result = fuel_matcher.extract_gallons(text)
assert result is not None
assert result.value == 10.5
def test_gal_prefix(self) -> None:
"""Test 'GAL: XX.XXX' pattern."""
text = "GAL: 10.500"
result = fuel_matcher.extract_gallons(text)
assert result is not None
assert result.value == 10.5
def test_volume_label(self) -> None:
"""Test 'VOLUME XX.XXX' pattern."""
text = "VOLUME: 10.500"
result = fuel_matcher.extract_gallons(text)
assert result is not None
assert result.value == 10.5
def test_liters_suffix(self) -> None:
"""Test 'XX.XX L' pattern."""
text = "40.5 L"
result = fuel_matcher.extract_liters(text)
assert result is not None
assert result.value == 40.5
assert result.unit == "L"
def test_liters_full_word(self) -> None:
"""Test 'XX.XX LITERS' pattern."""
text = "40.5 LITERS"
result = fuel_matcher.extract_liters(text)
assert result is not None
assert result.value == 40.5
def test_quantity_prefers_gallons(self) -> None:
"""Test extract_quantity prefers gallons for US receipts."""
text = "10.500 GAL"
result = fuel_matcher.extract_quantity(text)
assert result is not None
assert result.unit == "GAL"
def test_reasonable_quantity_filter(self) -> None:
"""Test unreasonable quantities are filtered."""
# Too small
text = "0.001 GAL"
result = fuel_matcher.extract_gallons(text)
assert result is None
# Too large
text = "100.0 GAL"
result = fuel_matcher.extract_gallons(text)
assert result is None
class TestFuelPriceExtraction:
"""Test price per unit extraction."""
def test_price_per_gal_dollar_sign(self) -> None:
"""Test '$X.XXX/GAL' pattern."""
text = "$3.679/GAL"
result = fuel_matcher.extract_price_per_unit(text)
assert result is not None
assert result.value == 3.679
assert result.unit == "GAL"
assert result.confidence > 0.95
def test_price_per_gal_no_dollar(self) -> None:
"""Test 'X.XXX/GAL' pattern."""
text = "3.679/GAL"
result = fuel_matcher.extract_price_per_unit(text)
assert result is not None
assert result.value == 3.679
def test_labeled_price_gal(self) -> None:
"""Test 'PRICE/GAL $X.XXX' pattern."""
text = "PRICE/GAL $3.679"
result = fuel_matcher.extract_price_per_unit(text)
assert result is not None
assert result.value == 3.679
def test_unit_price(self) -> None:
"""Test 'UNIT PRICE $X.XXX' pattern."""
text = "UNIT PRICE: $3.679"
result = fuel_matcher.extract_price_per_unit(text)
assert result is not None
assert result.value == 3.679
def test_at_price(self) -> None:
"""Test '@ $X.XXX' pattern."""
text = "10.500 GAL @ $3.679"
result = fuel_matcher.extract_price_per_unit(text)
assert result is not None
assert result.value == 3.679
def test_ppg_pattern(self) -> None:
"""Test 'PPG $X.XXX' pattern."""
text = "PPG: $3.679"
result = fuel_matcher.extract_price_per_unit(text)
assert result is not None
assert result.value == 3.679
def test_reasonable_price_filter(self) -> None:
"""Test unreasonable prices are filtered."""
# Too low
text = "$0.50/GAL"
result = fuel_matcher.extract_price_per_unit(text)
assert result is None
# Too high
text = "$15.00/GAL"
result = fuel_matcher.extract_price_per_unit(text)
assert result is None
class TestFuelGradeExtraction:
"""Test fuel grade/octane extraction."""
def test_regular_87(self) -> None:
"""Test 'REGULAR 87' pattern."""
text = "REGULAR 87"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "87"
assert "Regular" in result.display_name
def test_reg_87(self) -> None:
"""Test 'REG 87' pattern."""
text = "REG 87"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "87"
def test_unleaded_87(self) -> None:
"""Test 'UNLEADED 87' pattern."""
text = "UNLEADED 87"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "87"
def test_plus_89(self) -> None:
"""Test 'PLUS 89' pattern."""
text = "PLUS 89"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "89"
assert "Plus" in result.display_name
def test_midgrade_89(self) -> None:
"""Test 'MIDGRADE 89' pattern."""
text = "MIDGRADE 89"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "89"
def test_premium_93(self) -> None:
"""Test 'PREMIUM 93' pattern."""
text = "PREMIUM 93"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "93"
assert "Premium" in result.display_name
def test_super_93(self) -> None:
"""Test 'SUPER 93' pattern."""
text = "SUPER 93"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "93"
def test_diesel(self) -> None:
"""Test 'DIESEL' pattern."""
text = "DIESEL #2"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "DIESEL"
assert "Diesel" in result.display_name
def test_e85(self) -> None:
"""Test 'E85' ethanol pattern."""
text = "E85"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "E85"
def test_octane_only(self) -> None:
"""Test standalone octane number."""
text = "87 OCTANE"
result = fuel_matcher.extract_grade(text)
assert result is not None
assert result.value == "87"
class TestMerchantExtraction:
"""Test gas station name extraction."""
def test_shell_station(self) -> None:
"""Test Shell station detection."""
text = "SHELL\n123 MAIN ST"
result = fuel_matcher.extract_merchant_name(text)
assert result is not None
merchant, confidence = result
assert "SHELL" in merchant.upper()
assert confidence > 0.8
def test_chevron_station(self) -> None:
"""Test Chevron station detection."""
text = "CHEVRON #12345\n456 OAK AVE"
result = fuel_matcher.extract_merchant_name(text)
assert result is not None
merchant, confidence = result
assert "CHEVRON" in merchant.upper()
def test_costco_gas(self) -> None:
"""Test Costco gas detection."""
text = "COSTCO GASOLINE\n789 WAREHOUSE BLVD"
result = fuel_matcher.extract_merchant_name(text)
assert result is not None
merchant, confidence = result
assert "COSTCO" in merchant.upper()
def test_unknown_station_fallback(self) -> None:
"""Test fallback to first line for unknown stations."""
text = "JOE'S GAS\n123 MAIN ST"
result = fuel_matcher.extract_merchant_name(text)
assert result is not None
merchant, confidence = result
assert "JOE'S GAS" in merchant
assert confidence < 0.7 # Lower confidence for unknown
class TestReceiptContextExtraction:
"""Test extraction from realistic receipt text."""
def test_full_receipt_extraction(self) -> None:
"""Test all fields from complete receipt text."""
text = """
SHELL
123 MAIN STREET
ANYTOWN, USA 12345
DATE: 01/15/2024
TIME: 14:32
PUMP #5
REGULAR 87
10.500 GAL @ $3.679/GAL
FUEL TOTAL $38.63
TAX $0.00
TOTAL $38.63
DEBIT CARD
************1234
APPROVED
"""
# Test all extractors on this text
quantity = fuel_matcher.extract_quantity(text)
assert quantity is not None
assert quantity.value == 10.5
price = fuel_matcher.extract_price_per_unit(text)
assert price is not None
assert price.value == 3.679
grade = fuel_matcher.extract_grade(text)
assert grade is not None
assert grade.value == "87"
merchant = fuel_matcher.extract_merchant_name(text)
assert merchant is not None
assert "SHELL" in merchant[0].upper()

View File

@@ -0,0 +1,339 @@
"""Tests for receipt extraction pipeline."""
import io
import pytest
from unittest.mock import MagicMock, patch
from app.extractors.receipt_extractor import (
ReceiptExtractor,
ReceiptExtractionResult,
receipt_extractor,
)
from app.extractors.fuel_receipt import (
FuelReceiptExtractor,
FuelReceiptValidation,
fuel_receipt_extractor,
)
class TestReceiptExtractor:
"""Test the receipt extraction pipeline."""
def test_detect_receipt_type_fuel(self) -> None:
"""Test fuel receipt type detection."""
text = """
SHELL STATION
REGULAR 87
10.500 GAL
TOTAL $38.50
"""
extractor = ReceiptExtractor()
receipt_type = extractor._detect_receipt_type(text)
assert receipt_type == "fuel"
def test_detect_receipt_type_unknown(self) -> None:
"""Test unknown receipt type detection."""
text = """
WALMART
GROCERIES
MILK $3.99
BREAD $2.50
TOTAL $6.49
"""
extractor = ReceiptExtractor()
receipt_type = extractor._detect_receipt_type(text)
assert receipt_type == "unknown"
def test_extract_fuel_fields(self) -> None:
"""Test fuel field extraction from OCR text."""
text = """
SHELL
123 MAIN ST
DATE: 01/15/2024
REGULAR 87
10.500 GAL @ $3.679
TOTAL $38.63
"""
extractor = ReceiptExtractor()
fields = extractor._extract_fuel_fields(text)
assert "merchantName" in fields
assert "transactionDate" in fields
assert "totalAmount" in fields
assert "fuelQuantity" in fields
assert "pricePerUnit" in fields
assert "fuelGrade" in fields
assert fields["totalAmount"].value == 38.63
assert fields["fuelQuantity"].value == 10.5
assert fields["fuelGrade"].value == "87"
def test_extract_generic_fields(self) -> None:
"""Test generic field extraction."""
text = """
WALMART
01/15/2024
TOTAL $25.99
"""
extractor = ReceiptExtractor()
fields = extractor._extract_generic_fields(text)
assert "transactionDate" in fields
assert "totalAmount" in fields
assert fields["totalAmount"].value == 25.99
def test_calculated_price_per_unit(self) -> None:
"""Test price per unit calculation when not explicitly stated."""
text = """
SHELL
DATE: 01/15/2024
10.000 GAL
TOTAL $35.00
"""
extractor = ReceiptExtractor()
fields = extractor._extract_fuel_fields(text)
assert "pricePerUnit" in fields
# 35.00 / 10.000 = 3.50
assert abs(fields["pricePerUnit"].value - 3.50) < 0.01
# Calculated values should have lower confidence
assert fields["pricePerUnit"].confidence < 0.9
def test_validate_valid_data(self) -> None:
"""Test validation of valid receipt data."""
extractor = ReceiptExtractor()
data = {"totalAmount": 38.63, "transactionDate": "2024-01-15"}
assert extractor.validate(data) is True
def test_validate_invalid_data(self) -> None:
"""Test validation of invalid receipt data."""
extractor = ReceiptExtractor()
# Empty dict
assert extractor.validate({}) is False
# Not a dict
assert extractor.validate("invalid") is False
def test_unsupported_file_type(self) -> None:
"""Test handling of unsupported file types."""
extractor = ReceiptExtractor()
with patch.object(extractor, "_detect_mime_type", return_value="application/pdf"):
result = extractor.extract(b"fake pdf content")
assert result.success is False
assert "Unsupported file type" in result.error
class TestFuelReceiptExtractor:
"""Test fuel receipt specialized extractor."""
def test_validation_success(self) -> None:
"""Test validation passes for consistent data."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=38.63, confidence=0.95),
"fuelQuantity": ExtractedField(value=10.5, confidence=0.90),
"pricePerUnit": ExtractedField(value=3.679, confidence=0.92),
"fuelGrade": ExtractedField(value="87", confidence=0.88),
}
validation = extractor._validate_fuel_receipt(fields)
assert validation.is_valid is True
assert len(validation.issues) == 0
assert validation.confidence_score == 1.0
def test_validation_math_mismatch(self) -> None:
"""Test validation catches total != quantity * price."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=50.00, confidence=0.95),
"fuelQuantity": ExtractedField(value=10.0, confidence=0.90),
"pricePerUnit": ExtractedField(value=3.00, confidence=0.92), # Should be $30
}
validation = extractor._validate_fuel_receipt(fields)
assert validation.is_valid is False
assert any("doesn't match" in issue for issue in validation.issues)
def test_validation_quantity_too_small(self) -> None:
"""Test validation catches too-small quantity."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=1.00, confidence=0.95),
"fuelQuantity": ExtractedField(value=0.1, confidence=0.90),
}
validation = extractor._validate_fuel_receipt(fields)
assert any("too small" in issue for issue in validation.issues)
def test_validation_quantity_too_large(self) -> None:
"""Test validation warns on very large quantity."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=150.00, confidence=0.95),
"fuelQuantity": ExtractedField(value=45.0, confidence=0.90),
}
validation = extractor._validate_fuel_receipt(fields)
assert any("unusually large" in issue for issue in validation.issues)
def test_validation_price_too_low(self) -> None:
"""Test validation catches too-low price."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=10.00, confidence=0.95),
"pricePerUnit": ExtractedField(value=1.00, confidence=0.90),
}
validation = extractor._validate_fuel_receipt(fields)
assert any("too low" in issue for issue in validation.issues)
def test_validation_unknown_grade(self) -> None:
"""Test validation catches unknown fuel grade."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=38.00, confidence=0.95),
"fuelGrade": ExtractedField(value="95", confidence=0.70), # Not valid US grade
}
validation = extractor._validate_fuel_receipt(fields)
assert any("Unknown fuel grade" in issue for issue in validation.issues)
def test_confidence_adjustment_boost(self) -> None:
"""Test confidence boost when validation passes."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=38.63, confidence=0.80),
}
validation = FuelReceiptValidation(
is_valid=True, issues=[], confidence_score=1.0
)
adjusted = extractor._adjust_confidences(fields, validation)
# Should be boosted by 1.1
assert adjusted["totalAmount"].confidence == min(1.0, 0.80 * 1.1)
def test_confidence_adjustment_reduce(self) -> None:
"""Test confidence reduction when validation fails."""
from app.extractors.receipt_extractor import ExtractedField
extractor = FuelReceiptExtractor()
fields = {
"totalAmount": ExtractedField(value=50.00, confidence=0.90),
}
validation = FuelReceiptValidation(
is_valid=False, issues=["Math mismatch"], confidence_score=0.7
)
adjusted = extractor._adjust_confidences(fields, validation)
# Should be reduced by 0.7
assert adjusted["totalAmount"].confidence == 0.90 * 0.7
class TestReceiptPreprocessing:
"""Test receipt preprocessing integration."""
def test_preprocessing_result_structure(self) -> None:
"""Test preprocessing returns expected structure."""
from app.preprocessors.receipt_preprocessor import (
receipt_preprocessor,
ReceiptPreprocessingResult,
)
# Create a simple test image (1x1 white pixel PNG)
from PIL import Image
img = Image.new("RGB", (100, 100), color="white")
buffer = io.BytesIO()
img.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
result = receipt_preprocessor.preprocess(image_bytes)
assert isinstance(result, ReceiptPreprocessingResult)
assert len(result.image_bytes) > 0
assert "loaded" in result.preprocessing_applied
assert "grayscale" in result.preprocessing_applied
assert result.original_width == 100
assert result.original_height == 100
def test_preprocessing_steps_applied(self) -> None:
"""Test all preprocessing steps are applied."""
from app.preprocessors.receipt_preprocessor import receipt_preprocessor
from PIL import Image
img = Image.new("RGB", (100, 100), color="white")
buffer = io.BytesIO()
img.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
result = receipt_preprocessor.preprocess(
image_bytes,
apply_contrast=True,
apply_deskew=True,
apply_denoise=True,
apply_threshold=True,
apply_sharpen=True,
)
# Check that expected steps are in the applied list
assert "contrast" in result.preprocessing_applied
assert "denoise" in result.preprocessing_applied
assert "threshold" in result.preprocessing_applied
class TestEndpointIntegration:
"""Test receipt extraction endpoint integration."""
@pytest.fixture
def test_client(self):
"""Create test client for FastAPI app."""
from fastapi.testclient import TestClient
from app.main import app
return TestClient(app)
def test_receipt_endpoint_exists(self, test_client) -> None:
"""Test that /extract/receipt endpoint exists."""
# Should get 422 (no file) not 404 (not found)
response = test_client.post("/extract/receipt")
assert response.status_code == 422 # Unprocessable Entity (missing file)
def test_receipt_endpoint_no_file(self, test_client) -> None:
"""Test endpoint returns error when no file provided."""
response = test_client.post("/extract/receipt")
assert response.status_code == 422
def test_receipt_endpoint_empty_file(self, test_client) -> None:
"""Test endpoint returns error for empty file."""
response = test_client.post(
"/extract/receipt",
files={"file": ("receipt.jpg", b"", "image/jpeg")},
)
assert response.status_code == 400
assert "Empty file" in response.json()["detail"]