mirror of
https://github.com/HKUDS/AI-Trader
synced 2026-04-21 13:37:41 +00:00
Migrate backend to PostgreSQL and harden compatibility (#175)
This commit is contained in:
parent
5f29e69ffc
commit
3b3169b756
3 changed files with 48 additions and 8 deletions
|
|
@ -42,6 +42,7 @@ _ALTER_ADD_COLUMN_PATTERN = re.compile(
|
|||
r"\bALTER\s+TABLE\s+([A-Za-z_][A-Za-z0-9_]*)\s+ADD\s+COLUMN\s+(?!IF\s+NOT\s+EXISTS)",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
_POSTGRES_RETRYABLE_SQLSTATES = {"40001", "40P01", "55P03"}
|
||||
|
||||
|
||||
def using_postgres() -> bool:
|
||||
|
|
@ -52,6 +53,40 @@ def get_database_backend_name() -> str:
|
|||
return "postgresql" if using_postgres() else "sqlite"
|
||||
|
||||
|
||||
def begin_write_transaction(cursor: Any) -> None:
|
||||
"""Start a write transaction using syntax compatible with the active backend."""
|
||||
if using_postgres():
|
||||
cursor.execute("BEGIN")
|
||||
return
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
|
||||
|
||||
def is_retryable_db_error(exc: Exception) -> bool:
|
||||
"""Return True when the error is a transient write conflict worth retrying."""
|
||||
if isinstance(exc, sqlite3.OperationalError):
|
||||
message = str(exc).lower()
|
||||
return "database is locked" in message or "database is busy" in message
|
||||
|
||||
sqlstate = getattr(exc, "sqlstate", None)
|
||||
if not sqlstate:
|
||||
cause = getattr(exc, "__cause__", None)
|
||||
sqlstate = getattr(cause, "sqlstate", None)
|
||||
if sqlstate in _POSTGRES_RETRYABLE_SQLSTATES:
|
||||
return True
|
||||
|
||||
message = str(exc).lower()
|
||||
return any(
|
||||
fragment in message
|
||||
for fragment in (
|
||||
"could not serialize access",
|
||||
"deadlock detected",
|
||||
"lock not available",
|
||||
"database is locked",
|
||||
"database is busy",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _replace_unquoted_question_marks(sql: str) -> str:
|
||||
"""Translate sqlite-style placeholders to psycopg placeholders."""
|
||||
result: list[str] = []
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ def _enforce_content_rate_limit(agent_id: int, action: str, content: str, target
|
|||
}
|
||||
|
||||
from config import CORS_ORIGINS, SIGNAL_PUBLISH_REWARD, SIGNAL_ADOPT_REWARD, DISCUSSION_PUBLISH_REWARD, REPLY_PUBLISH_REWARD
|
||||
from database import get_db_connection
|
||||
from database import begin_write_transaction, get_db_connection
|
||||
from market_intel import (
|
||||
get_market_intel_overview,
|
||||
get_market_news_payload,
|
||||
|
|
@ -1113,7 +1113,7 @@ def create_app() -> FastAPI:
|
|||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
begin_write_transaction(cursor)
|
||||
signal_id = _reserve_signal_id(cursor)
|
||||
|
||||
if action_lower in ("sell", "cover"):
|
||||
|
|
@ -1209,7 +1209,7 @@ def create_app() -> FastAPI:
|
|||
# Get all followers of this agent
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
begin_write_transaction(cursor)
|
||||
cursor.execute("""
|
||||
SELECT follower_id FROM subscriptions
|
||||
WHERE leader_id = ? AND status = 'active'
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@ Services Module
|
|||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List
|
||||
from database import get_db_connection
|
||||
from database import get_db_connection, is_retryable_db_error
|
||||
|
||||
|
||||
# ==================== Agent Services ====================
|
||||
|
|
@ -66,7 +67,7 @@ def _add_agent_points(agent_id: int, points: int, reason: str = "reward") -> boo
|
|||
if points <= 0:
|
||||
return False
|
||||
|
||||
# Retry logic for database locking
|
||||
# Retry transient write conflicts on both SQLite and PostgreSQL.
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
conn = get_db_connection()
|
||||
|
|
@ -78,9 +79,13 @@ def _add_agent_points(agent_id: int, points: int, reason: str = "reward") -> boo
|
|||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
if "database is locked" in str(e) and attempt < max_retries - 1:
|
||||
import time
|
||||
time.sleep(0.5 * (attempt + 1)) # Exponential backoff
|
||||
try:
|
||||
conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_retryable_db_error(e) and attempt < max_retries - 1:
|
||||
time.sleep(0.5 * (attempt + 1))
|
||||
continue
|
||||
print(f"[ERROR] Failed to add points to agent {agent_id}: {e}")
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue