feat: Google Vision primary OCR with Auth0 WIF and monthly usage cap (#127) #128
@@ -21,7 +21,12 @@ class Settings:
|
||||
os.getenv("OCR_FALLBACK_THRESHOLD", "0.6")
|
||||
)
|
||||
self.google_vision_key_path: str = os.getenv(
|
||||
"GOOGLE_VISION_KEY_PATH", "/run/secrets/google-vision-key.json"
|
||||
"GOOGLE_VISION_KEY_PATH", "/run/secrets/google-wif-config.json"
|
||||
)
|
||||
|
||||
# Google Vision monthly usage cap (requests per calendar month)
|
||||
self.vision_monthly_limit: int = int(
|
||||
os.getenv("VISION_MONTHLY_LIMIT", "1000")
|
||||
)
|
||||
|
||||
# Redis configuration for job queue
|
||||
|
||||
@@ -15,8 +15,8 @@ from app.engines.base_engine import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default path for Google Vision service account key (Docker secret mount)
|
||||
_DEFAULT_KEY_PATH = "/run/secrets/google-vision-key.json"
|
||||
# Default path for Google WIF credential config (Docker secret mount)
|
||||
_DEFAULT_KEY_PATH = "/run/secrets/google-wif-config.json"
|
||||
|
||||
|
||||
class CloudEngine(OcrEngine):
|
||||
@@ -42,25 +42,33 @@ class CloudEngine(OcrEngine):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
"""Create the Vision client on first use."""
|
||||
"""Create the Vision client on first use.
|
||||
|
||||
Uses Application Default Credentials (ADC) pointed at a WIF
|
||||
credential config file. The WIF config references an executable
|
||||
that fetches an Auth0 M2M JWT.
|
||||
"""
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
# Verify credentials file exists
|
||||
# Verify credentials config exists
|
||||
if not os.path.isfile(self._key_path):
|
||||
raise EngineUnavailableError(
|
||||
f"Google Vision key not found at {self._key_path}. "
|
||||
f"Google Vision credential config not found at {self._key_path}. "
|
||||
"Set GOOGLE_VISION_KEY_PATH or mount the secret."
|
||||
)
|
||||
|
||||
try:
|
||||
from google.cloud import vision # type: ignore[import-untyped]
|
||||
|
||||
# Point the SDK at the service account key
|
||||
# Point ADC at the WIF credential config
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._key_path
|
||||
# Required for executable-sourced credentials
|
||||
os.environ["GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES"] = "1"
|
||||
self._client = vision.ImageAnnotatorClient()
|
||||
logger.info(
|
||||
"Google Vision client initialized (key: %s)", self._key_path
|
||||
"Google Vision client initialized via WIF (config: %s)",
|
||||
self._key_path,
|
||||
)
|
||||
return self._client
|
||||
except ImportError as exc:
|
||||
|
||||
@@ -76,11 +76,18 @@ def create_engine(engine_name: str | None = None) -> OcrEngine:
|
||||
from app.engines.hybrid_engine import HybridEngine
|
||||
|
||||
threshold = settings.ocr_fallback_threshold
|
||||
hybrid = HybridEngine(primary=primary, fallback=fallback, threshold=threshold)
|
||||
monthly_limit = settings.vision_monthly_limit
|
||||
hybrid = HybridEngine(
|
||||
primary=primary,
|
||||
fallback=fallback,
|
||||
threshold=threshold,
|
||||
monthly_limit=monthly_limit,
|
||||
)
|
||||
logger.info(
|
||||
"Created hybrid engine: primary=%s, fallback=%s, threshold=%.2f",
|
||||
"Created hybrid engine: primary=%s, fallback=%s, threshold=%.2f, vision_limit=%d",
|
||||
name,
|
||||
fallback_name,
|
||||
threshold,
|
||||
monthly_limit,
|
||||
)
|
||||
return hybrid
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Hybrid OCR engine: primary engine with optional cloud fallback."""
|
||||
"""Hybrid OCR engine: primary with fallback and monthly usage cap."""
|
||||
|
||||
import calendar
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
from app.config import settings
|
||||
from app.engines.base_engine import (
|
||||
EngineError,
|
||||
EngineProcessingError,
|
||||
@@ -16,15 +21,42 @@ logger = logging.getLogger(__name__)
|
||||
# Maximum time (seconds) to wait for the cloud fallback
|
||||
_CLOUD_TIMEOUT_SECONDS = 5.0
|
||||
|
||||
# Redis key prefix for monthly Vision API request counter
|
||||
_VISION_COUNTER_PREFIX = "ocr:vision_requests"
|
||||
|
||||
|
||||
def _vision_counter_key() -> str:
|
||||
"""Return the Redis key for the current calendar month counter."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return f"{_VISION_COUNTER_PREFIX}:{now.strftime('%Y-%m')}"
|
||||
|
||||
|
||||
def _seconds_until_month_end() -> int:
|
||||
"""Seconds from now until midnight UTC on the 1st of next month."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
_, days_in_month = calendar.monthrange(now.year, now.month)
|
||||
first_of_next = now.replace(
|
||||
day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
) + datetime.timedelta(days=days_in_month)
|
||||
return max(int((first_of_next - now).total_seconds()), 1)
|
||||
|
||||
|
||||
class HybridEngine(OcrEngine):
|
||||
"""Runs a primary engine and falls back to a cloud engine when
|
||||
the primary result confidence is below the configured threshold.
|
||||
"""Runs a primary engine with an optional fallback engine and a
|
||||
configurable monthly usage cap on cloud API requests.
|
||||
|
||||
If the fallback is ``None`` (default), this engine behaves identically
|
||||
to the primary engine. Cloud failures are handled gracefully -- the
|
||||
primary result is returned whenever the fallback is unavailable,
|
||||
times out, or errors.
|
||||
**When the primary engine is a cloud engine** (e.g. ``google_vision``),
|
||||
the monthly cap is checked *before* calling the primary. Once the
|
||||
limit is reached the fallback becomes the sole engine for the rest
|
||||
of the calendar month.
|
||||
|
||||
**When the primary engine is local** (e.g. ``paddleocr``), the
|
||||
original confidence-based fallback logic applies: if confidence is
|
||||
below the threshold, the cloud fallback is tried (subject to the
|
||||
same monthly cap).
|
||||
|
||||
Cloud failures are handled gracefully -- the local result is always
|
||||
returned when the cloud engine is unavailable, times out, or errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -32,21 +64,143 @@ class HybridEngine(OcrEngine):
|
||||
primary: OcrEngine,
|
||||
fallback: OcrEngine | None = None,
|
||||
threshold: float = 0.6,
|
||||
monthly_limit: int = 1000,
|
||||
) -> None:
|
||||
self._primary = primary
|
||||
self._fallback = fallback
|
||||
self._threshold = threshold
|
||||
self._monthly_limit = monthly_limit
|
||||
self._redis: redis.Redis | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
fallback_name = self._fallback.name if self._fallback else "none"
|
||||
return f"hybrid({self._primary.name}+{fallback_name})"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Redis helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_redis(self) -> redis.Redis:
|
||||
"""Return a synchronous Redis connection (lazy init)."""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
self._redis = redis.Redis(
|
||||
host=settings.redis_host,
|
||||
port=settings.redis_port,
|
||||
db=settings.redis_db,
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
def _vision_limit_reached(self) -> bool:
|
||||
"""Check whether the monthly Vision API limit has been reached."""
|
||||
try:
|
||||
r = self._get_redis()
|
||||
count = r.get(_vision_counter_key())
|
||||
current = int(count) if count else 0
|
||||
if current >= self._monthly_limit:
|
||||
logger.info(
|
||||
"Vision monthly limit reached (%d/%d)",
|
||||
current,
|
||||
self._monthly_limit,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Redis counter check failed, assuming limit NOT reached: %s",
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
def _increment_vision_counter(self) -> None:
|
||||
"""Atomically increment the monthly Vision counter with TTL."""
|
||||
try:
|
||||
r = self._get_redis()
|
||||
key = _vision_counter_key()
|
||||
pipe = r.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, _seconds_until_month_end())
|
||||
pipe.execute()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to increment Vision counter: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Engine selection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _is_cloud_engine(self, engine: OcrEngine) -> bool:
|
||||
"""Return True if this engine calls a cloud API."""
|
||||
return engine.name == "google_vision"
|
||||
|
||||
def _run_cloud_with_cap(
|
||||
self, cloud: OcrEngine, image_bytes: bytes, config: OcrConfig
|
||||
) -> OcrEngineResult | None:
|
||||
"""Run a cloud engine if the monthly cap allows, else return None."""
|
||||
if self._vision_limit_reached():
|
||||
return None
|
||||
|
||||
try:
|
||||
start = time.monotonic()
|
||||
result = cloud.recognize(image_bytes, config)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
if elapsed > _CLOUD_TIMEOUT_SECONDS:
|
||||
logger.warning(
|
||||
"Cloud engine took %.1fs (> %.1fs limit), discarding result",
|
||||
elapsed,
|
||||
_CLOUD_TIMEOUT_SECONDS,
|
||||
)
|
||||
return None
|
||||
|
||||
self._increment_vision_counter()
|
||||
return result
|
||||
except EngineError as exc:
|
||||
logger.warning("Cloud engine failed: %s", exc)
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.warning("Unexpected cloud engine error: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main recognize
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def recognize(self, image_bytes: bytes, config: OcrConfig) -> OcrEngineResult:
|
||||
"""Run primary OCR, optionally falling back to cloud engine."""
|
||||
"""Run OCR with monthly-capped cloud usage.
|
||||
|
||||
When primary is cloud: check cap -> run cloud or fall back.
|
||||
When primary is local: run local -> if low confidence, try cloud
|
||||
fallback (also subject to cap).
|
||||
"""
|
||||
# --- Cloud-primary path ---
|
||||
if self._is_cloud_engine(self._primary):
|
||||
cloud_result = self._run_cloud_with_cap(
|
||||
self._primary, image_bytes, config
|
||||
)
|
||||
if cloud_result is not None:
|
||||
logger.debug(
|
||||
"Cloud primary returned confidence %.2f",
|
||||
cloud_result.confidence,
|
||||
)
|
||||
return cloud_result
|
||||
|
||||
# Limit reached or cloud failed -- use fallback
|
||||
if self._fallback is not None:
|
||||
logger.info(
|
||||
"Cloud primary unavailable/capped, using fallback (%s)",
|
||||
self._fallback.name,
|
||||
)
|
||||
return self._fallback.recognize(image_bytes, config)
|
||||
|
||||
raise EngineProcessingError(
|
||||
"Cloud primary unavailable and no fallback configured"
|
||||
)
|
||||
|
||||
# --- Local-primary path (original confidence-based fallback) ---
|
||||
primary_result = self._primary.recognize(image_bytes, config)
|
||||
|
||||
# Happy path: primary confidence meets threshold
|
||||
if primary_result.confidence >= self._threshold:
|
||||
logger.debug(
|
||||
"Primary engine confidence %.2f >= threshold %.2f, no fallback",
|
||||
@@ -55,7 +209,6 @@ class HybridEngine(OcrEngine):
|
||||
)
|
||||
return primary_result
|
||||
|
||||
# No fallback configured -- return primary result as-is
|
||||
if self._fallback is None:
|
||||
logger.debug(
|
||||
"Primary confidence %.2f < threshold %.2f but no fallback configured",
|
||||
@@ -64,14 +217,39 @@ class HybridEngine(OcrEngine):
|
||||
)
|
||||
return primary_result
|
||||
|
||||
# Attempt cloud fallback with timeout guard
|
||||
# Only try cloud fallback if it is the fallback engine
|
||||
if self._is_cloud_engine(self._fallback):
|
||||
logger.info(
|
||||
"Primary confidence %.2f < threshold %.2f, trying cloud fallback (%s)",
|
||||
primary_result.confidence,
|
||||
self._threshold,
|
||||
self._fallback.name,
|
||||
)
|
||||
fallback_result = self._run_cloud_with_cap(
|
||||
self._fallback, image_bytes, config
|
||||
)
|
||||
if fallback_result is not None:
|
||||
if fallback_result.confidence > primary_result.confidence:
|
||||
logger.info(
|
||||
"Fallback confidence %.2f > primary %.2f, using fallback",
|
||||
fallback_result.confidence,
|
||||
primary_result.confidence,
|
||||
)
|
||||
return fallback_result
|
||||
logger.info(
|
||||
"Primary confidence %.2f >= fallback %.2f, keeping primary",
|
||||
primary_result.confidence,
|
||||
fallback_result.confidence,
|
||||
)
|
||||
return primary_result
|
||||
|
||||
# Non-cloud fallback (no cap needed)
|
||||
logger.info(
|
||||
"Primary confidence %.2f < threshold %.2f, trying fallback (%s)",
|
||||
primary_result.confidence,
|
||||
self._threshold,
|
||||
self._fallback.name,
|
||||
)
|
||||
|
||||
try:
|
||||
start = time.monotonic()
|
||||
fallback_result = self._fallback.recognize(image_bytes, config)
|
||||
@@ -79,23 +257,22 @@ class HybridEngine(OcrEngine):
|
||||
|
||||
if elapsed > _CLOUD_TIMEOUT_SECONDS:
|
||||
logger.warning(
|
||||
"Cloud fallback took %.1fs (> %.1fs limit), using primary result",
|
||||
"Fallback took %.1fs (> %.1fs limit), using primary result",
|
||||
elapsed,
|
||||
_CLOUD_TIMEOUT_SECONDS,
|
||||
)
|
||||
return primary_result
|
||||
|
||||
# Return whichever result has higher confidence
|
||||
if fallback_result.confidence > primary_result.confidence:
|
||||
logger.info(
|
||||
"Fallback confidence %.2f > primary %.2f, using fallback result",
|
||||
"Fallback confidence %.2f > primary %.2f, using fallback",
|
||||
fallback_result.confidence,
|
||||
primary_result.confidence,
|
||||
)
|
||||
return fallback_result
|
||||
|
||||
logger.info(
|
||||
"Primary confidence %.2f >= fallback %.2f, keeping primary result",
|
||||
"Primary confidence %.2f >= fallback %.2f, keeping primary",
|
||||
primary_result.confidence,
|
||||
fallback_result.confidence,
|
||||
)
|
||||
@@ -103,14 +280,13 @@ class HybridEngine(OcrEngine):
|
||||
|
||||
except EngineError as exc:
|
||||
logger.warning(
|
||||
"Cloud fallback failed (%s), returning primary result: %s",
|
||||
"Fallback failed (%s), returning primary: %s",
|
||||
self._fallback.name,
|
||||
exc,
|
||||
)
|
||||
return primary_result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Unexpected cloud fallback error, returning primary result: %s",
|
||||
exc,
|
||||
"Unexpected fallback error, returning primary: %s", exc
|
||||
)
|
||||
return primary_result
|
||||
|
||||
Reference in New Issue
Block a user