diff --git a/backend/main.py b/backend/main.py index c7c2c7a..3c5ea96 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2179,3 +2179,545 @@ async def strategy_plaza_trades( strategy_id, limit ) return {"trades": [dict(r) for r in rows]} + + +# ───────────────────────────────────────────────────────────────────────────── +# V5.4 Strategy Factory API +# ───────────────────────────────────────────────────────────────────────────── +import uuid as _uuid +from typing import Optional +from pydantic import BaseModel, Field, field_validator, model_validator + +# ── Pydantic Models ────────────────────────────────────────────────────────── + +class StrategyCreateRequest(BaseModel): + display_name: str = Field(..., min_length=1, max_length=50) + symbol: str + direction: str = "both" + initial_balance: float = 10000.0 + cvd_fast_window: str = "30m" + cvd_slow_window: str = "4h" + weight_direction: int = 55 + weight_env: int = 25 + weight_aux: int = 15 + weight_momentum: int = 5 + entry_score: int = 75 + gate_obi_enabled: bool = True + obi_threshold: float = 0.3 + gate_whale_enabled: bool = True + whale_cvd_threshold: float = 0.0 + gate_vol_enabled: bool = True + atr_percentile_min: int = 20 + gate_spot_perp_enabled: bool = False + spot_perp_threshold: float = 0.002 + sl_atr_multiplier: float = 1.5 + tp1_ratio: float = 0.75 + tp2_ratio: float = 1.5 + timeout_minutes: int = 240 + flip_threshold: int = 80 + description: Optional[str] = None + + @field_validator("symbol") + @classmethod + def validate_symbol(cls, v): + allowed = {"BTCUSDT", "ETHUSDT", "SOLUSDT", "XRPUSDT"} + if v not in allowed: + raise ValueError(f"symbol must be one of {allowed}") + return v + + @field_validator("direction") + @classmethod + def validate_direction(cls, v): + if v not in {"long_only", "short_only", "both"}: + raise ValueError("direction must be long_only, short_only, or both") + return v + + @field_validator("cvd_fast_window") + @classmethod + def validate_cvd_fast(cls, v): + if v not in {"5m", "15m", "30m"}: + raise ValueError("cvd_fast_window must be 5m, 15m, or 30m") + return v + + @field_validator("cvd_slow_window") + @classmethod + def validate_cvd_slow(cls, v): + if v not in {"30m", "1h", "4h"}: + raise ValueError("cvd_slow_window must be 30m, 1h, or 4h") + return v + + @field_validator("weight_direction") + @classmethod + def validate_w_dir(cls, v): + if not 10 <= v <= 80: + raise ValueError("weight_direction must be 10-80") + return v + + @field_validator("weight_env") + @classmethod + def validate_w_env(cls, v): + if not 5 <= v <= 60: + raise ValueError("weight_env must be 5-60") + return v + + @field_validator("weight_aux") + @classmethod + def validate_w_aux(cls, v): + if not 0 <= v <= 40: + raise ValueError("weight_aux must be 0-40") + return v + + @field_validator("weight_momentum") + @classmethod + def validate_w_mom(cls, v): + if not 0 <= v <= 20: + raise ValueError("weight_momentum must be 0-20") + return v + + @model_validator(mode="after") + def validate_weights_sum(self): + total = self.weight_direction + self.weight_env + self.weight_aux + self.weight_momentum + if total != 100: + raise ValueError(f"Weights must sum to 100, got {total}") + return self + + @field_validator("entry_score") + @classmethod + def validate_entry_score(cls, v): + if not 60 <= v <= 95: + raise ValueError("entry_score must be 60-95") + return v + + @field_validator("obi_threshold") + @classmethod + def validate_obi(cls, v): + if not 0.1 <= v <= 0.9: + raise ValueError("obi_threshold must be 0.1-0.9") + return v + + @field_validator("whale_cvd_threshold") + @classmethod + def validate_whale(cls, v): + if not -1.0 <= v <= 1.0: + raise ValueError("whale_cvd_threshold must be -1.0 to 1.0") + return v + + @field_validator("atr_percentile_min") + @classmethod + def validate_atr_pct(cls, v): + if not 5 <= v <= 80: + raise ValueError("atr_percentile_min must be 5-80") + return v + + @field_validator("spot_perp_threshold") + @classmethod + def validate_spot_perp(cls, v): + if not 0.0005 <= v <= 0.01: + raise ValueError("spot_perp_threshold must be 0.0005-0.01") + return v + + @field_validator("sl_atr_multiplier") + @classmethod + def validate_sl(cls, v): + if not 0.5 <= v <= 3.0: + raise ValueError("sl_atr_multiplier must be 0.5-3.0") + return v + + @field_validator("tp1_ratio") + @classmethod + def validate_tp1(cls, v): + if not 0.3 <= v <= 2.0: + raise ValueError("tp1_ratio must be 0.3-2.0") + return v + + @field_validator("tp2_ratio") + @classmethod + def validate_tp2(cls, v): + if not 0.5 <= v <= 4.0: + raise ValueError("tp2_ratio must be 0.5-4.0") + return v + + @field_validator("timeout_minutes") + @classmethod + def validate_timeout(cls, v): + if not 30 <= v <= 1440: + raise ValueError("timeout_minutes must be 30-1440") + return v + + @field_validator("flip_threshold") + @classmethod + def validate_flip(cls, v): + if not 60 <= v <= 95: + raise ValueError("flip_threshold must be 60-95") + return v + + @field_validator("initial_balance") + @classmethod + def validate_balance(cls, v): + if v < 1000: + raise ValueError("initial_balance must be >= 1000") + return v + + +class StrategyUpdateRequest(BaseModel): + """Partial update - all fields optional""" + display_name: Optional[str] = Field(None, min_length=1, max_length=50) + direction: Optional[str] = None + cvd_fast_window: Optional[str] = None + cvd_slow_window: Optional[str] = None + weight_direction: Optional[int] = None + weight_env: Optional[int] = None + weight_aux: Optional[int] = None + weight_momentum: Optional[int] = None + entry_score: Optional[int] = None + gate_obi_enabled: Optional[bool] = None + obi_threshold: Optional[float] = None + gate_whale_enabled: Optional[bool] = None + whale_cvd_threshold: Optional[float] = None + gate_vol_enabled: Optional[bool] = None + atr_percentile_min: Optional[int] = None + gate_spot_perp_enabled: Optional[bool] = None + spot_perp_threshold: Optional[float] = None + sl_atr_multiplier: Optional[float] = None + tp1_ratio: Optional[float] = None + tp2_ratio: Optional[float] = None + timeout_minutes: Optional[int] = None + flip_threshold: Optional[int] = None + description: Optional[str] = None + + +class AddBalanceRequest(BaseModel): + amount: float = Field(..., gt=0) + + +class DeprecateRequest(BaseModel): + confirm: bool + + +# ── Helper ────────────────────────────────────────────────────────────────── + +async def _get_strategy_or_404(strategy_id: str) -> dict: + row = await async_fetchrow( + "SELECT * FROM strategies WHERE strategy_id=$1", + strategy_id + ) + if not row: + raise HTTPException(status_code=404, detail="Strategy not found") + return dict(row) + + +def _strategy_row_to_card(row: dict) -> dict: + """Convert a strategies row to a card-level response (no config params)""" + return { + "strategy_id": str(row["strategy_id"]), + "display_name": row["display_name"], + "status": row["status"], + "symbol": row["symbol"], + "direction": row["direction"], + "started_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0, + "initial_balance": row["initial_balance"], + "current_balance": row["current_balance"], + "net_usdt": round(row["current_balance"] - row["initial_balance"], 2), + "deprecated_at": int(row["deprecated_at"].timestamp() * 1000) if row.get("deprecated_at") else None, + "last_run_at": int(row["last_run_at"].timestamp() * 1000) if row.get("last_run_at") else None, + "schema_version": row["schema_version"], + } + + +def _strategy_row_to_detail(row: dict) -> dict: + """Full detail including all config params""" + base = _strategy_row_to_card(row) + base.update({ + "cvd_fast_window": row["cvd_fast_window"], + "cvd_slow_window": row["cvd_slow_window"], + "weight_direction": row["weight_direction"], + "weight_env": row["weight_env"], + "weight_aux": row["weight_aux"], + "weight_momentum": row["weight_momentum"], + "entry_score": row["entry_score"], + "gate_obi_enabled": row["gate_obi_enabled"], + "obi_threshold": row["obi_threshold"], + "gate_whale_enabled": row["gate_whale_enabled"], + "whale_cvd_threshold": row["whale_cvd_threshold"], + "gate_vol_enabled": row["gate_vol_enabled"], + "atr_percentile_min": row["atr_percentile_min"], + "gate_spot_perp_enabled": row["gate_spot_perp_enabled"], + "spot_perp_threshold": row["spot_perp_threshold"], + "sl_atr_multiplier": row["sl_atr_multiplier"], + "tp1_ratio": row["tp1_ratio"], + "tp2_ratio": row["tp2_ratio"], + "timeout_minutes": row["timeout_minutes"], + "flip_threshold": row["flip_threshold"], + "description": row.get("description"), + "created_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0, + "updated_at": int(row["updated_at"].timestamp() * 1000) if row.get("updated_at") else 0, + }) + return base + + +async def _get_strategy_trade_stats(strategy_id: str) -> dict: + """Fetch trade statistics for a strategy by strategy_id""" + rows = await async_fetch( + """SELECT status, pnl_r, tp1_hit, entry_ts, exit_ts + FROM paper_trades + WHERE strategy_id=$1 AND status != 'active' + ORDER BY entry_ts DESC""", + strategy_id + ) + if not rows: + return { + "trade_count": 0, "win_rate": 0.0, + "avg_win_r": 0.0, "avg_loss_r": 0.0, + "open_positions": 0, + "pnl_usdt_24h": 0.0, "pnl_r_24h": 0.0, + "last_trade_at": None, + } + + total = len(rows) + wins = [r for r in rows if (r["pnl_r"] or 0) > 0] + losses = [r for r in rows if (r["pnl_r"] or 0) < 0] + win_rate = round(len(wins) / total * 100, 1) if total else 0.0 + avg_win = round(sum(r["pnl_r"] for r in wins) / len(wins), 3) if wins else 0.0 + avg_loss = round(sum(r["pnl_r"] for r in losses) / len(losses), 3) if losses else 0.0 + last_trade_at = rows[0]["exit_ts"] if rows else None + + # 24h stats + cutoff_ms = int((datetime.utcnow() - timedelta(hours=24)).timestamp() * 1000) + rows_24h = [r for r in rows if (r["exit_ts"] or 0) >= cutoff_ms] + pnl_r_24h = round(sum(r["pnl_r"] or 0 for r in rows_24h), 3) + pnl_usdt_24h = round(pnl_r_24h * 200, 2) + + # Open positions + open_rows = await async_fetch( + "SELECT COUNT(*) as cnt FROM paper_trades WHERE strategy_id=$1 AND status='active'", + strategy_id + ) + open_positions = open_rows[0]["cnt"] if open_rows else 0 + + return { + "trade_count": total, + "win_rate": win_rate, + "avg_win_r": avg_win, + "avg_loss_r": avg_loss, + "open_positions": open_positions, + "pnl_usdt_24h": pnl_usdt_24h, + "pnl_r_24h": pnl_r_24h, + "last_trade_at": last_trade_at, + } + + +# ── Endpoints ──────────────────────────────────────────────────────────────── + +@app.post("/api/strategies") +async def create_strategy(body: StrategyCreateRequest, user: dict = Depends(get_current_user)): + """创建新策略实例""" + new_id = str(_uuid.uuid4()) + await async_execute( + """INSERT INTO strategies ( + strategy_id, display_name, schema_version, status, + symbol, direction, + cvd_fast_window, cvd_slow_window, + weight_direction, weight_env, weight_aux, weight_momentum, + entry_score, + gate_obi_enabled, obi_threshold, + gate_whale_enabled, whale_cvd_threshold, + gate_vol_enabled, atr_percentile_min, + gate_spot_perp_enabled, spot_perp_threshold, + sl_atr_multiplier, tp1_ratio, tp2_ratio, + timeout_minutes, flip_threshold, + initial_balance, current_balance, + description + ) VALUES ( + $1,$2,1,'running', + $3,$4,$5,$6, + $7,$8,$9,$10, + $11, + $12,$13,$14,$15,$16,$17,$18,$19, + $20,$21,$22,$23,$24, + $25,$25,$26 + )""", + new_id, body.display_name, + body.symbol, body.direction, body.cvd_fast_window, body.cvd_slow_window, + body.weight_direction, body.weight_env, body.weight_aux, body.weight_momentum, + body.entry_score, + body.gate_obi_enabled, body.obi_threshold, + body.gate_whale_enabled, body.whale_cvd_threshold, + body.gate_vol_enabled, body.atr_percentile_min, + body.gate_spot_perp_enabled, body.spot_perp_threshold, + body.sl_atr_multiplier, body.tp1_ratio, body.tp2_ratio, + body.timeout_minutes, body.flip_threshold, + body.initial_balance, body.description + ) + row = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", new_id) + return {"ok": True, "strategy": _strategy_row_to_detail(dict(row))} + + +@app.get("/api/strategies") +async def list_strategies( + include_deprecated: bool = False, + user: dict = Depends(get_current_user) +): + """获取策略列表""" + if include_deprecated: + rows = await async_fetch("SELECT * FROM strategies ORDER BY created_at ASC") + else: + rows = await async_fetch( + "SELECT * FROM strategies WHERE status != 'deprecated' ORDER BY created_at ASC" + ) + result = [] + for row in rows: + d = _strategy_row_to_card(dict(row)) + stats = await _get_strategy_trade_stats(str(row["strategy_id"])) + d.update(stats) + result.append(d) + return {"strategies": result} + + +@app.get("/api/strategies/{sid}") +async def get_strategy(sid: str, user: dict = Depends(get_current_user)): + """获取单个策略详情(含完整参数配置)""" + row = await _get_strategy_or_404(sid) + detail = _strategy_row_to_detail(row) + stats = await _get_strategy_trade_stats(sid) + detail.update(stats) + return {"strategy": detail} + + +@app.patch("/api/strategies/{sid}") +async def update_strategy(sid: str, body: StrategyUpdateRequest, user: dict = Depends(get_current_user)): + """更新策略参数(Partial Update)""" + row = await _get_strategy_or_404(sid) + if row["status"] == "deprecated": + raise HTTPException(status_code=403, detail="Cannot modify a deprecated strategy") + + # Build SET clause dynamically from non-None fields + updates = body.model_dump(exclude_none=True) + if not updates: + raise HTTPException(status_code=400, detail="No fields to update") + + # Validate weights sum if any weight is being changed + weight_fields = {"weight_direction", "weight_env", "weight_aux", "weight_momentum"} + if weight_fields & set(updates.keys()): + w_dir = updates.get("weight_direction", row["weight_direction"]) + w_env = updates.get("weight_env", row["weight_env"]) + w_aux = updates.get("weight_aux", row["weight_aux"]) + w_mom = updates.get("weight_momentum", row["weight_momentum"]) + if w_dir + w_env + w_aux + w_mom != 100: + raise HTTPException(status_code=400, detail=f"Weights must sum to 100, got {w_dir+w_env+w_aux+w_mom}") + + # Validate individual field ranges + validators = { + "direction": lambda v: v in {"long_only", "short_only", "both"}, + "cvd_fast_window": lambda v: v in {"5m", "15m", "30m"}, + "cvd_slow_window": lambda v: v in {"30m", "1h", "4h"}, + "weight_direction": lambda v: 10 <= v <= 80, + "weight_env": lambda v: 5 <= v <= 60, + "weight_aux": lambda v: 0 <= v <= 40, + "weight_momentum": lambda v: 0 <= v <= 20, + "entry_score": lambda v: 60 <= v <= 95, + "obi_threshold": lambda v: 0.1 <= v <= 0.9, + "whale_cvd_threshold": lambda v: -1.0 <= v <= 1.0, + "atr_percentile_min": lambda v: 5 <= v <= 80, + "spot_perp_threshold": lambda v: 0.0005 <= v <= 0.01, + "sl_atr_multiplier": lambda v: 0.5 <= v <= 3.0, + "tp1_ratio": lambda v: 0.3 <= v <= 2.0, + "tp2_ratio": lambda v: 0.5 <= v <= 4.0, + "timeout_minutes": lambda v: 30 <= v <= 1440, + "flip_threshold": lambda v: 60 <= v <= 95, + } + for field, val in updates.items(): + if field in validators and not validators[field](val): + raise HTTPException(status_code=400, detail=f"Invalid value for {field}: {val}") + + # Execute update + set_parts = [f"{k}=${i+2}" for i, k in enumerate(updates.keys())] + set_parts.append(f"updated_at=NOW()") + sql = f"UPDATE strategies SET {', '.join(set_parts)} WHERE strategy_id=$1" + await async_execute(sql, sid, *updates.values()) + + updated = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", sid) + return {"ok": True, "strategy": _strategy_row_to_detail(dict(updated))} + + +@app.post("/api/strategies/{sid}/pause") +async def pause_strategy(sid: str, user: dict = Depends(get_current_user)): + """暂停策略(停止开新仓,不影响现有持仓)""" + row = await _get_strategy_or_404(sid) + if row["status"] == "deprecated": + raise HTTPException(status_code=403, detail="Cannot pause a deprecated strategy") + if row["status"] == "paused": + return {"ok": True, "message": "Already paused"} + await async_execute( + "UPDATE strategies SET status='paused', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1", + sid + ) + return {"ok": True, "status": "paused"} + + +@app.post("/api/strategies/{sid}/resume") +async def resume_strategy(sid: str, user: dict = Depends(get_current_user)): + """恢复策略""" + row = await _get_strategy_or_404(sid) + if row["status"] == "running": + return {"ok": True, "message": "Already running"} + await async_execute( + "UPDATE strategies SET status='running', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1", + sid + ) + return {"ok": True, "status": "running"} + + +@app.post("/api/strategies/{sid}/deprecate") +async def deprecate_strategy(sid: str, body: DeprecateRequest, user: dict = Depends(get_current_user)): + """废弃策略(数据永久保留,可重新启用)""" + if not body.confirm: + raise HTTPException(status_code=400, detail="Must set confirm=true to deprecate") + row = await _get_strategy_or_404(sid) + if row["status"] == "deprecated": + return {"ok": True, "message": "Already deprecated"} + await async_execute( + """UPDATE strategies + SET status='deprecated', deprecated_at=NOW(), + status_changed_at=NOW(), updated_at=NOW() + WHERE strategy_id=$1""", + sid + ) + return {"ok": True, "status": "deprecated"} + + +@app.post("/api/strategies/{sid}/restore") +async def restore_strategy(sid: str, user: dict = Depends(get_current_user)): + """重新启用废弃策略(继续原有余额和历史数据)""" + row = await _get_strategy_or_404(sid) + if row["status"] != "deprecated": + raise HTTPException(status_code=400, detail="Strategy is not deprecated") + await async_execute( + """UPDATE strategies + SET status='running', deprecated_at=NULL, + status_changed_at=NOW(), updated_at=NOW() + WHERE strategy_id=$1""", + sid + ) + return {"ok": True, "status": "running"} + + +@app.post("/api/strategies/{sid}/add-balance") +async def add_balance(sid: str, body: AddBalanceRequest, user: dict = Depends(get_current_user)): + """追加余额(initial_balance 和 current_balance 同步增加)""" + row = await _get_strategy_or_404(sid) + if row["status"] == "deprecated": + raise HTTPException(status_code=403, detail="Cannot add balance to a deprecated strategy") + new_initial = round(row["initial_balance"] + body.amount, 2) + new_current = round(row["current_balance"] + body.amount, 2) + await async_execute( + """UPDATE strategies + SET initial_balance=$2, current_balance=$3, updated_at=NOW() + WHERE strategy_id=$1""", + sid, new_initial, new_current + ) + return { + "ok": True, + "initial_balance": new_initial, + "current_balance": new_current, + "added": body.amount, + } diff --git a/backend/migrate_v54.py b/backend/migrate_v54.py new file mode 100644 index 0000000..8e8abac --- /dev/null +++ b/backend/migrate_v54.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +""" +V5.4 Strategy Factory DB Migration Script +- Creates `strategies` table +- Adds strategy_id + strategy_name_snapshot to paper_trades, signal_indicators +- Inserts existing 3 strategies with fixed UUIDs +- Backfills strategy_id + strategy_name_snapshot for all existing records +""" + +import os +import sys +import psycopg2 +from psycopg2.extras import execute_values + +PG_HOST = os.environ.get("PG_HOST", "10.106.0.3") +PG_PASS = os.environ.get("PG_PASS", "arb_engine_2026") +PG_USER = "arb" +PG_DB = "arb_engine" + +# Fixed UUIDs for existing strategies (deterministic, easy to recognize) +LEGACY_STRATEGY_MAP = { + "v53": ("00000000-0000-0000-0000-000000000053", "V5.3 Standard"), + "v53_middle": ("00000000-0000-0000-0000-000000000054", "V5.3 Middle"), + "v53_fast": ("00000000-0000-0000-0000-000000000055", "V5.3 Fast"), +} + +# Default config values per strategy (from strategy JSON files) +LEGACY_CONFIGS = { + "v53": { + "symbol": "BTCUSDT", # multi-symbol, use BTC as representative + "cvd_fast_window": "30m", + "cvd_slow_window": "4h", + "weight_direction": 55, + "weight_env": 25, + "weight_aux": 15, + "weight_momentum": 5, + "entry_score": 75, + "sl_atr_multiplier": 1.0, + "tp1_ratio": 0.75, + "tp2_ratio": 1.5, + "timeout_minutes": 60, + "flip_threshold": 75, + "status": "running", + "initial_balance": 10000.0, + }, + "v53_middle": { + "symbol": "BTCUSDT", + "cvd_fast_window": "15m", + "cvd_slow_window": "1h", + "weight_direction": 55, + "weight_env": 25, + "weight_aux": 15, + "weight_momentum": 5, + "entry_score": 75, + "sl_atr_multiplier": 1.0, + "tp1_ratio": 0.75, + "tp2_ratio": 1.5, + "timeout_minutes": 60, + "flip_threshold": 75, + "status": "running", + "initial_balance": 10000.0, + }, + "v53_fast": { + "symbol": "BTCUSDT", + "cvd_fast_window": "5m", + "cvd_slow_window": "30m", + "weight_direction": 55, + "weight_env": 25, + "weight_aux": 15, + "weight_momentum": 5, + "entry_score": 75, + "sl_atr_multiplier": 1.0, + "tp1_ratio": 0.75, + "tp2_ratio": 1.5, + "timeout_minutes": 60, + "flip_threshold": 75, + "status": "running", + "initial_balance": 10000.0, + }, +} + + +def get_conn(): + return psycopg2.connect( + host=PG_HOST, user=PG_USER, password=PG_PASS, dbname=PG_DB + ) + + +def step1_create_strategies_table(cur): + print("[Step 1] Creating strategies table...") + cur.execute(""" + CREATE TABLE IF NOT EXISTS strategies ( + strategy_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + display_name TEXT NOT NULL, + schema_version INT NOT NULL DEFAULT 1, + status TEXT NOT NULL DEFAULT 'running' + CHECK (status IN ('running', 'paused', 'deprecated')), + status_changed_at TIMESTAMP, + last_run_at TIMESTAMP, + deprecated_at TIMESTAMP, + symbol TEXT NOT NULL + CHECK (symbol IN ('BTCUSDT', 'ETHUSDT', 'SOLUSDT', 'XRPUSDT')), + direction TEXT NOT NULL DEFAULT 'both' + CHECK (direction IN ('long_only', 'short_only', 'both')), + cvd_fast_window TEXT NOT NULL DEFAULT '30m' + CHECK (cvd_fast_window IN ('5m', '15m', '30m')), + cvd_slow_window TEXT NOT NULL DEFAULT '4h' + CHECK (cvd_slow_window IN ('30m', '1h', '4h')), + weight_direction INT NOT NULL DEFAULT 55, + weight_env INT NOT NULL DEFAULT 25, + weight_aux INT NOT NULL DEFAULT 15, + weight_momentum INT NOT NULL DEFAULT 5, + entry_score INT NOT NULL DEFAULT 75, + gate_obi_enabled BOOL NOT NULL DEFAULT TRUE, + obi_threshold FLOAT NOT NULL DEFAULT 0.3, + gate_whale_enabled BOOL NOT NULL DEFAULT TRUE, + whale_cvd_threshold FLOAT NOT NULL DEFAULT 0.0, + gate_vol_enabled BOOL NOT NULL DEFAULT TRUE, + atr_percentile_min INT NOT NULL DEFAULT 20, + gate_spot_perp_enabled BOOL NOT NULL DEFAULT FALSE, + spot_perp_threshold FLOAT NOT NULL DEFAULT 0.002, + sl_atr_multiplier FLOAT NOT NULL DEFAULT 1.5, + tp1_ratio FLOAT NOT NULL DEFAULT 0.75, + tp2_ratio FLOAT NOT NULL DEFAULT 1.5, + timeout_minutes INT NOT NULL DEFAULT 240, + flip_threshold INT NOT NULL DEFAULT 80, + initial_balance FLOAT NOT NULL DEFAULT 10000.0, + current_balance FLOAT NOT NULL DEFAULT 10000.0, + description TEXT, + tags TEXT[], + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() + ) + """) + cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_status ON strategies(status)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_symbol ON strategies(symbol)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_last_run ON strategies(last_run_at)") + print("[Step 1] Done.") + + +def step2_add_columns(cur): + print("[Step 2] Adding strategy_id + strategy_name_snapshot columns...") + # paper_trades + for col, col_type in [ + ("strategy_id", "UUID REFERENCES strategies(strategy_id)"), + ("strategy_name_snapshot", "TEXT"), + ]: + cur.execute(f""" + ALTER TABLE paper_trades + ADD COLUMN IF NOT EXISTS {col} {col_type} + """) + # signal_indicators + for col, col_type in [ + ("strategy_id", "UUID REFERENCES strategies(strategy_id)"), + ("strategy_name_snapshot", "TEXT"), + ]: + cur.execute(f""" + ALTER TABLE signal_indicators + ADD COLUMN IF NOT EXISTS {col} {col_type} + """) + # Indexes + cur.execute("CREATE INDEX IF NOT EXISTS idx_paper_trades_strategy_id ON paper_trades(strategy_id)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_si_strategy_id ON signal_indicators(strategy_id)") + print("[Step 2] Done.") + + +def step3_insert_legacy_strategies(cur): + print("[Step 3] Inserting legacy strategies into strategies table...") + for strategy_name, (uuid, display_name) in LEGACY_STRATEGY_MAP.items(): + cfg = LEGACY_CONFIGS[strategy_name] + # Compute current_balance from actual paper trades + cur.execute(""" + SELECT + COALESCE(SUM(pnl_r) * 200, 0) as total_pnl_usdt + FROM paper_trades + WHERE strategy = %s AND status != 'active' + """, (strategy_name,)) + row = cur.fetchone() + pnl_usdt = row[0] if row else 0 + current_balance = round(cfg["initial_balance"] + pnl_usdt, 2) + + cur.execute(""" + INSERT INTO strategies ( + strategy_id, display_name, schema_version, status, + symbol, direction, + cvd_fast_window, cvd_slow_window, + weight_direction, weight_env, weight_aux, weight_momentum, + entry_score, + gate_obi_enabled, obi_threshold, + gate_whale_enabled, whale_cvd_threshold, + gate_vol_enabled, atr_percentile_min, + gate_spot_perp_enabled, spot_perp_threshold, + sl_atr_multiplier, tp1_ratio, tp2_ratio, + timeout_minutes, flip_threshold, + initial_balance, current_balance, + description + ) VALUES ( + %s, %s, 1, %s, + %s, 'both', + %s, %s, + %s, %s, %s, %s, + %s, + TRUE, 0.3, + TRUE, 0.0, + TRUE, 20, + FALSE, 0.002, + %s, %s, %s, + %s, %s, + %s, %s, + %s + ) + ON CONFLICT (strategy_id) DO NOTHING + """, ( + uuid, display_name, cfg["status"], + cfg["symbol"], cfg["cvd_fast_window"], cfg["cvd_slow_window"], + cfg["weight_direction"], cfg["weight_env"], cfg["weight_aux"], cfg["weight_momentum"], + cfg["entry_score"], + cfg["sl_atr_multiplier"], cfg["tp1_ratio"], cfg["tp2_ratio"], + cfg["timeout_minutes"], cfg["flip_threshold"], + cfg["initial_balance"], current_balance, + f"Migrated from V5.3 legacy strategy: {strategy_name}" + )) + print(f" Inserted {strategy_name} → {uuid} (balance: {current_balance})") + print("[Step 3] Done.") + + +def step4_backfill(cur): + print("[Step 4] Backfilling strategy_id + strategy_name_snapshot...") + for strategy_name, (uuid, display_name) in LEGACY_STRATEGY_MAP.items(): + # paper_trades + cur.execute(""" + UPDATE paper_trades + SET strategy_id = %s::uuid, + strategy_name_snapshot = %s + WHERE strategy = %s AND strategy_id IS NULL + """, (uuid, display_name, strategy_name)) + count = cur.rowcount + print(f" paper_trades [{strategy_name}]: {count} rows updated") + + # signal_indicators + cur.execute(""" + UPDATE signal_indicators + SET strategy_id = %s::uuid, + strategy_name_snapshot = %s + WHERE strategy = %s AND strategy_id IS NULL + """, (uuid, display_name, strategy_name)) + count = cur.rowcount + print(f" signal_indicators [{strategy_name}]: {count} rows updated") + + print("[Step 4] Done.") + + +def step5_verify(cur): + print("[Step 5] Verifying migration completeness...") + # Check strategies table + cur.execute("SELECT COUNT(*) FROM strategies") + n = cur.fetchone()[0] + print(f" strategies table: {n} rows") + + # Check NULL strategy_id in paper_trades (for known strategies) + cur.execute(""" + SELECT strategy, COUNT(*) as cnt + FROM paper_trades + WHERE strategy IN ('v53', 'v53_middle', 'v53_fast') + AND strategy_id IS NULL + GROUP BY strategy + """) + rows = cur.fetchall() + if rows: + print(f" WARNING: NULL strategy_id found in paper_trades:") + for r in rows: + print(f" {r[0]}: {r[1]} rows") + else: + print(" paper_trades: all known strategies backfilled ✅") + + # Check NULL in signal_indicators + cur.execute(""" + SELECT strategy, COUNT(*) as cnt + FROM signal_indicators + WHERE strategy IN ('v53', 'v53_middle', 'v53_fast') + AND strategy_id IS NULL + GROUP BY strategy + """) + rows = cur.fetchall() + if rows: + print(f" WARNING: NULL strategy_id found in signal_indicators:") + for r in rows: + print(f" {r[0]}: {r[1]} rows") + else: + print(" signal_indicators: all known strategies backfilled ✅") + + print("[Step 5] Done.") + + +def main(): + dry_run = "--dry-run" in sys.argv + if dry_run: + print("=== DRY RUN MODE (no changes will be committed) ===") + + conn = get_conn() + conn.autocommit = False + cur = conn.cursor() + + try: + step1_create_strategies_table(cur) + step2_add_columns(cur) + step3_insert_legacy_strategies(cur) + step4_backfill(cur) + step5_verify(cur) + + if dry_run: + conn.rollback() + print("\n=== DRY RUN: rolled back all changes ===") + else: + conn.commit() + print("\n=== Migration completed successfully ✅ ===") + except Exception as e: + conn.rollback() + print(f"\n=== ERROR: {e} ===") + raise + finally: + cur.close() + conn.close() + + +if __name__ == "__main__": + main()