feat(v54): add strategies table, migration script, 9 CRUD API endpoints

This commit is contained in:
root 2026-03-11 15:11:44 +00:00
parent b1ed55382c
commit 7be7b5b4c0
2 changed files with 869 additions and 0 deletions

View File

@ -2179,3 +2179,545 @@ async def strategy_plaza_trades(
strategy_id, limit strategy_id, limit
) )
return {"trades": [dict(r) for r in rows]} return {"trades": [dict(r) for r in rows]}
# ─────────────────────────────────────────────────────────────────────────────
# V5.4 Strategy Factory API
# ─────────────────────────────────────────────────────────────────────────────
import uuid as _uuid
from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator
# ── Pydantic Models ──────────────────────────────────────────────────────────
class StrategyCreateRequest(BaseModel):
display_name: str = Field(..., min_length=1, max_length=50)
symbol: str
direction: str = "both"
initial_balance: float = 10000.0
cvd_fast_window: str = "30m"
cvd_slow_window: str = "4h"
weight_direction: int = 55
weight_env: int = 25
weight_aux: int = 15
weight_momentum: int = 5
entry_score: int = 75
gate_obi_enabled: bool = True
obi_threshold: float = 0.3
gate_whale_enabled: bool = True
whale_cvd_threshold: float = 0.0
gate_vol_enabled: bool = True
atr_percentile_min: int = 20
gate_spot_perp_enabled: bool = False
spot_perp_threshold: float = 0.002
sl_atr_multiplier: float = 1.5
tp1_ratio: float = 0.75
tp2_ratio: float = 1.5
timeout_minutes: int = 240
flip_threshold: int = 80
description: Optional[str] = None
@field_validator("symbol")
@classmethod
def validate_symbol(cls, v):
allowed = {"BTCUSDT", "ETHUSDT", "SOLUSDT", "XRPUSDT"}
if v not in allowed:
raise ValueError(f"symbol must be one of {allowed}")
return v
@field_validator("direction")
@classmethod
def validate_direction(cls, v):
if v not in {"long_only", "short_only", "both"}:
raise ValueError("direction must be long_only, short_only, or both")
return v
@field_validator("cvd_fast_window")
@classmethod
def validate_cvd_fast(cls, v):
if v not in {"5m", "15m", "30m"}:
raise ValueError("cvd_fast_window must be 5m, 15m, or 30m")
return v
@field_validator("cvd_slow_window")
@classmethod
def validate_cvd_slow(cls, v):
if v not in {"30m", "1h", "4h"}:
raise ValueError("cvd_slow_window must be 30m, 1h, or 4h")
return v
@field_validator("weight_direction")
@classmethod
def validate_w_dir(cls, v):
if not 10 <= v <= 80:
raise ValueError("weight_direction must be 10-80")
return v
@field_validator("weight_env")
@classmethod
def validate_w_env(cls, v):
if not 5 <= v <= 60:
raise ValueError("weight_env must be 5-60")
return v
@field_validator("weight_aux")
@classmethod
def validate_w_aux(cls, v):
if not 0 <= v <= 40:
raise ValueError("weight_aux must be 0-40")
return v
@field_validator("weight_momentum")
@classmethod
def validate_w_mom(cls, v):
if not 0 <= v <= 20:
raise ValueError("weight_momentum must be 0-20")
return v
@model_validator(mode="after")
def validate_weights_sum(self):
total = self.weight_direction + self.weight_env + self.weight_aux + self.weight_momentum
if total != 100:
raise ValueError(f"Weights must sum to 100, got {total}")
return self
@field_validator("entry_score")
@classmethod
def validate_entry_score(cls, v):
if not 60 <= v <= 95:
raise ValueError("entry_score must be 60-95")
return v
@field_validator("obi_threshold")
@classmethod
def validate_obi(cls, v):
if not 0.1 <= v <= 0.9:
raise ValueError("obi_threshold must be 0.1-0.9")
return v
@field_validator("whale_cvd_threshold")
@classmethod
def validate_whale(cls, v):
if not -1.0 <= v <= 1.0:
raise ValueError("whale_cvd_threshold must be -1.0 to 1.0")
return v
@field_validator("atr_percentile_min")
@classmethod
def validate_atr_pct(cls, v):
if not 5 <= v <= 80:
raise ValueError("atr_percentile_min must be 5-80")
return v
@field_validator("spot_perp_threshold")
@classmethod
def validate_spot_perp(cls, v):
if not 0.0005 <= v <= 0.01:
raise ValueError("spot_perp_threshold must be 0.0005-0.01")
return v
@field_validator("sl_atr_multiplier")
@classmethod
def validate_sl(cls, v):
if not 0.5 <= v <= 3.0:
raise ValueError("sl_atr_multiplier must be 0.5-3.0")
return v
@field_validator("tp1_ratio")
@classmethod
def validate_tp1(cls, v):
if not 0.3 <= v <= 2.0:
raise ValueError("tp1_ratio must be 0.3-2.0")
return v
@field_validator("tp2_ratio")
@classmethod
def validate_tp2(cls, v):
if not 0.5 <= v <= 4.0:
raise ValueError("tp2_ratio must be 0.5-4.0")
return v
@field_validator("timeout_minutes")
@classmethod
def validate_timeout(cls, v):
if not 30 <= v <= 1440:
raise ValueError("timeout_minutes must be 30-1440")
return v
@field_validator("flip_threshold")
@classmethod
def validate_flip(cls, v):
if not 60 <= v <= 95:
raise ValueError("flip_threshold must be 60-95")
return v
@field_validator("initial_balance")
@classmethod
def validate_balance(cls, v):
if v < 1000:
raise ValueError("initial_balance must be >= 1000")
return v
class StrategyUpdateRequest(BaseModel):
"""Partial update - all fields optional"""
display_name: Optional[str] = Field(None, min_length=1, max_length=50)
direction: Optional[str] = None
cvd_fast_window: Optional[str] = None
cvd_slow_window: Optional[str] = None
weight_direction: Optional[int] = None
weight_env: Optional[int] = None
weight_aux: Optional[int] = None
weight_momentum: Optional[int] = None
entry_score: Optional[int] = None
gate_obi_enabled: Optional[bool] = None
obi_threshold: Optional[float] = None
gate_whale_enabled: Optional[bool] = None
whale_cvd_threshold: Optional[float] = None
gate_vol_enabled: Optional[bool] = None
atr_percentile_min: Optional[int] = None
gate_spot_perp_enabled: Optional[bool] = None
spot_perp_threshold: Optional[float] = None
sl_atr_multiplier: Optional[float] = None
tp1_ratio: Optional[float] = None
tp2_ratio: Optional[float] = None
timeout_minutes: Optional[int] = None
flip_threshold: Optional[int] = None
description: Optional[str] = None
class AddBalanceRequest(BaseModel):
amount: float = Field(..., gt=0)
class DeprecateRequest(BaseModel):
confirm: bool
# ── Helper ──────────────────────────────────────────────────────────────────
async def _get_strategy_or_404(strategy_id: str) -> dict:
row = await async_fetchrow(
"SELECT * FROM strategies WHERE strategy_id=$1",
strategy_id
)
if not row:
raise HTTPException(status_code=404, detail="Strategy not found")
return dict(row)
def _strategy_row_to_card(row: dict) -> dict:
"""Convert a strategies row to a card-level response (no config params)"""
return {
"strategy_id": str(row["strategy_id"]),
"display_name": row["display_name"],
"status": row["status"],
"symbol": row["symbol"],
"direction": row["direction"],
"started_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0,
"initial_balance": row["initial_balance"],
"current_balance": row["current_balance"],
"net_usdt": round(row["current_balance"] - row["initial_balance"], 2),
"deprecated_at": int(row["deprecated_at"].timestamp() * 1000) if row.get("deprecated_at") else None,
"last_run_at": int(row["last_run_at"].timestamp() * 1000) if row.get("last_run_at") else None,
"schema_version": row["schema_version"],
}
def _strategy_row_to_detail(row: dict) -> dict:
"""Full detail including all config params"""
base = _strategy_row_to_card(row)
base.update({
"cvd_fast_window": row["cvd_fast_window"],
"cvd_slow_window": row["cvd_slow_window"],
"weight_direction": row["weight_direction"],
"weight_env": row["weight_env"],
"weight_aux": row["weight_aux"],
"weight_momentum": row["weight_momentum"],
"entry_score": row["entry_score"],
"gate_obi_enabled": row["gate_obi_enabled"],
"obi_threshold": row["obi_threshold"],
"gate_whale_enabled": row["gate_whale_enabled"],
"whale_cvd_threshold": row["whale_cvd_threshold"],
"gate_vol_enabled": row["gate_vol_enabled"],
"atr_percentile_min": row["atr_percentile_min"],
"gate_spot_perp_enabled": row["gate_spot_perp_enabled"],
"spot_perp_threshold": row["spot_perp_threshold"],
"sl_atr_multiplier": row["sl_atr_multiplier"],
"tp1_ratio": row["tp1_ratio"],
"tp2_ratio": row["tp2_ratio"],
"timeout_minutes": row["timeout_minutes"],
"flip_threshold": row["flip_threshold"],
"description": row.get("description"),
"created_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0,
"updated_at": int(row["updated_at"].timestamp() * 1000) if row.get("updated_at") else 0,
})
return base
async def _get_strategy_trade_stats(strategy_id: str) -> dict:
"""Fetch trade statistics for a strategy by strategy_id"""
rows = await async_fetch(
"""SELECT status, pnl_r, tp1_hit, entry_ts, exit_ts
FROM paper_trades
WHERE strategy_id=$1 AND status != 'active'
ORDER BY entry_ts DESC""",
strategy_id
)
if not rows:
return {
"trade_count": 0, "win_rate": 0.0,
"avg_win_r": 0.0, "avg_loss_r": 0.0,
"open_positions": 0,
"pnl_usdt_24h": 0.0, "pnl_r_24h": 0.0,
"last_trade_at": None,
}
total = len(rows)
wins = [r for r in rows if (r["pnl_r"] or 0) > 0]
losses = [r for r in rows if (r["pnl_r"] or 0) < 0]
win_rate = round(len(wins) / total * 100, 1) if total else 0.0
avg_win = round(sum(r["pnl_r"] for r in wins) / len(wins), 3) if wins else 0.0
avg_loss = round(sum(r["pnl_r"] for r in losses) / len(losses), 3) if losses else 0.0
last_trade_at = rows[0]["exit_ts"] if rows else None
# 24h stats
cutoff_ms = int((datetime.utcnow() - timedelta(hours=24)).timestamp() * 1000)
rows_24h = [r for r in rows if (r["exit_ts"] or 0) >= cutoff_ms]
pnl_r_24h = round(sum(r["pnl_r"] or 0 for r in rows_24h), 3)
pnl_usdt_24h = round(pnl_r_24h * 200, 2)
# Open positions
open_rows = await async_fetch(
"SELECT COUNT(*) as cnt FROM paper_trades WHERE strategy_id=$1 AND status='active'",
strategy_id
)
open_positions = open_rows[0]["cnt"] if open_rows else 0
return {
"trade_count": total,
"win_rate": win_rate,
"avg_win_r": avg_win,
"avg_loss_r": avg_loss,
"open_positions": open_positions,
"pnl_usdt_24h": pnl_usdt_24h,
"pnl_r_24h": pnl_r_24h,
"last_trade_at": last_trade_at,
}
# ── Endpoints ────────────────────────────────────────────────────────────────
@app.post("/api/strategies")
async def create_strategy(body: StrategyCreateRequest, user: dict = Depends(get_current_user)):
"""创建新策略实例"""
new_id = str(_uuid.uuid4())
await async_execute(
"""INSERT INTO strategies (
strategy_id, display_name, schema_version, status,
symbol, direction,
cvd_fast_window, cvd_slow_window,
weight_direction, weight_env, weight_aux, weight_momentum,
entry_score,
gate_obi_enabled, obi_threshold,
gate_whale_enabled, whale_cvd_threshold,
gate_vol_enabled, atr_percentile_min,
gate_spot_perp_enabled, spot_perp_threshold,
sl_atr_multiplier, tp1_ratio, tp2_ratio,
timeout_minutes, flip_threshold,
initial_balance, current_balance,
description
) VALUES (
$1,$2,1,'running',
$3,$4,$5,$6,
$7,$8,$9,$10,
$11,
$12,$13,$14,$15,$16,$17,$18,$19,
$20,$21,$22,$23,$24,
$25,$25,$26
)""",
new_id, body.display_name,
body.symbol, body.direction, body.cvd_fast_window, body.cvd_slow_window,
body.weight_direction, body.weight_env, body.weight_aux, body.weight_momentum,
body.entry_score,
body.gate_obi_enabled, body.obi_threshold,
body.gate_whale_enabled, body.whale_cvd_threshold,
body.gate_vol_enabled, body.atr_percentile_min,
body.gate_spot_perp_enabled, body.spot_perp_threshold,
body.sl_atr_multiplier, body.tp1_ratio, body.tp2_ratio,
body.timeout_minutes, body.flip_threshold,
body.initial_balance, body.description
)
row = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", new_id)
return {"ok": True, "strategy": _strategy_row_to_detail(dict(row))}
@app.get("/api/strategies")
async def list_strategies(
include_deprecated: bool = False,
user: dict = Depends(get_current_user)
):
"""获取策略列表"""
if include_deprecated:
rows = await async_fetch("SELECT * FROM strategies ORDER BY created_at ASC")
else:
rows = await async_fetch(
"SELECT * FROM strategies WHERE status != 'deprecated' ORDER BY created_at ASC"
)
result = []
for row in rows:
d = _strategy_row_to_card(dict(row))
stats = await _get_strategy_trade_stats(str(row["strategy_id"]))
d.update(stats)
result.append(d)
return {"strategies": result}
@app.get("/api/strategies/{sid}")
async def get_strategy(sid: str, user: dict = Depends(get_current_user)):
"""获取单个策略详情(含完整参数配置)"""
row = await _get_strategy_or_404(sid)
detail = _strategy_row_to_detail(row)
stats = await _get_strategy_trade_stats(sid)
detail.update(stats)
return {"strategy": detail}
@app.patch("/api/strategies/{sid}")
async def update_strategy(sid: str, body: StrategyUpdateRequest, user: dict = Depends(get_current_user)):
"""更新策略参数Partial Update"""
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
raise HTTPException(status_code=403, detail="Cannot modify a deprecated strategy")
# Build SET clause dynamically from non-None fields
updates = body.model_dump(exclude_none=True)
if not updates:
raise HTTPException(status_code=400, detail="No fields to update")
# Validate weights sum if any weight is being changed
weight_fields = {"weight_direction", "weight_env", "weight_aux", "weight_momentum"}
if weight_fields & set(updates.keys()):
w_dir = updates.get("weight_direction", row["weight_direction"])
w_env = updates.get("weight_env", row["weight_env"])
w_aux = updates.get("weight_aux", row["weight_aux"])
w_mom = updates.get("weight_momentum", row["weight_momentum"])
if w_dir + w_env + w_aux + w_mom != 100:
raise HTTPException(status_code=400, detail=f"Weights must sum to 100, got {w_dir+w_env+w_aux+w_mom}")
# Validate individual field ranges
validators = {
"direction": lambda v: v in {"long_only", "short_only", "both"},
"cvd_fast_window": lambda v: v in {"5m", "15m", "30m"},
"cvd_slow_window": lambda v: v in {"30m", "1h", "4h"},
"weight_direction": lambda v: 10 <= v <= 80,
"weight_env": lambda v: 5 <= v <= 60,
"weight_aux": lambda v: 0 <= v <= 40,
"weight_momentum": lambda v: 0 <= v <= 20,
"entry_score": lambda v: 60 <= v <= 95,
"obi_threshold": lambda v: 0.1 <= v <= 0.9,
"whale_cvd_threshold": lambda v: -1.0 <= v <= 1.0,
"atr_percentile_min": lambda v: 5 <= v <= 80,
"spot_perp_threshold": lambda v: 0.0005 <= v <= 0.01,
"sl_atr_multiplier": lambda v: 0.5 <= v <= 3.0,
"tp1_ratio": lambda v: 0.3 <= v <= 2.0,
"tp2_ratio": lambda v: 0.5 <= v <= 4.0,
"timeout_minutes": lambda v: 30 <= v <= 1440,
"flip_threshold": lambda v: 60 <= v <= 95,
}
for field, val in updates.items():
if field in validators and not validators[field](val):
raise HTTPException(status_code=400, detail=f"Invalid value for {field}: {val}")
# Execute update
set_parts = [f"{k}=${i+2}" for i, k in enumerate(updates.keys())]
set_parts.append(f"updated_at=NOW()")
sql = f"UPDATE strategies SET {', '.join(set_parts)} WHERE strategy_id=$1"
await async_execute(sql, sid, *updates.values())
updated = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", sid)
return {"ok": True, "strategy": _strategy_row_to_detail(dict(updated))}
@app.post("/api/strategies/{sid}/pause")
async def pause_strategy(sid: str, user: dict = Depends(get_current_user)):
"""暂停策略(停止开新仓,不影响现有持仓)"""
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
raise HTTPException(status_code=403, detail="Cannot pause a deprecated strategy")
if row["status"] == "paused":
return {"ok": True, "message": "Already paused"}
await async_execute(
"UPDATE strategies SET status='paused', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1",
sid
)
return {"ok": True, "status": "paused"}
@app.post("/api/strategies/{sid}/resume")
async def resume_strategy(sid: str, user: dict = Depends(get_current_user)):
"""恢复策略"""
row = await _get_strategy_or_404(sid)
if row["status"] == "running":
return {"ok": True, "message": "Already running"}
await async_execute(
"UPDATE strategies SET status='running', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1",
sid
)
return {"ok": True, "status": "running"}
@app.post("/api/strategies/{sid}/deprecate")
async def deprecate_strategy(sid: str, body: DeprecateRequest, user: dict = Depends(get_current_user)):
"""废弃策略(数据永久保留,可重新启用)"""
if not body.confirm:
raise HTTPException(status_code=400, detail="Must set confirm=true to deprecate")
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
return {"ok": True, "message": "Already deprecated"}
await async_execute(
"""UPDATE strategies
SET status='deprecated', deprecated_at=NOW(),
status_changed_at=NOW(), updated_at=NOW()
WHERE strategy_id=$1""",
sid
)
return {"ok": True, "status": "deprecated"}
@app.post("/api/strategies/{sid}/restore")
async def restore_strategy(sid: str, user: dict = Depends(get_current_user)):
"""重新启用废弃策略(继续原有余额和历史数据)"""
row = await _get_strategy_or_404(sid)
if row["status"] != "deprecated":
raise HTTPException(status_code=400, detail="Strategy is not deprecated")
await async_execute(
"""UPDATE strategies
SET status='running', deprecated_at=NULL,
status_changed_at=NOW(), updated_at=NOW()
WHERE strategy_id=$1""",
sid
)
return {"ok": True, "status": "running"}
@app.post("/api/strategies/{sid}/add-balance")
async def add_balance(sid: str, body: AddBalanceRequest, user: dict = Depends(get_current_user)):
"""追加余额initial_balance 和 current_balance 同步增加)"""
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
raise HTTPException(status_code=403, detail="Cannot add balance to a deprecated strategy")
new_initial = round(row["initial_balance"] + body.amount, 2)
new_current = round(row["current_balance"] + body.amount, 2)
await async_execute(
"""UPDATE strategies
SET initial_balance=$2, current_balance=$3, updated_at=NOW()
WHERE strategy_id=$1""",
sid, new_initial, new_current
)
return {
"ok": True,
"initial_balance": new_initial,
"current_balance": new_current,
"added": body.amount,
}

