feat(v54): add strategies table, migration script, 9 CRUD API endpoints
This commit is contained in:
parent
b1ed55382c
commit
7be7b5b4c0
542
backend/main.py
542
backend/main.py
@ -2179,3 +2179,545 @@ async def strategy_plaza_trades(
|
||||
strategy_id, limit
|
||||
)
|
||||
return {"trades": [dict(r) for r in rows]}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# V5.4 Strategy Factory API
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
import uuid as _uuid
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
# ── Pydantic Models ──────────────────────────────────────────────────────────
|
||||
|
||||
class StrategyCreateRequest(BaseModel):
|
||||
display_name: str = Field(..., min_length=1, max_length=50)
|
||||
symbol: str
|
||||
direction: str = "both"
|
||||
initial_balance: float = 10000.0
|
||||
cvd_fast_window: str = "30m"
|
||||
cvd_slow_window: str = "4h"
|
||||
weight_direction: int = 55
|
||||
weight_env: int = 25
|
||||
weight_aux: int = 15
|
||||
weight_momentum: int = 5
|
||||
entry_score: int = 75
|
||||
gate_obi_enabled: bool = True
|
||||
obi_threshold: float = 0.3
|
||||
gate_whale_enabled: bool = True
|
||||
whale_cvd_threshold: float = 0.0
|
||||
gate_vol_enabled: bool = True
|
||||
atr_percentile_min: int = 20
|
||||
gate_spot_perp_enabled: bool = False
|
||||
spot_perp_threshold: float = 0.002
|
||||
sl_atr_multiplier: float = 1.5
|
||||
tp1_ratio: float = 0.75
|
||||
tp2_ratio: float = 1.5
|
||||
timeout_minutes: int = 240
|
||||
flip_threshold: int = 80
|
||||
description: Optional[str] = None
|
||||
|
||||
@field_validator("symbol")
|
||||
@classmethod
|
||||
def validate_symbol(cls, v):
|
||||
allowed = {"BTCUSDT", "ETHUSDT", "SOLUSDT", "XRPUSDT"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"symbol must be one of {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator("direction")
|
||||
@classmethod
|
||||
def validate_direction(cls, v):
|
||||
if v not in {"long_only", "short_only", "both"}:
|
||||
raise ValueError("direction must be long_only, short_only, or both")
|
||||
return v
|
||||
|
||||
@field_validator("cvd_fast_window")
|
||||
@classmethod
|
||||
def validate_cvd_fast(cls, v):
|
||||
if v not in {"5m", "15m", "30m"}:
|
||||
raise ValueError("cvd_fast_window must be 5m, 15m, or 30m")
|
||||
return v
|
||||
|
||||
@field_validator("cvd_slow_window")
|
||||
@classmethod
|
||||
def validate_cvd_slow(cls, v):
|
||||
if v not in {"30m", "1h", "4h"}:
|
||||
raise ValueError("cvd_slow_window must be 30m, 1h, or 4h")
|
||||
return v
|
||||
|
||||
@field_validator("weight_direction")
|
||||
@classmethod
|
||||
def validate_w_dir(cls, v):
|
||||
if not 10 <= v <= 80:
|
||||
raise ValueError("weight_direction must be 10-80")
|
||||
return v
|
||||
|
||||
@field_validator("weight_env")
|
||||
@classmethod
|
||||
def validate_w_env(cls, v):
|
||||
if not 5 <= v <= 60:
|
||||
raise ValueError("weight_env must be 5-60")
|
||||
return v
|
||||
|
||||
@field_validator("weight_aux")
|
||||
@classmethod
|
||||
def validate_w_aux(cls, v):
|
||||
if not 0 <= v <= 40:
|
||||
raise ValueError("weight_aux must be 0-40")
|
||||
return v
|
||||
|
||||
@field_validator("weight_momentum")
|
||||
@classmethod
|
||||
def validate_w_mom(cls, v):
|
||||
if not 0 <= v <= 20:
|
||||
raise ValueError("weight_momentum must be 0-20")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_weights_sum(self):
|
||||
total = self.weight_direction + self.weight_env + self.weight_aux + self.weight_momentum
|
||||
if total != 100:
|
||||
raise ValueError(f"Weights must sum to 100, got {total}")
|
||||
return self
|
||||
|
||||
@field_validator("entry_score")
|
||||
@classmethod
|
||||
def validate_entry_score(cls, v):
|
||||
if not 60 <= v <= 95:
|
||||
raise ValueError("entry_score must be 60-95")
|
||||
return v
|
||||
|
||||
@field_validator("obi_threshold")
|
||||
@classmethod
|
||||
def validate_obi(cls, v):
|
||||
if not 0.1 <= v <= 0.9:
|
||||
raise ValueError("obi_threshold must be 0.1-0.9")
|
||||
return v
|
||||
|
||||
@field_validator("whale_cvd_threshold")
|
||||
@classmethod
|
||||
def validate_whale(cls, v):
|
||||
if not -1.0 <= v <= 1.0:
|
||||
raise ValueError("whale_cvd_threshold must be -1.0 to 1.0")
|
||||
return v
|
||||
|
||||
@field_validator("atr_percentile_min")
|
||||
@classmethod
|
||||
def validate_atr_pct(cls, v):
|
||||
if not 5 <= v <= 80:
|
||||
raise ValueError("atr_percentile_min must be 5-80")
|
||||
return v
|
||||
|
||||
@field_validator("spot_perp_threshold")
|
||||
@classmethod
|
||||
def validate_spot_perp(cls, v):
|
||||
if not 0.0005 <= v <= 0.01:
|
||||
raise ValueError("spot_perp_threshold must be 0.0005-0.01")
|
||||
return v
|
||||
|
||||
@field_validator("sl_atr_multiplier")
|
||||
@classmethod
|
||||
def validate_sl(cls, v):
|
||||
if not 0.5 <= v <= 3.0:
|
||||
raise ValueError("sl_atr_multiplier must be 0.5-3.0")
|
||||
return v
|
||||
|
||||
@field_validator("tp1_ratio")
|
||||
@classmethod
|
||||
def validate_tp1(cls, v):
|
||||
if not 0.3 <= v <= 2.0:
|
||||
raise ValueError("tp1_ratio must be 0.3-2.0")
|
||||
return v
|
||||
|
||||
@field_validator("tp2_ratio")
|
||||
@classmethod
|
||||
def validate_tp2(cls, v):
|
||||
if not 0.5 <= v <= 4.0:
|
||||
raise ValueError("tp2_ratio must be 0.5-4.0")
|
||||
return v
|
||||
|
||||
@field_validator("timeout_minutes")
|
||||
@classmethod
|
||||
def validate_timeout(cls, v):
|
||||
if not 30 <= v <= 1440:
|
||||
raise ValueError("timeout_minutes must be 30-1440")
|
||||
return v
|
||||
|
||||
@field_validator("flip_threshold")
|
||||
@classmethod
|
||||
def validate_flip(cls, v):
|
||||
if not 60 <= v <= 95:
|
||||
raise ValueError("flip_threshold must be 60-95")
|
||||
return v
|
||||
|
||||
@field_validator("initial_balance")
|
||||
@classmethod
|
||||
def validate_balance(cls, v):
|
||||
if v < 1000:
|
||||
raise ValueError("initial_balance must be >= 1000")
|
||||
return v
|
||||
|
||||
|
||||
class StrategyUpdateRequest(BaseModel):
|
||||
"""Partial update - all fields optional"""
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=50)
|
||||
direction: Optional[str] = None
|
||||
cvd_fast_window: Optional[str] = None
|
||||
cvd_slow_window: Optional[str] = None
|
||||
weight_direction: Optional[int] = None
|
||||
weight_env: Optional[int] = None
|
||||
weight_aux: Optional[int] = None
|
||||
weight_momentum: Optional[int] = None
|
||||
entry_score: Optional[int] = None
|
||||
gate_obi_enabled: Optional[bool] = None
|
||||
obi_threshold: Optional[float] = None
|
||||
gate_whale_enabled: Optional[bool] = None
|
||||
whale_cvd_threshold: Optional[float] = None
|
||||
gate_vol_enabled: Optional[bool] = None
|
||||
atr_percentile_min: Optional[int] = None
|
||||
gate_spot_perp_enabled: Optional[bool] = None
|
||||
spot_perp_threshold: Optional[float] = None
|
||||
sl_atr_multiplier: Optional[float] = None
|
||||
tp1_ratio: Optional[float] = None
|
||||
tp2_ratio: Optional[float] = None
|
||||
timeout_minutes: Optional[int] = None
|
||||
flip_threshold: Optional[int] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class AddBalanceRequest(BaseModel):
|
||||
amount: float = Field(..., gt=0)
|
||||
|
||||
|
||||
class DeprecateRequest(BaseModel):
|
||||
confirm: bool
|
||||
|
||||
|
||||
# ── Helper ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _get_strategy_or_404(strategy_id: str) -> dict:
|
||||
row = await async_fetchrow(
|
||||
"SELECT * FROM strategies WHERE strategy_id=$1",
|
||||
strategy_id
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Strategy not found")
|
||||
return dict(row)
|
||||
|
||||
|
||||
def _strategy_row_to_card(row: dict) -> dict:
|
||||
"""Convert a strategies row to a card-level response (no config params)"""
|
||||
return {
|
||||
"strategy_id": str(row["strategy_id"]),
|
||||
"display_name": row["display_name"],
|
||||
"status": row["status"],
|
||||
"symbol": row["symbol"],
|
||||
"direction": row["direction"],
|
||||
"started_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0,
|
||||
"initial_balance": row["initial_balance"],
|
||||
"current_balance": row["current_balance"],
|
||||
"net_usdt": round(row["current_balance"] - row["initial_balance"], 2),
|
||||
"deprecated_at": int(row["deprecated_at"].timestamp() * 1000) if row.get("deprecated_at") else None,
|
||||
"last_run_at": int(row["last_run_at"].timestamp() * 1000) if row.get("last_run_at") else None,
|
||||
"schema_version": row["schema_version"],
|
||||
}
|
||||
|
||||
|
||||
def _strategy_row_to_detail(row: dict) -> dict:
|
||||
"""Full detail including all config params"""
|
||||
base = _strategy_row_to_card(row)
|
||||
base.update({
|
||||
"cvd_fast_window": row["cvd_fast_window"],
|
||||
"cvd_slow_window": row["cvd_slow_window"],
|
||||
"weight_direction": row["weight_direction"],
|
||||
"weight_env": row["weight_env"],
|
||||
"weight_aux": row["weight_aux"],
|
||||
"weight_momentum": row["weight_momentum"],
|
||||
"entry_score": row["entry_score"],
|
||||
"gate_obi_enabled": row["gate_obi_enabled"],
|
||||
"obi_threshold": row["obi_threshold"],
|
||||
"gate_whale_enabled": row["gate_whale_enabled"],
|
||||
"whale_cvd_threshold": row["whale_cvd_threshold"],
|
||||
"gate_vol_enabled": row["gate_vol_enabled"],
|
||||
"atr_percentile_min": row["atr_percentile_min"],
|
||||
"gate_spot_perp_enabled": row["gate_spot_perp_enabled"],
|
||||
"spot_perp_threshold": row["spot_perp_threshold"],
|
||||
"sl_atr_multiplier": row["sl_atr_multiplier"],
|
||||
"tp1_ratio": row["tp1_ratio"],
|
||||
"tp2_ratio": row["tp2_ratio"],
|
||||
"timeout_minutes": row["timeout_minutes"],
|
||||
"flip_threshold": row["flip_threshold"],
|
||||
"description": row.get("description"),
|
||||
"created_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0,
|
||||
"updated_at": int(row["updated_at"].timestamp() * 1000) if row.get("updated_at") else 0,
|
||||
})
|
||||
return base
|
||||
|
||||
|
||||
async def _get_strategy_trade_stats(strategy_id: str) -> dict:
|
||||
"""Fetch trade statistics for a strategy by strategy_id"""
|
||||
rows = await async_fetch(
|
||||
"""SELECT status, pnl_r, tp1_hit, entry_ts, exit_ts
|
||||
FROM paper_trades
|
||||
WHERE strategy_id=$1 AND status != 'active'
|
||||
ORDER BY entry_ts DESC""",
|
||||
strategy_id
|
||||
)
|
||||
if not rows:
|
||||
return {
|
||||
"trade_count": 0, "win_rate": 0.0,
|
||||
"avg_win_r": 0.0, "avg_loss_r": 0.0,
|
||||
"open_positions": 0,
|
||||
"pnl_usdt_24h": 0.0, "pnl_r_24h": 0.0,
|
||||
"last_trade_at": None,
|
||||
}
|
||||
|
||||
total = len(rows)
|
||||
wins = [r for r in rows if (r["pnl_r"] or 0) > 0]
|
||||
losses = [r for r in rows if (r["pnl_r"] or 0) < 0]
|
||||
win_rate = round(len(wins) / total * 100, 1) if total else 0.0
|
||||
avg_win = round(sum(r["pnl_r"] for r in wins) / len(wins), 3) if wins else 0.0
|
||||
avg_loss = round(sum(r["pnl_r"] for r in losses) / len(losses), 3) if losses else 0.0
|
||||
last_trade_at = rows[0]["exit_ts"] if rows else None
|
||||
|
||||
# 24h stats
|
||||
cutoff_ms = int((datetime.utcnow() - timedelta(hours=24)).timestamp() * 1000)
|
||||
rows_24h = [r for r in rows if (r["exit_ts"] or 0) >= cutoff_ms]
|
||||
pnl_r_24h = round(sum(r["pnl_r"] or 0 for r in rows_24h), 3)
|
||||
pnl_usdt_24h = round(pnl_r_24h * 200, 2)
|
||||
|
||||
# Open positions
|
||||
open_rows = await async_fetch(
|
||||
"SELECT COUNT(*) as cnt FROM paper_trades WHERE strategy_id=$1 AND status='active'",
|
||||
strategy_id
|
||||
)
|
||||
open_positions = open_rows[0]["cnt"] if open_rows else 0
|
||||
|
||||
return {
|
||||
"trade_count": total,
|
||||
"win_rate": win_rate,
|
||||
"avg_win_r": avg_win,
|
||||
"avg_loss_r": avg_loss,
|
||||
"open_positions": open_positions,
|
||||
"pnl_usdt_24h": pnl_usdt_24h,
|
||||
"pnl_r_24h": pnl_r_24h,
|
||||
"last_trade_at": last_trade_at,
|
||||
}
|
||||
|
||||
|
||||
# ── Endpoints ────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.post("/api/strategies")
|
||||
async def create_strategy(body: StrategyCreateRequest, user: dict = Depends(get_current_user)):
|
||||
"""创建新策略实例"""
|
||||
new_id = str(_uuid.uuid4())
|
||||
await async_execute(
|
||||
"""INSERT INTO strategies (
|
||||
strategy_id, display_name, schema_version, status,
|
||||
symbol, direction,
|
||||
cvd_fast_window, cvd_slow_window,
|
||||
weight_direction, weight_env, weight_aux, weight_momentum,
|
||||
entry_score,
|
||||
gate_obi_enabled, obi_threshold,
|
||||
gate_whale_enabled, whale_cvd_threshold,
|
||||
gate_vol_enabled, atr_percentile_min,
|
||||
gate_spot_perp_enabled, spot_perp_threshold,
|
||||
sl_atr_multiplier, tp1_ratio, tp2_ratio,
|
||||
timeout_minutes, flip_threshold,
|
||||
initial_balance, current_balance,
|
||||
description
|
||||
) VALUES (
|
||||
$1,$2,1,'running',
|
||||
$3,$4,$5,$6,
|
||||
$7,$8,$9,$10,
|
||||
$11,
|
||||
$12,$13,$14,$15,$16,$17,$18,$19,
|
||||
$20,$21,$22,$23,$24,
|
||||
$25,$25,$26
|
||||
)""",
|
||||
new_id, body.display_name,
|
||||
body.symbol, body.direction, body.cvd_fast_window, body.cvd_slow_window,
|
||||
body.weight_direction, body.weight_env, body.weight_aux, body.weight_momentum,
|
||||
body.entry_score,
|
||||
body.gate_obi_enabled, body.obi_threshold,
|
||||
body.gate_whale_enabled, body.whale_cvd_threshold,
|
||||
body.gate_vol_enabled, body.atr_percentile_min,
|
||||
body.gate_spot_perp_enabled, body.spot_perp_threshold,
|
||||
body.sl_atr_multiplier, body.tp1_ratio, body.tp2_ratio,
|
||||
body.timeout_minutes, body.flip_threshold,
|
||||
body.initial_balance, body.description
|
||||
)
|
||||
row = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", new_id)
|
||||
return {"ok": True, "strategy": _strategy_row_to_detail(dict(row))}
|
||||
|
||||
|
||||
@app.get("/api/strategies")
|
||||
async def list_strategies(
|
||||
include_deprecated: bool = False,
|
||||
user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""获取策略列表"""
|
||||
if include_deprecated:
|
||||
rows = await async_fetch("SELECT * FROM strategies ORDER BY created_at ASC")
|
||||
else:
|
||||
rows = await async_fetch(
|
||||
"SELECT * FROM strategies WHERE status != 'deprecated' ORDER BY created_at ASC"
|
||||
)
|
||||
result = []
|
||||
for row in rows:
|
||||
d = _strategy_row_to_card(dict(row))
|
||||
stats = await _get_strategy_trade_stats(str(row["strategy_id"]))
|
||||
d.update(stats)
|
||||
result.append(d)
|
||||
return {"strategies": result}
|
||||
|
||||
|
||||
@app.get("/api/strategies/{sid}")
|
||||
async def get_strategy(sid: str, user: dict = Depends(get_current_user)):
|
||||
"""获取单个策略详情(含完整参数配置)"""
|
||||
row = await _get_strategy_or_404(sid)
|
||||
detail = _strategy_row_to_detail(row)
|
||||
stats = await _get_strategy_trade_stats(sid)
|
||||
detail.update(stats)
|
||||
return {"strategy": detail}
|
||||
|
||||
|
||||
@app.patch("/api/strategies/{sid}")
|
||||
async def update_strategy(sid: str, body: StrategyUpdateRequest, user: dict = Depends(get_current_user)):
|
||||
"""更新策略参数(Partial Update)"""
|
||||
row = await _get_strategy_or_404(sid)
|
||||
if row["status"] == "deprecated":
|
||||
raise HTTPException(status_code=403, detail="Cannot modify a deprecated strategy")
|
||||
|
||||
# Build SET clause dynamically from non-None fields
|
||||
updates = body.model_dump(exclude_none=True)
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
|
||||
# Validate weights sum if any weight is being changed
|
||||
weight_fields = {"weight_direction", "weight_env", "weight_aux", "weight_momentum"}
|
||||
if weight_fields & set(updates.keys()):
|
||||
w_dir = updates.get("weight_direction", row["weight_direction"])
|
||||
w_env = updates.get("weight_env", row["weight_env"])
|
||||
w_aux = updates.get("weight_aux", row["weight_aux"])
|
||||
w_mom = updates.get("weight_momentum", row["weight_momentum"])
|
||||
if w_dir + w_env + w_aux + w_mom != 100:
|
||||
raise HTTPException(status_code=400, detail=f"Weights must sum to 100, got {w_dir+w_env+w_aux+w_mom}")
|
||||
|
||||
# Validate individual field ranges
|
||||
validators = {
|
||||
"direction": lambda v: v in {"long_only", "short_only", "both"},
|
||||
"cvd_fast_window": lambda v: v in {"5m", "15m", "30m"},
|
||||
"cvd_slow_window": lambda v: v in {"30m", "1h", "4h"},
|
||||
"weight_direction": lambda v: 10 <= v <= 80,
|
||||
"weight_env": lambda v: 5 <= v <= 60,
|
||||
"weight_aux": lambda v: 0 <= v <= 40,
|
||||
"weight_momentum": lambda v: 0 <= v <= 20,
|
||||
"entry_score": lambda v: 60 <= v <= 95,
|
||||
"obi_threshold": lambda v: 0.1 <= v <= 0.9,
|
||||
"whale_cvd_threshold": lambda v: -1.0 <= v <= 1.0,
|
||||
"atr_percentile_min": lambda v: 5 <= v <= 80,
|
||||
"spot_perp_threshold": lambda v: 0.0005 <= v <= 0.01,
|
||||
"sl_atr_multiplier": lambda v: 0.5 <= v <= 3.0,
|
||||
"tp1_ratio": lambda v: 0.3 <= v <= 2.0,
|
||||
"tp2_ratio": lambda v: 0.5 <= v <= 4.0,
|
||||
"timeout_minutes": lambda v: 30 <= v <= 1440,
|
||||
"flip_threshold": lambda v: 60 <= v <= 95,
|
||||
}
|
||||
for field, val in updates.items():
|
||||
if field in validators and not validators[field](val):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid value for {field}: {val}")
|
||||
|
||||
# Execute update
|
||||
set_parts = [f"{k}=${i+2}" for i, k in enumerate(updates.keys())]
|
||||
set_parts.append(f"updated_at=NOW()")
|
||||
sql = f"UPDATE strategies SET {', '.join(set_parts)} WHERE strategy_id=$1"
|
||||
await async_execute(sql, sid, *updates.values())
|
||||
|
||||
updated = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", sid)
|
||||
return {"ok": True, "strategy": _strategy_row_to_detail(dict(updated))}
|
||||
|
||||
|
||||
@app.post("/api/strategies/{sid}/pause")
|
||||
async def pause_strategy(sid: str, user: dict = Depends(get_current_user)):
|
||||
"""暂停策略(停止开新仓,不影响现有持仓)"""
|
||||
row = await _get_strategy_or_404(sid)
|
||||
if row["status"] == "deprecated":
|
||||
raise HTTPException(status_code=403, detail="Cannot pause a deprecated strategy")
|
||||
if row["status"] == "paused":
|
||||
return {"ok": True, "message": "Already paused"}
|
||||
await async_execute(
|
||||
"UPDATE strategies SET status='paused', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1",
|
||||
sid
|
||||
)
|
||||
return {"ok": True, "status": "paused"}
|
||||
|
||||
|
||||
@app.post("/api/strategies/{sid}/resume")
|
||||
async def resume_strategy(sid: str, user: dict = Depends(get_current_user)):
|
||||
"""恢复策略"""
|
||||
row = await _get_strategy_or_404(sid)
|
||||
if row["status"] == "running":
|
||||
return {"ok": True, "message": "Already running"}
|
||||
await async_execute(
|
||||
"UPDATE strategies SET status='running', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1",
|
||||
sid
|
||||
)
|
||||
return {"ok": True, "status": "running"}
|
||||
|
||||
|
||||
@app.post("/api/strategies/{sid}/deprecate")
|
||||
async def deprecate_strategy(sid: str, body: DeprecateRequest, user: dict = Depends(get_current_user)):
|
||||
"""废弃策略(数据永久保留,可重新启用)"""
|
||||
if not body.confirm:
|
||||
raise HTTPException(status_code=400, detail="Must set confirm=true to deprecate")
|
||||
row = await _get_strategy_or_404(sid)
|
||||
if row["status"] == "deprecated":
|
||||
return {"ok": True, "message": "Already deprecated"}
|
||||
await async_execute(
|
||||
"""UPDATE strategies
|
||||
SET status='deprecated', deprecated_at=NOW(),
|
||||
status_changed_at=NOW(), updated_at=NOW()
|
||||
WHERE strategy_id=$1""",
|
||||
sid
|
||||
)
|
||||
return {"ok": True, "status": "deprecated"}
|
||||
|
||||
|
||||
@app.post("/api/strategies/{sid}/restore")
|
||||
async def restore_strategy(sid: str, user: dict = Depends(get_current_user)):
|
||||
"""重新启用废弃策略(继续原有余额和历史数据)"""
|
||||
row = await _get_strategy_or_404(sid)
|
||||
if row["status"] != "deprecated":
|
||||
raise HTTPException(status_code=400, detail="Strategy is not deprecated")
|
||||
await async_execute(
|
||||
"""UPDATE strategies
|
||||
SET status='running', deprecated_at=NULL,
|
||||
status_changed_at=NOW(), updated_at=NOW()
|
||||
WHERE strategy_id=$1""",
|
||||
sid
|
||||
)
|
||||
return {"ok": True, "status": "running"}
|
||||
|
||||
|
||||
@app.post("/api/strategies/{sid}/add-balance")
|
||||
async def add_balance(sid: str, body: AddBalanceRequest, user: dict = Depends(get_current_user)):
|
||||
"""追加余额(initial_balance 和 current_balance 同步增加)"""
|
||||
row = await _get_strategy_or_404(sid)
|
||||
if row["status"] == "deprecated":
|
||||
raise HTTPException(status_code=403, detail="Cannot add balance to a deprecated strategy")
|
||||
new_initial = round(row["initial_balance"] + body.amount, 2)
|
||||
new_current = round(row["current_balance"] + body.amount, 2)
|
||||
await async_execute(
|
||||
"""UPDATE strategies
|
||||
SET initial_balance=$2, current_balance=$3, updated_at=NOW()
|
||||
WHERE strategy_id=$1""",
|
||||
sid, new_initial, new_current
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"initial_balance": new_initial,
|
||||
"current_balance": new_current,
|
||||
"added": body.amount,
|
||||
}
|
||||
|
||||
327
backend/migrate_v54.py
Normal file
327
backend/migrate_v54.py
Normal file
@ -0,0 +1,327 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
V5.4 Strategy Factory DB Migration Script
|
||||
- Creates `strategies` table
|
||||
- Adds strategy_id + strategy_name_snapshot to paper_trades, signal_indicators
|
||||
- Inserts existing 3 strategies with fixed UUIDs
|
||||
- Backfills strategy_id + strategy_name_snapshot for all existing records
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
PG_HOST = os.environ.get("PG_HOST", "10.106.0.3")
|
||||
PG_PASS = os.environ.get("PG_PASS", "arb_engine_2026")
|
||||
PG_USER = "arb"
|
||||
PG_DB = "arb_engine"
|
||||
|
||||
# Fixed UUIDs for existing strategies (deterministic, easy to recognize)
|
||||
LEGACY_STRATEGY_MAP = {
|
||||
"v53": ("00000000-0000-0000-0000-000000000053", "V5.3 Standard"),
|
||||
"v53_middle": ("00000000-0000-0000-0000-000000000054", "V5.3 Middle"),
|
||||
"v53_fast": ("00000000-0000-0000-0000-000000000055", "V5.3 Fast"),
|
||||
}
|
||||
|
||||
# Default config values per strategy (from strategy JSON files)
|
||||
LEGACY_CONFIGS = {
|
||||
"v53": {
|
||||
"symbol": "BTCUSDT", # multi-symbol, use BTC as representative
|
||||
"cvd_fast_window": "30m",
|
||||
"cvd_slow_window": "4h",
|
||||
"weight_direction": 55,
|
||||
"weight_env": 25,
|
||||
"weight_aux": 15,
|
||||
"weight_momentum": 5,
|
||||
"entry_score": 75,
|
||||
"sl_atr_multiplier": 1.0,
|
||||
"tp1_ratio": 0.75,
|
||||
"tp2_ratio": 1.5,
|
||||
"timeout_minutes": 60,
|
||||
"flip_threshold": 75,
|
||||
"status": "running",
|
||||
"initial_balance": 10000.0,
|
||||
},
|
||||
"v53_middle": {
|
||||
"symbol": "BTCUSDT",
|
||||
"cvd_fast_window": "15m",
|
||||
"cvd_slow_window": "1h",
|
||||
"weight_direction": 55,
|
||||
"weight_env": 25,
|
||||
"weight_aux": 15,
|
||||
"weight_momentum": 5,
|
||||
"entry_score": 75,
|
||||
"sl_atr_multiplier": 1.0,
|
||||
"tp1_ratio": 0.75,
|
||||
"tp2_ratio": 1.5,
|
||||
"timeout_minutes": 60,
|
||||
"flip_threshold": 75,
|
||||
"status": "running",
|
||||
"initial_balance": 10000.0,
|
||||
},
|
||||
"v53_fast": {
|
||||
"symbol": "BTCUSDT",
|
||||
"cvd_fast_window": "5m",
|
||||
"cvd_slow_window": "30m",
|
||||
"weight_direction": 55,
|
||||
"weight_env": 25,
|
||||
"weight_aux": 15,
|
||||
"weight_momentum": 5,
|
||||
"entry_score": 75,
|
||||
"sl_atr_multiplier": 1.0,
|
||||
"tp1_ratio": 0.75,
|
||||
"tp2_ratio": 1.5,
|
||||
"timeout_minutes": 60,
|
||||
"flip_threshold": 75,
|
||||
"status": "running",
|
||||
"initial_balance": 10000.0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_conn():
|
||||
return psycopg2.connect(
|
||||
host=PG_HOST, user=PG_USER, password=PG_PASS, dbname=PG_DB
|
||||
)
|
||||
|
||||
|
||||
def step1_create_strategies_table(cur):
|
||||
print("[Step 1] Creating strategies table...")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS strategies (
|
||||
strategy_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
display_name TEXT NOT NULL,
|
||||
schema_version INT NOT NULL DEFAULT 1,
|
||||
status TEXT NOT NULL DEFAULT 'running'
|
||||
CHECK (status IN ('running', 'paused', 'deprecated')),
|
||||
status_changed_at TIMESTAMP,
|
||||
last_run_at TIMESTAMP,
|
||||
deprecated_at TIMESTAMP,
|
||||
symbol TEXT NOT NULL
|
||||
CHECK (symbol IN ('BTCUSDT', 'ETHUSDT', 'SOLUSDT', 'XRPUSDT')),
|
||||
direction TEXT NOT NULL DEFAULT 'both'
|
||||
CHECK (direction IN ('long_only', 'short_only', 'both')),
|
||||
cvd_fast_window TEXT NOT NULL DEFAULT '30m'
|
||||
CHECK (cvd_fast_window IN ('5m', '15m', '30m')),
|
||||
cvd_slow_window TEXT NOT NULL DEFAULT '4h'
|
||||
CHECK (cvd_slow_window IN ('30m', '1h', '4h')),
|
||||
weight_direction INT NOT NULL DEFAULT 55,
|
||||
weight_env INT NOT NULL DEFAULT 25,
|
||||
weight_aux INT NOT NULL DEFAULT 15,
|
||||
weight_momentum INT NOT NULL DEFAULT 5,
|
||||
entry_score INT NOT NULL DEFAULT 75,
|
||||
gate_obi_enabled BOOL NOT NULL DEFAULT TRUE,
|
||||
obi_threshold FLOAT NOT NULL DEFAULT 0.3,
|
||||
gate_whale_enabled BOOL NOT NULL DEFAULT TRUE,
|
||||
whale_cvd_threshold FLOAT NOT NULL DEFAULT 0.0,
|
||||
gate_vol_enabled BOOL NOT NULL DEFAULT TRUE,
|
||||
atr_percentile_min INT NOT NULL DEFAULT 20,
|
||||
gate_spot_perp_enabled BOOL NOT NULL DEFAULT FALSE,
|
||||
spot_perp_threshold FLOAT NOT NULL DEFAULT 0.002,
|
||||
sl_atr_multiplier FLOAT NOT NULL DEFAULT 1.5,
|
||||
tp1_ratio FLOAT NOT NULL DEFAULT 0.75,
|
||||
tp2_ratio FLOAT NOT NULL DEFAULT 1.5,
|
||||
timeout_minutes INT NOT NULL DEFAULT 240,
|
||||
flip_threshold INT NOT NULL DEFAULT 80,
|
||||
initial_balance FLOAT NOT NULL DEFAULT 10000.0,
|
||||
current_balance FLOAT NOT NULL DEFAULT 10000.0,
|
||||
description TEXT,
|
||||
tags TEXT[],
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_status ON strategies(status)")
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_symbol ON strategies(symbol)")
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS idx_strategies_last_run ON strategies(last_run_at)")
|
||||
print("[Step 1] Done.")
|
||||
|
||||
|
||||
def step2_add_columns(cur):
|
||||
print("[Step 2] Adding strategy_id + strategy_name_snapshot columns...")
|
||||
# paper_trades
|
||||
for col, col_type in [
|
||||
("strategy_id", "UUID REFERENCES strategies(strategy_id)"),
|
||||
("strategy_name_snapshot", "TEXT"),
|
||||
]:
|
||||
cur.execute(f"""
|
||||
ALTER TABLE paper_trades
|
||||
ADD COLUMN IF NOT EXISTS {col} {col_type}
|
||||
""")
|
||||
# signal_indicators
|
||||
for col, col_type in [
|
||||
("strategy_id", "UUID REFERENCES strategies(strategy_id)"),
|
||||
("strategy_name_snapshot", "TEXT"),
|
||||
]:
|
||||
cur.execute(f"""
|
||||
ALTER TABLE signal_indicators
|
||||
ADD COLUMN IF NOT EXISTS {col} {col_type}
|
||||
""")
|
||||
# Indexes
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS idx_paper_trades_strategy_id ON paper_trades(strategy_id)")
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS idx_si_strategy_id ON signal_indicators(strategy_id)")
|
||||
print("[Step 2] Done.")
|
||||
|
||||
|
||||
def step3_insert_legacy_strategies(cur):
|
||||
print("[Step 3] Inserting legacy strategies into strategies table...")
|
||||
for strategy_name, (uuid, display_name) in LEGACY_STRATEGY_MAP.items():
|
||||
cfg = LEGACY_CONFIGS[strategy_name]
|
||||
# Compute current_balance from actual paper trades
|
||||
cur.execute("""
|
||||
SELECT
|
||||
COALESCE(SUM(pnl_r) * 200, 0) as total_pnl_usdt
|
||||
FROM paper_trades
|
||||
WHERE strategy = %s AND status != 'active'
|
||||
""", (strategy_name,))
|
||||
row = cur.fetchone()
|
||||
pnl_usdt = row[0] if row else 0
|
||||
current_balance = round(cfg["initial_balance"] + pnl_usdt, 2)
|
||||
|
||||
cur.execute("""
|
||||
INSERT INTO strategies (
|
||||
strategy_id, display_name, schema_version, status,
|
||||
symbol, direction,
|
||||
cvd_fast_window, cvd_slow_window,
|
||||
weight_direction, weight_env, weight_aux, weight_momentum,
|
||||
entry_score,
|
||||
gate_obi_enabled, obi_threshold,
|
||||
gate_whale_enabled, whale_cvd_threshold,
|
||||
gate_vol_enabled, atr_percentile_min,
|
||||
gate_spot_perp_enabled, spot_perp_threshold,
|
||||
sl_atr_multiplier, tp1_ratio, tp2_ratio,
|
||||
timeout_minutes, flip_threshold,
|
||||
initial_balance, current_balance,
|
||||
description
|
||||
) VALUES (
|
||||
%s, %s, 1, %s,
|
||||
%s, 'both',
|
||||
%s, %s,
|
||||
%s, %s, %s, %s,
|
||||
%s,
|
||||
TRUE, 0.3,
|
||||
TRUE, 0.0,
|
||||
TRUE, 20,
|
||||
FALSE, 0.002,
|
||||
%s, %s, %s,
|
||||
%s, %s,
|
||||
%s, %s,
|
||||
%s
|
||||
)
|
||||
ON CONFLICT (strategy_id) DO NOTHING
|
||||
""", (
|
||||
uuid, display_name, cfg["status"],
|
||||
cfg["symbol"], cfg["cvd_fast_window"], cfg["cvd_slow_window"],
|
||||
cfg["weight_direction"], cfg["weight_env"], cfg["weight_aux"], cfg["weight_momentum"],
|
||||
cfg["entry_score"],
|
||||
cfg["sl_atr_multiplier"], cfg["tp1_ratio"], cfg["tp2_ratio"],
|
||||
cfg["timeout_minutes"], cfg["flip_threshold"],
|
||||
cfg["initial_balance"], current_balance,
|
||||
f"Migrated from V5.3 legacy strategy: {strategy_name}"
|
||||
))
|
||||
print(f" Inserted {strategy_name} → {uuid} (balance: {current_balance})")
|
||||
print("[Step 3] Done.")
|
||||
|
||||
|
||||
def step4_backfill(cur):
|
||||
print("[Step 4] Backfilling strategy_id + strategy_name_snapshot...")
|
||||
for strategy_name, (uuid, display_name) in LEGACY_STRATEGY_MAP.items():
|
||||
# paper_trades
|
||||
cur.execute("""
|
||||
UPDATE paper_trades
|
||||
SET strategy_id = %s::uuid,
|
||||
strategy_name_snapshot = %s
|
||||
WHERE strategy = %s AND strategy_id IS NULL
|
||||
""", (uuid, display_name, strategy_name))
|
||||
count = cur.rowcount
|
||||
print(f" paper_trades [{strategy_name}]: {count} rows updated")
|
||||
|
||||
# signal_indicators
|
||||
cur.execute("""
|
||||
UPDATE signal_indicators
|
||||
SET strategy_id = %s::uuid,
|
||||
strategy_name_snapshot = %s
|
||||
WHERE strategy = %s AND strategy_id IS NULL
|
||||
""", (uuid, display_name, strategy_name))
|
||||
count = cur.rowcount
|
||||
print(f" signal_indicators [{strategy_name}]: {count} rows updated")
|
||||
|
||||
print("[Step 4] Done.")
|
||||
|
||||
|
||||
def step5_verify(cur):
|
||||
print("[Step 5] Verifying migration completeness...")
|
||||
# Check strategies table
|
||||
cur.execute("SELECT COUNT(*) FROM strategies")
|
||||
n = cur.fetchone()[0]
|
||||
print(f" strategies table: {n} rows")
|
||||
|
||||
# Check NULL strategy_id in paper_trades (for known strategies)
|
||||
cur.execute("""
|
||||
SELECT strategy, COUNT(*) as cnt
|
||||
FROM paper_trades
|
||||
WHERE strategy IN ('v53', 'v53_middle', 'v53_fast')
|
||||
AND strategy_id IS NULL
|
||||
GROUP BY strategy
|
||||
""")
|
||||
rows = cur.fetchall()
|
||||
if rows:
|
||||
print(f" WARNING: NULL strategy_id found in paper_trades:")
|
||||
for r in rows:
|
||||
print(f" {r[0]}: {r[1]} rows")
|
||||
else:
|
||||
print(" paper_trades: all known strategies backfilled ✅")
|
||||
|
||||
# Check NULL in signal_indicators
|
||||
cur.execute("""
|
||||
SELECT strategy, COUNT(*) as cnt
|
||||
FROM signal_indicators
|
||||
WHERE strategy IN ('v53', 'v53_middle', 'v53_fast')
|
||||
AND strategy_id IS NULL
|
||||
GROUP BY strategy
|
||||
""")
|
||||
rows = cur.fetchall()
|
||||
if rows:
|
||||
print(f" WARNING: NULL strategy_id found in signal_indicators:")
|
||||
for r in rows:
|
||||
print(f" {r[0]}: {r[1]} rows")
|
||||
else:
|
||||
print(" signal_indicators: all known strategies backfilled ✅")
|
||||
|
||||
print("[Step 5] Done.")
|
||||
|
||||
|
||||
def main():
|
||||
dry_run = "--dry-run" in sys.argv
|
||||
if dry_run:
|
||||
print("=== DRY RUN MODE (no changes will be committed) ===")
|
||||
|
||||
conn = get_conn()
|
||||
conn.autocommit = False
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
step1_create_strategies_table(cur)
|
||||
step2_add_columns(cur)
|
||||
step3_insert_legacy_strategies(cur)
|
||||
step4_backfill(cur)
|
||||
step5_verify(cur)
|
||||
|
||||
if dry_run:
|
||||
conn.rollback()
|
||||
print("\n=== DRY RUN: rolled back all changes ===")
|
||||
else:
|
||||
conn.commit()
|
||||
print("\n=== Migration completed successfully ✅ ===")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
print(f"\n=== ERROR: {e} ===")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user