diff --git a/backend/admin_cli.py b/backend/admin_cli.py new file mode 100644 index 0000000..5da6d5c --- /dev/null +++ b/backend/admin_cli.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Admin CLI for Arbitrage Engine""" +import sys, os, sqlite3, secrets, json + +DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") + +def get_conn(): + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + +def gen_invite(count=1, max_uses=1): + conn = get_conn() + codes = [] + for _ in range(count): + code = secrets.token_urlsafe(6)[:8].upper() + conn.execute( + "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)", + (code, max_uses) + ) + codes.append(code) + conn.commit() + conn.close() + for c in codes: + print(f" {c}") + print(f"\nGenerated {len(codes)} invite code(s)") + +def list_invites(): + conn = get_conn() + rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall() + conn.close() + if not rows: + print("No invite codes found") + return + print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}") + print("-" * 60) + for r in rows: + print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}") + +def disable_invite(code): + conn = get_conn() + conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,)) + conn.commit() + conn.close() + print(f"Disabled invite code: {code}") + +def list_users(): + conn = get_conn() + rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall() + conn.close() + if not rows: + print("No users found") + return + print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}") + print("-" * 72) + for r in rows: + print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}") + +def ban_user(user_id): + conn = get_conn() + conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,)) + conn.commit() + conn.close() + print(f"Banned user ID: {user_id}") + +def unban_user(user_id): + conn = get_conn() + conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,)) + conn.commit() + conn.close() + print(f"Unbanned user ID: {user_id}") + +def set_admin(user_id): + conn = get_conn() + conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,)) + conn.commit() + conn.close() + print(f"Set user {user_id} as admin") + +def usage(): + print("""Usage: python3 admin_cli.py [args] + +Commands: + gen-invite [count] [max_uses] Generate invite codes (default: 1 code, 1 use) + list-invites List all invite codes + disable-invite Disable an invite code + list-users List all users + ban-user Ban a user + unban-user Unban a user + set-admin Set user as admin +""") + +if __name__ == "__main__": + if len(sys.argv) < 2: + usage() + sys.exit(1) + cmd = sys.argv[1] + if cmd == "gen-invite": + count = int(sys.argv[2]) if len(sys.argv) > 2 else 1 + max_uses = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + gen_invite(count, max_uses) + elif cmd == "list-invites": + list_invites() + elif cmd == "disable-invite": + disable_invite(sys.argv[2]) + elif cmd == "list-users": + list_users() + elif cmd == "ban-user": + ban_user(int(sys.argv[2])) + elif cmd == "unban-user": + unban_user(int(sys.argv[2])) + elif cmd == "set-admin": + set_admin(int(sys.argv[2])) + else: + usage() diff --git a/backend/auth.py b/backend/auth.py index b405b48..651ea9e 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -7,17 +7,21 @@ import base64 import json from datetime import datetime, timedelta from typing import Optional +from functools import wraps -from fastapi import APIRouter, HTTPException, Header +from fastapi import APIRouter, HTTPException, Header, Depends, Request from pydantic import BaseModel, EmailStr DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") -JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me") -JWT_EXPIRE_HOURS = 24 * 7 +JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026") +ACCESS_TOKEN_HOURS = 24 +REFRESH_TOKEN_DAYS = 7 router = APIRouter(prefix="/api", tags=["auth"]) +# ─── DB helpers ─────────────────────────────────────────────── + def get_conn(): conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row @@ -27,29 +31,26 @@ def get_conn(): def ensure_tables(): conn = get_conn() cur = conn.cursor() - cur.execute( - """ + cur.execute(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, discord_id TEXT, + role TEXT NOT NULL DEFAULT 'user', + banned INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL ) - """ - ) - cur.execute( - """ + """) + cur.execute(""" CREATE TABLE IF NOT EXISTS subscriptions ( user_id INTEGER PRIMARY KEY, tier TEXT NOT NULL DEFAULT 'free', expires_at TEXT, FOREIGN KEY(user_id) REFERENCES users(id) ) - """ - ) - cur.execute( - """ + """) + cur.execute(""" CREATE TABLE IF NOT EXISTS signal_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL, @@ -58,12 +59,55 @@ def ensure_tables(): sent_at TEXT NOT NULL, message TEXT NOT NULL ) - """ - ) + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS invite_codes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + code TEXT UNIQUE NOT NULL, + created_by INTEGER, + max_uses INTEGER NOT NULL DEFAULT 1, + used_count INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', + expires_at TEXT, + created_at TEXT DEFAULT (datetime('now')) + ) + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS invite_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + code_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + used_at TEXT DEFAULT (datetime('now')), + FOREIGN KEY (code_id) REFERENCES invite_codes(id), + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS refresh_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + token TEXT UNIQUE NOT NULL, + expires_at TEXT NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + created_at TEXT DEFAULT (datetime('now')), + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + # migrate: add role/banned columns if missing + try: + cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'") + except: + pass + try: + cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0") + except: + pass conn.commit() conn.close() +# ─── Password utils ────────────────────────────────────────── + def hash_password(password: str) -> str: salt = secrets.token_hex(16) digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() @@ -79,19 +123,37 @@ def verify_password(password: str, stored: str) -> bool: return False +# ─── JWT utils ─────────────────────────────────────────────── + def b64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() -def create_token(user_id: int, email: str) -> str: +def create_access_token(user_id: int, email: str, role: str) -> str: header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode()) - exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp()) - payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode()) + exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp()) + payload = b64url(json.dumps({ + "sub": user_id, "email": email, "role": role, + "exp": exp, "type": "access" + }, separators=(",", ":")).encode()) sign_input = f"{header}.{payload}".encode() signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest() return f"{header}.{payload}.{b64url(signature)}" +def create_refresh_token(user_id: int) -> str: + token = secrets.token_urlsafe(48) + expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat() + conn = get_conn() + conn.execute( + "INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)", + (user_id, token, expires_at) + ) + conn.commit() + conn.close() + return token + + def parse_token(token: str) -> Optional[dict]: try: header_b64, payload_b64, sig_b64 = token.split(".") @@ -108,24 +170,41 @@ def parse_token(token: str) -> Optional[dict]: return None -def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row: +# ─── Auth dependency ───────────────────────────────────────── + +def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict: + """FastAPI dependency: returns user dict or raises 401""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="missing token") token = authorization.split(" ", 1)[1].strip() payload = parse_token(token) if not payload: - raise HTTPException(status_code=401, detail="invalid token") + raise HTTPException(status_code=401, detail="invalid or expired token") + if payload.get("type") != "access": + raise HTTPException(status_code=401, detail="invalid token type") conn = get_conn() user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone() conn.close() if not user: raise HTTPException(status_code=401, detail="user not found") + if user["banned"]: + raise HTTPException(status_code=403, detail="account banned") + return dict(user) + + +def require_admin(user: dict = Depends(get_current_user)) -> dict: + """FastAPI dependency: requires admin role""" + if user.get("role") != "admin": + raise HTTPException(status_code=403, detail="admin required") return user +# ─── Request models ────────────────────────────────────────── + class RegisterReq(BaseModel): email: EmailStr password: str + invite_code: str class LoginReq(BaseModel): @@ -133,18 +212,47 @@ class LoginReq(BaseModel): password: str -class BindDiscordReq(BaseModel): - discord_id: str +class RefreshReq(BaseModel): + refresh_token: str +class GenInviteReq(BaseModel): + count: int = 1 + max_uses: int = 1 + + +class BanUserReq(BaseModel): + banned: bool = True + + +# ─── Auth routes ───────────────────────────────────────────── + @router.post("/auth/register") def register(body: RegisterReq): ensure_tables() conn = get_conn() + + # verify invite code + invite = conn.execute( + "SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,) + ).fetchone() + if not invite: + conn.close() + raise HTTPException(status_code=400, detail="invalid invite code") + if invite["status"] != "active": + conn.close() + raise HTTPException(status_code=400, detail="invite code disabled") + if invite["used_count"] >= invite["max_uses"]: + conn.close() + raise HTTPException(status_code=400, detail="invite code exhausted") + if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat(): + conn.close() + raise HTTPException(status_code=400, detail="invite code expired") + try: pwd_hash = hash_password(body.password) cur = conn.execute( - "INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)", + "INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)", (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), ) user_id = cur.lastrowid @@ -152,14 +260,39 @@ def register(body: RegisterReq): "INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)", (user_id,), ) + # record invite usage + conn.execute( + "INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)", + (invite["id"], user_id), + ) + conn.execute( + "UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?", + (invite["id"],), + ) + # auto-exhaust if max reached + new_count = invite["used_count"] + 1 + if new_count >= invite["max_uses"]: + conn.execute( + "UPDATE invite_codes SET status = 'exhausted' WHERE id = ?", + (invite["id"],), + ) conn.commit() except sqlite3.IntegrityError: conn.close() raise HTTPException(status_code=400, detail="email already registered") - user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone() + # issue tokens + user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone() conn.close() - return dict(user) + access = create_access_token(user["id"], user["email"], user["role"]) + refresh = create_refresh_token(user["id"]) + return { + "access_token": access, + "refresh_token": refresh, + "token_type": "bearer", + "expires_in": ACCESS_TOKEN_HOURS * 3600, + "user": {"id": user["id"], "email": user["email"], "role": user["role"]}, + } @router.post("/auth/login") @@ -170,27 +303,113 @@ def login(body: LoginReq): conn.close() if not user or not verify_password(body.password, user["password_hash"]): raise HTTPException(status_code=401, detail="invalid credentials") - token = create_token(user["id"], user["email"]) - return {"token": token, "token_type": "bearer"} + if user["banned"]: + raise HTTPException(status_code=403, detail="account banned") + access = create_access_token(user["id"], user["email"], user["role"]) + refresh = create_refresh_token(user["id"]) + return { + "access_token": access, + "refresh_token": refresh, + "token_type": "bearer", + "expires_in": ACCESS_TOKEN_HOURS * 3600, + "user": {"id": user["id"], "email": user["email"], "role": user["role"]}, + } -@router.post("/user/bind-discord") -def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)): - user = get_user_from_auth(authorization) +@router.post("/auth/refresh") +def refresh_token(body: RefreshReq): conn = get_conn() - conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"])) + row = conn.execute( + "SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,) + ).fetchone() + if not row: + conn.close() + raise HTTPException(status_code=401, detail="invalid refresh token") + if row["expires_at"] < datetime.utcnow().isoformat(): + conn.close() + raise HTTPException(status_code=401, detail="refresh token expired") + user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone() + if not user or user["banned"]: + conn.close() + raise HTTPException(status_code=403, detail="account unavailable") + # revoke old, issue new + conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],)) conn.commit() - updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone() conn.close() - return dict(updated) + access = create_access_token(user["id"], user["email"], user["role"]) + new_refresh = create_refresh_token(user["id"]) + return { + "access_token": access, + "refresh_token": new_refresh, + "token_type": "bearer", + "expires_in": ACCESS_TOKEN_HOURS * 3600, + } -@router.get("/user/me") -def me(authorization: Optional[str] = Header(default=None)): - user = get_user_from_auth(authorization) +@router.get("/auth/me") +def me(user: dict = Depends(get_current_user)): conn = get_conn() sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone() conn.close() - out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]} - out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None} - return out + return { + "id": user["id"], "email": user["email"], "role": user["role"], + "discord_id": user.get("discord_id"), + "created_at": user["created_at"], + "subscription": dict(sub) if sub else {"tier": "free", "expires_at": None}, + } + + +# ─── Admin routes ──────────────────────────────────────────── + +@router.post("/admin/invite-codes") +def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)): + conn = get_conn() + codes = [] + for _ in range(body.count): + code = secrets.token_urlsafe(6)[:8].upper() + conn.execute( + "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)", + (code, admin["id"], body.max_uses), + ) + codes.append(code) + conn.commit() + conn.close() + return {"codes": codes} + + +@router.get("/admin/invite-codes") +def list_invite_codes(admin: dict = Depends(require_admin)): + conn = get_conn() + rows = conn.execute( + "SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC" + ).fetchall() + conn.close() + return {"items": [dict(r) for r in rows]} + + +@router.delete("/admin/invite-codes/{code_id}") +def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)): + conn = get_conn() + conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,)) + conn.commit() + conn.close() + return {"ok": True} + + +@router.get("/admin/users") +def list_users(admin: dict = Depends(require_admin)): + conn = get_conn() + rows = conn.execute( + "SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC" + ).fetchall() + conn.close() + return {"items": [dict(r) for r in rows]} + + +@router.put("/admin/users/{user_id}/ban") +def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)): + conn = get_conn() + conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id)) + conn.commit() + conn.close() + return {"ok": True, "user_id": user_id, "banned": body.banned} diff --git a/backend/main.py b/backend/main.py index cd21767..342b2fb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,9 +1,11 @@ -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware import httpx from datetime import datetime, timedelta import asyncio, time, sqlite3, os +from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables + app = FastAPI(title="Arbitrage Engine API") app.add_middleware( @@ -13,6 +15,8 @@ app.add_middleware( allow_headers=["*"], ) +app.include_router(auth_router) + BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" SYMBOLS = ["BTCUSDT", "ETHUSDT"] HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} @@ -101,6 +105,7 @@ async def background_snapshot_loop(): @app.on_event("startup") async def startup(): init_db() + ensure_auth_tables() asyncio.create_task(background_snapshot_loop()) @@ -137,7 +142,7 @@ async def get_rates(): @app.get("/api/snapshots") -async def get_snapshots(hours: int = 24, limit: int = 5000): +async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)): """查询本地落库的实时快照数据""" since = int(time.time()) - hours * 3600 conn = sqlite3.connect(DB_PATH) @@ -155,7 +160,7 @@ async def get_snapshots(hours: int = 24, limit: int = 5000): @app.get("/api/kline") -async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500): +async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)): """ 从 rate_snapshots 聚合K线数据 symbol: BTC | ETH @@ -212,7 +217,7 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500) @app.get("/api/stats/ytd") -async def get_stats_ytd(): +async def get_stats_ytd(user: dict = Depends(get_current_user)): """今年以来(YTD)资金费率年化统计""" cached = get_cache("stats_ytd", 3600) if cached: return cached @@ -245,7 +250,7 @@ async def get_stats_ytd(): @app.get("/api/signals/history") -async def get_signals_history(limit: int = 100): +async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)): """查询信号推送历史""" try: conn = sqlite3.connect(DB_PATH) @@ -261,7 +266,7 @@ async def get_signals_history(limit: int = 100): @app.get("/api/history") -async def get_history(): +async def get_history(user: dict = Depends(get_current_user)): cached = get_cache("history", 60) if cached: return cached end_time = int(datetime.utcnow().timestamp() * 1000) @@ -288,7 +293,7 @@ async def get_history(): @app.get("/api/stats") -async def get_stats(): +async def get_stats(user: dict = Depends(get_current_user)): cached = get_cache("stats", 60) if cached: return cached end_time = int(datetime.utcnow().timestamp() * 1000) diff --git a/frontend/app/layout.tsx b/frontend/app/layout.tsx index e5ece9b..0747d89 100644 --- a/frontend/app/layout.tsx +++ b/frontend/app/layout.tsx @@ -2,7 +2,8 @@ import type { Metadata } from "next"; import { Geist, Geist_Mono } from "next/font/google"; import "./globals.css"; import Sidebar from "@/components/Sidebar"; -import Link from "next/link"; +import { AuthProvider } from "@/lib/auth"; +import AuthHeader from "@/components/AuthHeader"; const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"] }); const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"] }); @@ -16,25 +17,17 @@ export default function RootLayout({ children }: Readonly<{ children: React.Reac return ( -
- -
- {/* 桌面端顶栏:右上角登录注册 */} -
- - 登录 - - - 注册 - -
-
- {children} -
+ +
+ +
+ +
+ {children} +
+
-
+ ); diff --git a/frontend/app/login/page.tsx b/frontend/app/login/page.tsx index b44f52d..7ada48d 100644 --- a/frontend/app/login/page.tsx +++ b/frontend/app/login/page.tsx @@ -1,10 +1,13 @@ "use client"; -import { useState, Suspense } from "react"; -import { useRouter, useSearchParams } from "next/navigation"; -function LoginForm() { +import { useState } from "react"; +import { useAuth } from "@/lib/auth"; +import { useRouter } from "next/navigation"; +import Link from "next/link"; + +export default function LoginPage() { + const { login } = useAuth(); const router = useRouter(); - const params = useSearchParams(); const [email, setEmail] = useState(""); const [password, setPassword] = useState(""); const [error, setError] = useState(""); @@ -12,74 +15,61 @@ function LoginForm() { const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); - setLoading(true); setError(""); + setLoading(true); try { - const form = new URLSearchParams(); - form.append("username", email); - form.append("password", password); - const r = await fetch("/api/auth/login", { - method: "POST", - headers: { "Content-Type": "application/x-www-form-urlencoded" }, - body: form.toString(), - }); - const data = await r.json(); - if (!r.ok) { setError(data.detail || "登录失败"); return; } - localStorage.setItem("arb_token", data.access_token); - router.push("/dashboard"); - } catch { - setError("网络错误,请重试"); + await login(email, password); + router.push("/"); + } catch (err: any) { + setError(err.message || "登录失败"); } finally { setLoading(false); } }; return ( -
-
-
-

登录

- {params.get("registered") && ( -

✅ 注册成功,请登录

- )} -

登录后查看信号和账户信息

+
+
+
+
+

⚡ Arbitrage Engine

+

登录您的账户

+
+ +
+
+ + setEmail(e.target.value)} + className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent" + placeholder="you@example.com" required + /> +
+
+ + setPassword(e.target.value)} + className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent" + placeholder="••••••••" required + /> +
+ + {error &&

{error}

} + + +
+ +

