feat: add receipt OCR pipeline (refs #69)
All checks were successful
Deploy to Staging / Build Images (pull_request) Successful in 32s
Deploy to Staging / Deploy to Staging (pull_request) Successful in 31s
Deploy to Staging / Verify Staging (pull_request) Successful in 2m20s
Deploy to Staging / Notify Staging Ready (pull_request) Successful in 8s
Deploy to Staging / Notify Staging Failure (pull_request) Has been skipped
All checks were successful
Deploy to Staging / Build Images (pull_request) Successful in 32s
Deploy to Staging / Deploy to Staging (pull_request) Successful in 31s
Deploy to Staging / Verify Staging (pull_request) Successful in 2m20s
Deploy to Staging / Notify Staging Ready (pull_request) Successful in 8s
Deploy to Staging / Notify Staging Failure (pull_request) Has been skipped
Implement receipt-specific OCR extraction for fuel receipts: - Pattern matching modules for date, currency, and fuel data extraction - Receipt-optimized image preprocessing for thermal receipts - POST /extract/receipt endpoint with field extraction - Confidence scoring per extracted field - Cross-validation of fuel receipt data - Unit tests for all pattern matchers Extracted fields: merchantName, transactionDate, totalAmount, fuelQuantity, pricePerUnit, fuelGrade Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,23 @@
|
||||
"""Extractors package for domain-specific OCR extraction."""
|
||||
from app.extractors.base import BaseExtractor, ExtractionResult
|
||||
from app.extractors.vin_extractor import VinExtractor, vin_extractor
|
||||
from app.extractors.receipt_extractor import (
|
||||
ReceiptExtractor,
|
||||
receipt_extractor,
|
||||
ReceiptExtractionResult,
|
||||
ExtractedField,
|
||||
)
|
||||
from app.extractors.fuel_receipt import FuelReceiptExtractor, fuel_receipt_extractor
|
||||
|
||||
__all__ = [
|
||||
"BaseExtractor",
|
||||
"ExtractionResult",
|
||||
"VinExtractor",
|
||||
"vin_extractor",
|
||||
"ReceiptExtractor",
|
||||
"receipt_extractor",
|
||||
"ReceiptExtractionResult",
|
||||
"ExtractedField",
|
||||
"FuelReceiptExtractor",
|
||||
"fuel_receipt_extractor",
|
||||
]
|
||||
|
||||
193
ocr/app/extractors/fuel_receipt.py
Normal file
193
ocr/app/extractors/fuel_receipt.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Fuel receipt specialization with validation and cross-checking."""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from app.extractors.receipt_extractor import (
|
||||
ExtractedField,
|
||||
ReceiptExtractionResult,
|
||||
receipt_extractor,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuelReceiptValidation:
|
||||
"""Validation result for fuel receipt extraction."""
|
||||
|
||||
is_valid: bool
|
||||
issues: list[str]
|
||||
confidence_score: float
|
||||
|
||||
|
||||
class FuelReceiptExtractor:
|
||||
"""Specialized fuel receipt extractor with cross-validation.
|
||||
|
||||
Provides additional validation and confidence scoring specific
|
||||
to fuel receipts by cross-checking extracted values.
|
||||
"""
|
||||
|
||||
# Expected fields for a complete fuel receipt
|
||||
REQUIRED_FIELDS = ["totalAmount"]
|
||||
OPTIONAL_FIELDS = [
|
||||
"merchantName",
|
||||
"transactionDate",
|
||||
"fuelQuantity",
|
||||
"pricePerUnit",
|
||||
"fuelGrade",
|
||||
]
|
||||
|
||||
def extract(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
content_type: Optional[str] = None,
|
||||
) -> ReceiptExtractionResult:
|
||||
"""
|
||||
Extract fuel receipt data with validation.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
content_type: MIME type
|
||||
|
||||
Returns:
|
||||
ReceiptExtractionResult with fuel-specific extraction
|
||||
"""
|
||||
# Use base receipt extractor with fuel hint
|
||||
result = receipt_extractor.extract(
|
||||
image_bytes=image_bytes,
|
||||
content_type=content_type,
|
||||
receipt_type="fuel",
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
# Validate and cross-check fuel fields
|
||||
validation = self._validate_fuel_receipt(result.extracted_fields)
|
||||
|
||||
if validation.issues:
|
||||
logger.warning(
|
||||
f"Fuel receipt validation issues: {validation.issues}"
|
||||
)
|
||||
|
||||
# Update overall confidence based on validation
|
||||
result.extracted_fields = self._adjust_confidences(
|
||||
result.extracted_fields, validation
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _validate_fuel_receipt(
|
||||
self, fields: dict[str, ExtractedField]
|
||||
) -> FuelReceiptValidation:
|
||||
"""
|
||||
Validate extracted fuel receipt fields.
|
||||
|
||||
Cross-checks:
|
||||
- total = quantity * price per unit (within tolerance)
|
||||
- quantity is reasonable for a single fill-up
|
||||
- price per unit is within market range
|
||||
|
||||
Args:
|
||||
fields: Extracted fields
|
||||
|
||||
Returns:
|
||||
FuelReceiptValidation with issues and confidence
|
||||
"""
|
||||
issues = []
|
||||
confidence_score = 1.0
|
||||
|
||||
# Check required fields
|
||||
for field_name in self.REQUIRED_FIELDS:
|
||||
if field_name not in fields:
|
||||
issues.append(f"Missing required field: {field_name}")
|
||||
confidence_score *= 0.5
|
||||
|
||||
# Cross-validate total = quantity * price
|
||||
if all(
|
||||
f in fields for f in ["totalAmount", "fuelQuantity", "pricePerUnit"]
|
||||
):
|
||||
total = fields["totalAmount"].value
|
||||
quantity = fields["fuelQuantity"].value
|
||||
price = fields["pricePerUnit"].value
|
||||
|
||||
calculated_total = quantity * price
|
||||
tolerance = 0.10 # Allow 10% tolerance for rounding
|
||||
|
||||
if abs(total - calculated_total) > total * tolerance:
|
||||
issues.append(
|
||||
f"Total ({total}) doesn't match quantity ({quantity}) * "
|
||||
f"price ({price}) = {calculated_total:.2f}"
|
||||
)
|
||||
confidence_score *= 0.7
|
||||
|
||||
# Validate quantity is reasonable
|
||||
if "fuelQuantity" in fields:
|
||||
quantity = fields["fuelQuantity"].value
|
||||
if quantity < 0.5:
|
||||
issues.append(f"Fuel quantity too small: {quantity}")
|
||||
confidence_score *= 0.6
|
||||
elif quantity > 40: # 40 gallons is very large tank
|
||||
issues.append(f"Fuel quantity unusually large: {quantity}")
|
||||
confidence_score *= 0.8
|
||||
|
||||
# Validate price is reasonable (current US market range)
|
||||
if "pricePerUnit" in fields:
|
||||
price = fields["pricePerUnit"].value
|
||||
if price < 1.50:
|
||||
issues.append(f"Price per unit too low: ${price}")
|
||||
confidence_score *= 0.7
|
||||
elif price > 7.00:
|
||||
issues.append(f"Price per unit unusually high: ${price}")
|
||||
confidence_score *= 0.8
|
||||
|
||||
# Validate fuel grade
|
||||
if "fuelGrade" in fields:
|
||||
grade = fields["fuelGrade"].value
|
||||
valid_grades = ["87", "89", "91", "93", "DIESEL", "E85"]
|
||||
if grade not in valid_grades:
|
||||
issues.append(f"Unknown fuel grade: {grade}")
|
||||
confidence_score *= 0.9
|
||||
|
||||
is_valid = len(issues) == 0
|
||||
return FuelReceiptValidation(
|
||||
is_valid=is_valid,
|
||||
issues=issues,
|
||||
confidence_score=confidence_score,
|
||||
)
|
||||
|
||||
def _adjust_confidences(
|
||||
self,
|
||||
fields: dict[str, ExtractedField],
|
||||
validation: FuelReceiptValidation,
|
||||
) -> dict[str, ExtractedField]:
|
||||
"""
|
||||
Adjust field confidences based on validation.
|
||||
|
||||
Args:
|
||||
fields: Extracted fields
|
||||
validation: Validation result
|
||||
|
||||
Returns:
|
||||
Fields with adjusted confidences
|
||||
"""
|
||||
if validation.is_valid:
|
||||
# Boost confidences when cross-validation passes
|
||||
boost = 1.1
|
||||
else:
|
||||
# Reduce confidences when there are issues
|
||||
boost = validation.confidence_score
|
||||
|
||||
adjusted = {}
|
||||
for name, field in fields.items():
|
||||
adjusted[name] = ExtractedField(
|
||||
value=field.value,
|
||||
confidence=min(1.0, field.confidence * boost),
|
||||
)
|
||||
|
||||
return adjusted
|
||||
|
||||
|
||||
# Singleton instance
|
||||
fuel_receipt_extractor = FuelReceiptExtractor()
|
||||
345
ocr/app/extractors/receipt_extractor.py
Normal file
345
ocr/app/extractors/receipt_extractor.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Receipt-specific OCR extractor with field extraction."""
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import magic
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
from pillow_heif import register_heif_opener
|
||||
|
||||
from app.config import settings
|
||||
from app.extractors.base import BaseExtractor
|
||||
from app.preprocessors.receipt_preprocessor import receipt_preprocessor
|
||||
from app.patterns import currency_matcher, date_matcher, fuel_matcher
|
||||
|
||||
# Register HEIF/HEIC opener
|
||||
register_heif_opener()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedField:
|
||||
"""A single extracted field with confidence."""
|
||||
|
||||
value: Any
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReceiptExtractionResult:
|
||||
"""Result of receipt extraction."""
|
||||
|
||||
success: bool
|
||||
receipt_type: str = "unknown"
|
||||
extracted_fields: dict[str, ExtractedField] = field(default_factory=dict)
|
||||
raw_text: str = ""
|
||||
processing_time_ms: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ReceiptExtractor(BaseExtractor):
|
||||
"""Receipt-specific OCR extractor for fuel and general receipts."""
|
||||
|
||||
# Supported MIME types
|
||||
SUPPORTED_TYPES = {
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/heic",
|
||||
"image/heif",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize receipt extractor."""
|
||||
pytesseract.pytesseract.tesseract_cmd = settings.tesseract_cmd
|
||||
|
||||
def extract(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
content_type: Optional[str] = None,
|
||||
receipt_type: Optional[str] = None,
|
||||
) -> ReceiptExtractionResult:
|
||||
"""
|
||||
Extract data from a receipt image.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes (HEIC, JPEG, PNG)
|
||||
content_type: MIME type (auto-detected if not provided)
|
||||
receipt_type: Hint for receipt type ("fuel" for specialized extraction)
|
||||
|
||||
Returns:
|
||||
ReceiptExtractionResult with extracted fields
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Detect content type if not provided
|
||||
if not content_type:
|
||||
content_type = self._detect_mime_type(image_bytes)
|
||||
|
||||
# Validate content type
|
||||
if content_type not in self.SUPPORTED_TYPES:
|
||||
return ReceiptExtractionResult(
|
||||
success=False,
|
||||
error=f"Unsupported file type: {content_type}",
|
||||
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
try:
|
||||
# Apply receipt-optimized preprocessing
|
||||
preprocessing_result = receipt_preprocessor.preprocess(image_bytes)
|
||||
preprocessed_bytes = preprocessing_result.image_bytes
|
||||
|
||||
# Perform OCR
|
||||
raw_text = self._perform_ocr(preprocessed_bytes)
|
||||
|
||||
if not raw_text.strip():
|
||||
# Try with less aggressive preprocessing
|
||||
preprocessing_result = receipt_preprocessor.preprocess(
|
||||
image_bytes,
|
||||
apply_threshold=False,
|
||||
)
|
||||
preprocessed_bytes = preprocessing_result.image_bytes
|
||||
raw_text = self._perform_ocr(preprocessed_bytes)
|
||||
|
||||
if not raw_text.strip():
|
||||
return ReceiptExtractionResult(
|
||||
success=False,
|
||||
error="No text found in image",
|
||||
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
# Detect receipt type if not specified
|
||||
detected_type = receipt_type or self._detect_receipt_type(raw_text)
|
||||
|
||||
# Extract fields based on receipt type
|
||||
if detected_type == "fuel":
|
||||
extracted_fields = self._extract_fuel_fields(raw_text)
|
||||
else:
|
||||
extracted_fields = self._extract_generic_fields(raw_text)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"Receipt extraction: type={detected_type}, "
|
||||
f"fields={len(extracted_fields)}, "
|
||||
f"time={processing_time_ms}ms"
|
||||
)
|
||||
|
||||
return ReceiptExtractionResult(
|
||||
success=True,
|
||||
receipt_type=detected_type,
|
||||
extracted_fields=extracted_fields,
|
||||
raw_text=raw_text,
|
||||
processing_time_ms=processing_time_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Receipt extraction failed: {e}", exc_info=True)
|
||||
return ReceiptExtractionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
def _detect_mime_type(self, file_bytes: bytes) -> str:
|
||||
"""Detect MIME type using python-magic."""
|
||||
mime = magic.Magic(mime=True)
|
||||
detected = mime.from_buffer(file_bytes)
|
||||
return detected or "application/octet-stream"
|
||||
|
||||
def _perform_ocr(self, image_bytes: bytes, psm: int = 6) -> str:
|
||||
"""
|
||||
Perform OCR on preprocessed image.
|
||||
|
||||
Args:
|
||||
image_bytes: Preprocessed image bytes
|
||||
psm: Tesseract page segmentation mode
|
||||
4 = Assume single column of text
|
||||
6 = Uniform block of text
|
||||
|
||||
Returns:
|
||||
Raw OCR text
|
||||
"""
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
# Configure Tesseract for receipt OCR
|
||||
# PSM 4 works well for columnar receipt text
|
||||
config = f"--psm {psm}"
|
||||
|
||||
return pytesseract.image_to_string(image, config=config)
|
||||
|
||||
def _detect_receipt_type(self, text: str) -> str:
|
||||
"""
|
||||
Detect receipt type based on content.
|
||||
|
||||
Args:
|
||||
text: OCR text
|
||||
|
||||
Returns:
|
||||
Receipt type: "fuel", "retail", or "unknown"
|
||||
"""
|
||||
text_upper = text.upper()
|
||||
|
||||
# Fuel receipt indicators
|
||||
fuel_keywords = [
|
||||
"GALLON", "GAL", "FUEL", "GAS", "DIESEL", "UNLEADED",
|
||||
"REGULAR", "PREMIUM", "OCTANE", "PPG", "PUMP",
|
||||
]
|
||||
|
||||
fuel_score = sum(1 for kw in fuel_keywords if kw in text_upper)
|
||||
|
||||
# Check for known gas stations
|
||||
if fuel_matcher.extract_merchant_name(text):
|
||||
merchant, _ = fuel_matcher.extract_merchant_name(text)
|
||||
if any(
|
||||
station in merchant.upper()
|
||||
for station in fuel_matcher.STATION_NAMES
|
||||
):
|
||||
fuel_score += 3
|
||||
|
||||
if fuel_score >= 2:
|
||||
return "fuel"
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _extract_fuel_fields(self, text: str) -> dict[str, ExtractedField]:
|
||||
"""
|
||||
Extract fuel-specific fields from receipt text.
|
||||
|
||||
Args:
|
||||
text: OCR text
|
||||
|
||||
Returns:
|
||||
Dictionary of extracted fields
|
||||
"""
|
||||
fields: dict[str, ExtractedField] = {}
|
||||
|
||||
# Extract merchant name
|
||||
merchant_result = fuel_matcher.extract_merchant_name(text)
|
||||
if merchant_result:
|
||||
merchant_name, confidence = merchant_result
|
||||
fields["merchantName"] = ExtractedField(
|
||||
value=merchant_name,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
# Extract transaction date
|
||||
date_match = date_matcher.extract_best_date(text)
|
||||
if date_match:
|
||||
fields["transactionDate"] = ExtractedField(
|
||||
value=date_match.value,
|
||||
confidence=date_match.confidence,
|
||||
)
|
||||
|
||||
# Extract total amount
|
||||
total_match = currency_matcher.extract_total(text)
|
||||
if total_match:
|
||||
fields["totalAmount"] = ExtractedField(
|
||||
value=total_match.value,
|
||||
confidence=total_match.confidence,
|
||||
)
|
||||
|
||||
# Extract fuel quantity
|
||||
quantity_match = fuel_matcher.extract_quantity(text)
|
||||
if quantity_match:
|
||||
fields["fuelQuantity"] = ExtractedField(
|
||||
value=quantity_match.value,
|
||||
confidence=quantity_match.confidence,
|
||||
)
|
||||
|
||||
# Extract price per unit
|
||||
price_match = fuel_matcher.extract_price_per_unit(text)
|
||||
if price_match:
|
||||
fields["pricePerUnit"] = ExtractedField(
|
||||
value=price_match.value,
|
||||
confidence=price_match.confidence,
|
||||
)
|
||||
|
||||
# Extract fuel grade
|
||||
grade_match = fuel_matcher.extract_grade(text)
|
||||
if grade_match:
|
||||
fields["fuelGrade"] = ExtractedField(
|
||||
value=grade_match.value,
|
||||
confidence=grade_match.confidence,
|
||||
)
|
||||
|
||||
# Calculate derived values if we have enough data
|
||||
if "totalAmount" in fields and "fuelQuantity" in fields:
|
||||
if "pricePerUnit" not in fields:
|
||||
# Calculate price per unit from total and quantity
|
||||
calculated_price = (
|
||||
fields["totalAmount"].value / fields["fuelQuantity"].value
|
||||
)
|
||||
# Only use if reasonable
|
||||
if 1.0 <= calculated_price <= 10.0:
|
||||
fields["pricePerUnit"] = ExtractedField(
|
||||
value=round(calculated_price, 3),
|
||||
confidence=min(
|
||||
fields["totalAmount"].confidence,
|
||||
fields["fuelQuantity"].confidence,
|
||||
)
|
||||
* 0.8, # Lower confidence for calculated value
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
def _extract_generic_fields(self, text: str) -> dict[str, ExtractedField]:
|
||||
"""
|
||||
Extract generic fields from receipt text.
|
||||
|
||||
Args:
|
||||
text: OCR text
|
||||
|
||||
Returns:
|
||||
Dictionary of extracted fields
|
||||
"""
|
||||
fields: dict[str, ExtractedField] = {}
|
||||
|
||||
# Extract date
|
||||
date_match = date_matcher.extract_best_date(text)
|
||||
if date_match:
|
||||
fields["transactionDate"] = ExtractedField(
|
||||
value=date_match.value,
|
||||
confidence=date_match.confidence,
|
||||
)
|
||||
|
||||
# Extract total amount
|
||||
total_match = currency_matcher.extract_total(text)
|
||||
if total_match:
|
||||
fields["totalAmount"] = ExtractedField(
|
||||
value=total_match.value,
|
||||
confidence=total_match.confidence,
|
||||
)
|
||||
|
||||
# Try to get merchant from first line
|
||||
lines = [l.strip() for l in text.split("\n") if l.strip()]
|
||||
if lines:
|
||||
fields["merchantName"] = ExtractedField(
|
||||
value=lines[0][:50],
|
||||
confidence=0.40,
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
def validate(self, data: Any) -> bool:
|
||||
"""
|
||||
Validate extracted receipt data.
|
||||
|
||||
Args:
|
||||
data: Extracted data to validate
|
||||
|
||||
Returns:
|
||||
True if data has minimum required fields
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
# Minimum: must have at least total amount or date
|
||||
return "totalAmount" in data or "transactionDate" in data
|
||||
|
||||
|
||||
# Singleton instance
|
||||
receipt_extractor = ReceiptExtractor()
|
||||
Reference in New Issue
Block a user