327
backend/migrate_v54.py Normal file
View File

@ -0,0 +1,327 @@
#!/usr/bin/env python3
"""
V5.4 Strategy Factory DB Migration Script
- Creates `strategies` table
- Adds strategy_id + strategy_name_snapshot to paper_trades, signal_indicators
- Inserts existing 3 strategies with fixed UUIDs
- Backfills strategy_id + strategy_name_snapshot for all existing records
"""
import os
import sys
import psycopg2
from psycopg2.extras import execute_values
PG_HOST = os.environ.get("PG_HOST", "10.106.0.3")
PG_PASS = os.environ.get("PG_PASS", "arb_engine_2026")
PG_USER = "arb"
PG_DB = "arb_engine"
# Fixed UUIDs for existing strategies (deterministic, easy to recognize)
LEGACY_STRATEGY_MAP = {
"v53": ("00000000-0000-0000-0000-000000000053", "V5.3 Standard"),
"v53_middle": ("00000000-0000-0000-0000-000000000054", "V5.3 Middle"),
"v53_fast": ("00000000-0000-0000-0000-000000000055", "V5.3 Fast"),
}
# Default config values per strategy (from strategy JSON files)
LEGACY_CONFIGS = {
"v53": {
"symbol": "BTCUSDT", # multi-symbol, use BTC as representative
"cvd_fast_window": "30m",
"cvd_slow_window": "4h",
"weight_direction": 55,
"weight_env": 25,
"weight_aux": 15,
"weight_momentum": 5,
"entry_score": 75,
"sl_atr_multiplier": 1.0,
"tp1_ratio": 0.75,
"tp2_ratio": 1.5,
"timeout_minutes": 60,
"flip_threshold": 75,
"status": "running",
"initial_balance": 10000.0,
},
"v53_middle": {
"symbol": "BTCUSDT",
"cvd_fast_window": "15m",
"cvd_slow_window": "1h",
"weight_direction": 55,
"weight_env": 25,
"weight_aux": 15,
"weight_momentum": 5,
"entry_score": 75,
"sl_atr_multiplier": 1.0,
"tp1_ratio": 0.75,
"tp2_ratio": 1.5,
"timeout_minutes": 60,
"flip_threshold": 75,
"status": "running",
"initial_balance": 10000.0,
},
"v53_fast": {
"symbol": "BTCUSDT",
"cvd_fast_window": "5m",
"cvd_slow_window": "30m",
"weight_direction": 55,
"weight_env": 25,
"weight_aux": 15,
"weight_momentum": 5,
"entry_score": 75,
"sl_atr_multiplier": 1.0,
"tp1_ratio": 0.75,
"tp2_ratio": 1.5,
"timeout_minutes": 60,
"flip_threshold": 75,
"status": "running",
"initial_balance": 10000.0,
},
}
def get_conn():
return psycopg2.connect(
host=PG_HOST, user=PG_USER, password=PG_PASS, dbname=PG_DB
)
def step1_create_strategies_table(cur):
print("[Step 1] Creating strategies table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS strategies (
strategy_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
display_name TEXT NOT NULL,
schema_version INT NOT NULL DEFAULT 1,
status TEXT NOT NULL DEFAULT 'running'
CHECK (status IN ('running', 'paused', 'deprecated')),
status_changed_at TIMESTAMP,
last_run_at TIMESTAMP,
deprecated_at TIMESTAMP,
symbol TEXT NOT NULL
CHECK (symbol IN ('BTCUSDT', 'ETHUSDT', 'SOLUSDT', 'XRPUSDT')),
direction TEXT NOT NULL DEFAULT 'both'
CHECK (direction IN ('long_only', 'short_only', 'both')),
cvd_fast_window TEXT NOT NULL DEFAULT '30m'
CHECK (cvd_fast_window IN ('5m', '15m', '30m')),
cvd_slow_window TEXT NOT NULL DEFAULT '4h'
CHECK (cvd_slow_window IN ('30m', '1h', '4h')),
weight_direction INT NOT NULL DEFAULT 55,
weight_env INT NOT NULL DEFAULT 25,
weight_aux INT NOT NULL DEFAULT 15,
weight_momentum INT NOT NULL DEFAULT 5,
entry_score INT NOT NULL DEFAULT 75,
gate_obi_enabled BOOL NOT NULL DEFAULT TRUE,
obi_threshold FLOAT NOT NULL DEFAULT 0.3,
gate_whale_enabled BOOL NOT NULL DEFAULT TRUE,
whale_cvd_threshold FLOAT NOT NULL DEFAULT 0.0,
gate_vol_enabled BOOL NOT NULL DEFAULT TRUE,
atr_percentile_min INT NOT NULL DEFAULT 20,
gate_spot_perp_enabled BOOL NOT NULL DEFAULT FALSE,
spot_perp_threshold FLOAT NOT NULL DEFAULT 0.002,
sl_atr_multiplier FLOAT NOT NULL DEFAULT 1.5,
tp1_ratio FLOAT NOT NULL DEFAULT 0.75,
tp2_ratio FLOAT NOT NULL DEFAULT 1.5,
timeout_minutes INT NOT NULL DEFAULT 240,
flip_threshold INT NOT NULL DEFAULT 80,
initial_balance FLOAT NOT NULL DEFAULT 10000.0,
current_balance FLOAT NOT NULL DEFAULT 10000.0,
description TEXT,
tags TEXT[],
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
)
""")
cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_status ON strategies(status)")
cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_symbol ON strategies(symbol)")
cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_last_run ON strategies(last_run_at)")
print("[Step 1] Done.")
def step2_add_columns(cur):
print("[Step 2] Adding strategy_id + strategy_name_snapshot columns...")
# paper_trades
for col, col_type in [
("strategy_id", "UUID REFERENCES strategies(strategy_id)"),
("strategy_name_snapshot", "TEXT"),
]:
cur.execute(f"""
ALTER TABLE paper_trades
ADD COLUMN IF NOT EXISTS {col} {col_type}
""")
# signal_indicators
for col, col_type in [
("strategy_id", "UUID REFERENCES strategies(strategy_id)"),
("strategy_name_snapshot", "TEXT"),
]:
cur.execute(f"""
ALTER TABLE signal_indicators
ADD COLUMN IF NOT EXISTS {col} {col_type}
""")
# Indexes
cur.execute("CREATE INDEX IF NOT EXISTS idx_paper_trades_strategy_id ON paper_trades(strategy_id)")
cur.execute("CREATE INDEX IF NOT EXISTS idx_si_strategy_id ON signal_indicators(strategy_id)")
print("[Step 2] Done.")
def step3_insert_legacy_strategies(cur):
print("[Step 3] Inserting legacy strategies into strategies table...")
for strategy_name, (uuid, display_name) in LEGACY_STRATEGY_MAP.items():
cfg = LEGACY_CONFIGS[strategy_name]
# Compute current_balance from actual paper trades
cur.execute("""
SELECT
COALESCE(SUM(pnl_r) * 200, 0) as total_pnl_usdt
FROM paper_trades
WHERE strategy = %s AND status != 'active'
""", (strategy_name,))
row = cur.fetchone()
pnl_usdt = row[0] if row else 0
current_balance = round(cfg["initial_balance"] + pnl_usdt, 2)
cur.execute("""
INSERT INTO strategies (
strategy_id, display_name, schema_version, status,
symbol, direction,
cvd_fast_window, cvd_slow_window,
weight_direction, weight_env, weight_aux, weight_momentum,
entry_score,
gate_obi_enabled, obi_threshold,
gate_whale_enabled, whale_cvd_threshold,
gate_vol_enabled, atr_percentile_min,
gate_spot_perp_enabled, spot_perp_threshold,
sl_atr_multiplier, tp1_ratio, tp2_ratio,
timeout_minutes, flip_threshold,
initial_balance, current_balance,
description
) VALUES (
%s, %s, 1, %s,
%s, 'both',
%s, %s,
%s, %s, %s, %s,
%s,
TRUE, 0.3,
TRUE, 0.0,
TRUE, 20,
FALSE, 0.002,
%s, %s, %s,
%s, %s,
%s, %s,
%s
)
ON CONFLICT (strategy_id) DO NOTHING
""", (
uuid, display_name, cfg["status"],
cfg["symbol"], cfg["cvd_fast_window"], cfg["cvd_slow_window"],
cfg["weight_direction"], cfg["weight_env"], cfg["weight_aux"], cfg["weight_momentum"],
cfg["entry_score"],
cfg["sl_atr_multiplier"], cfg["tp1_ratio"], cfg["tp2_ratio"],
cfg["timeout_minutes"], cfg["flip_threshold"],
cfg["initial_balance"], current_balance,
f"Migrated from V5.3 legacy strategy: {strategy_name}"
))
print(f" Inserted {strategy_name}{uuid} (balance: {current_balance})")
print("[Step 3] Done.")
def step4_backfill(cur):
print("[Step 4] Backfilling strategy_id + strategy_name_snapshot...")
for strategy_name, (uuid, display_name) in LEGACY_STRATEGY_MAP.items():
# paper_trades
cur.execute("""
UPDATE paper_trades
SET strategy_id = %s::uuid,
strategy_name_snapshot = %s
WHERE strategy = %s AND strategy_id IS NULL
""", (uuid, display_name, strategy_name))
count = cur.rowcount
print(f" paper_trades [{strategy_name}]: {count} rows updated")
# signal_indicators
cur.execute("""
UPDATE signal_indicators
SET strategy_id = %s::uuid,
strategy_name_snapshot = %s
WHERE strategy = %s AND strategy_id IS NULL
""", (uuid, display_name, strategy_name))
count = cur.rowcount
print(f" signal_indicators [{strategy_name}]: {count} rows updated")
print("[Step 4] Done.")
def step5_verify(cur):
print("[Step 5] Verifying migration completeness...")
# Check strategies table
cur.execute("SELECT COUNT(*) FROM strategies")
n = cur.fetchone()[0]
print(f" strategies table: {n} rows")
# Check NULL strategy_id in paper_trades (for known strategies)
cur.execute("""
SELECT strategy, COUNT(*) as cnt
FROM paper_trades
WHERE strategy IN ('v53', 'v53_middle', 'v53_fast')
AND strategy_id IS NULL
GROUP BY strategy
""")
rows = cur.fetchall()
if rows:
print(f" WARNING: NULL strategy_id found in paper_trades:")
for r in rows:
print(f" {r[0]}: {r[1]} rows")
else:
print(" paper_trades: all known strategies backfilled ✅")
# Check NULL in signal_indicators
cur.execute("""
SELECT strategy, COUNT(*) as cnt
FROM signal_indicators
WHERE strategy IN ('v53', 'v53_middle', 'v53_fast')
AND strategy_id IS NULL
GROUP BY strategy
""")
rows = cur.fetchall()
if rows:
print(f" WARNING: NULL strategy_id found in signal_indicators:")
for r in rows:
print(f" {r[0]}: {r[1]} rows")
else:
print(" signal_indicators: all known strategies backfilled ✅")
print("[Step 5] Done.")
def main():
dry_run = "--dry-run" in sys.argv
if dry_run:
print("=== DRY RUN MODE (no changes will be committed) ===")
conn = get_conn()
conn.autocommit = False
cur = conn.cursor()
try:
step1_create_strategies_table(cur)
step2_add_columns(cur)
step3_insert_legacy_strategies(cur)
step4_backfill(cur)
step5_verify(cur)
if dry_run:
conn.rollback()
print("\n=== DRY RUN: rolled back all changes ===")
else:
conn.commit()
print("\n=== Migration completed successfully ✅ ===")
except Exception as e:
conn.rollback()
print(f"\n=== ERROR: {e} ===")
raise
finally:
cur.close()
conn.close()
if __name__ == "__main__":
main()