feat: Receipt OCR Pipeline (#69) #77
@@ -1,10 +1,23 @@
|
|||||||
"""Extractors package for domain-specific OCR extraction."""
|
"""Extractors package for domain-specific OCR extraction."""
|
||||||
from app.extractors.base import BaseExtractor, ExtractionResult
|
from app.extractors.base import BaseExtractor, ExtractionResult
|
||||||
from app.extractors.vin_extractor import VinExtractor, vin_extractor
|
from app.extractors.vin_extractor import VinExtractor, vin_extractor
|
||||||
|
from app.extractors.receipt_extractor import (
|
||||||
|
ReceiptExtractor,
|
||||||
|
receipt_extractor,
|
||||||
|
ReceiptExtractionResult,
|
||||||
|
ExtractedField,
|
||||||
|
)
|
||||||
|
from app.extractors.fuel_receipt import FuelReceiptExtractor, fuel_receipt_extractor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseExtractor",
|
"BaseExtractor",
|
||||||
"ExtractionResult",
|
"ExtractionResult",
|
||||||
"VinExtractor",
|
"VinExtractor",
|
||||||
"vin_extractor",
|
"vin_extractor",
|
||||||
|
"ReceiptExtractor",
|
||||||
|
"receipt_extractor",
|
||||||
|
"ReceiptExtractionResult",
|
||||||
|
"ExtractedField",
|
||||||
|
"FuelReceiptExtractor",
|
||||||
|
"fuel_receipt_extractor",
|
||||||
]
|
]
|
||||||
|
|||||||
193
ocr/app/extractors/fuel_receipt.py
Normal file
193
ocr/app/extractors/fuel_receipt.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""Fuel receipt specialization with validation and cross-checking."""
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.extractors.receipt_extractor import (
|
||||||
|
ExtractedField,
|
||||||
|
ReceiptExtractionResult,
|
||||||
|
receipt_extractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FuelReceiptValidation:
|
||||||
|
"""Validation result for fuel receipt extraction."""
|
||||||
|
|
||||||
|
is_valid: bool
|
||||||
|
issues: list[str]
|
||||||
|
confidence_score: float
|
||||||
|
|
||||||
|
|
||||||
|
class FuelReceiptExtractor:
|
||||||
|
"""Specialized fuel receipt extractor with cross-validation.
|
||||||
|
|
||||||
|
Provides additional validation and confidence scoring specific
|
||||||
|
to fuel receipts by cross-checking extracted values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Expected fields for a complete fuel receipt
|
||||||
|
REQUIRED_FIELDS = ["totalAmount"]
|
||||||
|
OPTIONAL_FIELDS = [
|
||||||
|
"merchantName",
|
||||||
|
"transactionDate",
|
||||||
|
"fuelQuantity",
|
||||||
|
"pricePerUnit",
|
||||||
|
"fuelGrade",
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
image_bytes: bytes,
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> ReceiptExtractionResult:
|
||||||
|
"""
|
||||||
|
Extract fuel receipt data with validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: Raw image bytes
|
||||||
|
content_type: MIME type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReceiptExtractionResult with fuel-specific extraction
|
||||||
|
"""
|
||||||
|
# Use base receipt extractor with fuel hint
|
||||||
|
result = receipt_extractor.extract(
|
||||||
|
image_bytes=image_bytes,
|
||||||
|
content_type=content_type,
|
||||||
|
receipt_type="fuel",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Validate and cross-check fuel fields
|
||||||
|
validation = self._validate_fuel_receipt(result.extracted_fields)
|
||||||
|
|
||||||
|
if validation.issues:
|
||||||
|
logger.warning(
|
||||||
|
f"Fuel receipt validation issues: {validation.issues}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update overall confidence based on validation
|
||||||
|
result.extracted_fields = self._adjust_confidences(
|
||||||
|
result.extracted_fields, validation
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _validate_fuel_receipt(
|
||||||
|
self, fields: dict[str, ExtractedField]
|
||||||
|
) -> FuelReceiptValidation:
|
||||||
|
"""
|
||||||
|
Validate extracted fuel receipt fields.
|
||||||
|
|
||||||
|
Cross-checks:
|
||||||
|
- total = quantity * price per unit (within tolerance)
|
||||||
|
- quantity is reasonable for a single fill-up
|
||||||
|
- price per unit is within market range
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields: Extracted fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FuelReceiptValidation with issues and confidence
|
||||||
|
"""
|
||||||
|
issues = []
|
||||||
|
confidence_score = 1.0
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
for field_name in self.REQUIRED_FIELDS:
|
||||||
|
if field_name not in fields:
|
||||||
|
issues.append(f"Missing required field: {field_name}")
|
||||||
|
confidence_score *= 0.5
|
||||||
|
|
||||||
|
# Cross-validate total = quantity * price
|
||||||
|
if all(
|
||||||
|
f in fields for f in ["totalAmount", "fuelQuantity", "pricePerUnit"]
|
||||||
|
):
|
||||||
|
total = fields["totalAmount"].value
|
||||||
|
quantity = fields["fuelQuantity"].value
|
||||||
|
price = fields["pricePerUnit"].value
|
||||||
|
|
||||||
|
calculated_total = quantity * price
|
||||||
|
tolerance = 0.10 # Allow 10% tolerance for rounding
|
||||||
|
|
||||||
|
if abs(total - calculated_total) > total * tolerance:
|
||||||
|
issues.append(
|
||||||
|
f"Total ({total}) doesn't match quantity ({quantity}) * "
|
||||||
|
f"price ({price}) = {calculated_total:.2f}"
|
||||||
|
)
|
||||||
|
confidence_score *= 0.7
|
||||||
|
|
||||||
|
# Validate quantity is reasonable
|
||||||
|
if "fuelQuantity" in fields:
|
||||||
|
quantity = fields["fuelQuantity"].value
|
||||||
|
if quantity < 0.5:
|
||||||
|
issues.append(f"Fuel quantity too small: {quantity}")
|
||||||
|
confidence_score *= 0.6
|
||||||
|
elif quantity > 40: # 40 gallons is very large tank
|
||||||
|
issues.append(f"Fuel quantity unusually large: {quantity}")
|
||||||
|
confidence_score *= 0.8
|
||||||
|
|
||||||
|
# Validate price is reasonable (current US market range)
|
||||||
|
if "pricePerUnit" in fields:
|
||||||
|
price = fields["pricePerUnit"].value
|
||||||
|
if price < 1.50:
|
||||||
|
issues.append(f"Price per unit too low: ${price}")
|
||||||
|
confidence_score *= 0.7
|
||||||
|
elif price > 7.00:
|
||||||
|
issues.append(f"Price per unit unusually high: ${price}")
|
||||||
|
confidence_score *= 0.8
|
||||||
|
|
||||||
|
# Validate fuel grade
|
||||||
|
if "fuelGrade" in fields:
|
||||||
|
grade = fields["fuelGrade"].value
|
||||||
|
valid_grades = ["87", "89", "91", "93", "DIESEL", "E85"]
|
||||||
|
if grade not in valid_grades:
|
||||||
|
issues.append(f"Unknown fuel grade: {grade}")
|
||||||
|
confidence_score *= 0.9
|
||||||
|
|
||||||
|
is_valid = len(issues) == 0
|
||||||
|
return FuelReceiptValidation(
|
||||||
|
is_valid=is_valid,
|
||||||
|
issues=issues,
|
||||||
|
confidence_score=confidence_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _adjust_confidences(
|
||||||
|
self,
|
||||||
|
fields: dict[str, ExtractedField],
|
||||||
|
validation: FuelReceiptValidation,
|
||||||
|
) -> dict[str, ExtractedField]:
|
||||||
|
"""
|
||||||
|
Adjust field confidences based on validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields: Extracted fields
|
||||||
|
validation: Validation result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fields with adjusted confidences
|
||||||
|
"""
|
||||||
|
if validation.is_valid:
|
||||||
|
# Boost confidences when cross-validation passes
|
||||||
|
boost = 1.1
|
||||||
|
else:
|
||||||
|
# Reduce confidences when there are issues
|
||||||
|
boost = validation.confidence_score
|
||||||
|
|
||||||
|
adjusted = {}
|
||||||
|
for name, field in fields.items():
|
||||||
|
adjusted[name] = ExtractedField(
|
||||||
|
value=field.value,
|
||||||
|
confidence=min(1.0, field.confidence * boost),
|
||||||
|
)
|
||||||
|
|
||||||
|
return adjusted
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
fuel_receipt_extractor = FuelReceiptExtractor()
|
||||||
345
ocr/app/extractors/receipt_extractor.py
Normal file
345
ocr/app/extractors/receipt_extractor.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""Receipt-specific OCR extractor with field extraction."""
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import magic
|
||||||
|
import pytesseract
|
||||||
|
from PIL import Image
|
||||||
|
from pillow_heif import register_heif_opener
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.extractors.base import BaseExtractor
|
||||||
|
from app.preprocessors.receipt_preprocessor import receipt_preprocessor
|
||||||
|
from app.patterns import currency_matcher, date_matcher, fuel_matcher
|
||||||
|
|
||||||
|
# Register HEIF/HEIC opener
|
||||||
|
register_heif_opener()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedField:
|
||||||
|
"""A single extracted field with confidence."""
|
||||||
|
|
||||||
|
value: Any
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReceiptExtractionResult:
|
||||||
|
"""Result of receipt extraction."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
receipt_type: str = "unknown"
|
||||||
|
extracted_fields: dict[str, ExtractedField] = field(default_factory=dict)
|
||||||
|
raw_text: str = ""
|
||||||
|
processing_time_ms: int = 0
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptExtractor(BaseExtractor):
|
||||||
|
"""Receipt-specific OCR extractor for fuel and general receipts."""
|
||||||
|
|
||||||
|
# Supported MIME types
|
||||||
|
SUPPORTED_TYPES = {
|
||||||
|
"image/jpeg",
|
||||||
|
"image/png",
|
||||||
|
"image/heic",
|
||||||
|
"image/heif",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize receipt extractor."""
|
||||||
|
pytesseract.pytesseract.tesseract_cmd = settings.tesseract_cmd
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
image_bytes: bytes,
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
receipt_type: Optional[str] = None,
|
||||||
|
) -> ReceiptExtractionResult:
|
||||||
|
"""
|
||||||
|
Extract data from a receipt image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: Raw image bytes (HEIC, JPEG, PNG)
|
||||||
|
content_type: MIME type (auto-detected if not provided)
|
||||||
|
receipt_type: Hint for receipt type ("fuel" for specialized extraction)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReceiptExtractionResult with extracted fields
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Detect content type if not provided
|
||||||
|
if not content_type:
|
||||||
|
content_type = self._detect_mime_type(image_bytes)
|
||||||
|
|
||||||
|
# Validate content type
|
||||||
|
if content_type not in self.SUPPORTED_TYPES:
|
||||||
|
return ReceiptExtractionResult(
|
||||||
|
success=False,
|
||||||
|
error=f"Unsupported file type: {content_type}",
|
||||||
|
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply receipt-optimized preprocessing
|
||||||
|
preprocessing_result = receipt_preprocessor.preprocess(image_bytes)
|
||||||
|
preprocessed_bytes = preprocessing_result.image_bytes
|
||||||
|
|
||||||
|
# Perform OCR
|
||||||
|
raw_text = self._perform_ocr(preprocessed_bytes)
|
||||||
|
|
||||||
|
if not raw_text.strip():
|
||||||
|
# Try with less aggressive preprocessing
|
||||||
|
preprocessing_result = receipt_preprocessor.preprocess(
|
||||||
|
image_bytes,
|
||||||
|
apply_threshold=False,
|
||||||
|
)
|
||||||
|
preprocessed_bytes = preprocessing_result.image_bytes
|
||||||
|
raw_text = self._perform_ocr(preprocessed_bytes)
|
||||||
|
|
||||||
|
if not raw_text.strip():
|
||||||
|
return ReceiptExtractionResult(
|
||||||
|
success=False,
|
||||||
|
error="No text found in image",
|
||||||
|
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect receipt type if not specified
|
||||||
|
detected_type = receipt_type or self._detect_receipt_type(raw_text)
|
||||||
|
|
||||||
|
# Extract fields based on receipt type
|
||||||
|
if detected_type == "fuel":
|
||||||
|
extracted_fields = self._extract_fuel_fields(raw_text)
|
||||||
|
else:
|
||||||
|
extracted_fields = self._extract_generic_fields(raw_text)
|
||||||
|
|
||||||
|
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Receipt extraction: type={detected_type}, "
|
||||||
|
f"fields={len(extracted_fields)}, "
|
||||||
|
f"time={processing_time_ms}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ReceiptExtractionResult(
|
||||||
|
success=True,
|
||||||
|
receipt_type=detected_type,
|
||||||
|
extracted_fields=extracted_fields,
|
||||||
|
raw_text=raw_text,
|
||||||
|
processing_time_ms=processing_time_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Receipt extraction failed: {e}", exc_info=True)
|
||||||
|
return ReceiptExtractionResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _detect_mime_type(self, file_bytes: bytes) -> str:
|
||||||
|
"""Detect MIME type using python-magic."""
|
||||||
|
mime = magic.Magic(mime=True)
|
||||||
|
detected = mime.from_buffer(file_bytes)
|
||||||
|
return detected or "application/octet-stream"
|
||||||
|
|
||||||
|
def _perform_ocr(self, image_bytes: bytes, psm: int = 6) -> str:
|
||||||
|
"""
|
||||||
|
Perform OCR on preprocessed image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: Preprocessed image bytes
|
||||||
|
psm: Tesseract page segmentation mode
|
||||||
|
4 = Assume single column of text
|
||||||
|
6 = Uniform block of text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw OCR text
|
||||||
|
"""
|
||||||
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
|
||||||
|
# Configure Tesseract for receipt OCR
|
||||||
|
# PSM 4 works well for columnar receipt text
|
||||||
|
config = f"--psm {psm}"
|
||||||
|
|
||||||
|
return pytesseract.image_to_string(image, config=config)
|
||||||
|
|
||||||
|
def _detect_receipt_type(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Detect receipt type based on content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: OCR text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Receipt type: "fuel", "retail", or "unknown"
|
||||||
|
"""
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
# Fuel receipt indicators
|
||||||
|
fuel_keywords = [
|
||||||
|
"GALLON", "GAL", "FUEL", "GAS", "DIESEL", "UNLEADED",
|
||||||
|
"REGULAR", "PREMIUM", "OCTANE", "PPG", "PUMP",
|
||||||
|
]
|
||||||
|
|
||||||
|
fuel_score = sum(1 for kw in fuel_keywords if kw in text_upper)
|
||||||
|
|
||||||
|
# Check for known gas stations
|
||||||
|
if fuel_matcher.extract_merchant_name(text):
|
||||||
|
merchant, _ = fuel_matcher.extract_merchant_name(text)
|
||||||
|
if any(
|
||||||
|
station in merchant.upper()
|
||||||
|
for station in fuel_matcher.STATION_NAMES
|
||||||
|
):
|
||||||
|
fuel_score += 3
|
||||||
|
|
||||||
|
if fuel_score >= 2:
|
||||||
|
return "fuel"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _extract_fuel_fields(self, text: str) -> dict[str, ExtractedField]:
|
||||||
|
"""
|
||||||
|
Extract fuel-specific fields from receipt text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: OCR text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of extracted fields
|
||||||
|
"""
|
||||||
|
fields: dict[str, ExtractedField] = {}
|
||||||
|
|
||||||
|
# Extract merchant name
|
||||||
|
merchant_result = fuel_matcher.extract_merchant_name(text)
|
||||||
|
if merchant_result:
|
||||||
|
merchant_name, confidence = merchant_result
|
||||||
|
fields["merchantName"] = ExtractedField(
|
||||||
|
value=merchant_name,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract transaction date
|
||||||
|
date_match = date_matcher.extract_best_date(text)
|
||||||
|
if date_match:
|
||||||
|
fields["transactionDate"] = ExtractedField(
|
||||||
|
value=date_match.value,
|
||||||
|
confidence=date_match.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract total amount
|
||||||
|
total_match = currency_matcher.extract_total(text)
|
||||||
|
if total_match:
|
||||||
|
fields["totalAmount"] = ExtractedField(
|
||||||
|
value=total_match.value,
|
||||||
|
confidence=total_match.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract fuel quantity
|
||||||
|
quantity_match = fuel_matcher.extract_quantity(text)
|
||||||
|
if quantity_match:
|
||||||
|
fields["fuelQuantity"] = ExtractedField(
|
||||||
|
value=quantity_match.value,
|
||||||
|
confidence=quantity_match.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract price per unit
|
||||||
|
price_match = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
if price_match:
|
||||||
|
fields["pricePerUnit"] = ExtractedField(
|
||||||
|
value=price_match.value,
|
||||||
|
confidence=price_match.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract fuel grade
|
||||||
|
grade_match = fuel_matcher.extract_grade(text)
|
||||||
|
if grade_match:
|
||||||
|
fields["fuelGrade"] = ExtractedField(
|
||||||
|
value=grade_match.value,
|
||||||
|
confidence=grade_match.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate derived values if we have enough data
|
||||||
|
if "totalAmount" in fields and "fuelQuantity" in fields:
|
||||||
|
if "pricePerUnit" not in fields:
|
||||||
|
# Calculate price per unit from total and quantity
|
||||||
|
calculated_price = (
|
||||||
|
fields["totalAmount"].value / fields["fuelQuantity"].value
|
||||||
|
)
|
||||||
|
# Only use if reasonable
|
||||||
|
if 1.0 <= calculated_price <= 10.0:
|
||||||
|
fields["pricePerUnit"] = ExtractedField(
|
||||||
|
value=round(calculated_price, 3),
|
||||||
|
confidence=min(
|
||||||
|
fields["totalAmount"].confidence,
|
||||||
|
fields["fuelQuantity"].confidence,
|
||||||
|
)
|
||||||
|
* 0.8, # Lower confidence for calculated value
|
||||||
|
)
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
def _extract_generic_fields(self, text: str) -> dict[str, ExtractedField]:
|
||||||
|
"""
|
||||||
|
Extract generic fields from receipt text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: OCR text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of extracted fields
|
||||||
|
"""
|
||||||
|
fields: dict[str, ExtractedField] = {}
|
||||||
|
|
||||||
|
# Extract date
|
||||||
|
date_match = date_matcher.extract_best_date(text)
|
||||||
|
if date_match:
|
||||||
|
fields["transactionDate"] = ExtractedField(
|
||||||
|
value=date_match.value,
|
||||||
|
confidence=date_match.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract total amount
|
||||||
|
total_match = currency_matcher.extract_total(text)
|
||||||
|
if total_match:
|
||||||
|
fields["totalAmount"] = ExtractedField(
|
||||||
|
value=total_match.value,
|
||||||
|
confidence=total_match.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to get merchant from first line
|
||||||
|
lines = [l.strip() for l in text.split("\n") if l.strip()]
|
||||||
|
if lines:
|
||||||
|
fields["merchantName"] = ExtractedField(
|
||||||
|
value=lines[0][:50],
|
||||||
|
confidence=0.40,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
def validate(self, data: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Validate extracted receipt data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Extracted data to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if data has minimum required fields
|
||||||
|
"""
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Minimum: must have at least total amount or date
|
||||||
|
return "totalAmount" in data or "transactionDate" in data
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
receipt_extractor = ReceiptExtractor()
|
||||||
@@ -7,6 +7,8 @@ from .schemas import (
|
|||||||
JobStatus,
|
JobStatus,
|
||||||
JobSubmitRequest,
|
JobSubmitRequest,
|
||||||
OcrResponse,
|
OcrResponse,
|
||||||
|
ReceiptExtractedField,
|
||||||
|
ReceiptExtractionResponse,
|
||||||
VinAlternative,
|
VinAlternative,
|
||||||
VinExtractionResponse,
|
VinExtractionResponse,
|
||||||
)
|
)
|
||||||
@@ -19,6 +21,8 @@ __all__ = [
|
|||||||
"JobStatus",
|
"JobStatus",
|
||||||
"JobSubmitRequest",
|
"JobSubmitRequest",
|
||||||
"OcrResponse",
|
"OcrResponse",
|
||||||
|
"ReceiptExtractedField",
|
||||||
|
"ReceiptExtractionResponse",
|
||||||
"VinAlternative",
|
"VinAlternative",
|
||||||
"VinExtractionResponse",
|
"VinExtractionResponse",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -93,3 +93,25 @@ class JobSubmitRequest(BaseModel):
|
|||||||
callback_url: Optional[str] = Field(default=None, alias="callbackUrl")
|
callback_url: Optional[str] = Field(default=None, alias="callbackUrl")
|
||||||
|
|
||||||
model_config = {"populate_by_name": True}
|
model_config = {"populate_by_name": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptExtractedField(BaseModel):
|
||||||
|
"""A single extracted field from a receipt with confidence."""
|
||||||
|
|
||||||
|
value: str | float
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptExtractionResponse(BaseModel):
|
||||||
|
"""Response from receipt extraction endpoint."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
receipt_type: str = Field(alias="receiptType")
|
||||||
|
extracted_fields: dict[str, ReceiptExtractedField] = Field(
|
||||||
|
default_factory=dict, alias="extractedFields"
|
||||||
|
)
|
||||||
|
raw_text: str = Field(alias="rawText")
|
||||||
|
processing_time_ms: int = Field(alias="processingTimeMs")
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
model_config = {"populate_by_name": True}
|
||||||
|
|||||||
13
ocr/app/patterns/__init__.py
Normal file
13
ocr/app/patterns/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""Pattern matching modules for receipt field extraction."""
|
||||||
|
from app.patterns.date_patterns import DatePatternMatcher, date_matcher
|
||||||
|
from app.patterns.currency_patterns import CurrencyPatternMatcher, currency_matcher
|
||||||
|
from app.patterns.fuel_patterns import FuelPatternMatcher, fuel_matcher
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatePatternMatcher",
|
||||||
|
"date_matcher",
|
||||||
|
"CurrencyPatternMatcher",
|
||||||
|
"currency_matcher",
|
||||||
|
"FuelPatternMatcher",
|
||||||
|
"fuel_matcher",
|
||||||
|
]
|
||||||
227
ocr/app/patterns/currency_patterns.py
Normal file
227
ocr/app/patterns/currency_patterns.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""Currency and amount pattern matching for receipt extraction."""
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from decimal import Decimal, InvalidOperation
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AmountMatch:
|
||||||
|
"""Result of currency/amount pattern matching."""
|
||||||
|
|
||||||
|
value: float
|
||||||
|
raw_match: str
|
||||||
|
confidence: float
|
||||||
|
pattern_name: str
|
||||||
|
label: Optional[str] = None # e.g., "TOTAL", "SUBTOTAL"
|
||||||
|
|
||||||
|
|
||||||
|
class CurrencyPatternMatcher:
|
||||||
|
"""Extract and normalize currency amounts from receipt text."""
|
||||||
|
|
||||||
|
# Total amount patterns (prioritized)
|
||||||
|
TOTAL_PATTERNS = [
|
||||||
|
# TOTAL $XX.XX or TOTAL: $XX.XX
|
||||||
|
(
|
||||||
|
r"(?:^|\s)TOTAL[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)",
|
||||||
|
"total_explicit",
|
||||||
|
0.98,
|
||||||
|
),
|
||||||
|
# AMOUNT DUE $XX.XX
|
||||||
|
(
|
||||||
|
r"AMOUNT\s*DUE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
|
||||||
|
"amount_due",
|
||||||
|
0.95,
|
||||||
|
),
|
||||||
|
# SALE $XX.XX
|
||||||
|
(
|
||||||
|
r"(?:^|\s)SALE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)",
|
||||||
|
"sale_explicit",
|
||||||
|
0.92,
|
||||||
|
),
|
||||||
|
# GRAND TOTAL $XX.XX
|
||||||
|
(
|
||||||
|
r"GRAND\s*TOTAL[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
|
||||||
|
"grand_total",
|
||||||
|
0.97,
|
||||||
|
),
|
||||||
|
# TOTAL SALE $XX.XX
|
||||||
|
(
|
||||||
|
r"TOTAL\s*SALE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
|
||||||
|
"total_sale",
|
||||||
|
0.96,
|
||||||
|
),
|
||||||
|
# BALANCE DUE $XX.XX
|
||||||
|
(
|
||||||
|
r"BALANCE\s*DUE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
|
||||||
|
"balance_due",
|
||||||
|
0.94,
|
||||||
|
),
|
||||||
|
# PURCHASE $XX.XX
|
||||||
|
(
|
||||||
|
r"(?:^|\s)PURCHASE[:\s]*\$?\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})(?:\s|$)",
|
||||||
|
"purchase",
|
||||||
|
0.88,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generic amount patterns (lower priority)
|
||||||
|
AMOUNT_PATTERNS = [
|
||||||
|
# $XX.XX (standalone dollar amount)
|
||||||
|
(
|
||||||
|
r"\$\s*(\d{1,6}[,.]?\d{0,3}[.,]\d{2})",
|
||||||
|
"dollar_amount",
|
||||||
|
0.60,
|
||||||
|
),
|
||||||
|
# XX.XX (standalone decimal amount)
|
||||||
|
(
|
||||||
|
r"(?<![.$\d])(\d{1,6}[.,]\d{2})(?![.\d])",
|
||||||
|
"decimal_amount",
|
||||||
|
0.40,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_total(self, text: str) -> Optional[AmountMatch]:
|
||||||
|
"""
|
||||||
|
Extract the total amount from receipt text.
|
||||||
|
|
||||||
|
Prioritizes explicit total patterns over generic amounts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AmountMatch for total or None if not found
|
||||||
|
"""
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
# Try total-specific patterns first
|
||||||
|
for pattern, name, confidence in self.TOTAL_PATTERNS:
|
||||||
|
match = re.search(pattern, text_upper, re.MULTILINE)
|
||||||
|
if match:
|
||||||
|
amount = self._parse_amount(match.group(1))
|
||||||
|
if amount is not None and self._is_reasonable_total(amount):
|
||||||
|
return AmountMatch(
|
||||||
|
value=amount,
|
||||||
|
raw_match=match.group(0).strip(),
|
||||||
|
confidence=confidence,
|
||||||
|
pattern_name=name,
|
||||||
|
label=self._extract_label(name),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fall back to finding the largest reasonable amount
|
||||||
|
all_amounts = self.extract_all_amounts(text)
|
||||||
|
reasonable = [a for a in all_amounts if self._is_reasonable_total(a.value)]
|
||||||
|
if reasonable:
|
||||||
|
# Assume largest amount is the total
|
||||||
|
reasonable.sort(key=lambda x: x.value, reverse=True)
|
||||||
|
best = reasonable[0]
|
||||||
|
# Lower confidence since we're guessing
|
||||||
|
return AmountMatch(
|
||||||
|
value=best.value,
|
||||||
|
raw_match=best.raw_match,
|
||||||
|
confidence=min(0.60, best.confidence),
|
||||||
|
pattern_name="inferred_total",
|
||||||
|
label="TOTAL",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_all_amounts(self, text: str) -> list[AmountMatch]:
|
||||||
|
"""
|
||||||
|
Extract all currency amounts from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AmountMatch objects
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
# Check total patterns
|
||||||
|
for pattern, name, confidence in self.TOTAL_PATTERNS:
|
||||||
|
for match in re.finditer(pattern, text_upper, re.MULTILINE):
|
||||||
|
amount = self._parse_amount(match.group(1))
|
||||||
|
if amount is not None:
|
||||||
|
matches.append(
|
||||||
|
AmountMatch(
|
||||||
|
value=amount,
|
||||||
|
raw_match=match.group(0).strip(),
|
||||||
|
confidence=confidence,
|
||||||
|
pattern_name=name,
|
||||||
|
label=self._extract_label(name),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check generic amount patterns
|
||||||
|
for pattern, name, confidence in self.AMOUNT_PATTERNS:
|
||||||
|
for match in re.finditer(pattern, text_upper):
|
||||||
|
amount = self._parse_amount(match.group(1))
|
||||||
|
if amount is not None:
|
||||||
|
# Skip if already found by a more specific pattern
|
||||||
|
if not any(abs(m.value - amount) < 0.01 for m in matches):
|
||||||
|
matches.append(
|
||||||
|
AmountMatch(
|
||||||
|
value=amount,
|
||||||
|
raw_match=match.group(0).strip(),
|
||||||
|
confidence=confidence,
|
||||||
|
pattern_name=name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _parse_amount(self, amount_str: str) -> Optional[float]:
|
||||||
|
"""Parse amount string to float, handling various formats."""
|
||||||
|
# Remove any spaces
|
||||||
|
cleaned = amount_str.strip().replace(" ", "")
|
||||||
|
|
||||||
|
# Handle European format (1.234,56) vs US format (1,234.56)
|
||||||
|
# For US receipts, assume comma is thousands separator
|
||||||
|
if "," in cleaned and "." in cleaned:
|
||||||
|
# Determine which is decimal separator (last one)
|
||||||
|
if cleaned.rfind(",") > cleaned.rfind("."):
|
||||||
|
# European format
|
||||||
|
cleaned = cleaned.replace(".", "").replace(",", ".")
|
||||||
|
else:
|
||||||
|
# US format
|
||||||
|
cleaned = cleaned.replace(",", "")
|
||||||
|
elif "," in cleaned:
|
||||||
|
# Could be thousands separator or decimal
|
||||||
|
parts = cleaned.split(",")
|
||||||
|
if len(parts) == 2 and len(parts[1]) == 2:
|
||||||
|
# Likely decimal separator
|
||||||
|
cleaned = cleaned.replace(",", ".")
|
||||||
|
else:
|
||||||
|
# Likely thousands separator
|
||||||
|
cleaned = cleaned.replace(",", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
amount = float(Decimal(cleaned))
|
||||||
|
return amount if amount >= 0 else None
|
||||||
|
except (InvalidOperation, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_reasonable_total(self, amount: float) -> bool:
|
||||||
|
"""Check if amount is a reasonable total for a fuel receipt."""
|
||||||
|
# Reasonable range: $1 to $500 for typical fuel purchases
|
||||||
|
return 1.0 <= amount <= 500.0
|
||||||
|
|
||||||
|
def _extract_label(self, pattern_name: str) -> str:
|
||||||
|
"""Extract display label from pattern name."""
|
||||||
|
labels = {
|
||||||
|
"total_explicit": "TOTAL",
|
||||||
|
"amount_due": "AMOUNT DUE",
|
||||||
|
"sale_explicit": "SALE",
|
||||||
|
"grand_total": "GRAND TOTAL",
|
||||||
|
"total_sale": "TOTAL SALE",
|
||||||
|
"balance_due": "BALANCE DUE",
|
||||||
|
"purchase": "PURCHASE",
|
||||||
|
}
|
||||||
|
return labels.get(pattern_name, "TOTAL")
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
currency_matcher = CurrencyPatternMatcher()
|
||||||
186
ocr/app/patterns/date_patterns.py
Normal file
186
ocr/app/patterns/date_patterns.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Date pattern matching for receipt extraction."""
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DateMatch:
|
||||||
|
"""Result of date pattern matching."""
|
||||||
|
|
||||||
|
value: str # ISO format YYYY-MM-DD
|
||||||
|
raw_match: str # Original text matched
|
||||||
|
confidence: float
|
||||||
|
pattern_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DatePatternMatcher:
|
||||||
|
"""Extract and normalize dates from receipt text."""
|
||||||
|
|
||||||
|
# Pattern definitions with named groups and confidence weights
|
||||||
|
PATTERNS = [
|
||||||
|
# MM/DD/YYYY or MM/DD/YY (most common US format)
|
||||||
|
(
|
||||||
|
r"(?P<month>\d{1,2})/(?P<day>\d{1,2})/(?P<year>\d{2,4})",
|
||||||
|
"mm_dd_yyyy",
|
||||||
|
0.95,
|
||||||
|
),
|
||||||
|
# MM-DD-YYYY or MM-DD-YY
|
||||||
|
(
|
||||||
|
r"(?P<month>\d{1,2})-(?P<day>\d{1,2})-(?P<year>\d{2,4})",
|
||||||
|
"mm_dd_yyyy_dash",
|
||||||
|
0.90,
|
||||||
|
),
|
||||||
|
# YYYY-MM-DD (ISO format)
|
||||||
|
(
|
||||||
|
r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})",
|
||||||
|
"iso_date",
|
||||||
|
0.98,
|
||||||
|
),
|
||||||
|
# Mon DD, YYYY (e.g., Jan 15, 2024)
|
||||||
|
(
|
||||||
|
r"(?P<month_name>[A-Za-z]{3})\s+(?P<day>\d{1,2}),?\s+(?P<year>\d{4})",
|
||||||
|
"month_name_long",
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
|
# DD Mon YYYY (e.g., 15 Jan 2024)
|
||||||
|
(
|
||||||
|
r"(?P<day>\d{1,2})\s+(?P<month_name>[A-Za-z]{3})\s+(?P<year>\d{4})",
|
||||||
|
"day_month_year",
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
|
# MMDDYYYY or MMDDYY (no separators, common in some POS systems)
|
||||||
|
(
|
||||||
|
r"(?<!\d)(?P<month>\d{2})(?P<day>\d{2})(?P<year>\d{2,4})(?!\d)",
|
||||||
|
"compact_date",
|
||||||
|
0.70,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
MONTH_NAMES = {
|
||||||
|
"jan": 1, "january": 1,
|
||||||
|
"feb": 2, "february": 2,
|
||||||
|
"mar": 3, "march": 3,
|
||||||
|
"apr": 4, "april": 4,
|
||||||
|
"may": 5,
|
||||||
|
"jun": 6, "june": 6,
|
||||||
|
"jul": 7, "july": 7,
|
||||||
|
"aug": 8, "august": 8,
|
||||||
|
"sep": 9, "sept": 9, "september": 9,
|
||||||
|
"oct": 10, "october": 10,
|
||||||
|
"nov": 11, "november": 11,
|
||||||
|
"dec": 12, "december": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract_dates(self, text: str) -> list[DateMatch]:
|
||||||
|
"""
|
||||||
|
Extract all date patterns from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DateMatch objects sorted by confidence
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
for pattern, name, base_confidence in self.PATTERNS:
|
||||||
|
for match in re.finditer(pattern, text, re.IGNORECASE):
|
||||||
|
parsed = self._parse_match(match, name)
|
||||||
|
if parsed:
|
||||||
|
year, month, day = parsed
|
||||||
|
if self._is_valid_date(year, month, day):
|
||||||
|
# Adjust confidence based on context
|
||||||
|
confidence = self._adjust_confidence(
|
||||||
|
base_confidence, text_upper, match.start()
|
||||||
|
)
|
||||||
|
matches.append(
|
||||||
|
DateMatch(
|
||||||
|
value=f"{year:04d}-{month:02d}-{day:02d}",
|
||||||
|
raw_match=match.group(0),
|
||||||
|
confidence=confidence,
|
||||||
|
pattern_name=name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by confidence, deduplicate by value
|
||||||
|
matches.sort(key=lambda x: x.confidence, reverse=True)
|
||||||
|
seen = set()
|
||||||
|
unique_matches = []
|
||||||
|
for match in matches:
|
||||||
|
if match.value not in seen:
|
||||||
|
seen.add(match.value)
|
||||||
|
unique_matches.append(match)
|
||||||
|
|
||||||
|
return unique_matches
|
||||||
|
|
||||||
|
def extract_best_date(self, text: str) -> Optional[DateMatch]:
|
||||||
|
"""
|
||||||
|
Extract the most likely transaction date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Best DateMatch or None if no date found
|
||||||
|
"""
|
||||||
|
matches = self.extract_dates(text)
|
||||||
|
return matches[0] if matches else None
|
||||||
|
|
||||||
|
def _parse_match(
|
||||||
|
self, match: re.Match, pattern_name: str
|
||||||
|
) -> Optional[tuple[int, int, int]]:
|
||||||
|
"""Parse regex match into year, month, day tuple."""
|
||||||
|
groups = match.groupdict()
|
||||||
|
|
||||||
|
# Handle month name patterns
|
||||||
|
if "month_name" in groups:
|
||||||
|
month_str = groups["month_name"].lower()
|
||||||
|
month = self.MONTH_NAMES.get(month_str)
|
||||||
|
if not month:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
month = int(groups["month"])
|
||||||
|
|
||||||
|
day = int(groups["day"])
|
||||||
|
year = int(groups["year"])
|
||||||
|
|
||||||
|
# Normalize 2-digit years
|
||||||
|
if year < 100:
|
||||||
|
year = 2000 + year if year < 50 else 1900 + year
|
||||||
|
|
||||||
|
return year, month, day
|
||||||
|
|
||||||
|
def _is_valid_date(self, year: int, month: int, day: int) -> bool:
|
||||||
|
"""Check if date components form a valid date."""
|
||||||
|
try:
|
||||||
|
datetime(year=year, month=month, day=day)
|
||||||
|
# Reasonable year range for receipts
|
||||||
|
return 2000 <= year <= 2100
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _adjust_confidence(
|
||||||
|
self, base_confidence: float, text: str, position: int
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Adjust confidence based on context clues.
|
||||||
|
|
||||||
|
Boost confidence if date appears near date-related keywords.
|
||||||
|
"""
|
||||||
|
# Look for nearby date keywords
|
||||||
|
context_start = max(0, position - 50)
|
||||||
|
context = text[context_start:position + 50]
|
||||||
|
|
||||||
|
date_keywords = ["DATE", "TIME", "TRANS", "SALE"]
|
||||||
|
for keyword in date_keywords:
|
||||||
|
if keyword in context:
|
||||||
|
return min(1.0, base_confidence + 0.05)
|
||||||
|
|
||||||
|
return base_confidence
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
date_matcher = DatePatternMatcher()
|
||||||
364
ocr/app/patterns/fuel_patterns.py
Normal file
364
ocr/app/patterns/fuel_patterns.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""Fuel-specific pattern matching for receipt extraction."""
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FuelQuantityMatch:
|
||||||
|
"""Result of fuel quantity pattern matching."""
|
||||||
|
|
||||||
|
value: float # Gallons or liters
|
||||||
|
unit: str # "GAL" or "L"
|
||||||
|
raw_match: str
|
||||||
|
confidence: float
|
||||||
|
pattern_name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FuelPriceMatch:
|
||||||
|
"""Result of fuel price per unit pattern matching."""
|
||||||
|
|
||||||
|
value: float
|
||||||
|
unit: str # "GAL" or "L"
|
||||||
|
raw_match: str
|
||||||
|
confidence: float
|
||||||
|
pattern_name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FuelGradeMatch:
|
||||||
|
"""Result of fuel grade pattern matching."""
|
||||||
|
|
||||||
|
value: str # e.g., "87", "89", "93", "DIESEL"
|
||||||
|
display_name: str # e.g., "Regular 87", "Premium 93"
|
||||||
|
raw_match: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
class FuelPatternMatcher:
|
||||||
|
"""Extract fuel-specific data from receipt text."""
|
||||||
|
|
||||||
|
# Gallons patterns
|
||||||
|
GALLONS_PATTERNS = [
|
||||||
|
# XX.XXX GAL or XX.XXX GALLONS
|
||||||
|
(
|
||||||
|
r"(\d{1,3}\.\d{1,3})\s*(?:GAL(?:LON)?S?)",
|
||||||
|
"gallons_suffix",
|
||||||
|
0.95,
|
||||||
|
),
|
||||||
|
# GALLONS: XX.XXX or GAL: XX.XXX
|
||||||
|
(
|
||||||
|
r"(?:GAL(?:LON)?S?)[:\s]+(\d{1,3}\.\d{1,3})",
|
||||||
|
"gallons_prefix",
|
||||||
|
0.93,
|
||||||
|
),
|
||||||
|
# VOLUME XX.XXX
|
||||||
|
(
|
||||||
|
r"VOLUME[:\s]+(\d{1,3}\.\d{1,3})",
|
||||||
|
"volume",
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
|
# QTY XX.XXX (near fuel context)
|
||||||
|
(
|
||||||
|
r"QTY[:\s]+(\d{1,3}\.\d{1,3})",
|
||||||
|
"qty",
|
||||||
|
0.70,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Liters patterns (for international receipts)
|
||||||
|
LITERS_PATTERNS = [
|
||||||
|
# XX.XX L or XX.XX LITERS
|
||||||
|
(
|
||||||
|
r"(\d{1,3}\.\d{1,3})\s*(?:L(?:ITERS?)?)",
|
||||||
|
"liters_suffix",
|
||||||
|
0.95,
|
||||||
|
),
|
||||||
|
# LITERS: XX.XX
|
||||||
|
(
|
||||||
|
r"(?:L(?:ITERS?)?)[:\s]+(\d{1,3}\.\d{1,3})",
|
||||||
|
"liters_prefix",
|
||||||
|
0.93,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Price per gallon patterns
|
||||||
|
PRICE_PER_UNIT_PATTERNS = [
|
||||||
|
# $X.XXX/GAL or $X.XX/GAL
|
||||||
|
(
|
||||||
|
r"\$?\s*(\d{1,2}\.\d{2,3})\s*/\s*GAL",
|
||||||
|
"price_per_gal",
|
||||||
|
0.98,
|
||||||
|
),
|
||||||
|
# PRICE/GAL $X.XXX
|
||||||
|
(
|
||||||
|
r"PRICE\s*/\s*GAL[:\s]*\$?\s*(\d{1,2}\.\d{2,3})",
|
||||||
|
"labeled_price_gal",
|
||||||
|
0.96,
|
||||||
|
),
|
||||||
|
# UNIT PRICE $X.XXX
|
||||||
|
(
|
||||||
|
r"UNIT\s*PRICE[:\s]*\$?\s*(\d{1,2}\.\d{2,3})",
|
||||||
|
"unit_price",
|
||||||
|
0.90,
|
||||||
|
),
|
||||||
|
# @ $X.XXX (per unit implied)
|
||||||
|
(
|
||||||
|
r"@\s*\$?\s*(\d{1,2}\.\d{2,3})",
|
||||||
|
"at_price",
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
|
# PPG $X.XXX (price per gallon)
|
||||||
|
(
|
||||||
|
r"PPG[:\s]*\$?\s*(\d{1,2}\.\d{2,3})",
|
||||||
|
"ppg",
|
||||||
|
0.92,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Fuel grade patterns
|
||||||
|
GRADE_PATTERNS = [
|
||||||
|
# REGULAR 87, REG 87
|
||||||
|
(r"(?:REGULAR|REG)\s*(\d{2})", "regular", 0.95),
|
||||||
|
# UNLEADED 87
|
||||||
|
(r"UNLEADED\s*(\d{2})", "unleaded", 0.93),
|
||||||
|
# PLUS 89, MID 89, MIDGRADE 89
|
||||||
|
(r"(?:PLUS|MID(?:GRADE)?)\s*(\d{2})", "plus", 0.95),
|
||||||
|
# PREMIUM 91/93, PREM 91/93, SUPER 91/93
|
||||||
|
(r"(?:PREMIUM|PREM|SUPER)\s*(\d{2})", "premium", 0.95),
|
||||||
|
# Just the octane number near fuel context (87, 89, 91, 93)
|
||||||
|
(r"(?<!\d)\s*(87|89|91|93)\s*(?:OCT(?:ANE)?)?", "octane_only", 0.75),
|
||||||
|
# DIESEL (no octane)
|
||||||
|
(r"DIESEL(?:\s*#?\d)?", "diesel", 0.98),
|
||||||
|
# E85 (ethanol blend)
|
||||||
|
(r"E\s*85", "e85", 0.95),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Common gas station names
|
||||||
|
STATION_NAMES = [
|
||||||
|
"SHELL", "CHEVRON", "EXXON", "MOBIL", "BP", "SUNOCO", "76",
|
||||||
|
"CIRCLE K", "SPEEDWAY", "WAWA", "SHEETZ", "CASEY", "PILOT",
|
||||||
|
"FLYING J", "LOVES", "TA", "PETRO", "MARATHON", "CITGO",
|
||||||
|
"VALERO", "MURPHY", "COSTCO", "SAMS CLUB", "SAM'S CLUB",
|
||||||
|
"KROGER", "QT", "QUIKTRIP", "RACETRAC", "KUM & GO",
|
||||||
|
"KWIK TRIP", "HOLIDAY", "SINCLAIR", "CONOCO", "PHILLIPS 66",
|
||||||
|
"ARCO", "AMPM", "AM/PM", "7-ELEVEN", "7 ELEVEN", "GETTY",
|
||||||
|
"GULF", "HESS", "TEXACO", "TURKEY HILL", "CUMBERLAND FARMS",
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_gallons(self, text: str) -> Optional[FuelQuantityMatch]:
|
||||||
|
"""
|
||||||
|
Extract fuel quantity in gallons.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FuelQuantityMatch or None
|
||||||
|
"""
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
for pattern, name, confidence in self.GALLONS_PATTERNS:
|
||||||
|
match = re.search(pattern, text_upper)
|
||||||
|
if match:
|
||||||
|
quantity = float(match.group(1))
|
||||||
|
if self._is_reasonable_quantity(quantity):
|
||||||
|
return FuelQuantityMatch(
|
||||||
|
value=quantity,
|
||||||
|
unit="GAL",
|
||||||
|
raw_match=match.group(0),
|
||||||
|
confidence=confidence,
|
||||||
|
pattern_name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_liters(self, text: str) -> Optional[FuelQuantityMatch]:
|
||||||
|
"""
|
||||||
|
Extract fuel quantity in liters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FuelQuantityMatch or None
|
||||||
|
"""
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
for pattern, name, confidence in self.LITERS_PATTERNS:
|
||||||
|
match = re.search(pattern, text_upper)
|
||||||
|
if match:
|
||||||
|
quantity = float(match.group(1))
|
||||||
|
if self._is_reasonable_quantity(quantity, is_liters=True):
|
||||||
|
return FuelQuantityMatch(
|
||||||
|
value=quantity,
|
||||||
|
unit="L",
|
||||||
|
raw_match=match.group(0),
|
||||||
|
confidence=confidence,
|
||||||
|
pattern_name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_quantity(self, text: str) -> Optional[FuelQuantityMatch]:
|
||||||
|
"""
|
||||||
|
Extract fuel quantity (gallons or liters).
|
||||||
|
|
||||||
|
Prefers gallons for US receipts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FuelQuantityMatch or None
|
||||||
|
"""
|
||||||
|
# Try gallons first (more common in US)
|
||||||
|
gallons = self.extract_gallons(text)
|
||||||
|
if gallons:
|
||||||
|
return gallons
|
||||||
|
|
||||||
|
# Fall back to liters
|
||||||
|
return self.extract_liters(text)
|
||||||
|
|
||||||
|
def extract_price_per_unit(self, text: str) -> Optional[FuelPriceMatch]:
|
||||||
|
"""
|
||||||
|
Extract price per gallon/liter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FuelPriceMatch or None
|
||||||
|
"""
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
for pattern, name, confidence in self.PRICE_PER_UNIT_PATTERNS:
|
||||||
|
match = re.search(pattern, text_upper)
|
||||||
|
if match:
|
||||||
|
price = float(match.group(1))
|
||||||
|
if self._is_reasonable_price(price):
|
||||||
|
return FuelPriceMatch(
|
||||||
|
value=price,
|
||||||
|
unit="GAL", # Default to gallons for US
|
||||||
|
raw_match=match.group(0),
|
||||||
|
confidence=confidence,
|
||||||
|
pattern_name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_grade(self, text: str) -> Optional[FuelGradeMatch]:
|
||||||
|
"""
|
||||||
|
Extract fuel grade (octane rating or diesel).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FuelGradeMatch or None
|
||||||
|
"""
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
for pattern, name, confidence in self.GRADE_PATTERNS:
|
||||||
|
match = re.search(pattern, text_upper)
|
||||||
|
if match:
|
||||||
|
if name == "diesel":
|
||||||
|
return FuelGradeMatch(
|
||||||
|
value="DIESEL",
|
||||||
|
display_name="Diesel",
|
||||||
|
raw_match=match.group(0),
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
elif name == "e85":
|
||||||
|
return FuelGradeMatch(
|
||||||
|
value="E85",
|
||||||
|
display_name="E85 Ethanol",
|
||||||
|
raw_match=match.group(0),
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
octane = match.group(1)
|
||||||
|
display = self._get_grade_display_name(octane, name)
|
||||||
|
return FuelGradeMatch(
|
||||||
|
value=octane,
|
||||||
|
display_name=display,
|
||||||
|
raw_match=match.group(0),
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_merchant_name(self, text: str) -> Optional[tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Extract gas station/merchant name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Receipt text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (merchant_name, confidence) or None
|
||||||
|
"""
|
||||||
|
text_upper = text.upper()
|
||||||
|
|
||||||
|
# Check for known station names
|
||||||
|
for station in self.STATION_NAMES:
|
||||||
|
if station in text_upper:
|
||||||
|
# Try to get the full line for context
|
||||||
|
for line in text.split("\n"):
|
||||||
|
if station in line.upper():
|
||||||
|
# Clean up the line
|
||||||
|
cleaned = line.strip()
|
||||||
|
if len(cleaned) <= 50: # Reasonable length
|
||||||
|
return (cleaned, 0.90)
|
||||||
|
return (station.title(), 0.85)
|
||||||
|
|
||||||
|
# Fall back to first non-empty line (often the merchant)
|
||||||
|
lines = [l.strip() for l in text.split("\n") if l.strip()]
|
||||||
|
if lines:
|
||||||
|
first_line = lines[0]
|
||||||
|
# Skip if it looks like a date or number
|
||||||
|
if not re.match(r"^\d+[/\-.]", first_line):
|
||||||
|
return (first_line[:50], 0.50) # Low confidence
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_reasonable_quantity(
|
||||||
|
self, quantity: float, is_liters: bool = False
|
||||||
|
) -> bool:
|
||||||
|
"""Check if fuel quantity is reasonable."""
|
||||||
|
if is_liters:
|
||||||
|
# Typical fill: 20-100 liters
|
||||||
|
return 0.5 <= quantity <= 150.0
|
||||||
|
else:
|
||||||
|
# Typical fill: 5-30 gallons
|
||||||
|
return 0.1 <= quantity <= 50.0
|
||||||
|
|
||||||
|
def _is_reasonable_price(self, price: float) -> bool:
|
||||||
|
"""Check if price per unit is reasonable."""
|
||||||
|
# US gas prices: $1.50 - $8.00 per gallon (allowing for fluctuation)
|
||||||
|
return 1.00 <= price <= 10.00
|
||||||
|
|
||||||
|
def _get_grade_display_name(self, octane: str, pattern_name: str) -> str:
|
||||||
|
"""Get display name for fuel grade."""
|
||||||
|
grade_names = {
|
||||||
|
"87": "Regular 87",
|
||||||
|
"89": "Plus 89",
|
||||||
|
"91": "Premium 91",
|
||||||
|
"93": "Premium 93",
|
||||||
|
}
|
||||||
|
|
||||||
|
if octane in grade_names:
|
||||||
|
return grade_names[octane]
|
||||||
|
|
||||||
|
# Use pattern hint
|
||||||
|
if pattern_name == "premium":
|
||||||
|
return f"Premium {octane}"
|
||||||
|
elif pattern_name == "plus":
|
||||||
|
return f"Plus {octane}"
|
||||||
|
else:
|
||||||
|
return f"Unleaded {octane}"
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
fuel_matcher = FuelPatternMatcher()
|
||||||
@@ -1,10 +1,16 @@
|
|||||||
"""Image preprocessors for OCR optimization."""
|
"""Image preprocessors for OCR optimization."""
|
||||||
from app.services.preprocessor import ImagePreprocessor, preprocessor
|
from app.services.preprocessor import ImagePreprocessor, preprocessor
|
||||||
from app.preprocessors.vin_preprocessor import VinPreprocessor, vin_preprocessor
|
from app.preprocessors.vin_preprocessor import VinPreprocessor, vin_preprocessor
|
||||||
|
from app.preprocessors.receipt_preprocessor import (
|
||||||
|
ReceiptPreprocessor,
|
||||||
|
receipt_preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ImagePreprocessor",
|
"ImagePreprocessor",
|
||||||
"preprocessor",
|
"preprocessor",
|
||||||
"VinPreprocessor",
|
"VinPreprocessor",
|
||||||
"vin_preprocessor",
|
"vin_preprocessor",
|
||||||
|
"ReceiptPreprocessor",
|
||||||
|
"receipt_preprocessor",
|
||||||
]
|
]
|
||||||
|
|||||||
340
ocr/app/preprocessors/receipt_preprocessor.py
Normal file
340
ocr/app/preprocessors/receipt_preprocessor.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
"""Receipt-optimized image preprocessing pipeline."""
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from pillow_heif import register_heif_opener
|
||||||
|
|
||||||
|
# Register HEIF/HEIC opener
|
||||||
|
register_heif_opener()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReceiptPreprocessingResult:
|
||||||
|
"""Result of receipt preprocessing."""
|
||||||
|
|
||||||
|
image_bytes: bytes
|
||||||
|
preprocessing_applied: list[str]
|
||||||
|
original_width: int
|
||||||
|
original_height: int
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptPreprocessor:
|
||||||
|
"""Receipt-optimized image preprocessing for improved OCR accuracy.
|
||||||
|
|
||||||
|
Thermal receipts typically have:
|
||||||
|
- Low contrast (faded ink)
|
||||||
|
- Uneven illumination
|
||||||
|
- Paper curl/skew
|
||||||
|
- Variable font weights
|
||||||
|
|
||||||
|
This preprocessor addresses these issues with targeted enhancements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Optimal width for receipt OCR (narrow receipts work better)
|
||||||
|
TARGET_WIDTH = 800
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
image_bytes: bytes,
|
||||||
|
apply_contrast: bool = True,
|
||||||
|
apply_deskew: bool = True,
|
||||||
|
apply_denoise: bool = True,
|
||||||
|
apply_threshold: bool = True,
|
||||||
|
apply_sharpen: bool = True,
|
||||||
|
) -> ReceiptPreprocessingResult:
|
||||||
|
"""
|
||||||
|
Apply receipt-optimized preprocessing pipeline.
|
||||||
|
|
||||||
|
Pipeline optimized for thermal receipts:
|
||||||
|
1. HEIC conversion (if needed)
|
||||||
|
2. Grayscale conversion
|
||||||
|
3. Resize to optimal width
|
||||||
|
4. Deskew (correct rotation)
|
||||||
|
5. High contrast enhancement (CLAHE + histogram stretch)
|
||||||
|
6. Adaptive sharpening
|
||||||
|
7. Noise reduction
|
||||||
|
8. Adaptive thresholding (receipt-optimized)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: Raw image bytes (HEIC, JPEG, PNG)
|
||||||
|
apply_contrast: Apply contrast enhancement
|
||||||
|
apply_deskew: Apply deskew correction
|
||||||
|
apply_denoise: Apply noise reduction
|
||||||
|
apply_threshold: Apply adaptive thresholding
|
||||||
|
apply_sharpen: Apply sharpening
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReceiptPreprocessingResult with processed image bytes
|
||||||
|
"""
|
||||||
|
steps_applied = []
|
||||||
|
|
||||||
|
# Load image with PIL (handles HEIC via pillow-heif)
|
||||||
|
pil_image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
original_width, original_height = pil_image.size
|
||||||
|
steps_applied.append("loaded")
|
||||||
|
|
||||||
|
# Handle EXIF rotation
|
||||||
|
pil_image = self._fix_orientation(pil_image)
|
||||||
|
|
||||||
|
# Convert to RGB if needed
|
||||||
|
if pil_image.mode not in ("RGB", "L"):
|
||||||
|
pil_image = pil_image.convert("RGB")
|
||||||
|
steps_applied.append("convert_rgb")
|
||||||
|
|
||||||
|
# Convert to OpenCV format
|
||||||
|
cv_image = np.array(pil_image)
|
||||||
|
if len(cv_image.shape) == 3:
|
||||||
|
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# Convert to grayscale
|
||||||
|
if len(cv_image.shape) == 3:
|
||||||
|
gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
|
||||||
|
else:
|
||||||
|
gray = cv_image
|
||||||
|
steps_applied.append("grayscale")
|
||||||
|
|
||||||
|
# Resize to optimal width while maintaining aspect ratio
|
||||||
|
gray = self._resize_optimal(gray)
|
||||||
|
steps_applied.append("resize")
|
||||||
|
|
||||||
|
# Apply deskew
|
||||||
|
if apply_deskew:
|
||||||
|
gray = self._deskew(gray)
|
||||||
|
steps_applied.append("deskew")
|
||||||
|
|
||||||
|
# Apply high contrast enhancement (critical for thermal receipts)
|
||||||
|
if apply_contrast:
|
||||||
|
gray = self._enhance_contrast(gray)
|
||||||
|
steps_applied.append("contrast")
|
||||||
|
|
||||||
|
# Apply sharpening
|
||||||
|
if apply_sharpen:
|
||||||
|
gray = self._sharpen(gray)
|
||||||
|
steps_applied.append("sharpen")
|
||||||
|
|
||||||
|
# Apply denoising
|
||||||
|
if apply_denoise:
|
||||||
|
gray = self._denoise(gray)
|
||||||
|
steps_applied.append("denoise")
|
||||||
|
|
||||||
|
# Apply adaptive thresholding (receipt-optimized parameters)
|
||||||
|
if apply_threshold:
|
||||||
|
gray = self._adaptive_threshold_receipt(gray)
|
||||||
|
steps_applied.append("threshold")
|
||||||
|
|
||||||
|
# Convert back to PNG bytes
|
||||||
|
result_image = Image.fromarray(gray)
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
result_image.save(buffer, format="PNG")
|
||||||
|
|
||||||
|
logger.debug(f"Receipt preprocessing applied: {steps_applied}")
|
||||||
|
|
||||||
|
return ReceiptPreprocessingResult(
|
||||||
|
image_bytes=buffer.getvalue(),
|
||||||
|
preprocessing_applied=steps_applied,
|
||||||
|
original_width=original_width,
|
||||||
|
original_height=original_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fix_orientation(self, image: Image.Image) -> Image.Image:
|
||||||
|
"""Fix image orientation based on EXIF data."""
|
||||||
|
try:
|
||||||
|
exif = image.getexif()
|
||||||
|
if exif:
|
||||||
|
orientation = exif.get(274) # Orientation tag
|
||||||
|
if orientation:
|
||||||
|
rotate_values = {
|
||||||
|
3: 180,
|
||||||
|
6: 270,
|
||||||
|
8: 90,
|
||||||
|
}
|
||||||
|
if orientation in rotate_values:
|
||||||
|
return image.rotate(
|
||||||
|
rotate_values[orientation], expand=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not read EXIF orientation: {e}")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _resize_optimal(self, image: np.ndarray) -> np.ndarray:
|
||||||
|
"""Resize image to optimal width for OCR."""
|
||||||
|
height, width = image.shape[:2]
|
||||||
|
|
||||||
|
if width <= self.TARGET_WIDTH:
|
||||||
|
return image
|
||||||
|
|
||||||
|
scale = self.TARGET_WIDTH / width
|
||||||
|
new_height = int(height * scale)
|
||||||
|
|
||||||
|
return cv2.resize(
|
||||||
|
image,
|
||||||
|
(self.TARGET_WIDTH, new_height),
|
||||||
|
interpolation=cv2.INTER_AREA,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _deskew(self, image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Correct image rotation using projection profile.
|
||||||
|
|
||||||
|
Receipts often have slight rotation from scanning/photography.
|
||||||
|
Uses projection profile method optimized for text documents.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create binary image for angle detection
|
||||||
|
_, binary = cv2.threshold(
|
||||||
|
image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find all non-zero points
|
||||||
|
coords = np.column_stack(np.where(binary > 0))
|
||||||
|
if len(coords) < 100:
|
||||||
|
return image
|
||||||
|
|
||||||
|
# Use minimum area rectangle to find angle
|
||||||
|
rect = cv2.minAreaRect(coords)
|
||||||
|
angle = rect[-1]
|
||||||
|
|
||||||
|
# Normalize angle
|
||||||
|
if angle < -45:
|
||||||
|
angle = 90 + angle
|
||||||
|
elif angle > 45:
|
||||||
|
angle = angle - 90
|
||||||
|
|
||||||
|
# Only correct if angle is significant but not extreme
|
||||||
|
if abs(angle) < 0.5 or abs(angle) > 15:
|
||||||
|
return image
|
||||||
|
|
||||||
|
# Rotate to correct skew
|
||||||
|
height, width = image.shape[:2]
|
||||||
|
center = (width // 2, height // 2)
|
||||||
|
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||||||
|
|
||||||
|
rotated = cv2.warpAffine(
|
||||||
|
image,
|
||||||
|
rotation_matrix,
|
||||||
|
(width, height),
|
||||||
|
borderMode=cv2.BORDER_REPLICATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Receipt deskewed by {angle:.2f} degrees")
|
||||||
|
return rotated
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Deskew failed: {e}")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _enhance_contrast(self, image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply aggressive contrast enhancement for faded receipts.
|
||||||
|
|
||||||
|
Combines:
|
||||||
|
1. Histogram stretching
|
||||||
|
2. CLAHE (Contrast Limited Adaptive Histogram Equalization)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# First, stretch histogram to use full dynamic range
|
||||||
|
p2, p98 = np.percentile(image, (2, 98))
|
||||||
|
stretched = np.clip(
|
||||||
|
(image - p2) * 255.0 / (p98 - p2), 0, 255
|
||||||
|
).astype(np.uint8)
|
||||||
|
|
||||||
|
# Apply CLAHE with parameters optimized for receipts
|
||||||
|
# Higher clipLimit for faded thermal receipts
|
||||||
|
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||||
|
enhanced = clahe.apply(stretched)
|
||||||
|
|
||||||
|
return enhanced
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Contrast enhancement failed: {e}")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _sharpen(self, image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply unsharp masking for clearer text edges.
|
||||||
|
|
||||||
|
Light sharpening improves OCR on slightly blurry images.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Gaussian blur for unsharp mask
|
||||||
|
blurred = cv2.GaussianBlur(image, (0, 0), 2.0)
|
||||||
|
|
||||||
|
# Unsharp mask: original + alpha * (original - blurred)
|
||||||
|
sharpened = cv2.addWeighted(image, 1.5, blurred, -0.5, 0)
|
||||||
|
|
||||||
|
return sharpened
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Sharpening failed: {e}")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _denoise(self, image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply light denoising optimized for text.
|
||||||
|
|
||||||
|
Uses bilateral filter to preserve edges while reducing noise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Bilateral filter preserves edges better than Gaussian
|
||||||
|
# Light denoising - don't want to blur text
|
||||||
|
return cv2.bilateralFilter(image, 5, 50, 50)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Denoising failed: {e}")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _adaptive_threshold_receipt(self, image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply adaptive thresholding optimized for receipt text.
|
||||||
|
|
||||||
|
Uses parameters tuned for:
|
||||||
|
- Variable font sizes (small print + headers)
|
||||||
|
- Faded thermal printing
|
||||||
|
- Uneven paper illumination
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use Gaussian adaptive threshold
|
||||||
|
# Larger block size (31) handles uneven illumination
|
||||||
|
# Moderate C value (8) for faded receipts
|
||||||
|
binary = cv2.adaptiveThreshold(
|
||||||
|
image,
|
||||||
|
255,
|
||||||
|
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||||
|
cv2.THRESH_BINARY,
|
||||||
|
blockSize=31,
|
||||||
|
C=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
return binary
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Adaptive threshold failed: {e}")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def preprocess_for_low_quality(
|
||||||
|
self, image_bytes: bytes
|
||||||
|
) -> ReceiptPreprocessingResult:
|
||||||
|
"""
|
||||||
|
Apply aggressive preprocessing for very low quality receipts.
|
||||||
|
|
||||||
|
Use this when standard preprocessing fails to produce readable text.
|
||||||
|
"""
|
||||||
|
return self.preprocess(
|
||||||
|
image_bytes,
|
||||||
|
apply_contrast=True,
|
||||||
|
apply_deskew=True,
|
||||||
|
apply_denoise=True,
|
||||||
|
apply_threshold=True,
|
||||||
|
apply_sharpen=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
receipt_preprocessor = ReceiptPreprocessor()
|
||||||
@@ -1,10 +1,19 @@
|
|||||||
"""OCR extraction endpoints."""
|
"""OCR extraction endpoints."""
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||||
|
|
||||||
from app.extractors.vin_extractor import vin_extractor
|
from app.extractors.vin_extractor import vin_extractor
|
||||||
from app.models import BoundingBox, OcrResponse, VinAlternative, VinExtractionResponse
|
from app.extractors.receipt_extractor import receipt_extractor
|
||||||
|
from app.models import (
|
||||||
|
BoundingBox,
|
||||||
|
OcrResponse,
|
||||||
|
ReceiptExtractedField,
|
||||||
|
ReceiptExtractionResponse,
|
||||||
|
VinAlternative,
|
||||||
|
VinExtractionResponse,
|
||||||
|
)
|
||||||
from app.services import ocr_service
|
from app.services import ocr_service
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -154,3 +163,97 @@ async def extract_vin(
|
|||||||
processingTimeMs=result.processing_time_ms,
|
processingTimeMs=result.processing_time_ms,
|
||||||
error=result.error,
|
error=result.error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/receipt", response_model=ReceiptExtractionResponse)
|
||||||
|
async def extract_receipt(
|
||||||
|
file: UploadFile = File(..., description="Receipt image file"),
|
||||||
|
receipt_type: Optional[str] = Form(
|
||||||
|
default=None,
|
||||||
|
description="Receipt type hint: 'fuel' for specialized extraction",
|
||||||
|
),
|
||||||
|
) -> ReceiptExtractionResponse:
|
||||||
|
"""
|
||||||
|
Extract data from a receipt image using OCR.
|
||||||
|
|
||||||
|
Optimized for fuel receipts with pattern-based field extraction:
|
||||||
|
- HEIC conversion (if needed)
|
||||||
|
- Grayscale conversion
|
||||||
|
- High contrast enhancement (for thermal receipts)
|
||||||
|
- Adaptive thresholding
|
||||||
|
- Pattern matching for dates, amounts, fuel quantities
|
||||||
|
|
||||||
|
Supports HEIC, JPEG, PNG formats.
|
||||||
|
Processing time target: <3 seconds.
|
||||||
|
|
||||||
|
- **file**: Receipt image file (max 10MB)
|
||||||
|
- **receipt_type**: Optional hint ("fuel" for gas station receipts)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- **receiptType**: Detected type ("fuel" or "unknown")
|
||||||
|
- **extractedFields**: Dictionary of extracted fields with confidence scores
|
||||||
|
- merchantName: Gas station or store name
|
||||||
|
- transactionDate: Date in YYYY-MM-DD format
|
||||||
|
- totalAmount: Total purchase amount
|
||||||
|
- fuelQuantity: Gallons/liters purchased (fuel receipts)
|
||||||
|
- pricePerUnit: Price per gallon/liter (fuel receipts)
|
||||||
|
- fuelGrade: Octane rating or fuel type (fuel receipts)
|
||||||
|
- **rawText**: Full OCR text
|
||||||
|
- **processingTimeMs**: Processing time in milliseconds
|
||||||
|
"""
|
||||||
|
# Validate file presence
|
||||||
|
if not file.filename:
|
||||||
|
raise HTTPException(status_code=400, detail="No file provided")
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
# Validate file size
|
||||||
|
if file_size > MAX_SYNC_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail=f"File too large. Max: {MAX_SYNC_SIZE // (1024*1024)}MB",
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_size == 0:
|
||||||
|
raise HTTPException(status_code=400, detail="Empty file provided")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Receipt extraction: {file.filename}, "
|
||||||
|
f"size: {file_size} bytes, "
|
||||||
|
f"content_type: {file.content_type}, "
|
||||||
|
f"receipt_type: {receipt_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform receipt extraction
|
||||||
|
result = receipt_extractor.extract(
|
||||||
|
image_bytes=content,
|
||||||
|
content_type=file.content_type,
|
||||||
|
receipt_type=receipt_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
logger.warning(f"Receipt extraction failed for {file.filename}: {result.error}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail=result.error or "Failed to extract data from receipt image",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert internal fields to API response format
|
||||||
|
extracted_fields = {
|
||||||
|
name: ReceiptExtractedField(
|
||||||
|
value=field.value,
|
||||||
|
confidence=field.confidence,
|
||||||
|
)
|
||||||
|
for name, field in result.extracted_fields.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReceiptExtractionResponse(
|
||||||
|
success=result.success,
|
||||||
|
receiptType=result.receipt_type,
|
||||||
|
extractedFields=extracted_fields,
|
||||||
|
rawText=result.raw_text,
|
||||||
|
processingTimeMs=result.processing_time_ms,
|
||||||
|
error=result.error,
|
||||||
|
)
|
||||||
|
|||||||
198
ocr/tests/test_currency_patterns.py
Normal file
198
ocr/tests/test_currency_patterns.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Tests for currency pattern matching."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.patterns.currency_patterns import CurrencyPatternMatcher, currency_matcher
|
||||||
|
|
||||||
|
|
||||||
|
class TestCurrencyPatternMatcher:
|
||||||
|
"""Test currency and amount extraction."""
|
||||||
|
|
||||||
|
def test_total_explicit(self) -> None:
|
||||||
|
"""Test 'TOTAL $XX.XX' pattern."""
|
||||||
|
text = "TOTAL $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
assert result.confidence > 0.9
|
||||||
|
assert result.label == "TOTAL"
|
||||||
|
|
||||||
|
def test_total_with_colon(self) -> None:
|
||||||
|
"""Test 'TOTAL: $XX.XX' pattern."""
|
||||||
|
text = "TOTAL: $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
|
||||||
|
def test_total_without_dollar_sign(self) -> None:
|
||||||
|
"""Test 'TOTAL 45.67' pattern."""
|
||||||
|
text = "TOTAL 45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
|
||||||
|
def test_amount_due(self) -> None:
|
||||||
|
"""Test 'AMOUNT DUE' pattern."""
|
||||||
|
text = "AMOUNT DUE: $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
assert result.label == "AMOUNT DUE"
|
||||||
|
|
||||||
|
def test_sale_pattern(self) -> None:
|
||||||
|
"""Test 'SALE $XX.XX' pattern."""
|
||||||
|
text = "SALE $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
|
||||||
|
def test_grand_total(self) -> None:
|
||||||
|
"""Test 'GRAND TOTAL' pattern."""
|
||||||
|
text = "GRAND TOTAL $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
assert result.label == "GRAND TOTAL"
|
||||||
|
|
||||||
|
def test_total_sale(self) -> None:
|
||||||
|
"""Test 'TOTAL SALE' pattern."""
|
||||||
|
text = "TOTAL SALE: $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
|
||||||
|
def test_balance_due(self) -> None:
|
||||||
|
"""Test 'BALANCE DUE' pattern."""
|
||||||
|
text = "BALANCE DUE $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
|
||||||
|
def test_multiple_amounts_picks_total(self) -> None:
|
||||||
|
"""Test that labeled total is preferred over generic amounts."""
|
||||||
|
text = """
|
||||||
|
REGULAR 87
|
||||||
|
10.500 GAL @ $3.67
|
||||||
|
SUBTOTAL $38.54
|
||||||
|
TAX $0.00
|
||||||
|
TOTAL $38.54
|
||||||
|
"""
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 38.54
|
||||||
|
assert result.pattern_name == "total_explicit"
|
||||||
|
|
||||||
|
def test_all_amounts(self) -> None:
|
||||||
|
"""Test extracting all amounts from receipt."""
|
||||||
|
text = """
|
||||||
|
SUBTOTAL $35.00
|
||||||
|
TAX $3.54
|
||||||
|
TOTAL $38.54
|
||||||
|
"""
|
||||||
|
results = currency_matcher.extract_all_amounts(text)
|
||||||
|
|
||||||
|
# Should find TOTAL and possibly others
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert any(r.value == 38.54 for r in results)
|
||||||
|
|
||||||
|
def test_comma_thousand_separator(self) -> None:
|
||||||
|
"""Test amounts with thousand separators."""
|
||||||
|
text = "TOTAL $1,234.56"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 1234.56
|
||||||
|
|
||||||
|
def test_reasonable_total_range(self) -> None:
|
||||||
|
"""Test that unreasonable totals are filtered."""
|
||||||
|
# Very small amount
|
||||||
|
text = "TOTAL $0.05"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
assert result is None # Too small for fuel receipt
|
||||||
|
|
||||||
|
# Reasonable amount
|
||||||
|
text = "TOTAL $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_receipt_context_extraction(self) -> None:
|
||||||
|
"""Test extraction from realistic receipt text."""
|
||||||
|
text = """
|
||||||
|
SHELL
|
||||||
|
123 MAIN ST
|
||||||
|
DATE: 01/15/2024
|
||||||
|
|
||||||
|
UNLEADED 87
|
||||||
|
10.500 GAL
|
||||||
|
@ $3.679/GAL
|
||||||
|
|
||||||
|
FUEL TOTAL $38.63
|
||||||
|
TAX $0.00
|
||||||
|
TOTAL $38.63
|
||||||
|
|
||||||
|
DEBIT CARD
|
||||||
|
************1234
|
||||||
|
"""
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 38.63
|
||||||
|
|
||||||
|
def test_no_total_returns_largest(self) -> None:
|
||||||
|
"""Test fallback to largest amount when no labeled total."""
|
||||||
|
text = """
|
||||||
|
$10.50
|
||||||
|
$5.00
|
||||||
|
$45.67
|
||||||
|
"""
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
# Should infer largest reasonable amount as total
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
assert result.confidence < 0.7 # Lower confidence for inferred
|
||||||
|
|
||||||
|
def test_no_amounts_returns_none(self) -> None:
|
||||||
|
"""Test that text without amounts returns None."""
|
||||||
|
text = "SHELL STATION\nPUMP 5"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Test edge cases in currency parsing."""
|
||||||
|
|
||||||
|
def test_european_format(self) -> None:
|
||||||
|
"""Test European format (comma as decimal)."""
|
||||||
|
# European: 45,67 means 45.67
|
||||||
|
text = "TOTAL 45,67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
|
||||||
|
def test_spaces_in_amount(self) -> None:
|
||||||
|
"""Test handling of spaces around amounts."""
|
||||||
|
text = "TOTAL $ 45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 45.67
|
||||||
|
|
||||||
|
def test_case_insensitive(self) -> None:
|
||||||
|
"""Test case insensitive matching."""
|
||||||
|
for label in ["TOTAL", "Total", "total"]:
|
||||||
|
text = f"{label} $45.67"
|
||||||
|
result = currency_matcher.extract_total(text)
|
||||||
|
|
||||||
|
assert result is not None, f"Failed for {label}"
|
||||||
|
assert result.value == 45.67
|
||||||
163
ocr/tests/test_date_patterns.py
Normal file
163
ocr/tests/test_date_patterns.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Tests for date pattern matching."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.patterns.date_patterns import DatePatternMatcher, date_matcher
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatePatternMatcher:
|
||||||
|
"""Test date pattern extraction."""
|
||||||
|
|
||||||
|
def test_mm_dd_yyyy_slash(self) -> None:
|
||||||
|
"""Test MM/DD/YYYY format."""
|
||||||
|
text = "DATE: 01/15/2024"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
assert result.confidence > 0.9
|
||||||
|
|
||||||
|
def test_mm_dd_yy_slash(self) -> None:
|
||||||
|
"""Test MM/DD/YY format with 2-digit year."""
|
||||||
|
text = "01/15/24"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
|
||||||
|
def test_mm_dd_yyyy_dash(self) -> None:
|
||||||
|
"""Test MM-DD-YYYY format."""
|
||||||
|
text = "01-15-2024"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
|
||||||
|
def test_iso_format(self) -> None:
|
||||||
|
"""Test ISO YYYY-MM-DD format."""
|
||||||
|
text = "2024-01-15"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
assert result.confidence > 0.95
|
||||||
|
|
||||||
|
def test_month_name_format(self) -> None:
|
||||||
|
"""Test 'Jan 15, 2024' format."""
|
||||||
|
text = "Jan 15, 2024"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
|
||||||
|
def test_month_name_no_comma(self) -> None:
|
||||||
|
"""Test 'Jan 15 2024' format without comma."""
|
||||||
|
text = "Jan 15 2024"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
|
||||||
|
def test_day_month_year_format(self) -> None:
|
||||||
|
"""Test '15 Jan 2024' format."""
|
||||||
|
text = "15 Jan 2024"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
|
||||||
|
def test_full_month_name(self) -> None:
|
||||||
|
"""Test full month name like 'January'."""
|
||||||
|
text = "January 15, 2024"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
|
||||||
|
def test_multiple_dates_returns_best(self) -> None:
|
||||||
|
"""Test that multiple dates returns highest confidence."""
|
||||||
|
text = "Date: 01/15/2024\nExpires: 01/15/2025"
|
||||||
|
results = date_matcher.extract_dates(text)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# Both should be valid
|
||||||
|
assert all(r.confidence > 0.5 for r in results)
|
||||||
|
|
||||||
|
def test_invalid_date_rejected(self) -> None:
|
||||||
|
"""Test that invalid dates are rejected."""
|
||||||
|
text = "13/45/2024" # Invalid month/day
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_receipt_context_text(self) -> None:
|
||||||
|
"""Test date extraction from realistic receipt text."""
|
||||||
|
text = """
|
||||||
|
SHELL STATION
|
||||||
|
123 MAIN ST
|
||||||
|
DATE: 01/15/2024
|
||||||
|
TIME: 14:32
|
||||||
|
PUMP #5
|
||||||
|
REGULAR 87
|
||||||
|
10.500 GAL
|
||||||
|
TOTAL $38.50
|
||||||
|
"""
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-01-15"
|
||||||
|
|
||||||
|
def test_no_date_returns_none(self) -> None:
|
||||||
|
"""Test that text without dates returns None."""
|
||||||
|
text = "SHELL STATION\nTOTAL $38.50"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_confidence_boost_near_keyword(self) -> None:
|
||||||
|
"""Test confidence boost when date is near DATE keyword."""
|
||||||
|
text_with_keyword = "DATE: 01/15/2024"
|
||||||
|
text_without = "01/15/2024"
|
||||||
|
|
||||||
|
result_with = date_matcher.extract_best_date(text_with_keyword)
|
||||||
|
result_without = date_matcher.extract_best_date(text_without)
|
||||||
|
|
||||||
|
assert result_with is not None
|
||||||
|
assert result_without is not None
|
||||||
|
# Keyword proximity should boost confidence
|
||||||
|
assert result_with.confidence >= result_without.confidence
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Test edge cases in date parsing."""
|
||||||
|
|
||||||
|
def test_year_2000(self) -> None:
|
||||||
|
"""Test 2-digit year 00 is parsed as 2000."""
|
||||||
|
text = "01/15/00"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2000-01-15"
|
||||||
|
|
||||||
|
def test_leap_year_date(self) -> None:
|
||||||
|
"""Test Feb 29 on leap year."""
|
||||||
|
text = "02/29/2024" # 2024 is a leap year
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "2024-02-29"
|
||||||
|
|
||||||
|
def test_leap_year_invalid(self) -> None:
|
||||||
|
"""Test Feb 29 on non-leap year is rejected."""
|
||||||
|
text = "02/29/2023" # 2023 is not a leap year
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_september_abbrev(self) -> None:
|
||||||
|
"""Test September abbreviation (Sept vs Sep)."""
|
||||||
|
for abbrev in ["Sep", "Sept", "September"]:
|
||||||
|
text = f"{abbrev} 15, 2024"
|
||||||
|
result = date_matcher.extract_best_date(text)
|
||||||
|
|
||||||
|
assert result is not None, f"Failed for {abbrev}"
|
||||||
|
assert result.value == "2024-09-15"
|
||||||
327
ocr/tests/test_fuel_patterns.py
Normal file
327
ocr/tests/test_fuel_patterns.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
"""Tests for fuel-specific pattern matching."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.patterns.fuel_patterns import FuelPatternMatcher, fuel_matcher
|
||||||
|
|
||||||
|
|
||||||
|
class TestFuelQuantityExtraction:
|
||||||
|
"""Test fuel quantity (gallons/liters) extraction."""
|
||||||
|
|
||||||
|
def test_gallons_suffix(self) -> None:
|
||||||
|
"""Test 'XX.XXX GAL' pattern."""
|
||||||
|
text = "10.500 GAL"
|
||||||
|
result = fuel_matcher.extract_gallons(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 10.5
|
||||||
|
assert result.unit == "GAL"
|
||||||
|
assert result.confidence > 0.9
|
||||||
|
|
||||||
|
def test_gallons_full_word(self) -> None:
|
||||||
|
"""Test 'XX.XXX GALLONS' pattern."""
|
||||||
|
text = "10.500 GALLONS"
|
||||||
|
result = fuel_matcher.extract_gallons(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 10.5
|
||||||
|
|
||||||
|
def test_gallons_prefix(self) -> None:
|
||||||
|
"""Test 'GALLONS: XX.XXX' pattern."""
|
||||||
|
text = "GALLONS: 10.500"
|
||||||
|
result = fuel_matcher.extract_gallons(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 10.5
|
||||||
|
|
||||||
|
def test_gal_prefix(self) -> None:
|
||||||
|
"""Test 'GAL: XX.XXX' pattern."""
|
||||||
|
text = "GAL: 10.500"
|
||||||
|
result = fuel_matcher.extract_gallons(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 10.5
|
||||||
|
|
||||||
|
def test_volume_label(self) -> None:
|
||||||
|
"""Test 'VOLUME XX.XXX' pattern."""
|
||||||
|
text = "VOLUME: 10.500"
|
||||||
|
result = fuel_matcher.extract_gallons(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 10.5
|
||||||
|
|
||||||
|
def test_liters_suffix(self) -> None:
|
||||||
|
"""Test 'XX.XX L' pattern."""
|
||||||
|
text = "40.5 L"
|
||||||
|
result = fuel_matcher.extract_liters(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 40.5
|
||||||
|
assert result.unit == "L"
|
||||||
|
|
||||||
|
def test_liters_full_word(self) -> None:
|
||||||
|
"""Test 'XX.XX LITERS' pattern."""
|
||||||
|
text = "40.5 LITERS"
|
||||||
|
result = fuel_matcher.extract_liters(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 40.5
|
||||||
|
|
||||||
|
def test_quantity_prefers_gallons(self) -> None:
|
||||||
|
"""Test extract_quantity prefers gallons for US receipts."""
|
||||||
|
text = "10.500 GAL"
|
||||||
|
result = fuel_matcher.extract_quantity(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.unit == "GAL"
|
||||||
|
|
||||||
|
def test_reasonable_quantity_filter(self) -> None:
|
||||||
|
"""Test unreasonable quantities are filtered."""
|
||||||
|
# Too small
|
||||||
|
text = "0.001 GAL"
|
||||||
|
result = fuel_matcher.extract_gallons(text)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Too large
|
||||||
|
text = "100.0 GAL"
|
||||||
|
result = fuel_matcher.extract_gallons(text)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestFuelPriceExtraction:
|
||||||
|
"""Test price per unit extraction."""
|
||||||
|
|
||||||
|
def test_price_per_gal_dollar_sign(self) -> None:
|
||||||
|
"""Test '$X.XXX/GAL' pattern."""
|
||||||
|
text = "$3.679/GAL"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 3.679
|
||||||
|
assert result.unit == "GAL"
|
||||||
|
assert result.confidence > 0.95
|
||||||
|
|
||||||
|
def test_price_per_gal_no_dollar(self) -> None:
|
||||||
|
"""Test 'X.XXX/GAL' pattern."""
|
||||||
|
text = "3.679/GAL"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 3.679
|
||||||
|
|
||||||
|
def test_labeled_price_gal(self) -> None:
|
||||||
|
"""Test 'PRICE/GAL $X.XXX' pattern."""
|
||||||
|
text = "PRICE/GAL $3.679"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 3.679
|
||||||
|
|
||||||
|
def test_unit_price(self) -> None:
|
||||||
|
"""Test 'UNIT PRICE $X.XXX' pattern."""
|
||||||
|
text = "UNIT PRICE: $3.679"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 3.679
|
||||||
|
|
||||||
|
def test_at_price(self) -> None:
|
||||||
|
"""Test '@ $X.XXX' pattern."""
|
||||||
|
text = "10.500 GAL @ $3.679"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 3.679
|
||||||
|
|
||||||
|
def test_ppg_pattern(self) -> None:
|
||||||
|
"""Test 'PPG $X.XXX' pattern."""
|
||||||
|
text = "PPG: $3.679"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == 3.679
|
||||||
|
|
||||||
|
def test_reasonable_price_filter(self) -> None:
|
||||||
|
"""Test unreasonable prices are filtered."""
|
||||||
|
# Too low
|
||||||
|
text = "$0.50/GAL"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Too high
|
||||||
|
text = "$15.00/GAL"
|
||||||
|
result = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestFuelGradeExtraction:
|
||||||
|
"""Test fuel grade/octane extraction."""
|
||||||
|
|
||||||
|
def test_regular_87(self) -> None:
|
||||||
|
"""Test 'REGULAR 87' pattern."""
|
||||||
|
text = "REGULAR 87"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "87"
|
||||||
|
assert "Regular" in result.display_name
|
||||||
|
|
||||||
|
def test_reg_87(self) -> None:
|
||||||
|
"""Test 'REG 87' pattern."""
|
||||||
|
text = "REG 87"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "87"
|
||||||
|
|
||||||
|
def test_unleaded_87(self) -> None:
|
||||||
|
"""Test 'UNLEADED 87' pattern."""
|
||||||
|
text = "UNLEADED 87"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "87"
|
||||||
|
|
||||||
|
def test_plus_89(self) -> None:
|
||||||
|
"""Test 'PLUS 89' pattern."""
|
||||||
|
text = "PLUS 89"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "89"
|
||||||
|
assert "Plus" in result.display_name
|
||||||
|
|
||||||
|
def test_midgrade_89(self) -> None:
|
||||||
|
"""Test 'MIDGRADE 89' pattern."""
|
||||||
|
text = "MIDGRADE 89"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "89"
|
||||||
|
|
||||||
|
def test_premium_93(self) -> None:
|
||||||
|
"""Test 'PREMIUM 93' pattern."""
|
||||||
|
text = "PREMIUM 93"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "93"
|
||||||
|
assert "Premium" in result.display_name
|
||||||
|
|
||||||
|
def test_super_93(self) -> None:
|
||||||
|
"""Test 'SUPER 93' pattern."""
|
||||||
|
text = "SUPER 93"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "93"
|
||||||
|
|
||||||
|
def test_diesel(self) -> None:
|
||||||
|
"""Test 'DIESEL' pattern."""
|
||||||
|
text = "DIESEL #2"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "DIESEL"
|
||||||
|
assert "Diesel" in result.display_name
|
||||||
|
|
||||||
|
def test_e85(self) -> None:
|
||||||
|
"""Test 'E85' ethanol pattern."""
|
||||||
|
text = "E85"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "E85"
|
||||||
|
|
||||||
|
def test_octane_only(self) -> None:
|
||||||
|
"""Test standalone octane number."""
|
||||||
|
text = "87 OCTANE"
|
||||||
|
result = fuel_matcher.extract_grade(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "87"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMerchantExtraction:
|
||||||
|
"""Test gas station name extraction."""
|
||||||
|
|
||||||
|
def test_shell_station(self) -> None:
|
||||||
|
"""Test Shell station detection."""
|
||||||
|
text = "SHELL\n123 MAIN ST"
|
||||||
|
result = fuel_matcher.extract_merchant_name(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
merchant, confidence = result
|
||||||
|
assert "SHELL" in merchant.upper()
|
||||||
|
assert confidence > 0.8
|
||||||
|
|
||||||
|
def test_chevron_station(self) -> None:
|
||||||
|
"""Test Chevron station detection."""
|
||||||
|
text = "CHEVRON #12345\n456 OAK AVE"
|
||||||
|
result = fuel_matcher.extract_merchant_name(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
merchant, confidence = result
|
||||||
|
assert "CHEVRON" in merchant.upper()
|
||||||
|
|
||||||
|
def test_costco_gas(self) -> None:
|
||||||
|
"""Test Costco gas detection."""
|
||||||
|
text = "COSTCO GASOLINE\n789 WAREHOUSE BLVD"
|
||||||
|
result = fuel_matcher.extract_merchant_name(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
merchant, confidence = result
|
||||||
|
assert "COSTCO" in merchant.upper()
|
||||||
|
|
||||||
|
def test_unknown_station_fallback(self) -> None:
|
||||||
|
"""Test fallback to first line for unknown stations."""
|
||||||
|
text = "JOE'S GAS\n123 MAIN ST"
|
||||||
|
result = fuel_matcher.extract_merchant_name(text)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
merchant, confidence = result
|
||||||
|
assert "JOE'S GAS" in merchant
|
||||||
|
assert confidence < 0.7 # Lower confidence for unknown
|
||||||
|
|
||||||
|
|
||||||
|
class TestReceiptContextExtraction:
|
||||||
|
"""Test extraction from realistic receipt text."""
|
||||||
|
|
||||||
|
def test_full_receipt_extraction(self) -> None:
|
||||||
|
"""Test all fields from complete receipt text."""
|
||||||
|
text = """
|
||||||
|
SHELL
|
||||||
|
123 MAIN STREET
|
||||||
|
ANYTOWN, USA 12345
|
||||||
|
|
||||||
|
DATE: 01/15/2024
|
||||||
|
TIME: 14:32
|
||||||
|
PUMP #5
|
||||||
|
|
||||||
|
REGULAR 87
|
||||||
|
10.500 GAL @ $3.679/GAL
|
||||||
|
|
||||||
|
FUEL TOTAL $38.63
|
||||||
|
TAX $0.00
|
||||||
|
TOTAL $38.63
|
||||||
|
|
||||||
|
DEBIT CARD
|
||||||
|
************1234
|
||||||
|
APPROVED
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Test all extractors on this text
|
||||||
|
quantity = fuel_matcher.extract_quantity(text)
|
||||||
|
assert quantity is not None
|
||||||
|
assert quantity.value == 10.5
|
||||||
|
|
||||||
|
price = fuel_matcher.extract_price_per_unit(text)
|
||||||
|
assert price is not None
|
||||||
|
assert price.value == 3.679
|
||||||
|
|
||||||
|
grade = fuel_matcher.extract_grade(text)
|
||||||
|
assert grade is not None
|
||||||
|
assert grade.value == "87"
|
||||||
|
|
||||||
|
merchant = fuel_matcher.extract_merchant_name(text)
|
||||||
|
assert merchant is not None
|
||||||
|
assert "SHELL" in merchant[0].upper()
|
||||||
339
ocr/tests/test_receipt_extraction.py
Normal file
339
ocr/tests/test_receipt_extraction.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""Tests for receipt extraction pipeline."""
|
||||||
|
import io
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from app.extractors.receipt_extractor import (
|
||||||
|
ReceiptExtractor,
|
||||||
|
ReceiptExtractionResult,
|
||||||
|
receipt_extractor,
|
||||||
|
)
|
||||||
|
from app.extractors.fuel_receipt import (
|
||||||
|
FuelReceiptExtractor,
|
||||||
|
FuelReceiptValidation,
|
||||||
|
fuel_receipt_extractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReceiptExtractor:
|
||||||
|
"""Test the receipt extraction pipeline."""
|
||||||
|
|
||||||
|
def test_detect_receipt_type_fuel(self) -> None:
|
||||||
|
"""Test fuel receipt type detection."""
|
||||||
|
text = """
|
||||||
|
SHELL STATION
|
||||||
|
REGULAR 87
|
||||||
|
10.500 GAL
|
||||||
|
TOTAL $38.50
|
||||||
|
"""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
receipt_type = extractor._detect_receipt_type(text)
|
||||||
|
|
||||||
|
assert receipt_type == "fuel"
|
||||||
|
|
||||||
|
def test_detect_receipt_type_unknown(self) -> None:
|
||||||
|
"""Test unknown receipt type detection."""
|
||||||
|
text = """
|
||||||
|
WALMART
|
||||||
|
GROCERIES
|
||||||
|
MILK $3.99
|
||||||
|
BREAD $2.50
|
||||||
|
TOTAL $6.49
|
||||||
|
"""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
receipt_type = extractor._detect_receipt_type(text)
|
||||||
|
|
||||||
|
assert receipt_type == "unknown"
|
||||||
|
|
||||||
|
def test_extract_fuel_fields(self) -> None:
|
||||||
|
"""Test fuel field extraction from OCR text."""
|
||||||
|
text = """
|
||||||
|
SHELL
|
||||||
|
123 MAIN ST
|
||||||
|
DATE: 01/15/2024
|
||||||
|
REGULAR 87
|
||||||
|
10.500 GAL @ $3.679
|
||||||
|
TOTAL $38.63
|
||||||
|
"""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
fields = extractor._extract_fuel_fields(text)
|
||||||
|
|
||||||
|
assert "merchantName" in fields
|
||||||
|
assert "transactionDate" in fields
|
||||||
|
assert "totalAmount" in fields
|
||||||
|
assert "fuelQuantity" in fields
|
||||||
|
assert "pricePerUnit" in fields
|
||||||
|
assert "fuelGrade" in fields
|
||||||
|
|
||||||
|
assert fields["totalAmount"].value == 38.63
|
||||||
|
assert fields["fuelQuantity"].value == 10.5
|
||||||
|
assert fields["fuelGrade"].value == "87"
|
||||||
|
|
||||||
|
def test_extract_generic_fields(self) -> None:
|
||||||
|
"""Test generic field extraction."""
|
||||||
|
text = """
|
||||||
|
WALMART
|
||||||
|
01/15/2024
|
||||||
|
TOTAL $25.99
|
||||||
|
"""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
fields = extractor._extract_generic_fields(text)
|
||||||
|
|
||||||
|
assert "transactionDate" in fields
|
||||||
|
assert "totalAmount" in fields
|
||||||
|
assert fields["totalAmount"].value == 25.99
|
||||||
|
|
||||||
|
def test_calculated_price_per_unit(self) -> None:
|
||||||
|
"""Test price per unit calculation when not explicitly stated."""
|
||||||
|
text = """
|
||||||
|
SHELL
|
||||||
|
DATE: 01/15/2024
|
||||||
|
10.000 GAL
|
||||||
|
TOTAL $35.00
|
||||||
|
"""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
fields = extractor._extract_fuel_fields(text)
|
||||||
|
|
||||||
|
assert "pricePerUnit" in fields
|
||||||
|
# 35.00 / 10.000 = 3.50
|
||||||
|
assert abs(fields["pricePerUnit"].value - 3.50) < 0.01
|
||||||
|
# Calculated values should have lower confidence
|
||||||
|
assert fields["pricePerUnit"].confidence < 0.9
|
||||||
|
|
||||||
|
def test_validate_valid_data(self) -> None:
|
||||||
|
"""Test validation of valid receipt data."""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
|
||||||
|
data = {"totalAmount": 38.63, "transactionDate": "2024-01-15"}
|
||||||
|
assert extractor.validate(data) is True
|
||||||
|
|
||||||
|
def test_validate_invalid_data(self) -> None:
|
||||||
|
"""Test validation of invalid receipt data."""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
|
||||||
|
# Empty dict
|
||||||
|
assert extractor.validate({}) is False
|
||||||
|
# Not a dict
|
||||||
|
assert extractor.validate("invalid") is False
|
||||||
|
|
||||||
|
def test_unsupported_file_type(self) -> None:
|
||||||
|
"""Test handling of unsupported file types."""
|
||||||
|
extractor = ReceiptExtractor()
|
||||||
|
|
||||||
|
with patch.object(extractor, "_detect_mime_type", return_value="application/pdf"):
|
||||||
|
result = extractor.extract(b"fake pdf content")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Unsupported file type" in result.error
|
||||||
|
|
||||||
|
|
||||||
|
class TestFuelReceiptExtractor:
|
||||||
|
"""Test fuel receipt specialized extractor."""
|
||||||
|
|
||||||
|
def test_validation_success(self) -> None:
|
||||||
|
"""Test validation passes for consistent data."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=38.63, confidence=0.95),
|
||||||
|
"fuelQuantity": ExtractedField(value=10.5, confidence=0.90),
|
||||||
|
"pricePerUnit": ExtractedField(value=3.679, confidence=0.92),
|
||||||
|
"fuelGrade": ExtractedField(value="87", confidence=0.88),
|
||||||
|
}
|
||||||
|
|
||||||
|
validation = extractor._validate_fuel_receipt(fields)
|
||||||
|
|
||||||
|
assert validation.is_valid is True
|
||||||
|
assert len(validation.issues) == 0
|
||||||
|
assert validation.confidence_score == 1.0
|
||||||
|
|
||||||
|
def test_validation_math_mismatch(self) -> None:
|
||||||
|
"""Test validation catches total != quantity * price."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=50.00, confidence=0.95),
|
||||||
|
"fuelQuantity": ExtractedField(value=10.0, confidence=0.90),
|
||||||
|
"pricePerUnit": ExtractedField(value=3.00, confidence=0.92), # Should be $30
|
||||||
|
}
|
||||||
|
|
||||||
|
validation = extractor._validate_fuel_receipt(fields)
|
||||||
|
|
||||||
|
assert validation.is_valid is False
|
||||||
|
assert any("doesn't match" in issue for issue in validation.issues)
|
||||||
|
|
||||||
|
def test_validation_quantity_too_small(self) -> None:
|
||||||
|
"""Test validation catches too-small quantity."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=1.00, confidence=0.95),
|
||||||
|
"fuelQuantity": ExtractedField(value=0.1, confidence=0.90),
|
||||||
|
}
|
||||||
|
|
||||||
|
validation = extractor._validate_fuel_receipt(fields)
|
||||||
|
|
||||||
|
assert any("too small" in issue for issue in validation.issues)
|
||||||
|
|
||||||
|
def test_validation_quantity_too_large(self) -> None:
|
||||||
|
"""Test validation warns on very large quantity."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=150.00, confidence=0.95),
|
||||||
|
"fuelQuantity": ExtractedField(value=45.0, confidence=0.90),
|
||||||
|
}
|
||||||
|
|
||||||
|
validation = extractor._validate_fuel_receipt(fields)
|
||||||
|
|
||||||
|
assert any("unusually large" in issue for issue in validation.issues)
|
||||||
|
|
||||||
|
def test_validation_price_too_low(self) -> None:
|
||||||
|
"""Test validation catches too-low price."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=10.00, confidence=0.95),
|
||||||
|
"pricePerUnit": ExtractedField(value=1.00, confidence=0.90),
|
||||||
|
}
|
||||||
|
|
||||||
|
validation = extractor._validate_fuel_receipt(fields)
|
||||||
|
|
||||||
|
assert any("too low" in issue for issue in validation.issues)
|
||||||
|
|
||||||
|
def test_validation_unknown_grade(self) -> None:
|
||||||
|
"""Test validation catches unknown fuel grade."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=38.00, confidence=0.95),
|
||||||
|
"fuelGrade": ExtractedField(value="95", confidence=0.70), # Not valid US grade
|
||||||
|
}
|
||||||
|
|
||||||
|
validation = extractor._validate_fuel_receipt(fields)
|
||||||
|
|
||||||
|
assert any("Unknown fuel grade" in issue for issue in validation.issues)
|
||||||
|
|
||||||
|
def test_confidence_adjustment_boost(self) -> None:
|
||||||
|
"""Test confidence boost when validation passes."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=38.63, confidence=0.80),
|
||||||
|
}
|
||||||
|
validation = FuelReceiptValidation(
|
||||||
|
is_valid=True, issues=[], confidence_score=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
adjusted = extractor._adjust_confidences(fields, validation)
|
||||||
|
|
||||||
|
# Should be boosted by 1.1
|
||||||
|
assert adjusted["totalAmount"].confidence == min(1.0, 0.80 * 1.1)
|
||||||
|
|
||||||
|
def test_confidence_adjustment_reduce(self) -> None:
|
||||||
|
"""Test confidence reduction when validation fails."""
|
||||||
|
from app.extractors.receipt_extractor import ExtractedField
|
||||||
|
|
||||||
|
extractor = FuelReceiptExtractor()
|
||||||
|
fields = {
|
||||||
|
"totalAmount": ExtractedField(value=50.00, confidence=0.90),
|
||||||
|
}
|
||||||
|
validation = FuelReceiptValidation(
|
||||||
|
is_valid=False, issues=["Math mismatch"], confidence_score=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
adjusted = extractor._adjust_confidences(fields, validation)
|
||||||
|
|
||||||
|
# Should be reduced by 0.7
|
||||||
|
assert adjusted["totalAmount"].confidence == 0.90 * 0.7
|
||||||
|
|
||||||
|
|
||||||
|
class TestReceiptPreprocessing:
|
||||||
|
"""Test receipt preprocessing integration."""
|
||||||
|
|
||||||
|
def test_preprocessing_result_structure(self) -> None:
|
||||||
|
"""Test preprocessing returns expected structure."""
|
||||||
|
from app.preprocessors.receipt_preprocessor import (
|
||||||
|
receipt_preprocessor,
|
||||||
|
ReceiptPreprocessingResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a simple test image (1x1 white pixel PNG)
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
img = Image.new("RGB", (100, 100), color="white")
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format="PNG")
|
||||||
|
image_bytes = buffer.getvalue()
|
||||||
|
|
||||||
|
result = receipt_preprocessor.preprocess(image_bytes)
|
||||||
|
|
||||||
|
assert isinstance(result, ReceiptPreprocessingResult)
|
||||||
|
assert len(result.image_bytes) > 0
|
||||||
|
assert "loaded" in result.preprocessing_applied
|
||||||
|
assert "grayscale" in result.preprocessing_applied
|
||||||
|
assert result.original_width == 100
|
||||||
|
assert result.original_height == 100
|
||||||
|
|
||||||
|
def test_preprocessing_steps_applied(self) -> None:
|
||||||
|
"""Test all preprocessing steps are applied."""
|
||||||
|
from app.preprocessors.receipt_preprocessor import receipt_preprocessor
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
img = Image.new("RGB", (100, 100), color="white")
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format="PNG")
|
||||||
|
image_bytes = buffer.getvalue()
|
||||||
|
|
||||||
|
result = receipt_preprocessor.preprocess(
|
||||||
|
image_bytes,
|
||||||
|
apply_contrast=True,
|
||||||
|
apply_deskew=True,
|
||||||
|
apply_denoise=True,
|
||||||
|
apply_threshold=True,
|
||||||
|
apply_sharpen=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that expected steps are in the applied list
|
||||||
|
assert "contrast" in result.preprocessing_applied
|
||||||
|
assert "denoise" in result.preprocessing_applied
|
||||||
|
assert "threshold" in result.preprocessing_applied
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndpointIntegration:
|
||||||
|
"""Test receipt extraction endpoint integration."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_client(self):
|
||||||
|
"""Create test client for FastAPI app."""
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
def test_receipt_endpoint_exists(self, test_client) -> None:
|
||||||
|
"""Test that /extract/receipt endpoint exists."""
|
||||||
|
# Should get 422 (no file) not 404 (not found)
|
||||||
|
response = test_client.post("/extract/receipt")
|
||||||
|
assert response.status_code == 422 # Unprocessable Entity (missing file)
|
||||||
|
|
||||||
|
def test_receipt_endpoint_no_file(self, test_client) -> None:
|
||||||
|
"""Test endpoint returns error when no file provided."""
|
||||||
|
response = test_client.post("/extract/receipt")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_receipt_endpoint_empty_file(self, test_client) -> None:
|
||||||
|
"""Test endpoint returns error for empty file."""
|
||||||
|
response = test_client.post(
|
||||||
|
"/extract/receipt",
|
||||||
|
files={"file": ("receipt.jpg", b"", "image/jpeg")},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Empty file" in response.json()["detail"]
|
||||||
Reference in New Issue
Block a user