+ 没有账户?{" "} + 注册 +

-
-
- - setEmail(e.target.value)} - className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500" - placeholder="your@email.com" - /> -
-
- - setPassword(e.target.value)} - className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500" - /> -
- {error &&

{error}

} - -
-

- 没有账号?注册 -

); } - -export default function LoginPage() { - return ( - - - - ); -} diff --git a/frontend/app/page.tsx b/frontend/app/page.tsx index cfe08fd..c977ed4 100644 --- a/frontend/app/page.tsx +++ b/frontend/app/page.tsx @@ -3,8 +3,10 @@ import { useEffect, useState, useCallback, useRef } from "react"; import { createChart, ColorType, CandlestickSeries } from "lightweight-charts"; import { api, RatesResponse, StatsResponse, HistoryResponse, HistoryPoint, SignalHistoryItem, KBar, YtdStatsResponse } from "@/lib/api"; +import { useAuth } from "@/lib/auth"; import RateCard from "@/components/RateCard"; import StatsCard from "@/components/StatsCard"; +import Link from "next/link"; import { LineChart, Line, XAxis, YAxis, Tooltip, Legend, ResponsiveContainer, ReferenceLine @@ -91,8 +93,35 @@ function IntervalPicker({ value, onChange }: { value: string; onChange: (v: stri ); } +// ─── 未登录遮挡组件 ────────────────────────────────────────────── +function AuthGate({ children }: { children: React.ReactNode }) { + const { isLoggedIn, loading } = useAuth(); + if (loading) return
加载中...
; + if (!isLoggedIn) { + return ( +
+
+ {children} +
+
+
+
🔒
+

登录后查看完整数据

+
+ 登录 + 注册 +
+
+
+
+ ); + } + return <>{children}; +} + // ─── 主仪表盘 ──────────────────────────────────────────────────── export default function Dashboard() { + const { isLoggedIn } = useAuth(); const [rates, setRates] = useState(null); const [stats, setStats] = useState(null); const [history, setHistory] = useState(null); @@ -110,11 +139,12 @@ export default function Dashboard() { }, []); const fetchSlow = useCallback(async () => { + if (!isLoggedIn) return; try { const [s, h, sig, y] = await Promise.all([api.stats(), api.history(), api.signalsHistory(), api.statsYtd()]); setStats(s); setHistory(h); setSignals(sig.items || []); setYtd(y); } catch {} - }, []); + }, [isLoggedIn]); useEffect(() => { fetchRates(); fetchSlow(); @@ -154,6 +184,7 @@ export default function Dashboard() {
{/* 统计卡片 */} + {stats && (
@@ -283,6 +314,7 @@ export default function Dashboard() { 策略原理: 持有现货多头 + 永续空头,每8小时收取资金费率,赚取无方向风险的稳定收益。
+
); } diff --git a/frontend/app/register/page.tsx b/frontend/app/register/page.tsx index 5c35748..48460a4 100644 --- a/frontend/app/register/page.tsx +++ b/frontend/app/register/page.tsx @@ -1,79 +1,87 @@ "use client"; + import { useState } from "react"; +import { useAuth } from "@/lib/auth"; import { useRouter } from "next/navigation"; +import Link from "next/link"; export default function RegisterPage() { + const { register } = useAuth(); const router = useRouter(); const [email, setEmail] = useState(""); const [password, setPassword] = useState(""); - const [discordId, setDiscordId] = useState(""); + const [inviteCode, setInviteCode] = useState(""); const [error, setError] = useState(""); const [loading, setLoading] = useState(false); const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); - setLoading(true); setError(""); + if (password.length < 6) { + setError("密码至少6位"); + return; + } + setLoading(true); try { - const r = await fetch("/api/auth/register", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ email, password, discord_id: discordId || undefined }), - }); - const data = await r.json(); - if (!r.ok) { setError(data.detail || "注册失败"); return; } - router.push("/login?registered=1"); - } catch { - setError("网络错误,请重试"); + await register(email, password, inviteCode); + router.push("/"); + } catch (err: any) { + setError(err.message || "注册失败"); } finally { setLoading(false); } }; return ( -
-
-
-

注册账号

-

注册后可接收套利信号推送

+
+
+
+
+

⚡ Arbitrage Engine

+

注册新账户(需要邀请码)

+
+ +
+
+ + setInviteCode(e.target.value.toUpperCase())} + className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent font-mono tracking-wider" + placeholder="输入邀请码" required + /> +
+
+ + setEmail(e.target.value)} + className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent" + placeholder="you@example.com" required + /> +
+
+ + setPassword(e.target.value)} + className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent" + placeholder="至少6位" required minLength={6} + /> +
+ + {error &&

{error}

} + + +
+ +

