Files
motovaultpro/ocr/app/extractors/maintenance_receipt_extractor.py
Eric Gullickson 9f51e62b94 feat: migrate MaintenanceReceiptExtractor to google-genai SDK (refs #234)
Replace vertexai.generative_models with google.genai client pattern.
Fix pre-existing bug: raise GeminiUnavailableError instead of bare
RuntimeError for missing credentials. Add proper try/except blocks
matching GeminiEngine error handling pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-28 11:17:14 -06:00

325 lines
11 KiB
Python

"""Maintenance receipt extraction with Gemini-primary and regex cross-validation.
Flow:
1. Preprocess image and OCR via receipt_extractor (PaddleOCR)
2. Send OCR text to Gemini text API for semantic field extraction
3. Cross-validate structured fields (date, cost, odometer) with regex
4. Return ReceiptExtractionResult with per-field confidence scores
"""
import json
import logging
import os
import time
from typing import Any, Optional
from app.config import settings
from app.engines.gemini_engine import GeminiUnavailableError
from app.extractors.receipt_extractor import (
ExtractedField,
ReceiptExtractionResult,
receipt_extractor,
)
from app.patterns.maintenance_receipt_validation import (
MaintenanceReceiptValidation,
maintenance_receipt_validator,
)
logger = logging.getLogger(__name__)
# Default confidence for Gemini-extracted fields before cross-validation
DEFAULT_GEMINI_CONFIDENCE = 0.85
# Gemini prompt for maintenance receipt field extraction
_RECEIPT_EXTRACTION_PROMPT = """\
Extract maintenance service receipt fields from the following OCR text.
For each field, extract the value if present. Return null for fields not found.
Fields to extract:
- serviceName: The maintenance service performed (e.g., "Oil Change", "Brake Pad Replacement", "Tire Rotation")
- serviceDate: Date of service in YYYY-MM-DD format
- totalCost: Total cost as a number (e.g., 89.95)
- shopName: Name of the shop or business
- laborCost: Labor cost as a number, or null if not broken out
- partsCost: Parts cost as a number, or null if not broken out
- odometerReading: Odometer/mileage reading as a number, or null if not present
- vehicleInfo: Vehicle description if present (e.g., "2022 Toyota Camry"), or null
Return a JSON object with these field names and their extracted values.
OCR Text:
---
{ocr_text}
---\
"""
_RECEIPT_RESPONSE_SCHEMA: dict[str, Any] = {
"type": "OBJECT",
"properties": {
"serviceName": {"type": "STRING", "nullable": True},
"serviceDate": {"type": "STRING", "nullable": True},
"totalCost": {"type": "NUMBER", "nullable": True},
"shopName": {"type": "STRING", "nullable": True},
"laborCost": {"type": "NUMBER", "nullable": True},
"partsCost": {"type": "NUMBER", "nullable": True},
"odometerReading": {"type": "NUMBER", "nullable": True},
"vehicleInfo": {"type": "STRING", "nullable": True},
},
"required": [
"serviceName",
"serviceDate",
"totalCost",
"shopName",
"laborCost",
"partsCost",
"odometerReading",
"vehicleInfo",
],
}
class MaintenanceReceiptExtractor:
"""Maintenance receipt extractor using Gemini for semantic extraction.
Wraps receipt_extractor for OCR preprocessing, then sends raw text to
Gemini for field extraction. Structured fields (dates, amounts, odometer)
are cross-validated against regex patterns for confidence adjustment.
"""
def __init__(self) -> None:
self._client: Any | None = None
self._model_name: str = ""
def extract(
self,
image_bytes: bytes,
content_type: Optional[str] = None,
) -> ReceiptExtractionResult:
"""Extract maintenance receipt fields from an image.
Args:
image_bytes: Raw image or PDF bytes (HEIC, JPEG, PNG, PDF).
content_type: MIME type (auto-detected if not provided).
Returns:
ReceiptExtractionResult with maintenance-specific fields.
"""
start_time = time.time()
# Step 1: OCR the image via receipt_extractor
ocr_result = receipt_extractor.extract(
image_bytes=image_bytes,
content_type=content_type,
)
if not ocr_result.success:
return ocr_result
raw_text = ocr_result.raw_text
if not raw_text.strip():
return ReceiptExtractionResult(
success=False,
error="No text found in image",
processing_time_ms=int((time.time() - start_time) * 1000),
)
# Step 2: Extract fields with Gemini
try:
gemini_fields = self._extract_with_gemini(raw_text)
except Exception as e:
logger.warning(f"Gemini extraction failed, falling back to OCR-only: {e}")
gemini_fields = {}
# Step 3: Build extracted fields with base confidence
extracted_fields = self._build_fields(gemini_fields)
if not extracted_fields:
return ReceiptExtractionResult(
success=False,
receipt_type="maintenance",
error="No maintenance receipt fields could be extracted",
raw_text=raw_text,
processing_time_ms=int((time.time() - start_time) * 1000),
)
# Step 4: Cross-validate structured fields with regex
validation = maintenance_receipt_validator.validate(gemini_fields, raw_text)
if validation.issues:
logger.info(f"Maintenance receipt validation issues: {validation.issues}")
# Step 5: Adjust confidences based on cross-validation
adjusted_fields = self._adjust_confidences(extracted_fields, validation)
processing_time_ms = int((time.time() - start_time) * 1000)
logger.info(
f"Maintenance receipt extraction: "
f"fields={len(adjusted_fields)}, "
f"validated={validation.is_valid}, "
f"time={processing_time_ms}ms"
)
return ReceiptExtractionResult(
success=True,
receipt_type="maintenance",
extracted_fields=adjusted_fields,
raw_text=raw_text,
processing_time_ms=processing_time_ms,
)
def _get_client(self) -> Any:
"""Lazy-initialize google-genai Gemini client.
Uses the same authentication pattern as GeminiEngine.
"""
if self._client is not None:
return self._client
key_path = settings.google_vision_key_path
if not os.path.isfile(key_path):
raise GeminiUnavailableError(
f"Google credential config not found at {key_path}. "
"Set GOOGLE_VISION_KEY_PATH or mount the secret."
)
try:
from google import genai # type: ignore[import-untyped]
# Point ADC at the WIF credential config (must be set BEFORE Client construction)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = key_path
os.environ["GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES"] = "1"
self._client = genai.Client(
vertexai=True,
project=settings.vertex_ai_project,
location=settings.vertex_ai_location,
)
self._model_name = settings.gemini_model
logger.info(
"Maintenance receipt Gemini client initialized (model=%s)",
self._model_name,
)
return self._client
except ImportError as exc:
logger.exception("google-genai SDK import failed")
raise GeminiUnavailableError(
"google-genai is not installed. "
"Install with: pip install google-genai"
) from exc
except Exception as exc:
logger.exception("Gemini authentication failed: %s", type(exc).__name__)
raise GeminiUnavailableError(
f"Gemini authentication failed: {exc}"
) from exc
def _extract_with_gemini(self, ocr_text: str) -> dict:
"""Send OCR text to Gemini for semantic field extraction.
Args:
ocr_text: Raw OCR text from receipt image.
Returns:
Dictionary of field_name -> extracted_value from Gemini.
"""
client = self._get_client()
from google.genai import types # type: ignore[import-untyped]
prompt = _RECEIPT_EXTRACTION_PROMPT.format(ocr_text=ocr_text)
response = client.models.generate_content(
model=self._model_name,
contents=[prompt],
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=_RECEIPT_RESPONSE_SCHEMA,
),
)
raw = json.loads(response.text)
logger.info(
"Gemini extracted maintenance fields: %s",
[k for k, v in raw.items() if v is not None],
)
return raw
def _build_fields(self, gemini_fields: dict) -> dict[str, ExtractedField]:
"""Convert Gemini response to ExtractedField dict with base confidence.
Args:
gemini_fields: Raw Gemini response dict.
Returns:
Dictionary of field_name -> ExtractedField.
"""
fields: dict[str, ExtractedField] = {}
for field_name, value in gemini_fields.items():
if value is None:
continue
# Convert numeric values to appropriate types
if field_name in ("totalCost", "laborCost", "partsCost"):
try:
value = round(float(value), 2)
except (ValueError, TypeError):
continue
elif field_name == "odometerReading":
try:
value = int(float(value))
except (ValueError, TypeError):
continue
elif isinstance(value, str) and not value.strip():
continue
fields[field_name] = ExtractedField(
value=value,
confidence=DEFAULT_GEMINI_CONFIDENCE,
)
return fields
def _adjust_confidences(
self,
fields: dict[str, ExtractedField],
validation: MaintenanceReceiptValidation,
) -> dict[str, ExtractedField]:
"""Adjust field confidences based on cross-validation results.
Args:
fields: Extracted fields with base confidence.
validation: Cross-validation results.
Returns:
Fields with adjusted confidences.
"""
adjusted: dict[str, ExtractedField] = {}
for name, extracted_field in fields.items():
if name in validation.field_validations:
fv = validation.field_validations[name]
new_confidence = min(
1.0, extracted_field.confidence * fv.confidence_adjustment
)
else:
# Semantic fields (no regex validation) keep base confidence
new_confidence = extracted_field.confidence
adjusted[name] = ExtractedField(
value=extracted_field.value,
confidence=round(new_confidence, 3),
)
return adjusted
# Singleton instance
maintenance_receipt_extractor = MaintenanceReceiptExtractor()