feat: add maintenance receipt extraction pipeline with Gemini + regex (refs #150)
- New MaintenanceReceiptExtractor: Gemini-primary extraction with regex cross-validation for dates, amounts, and odometer readings - New maintenance_receipt_validation.py: cross-validation patterns for structured field confidence adjustment - New POST /extract/maintenance-receipt endpoint reusing ReceiptExtractionResponse model - Per-field confidence scores (0.0-1.0) with Gemini base 0.85, boosted/reduced by regex agreement Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
312
ocr/app/extractors/maintenance_receipt_extractor.py
Normal file
312
ocr/app/extractors/maintenance_receipt_extractor.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Maintenance receipt extraction with Gemini-primary and regex cross-validation.
|
||||
|
||||
Flow:
|
||||
1. Preprocess image and OCR via receipt_extractor (PaddleOCR)
|
||||
2. Send OCR text to Gemini text API for semantic field extraction
|
||||
3. Cross-validate structured fields (date, cost, odometer) with regex
|
||||
4. Return ReceiptExtractionResult with per-field confidence scores
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.extractors.receipt_extractor import (
|
||||
ExtractedField,
|
||||
ReceiptExtractionResult,
|
||||
receipt_extractor,
|
||||
)
|
||||
from app.patterns.maintenance_receipt_validation import (
|
||||
MaintenanceReceiptValidation,
|
||||
maintenance_receipt_validator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default confidence for Gemini-extracted fields before cross-validation
|
||||
DEFAULT_GEMINI_CONFIDENCE = 0.85
|
||||
|
||||
# Gemini prompt for maintenance receipt field extraction
|
||||
_RECEIPT_EXTRACTION_PROMPT = """\
|
||||
Extract maintenance service receipt fields from the following OCR text.
|
||||
|
||||
For each field, extract the value if present. Return null for fields not found.
|
||||
|
||||
Fields to extract:
|
||||
- serviceName: The maintenance service performed (e.g., "Oil Change", "Brake Pad Replacement", "Tire Rotation")
|
||||
- serviceDate: Date of service in YYYY-MM-DD format
|
||||
- totalCost: Total cost as a number (e.g., 89.95)
|
||||
- shopName: Name of the shop or business
|
||||
- laborCost: Labor cost as a number, or null if not broken out
|
||||
- partsCost: Parts cost as a number, or null if not broken out
|
||||
- odometerReading: Odometer/mileage reading as a number, or null if not present
|
||||
- vehicleInfo: Vehicle description if present (e.g., "2022 Toyota Camry"), or null
|
||||
|
||||
Return a JSON object with these field names and their extracted values.
|
||||
|
||||
OCR Text:
|
||||
---
|
||||
{ocr_text}
|
||||
---\
|
||||
"""
|
||||
|
||||
_RECEIPT_RESPONSE_SCHEMA: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"serviceName": {"type": "string", "nullable": True},
|
||||
"serviceDate": {"type": "string", "nullable": True},
|
||||
"totalCost": {"type": "number", "nullable": True},
|
||||
"shopName": {"type": "string", "nullable": True},
|
||||
"laborCost": {"type": "number", "nullable": True},
|
||||
"partsCost": {"type": "number", "nullable": True},
|
||||
"odometerReading": {"type": "number", "nullable": True},
|
||||
"vehicleInfo": {"type": "string", "nullable": True},
|
||||
},
|
||||
"required": [
|
||||
"serviceName",
|
||||
"serviceDate",
|
||||
"totalCost",
|
||||
"shopName",
|
||||
"laborCost",
|
||||
"partsCost",
|
||||
"odometerReading",
|
||||
"vehicleInfo",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class MaintenanceReceiptExtractor:
|
||||
"""Maintenance receipt extractor using Gemini for semantic extraction.
|
||||
|
||||
Wraps receipt_extractor for OCR preprocessing, then sends raw text to
|
||||
Gemini for field extraction. Structured fields (dates, amounts, odometer)
|
||||
are cross-validated against regex patterns for confidence adjustment.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._model: Any | None = None
|
||||
self._generation_config: Any | None = None
|
||||
|
||||
def extract(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
content_type: Optional[str] = None,
|
||||
) -> ReceiptExtractionResult:
|
||||
"""Extract maintenance receipt fields from an image.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes (HEIC, JPEG, PNG).
|
||||
content_type: MIME type (auto-detected if not provided).
|
||||
|
||||
Returns:
|
||||
ReceiptExtractionResult with maintenance-specific fields.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Step 1: OCR the image via receipt_extractor
|
||||
ocr_result = receipt_extractor.extract(
|
||||
image_bytes=image_bytes,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
if not ocr_result.success:
|
||||
return ocr_result
|
||||
|
||||
raw_text = ocr_result.raw_text
|
||||
|
||||
if not raw_text.strip():
|
||||
return ReceiptExtractionResult(
|
||||
success=False,
|
||||
error="No text found in image",
|
||||
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
# Step 2: Extract fields with Gemini
|
||||
try:
|
||||
gemini_fields = self._extract_with_gemini(raw_text)
|
||||
except Exception as e:
|
||||
logger.warning(f"Gemini extraction failed, falling back to OCR-only: {e}")
|
||||
gemini_fields = {}
|
||||
|
||||
# Step 3: Build extracted fields with base confidence
|
||||
extracted_fields = self._build_fields(gemini_fields)
|
||||
|
||||
if not extracted_fields:
|
||||
return ReceiptExtractionResult(
|
||||
success=False,
|
||||
receipt_type="maintenance",
|
||||
error="No maintenance receipt fields could be extracted",
|
||||
raw_text=raw_text,
|
||||
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
# Step 4: Cross-validate structured fields with regex
|
||||
validation = maintenance_receipt_validator.validate(gemini_fields, raw_text)
|
||||
|
||||
if validation.issues:
|
||||
logger.info(f"Maintenance receipt validation issues: {validation.issues}")
|
||||
|
||||
# Step 5: Adjust confidences based on cross-validation
|
||||
adjusted_fields = self._adjust_confidences(extracted_fields, validation)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"Maintenance receipt extraction: "
|
||||
f"fields={len(adjusted_fields)}, "
|
||||
f"validated={validation.is_valid}, "
|
||||
f"time={processing_time_ms}ms"
|
||||
)
|
||||
|
||||
return ReceiptExtractionResult(
|
||||
success=True,
|
||||
receipt_type="maintenance",
|
||||
extracted_fields=adjusted_fields,
|
||||
raw_text=raw_text,
|
||||
processing_time_ms=processing_time_ms,
|
||||
)
|
||||
|
||||
def _get_model(self) -> Any:
|
||||
"""Lazy-initialize Vertex AI Gemini model.
|
||||
|
||||
Uses the same authentication pattern as GeminiEngine.
|
||||
"""
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
key_path = settings.google_vision_key_path
|
||||
if not os.path.isfile(key_path):
|
||||
raise RuntimeError(
|
||||
f"Google credential config not found at {key_path}. "
|
||||
"Set GOOGLE_VISION_KEY_PATH or mount the secret."
|
||||
)
|
||||
|
||||
from google.cloud import aiplatform # type: ignore[import-untyped]
|
||||
from vertexai.generative_models import ( # type: ignore[import-untyped]
|
||||
GenerationConfig,
|
||||
GenerativeModel,
|
||||
)
|
||||
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = key_path
|
||||
os.environ["GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES"] = "1"
|
||||
|
||||
aiplatform.init(
|
||||
project=settings.vertex_ai_project,
|
||||
location=settings.vertex_ai_location,
|
||||
)
|
||||
|
||||
model_name = settings.gemini_model
|
||||
self._model = GenerativeModel(model_name)
|
||||
self._generation_config = GenerationConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=_RECEIPT_RESPONSE_SCHEMA,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Maintenance receipt Gemini model initialized (model=%s)",
|
||||
model_name,
|
||||
)
|
||||
return self._model
|
||||
|
||||
def _extract_with_gemini(self, ocr_text: str) -> dict:
|
||||
"""Send OCR text to Gemini for semantic field extraction.
|
||||
|
||||
Args:
|
||||
ocr_text: Raw OCR text from receipt image.
|
||||
|
||||
Returns:
|
||||
Dictionary of field_name -> extracted_value from Gemini.
|
||||
"""
|
||||
model = self._get_model()
|
||||
|
||||
prompt = _RECEIPT_EXTRACTION_PROMPT.format(ocr_text=ocr_text)
|
||||
|
||||
response = model.generate_content(
|
||||
[prompt],
|
||||
generation_config=self._generation_config,
|
||||
)
|
||||
|
||||
raw = json.loads(response.text)
|
||||
|
||||
logger.info(
|
||||
"Gemini extracted maintenance fields: %s",
|
||||
[k for k, v in raw.items() if v is not None],
|
||||
)
|
||||
|
||||
return raw
|
||||
|
||||
def _build_fields(self, gemini_fields: dict) -> dict[str, ExtractedField]:
|
||||
"""Convert Gemini response to ExtractedField dict with base confidence.
|
||||
|
||||
Args:
|
||||
gemini_fields: Raw Gemini response dict.
|
||||
|
||||
Returns:
|
||||
Dictionary of field_name -> ExtractedField.
|
||||
"""
|
||||
fields: dict[str, ExtractedField] = {}
|
||||
|
||||
for field_name, value in gemini_fields.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if field_name in ("totalCost", "laborCost", "partsCost"):
|
||||
try:
|
||||
value = round(float(value), 2)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif field_name == "odometerReading":
|
||||
try:
|
||||
value = int(float(value))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, str) and not value.strip():
|
||||
continue
|
||||
|
||||
fields[field_name] = ExtractedField(
|
||||
value=value,
|
||||
confidence=DEFAULT_GEMINI_CONFIDENCE,
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
def _adjust_confidences(
|
||||
self,
|
||||
fields: dict[str, ExtractedField],
|
||||
validation: MaintenanceReceiptValidation,
|
||||
) -> dict[str, ExtractedField]:
|
||||
"""Adjust field confidences based on cross-validation results.
|
||||
|
||||
Args:
|
||||
fields: Extracted fields with base confidence.
|
||||
validation: Cross-validation results.
|
||||
|
||||
Returns:
|
||||
Fields with adjusted confidences.
|
||||
"""
|
||||
adjusted: dict[str, ExtractedField] = {}
|
||||
|
||||
for name, extracted_field in fields.items():
|
||||
if name in validation.field_validations:
|
||||
fv = validation.field_validations[name]
|
||||
new_confidence = min(
|
||||
1.0, extracted_field.confidence * fv.confidence_adjustment
|
||||
)
|
||||
else:
|
||||
# Semantic fields (no regex validation) keep base confidence
|
||||
new_confidence = extracted_field.confidence
|
||||
|
||||
adjusted[name] = ExtractedField(
|
||||
value=extracted_field.value,
|
||||
confidence=round(new_confidence, 3),
|
||||
)
|
||||
|
||||
return adjusted
|
||||
|
||||
|
||||
# Singleton instance
|
||||
maintenance_receipt_extractor = MaintenanceReceiptExtractor()
|
||||
Reference in New Issue
Block a user