diff --git a/ocr/app/extractors/__init__.py b/ocr/app/extractors/__init__.py index 9ae8f51..d0468f1 100644 --- a/ocr/app/extractors/__init__.py +++ b/ocr/app/extractors/__init__.py @@ -1,10 +1,23 @@ """Extractors package for domain-specific OCR extraction.""" from app.extractors.base import BaseExtractor, ExtractionResult from app.extractors.vin_extractor import VinExtractor, vin_extractor +from app.extractors.receipt_extractor import ( + ReceiptExtractor, + receipt_extractor, + ReceiptExtractionResult, + ExtractedField, +) +from app.extractors.fuel_receipt import FuelReceiptExtractor, fuel_receipt_extractor __all__ = [ "BaseExtractor", "ExtractionResult", "VinExtractor", "vin_extractor", + "ReceiptExtractor", + "receipt_extractor", + "ReceiptExtractionResult", + "ExtractedField", + "FuelReceiptExtractor", + "fuel_receipt_extractor", ] diff --git a/ocr/app/extractors/fuel_receipt.py b/ocr/app/extractors/fuel_receipt.py new file mode 100644 index 0000000..cb4a62f --- /dev/null +++ b/ocr/app/extractors/fuel_receipt.py @@ -0,0 +1,193 @@ +"""Fuel receipt specialization with validation and cross-checking.""" +import logging +from dataclasses import dataclass +from typing import Optional + +from app.extractors.receipt_extractor import ( + ExtractedField, + ReceiptExtractionResult, + receipt_extractor, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class FuelReceiptValidation: + """Validation result for fuel receipt extraction.""" + + is_valid: bool + issues: list[str] + confidence_score: float + + +class FuelReceiptExtractor: + """Specialized fuel receipt extractor with cross-validation. + + Provides additional validation and confidence scoring specific + to fuel receipts by cross-checking extracted values. + """ + + # Expected fields for a complete fuel receipt + REQUIRED_FIELDS = ["totalAmount"] + OPTIONAL_FIELDS = [ + "merchantName", + "transactionDate", + "fuelQuantity", + "pricePerUnit", + "fuelGrade", + ] + + def extract( + self, + image_bytes: bytes, + content_type: Optional[str] = None, + ) -> ReceiptExtractionResult: + """ + Extract fuel receipt data with validation. + + Args: + image_bytes: Raw image bytes + content_type: MIME type + + Returns: + ReceiptExtractionResult with fuel-specific extraction + """ + # Use base receipt extractor with fuel hint + result = receipt_extractor.extract( + image_bytes=image_bytes, + content_type=content_type, + receipt_type="fuel", + ) + + if not result.success: + return result + + # Validate and cross-check fuel fields + validation = self._validate_fuel_receipt(result.extracted_fields) + + if validation.issues: + logger.warning( + f"Fuel receipt validation issues: {validation.issues}" + ) + + # Update overall confidence based on validation + result.extracted_fields = self._adjust_confidences( + result.extracted_fields, validation + ) + + return result + + def _validate_fuel_receipt( + self, fields: dict[str, ExtractedField] + ) -> FuelReceiptValidation: + """ + Validate extracted fuel receipt fields. + + Cross-checks: + - total = quantity * price per unit (within tolerance) + - quantity is reasonable for a single fill-up + - price per unit is within market range + + Args: + fields: Extracted fields + + Returns: + FuelReceiptValidation with issues and confidence + """ + issues = [] + confidence_score = 1.0 + + # Check required fields + for field_name in self.REQUIRED_FIELDS: + if field_name not in fields: + issues.append(f"Missing required field: {field_name}") + confidence_score *= 0.5 + + # Cross-validate total = quantity * price + if all( + f in fields for f in ["totalAmount", "fuelQuantity", "pricePerUnit"] + ): + total = fields["totalAmount"].value + quantity = fields["fuelQuantity"].value + price = fields["pricePerUnit"].value + + calculated_total = quantity * price + tolerance = 0.10 # Allow 10% tolerance for rounding + + if abs(total - calculated_total) > total * tolerance: + issues.append( + f"Total ({total}) doesn't match quantity ({quantity}) * " + f"price ({price}) = {calculated_total:.2f}" + ) + confidence_score *= 0.7 + + # Validate quantity is reasonable + if "fuelQuantity" in fields: + quantity = fields["fuelQuantity"].value + if quantity < 0.5: + issues.append(f"Fuel quantity too small: {quantity}") + confidence_score *= 0.6 + elif quantity > 40: # 40 gallons is very large tank + issues.append(f"Fuel quantity unusually large: {quantity}") + confidence_score *= 0.8 + + # Validate price is reasonable (current US market range) + if "pricePerUnit" in fields: + price = fields["pricePerUnit"].value + if price < 1.50: + issues.append(f"Price per unit too low: ${price}") + confidence_score *= 0.7 + elif price > 7.00: + issues.append(f"Price per unit unusually high: ${price}") + confidence_score *= 0.8 + + # Validate fuel grade + if "fuelGrade" in fields: + grade = fields["fuelGrade"].value + valid_grades = ["87", "89", "91", "93", "DIESEL", "E85"] + if grade not in valid_grades: + issues.append(f"Unknown fuel grade: {grade}") + confidence_score *= 0.9 + + is_valid = len(issues) == 0 + return FuelReceiptValidation( + is_valid=is_valid, + issues=issues, + confidence_score=confidence_score, + ) + + def _adjust_confidences( + self, + fields: dict[str, ExtractedField], + validation: FuelReceiptValidation, + ) -> dict[str, ExtractedField]: + """ + Adjust field confidences based on validation. + + Args: + fields: Extracted fields + validation: Validation result + + Returns: + Fields with adjusted confidences + """ + if validation.is_valid: + # Boost confidences when cross-validation passes + boost = 1.1 + else: + # Reduce confidences when there are issues + boost = validation.confidence_score + + adjusted = {} + for name, field in fields.items(): + adjusted[name] = ExtractedField( + value=field.value, + confidence=min(1.0, field.confidence * boost), + ) + + return adjusted + + +# Singleton instance +fuel_receipt_extractor = FuelReceiptExtractor() diff --git a/ocr/app/extractors/receipt_extractor.py b/ocr/app/extractors/receipt_extractor.py new file mode 100644 index 0000000..6134988 --- /dev/null +++ b/ocr/app/extractors/receipt_extractor.py @@ -0,0 +1,345 @@ +"""Receipt-specific OCR extractor with field extraction.""" +import io +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +import magic +import pytesseract +from PIL import Image +from pillow_heif import register_heif_opener + +from app.config import settings +from app.extractors.base import BaseExtractor +from app.preprocessors.receipt_preprocessor import receipt_preprocessor +from app.patterns import currency_matcher, date_matcher, fuel_matcher + +# Register HEIF/HEIC opener +register_heif_opener() + +logger = logging.getLogger(__name__) + + +@dataclass +class ExtractedField: + """A single extracted field with confidence.""" + + value: Any + confidence: float + + +@dataclass +class ReceiptExtractionResult: + """Result of receipt extraction.""" + + success: bool + receipt_type: str = "unknown" + extracted_fields: dict[str, ExtractedField] = field(default_factory=dict) + raw_text: str = "" + processing_time_ms: int = 0 + error: Optional[str] = None + + +class ReceiptExtractor(BaseExtractor): + """Receipt-specific OCR extractor for fuel and general receipts.""" + + # Supported MIME types + SUPPORTED_TYPES = { + "image/jpeg", + "image/png", + "image/heic", + "image/heif", + } + + def __init__(self) -> None: + """Initialize receipt extractor.""" + pytesseract.pytesseract.tesseract_cmd = settings.tesseract_cmd + + def extract( + self, + image_bytes: bytes, + content_type: Optional[str] = None, + receipt_type: Optional[str] = None, + ) -> ReceiptExtractionResult: + """ + Extract data from a receipt image. + + Args: + image_bytes: Raw image bytes (HEIC, JPEG, PNG) + content_type: MIME type (auto-detected if not provided) + receipt_type: Hint for receipt type ("fuel" for specialized extraction) + + Returns: + ReceiptExtractionResult with extracted fields + """ + start_time = time.time() + + # Detect content type if not provided + if not content_type: + content_type = self._detect_mime_type(image_bytes) + + # Validate content type + if content_type not in self.SUPPORTED_TYPES: + return ReceiptExtractionResult( + success=False, + error=f"Unsupported file type: {content_type}", + processing_time_ms=int((time.time() - start_time) * 1000), + ) + + try: + # Apply receipt-optimized preprocessing + preprocessing_result = receipt_preprocessor.preprocess(image_bytes) + preprocessed_bytes = preprocessing_result.image_bytes + + # Perform OCR + raw_text = self._perform_ocr(preprocessed_bytes) + + if not raw_text.strip(): + # Try with less aggressive preprocessing + preprocessing_result = receipt_preprocessor.preprocess( + image_bytes, + apply_threshold=False, + ) + preprocessed_bytes = preprocessing_result.image_bytes + raw_text = self._perform_ocr(preprocessed_bytes) + + if not raw_text.strip(): + return ReceiptExtractionResult( + success=False, + error="No text found in image", + processing_time_ms=int((time.time() - start_time) * 1000), + ) + + # Detect receipt type if not specified + detected_type = receipt_type or self._detect_receipt_type(raw_text) + + # Extract fields based on receipt type + if detected_type == "fuel": + extracted_fields = self._extract_fuel_fields(raw_text) + else: + extracted_fields = self._extract_generic_fields(raw_text) + + processing_time_ms = int((time.time() - start_time) * 1000) + + logger.info( + f"Receipt extraction: type={detected_type}, " + f"fields={len(extracted_fields)}, " + f"time={processing_time_ms}ms" + ) + + return ReceiptExtractionResult( + success=True, + receipt_type=detected_type, + extracted_fields=extracted_fields, + raw_text=raw_text, + processing_time_ms=processing_time_ms, + ) + + except Exception as e: + logger.error(f"Receipt extraction failed: {e}", exc_info=True) + return ReceiptExtractionResult( + success=False, + error=str(e), + processing_time_ms=int((time.time() - start_time) * 1000), + ) + + def _detect_mime_type(self, file_bytes: bytes) -> str: + """Detect MIME type using python-magic.""" + mime = magic.Magic(mime=True) + detected = mime.from_buffer(file_bytes) + return detected or "application/octet-stream" + + def _perform_ocr(self, image_bytes: bytes, psm: int = 6) -> str: + """ + Perform OCR on preprocessed image. + + Args: + image_bytes: Preprocessed image bytes + psm: Tesseract page segmentation mode + 4 = Assume single column of text + 6 = Uniform block of text + + Returns: + Raw OCR text + """ + image = Image.open(io.BytesIO(image_bytes)) + + # Configure Tesseract for receipt OCR + # PSM 4 works well for columnar receipt text + config = f"--psm {psm}" + + return pytesseract.image_to_string(image, config=config) + + def _detect_receipt_type(self, text: str) -> str: + """ + Detect receipt type based on content. + + Args: + text: OCR text + + Returns: + Receipt type: "fuel", "retail", or "unknown" + """ + text_upper = text.upper() + + # Fuel receipt indicators + fuel_keywords = [ + "GALLON", "GAL", "FUEL", "GAS", "DIESEL", "UNLEADED", + "REGULAR", "PREMIUM", "OCTANE", "PPG", "PUMP", + ] + + fuel_score = sum(1 for kw in fuel_keywords if kw in text_upper) + + # Check for known gas stations + if fuel_matcher.extract_merchant_name(text): + merchant, _ = fuel_matcher.extract_merchant_name(text) + if any( + station in merchant.upper() + for station in fuel_matcher.STATION_NAMES + ): + fuel_score += 3 + + if fuel_score >= 2: + return "fuel" + + return "unknown" + + def _extract_fuel_fields(self, text: str) -> dict[str, ExtractedField]: + """ + Extract fuel-specific fields from receipt text. + + Args: + text: OCR text + + Returns: + Dictionary of extracted fields + """ + fields: dict[str, ExtractedField] = {} + + # Extract merchant name + merchant_result = fuel_matcher.extract_merchant_name(text) + if merchant_result: + merchant_name, confidence = merchant_result + fields["merchantName"] = ExtractedField( + value=merchant_name, + confidence=confidence, + ) + + # Extract transaction date + date_match = date_matcher.extract_best_date(text) + if date_match: + fields["transactionDate"] = ExtractedField( + value=date_match.value, + confidence=date_match.confidence, + ) + + # Extract total amount + total_match = currency_matcher.extract_total(text) + if total_match: + fields["totalAmount"] = ExtractedField( + value=total_match.value, + confidence=total_match.confidence, + ) + + # Extract fuel quantity + quantity_match = fuel_matcher.extract_quantity(text) + if quantity_match: + fields["fuelQuantity"] = ExtractedField( + value=quantity_match.value, + confidence=quantity_match.confidence, + ) + + # Extract price per unit + price_match = fuel_matcher.extract_price_per_unit(text) + if price_match: + fields["pricePerUnit"] = ExtractedField( + value=price_match.value, + confidence=price_match.confidence, + ) + + # Extract fuel grade + grade_match = fuel_matcher.extract_grade(text) + if grade_match: + fields["fuelGrade"] = ExtractedField( + value=grade_match.value, + confidence=grade_match.confidence, + ) + + # Calculate derived values if we have enough data + if "totalAmount" in fields and "fuelQuantity" in fields: + if "pricePerUnit" not in fields: + # Calculate price per unit from total and quantity + calculated_price = ( + fields["totalAmount"].value / fields["fuelQuantity"].value + ) + # Only use if reasonable + if 1.0 <= calculated_price <= 10.0: + fields["pricePerUnit"] = ExtractedField( + value=round(calculated_price, 3), + confidence=min( + fields["totalAmount"].confidence, + fields["fuelQuantity"].confidence, + ) + * 0.8, # Lower confidence for calculated value + ) + + return fields + + def _extract_generic_fields(self, text: str) -> dict[str, ExtractedField]: + """ + Extract generic fields from receipt text. + + Args: + text: OCR text + + Returns: + Dictionary of extracted fields + """ + fields: dict[str, ExtractedField] = {} + + # Extract date + date_match = date_matcher.extract_best_date(text) + if date_match: + fields["transactionDate"] = ExtractedField( + value=date_match.value, + confidence=date_match.confidence, + ) + + # Extract total amount + total_match = currency_matcher.extract_total(text) + if total_match: + fields["totalAmount"] = ExtractedField( + value=total_match.value, + confidence=total_match.confidence, + ) + + # Try to get merchant from first line + lines = [l.strip() for l in text.split("\n") if l.strip()] + if lines: + fields["merchantName"] = ExtractedField( + value=lines[0][:50], + confidence=0.40, + ) + + return fields + + def validate(self, data: Any) -> bool: + """ + Validate extracted receipt data. + + Args: + data: Extracted data to validate + + Returns: + True if data has minimum required fields + """ + if not isinstance(data, dict): + return False + + # Minimum: must have at least total amount or date + return "totalAmount" in data or "transactionDate" in data + + +# Singleton instance +receipt_extractor = ReceiptExtractor() diff --git a/ocr/app/models/__init__.py b/ocr/app/models/__init__.py index 9063882..eecbf23 100644 --- a/ocr/app/models/__init__.py +++ b/ocr/app/models/__init__.py @@ -7,6 +7,8 @@ from .schemas import ( JobStatus, JobSubmitRequest, OcrResponse, + ReceiptExtractedField, + ReceiptExtractionResponse, VinAlternative, VinExtractionResponse, ) @@ -19,6 +21,8 @@ __all__ = [ "JobStatus", "JobSubmitRequest", "OcrResponse", + "ReceiptExtractedField", + "ReceiptExtractionResponse", "VinAlternative", "VinExtractionResponse", ] diff --git a/ocr/app/models/schemas.py b/ocr/app/models/schemas.py index ff34d94..d1c9536 100644 --- a/ocr/app/models/schemas.py +++ b/ocr/app/models/schemas.py @@ -93,3 +93,25 @@ class JobSubmitRequest(BaseModel): callback_url: Optional[str] = Field(default=None, alias="callbackUrl") model_config = {"populate_by_name": True} + + +class ReceiptExtractedField(BaseModel): + """A single extracted field from a receipt with confidence.""" + + value: str | float + confidence: float = Field(ge=0.0, le=1.0) + + +class ReceiptExtractionResponse(BaseModel): + """Response from receipt extraction endpoint.""" + + success: bool + receipt_type: str = Field(alias="receiptType") + extracted_fields: dict[str, ReceiptExtractedField] = Field( + default_factory=dict, alias="extractedFields" + ) + raw_text: str = Field(alias="rawText") + processing_time_ms: int = Field(alias="processingTimeMs") + error: Optional[str] = None + + model_config = {"populate_by_name": True} diff --git a/ocr/app/patterns/__init__.py b/ocr/app/patterns/__init__.py new file mode 100644 index 0000000..e4d94c3 --- /dev/null +++ b/ocr/app/patterns/__init__.py @@ -0,0 +1,13 @@ +"""Pattern matching modules for receipt field extraction.""" +from app.patterns.date_patterns import DatePatternMatcher, date_matcher +from app.patterns.currency_patterns import CurrencyPatternMatcher, currency_matcher +from app.patterns.fuel_patterns import FuelPatternMatcher, fuel_matcher + +__all__ = [ + "DatePatternMatcher", + "date_matcher", + "CurrencyPatternMatcher", + "currency_matcher", + "FuelPatternMatcher", + "fuel_matcher", +] diff --git a/ocr/app/patterns/currency_patterns.py b/ocr/app/patterns/currency_patterns.py new file mode 100644 index 0000000..0b249cc --- /dev/null +++ b/ocr/app/patterns/currency_patterns.py @@ -0,0 +1,227 @@ +"""Currency and amount pattern matching for receipt extraction.""" +import re +from dataclasses import dataclass +from decimal import Decimal, InvalidOperation +from typing import Optional + + +@dataclass +class AmountMatch: + """Result of currency/amount pattern matching.""" + + value: float + raw_match: str + confidence: float + pattern_name: str + label: Optional[str] = None # e.g., "TOTAL", "SUBTOTAL" + + +class CurrencyPatternMatcher: + """Extract and normalize currency amounts from receipt text.""" + + # Total amount patterns (prioritized) + TOTAL_PATTERNS = [ + # TOTAL $XX.XX or TOTAL: $XX.XX + ( + r"(?:^|\s)TOTAL[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)", + "total_explicit", + 0.98, + ), + # AMOUNT DUE $XX.XX + ( + r"AMOUNT\s*DUE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})", + "amount_due", + 0.95, + ), + # SALE $XX.XX + ( + r"(?:^|\s)SALE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)", + "sale_explicit", + 0.92, + ), + # GRAND TOTAL $XX.XX + ( + r"GRAND\s*TOTAL[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})", + "grand_total", + 0.97, + ), + # TOTAL SALE $XX.XX + ( + r"TOTAL\s*SALE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})", + "total_sale", + 0.96, + ), + # BALANCE DUE $XX.XX + ( + r"BALANCE\s*DUE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})", + "balance_due", + 0.94, + ), + # PURCHASE $XX.XX + ( + r"(?:^|\s)PURCHASE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)", + "purchase", + 0.88, + ), + ] + + # Generic amount patterns (lower priority) + AMOUNT_PATTERNS = [ + # $XX.XX (standalone dollar amount) + ( + r"\$\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})", + "dollar_amount", + 0.60, + ), + # XX.XX (standalone decimal amount) + ( + r"(? Optional[AmountMatch]: + """ + Extract the total amount from receipt text. + + Prioritizes explicit total patterns over generic amounts. + + Args: + text: Receipt text to search + + Returns: + AmountMatch for total or None if not found + """ + text_upper = text.upper() + + # Try total-specific patterns first + for pattern, name, confidence in self.TOTAL_PATTERNS: + match = re.search(pattern, text_upper, re.MULTILINE) + if match: + amount = self._parse_amount(match.group(1)) + if amount is not None and self._is_reasonable_total(amount): + return AmountMatch( + value=amount, + raw_match=match.group(0).strip(), + confidence=confidence, + pattern_name=name, + label=self._extract_label(name), + ) + + # Fall back to finding the largest reasonable amount + all_amounts = self.extract_all_amounts(text) + reasonable = [a for a in all_amounts if self._is_reasonable_total(a.value)] + if reasonable: + # Assume largest amount is the total + reasonable.sort(key=lambda x: x.value, reverse=True) + best = reasonable[0] + # Lower confidence since we're guessing + return AmountMatch( + value=best.value, + raw_match=best.raw_match, + confidence=min(0.60, best.confidence), + pattern_name="inferred_total", + label="TOTAL", + ) + + return None + + def extract_all_amounts(self, text: str) -> list[AmountMatch]: + """ + Extract all currency amounts from text. + + Args: + text: Receipt text to search + + Returns: + List of AmountMatch objects + """ + matches = [] + text_upper = text.upper() + + # Check total patterns + for pattern, name, confidence in self.TOTAL_PATTERNS: + for match in re.finditer(pattern, text_upper, re.MULTILINE): + amount = self._parse_amount(match.group(1)) + if amount is not None: + matches.append( + AmountMatch( + value=amount, + raw_match=match.group(0).strip(), + confidence=confidence, + pattern_name=name, + label=self._extract_label(name), + ) + ) + + # Check generic amount patterns + for pattern, name, confidence in self.AMOUNT_PATTERNS: + for match in re.finditer(pattern, text_upper): + amount = self._parse_amount(match.group(1)) + if amount is not None: + # Skip if already found by a more specific pattern + if not any(abs(m.value - amount) < 0.01 for m in matches): + matches.append( + AmountMatch( + value=amount, + raw_match=match.group(0).strip(), + confidence=confidence, + pattern_name=name, + ) + ) + + return matches + + def _parse_amount(self, amount_str: str) -> Optional[float]: + """Parse amount string to float, handling various formats.""" + # Remove any spaces + cleaned = amount_str.strip().replace(" ", "") + + # Handle European format (1.234,56) vs US format (1,234.56) + # For US receipts, assume comma is thousands separator + if "," in cleaned and "." in cleaned: + # Determine which is decimal separator (last one) + if cleaned.rfind(",") > cleaned.rfind("."): + # European format + cleaned = cleaned.replace(".", "").replace(",", ".") + else: + # US format + cleaned = cleaned.replace(",", "") + elif "," in cleaned: + # Could be thousands separator or decimal + parts = cleaned.split(",") + if len(parts) == 2 and len(parts[1]) == 2: + # Likely decimal separator + cleaned = cleaned.replace(",", ".") + else: + # Likely thousands separator + cleaned = cleaned.replace(",", "") + + try: + amount = float(Decimal(cleaned)) + return amount if amount >= 0 else None + except (InvalidOperation, ValueError): + return None + + def _is_reasonable_total(self, amount: float) -> bool: + """Check if amount is a reasonable total for a fuel receipt.""" + # Reasonable range: $1 to $500 for typical fuel purchases + return 1.0 <= amount <= 500.0 + + def _extract_label(self, pattern_name: str) -> str: + """Extract display label from pattern name.""" + labels = { + "total_explicit": "TOTAL", + "amount_due": "AMOUNT DUE", + "sale_explicit": "SALE", + "grand_total": "GRAND TOTAL", + "total_sale": "TOTAL SALE", + "balance_due": "BALANCE DUE", + "purchase": "PURCHASE", + } + return labels.get(pattern_name, "TOTAL") + + +# Singleton instance +currency_matcher = CurrencyPatternMatcher() diff --git a/ocr/app/patterns/date_patterns.py b/ocr/app/patterns/date_patterns.py new file mode 100644 index 0000000..4dd9c38 --- /dev/null +++ b/ocr/app/patterns/date_patterns.py @@ -0,0 +1,186 @@ +"""Date pattern matching for receipt extraction.""" +import re +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class DateMatch: + """Result of date pattern matching.""" + + value: str # ISO format YYYY-MM-DD + raw_match: str # Original text matched + confidence: float + pattern_name: str + + +class DatePatternMatcher: + """Extract and normalize dates from receipt text.""" + + # Pattern definitions with named groups and confidence weights + PATTERNS = [ + # MM/DD/YYYY or MM/DD/YY (most common US format) + ( + r"(?P\d{1,2})/(?P\d{1,2})/(?P\d{2,4})", + "mm_dd_yyyy", + 0.95, + ), + # MM-DD-YYYY or MM-DD-YY + ( + r"(?P\d{1,2})-(?P\d{1,2})-(?P\d{2,4})", + "mm_dd_yyyy_dash", + 0.90, + ), + # YYYY-MM-DD (ISO format) + ( + r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})", + "iso_date", + 0.98, + ), + # Mon DD, YYYY (e.g., Jan 15, 2024) + ( + r"(?P[A-Za-z]{3})\s+(?P\d{1,2}),?\s+(?P\d{4})", + "month_name_long", + 0.85, + ), + # DD Mon YYYY (e.g., 15 Jan 2024) + ( + r"(?P\d{1,2})\s+(?P[A-Za-z]{3})\s+(?P\d{4})", + "day_month_year", + 0.85, + ), + # MMDDYYYY or MMDDYY (no separators, common in some POS systems) + ( + r"(?\d{2})(?P\d{2})(?P\d{2,4})(?!\d)", + "compact_date", + 0.70, + ), + ] + + MONTH_NAMES = { + "jan": 1, "january": 1, + "feb": 2, "february": 2, + "mar": 3, "march": 3, + "apr": 4, "april": 4, + "may": 5, + "jun": 6, "june": 6, + "jul": 7, "july": 7, + "aug": 8, "august": 8, + "sep": 9, "sept": 9, "september": 9, + "oct": 10, "october": 10, + "nov": 11, "november": 11, + "dec": 12, "december": 12, + } + + def extract_dates(self, text: str) -> list[DateMatch]: + """ + Extract all date patterns from text. + + Args: + text: Receipt text to search + + Returns: + List of DateMatch objects sorted by confidence + """ + matches = [] + text_upper = text.upper() + + for pattern, name, base_confidence in self.PATTERNS: + for match in re.finditer(pattern, text, re.IGNORECASE): + parsed = self._parse_match(match, name) + if parsed: + year, month, day = parsed + if self._is_valid_date(year, month, day): + # Adjust confidence based on context + confidence = self._adjust_confidence( + base_confidence, text_upper, match.start() + ) + matches.append( + DateMatch( + value=f"{year:04d}-{month:02d}-{day:02d}", + raw_match=match.group(0), + confidence=confidence, + pattern_name=name, + ) + ) + + # Sort by confidence, deduplicate by value + matches.sort(key=lambda x: x.confidence, reverse=True) + seen = set() + unique_matches = [] + for match in matches: + if match.value not in seen: + seen.add(match.value) + unique_matches.append(match) + + return unique_matches + + def extract_best_date(self, text: str) -> Optional[DateMatch]: + """ + Extract the most likely transaction date. + + Args: + text: Receipt text to search + + Returns: + Best DateMatch or None if no date found + """ + matches = self.extract_dates(text) + return matches[0] if matches else None + + def _parse_match( + self, match: re.Match, pattern_name: str + ) -> Optional[tuple[int, int, int]]: + """Parse regex match into year, month, day tuple.""" + groups = match.groupdict() + + # Handle month name patterns + if "month_name" in groups: + month_str = groups["month_name"].lower() + month = self.MONTH_NAMES.get(month_str) + if not month: + return None + else: + month = int(groups["month"]) + + day = int(groups["day"]) + year = int(groups["year"]) + + # Normalize 2-digit years + if year < 100: + year = 2000 + year if year < 50 else 1900 + year + + return year, month, day + + def _is_valid_date(self, year: int, month: int, day: int) -> bool: + """Check if date components form a valid date.""" + try: + datetime(year=year, month=month, day=day) + # Reasonable year range for receipts + return 2000 <= year <= 2100 + except ValueError: + return False + + def _adjust_confidence( + self, base_confidence: float, text: str, position: int + ) -> float: + """ + Adjust confidence based on context clues. + + Boost confidence if date appears near date-related keywords. + """ + # Look for nearby date keywords + context_start = max(0, position - 50) + context = text[context_start:position + 50] + + date_keywords = ["DATE", "TIME", "TRANS", "SALE"] + for keyword in date_keywords: + if keyword in context: + return min(1.0, base_confidence + 0.05) + + return base_confidence + + +# Singleton instance +date_matcher = DatePatternMatcher() diff --git a/ocr/app/patterns/fuel_patterns.py b/ocr/app/patterns/fuel_patterns.py new file mode 100644 index 0000000..c8c91f3 --- /dev/null +++ b/ocr/app/patterns/fuel_patterns.py @@ -0,0 +1,364 @@ +"""Fuel-specific pattern matching for receipt extraction.""" +import re +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class FuelQuantityMatch: + """Result of fuel quantity pattern matching.""" + + value: float # Gallons or liters + unit: str # "GAL" or "L" + raw_match: str + confidence: float + pattern_name: str + + +@dataclass +class FuelPriceMatch: + """Result of fuel price per unit pattern matching.""" + + value: float + unit: str # "GAL" or "L" + raw_match: str + confidence: float + pattern_name: str + + +@dataclass +class FuelGradeMatch: + """Result of fuel grade pattern matching.""" + + value: str # e.g., "87", "89", "93", "DIESEL" + display_name: str # e.g., "Regular 87", "Premium 93" + raw_match: str + confidence: float + + +class FuelPatternMatcher: + """Extract fuel-specific data from receipt text.""" + + # Gallons patterns + GALLONS_PATTERNS = [ + # XX.XXX GAL or XX.XXX GALLONS + ( + r"(\d{1,3}\.\d{1,3})\s*(?:GAL(?:LON)?S?)", + "gallons_suffix", + 0.95, + ), + # GALLONS: XX.XXX or GAL: XX.XXX + ( + r"(?:GAL(?:LON)?S?)[:\s]+(\d{1,3}\.\d{1,3})", + "gallons_prefix", + 0.93, + ), + # VOLUME XX.XXX + ( + r"VOLUME[:\s]+(\d{1,3}\.\d{1,3})", + "volume", + 0.85, + ), + # QTY XX.XXX (near fuel context) + ( + r"QTY[:\s]+(\d{1,3}\.\d{1,3})", + "qty", + 0.70, + ), + ] + + # Liters patterns (for international receipts) + LITERS_PATTERNS = [ + # XX.XX L or XX.XX LITERS + ( + r"(\d{1,3}\.\d{1,3})\s*(?:L(?:ITERS?)?)", + "liters_suffix", + 0.95, + ), + # LITERS: XX.XX + ( + r"(?:L(?:ITERS?)?)[:\s]+(\d{1,3}\.\d{1,3})", + "liters_prefix", + 0.93, + ), + ] + + # Price per gallon patterns + PRICE_PER_UNIT_PATTERNS = [ + # $X.XXX/GAL or $X.XX/GAL + ( + r"\$?\s*(\d{1,2}\.\d{2,3})\s*/\s*GAL", + "price_per_gal", + 0.98, + ), + # PRICE/GAL $X.XXX + ( + r"PRICE\s*/\s*GAL[:\s]*\$?\s*(\d{1,2}\.\d{2,3})", + "labeled_price_gal", + 0.96, + ), + # UNIT PRICE $X.XXX + ( + r"UNIT\s*PRICE[:\s]*\$?\s*(\d{1,2}\.\d{2,3})", + "unit_price", + 0.90, + ), + # @ $X.XXX (per unit implied) + ( + r"@\s*\$?\s*(\d{1,2}\.\d{2,3})", + "at_price", + 0.85, + ), + # PPG $X.XXX (price per gallon) + ( + r"PPG[:\s]*\$?\s*(\d{1,2}\.\d{2,3})", + "ppg", + 0.92, + ), + ] + + # Fuel grade patterns + GRADE_PATTERNS = [ + # REGULAR 87, REG 87 + (r"(?:REGULAR|REG)\s*(\d{2})", "regular", 0.95), + # UNLEADED 87 + (r"UNLEADED\s*(\d{2})", "unleaded", 0.93), + # PLUS 89, MID 89, MIDGRADE 89 + (r"(?:PLUS|MID(?:GRADE)?)\s*(\d{2})", "plus", 0.95), + # PREMIUM 91/93, PREM 91/93, SUPER 91/93 + (r"(?:PREMIUM|PREM|SUPER)\s*(\d{2})", "premium", 0.95), + # Just the octane number near fuel context (87, 89, 91, 93) + (r"(? Optional[FuelQuantityMatch]: + """ + Extract fuel quantity in gallons. + + Args: + text: Receipt text to search + + Returns: + FuelQuantityMatch or None + """ + text_upper = text.upper() + + for pattern, name, confidence in self.GALLONS_PATTERNS: + match = re.search(pattern, text_upper) + if match: + quantity = float(match.group(1)) + if self._is_reasonable_quantity(quantity): + return FuelQuantityMatch( + value=quantity, + unit="GAL", + raw_match=match.group(0), + confidence=confidence, + pattern_name=name, + ) + + return None + + def extract_liters(self, text: str) -> Optional[FuelQuantityMatch]: + """ + Extract fuel quantity in liters. + + Args: + text: Receipt text to search + + Returns: + FuelQuantityMatch or None + """ + text_upper = text.upper() + + for pattern, name, confidence in self.LITERS_PATTERNS: + match = re.search(pattern, text_upper) + if match: + quantity = float(match.group(1)) + if self._is_reasonable_quantity(quantity, is_liters=True): + return FuelQuantityMatch( + value=quantity, + unit="L", + raw_match=match.group(0), + confidence=confidence, + pattern_name=name, + ) + + return None + + def extract_quantity(self, text: str) -> Optional[FuelQuantityMatch]: + """ + Extract fuel quantity (gallons or liters). + + Prefers gallons for US receipts. + + Args: + text: Receipt text to search + + Returns: + FuelQuantityMatch or None + """ + # Try gallons first (more common in US) + gallons = self.extract_gallons(text) + if gallons: + return gallons + + # Fall back to liters + return self.extract_liters(text) + + def extract_price_per_unit(self, text: str) -> Optional[FuelPriceMatch]: + """ + Extract price per gallon/liter. + + Args: + text: Receipt text to search + + Returns: + FuelPriceMatch or None + """ + text_upper = text.upper() + + for pattern, name, confidence in self.PRICE_PER_UNIT_PATTERNS: + match = re.search(pattern, text_upper) + if match: + price = float(match.group(1)) + if self._is_reasonable_price(price): + return FuelPriceMatch( + value=price, + unit="GAL", # Default to gallons for US + raw_match=match.group(0), + confidence=confidence, + pattern_name=name, + ) + + return None + + def extract_grade(self, text: str) -> Optional[FuelGradeMatch]: + """ + Extract fuel grade (octane rating or diesel). + + Args: + text: Receipt text to search + + Returns: + FuelGradeMatch or None + """ + text_upper = text.upper() + + for pattern, name, confidence in self.GRADE_PATTERNS: + match = re.search(pattern, text_upper) + if match: + if name == "diesel": + return FuelGradeMatch( + value="DIESEL", + display_name="Diesel", + raw_match=match.group(0), + confidence=confidence, + ) + elif name == "e85": + return FuelGradeMatch( + value="E85", + display_name="E85 Ethanol", + raw_match=match.group(0), + confidence=confidence, + ) + else: + octane = match.group(1) + display = self._get_grade_display_name(octane, name) + return FuelGradeMatch( + value=octane, + display_name=display, + raw_match=match.group(0), + confidence=confidence, + ) + + return None + + def extract_merchant_name(self, text: str) -> Optional[tuple[str, float]]: + """ + Extract gas station/merchant name. + + Args: + text: Receipt text to search + + Returns: + Tuple of (merchant_name, confidence) or None + """ + text_upper = text.upper() + + # Check for known station names + for station in self.STATION_NAMES: + if station in text_upper: + # Try to get the full line for context + for line in text.split("\n"): + if station in line.upper(): + # Clean up the line + cleaned = line.strip() + if len(cleaned) <= 50: # Reasonable length + return (cleaned, 0.90) + return (station.title(), 0.85) + + # Fall back to first non-empty line (often the merchant) + lines = [l.strip() for l in text.split("\n") if l.strip()] + if lines: + first_line = lines[0] + # Skip if it looks like a date or number + if not re.match(r"^\d+[/\-.]", first_line): + return (first_line[:50], 0.50) # Low confidence + + return None + + def _is_reasonable_quantity( + self, quantity: float, is_liters: bool = False + ) -> bool: + """Check if fuel quantity is reasonable.""" + if is_liters: + # Typical fill: 20-100 liters + return 0.5 <= quantity <= 150.0 + else: + # Typical fill: 5-30 gallons + return 0.1 <= quantity <= 50.0 + + def _is_reasonable_price(self, price: float) -> bool: + """Check if price per unit is reasonable.""" + # US gas prices: $1.50 - $8.00 per gallon (allowing for fluctuation) + return 1.00 <= price <= 10.00 + + def _get_grade_display_name(self, octane: str, pattern_name: str) -> str: + """Get display name for fuel grade.""" + grade_names = { + "87": "Regular 87", + "89": "Plus 89", + "91": "Premium 91", + "93": "Premium 93", + } + + if octane in grade_names: + return grade_names[octane] + + # Use pattern hint + if pattern_name == "premium": + return f"Premium {octane}" + elif pattern_name == "plus": + return f"Plus {octane}" + else: + return f"Unleaded {octane}" + + +# Singleton instance +fuel_matcher = FuelPatternMatcher() diff --git a/ocr/app/preprocessors/__init__.py b/ocr/app/preprocessors/__init__.py index ff54eee..50e04ed 100644 --- a/ocr/app/preprocessors/__init__.py +++ b/ocr/app/preprocessors/__init__.py @@ -1,10 +1,16 @@ """Image preprocessors for OCR optimization.""" from app.services.preprocessor import ImagePreprocessor, preprocessor from app.preprocessors.vin_preprocessor import VinPreprocessor, vin_preprocessor +from app.preprocessors.receipt_preprocessor import ( + ReceiptPreprocessor, + receipt_preprocessor, +) __all__ = [ "ImagePreprocessor", "preprocessor", "VinPreprocessor", "vin_preprocessor", + "ReceiptPreprocessor", + "receipt_preprocessor", ] diff --git a/ocr/app/preprocessors/receipt_preprocessor.py b/ocr/app/preprocessors/receipt_preprocessor.py new file mode 100644 index 0000000..55ebdb0 --- /dev/null +++ b/ocr/app/preprocessors/receipt_preprocessor.py @@ -0,0 +1,340 @@ +"""Receipt-optimized image preprocessing pipeline.""" +import io +import logging +from dataclasses import dataclass +from typing import Optional + +import cv2 +import numpy as np +from PIL import Image +from pillow_heif import register_heif_opener + +# Register HEIF/HEIC opener +register_heif_opener() + +logger = logging.getLogger(__name__) + + +@dataclass +class ReceiptPreprocessingResult: + """Result of receipt preprocessing.""" + + image_bytes: bytes + preprocessing_applied: list[str] + original_width: int + original_height: int + + +class ReceiptPreprocessor: + """Receipt-optimized image preprocessing for improved OCR accuracy. + + Thermal receipts typically have: + - Low contrast (faded ink) + - Uneven illumination + - Paper curl/skew + - Variable font weights + + This preprocessor addresses these issues with targeted enhancements. + """ + + # Optimal width for receipt OCR (narrow receipts work better) + TARGET_WIDTH = 800 + + def preprocess( + self, + image_bytes: bytes, + apply_contrast: bool = True, + apply_deskew: bool = True, + apply_denoise: bool = True, + apply_threshold: bool = True, + apply_sharpen: bool = True, + ) -> ReceiptPreprocessingResult: + """ + Apply receipt-optimized preprocessing pipeline. + + Pipeline optimized for thermal receipts: + 1. HEIC conversion (if needed) + 2. Grayscale conversion + 3. Resize to optimal width + 4. Deskew (correct rotation) + 5. High contrast enhancement (CLAHE + histogram stretch) + 6. Adaptive sharpening + 7. Noise reduction + 8. Adaptive thresholding (receipt-optimized) + + Args: + image_bytes: Raw image bytes (HEIC, JPEG, PNG) + apply_contrast: Apply contrast enhancement + apply_deskew: Apply deskew correction + apply_denoise: Apply noise reduction + apply_threshold: Apply adaptive thresholding + apply_sharpen: Apply sharpening + + Returns: + ReceiptPreprocessingResult with processed image bytes + """ + steps_applied = [] + + # Load image with PIL (handles HEIC via pillow-heif) + pil_image = Image.open(io.BytesIO(image_bytes)) + original_width, original_height = pil_image.size + steps_applied.append("loaded") + + # Handle EXIF rotation + pil_image = self._fix_orientation(pil_image) + + # Convert to RGB if needed + if pil_image.mode not in ("RGB", "L"): + pil_image = pil_image.convert("RGB") + steps_applied.append("convert_rgb") + + # Convert to OpenCV format + cv_image = np.array(pil_image) + if len(cv_image.shape) == 3: + cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) + + # Convert to grayscale + if len(cv_image.shape) == 3: + gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY) + else: + gray = cv_image + steps_applied.append("grayscale") + + # Resize to optimal width while maintaining aspect ratio + gray = self._resize_optimal(gray) + steps_applied.append("resize") + + # Apply deskew + if apply_deskew: + gray = self._deskew(gray) + steps_applied.append("deskew") + + # Apply high contrast enhancement (critical for thermal receipts) + if apply_contrast: + gray = self._enhance_contrast(gray) + steps_applied.append("contrast") + + # Apply sharpening + if apply_sharpen: + gray = self._sharpen(gray) + steps_applied.append("sharpen") + + # Apply denoising + if apply_denoise: + gray = self._denoise(gray) + steps_applied.append("denoise") + + # Apply adaptive thresholding (receipt-optimized parameters) + if apply_threshold: + gray = self._adaptive_threshold_receipt(gray) + steps_applied.append("threshold") + + # Convert back to PNG bytes + result_image = Image.fromarray(gray) + buffer = io.BytesIO() + result_image.save(buffer, format="PNG") + + logger.debug(f"Receipt preprocessing applied: {steps_applied}") + + return ReceiptPreprocessingResult( + image_bytes=buffer.getvalue(), + preprocessing_applied=steps_applied, + original_width=original_width, + original_height=original_height, + ) + + def _fix_orientation(self, image: Image.Image) -> Image.Image: + """Fix image orientation based on EXIF data.""" + try: + exif = image.getexif() + if exif: + orientation = exif.get(274) # Orientation tag + if orientation: + rotate_values = { + 3: 180, + 6: 270, + 8: 90, + } + if orientation in rotate_values: + return image.rotate( + rotate_values[orientation], expand=True + ) + except Exception as e: + logger.debug(f"Could not read EXIF orientation: {e}") + return image + + def _resize_optimal(self, image: np.ndarray) -> np.ndarray: + """Resize image to optimal width for OCR.""" + height, width = image.shape[:2] + + if width <= self.TARGET_WIDTH: + return image + + scale = self.TARGET_WIDTH / width + new_height = int(height * scale) + + return cv2.resize( + image, + (self.TARGET_WIDTH, new_height), + interpolation=cv2.INTER_AREA, + ) + + def _deskew(self, image: np.ndarray) -> np.ndarray: + """ + Correct image rotation using projection profile. + + Receipts often have slight rotation from scanning/photography. + Uses projection profile method optimized for text documents. + """ + try: + # Create binary image for angle detection + _, binary = cv2.threshold( + image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + + # Find all non-zero points + coords = np.column_stack(np.where(binary > 0)) + if len(coords) < 100: + return image + + # Use minimum area rectangle to find angle + rect = cv2.minAreaRect(coords) + angle = rect[-1] + + # Normalize angle + if angle < -45: + angle = 90 + angle + elif angle > 45: + angle = angle - 90 + + # Only correct if angle is significant but not extreme + if abs(angle) < 0.5 or abs(angle) > 15: + return image + + # Rotate to correct skew + height, width = image.shape[:2] + center = (width // 2, height // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + + rotated = cv2.warpAffine( + image, + rotation_matrix, + (width, height), + borderMode=cv2.BORDER_REPLICATE, + ) + + logger.debug(f"Receipt deskewed by {angle:.2f} degrees") + return rotated + + except Exception as e: + logger.warning(f"Deskew failed: {e}") + return image + + def _enhance_contrast(self, image: np.ndarray) -> np.ndarray: + """ + Apply aggressive contrast enhancement for faded receipts. + + Combines: + 1. Histogram stretching + 2. CLAHE (Contrast Limited Adaptive Histogram Equalization) + """ + try: + # First, stretch histogram to use full dynamic range + p2, p98 = np.percentile(image, (2, 98)) + stretched = np.clip( + (image - p2) * 255.0 / (p98 - p2), 0, 255 + ).astype(np.uint8) + + # Apply CLAHE with parameters optimized for receipts + # Higher clipLimit for faded thermal receipts + clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(stretched) + + return enhanced + + except Exception as e: + logger.warning(f"Contrast enhancement failed: {e}") + return image + + def _sharpen(self, image: np.ndarray) -> np.ndarray: + """ + Apply unsharp masking for clearer text edges. + + Light sharpening improves OCR on slightly blurry images. + """ + try: + # Gaussian blur for unsharp mask + blurred = cv2.GaussianBlur(image, (0, 0), 2.0) + + # Unsharp mask: original + alpha * (original - blurred) + sharpened = cv2.addWeighted(image, 1.5, blurred, -0.5, 0) + + return sharpened + + except Exception as e: + logger.warning(f"Sharpening failed: {e}") + return image + + def _denoise(self, image: np.ndarray) -> np.ndarray: + """ + Apply light denoising optimized for text. + + Uses bilateral filter to preserve edges while reducing noise. + """ + try: + # Bilateral filter preserves edges better than Gaussian + # Light denoising - don't want to blur text + return cv2.bilateralFilter(image, 5, 50, 50) + + except Exception as e: + logger.warning(f"Denoising failed: {e}") + return image + + def _adaptive_threshold_receipt(self, image: np.ndarray) -> np.ndarray: + """ + Apply adaptive thresholding optimized for receipt text. + + Uses parameters tuned for: + - Variable font sizes (small print + headers) + - Faded thermal printing + - Uneven paper illumination + """ + try: + # Use Gaussian adaptive threshold + # Larger block size (31) handles uneven illumination + # Moderate C value (8) for faded receipts + binary = cv2.adaptiveThreshold( + image, + 255, + cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY, + blockSize=31, + C=8, + ) + + return binary + + except Exception as e: + logger.warning(f"Adaptive threshold failed: {e}") + return image + + def preprocess_for_low_quality( + self, image_bytes: bytes + ) -> ReceiptPreprocessingResult: + """ + Apply aggressive preprocessing for very low quality receipts. + + Use this when standard preprocessing fails to produce readable text. + """ + return self.preprocess( + image_bytes, + apply_contrast=True, + apply_deskew=True, + apply_denoise=True, + apply_threshold=True, + apply_sharpen=True, + ) + + +# Singleton instance +receipt_preprocessor = ReceiptPreprocessor() diff --git a/ocr/app/routers/extract.py b/ocr/app/routers/extract.py index 23f8483..32bbdb3 100644 --- a/ocr/app/routers/extract.py +++ b/ocr/app/routers/extract.py @@ -1,10 +1,19 @@ """OCR extraction endpoints.""" import logging +from typing import Optional -from fastapi import APIRouter, File, HTTPException, Query, UploadFile +from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile from app.extractors.vin_extractor import vin_extractor -from app.models import BoundingBox, OcrResponse, VinAlternative, VinExtractionResponse +from app.extractors.receipt_extractor import receipt_extractor +from app.models import ( + BoundingBox, + OcrResponse, + ReceiptExtractedField, + ReceiptExtractionResponse, + VinAlternative, + VinExtractionResponse, +) from app.services import ocr_service logger = logging.getLogger(__name__) @@ -154,3 +163,97 @@ async def extract_vin( processingTimeMs=result.processing_time_ms, error=result.error, ) + + +@router.post("/receipt", response_model=ReceiptExtractionResponse) +async def extract_receipt( + file: UploadFile = File(..., description="Receipt image file"), + receipt_type: Optional[str] = Form( + default=None, + description="Receipt type hint: 'fuel' for specialized extraction", + ), +) -> ReceiptExtractionResponse: + """ + Extract data from a receipt image using OCR. + + Optimized for fuel receipts with pattern-based field extraction: + - HEIC conversion (if needed) + - Grayscale conversion + - High contrast enhancement (for thermal receipts) + - Adaptive thresholding + - Pattern matching for dates, amounts, fuel quantities + + Supports HEIC, JPEG, PNG formats. + Processing time target: <3 seconds. + + - **file**: Receipt image file (max 10MB) + - **receipt_type**: Optional hint ("fuel" for gas station receipts) + + Returns: + - **receiptType**: Detected type ("fuel" or "unknown") + - **extractedFields**: Dictionary of extracted fields with confidence scores + - merchantName: Gas station or store name + - transactionDate: Date in YYYY-MM-DD format + - totalAmount: Total purchase amount + - fuelQuantity: Gallons/liters purchased (fuel receipts) + - pricePerUnit: Price per gallon/liter (fuel receipts) + - fuelGrade: Octane rating or fuel type (fuel receipts) + - **rawText**: Full OCR text + - **processingTimeMs**: Processing time in milliseconds + """ + # Validate file presence + if not file.filename: + raise HTTPException(status_code=400, detail="No file provided") + + # Read file content + content = await file.read() + file_size = len(content) + + # Validate file size + if file_size > MAX_SYNC_SIZE: + raise HTTPException( + status_code=413, + detail=f"File too large. Max: {MAX_SYNC_SIZE // (1024*1024)}MB", + ) + + if file_size == 0: + raise HTTPException(status_code=400, detail="Empty file provided") + + logger.info( + f"Receipt extraction: {file.filename}, " + f"size: {file_size} bytes, " + f"content_type: {file.content_type}, " + f"receipt_type: {receipt_type}" + ) + + # Perform receipt extraction + result = receipt_extractor.extract( + image_bytes=content, + content_type=file.content_type, + receipt_type=receipt_type, + ) + + if not result.success: + logger.warning(f"Receipt extraction failed for {file.filename}: {result.error}") + raise HTTPException( + status_code=422, + detail=result.error or "Failed to extract data from receipt image", + ) + + # Convert internal fields to API response format + extracted_fields = { + name: ReceiptExtractedField( + value=field.value, + confidence=field.confidence, + ) + for name, field in result.extracted_fields.items() + } + + return ReceiptExtractionResponse( + success=result.success, + receiptType=result.receipt_type, + extractedFields=extracted_fields, + rawText=result.raw_text, + processingTimeMs=result.processing_time_ms, + error=result.error, + ) diff --git a/ocr/tests/test_currency_patterns.py b/ocr/tests/test_currency_patterns.py new file mode 100644 index 0000000..b33b398 --- /dev/null +++ b/ocr/tests/test_currency_patterns.py @@ -0,0 +1,198 @@ +"""Tests for currency pattern matching.""" +import pytest + +from app.patterns.currency_patterns import CurrencyPatternMatcher, currency_matcher + + +class TestCurrencyPatternMatcher: + """Test currency and amount extraction.""" + + def test_total_explicit(self) -> None: + """Test 'TOTAL $XX.XX' pattern.""" + text = "TOTAL $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + assert result.confidence > 0.9 + assert result.label == "TOTAL" + + def test_total_with_colon(self) -> None: + """Test 'TOTAL: $XX.XX' pattern.""" + text = "TOTAL: $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + + def test_total_without_dollar_sign(self) -> None: + """Test 'TOTAL 45.67' pattern.""" + text = "TOTAL 45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + + def test_amount_due(self) -> None: + """Test 'AMOUNT DUE' pattern.""" + text = "AMOUNT DUE: $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + assert result.label == "AMOUNT DUE" + + def test_sale_pattern(self) -> None: + """Test 'SALE $XX.XX' pattern.""" + text = "SALE $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + + def test_grand_total(self) -> None: + """Test 'GRAND TOTAL' pattern.""" + text = "GRAND TOTAL $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + assert result.label == "GRAND TOTAL" + + def test_total_sale(self) -> None: + """Test 'TOTAL SALE' pattern.""" + text = "TOTAL SALE: $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + + def test_balance_due(self) -> None: + """Test 'BALANCE DUE' pattern.""" + text = "BALANCE DUE $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + + def test_multiple_amounts_picks_total(self) -> None: + """Test that labeled total is preferred over generic amounts.""" + text = """ + REGULAR 87 + 10.500 GAL @ $3.67 + SUBTOTAL $38.54 + TAX $0.00 + TOTAL $38.54 + """ + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 38.54 + assert result.pattern_name == "total_explicit" + + def test_all_amounts(self) -> None: + """Test extracting all amounts from receipt.""" + text = """ + SUBTOTAL $35.00 + TAX $3.54 + TOTAL $38.54 + """ + results = currency_matcher.extract_all_amounts(text) + + # Should find TOTAL and possibly others + assert len(results) >= 1 + assert any(r.value == 38.54 for r in results) + + def test_comma_thousand_separator(self) -> None: + """Test amounts with thousand separators.""" + text = "TOTAL $1,234.56" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 1234.56 + + def test_reasonable_total_range(self) -> None: + """Test that unreasonable totals are filtered.""" + # Very small amount + text = "TOTAL $0.05" + result = currency_matcher.extract_total(text) + assert result is None # Too small for fuel receipt + + # Reasonable amount + text = "TOTAL $45.67" + result = currency_matcher.extract_total(text) + assert result is not None + + def test_receipt_context_extraction(self) -> None: + """Test extraction from realistic receipt text.""" + text = """ + SHELL + 123 MAIN ST + DATE: 01/15/2024 + + UNLEADED 87 + 10.500 GAL + @ $3.679/GAL + + FUEL TOTAL $38.63 + TAX $0.00 + TOTAL $38.63 + + DEBIT CARD + ************1234 + """ + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 38.63 + + def test_no_total_returns_largest(self) -> None: + """Test fallback to largest amount when no labeled total.""" + text = """ + $10.50 + $5.00 + $45.67 + """ + result = currency_matcher.extract_total(text) + + # Should infer largest reasonable amount as total + assert result is not None + assert result.value == 45.67 + assert result.confidence < 0.7 # Lower confidence for inferred + + def test_no_amounts_returns_none(self) -> None: + """Test that text without amounts returns None.""" + text = "SHELL STATION\nPUMP 5" + result = currency_matcher.extract_total(text) + + assert result is None + + +class TestEdgeCases: + """Test edge cases in currency parsing.""" + + def test_european_format(self) -> None: + """Test European format (comma as decimal).""" + # European: 45,67 means 45.67 + text = "TOTAL 45,67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + + def test_spaces_in_amount(self) -> None: + """Test handling of spaces around amounts.""" + text = "TOTAL $ 45.67" + result = currency_matcher.extract_total(text) + + assert result is not None + assert result.value == 45.67 + + def test_case_insensitive(self) -> None: + """Test case insensitive matching.""" + for label in ["TOTAL", "Total", "total"]: + text = f"{label} $45.67" + result = currency_matcher.extract_total(text) + + assert result is not None, f"Failed for {label}" + assert result.value == 45.67 diff --git a/ocr/tests/test_date_patterns.py b/ocr/tests/test_date_patterns.py new file mode 100644 index 0000000..9fed908 --- /dev/null +++ b/ocr/tests/test_date_patterns.py @@ -0,0 +1,163 @@ +"""Tests for date pattern matching.""" +import pytest + +from app.patterns.date_patterns import DatePatternMatcher, date_matcher + + +class TestDatePatternMatcher: + """Test date pattern extraction.""" + + def test_mm_dd_yyyy_slash(self) -> None: + """Test MM/DD/YYYY format.""" + text = "DATE: 01/15/2024" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + assert result.confidence > 0.9 + + def test_mm_dd_yy_slash(self) -> None: + """Test MM/DD/YY format with 2-digit year.""" + text = "01/15/24" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + + def test_mm_dd_yyyy_dash(self) -> None: + """Test MM-DD-YYYY format.""" + text = "01-15-2024" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + + def test_iso_format(self) -> None: + """Test ISO YYYY-MM-DD format.""" + text = "2024-01-15" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + assert result.confidence > 0.95 + + def test_month_name_format(self) -> None: + """Test 'Jan 15, 2024' format.""" + text = "Jan 15, 2024" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + + def test_month_name_no_comma(self) -> None: + """Test 'Jan 15 2024' format without comma.""" + text = "Jan 15 2024" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + + def test_day_month_year_format(self) -> None: + """Test '15 Jan 2024' format.""" + text = "15 Jan 2024" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + + def test_full_month_name(self) -> None: + """Test full month name like 'January'.""" + text = "January 15, 2024" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + + def test_multiple_dates_returns_best(self) -> None: + """Test that multiple dates returns highest confidence.""" + text = "Date: 01/15/2024\nExpires: 01/15/2025" + results = date_matcher.extract_dates(text) + + assert len(results) == 2 + # Both should be valid + assert all(r.confidence > 0.5 for r in results) + + def test_invalid_date_rejected(self) -> None: + """Test that invalid dates are rejected.""" + text = "13/45/2024" # Invalid month/day + result = date_matcher.extract_best_date(text) + + assert result is None + + def test_receipt_context_text(self) -> None: + """Test date extraction from realistic receipt text.""" + text = """ + SHELL STATION + 123 MAIN ST + DATE: 01/15/2024 + TIME: 14:32 + PUMP #5 + REGULAR 87 + 10.500 GAL + TOTAL $38.50 + """ + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-01-15" + + def test_no_date_returns_none(self) -> None: + """Test that text without dates returns None.""" + text = "SHELL STATION\nTOTAL $38.50" + result = date_matcher.extract_best_date(text) + + assert result is None + + def test_confidence_boost_near_keyword(self) -> None: + """Test confidence boost when date is near DATE keyword.""" + text_with_keyword = "DATE: 01/15/2024" + text_without = "01/15/2024" + + result_with = date_matcher.extract_best_date(text_with_keyword) + result_without = date_matcher.extract_best_date(text_without) + + assert result_with is not None + assert result_without is not None + # Keyword proximity should boost confidence + assert result_with.confidence >= result_without.confidence + + +class TestEdgeCases: + """Test edge cases in date parsing.""" + + def test_year_2000(self) -> None: + """Test 2-digit year 00 is parsed as 2000.""" + text = "01/15/00" + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2000-01-15" + + def test_leap_year_date(self) -> None: + """Test Feb 29 on leap year.""" + text = "02/29/2024" # 2024 is a leap year + result = date_matcher.extract_best_date(text) + + assert result is not None + assert result.value == "2024-02-29" + + def test_leap_year_invalid(self) -> None: + """Test Feb 29 on non-leap year is rejected.""" + text = "02/29/2023" # 2023 is not a leap year + result = date_matcher.extract_best_date(text) + + assert result is None + + def test_september_abbrev(self) -> None: + """Test September abbreviation (Sept vs Sep).""" + for abbrev in ["Sep", "Sept", "September"]: + text = f"{abbrev} 15, 2024" + result = date_matcher.extract_best_date(text) + + assert result is not None, f"Failed for {abbrev}" + assert result.value == "2024-09-15" diff --git a/ocr/tests/test_fuel_patterns.py b/ocr/tests/test_fuel_patterns.py new file mode 100644 index 0000000..7bec6ed --- /dev/null +++ b/ocr/tests/test_fuel_patterns.py @@ -0,0 +1,327 @@ +"""Tests for fuel-specific pattern matching.""" +import pytest + +from app.patterns.fuel_patterns import FuelPatternMatcher, fuel_matcher + + +class TestFuelQuantityExtraction: + """Test fuel quantity (gallons/liters) extraction.""" + + def test_gallons_suffix(self) -> None: + """Test 'XX.XXX GAL' pattern.""" + text = "10.500 GAL" + result = fuel_matcher.extract_gallons(text) + + assert result is not None + assert result.value == 10.5 + assert result.unit == "GAL" + assert result.confidence > 0.9 + + def test_gallons_full_word(self) -> None: + """Test 'XX.XXX GALLONS' pattern.""" + text = "10.500 GALLONS" + result = fuel_matcher.extract_gallons(text) + + assert result is not None + assert result.value == 10.5 + + def test_gallons_prefix(self) -> None: + """Test 'GALLONS: XX.XXX' pattern.""" + text = "GALLONS: 10.500" + result = fuel_matcher.extract_gallons(text) + + assert result is not None + assert result.value == 10.5 + + def test_gal_prefix(self) -> None: + """Test 'GAL: XX.XXX' pattern.""" + text = "GAL: 10.500" + result = fuel_matcher.extract_gallons(text) + + assert result is not None + assert result.value == 10.5 + + def test_volume_label(self) -> None: + """Test 'VOLUME XX.XXX' pattern.""" + text = "VOLUME: 10.500" + result = fuel_matcher.extract_gallons(text) + + assert result is not None + assert result.value == 10.5 + + def test_liters_suffix(self) -> None: + """Test 'XX.XX L' pattern.""" + text = "40.5 L" + result = fuel_matcher.extract_liters(text) + + assert result is not None + assert result.value == 40.5 + assert result.unit == "L" + + def test_liters_full_word(self) -> None: + """Test 'XX.XX LITERS' pattern.""" + text = "40.5 LITERS" + result = fuel_matcher.extract_liters(text) + + assert result is not None + assert result.value == 40.5 + + def test_quantity_prefers_gallons(self) -> None: + """Test extract_quantity prefers gallons for US receipts.""" + text = "10.500 GAL" + result = fuel_matcher.extract_quantity(text) + + assert result is not None + assert result.unit == "GAL" + + def test_reasonable_quantity_filter(self) -> None: + """Test unreasonable quantities are filtered.""" + # Too small + text = "0.001 GAL" + result = fuel_matcher.extract_gallons(text) + assert result is None + + # Too large + text = "100.0 GAL" + result = fuel_matcher.extract_gallons(text) + assert result is None + + +class TestFuelPriceExtraction: + """Test price per unit extraction.""" + + def test_price_per_gal_dollar_sign(self) -> None: + """Test '$X.XXX/GAL' pattern.""" + text = "$3.679/GAL" + result = fuel_matcher.extract_price_per_unit(text) + + assert result is not None + assert result.value == 3.679 + assert result.unit == "GAL" + assert result.confidence > 0.95 + + def test_price_per_gal_no_dollar(self) -> None: + """Test 'X.XXX/GAL' pattern.""" + text = "3.679/GAL" + result = fuel_matcher.extract_price_per_unit(text) + + assert result is not None + assert result.value == 3.679 + + def test_labeled_price_gal(self) -> None: + """Test 'PRICE/GAL $X.XXX' pattern.""" + text = "PRICE/GAL $3.679" + result = fuel_matcher.extract_price_per_unit(text) + + assert result is not None + assert result.value == 3.679 + + def test_unit_price(self) -> None: + """Test 'UNIT PRICE $X.XXX' pattern.""" + text = "UNIT PRICE: $3.679" + result = fuel_matcher.extract_price_per_unit(text) + + assert result is not None + assert result.value == 3.679 + + def test_at_price(self) -> None: + """Test '@ $X.XXX' pattern.""" + text = "10.500 GAL @ $3.679" + result = fuel_matcher.extract_price_per_unit(text) + + assert result is not None + assert result.value == 3.679 + + def test_ppg_pattern(self) -> None: + """Test 'PPG $X.XXX' pattern.""" + text = "PPG: $3.679" + result = fuel_matcher.extract_price_per_unit(text) + + assert result is not None + assert result.value == 3.679 + + def test_reasonable_price_filter(self) -> None: + """Test unreasonable prices are filtered.""" + # Too low + text = "$0.50/GAL" + result = fuel_matcher.extract_price_per_unit(text) + assert result is None + + # Too high + text = "$15.00/GAL" + result = fuel_matcher.extract_price_per_unit(text) + assert result is None + + +class TestFuelGradeExtraction: + """Test fuel grade/octane extraction.""" + + def test_regular_87(self) -> None: + """Test 'REGULAR 87' pattern.""" + text = "REGULAR 87" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "87" + assert "Regular" in result.display_name + + def test_reg_87(self) -> None: + """Test 'REG 87' pattern.""" + text = "REG 87" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "87" + + def test_unleaded_87(self) -> None: + """Test 'UNLEADED 87' pattern.""" + text = "UNLEADED 87" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "87" + + def test_plus_89(self) -> None: + """Test 'PLUS 89' pattern.""" + text = "PLUS 89" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "89" + assert "Plus" in result.display_name + + def test_midgrade_89(self) -> None: + """Test 'MIDGRADE 89' pattern.""" + text = "MIDGRADE 89" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "89" + + def test_premium_93(self) -> None: + """Test 'PREMIUM 93' pattern.""" + text = "PREMIUM 93" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "93" + assert "Premium" in result.display_name + + def test_super_93(self) -> None: + """Test 'SUPER 93' pattern.""" + text = "SUPER 93" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "93" + + def test_diesel(self) -> None: + """Test 'DIESEL' pattern.""" + text = "DIESEL #2" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "DIESEL" + assert "Diesel" in result.display_name + + def test_e85(self) -> None: + """Test 'E85' ethanol pattern.""" + text = "E85" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "E85" + + def test_octane_only(self) -> None: + """Test standalone octane number.""" + text = "87 OCTANE" + result = fuel_matcher.extract_grade(text) + + assert result is not None + assert result.value == "87" + + +class TestMerchantExtraction: + """Test gas station name extraction.""" + + def test_shell_station(self) -> None: + """Test Shell station detection.""" + text = "SHELL\n123 MAIN ST" + result = fuel_matcher.extract_merchant_name(text) + + assert result is not None + merchant, confidence = result + assert "SHELL" in merchant.upper() + assert confidence > 0.8 + + def test_chevron_station(self) -> None: + """Test Chevron station detection.""" + text = "CHEVRON #12345\n456 OAK AVE" + result = fuel_matcher.extract_merchant_name(text) + + assert result is not None + merchant, confidence = result + assert "CHEVRON" in merchant.upper() + + def test_costco_gas(self) -> None: + """Test Costco gas detection.""" + text = "COSTCO GASOLINE\n789 WAREHOUSE BLVD" + result = fuel_matcher.extract_merchant_name(text) + + assert result is not None + merchant, confidence = result + assert "COSTCO" in merchant.upper() + + def test_unknown_station_fallback(self) -> None: + """Test fallback to first line for unknown stations.""" + text = "JOE'S GAS\n123 MAIN ST" + result = fuel_matcher.extract_merchant_name(text) + + assert result is not None + merchant, confidence = result + assert "JOE'S GAS" in merchant + assert confidence < 0.7 # Lower confidence for unknown + + +class TestReceiptContextExtraction: + """Test extraction from realistic receipt text.""" + + def test_full_receipt_extraction(self) -> None: + """Test all fields from complete receipt text.""" + text = """ + SHELL + 123 MAIN STREET + ANYTOWN, USA 12345 + + DATE: 01/15/2024 + TIME: 14:32 + PUMP #5 + + REGULAR 87 + 10.500 GAL @ $3.679/GAL + + FUEL TOTAL $38.63 + TAX $0.00 + TOTAL $38.63 + + DEBIT CARD + ************1234 + APPROVED + """ + + # Test all extractors on this text + quantity = fuel_matcher.extract_quantity(text) + assert quantity is not None + assert quantity.value == 10.5 + + price = fuel_matcher.extract_price_per_unit(text) + assert price is not None + assert price.value == 3.679 + + grade = fuel_matcher.extract_grade(text) + assert grade is not None + assert grade.value == "87" + + merchant = fuel_matcher.extract_merchant_name(text) + assert merchant is not None + assert "SHELL" in merchant[0].upper() diff --git a/ocr/tests/test_receipt_extraction.py b/ocr/tests/test_receipt_extraction.py new file mode 100644 index 0000000..6442a43 --- /dev/null +++ b/ocr/tests/test_receipt_extraction.py @@ -0,0 +1,339 @@ +"""Tests for receipt extraction pipeline.""" +import io +import pytest +from unittest.mock import MagicMock, patch + +from app.extractors.receipt_extractor import ( + ReceiptExtractor, + ReceiptExtractionResult, + receipt_extractor, +) +from app.extractors.fuel_receipt import ( + FuelReceiptExtractor, + FuelReceiptValidation, + fuel_receipt_extractor, +) + + +class TestReceiptExtractor: + """Test the receipt extraction pipeline.""" + + def test_detect_receipt_type_fuel(self) -> None: + """Test fuel receipt type detection.""" + text = """ + SHELL STATION + REGULAR 87 + 10.500 GAL + TOTAL $38.50 + """ + extractor = ReceiptExtractor() + receipt_type = extractor._detect_receipt_type(text) + + assert receipt_type == "fuel" + + def test_detect_receipt_type_unknown(self) -> None: + """Test unknown receipt type detection.""" + text = """ + WALMART + GROCERIES + MILK $3.99 + BREAD $2.50 + TOTAL $6.49 + """ + extractor = ReceiptExtractor() + receipt_type = extractor._detect_receipt_type(text) + + assert receipt_type == "unknown" + + def test_extract_fuel_fields(self) -> None: + """Test fuel field extraction from OCR text.""" + text = """ + SHELL + 123 MAIN ST + DATE: 01/15/2024 + REGULAR 87 + 10.500 GAL @ $3.679 + TOTAL $38.63 + """ + extractor = ReceiptExtractor() + fields = extractor._extract_fuel_fields(text) + + assert "merchantName" in fields + assert "transactionDate" in fields + assert "totalAmount" in fields + assert "fuelQuantity" in fields + assert "pricePerUnit" in fields + assert "fuelGrade" in fields + + assert fields["totalAmount"].value == 38.63 + assert fields["fuelQuantity"].value == 10.5 + assert fields["fuelGrade"].value == "87" + + def test_extract_generic_fields(self) -> None: + """Test generic field extraction.""" + text = """ + WALMART + 01/15/2024 + TOTAL $25.99 + """ + extractor = ReceiptExtractor() + fields = extractor._extract_generic_fields(text) + + assert "transactionDate" in fields + assert "totalAmount" in fields + assert fields["totalAmount"].value == 25.99 + + def test_calculated_price_per_unit(self) -> None: + """Test price per unit calculation when not explicitly stated.""" + text = """ + SHELL + DATE: 01/15/2024 + 10.000 GAL + TOTAL $35.00 + """ + extractor = ReceiptExtractor() + fields = extractor._extract_fuel_fields(text) + + assert "pricePerUnit" in fields + # 35.00 / 10.000 = 3.50 + assert abs(fields["pricePerUnit"].value - 3.50) < 0.01 + # Calculated values should have lower confidence + assert fields["pricePerUnit"].confidence < 0.9 + + def test_validate_valid_data(self) -> None: + """Test validation of valid receipt data.""" + extractor = ReceiptExtractor() + + data = {"totalAmount": 38.63, "transactionDate": "2024-01-15"} + assert extractor.validate(data) is True + + def test_validate_invalid_data(self) -> None: + """Test validation of invalid receipt data.""" + extractor = ReceiptExtractor() + + # Empty dict + assert extractor.validate({}) is False + # Not a dict + assert extractor.validate("invalid") is False + + def test_unsupported_file_type(self) -> None: + """Test handling of unsupported file types.""" + extractor = ReceiptExtractor() + + with patch.object(extractor, "_detect_mime_type", return_value="application/pdf"): + result = extractor.extract(b"fake pdf content") + + assert result.success is False + assert "Unsupported file type" in result.error + + +class TestFuelReceiptExtractor: + """Test fuel receipt specialized extractor.""" + + def test_validation_success(self) -> None: + """Test validation passes for consistent data.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=38.63, confidence=0.95), + "fuelQuantity": ExtractedField(value=10.5, confidence=0.90), + "pricePerUnit": ExtractedField(value=3.679, confidence=0.92), + "fuelGrade": ExtractedField(value="87", confidence=0.88), + } + + validation = extractor._validate_fuel_receipt(fields) + + assert validation.is_valid is True + assert len(validation.issues) == 0 + assert validation.confidence_score == 1.0 + + def test_validation_math_mismatch(self) -> None: + """Test validation catches total != quantity * price.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=50.00, confidence=0.95), + "fuelQuantity": ExtractedField(value=10.0, confidence=0.90), + "pricePerUnit": ExtractedField(value=3.00, confidence=0.92), # Should be $30 + } + + validation = extractor._validate_fuel_receipt(fields) + + assert validation.is_valid is False + assert any("doesn't match" in issue for issue in validation.issues) + + def test_validation_quantity_too_small(self) -> None: + """Test validation catches too-small quantity.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=1.00, confidence=0.95), + "fuelQuantity": ExtractedField(value=0.1, confidence=0.90), + } + + validation = extractor._validate_fuel_receipt(fields) + + assert any("too small" in issue for issue in validation.issues) + + def test_validation_quantity_too_large(self) -> None: + """Test validation warns on very large quantity.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=150.00, confidence=0.95), + "fuelQuantity": ExtractedField(value=45.0, confidence=0.90), + } + + validation = extractor._validate_fuel_receipt(fields) + + assert any("unusually large" in issue for issue in validation.issues) + + def test_validation_price_too_low(self) -> None: + """Test validation catches too-low price.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=10.00, confidence=0.95), + "pricePerUnit": ExtractedField(value=1.00, confidence=0.90), + } + + validation = extractor._validate_fuel_receipt(fields) + + assert any("too low" in issue for issue in validation.issues) + + def test_validation_unknown_grade(self) -> None: + """Test validation catches unknown fuel grade.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=38.00, confidence=0.95), + "fuelGrade": ExtractedField(value="95", confidence=0.70), # Not valid US grade + } + + validation = extractor._validate_fuel_receipt(fields) + + assert any("Unknown fuel grade" in issue for issue in validation.issues) + + def test_confidence_adjustment_boost(self) -> None: + """Test confidence boost when validation passes.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=38.63, confidence=0.80), + } + validation = FuelReceiptValidation( + is_valid=True, issues=[], confidence_score=1.0 + ) + + adjusted = extractor._adjust_confidences(fields, validation) + + # Should be boosted by 1.1 + assert adjusted["totalAmount"].confidence == min(1.0, 0.80 * 1.1) + + def test_confidence_adjustment_reduce(self) -> None: + """Test confidence reduction when validation fails.""" + from app.extractors.receipt_extractor import ExtractedField + + extractor = FuelReceiptExtractor() + fields = { + "totalAmount": ExtractedField(value=50.00, confidence=0.90), + } + validation = FuelReceiptValidation( + is_valid=False, issues=["Math mismatch"], confidence_score=0.7 + ) + + adjusted = extractor._adjust_confidences(fields, validation) + + # Should be reduced by 0.7 + assert adjusted["totalAmount"].confidence == 0.90 * 0.7 + + +class TestReceiptPreprocessing: + """Test receipt preprocessing integration.""" + + def test_preprocessing_result_structure(self) -> None: + """Test preprocessing returns expected structure.""" + from app.preprocessors.receipt_preprocessor import ( + receipt_preprocessor, + ReceiptPreprocessingResult, + ) + + # Create a simple test image (1x1 white pixel PNG) + from PIL import Image + + img = Image.new("RGB", (100, 100), color="white") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + image_bytes = buffer.getvalue() + + result = receipt_preprocessor.preprocess(image_bytes) + + assert isinstance(result, ReceiptPreprocessingResult) + assert len(result.image_bytes) > 0 + assert "loaded" in result.preprocessing_applied + assert "grayscale" in result.preprocessing_applied + assert result.original_width == 100 + assert result.original_height == 100 + + def test_preprocessing_steps_applied(self) -> None: + """Test all preprocessing steps are applied.""" + from app.preprocessors.receipt_preprocessor import receipt_preprocessor + from PIL import Image + + img = Image.new("RGB", (100, 100), color="white") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + image_bytes = buffer.getvalue() + + result = receipt_preprocessor.preprocess( + image_bytes, + apply_contrast=True, + apply_deskew=True, + apply_denoise=True, + apply_threshold=True, + apply_sharpen=True, + ) + + # Check that expected steps are in the applied list + assert "contrast" in result.preprocessing_applied + assert "denoise" in result.preprocessing_applied + assert "threshold" in result.preprocessing_applied + + +class TestEndpointIntegration: + """Test receipt extraction endpoint integration.""" + + @pytest.fixture + def test_client(self): + """Create test client for FastAPI app.""" + from fastapi.testclient import TestClient + from app.main import app + + return TestClient(app) + + def test_receipt_endpoint_exists(self, test_client) -> None: + """Test that /extract/receipt endpoint exists.""" + # Should get 422 (no file) not 404 (not found) + response = test_client.post("/extract/receipt") + assert response.status_code == 422 # Unprocessable Entity (missing file) + + def test_receipt_endpoint_no_file(self, test_client) -> None: + """Test endpoint returns error when no file provided.""" + response = test_client.post("/extract/receipt") + assert response.status_code == 422 + + def test_receipt_endpoint_empty_file(self, test_client) -> None: + """Test endpoint returns error for empty file.""" + response = test_client.post( + "/extract/receipt", + files={"file": ("receipt.jpg", b"", "image/jpeg")}, + ) + assert response.status_code == 400 + assert "Empty file" in response.json()["detail"]