feat: add Vision monthly cap, WIF auth, and cloud-primary hybrid engine (refs #127)

- Add VISION_MONTHLY_LIMIT config setting (default 1000)
- Update CloudEngine to use WIF credential config via ADC
- Rewrite HybridEngine to support cloud-primary with Redis counter
- Pass monthly_limit through engine factory

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Eric Gullickson
2026-02-09 20:50:02 -06:00
parent 4412700e12
commit 4abd7d8d5b
4 changed files with 225 additions and 29 deletions

View File

@@ -21,7 +21,12 @@ class Settings:
os.getenv("OCR_FALLBACK_THRESHOLD", "0.6") os.getenv("OCR_FALLBACK_THRESHOLD", "0.6")
) )
self.google_vision_key_path: str = os.getenv( self.google_vision_key_path: str = os.getenv(
"GOOGLE_VISION_KEY_PATH", "/run/secrets/google-vision-key.json" "GOOGLE_VISION_KEY_PATH", "/run/secrets/google-wif-config.json"
)
# Google Vision monthly usage cap (requests per calendar month)
self.vision_monthly_limit: int = int(
os.getenv("VISION_MONTHLY_LIMIT", "1000")
) )
# Redis configuration for job queue # Redis configuration for job queue

View File

@@ -15,8 +15,8 @@ from app.engines.base_engine import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Default path for Google Vision service account key (Docker secret mount) # Default path for Google WIF credential config (Docker secret mount)
_DEFAULT_KEY_PATH = "/run/secrets/google-vision-key.json" _DEFAULT_KEY_PATH = "/run/secrets/google-wif-config.json"
class CloudEngine(OcrEngine): class CloudEngine(OcrEngine):
@@ -42,25 +42,33 @@ class CloudEngine(OcrEngine):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _get_client(self) -> Any: def _get_client(self) -> Any:
"""Create the Vision client on first use.""" """Create the Vision client on first use.
Uses Application Default Credentials (ADC) pointed at a WIF
credential config file. The WIF config references an executable
that fetches an Auth0 M2M JWT.
"""
if self._client is not None: if self._client is not None:
return self._client return self._client
# Verify credentials file exists # Verify credentials config exists
if not os.path.isfile(self._key_path): if not os.path.isfile(self._key_path):
raise EngineUnavailableError( raise EngineUnavailableError(
f"Google Vision key not found at {self._key_path}. " f"Google Vision credential config not found at {self._key_path}. "
"Set GOOGLE_VISION_KEY_PATH or mount the secret." "Set GOOGLE_VISION_KEY_PATH or mount the secret."
) )
try: try:
from google.cloud import vision # type: ignore[import-untyped] from google.cloud import vision # type: ignore[import-untyped]
# Point the SDK at the service account key # Point ADC at the WIF credential config
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._key_path os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._key_path
# Required for executable-sourced credentials
os.environ["GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES"] = "1"
self._client = vision.ImageAnnotatorClient() self._client = vision.ImageAnnotatorClient()
logger.info( logger.info(
"Google Vision client initialized (key: %s)", self._key_path "Google Vision client initialized via WIF (config: %s)",
self._key_path,
) )
return self._client return self._client
except ImportError as exc: except ImportError as exc:

View File

@@ -76,11 +76,18 @@ def create_engine(engine_name: str | None = None) -> OcrEngine:
from app.engines.hybrid_engine import HybridEngine from app.engines.hybrid_engine import HybridEngine
threshold = settings.ocr_fallback_threshold threshold = settings.ocr_fallback_threshold
hybrid = HybridEngine(primary=primary, fallback=fallback, threshold=threshold) monthly_limit = settings.vision_monthly_limit
hybrid = HybridEngine(
primary=primary,
fallback=fallback,
threshold=threshold,
monthly_limit=monthly_limit,
)
logger.info( logger.info(
"Created hybrid engine: primary=%s, fallback=%s, threshold=%.2f", "Created hybrid engine: primary=%s, fallback=%s, threshold=%.2f, vision_limit=%d",
name, name,
fallback_name, fallback_name,
threshold, threshold,
monthly_limit,
) )
return hybrid return hybrid

View File

