"""Tests for receipt extraction pipeline.""" import io import pytest from unittest.mock import MagicMock, patch from app.extractors.receipt_extractor import ( ReceiptExtractor, ReceiptExtractionResult, receipt_extractor, ) from app.extractors.fuel_receipt import ( FuelReceiptExtractor, FuelReceiptValidation, fuel_receipt_extractor, ) class TestReceiptExtractor: """Test the receipt extraction pipeline.""" def test_detect_receipt_type_fuel(self) -> None: """Test fuel receipt type detection.""" text = """ SHELL STATION REGULAR 87 10.500 GAL TOTAL $38.50 """ extractor = ReceiptExtractor() receipt_type = extractor._detect_receipt_type(text) assert receipt_type == "fuel" def test_detect_receipt_type_unknown(self) -> None: """Test unknown receipt type detection.""" text = """ WALMART GROCERIES MILK $3.99 BREAD $2.50 TOTAL $6.49 """ extractor = ReceiptExtractor() receipt_type = extractor._detect_receipt_type(text) assert receipt_type == "unknown" def test_extract_fuel_fields(self) -> None: """Test fuel field extraction from OCR text.""" text = """ SHELL 123 MAIN ST DATE: 01/15/2024 REGULAR 87 10.500 GAL @ $3.679 TOTAL $38.63 """ extractor = ReceiptExtractor() fields = extractor._extract_fuel_fields(text) assert "merchantName" in fields assert "transactionDate" in fields assert "totalAmount" in fields assert "fuelQuantity" in fields assert "pricePerUnit" in fields assert "fuelGrade" in fields assert fields["totalAmount"].value == 38.63 assert fields["fuelQuantity"].value == 10.5 assert fields["fuelGrade"].value == "87" def test_extract_generic_fields(self) -> None: """Test generic field extraction.""" text = """ WALMART 01/15/2024 TOTAL $25.99 """ extractor = ReceiptExtractor() fields = extractor._extract_generic_fields(text) assert "transactionDate" in fields assert "totalAmount" in fields assert fields["totalAmount"].value == 25.99 def test_calculated_price_per_unit(self) -> None: """Test price per unit calculation when not explicitly stated.""" text = """ SHELL DATE: 01/15/2024 10.000 GAL TOTAL $35.00 """ extractor = ReceiptExtractor() fields = extractor._extract_fuel_fields(text) assert "pricePerUnit" in fields # 35.00 / 10.000 = 3.50 assert abs(fields["pricePerUnit"].value - 3.50) < 0.01 # Calculated values should have lower confidence assert fields["pricePerUnit"].confidence < 0.9 def test_validate_valid_data(self) -> None: """Test validation of valid receipt data.""" extractor = ReceiptExtractor() data = {"totalAmount": 38.63, "transactionDate": "2024-01-15"} assert extractor.validate(data) is True def test_validate_invalid_data(self) -> None: """Test validation of invalid receipt data.""" extractor = ReceiptExtractor() # Empty dict assert extractor.validate({}) is False # Not a dict assert extractor.validate("invalid") is False def test_unsupported_file_type(self) -> None: """Test handling of unsupported file types.""" extractor = ReceiptExtractor() with patch.object(extractor, "_detect_mime_type", return_value="application/pdf"): result = extractor.extract(b"fake pdf content") assert result.success is False assert "Unsupported file type" in result.error class TestFuelReceiptExtractor: """Test fuel receipt specialized extractor.""" def test_validation_success(self) -> None: """Test validation passes for consistent data.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=38.63, confidence=0.95), "fuelQuantity": ExtractedField(value=10.5, confidence=0.90), "pricePerUnit": ExtractedField(value=3.679, confidence=0.92), "fuelGrade": ExtractedField(value="87", confidence=0.88), } validation = extractor._validate_fuel_receipt(fields) assert validation.is_valid is True assert len(validation.issues) == 0 assert validation.confidence_score == 1.0 def test_validation_math_mismatch(self) -> None: """Test validation catches total != quantity * price.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=50.00, confidence=0.95), "fuelQuantity": ExtractedField(value=10.0, confidence=0.90), "pricePerUnit": ExtractedField(value=3.00, confidence=0.92), # Should be $30 } validation = extractor._validate_fuel_receipt(fields) assert validation.is_valid is False assert any("doesn't match" in issue for issue in validation.issues) def test_validation_quantity_too_small(self) -> None: """Test validation catches too-small quantity.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=1.00, confidence=0.95), "fuelQuantity": ExtractedField(value=0.1, confidence=0.90), } validation = extractor._validate_fuel_receipt(fields) assert any("too small" in issue for issue in validation.issues) def test_validation_quantity_too_large(self) -> None: """Test validation warns on very large quantity.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=150.00, confidence=0.95), "fuelQuantity": ExtractedField(value=45.0, confidence=0.90), } validation = extractor._validate_fuel_receipt(fields) assert any("unusually large" in issue for issue in validation.issues) def test_validation_price_too_low(self) -> None: """Test validation catches too-low price.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=10.00, confidence=0.95), "pricePerUnit": ExtractedField(value=1.00, confidence=0.90), } validation = extractor._validate_fuel_receipt(fields) assert any("too low" in issue for issue in validation.issues) def test_validation_unknown_grade(self) -> None: """Test validation catches unknown fuel grade.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=38.00, confidence=0.95), "fuelGrade": ExtractedField(value="95", confidence=0.70), # Not valid US grade } validation = extractor._validate_fuel_receipt(fields) assert any("Unknown fuel grade" in issue for issue in validation.issues) def test_confidence_adjustment_boost(self) -> None: """Test confidence boost when validation passes.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=38.63, confidence=0.80), } validation = FuelReceiptValidation( is_valid=True, issues=[], confidence_score=1.0 ) adjusted = extractor._adjust_confidences(fields, validation) # Should be boosted by 1.1 assert adjusted["totalAmount"].confidence == min(1.0, 0.80 * 1.1) def test_confidence_adjustment_reduce(self) -> None: """Test confidence reduction when validation fails.""" from app.extractors.receipt_extractor import ExtractedField extractor = FuelReceiptExtractor() fields = { "totalAmount": ExtractedField(value=50.00, confidence=0.90), } validation = FuelReceiptValidation( is_valid=False, issues=["Math mismatch"], confidence_score=0.7 ) adjusted = extractor._adjust_confidences(fields, validation) # Should be reduced by 0.7 assert adjusted["totalAmount"].confidence == 0.90 * 0.7 class TestReceiptPreprocessing: """Test receipt preprocessing integration.""" def test_preprocessing_result_structure(self) -> None: """Test preprocessing returns expected structure.""" from app.preprocessors.receipt_preprocessor import ( receipt_preprocessor, ReceiptPreprocessingResult, ) # Create a simple test image (1x1 white pixel PNG) from PIL import Image img = Image.new("RGB", (100, 100), color="white") buffer = io.BytesIO() img.save(buffer, format="PNG") image_bytes = buffer.getvalue() result = receipt_preprocessor.preprocess(image_bytes) assert isinstance(result, ReceiptPreprocessingResult) assert len(result.image_bytes) > 0 assert "loaded" in result.preprocessing_applied assert "grayscale" in result.preprocessing_applied assert result.original_width == 100 assert result.original_height == 100 def test_preprocessing_steps_applied(self) -> None: """Test all preprocessing steps are applied.""" from app.preprocessors.receipt_preprocessor import receipt_preprocessor from PIL import Image img = Image.new("RGB", (100, 100), color="white") buffer = io.BytesIO() img.save(buffer, format="PNG") image_bytes = buffer.getvalue() result = receipt_preprocessor.preprocess( image_bytes, apply_contrast=True, apply_deskew=True, apply_denoise=True, apply_threshold=True, apply_sharpen=True, ) # Check that expected steps are in the applied list assert "contrast" in result.preprocessing_applied assert "denoise" in result.preprocessing_applied assert "threshold" in result.preprocessing_applied class TestEndpointIntegration: """Test receipt extraction endpoint integration.""" @pytest.fixture def test_client(self): """Create test client for FastAPI app.""" from fastapi.testclient import TestClient from app.main import app return TestClient(app) def test_receipt_endpoint_exists(self, test_client) -> None: """Test that /extract/receipt endpoint exists.""" # Should get 422 (no file) not 404 (not found) response = test_client.post("/extract/receipt") assert response.status_code == 422 # Unprocessable Entity (missing file) def test_receipt_endpoint_no_file(self, test_client) -> None: """Test endpoint returns error when no file provided.""" response = test_client.post("/extract/receipt") assert response.status_code == 422 def test_receipt_endpoint_empty_file(self, test_client) -> None: """Test endpoint returns error for empty file.""" response = test_client.post( "/extract/receipt", files={"file": ("receipt.jpg", b"", "image/jpeg")}, ) assert response.status_code == 400 assert "Empty file" in response.json()["detail"]