All checks were successful
Deploy to Staging / Build Images (pull_request) Successful in 32s
Deploy to Staging / Deploy to Staging (pull_request) Successful in 31s
Deploy to Staging / Verify Staging (pull_request) Successful in 2m20s
Deploy to Staging / Notify Staging Ready (pull_request) Successful in 8s
Deploy to Staging / Notify Staging Failure (pull_request) Has been skipped
Implement receipt-specific OCR extraction for fuel receipts: - Pattern matching modules for date, currency, and fuel data extraction - Receipt-optimized image preprocessing for thermal receipts - POST /extract/receipt endpoint with field extraction - Confidence scoring per extracted field - Cross-validation of fuel receipt data - Unit tests for all pattern matchers Extracted fields: merchantName, transactionDate, totalAmount, fuelQuantity, pricePerUnit, fuelGrade Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
340 lines
12 KiB
Python
340 lines
12 KiB
Python
"""Tests for receipt extraction pipeline."""
|
|
import io
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from app.extractors.receipt_extractor import (
|
|
ReceiptExtractor,
|
|
ReceiptExtractionResult,
|
|
receipt_extractor,
|
|
)
|
|
from app.extractors.fuel_receipt import (
|
|
FuelReceiptExtractor,
|
|
FuelReceiptValidation,
|
|
fuel_receipt_extractor,
|
|
)
|
|
|
|
|
|
class TestReceiptExtractor:
|
|
"""Test the receipt extraction pipeline."""
|
|
|
|
def test_detect_receipt_type_fuel(self) -> None:
|
|
"""Test fuel receipt type detection."""
|
|
text = """
|
|
SHELL STATION
|
|
REGULAR 87
|
|
10.500 GAL
|
|
TOTAL $38.50
|
|
"""
|
|
extractor = ReceiptExtractor()
|
|
receipt_type = extractor._detect_receipt_type(text)
|
|
|
|
assert receipt_type == "fuel"
|
|
|
|
def test_detect_receipt_type_unknown(self) -> None:
|
|
"""Test unknown receipt type detection."""
|
|
text = """
|
|
WALMART
|
|
GROCERIES
|
|
MILK $3.99
|
|
BREAD $2.50
|
|
TOTAL $6.49
|
|
"""
|
|
extractor = ReceiptExtractor()
|
|
receipt_type = extractor._detect_receipt_type(text)
|
|
|
|
assert receipt_type == "unknown"
|
|
|
|
def test_extract_fuel_fields(self) -> None:
|
|
"""Test fuel field extraction from OCR text."""
|
|
text = """
|
|
SHELL
|
|
123 MAIN ST
|
|
DATE: 01/15/2024
|
|
REGULAR 87
|
|
10.500 GAL @ $3.679
|
|
TOTAL $38.63
|
|
"""
|
|
extractor = ReceiptExtractor()
|
|
fields = extractor._extract_fuel_fields(text)
|
|
|
|
assert "merchantName" in fields
|
|
assert "transactionDate" in fields
|
|
assert "totalAmount" in fields
|
|
assert "fuelQuantity" in fields
|
|
assert "pricePerUnit" in fields
|
|
assert "fuelGrade" in fields
|
|
|
|
assert fields["totalAmount"].value == 38.63
|
|
assert fields["fuelQuantity"].value == 10.5
|
|
assert fields["fuelGrade"].value == "87"
|
|
|
|
def test_extract_generic_fields(self) -> None:
|
|
"""Test generic field extraction."""
|
|
text = """
|
|
WALMART
|
|
01/15/2024
|
|
TOTAL $25.99
|
|
"""
|
|
extractor = ReceiptExtractor()
|
|
fields = extractor._extract_generic_fields(text)
|
|
|
|
assert "transactionDate" in fields
|
|
assert "totalAmount" in fields
|
|
assert fields["totalAmount"].value == 25.99
|
|
|
|
def test_calculated_price_per_unit(self) -> None:
|
|
"""Test price per unit calculation when not explicitly stated."""
|
|
text = """
|
|
SHELL
|
|
DATE: 01/15/2024
|
|
10.000 GAL
|
|
TOTAL $35.00
|
|
"""
|
|
extractor = ReceiptExtractor()
|
|
fields = extractor._extract_fuel_fields(text)
|
|
|
|
assert "pricePerUnit" in fields
|
|
# 35.00 / 10.000 = 3.50
|
|
assert abs(fields["pricePerUnit"].value - 3.50) < 0.01
|
|
# Calculated values should have lower confidence
|
|
assert fields["pricePerUnit"].confidence < 0.9
|
|
|
|
def test_validate_valid_data(self) -> None:
|
|
"""Test validation of valid receipt data."""
|
|
extractor = ReceiptExtractor()
|
|
|
|
data = {"totalAmount": 38.63, "transactionDate": "2024-01-15"}
|
|
assert extractor.validate(data) is True
|
|
|
|
def test_validate_invalid_data(self) -> None:
|
|
"""Test validation of invalid receipt data."""
|
|
extractor = ReceiptExtractor()
|
|
|
|
# Empty dict
|
|
assert extractor.validate({}) is False
|
|
# Not a dict
|
|
assert extractor.validate("invalid") is False
|
|
|
|
def test_unsupported_file_type(self) -> None:
|
|
"""Test handling of unsupported file types."""
|
|
extractor = ReceiptExtractor()
|
|
|
|
with patch.object(extractor, "_detect_mime_type", return_value="application/pdf"):
|
|
result = extractor.extract(b"fake pdf content")
|
|
|
|
assert result.success is False
|
|
assert "Unsupported file type" in result.error
|
|
|
|
|
|
class TestFuelReceiptExtractor:
|
|
"""Test fuel receipt specialized extractor."""
|
|
|
|
def test_validation_success(self) -> None:
|
|
"""Test validation passes for consistent data."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=38.63, confidence=0.95),
|
|
"fuelQuantity": ExtractedField(value=10.5, confidence=0.90),
|
|
"pricePerUnit": ExtractedField(value=3.679, confidence=0.92),
|
|
"fuelGrade": ExtractedField(value="87", confidence=0.88),
|
|
}
|
|
|
|
validation = extractor._validate_fuel_receipt(fields)
|
|
|
|
assert validation.is_valid is True
|
|
assert len(validation.issues) == 0
|
|
assert validation.confidence_score == 1.0
|
|
|
|
def test_validation_math_mismatch(self) -> None:
|
|
"""Test validation catches total != quantity * price."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=50.00, confidence=0.95),
|
|
"fuelQuantity": ExtractedField(value=10.0, confidence=0.90),
|
|
"pricePerUnit": ExtractedField(value=3.00, confidence=0.92), # Should be $30
|
|
}
|
|
|
|
validation = extractor._validate_fuel_receipt(fields)
|
|
|
|
assert validation.is_valid is False
|
|
assert any("doesn't match" in issue for issue in validation.issues)
|
|
|
|
def test_validation_quantity_too_small(self) -> None:
|
|
"""Test validation catches too-small quantity."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=1.00, confidence=0.95),
|
|
"fuelQuantity": ExtractedField(value=0.1, confidence=0.90),
|
|
}
|
|
|
|
validation = extractor._validate_fuel_receipt(fields)
|
|
|
|
assert any("too small" in issue for issue in validation.issues)
|
|
|
|
def test_validation_quantity_too_large(self) -> None:
|
|
"""Test validation warns on very large quantity."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=150.00, confidence=0.95),
|
|
"fuelQuantity": ExtractedField(value=45.0, confidence=0.90),
|
|
}
|
|
|
|
validation = extractor._validate_fuel_receipt(fields)
|
|
|
|
assert any("unusually large" in issue for issue in validation.issues)
|
|
|
|
def test_validation_price_too_low(self) -> None:
|
|
"""Test validation catches too-low price."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=10.00, confidence=0.95),
|
|
"pricePerUnit": ExtractedField(value=1.00, confidence=0.90),
|
|
}
|
|
|
|
validation = extractor._validate_fuel_receipt(fields)
|
|
|
|
assert any("too low" in issue for issue in validation.issues)
|
|
|
|
def test_validation_unknown_grade(self) -> None:
|
|
"""Test validation catches unknown fuel grade."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=38.00, confidence=0.95),
|
|
"fuelGrade": ExtractedField(value="95", confidence=0.70), # Not valid US grade
|
|
}
|
|
|
|
validation = extractor._validate_fuel_receipt(fields)
|
|
|
|
assert any("Unknown fuel grade" in issue for issue in validation.issues)
|
|
|
|
def test_confidence_adjustment_boost(self) -> None:
|
|
"""Test confidence boost when validation passes."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=38.63, confidence=0.80),
|
|
}
|
|
validation = FuelReceiptValidation(
|
|
is_valid=True, issues=[], confidence_score=1.0
|
|
)
|
|
|
|
adjusted = extractor._adjust_confidences(fields, validation)
|
|
|
|
# Should be boosted by 1.1
|
|
assert adjusted["totalAmount"].confidence == min(1.0, 0.80 * 1.1)
|
|
|
|
def test_confidence_adjustment_reduce(self) -> None:
|
|
"""Test confidence reduction when validation fails."""
|
|
from app.extractors.receipt_extractor import ExtractedField
|
|
|
|
extractor = FuelReceiptExtractor()
|
|
fields = {
|
|
"totalAmount": ExtractedField(value=50.00, confidence=0.90),
|
|
}
|
|
validation = FuelReceiptValidation(
|
|
is_valid=False, issues=["Math mismatch"], confidence_score=0.7
|
|
)
|
|
|
|
adjusted = extractor._adjust_confidences(fields, validation)
|
|
|
|
# Should be reduced by 0.7
|
|
assert adjusted["totalAmount"].confidence == 0.90 * 0.7
|
|
|
|
|
|
class TestReceiptPreprocessing:
|
|
"""Test receipt preprocessing integration."""
|
|
|
|
def test_preprocessing_result_structure(self) -> None:
|
|
"""Test preprocessing returns expected structure."""
|
|
from app.preprocessors.receipt_preprocessor import (
|
|
receipt_preprocessor,
|
|
ReceiptPreprocessingResult,
|
|
)
|
|
|
|
# Create a simple test image (1x1 white pixel PNG)
|
|
from PIL import Image
|
|
|
|
img = Image.new("RGB", (100, 100), color="white")
|
|
buffer = io.BytesIO()
|
|
img.save(buffer, format="PNG")
|
|
image_bytes = buffer.getvalue()
|
|
|
|
result = receipt_preprocessor.preprocess(image_bytes)
|
|
|
|
assert isinstance(result, ReceiptPreprocessingResult)
|
|
assert len(result.image_bytes) > 0
|
|
assert "loaded" in result.preprocessing_applied
|
|
assert "grayscale" in result.preprocessing_applied
|
|
assert result.original_width == 100
|
|
assert result.original_height == 100
|
|
|
|
def test_preprocessing_steps_applied(self) -> None:
|
|
"""Test all preprocessing steps are applied."""
|
|
from app.preprocessors.receipt_preprocessor import receipt_preprocessor
|
|
from PIL import Image
|
|
|
|
img = Image.new("RGB", (100, 100), color="white")
|
|
buffer = io.BytesIO()
|
|
img.save(buffer, format="PNG")
|
|
image_bytes = buffer.getvalue()
|
|
|
|
result = receipt_preprocessor.preprocess(
|
|
image_bytes,
|
|
apply_contrast=True,
|
|
apply_deskew=True,
|
|
apply_denoise=True,
|
|
apply_threshold=True,
|
|
apply_sharpen=True,
|
|
)
|
|
|
|
# Check that expected steps are in the applied list
|
|
assert "contrast" in result.preprocessing_applied
|
|
assert "denoise" in result.preprocessing_applied
|
|
assert "threshold" in result.preprocessing_applied
|
|
|
|
|
|
class TestEndpointIntegration:
|
|
"""Test receipt extraction endpoint integration."""
|
|
|
|
@pytest.fixture
|
|
def test_client(self):
|
|
"""Create test client for FastAPI app."""
|
|
from fastapi.testclient import TestClient
|
|
from app.main import app
|
|
|
|
return TestClient(app)
|
|
|
|
def test_receipt_endpoint_exists(self, test_client) -> None:
|
|
"""Test that /extract/receipt endpoint exists."""
|
|
# Should get 422 (no file) not 404 (not found)
|
|
response = test_client.post("/extract/receipt")
|
|
assert response.status_code == 422 # Unprocessable Entity (missing file)
|
|
|
|
def test_receipt_endpoint_no_file(self, test_client) -> None:
|
|
"""Test endpoint returns error when no file provided."""
|
|
response = test_client.post("/extract/receipt")
|
|
assert response.status_code == 422
|
|
|
|
def test_receipt_endpoint_empty_file(self, test_client) -> None:
|
|
"""Test endpoint returns error for empty file."""
|
|
response = test_client.post(
|
|
"/extract/receipt",
|
|
files={"file": ("receipt.jpg", b"", "image/jpeg")},
|
|
)
|
|
assert response.status_code == 400
|
|
assert "Empty file" in response.json()["detail"]
|