@@ -1,8 +1,13 @@
"""Hybrid OCR engine: primary engine with optional cloud fallback.""" """Hybrid OCR engine: primary with fallback and monthly usage cap."""
import calendar
import datetime
import logging import logging
import time import time
import redis
from app.config import settings
from app.engines.base_engine import ( from app.engines.base_engine import (
EngineError, EngineError,
EngineProcessingError, EngineProcessingError,
@@ -16,15 +21,42 @@ logger = logging.getLogger(__name__)
# Maximum time (seconds) to wait for the cloud fallback # Maximum time (seconds) to wait for the cloud fallback
_CLOUD_TIMEOUT_SECONDS = 5.0 _CLOUD_TIMEOUT_SECONDS = 5.0
# Redis key prefix for monthly Vision API request counter
_VISION_COUNTER_PREFIX = "ocr:vision_requests"
def _vision_counter_key() -> str:
"""Return the Redis key for the current calendar month counter."""
now = datetime.datetime.now(datetime.timezone.utc)
return f"{_VISION_COUNTER_PREFIX}:{now.strftime('%Y-%m')}"
def _seconds_until_month_end() -> int:
"""Seconds from now until midnight UTC on the 1st of next month."""
now = datetime.datetime.now(datetime.timezone.utc)
_, days_in_month = calendar.monthrange(now.year, now.month)
first_of_next = now.replace(
day=1, hour=0, minute=0, second=0, microsecond=0
) + datetime.timedelta(days=days_in_month)
return max(int((first_of_next - now).total_seconds()), 1)
class HybridEngine(OcrEngine): class HybridEngine(OcrEngine):
"""Runs a primary engine and falls back to a cloud engine when """Runs a primary engine with an optional fallback engine and a
the primary result confidence is below the configured threshold. configurable monthly usage cap on cloud API requests.
If the fallback is ``None`` (default), this engine behaves identically **When the primary engine is a cloud engine** (e.g. ``google_vision``),
to the primary engine. Cloud failures are handled gracefully -- the the monthly cap is checked *before* calling the primary. Once the
primary result is returned whenever the fallback is unavailable, limit is reached the fallback becomes the sole engine for the rest
times out, or errors. of the calendar month.
**When the primary engine is local** (e.g. ``paddleocr``), the
original confidence-based fallback logic applies: if confidence is
below the threshold, the cloud fallback is tried (subject to the
same monthly cap).
Cloud failures are handled gracefully -- the local result is always
returned when the cloud engine is unavailable, times out, or errors.
""" """
def __init__( def __init__(
@@ -32,21 +64,143 @@ class HybridEngine(OcrEngine):
primary: OcrEngine, primary: OcrEngine,
fallback: OcrEngine | None = None, fallback: OcrEngine | None = None,
threshold: float = 0.6, threshold: float = 0.6,
monthly_limit: int = 1000,
) -> None: ) -> None:
self._primary = primary self._primary = primary
self._fallback = fallback self._fallback = fallback
self._threshold = threshold self._threshold = threshold
self._monthly_limit = monthly_limit
self._redis: redis.Redis | None = None
@property @property
def name(self) -> str: def name(self) -> str:
fallback_name = self._fallback.name if self._fallback else "none" fallback_name = self._fallback.name if self._fallback else "none"
return f"hybrid({self._primary.name}+{fallback_name})" return f"hybrid({self._primary.name}+{fallback_name})"
# ------------------------------------------------------------------
# Redis helpers
# ------------------------------------------------------------------
def _get_redis(self) -> redis.Redis:
"""Return a synchronous Redis connection (lazy init)."""
if self._redis is not None:
return self._redis
self._redis = redis.Redis(
host=settings.redis_host,
port=settings.redis_port,
db=settings.redis_db,
decode_responses=True,
)
return self._redis
def _vision_limit_reached(self) -> bool:
"""Check whether the monthly Vision API limit has been reached."""
try:
r = self._get_redis()
count = r.get(_vision_counter_key())
current = int(count) if count else 0
if current >= self._monthly_limit:
logger.info(
"Vision monthly limit reached (%d/%d)",
current,
self._monthly_limit,
)
return True
return False
except Exception as exc:
logger.warning(
"Redis counter check failed, assuming limit NOT reached: %s",
exc,
)
return False
def _increment_vision_counter(self) -> None:
"""Atomically increment the monthly Vision counter with TTL."""
try:
r = self._get_redis()
key = _vision_counter_key()
pipe = r.pipeline()
pipe.incr(key)
pipe.expire(key, _seconds_until_month_end())
pipe.execute()
except Exception as exc:
logger.warning("Failed to increment Vision counter: %s", exc)
# ------------------------------------------------------------------
# Engine selection helpers
# ------------------------------------------------------------------
def _is_cloud_engine(self, engine: OcrEngine) -> bool:
"""Return True if this engine calls a cloud API."""
return engine.name == "google_vision"
def _run_cloud_with_cap(
self, cloud: OcrEngine, image_bytes: bytes, config: OcrConfig
) -> OcrEngineResult | None:
"""Run a cloud engine if the monthly cap allows, else return None."""
if self._vision_limit_reached():
return None
try:
start = time.monotonic()
result = cloud.recognize(image_bytes, config)
elapsed = time.monotonic() - start
if elapsed > _CLOUD_TIMEOUT_SECONDS:
logger.warning(
"Cloud engine took %.1fs (> %.1fs limit), discarding result",
elapsed,
_CLOUD_TIMEOUT_SECONDS,
)
return None
self._increment_vision_counter()
return result
except EngineError as exc:
logger.warning("Cloud engine failed: %s", exc)
return None
except Exception as exc:
logger.warning("Unexpected cloud engine error: %s", exc)
return None
# ------------------------------------------------------------------
# Main recognize
# ------------------------------------------------------------------
def recognize(self, image_bytes: bytes, config: OcrConfig) -> OcrEngineResult: def recognize(self, image_bytes: bytes, config: OcrConfig) -> OcrEngineResult:
"""Run primary OCR, optionally falling back to cloud engine.""" """Run OCR with monthly-capped cloud usage.
When primary is cloud: check cap -> run cloud or fall back.
When primary is local: run local -> if low confidence, try cloud
fallback (also subject to cap).
"""
# --- Cloud-primary path ---
if self._is_cloud_engine(self._primary):
cloud_result = self._run_cloud_with_cap(
self._primary, image_bytes, config
)
if cloud_result is not None:
logger.debug(
"Cloud primary returned confidence %.2f",
cloud_result.confidence,
)
return cloud_result
# Limit reached or cloud failed -- use fallback
if self._fallback is not None:
logger.info(
"Cloud primary unavailable/capped, using fallback (%s)",
self._fallback.name,
)
return self._fallback.recognize(image_bytes, config)
raise EngineProcessingError(
"Cloud primary unavailable and no fallback configured"
)
# --- Local-primary path (original confidence-based fallback) ---
primary_result = self._primary.recognize(image_bytes, config) primary_result = self._primary.recognize(image_bytes, config)
# Happy path: primary confidence meets threshold
if primary_result.confidence >= self._threshold: if primary_result.confidence >= self._threshold:
logger.debug( logger.debug(
"Primary engine confidence %.2f >= threshold %.2f, no fallback", "Primary engine confidence %.2f >= threshold %.2f, no fallback",
@@ -55,7 +209,6 @@ class HybridEngine(OcrEngine):
) )
return primary_result return primary_result
# No fallback configured -- return primary result as-is
if self._fallback is None: if self._fallback is None:
logger.debug( logger.debug(
"Primary confidence %.2f < threshold %.2f but no fallback configured", "Primary confidence %.2f < threshold %.2f but no fallback configured",
@@ -64,14 +217,39 @@ class HybridEngine(OcrEngine):
) )
return primary_result return primary_result
# Attempt cloud fallback with timeout guard # Only try cloud fallback if it is the fallback engine
if self._is_cloud_engine(self._fallback):
logger.info(
"Primary confidence %.2f < threshold %.2f, trying cloud fallback (%s)",
primary_result.confidence,
self._threshold,
self._fallback.name,
)
fallback_result = self._run_cloud_with_cap(
self._fallback, image_bytes, config
)
if fallback_result is not None:
if fallback_result.confidence > primary_result.confidence:
logger.info(
"Fallback confidence %.2f > primary %.2f, using fallback",
fallback_result.confidence,
primary_result.confidence,
)
return fallback_result
logger.info(
"Primary confidence %.2f >= fallback %.2f, keeping primary",
primary_result.confidence,
fallback_result.confidence,
)
return primary_result
# Non-cloud fallback (no cap needed)
logger.info( logger.info(
"Primary confidence %.2f < threshold %.2f, trying fallback (%s)", "Primary confidence %.2f < threshold %.2f, trying fallback (%s)",
primary_result.confidence, primary_result.confidence,
self._threshold, self._threshold,
self._fallback.name, self._fallback.name,
) )
try: try:
start = time.monotonic() start = time.monotonic()
fallback_result = self._fallback.recognize(image_bytes, config) fallback_result = self._fallback.recognize(image_bytes, config)
@@ -79,23 +257,22 @@ class HybridEngine(OcrEngine):
if elapsed > _CLOUD_TIMEOUT_SECONDS: if elapsed > _CLOUD_TIMEOUT_SECONDS:
logger.warning( logger.warning(
"Cloud fallback took %.1fs (> %.1fs limit), using primary result", "Fallback took %.1fs (> %.1fs limit), using primary result",
elapsed, elapsed,
_CLOUD_TIMEOUT_SECONDS, _CLOUD_TIMEOUT_SECONDS,
) )
return primary_result return primary_result
# Return whichever result has higher confidence
if fallback_result.confidence > primary_result.confidence: if fallback_result.confidence > primary_result.confidence:
logger.info( logger.info(
"Fallback confidence %.2f > primary %.2f, using fallback result", "Fallback confidence %.2f > primary %.2f, using fallback",
fallback_result.confidence, fallback_result.confidence,
primary_result.confidence, primary_result.confidence,
) )
return fallback_result return fallback_result
logger.info( logger.info(
"Primary confidence %.2f >= fallback %.2f, keeping primary result", "Primary confidence %.2f >= fallback %.2f, keeping primary",
primary_result.confidence, primary_result.confidence,
fallback_result.confidence, fallback_result.confidence,
) )
@@ -103,14 +280,13 @@ class HybridEngine(OcrEngine):
except EngineError as exc: except EngineError as exc:
logger.warning( logger.warning(
"Cloud fallback failed (%s), returning primary result: %s", "Fallback failed (%s), returning primary: %s",
self._fallback.name, self._fallback.name,
exc, exc,
) )
return primary_result return primary_result
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"Unexpected cloud fallback error, returning primary result: %s", "Unexpected fallback error, returning primary: %s", exc
exc,
) )
return primary_result return primary_result