"""Fuel receipt specialization with validation and cross-checking.""" import logging from dataclasses import dataclass from typing import Optional from app.extractors.receipt_extractor import ( ExtractedField, ReceiptExtractionResult, receipt_extractor, ) logger = logging.getLogger(__name__) @dataclass class FuelReceiptValidation: """Validation result for fuel receipt extraction.""" is_valid: bool issues: list[str] confidence_score: float class FuelReceiptExtractor: """Specialized fuel receipt extractor with cross-validation. Provides additional validation and confidence scoring specific to fuel receipts by cross-checking extracted values. """ # Expected fields for a complete fuel receipt REQUIRED_FIELDS = ["totalAmount"] OPTIONAL_FIELDS = [ "merchantName", "transactionDate", "fuelQuantity", "pricePerUnit", "fuelGrade", ] def extract( self, image_bytes: bytes, content_type: Optional[str] = None, ) -> ReceiptExtractionResult: """ Extract fuel receipt data with validation. Args: image_bytes: Raw image bytes content_type: MIME type Returns: ReceiptExtractionResult with fuel-specific extraction """ # Use base receipt extractor with fuel hint result = receipt_extractor.extract( image_bytes=image_bytes, content_type=content_type, receipt_type="fuel", ) if not result.success: return result # Validate and cross-check fuel fields validation = self._validate_fuel_receipt(result.extracted_fields) if validation.issues: logger.warning( f"Fuel receipt validation issues: {validation.issues}" ) # Update overall confidence based on validation result.extracted_fields = self._adjust_confidences( result.extracted_fields, validation ) return result def _validate_fuel_receipt( self, fields: dict[str, ExtractedField] ) -> FuelReceiptValidation: """ Validate extracted fuel receipt fields. Cross-checks: - total = quantity * price per unit (within tolerance) - quantity is reasonable for a single fill-up - price per unit is within market range Args: fields: Extracted fields Returns: FuelReceiptValidation with issues and confidence """ issues = [] confidence_score = 1.0 # Check required fields for field_name in self.REQUIRED_FIELDS: if field_name not in fields: issues.append(f"Missing required field: {field_name}") confidence_score *= 0.5 # Cross-validate total = quantity * price if all( f in fields for f in ["totalAmount", "fuelQuantity", "pricePerUnit"] ): total = fields["totalAmount"].value quantity = fields["fuelQuantity"].value price = fields["pricePerUnit"].value calculated_total = quantity * price tolerance = 0.10 # Allow 10% tolerance for rounding if abs(total - calculated_total) > total * tolerance: issues.append( f"Total ({total}) doesn't match quantity ({quantity}) * " f"price ({price}) = {calculated_total:.2f}" ) confidence_score *= 0.7 # Validate quantity is reasonable if "fuelQuantity" in fields: quantity = fields["fuelQuantity"].value if quantity < 0.5: issues.append(f"Fuel quantity too small: {quantity}") confidence_score *= 0.6 elif quantity > 40: # 40 gallons is very large tank issues.append(f"Fuel quantity unusually large: {quantity}") confidence_score *= 0.8 # Validate price is reasonable (current US market range) if "pricePerUnit" in fields: price = fields["pricePerUnit"].value if price < 1.50: issues.append(f"Price per unit too low: ${price}") confidence_score *= 0.7 elif price > 7.00: issues.append(f"Price per unit unusually high: ${price}") confidence_score *= 0.8 # Validate fuel grade if "fuelGrade" in fields: grade = fields["fuelGrade"].value valid_grades = ["87", "89", "91", "93", "DIESEL", "E85"] if grade not in valid_grades: issues.append(f"Unknown fuel grade: {grade}") confidence_score *= 0.9 is_valid = len(issues) == 0 return FuelReceiptValidation( is_valid=is_valid, issues=issues, confidence_score=confidence_score, ) def _adjust_confidences( self, fields: dict[str, ExtractedField], validation: FuelReceiptValidation, ) -> dict[str, ExtractedField]: """ Adjust field confidences based on validation. Args: fields: Extracted fields validation: Validation result Returns: Fields with adjusted confidences """ if validation.is_valid: # Boost confidences when cross-validation passes boost = 1.1 else: # Reduce confidences when there are issues boost = validation.confidence_score adjusted = {} for name, field in fields.items(): adjusted[name] = ExtractedField( value=field.value, confidence=min(1.0, field.confidence * boost), ) return adjusted # Singleton instance fuel_receipt_extractor = FuelReceiptExtractor()