feat: Migrate Gemini SDK to google-genai (#231) #236
@@ -14,6 +14,7 @@ import time
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.engines.gemini_engine import GeminiUnavailableError
|
||||||
from app.extractors.receipt_extractor import (
|
from app.extractors.receipt_extractor import (
|
||||||
ExtractedField,
|
ExtractedField,
|
||||||
ReceiptExtractionResult,
|
ReceiptExtractionResult,
|
||||||
@@ -54,16 +55,16 @@ OCR Text:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_RECEIPT_RESPONSE_SCHEMA: dict[str, Any] = {
|
_RECEIPT_RESPONSE_SCHEMA: dict[str, Any] = {
|
||||||
"type": "object",
|
"type": "OBJECT",
|
||||||
"properties": {
|
"properties": {
|
||||||
"serviceName": {"type": "string", "nullable": True},
|
"serviceName": {"type": "STRING", "nullable": True},
|
||||||
"serviceDate": {"type": "string", "nullable": True},
|
"serviceDate": {"type": "STRING", "nullable": True},
|
||||||
"totalCost": {"type": "number", "nullable": True},
|
"totalCost": {"type": "NUMBER", "nullable": True},
|
||||||
"shopName": {"type": "string", "nullable": True},
|
"shopName": {"type": "STRING", "nullable": True},
|
||||||
"laborCost": {"type": "number", "nullable": True},
|
"laborCost": {"type": "NUMBER", "nullable": True},
|
||||||
"partsCost": {"type": "number", "nullable": True},
|
"partsCost": {"type": "NUMBER", "nullable": True},
|
||||||
"odometerReading": {"type": "number", "nullable": True},
|
"odometerReading": {"type": "NUMBER", "nullable": True},
|
||||||
"vehicleInfo": {"type": "string", "nullable": True},
|
"vehicleInfo": {"type": "STRING", "nullable": True},
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"serviceName",
|
"serviceName",
|
||||||
@@ -87,8 +88,8 @@ class MaintenanceReceiptExtractor:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._model: Any | None = None
|
self._client: Any | None = None
|
||||||
self._generation_config: Any | None = None
|
self._model_name: str = ""
|
||||||
|
|
||||||
def extract(
|
def extract(
|
||||||
self,
|
self,
|
||||||
@@ -169,47 +170,52 @@ class MaintenanceReceiptExtractor:
|
|||||||
processing_time_ms=processing_time_ms,
|
processing_time_ms=processing_time_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model(self) -> Any:
|
def _get_client(self) -> Any:
|
||||||
"""Lazy-initialize Vertex AI Gemini model.
|
"""Lazy-initialize google-genai Gemini client.
|
||||||
|
|
||||||
Uses the same authentication pattern as GeminiEngine.
|
Uses the same authentication pattern as GeminiEngine.
|
||||||
"""
|
"""
|
||||||
if self._model is not None:
|
if self._client is not None:
|
||||||
return self._model
|
return self._client
|
||||||
|
|
||||||
key_path = settings.google_vision_key_path
|
key_path = settings.google_vision_key_path
|
||||||
if not os.path.isfile(key_path):
|
if not os.path.isfile(key_path):
|
||||||
raise RuntimeError(
|
raise GeminiUnavailableError(
|
||||||
f"Google credential config not found at {key_path}. "
|
f"Google credential config not found at {key_path}. "
|
||||||
"Set GOOGLE_VISION_KEY_PATH or mount the secret."
|
"Set GOOGLE_VISION_KEY_PATH or mount the secret."
|
||||||
)
|
)
|
||||||
|
|
||||||
from google.cloud import aiplatform # type: ignore[import-untyped]
|
try:
|
||||||
from vertexai.generative_models import ( # type: ignore[import-untyped]
|
from google import genai # type: ignore[import-untyped]
|
||||||
GenerationConfig,
|
|
||||||
GenerativeModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = key_path
|
# Point ADC at the WIF credential config (must be set BEFORE Client construction)
|
||||||
os.environ["GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES"] = "1"
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = key_path
|
||||||
|
os.environ["GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES"] = "1"
|
||||||
|
|
||||||
aiplatform.init(
|
self._client = genai.Client(
|
||||||
project=settings.vertex_ai_project,
|
vertexai=True,
|
||||||
location=settings.vertex_ai_location,
|
project=settings.vertex_ai_project,
|
||||||
)
|
location=settings.vertex_ai_location,
|
||||||
|
)
|
||||||
|
self._model_name = settings.gemini_model
|
||||||
|
|
||||||
model_name = settings.gemini_model
|
logger.info(
|
||||||
self._model = GenerativeModel(model_name)
|
"Maintenance receipt Gemini client initialized (model=%s)",
|
||||||
self._generation_config = GenerationConfig(
|
self._model_name,
|
||||||
response_mime_type="application/json",
|
)
|
||||||
response_schema=_RECEIPT_RESPONSE_SCHEMA,
|
return self._client
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
except ImportError as exc:
|
||||||
"Maintenance receipt Gemini model initialized (model=%s)",
|
logger.exception("google-genai SDK import failed")
|
||||||
model_name,
|
raise GeminiUnavailableError(
|
||||||
)
|
"google-genai is not installed. "
|
||||||
return self._model
|
"Install with: pip install google-genai"
|
||||||
|
) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Gemini authentication failed: %s", type(exc).__name__)
|
||||||
|
raise GeminiUnavailableError(
|
||||||
|
f"Gemini authentication failed: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
def _extract_with_gemini(self, ocr_text: str) -> dict:
|
def _extract_with_gemini(self, ocr_text: str) -> dict:
|
||||||
"""Send OCR text to Gemini for semantic field extraction.
|
"""Send OCR text to Gemini for semantic field extraction.
|
||||||
@@ -220,13 +226,19 @@ class MaintenanceReceiptExtractor:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of field_name -> extracted_value from Gemini.
|
Dictionary of field_name -> extracted_value from Gemini.
|
||||||
"""
|
"""
|
||||||
model = self._get_model()
|
client = self._get_client()
|
||||||
|
|
||||||
|
from google.genai import types # type: ignore[import-untyped]
|
||||||
|
|
||||||
prompt = _RECEIPT_EXTRACTION_PROMPT.format(ocr_text=ocr_text)
|
prompt = _RECEIPT_EXTRACTION_PROMPT.format(ocr_text=ocr_text)
|
||||||
|
|
||||||
response = model.generate_content(
|
response = client.models.generate_content(
|
||||||
[prompt],
|
model=self._model_name,
|
||||||
generation_config=self._generation_config,
|
contents=[prompt],
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
response_mime_type="application/json",
|
||||||
|
response_schema=_RECEIPT_RESPONSE_SCHEMA,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
raw = json.loads(response.text)
|
raw = json.loads(response.text)
|
||||||
|
|||||||
Reference in New Issue
Block a user