+ 已有账户?{" "} + 登录 +

-
-
- - setEmail(e.target.value)} - className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500" - placeholder="your@email.com" - /> -
-
- - setPassword(e.target.value)} - className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500" - placeholder="至少8位" - minLength={8} - /> -
-
- - setDiscordId(e.target.value)} - className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500" - placeholder="例:123456789012345678" - /> -
- {error &&

{error}

} - -
-

- 已有账号?登录 -

); diff --git a/frontend/components/AuthHeader.tsx b/frontend/components/AuthHeader.tsx new file mode 100644 index 0000000..da80d39 --- /dev/null +++ b/frontend/components/AuthHeader.tsx @@ -0,0 +1,54 @@ +"use client"; + +import Link from "next/link"; +import { useAuth } from "@/lib/auth"; + +export default function AuthHeader() { + const { user, isLoggedIn, logout } = useAuth(); + + return ( + <> + {/* 桌面端顶栏 */} +
+ {isLoggedIn ? ( + <> + {user?.email} + {user?.role === "admin" && ( + Admin + )} + + + ) : ( + <> + + 登录 + + + 注册 + + + )} +
+ + {/* 手机端顶栏 */} +
+ ⚡ ArbEngine +
+ {isLoggedIn ? ( + + ) : ( + <> + 登录 + 注册 + + )} +
+
+ + ); +} diff --git a/frontend/components/Navbar.tsx b/frontend/components/Navbar.tsx index bdc3ec1..0f6ac66 100644 --- a/frontend/components/Navbar.tsx +++ b/frontend/components/Navbar.tsx @@ -1,17 +1,11 @@ "use client"; import Link from "next/link"; import { useState } from "react"; - -const navLinks = [ - { href: "/", label: "仪表盘" }, - { href: "/kline", label: "K线" }, - { href: "/live", label: "实时" }, - { href: "/signals", label: "信号" }, - { href: "/about", label: "说明" }, -]; +import { useAuth } from "@/lib/auth"; export default function Navbar() { const [open, setOpen] = useState(false); + const { user, isLoggedIn, logout } = useAuth(); return ( diff --git a/frontend/lib/api.ts b/frontend/lib/api.ts index 36bc582..a20c5e9 100644 --- a/frontend/lib/api.ts +++ b/frontend/lib/api.ts @@ -1,3 +1,5 @@ +import { authFetch } from "./auth"; + const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? ""; export interface RateData { @@ -84,21 +86,31 @@ export interface YtdStatsResponse { ETH: YtdStats; } -async function fetchAPI(path: string): Promise { +// Public fetch (no auth needed) +async function fetchPublic(path: string): Promise { const res = await fetch(`${API_BASE}${path}`, { cache: "no-store" }); if (!res.ok) throw new Error(`API error ${res.status}`); return res.json(); } +// Protected fetch (auth required, auto-refresh) +async function fetchProtected(path: string): Promise { + const res = await authFetch(path, { cache: "no-store" }); + if (!res.ok) throw new Error(`API error ${res.status}`); + return res.json(); +} + export const api = { - rates: () => fetchAPI("/api/rates"), - history: () => fetchAPI("/api/history"), - stats: () => fetchAPI("/api/stats"), - health: () => fetchAPI<{ status: string }>("/api/health"), - signalsHistory: () => fetchAPI("/api/signals/history"), + // Public + rates: () => fetchPublic("/api/rates"), + health: () => fetchPublic<{ status: string }>("/api/health"), + // Protected + history: () => fetchProtected("/api/history"), + stats: () => fetchProtected("/api/stats"), + signalsHistory: () => fetchProtected("/api/signals/history"), snapshots: (hours = 24, limit = 5000) => - fetchAPI(`/api/snapshots?hours=${hours}&limit=${limit}`), + fetchProtected(`/api/snapshots?hours=${hours}&limit=${limit}`), kline: (symbol = "BTC", interval = "1h", limit = 500) => - fetchAPI(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`), - statsYtd: () => fetchAPI("/api/stats/ytd"), + fetchProtected(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`), + statsYtd: () => fetchProtected("/api/stats/ytd"), }; diff --git a/frontend/lib/auth.tsx b/frontend/lib/auth.tsx new file mode 100644 index 0000000..bbd33b1 --- /dev/null +++ b/frontend/lib/auth.tsx @@ -0,0 +1,137 @@ +"use client"; + +import { createContext, useContext, useState, useEffect, ReactNode, useCallback } from "react"; + +interface User { + id: number; + email: string; + role: string; +} + +interface AuthState { + user: User | null; + accessToken: string | null; + loading: boolean; + login: (email: string, password: string) => Promise; + register: (email: string, password: string, inviteCode: string) => Promise; + logout: () => void; + isLoggedIn: boolean; + isAdmin: boolean; +} + +const AuthContext = createContext(undefined); + +const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? ""; + +export function AuthProvider({ children }: { children: ReactNode }) { + const [user, setUser] = useState(null); + const [accessToken, setAccessToken] = useState(null); + const [loading, setLoading] = useState(true); + + // init from localStorage + useEffect(() => { + const token = localStorage.getItem("access_token"); + const saved = localStorage.getItem("user"); + if (token && saved) { + try { + setAccessToken(token); + setUser(JSON.parse(saved)); + } catch {} + } + setLoading(false); + }, []); + + const saveAuth = (data: { access_token: string; refresh_token: string; user: User }) => { + localStorage.setItem("access_token", data.access_token); + localStorage.setItem("refresh_token", data.refresh_token); + localStorage.setItem("user", JSON.stringify(data.user)); + setAccessToken(data.access_token); + setUser(data.user); + }; + + const login = useCallback(async (email: string, password: string) => { + const res = await fetch(`${API_BASE}/api/auth/login`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ email, password }), + }); + if (!res.ok) { + const err = await res.json().catch(() => ({})); + throw new Error(err.detail || "Login failed"); + } + const data = await res.json(); + saveAuth(data); + }, []); + + const register = useCallback(async (email: string, password: string, inviteCode: string) => { + const res = await fetch(`${API_BASE}/api/auth/register`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ email, password, invite_code: inviteCode }), + }); + if (!res.ok) { + const err = await res.json().catch(() => ({})); + throw new Error(err.detail || "Registration failed"); + } + const data = await res.json(); + saveAuth(data); + }, []); + + const logout = useCallback(() => { + localStorage.removeItem("access_token"); + localStorage.removeItem("refresh_token"); + localStorage.removeItem("user"); + setAccessToken(null); + setUser(null); + }, []); + + return ( + + {children} + + ); +} + +export function useAuth() { + const ctx = useContext(AuthContext); + if (!ctx) throw new Error("useAuth must be used within AuthProvider"); + return ctx; +} + +// Authenticated fetch helper +export async function authFetch(path: string, options: RequestInit = {}): Promise { + const token = localStorage.getItem("access_token"); + const headers = new Headers(options.headers); + if (token) headers.set("Authorization", `Bearer ${token}`); + + let res = await fetch(`${API_BASE}${path}`, { ...options, headers }); + + // try refresh on 401 + if (res.status === 401) { + const refreshToken = localStorage.getItem("refresh_token"); + if (refreshToken) { + const refreshRes = await fetch(`${API_BASE}/api/auth/refresh`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ refresh_token: refreshToken }), + }); + if (refreshRes.ok) { + const data = await refreshRes.json(); + localStorage.setItem("access_token", data.access_token); + localStorage.setItem("refresh_token", data.refresh_token); + headers.set("Authorization", `Bearer ${data.access_token}`); + res = await fetch(`${API_BASE}${path}`, { ...options, headers }); + } else { + // refresh failed, clear auth + localStorage.removeItem("access_token"); + localStorage.removeItem("refresh_token"); + localStorage.removeItem("user"); + } + } + } + + return res; +}