feat: V2.0 auth system - JWT access/refresh, invite codes, route protection, admin CLI, auth gate blur overlay

This commit is contained in:
root 2026-02-27 11:08:57 +00:00
parent 052e5a0541
commit 1ab228286c
11 changed files with 784 additions and 214 deletions

115
backend/admin_cli.py Normal file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""Admin CLI for Arbitrage Engine"""
import sys, os, sqlite3, secrets, json
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
def get_conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def gen_invite(count=1, max_uses=1):
conn = get_conn()
codes = []
for _ in range(count):
code = secrets.token_urlsafe(6)[:8].upper()
conn.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)",
(code, max_uses)
)
codes.append(code)
conn.commit()
conn.close()
for c in codes:
print(f" {c}")
print(f"\nGenerated {len(codes)} invite code(s)")
def list_invites():
conn = get_conn()
rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall()
conn.close()
if not rows:
print("No invite codes found")
return
print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}")
print("-" * 60)
for r in rows:
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}")
def disable_invite(code):
conn = get_conn()
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,))
conn.commit()
conn.close()
print(f"Disabled invite code: {code}")
def list_users():
conn = get_conn()
rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall()
conn.close()
if not rows:
print("No users found")
return
print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}")
print("-" * 72)
for r in rows:
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}")
def ban_user(user_id):
conn = get_conn()
conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,))
conn.commit()
conn.close()
print(f"Banned user ID: {user_id}")
def unban_user(user_id):
conn = get_conn()
conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,))
conn.commit()
conn.close()
print(f"Unbanned user ID: {user_id}")
def set_admin(user_id):
conn = get_conn()
conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,))
conn.commit()
conn.close()
print(f"Set user {user_id} as admin")
def usage():
print("""Usage: python3 admin_cli.py <command> [args]
Commands:
gen-invite [count] [max_uses] Generate invite codes (default: 1 code, 1 use)
list-invites List all invite codes
disable-invite <CODE> Disable an invite code
list-users List all users
ban-user <USER_ID> Ban a user
unban-user <USER_ID> Unban a user
set-admin <USER_ID> Set user as admin
""")
if __name__ == "__main__":
if len(sys.argv) < 2:
usage()
sys.exit(1)
cmd = sys.argv[1]
if cmd == "gen-invite":
count = int(sys.argv[2]) if len(sys.argv) > 2 else 1
max_uses = int(sys.argv[3]) if len(sys.argv) > 3 else 1
gen_invite(count, max_uses)
elif cmd == "list-invites":
list_invites()
elif cmd == "disable-invite":
disable_invite(sys.argv[2])
elif cmd == "list-users":
list_users()
elif cmd == "ban-user":
ban_user(int(sys.argv[2]))
elif cmd == "unban-user":
unban_user(int(sys.argv[2]))
elif cmd == "set-admin":
set_admin(int(sys.argv[2]))
else:
usage()

View File

@ -7,17 +7,21 @@ import base64
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from functools import wraps
from fastapi import APIRouter, HTTPException, Header from fastapi import APIRouter, HTTPException, Header, Depends, Request
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me") JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
JWT_EXPIRE_HOURS = 24 * 7 ACCESS_TOKEN_HOURS = 24
REFRESH_TOKEN_DAYS = 7
router = APIRouter(prefix="/api", tags=["auth"]) router = APIRouter(prefix="/api", tags=["auth"])
# ─── DB helpers ───────────────────────────────────────────────
def get_conn(): def get_conn():
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@ -27,29 +31,26 @@ def get_conn():
def ensure_tables(): def ensure_tables():
conn = get_conn() conn = get_conn()
cur = conn.cursor() cur = conn.cursor()
cur.execute( cur.execute("""
"""
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT UNIQUE NOT NULL, email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL, password_hash TEXT NOT NULL,
discord_id TEXT, discord_id TEXT,
role TEXT NOT NULL DEFAULT 'user',
banned INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL created_at TEXT NOT NULL
) )
""" """)
) cur.execute("""
cur.execute(
"""
CREATE TABLE IF NOT EXISTS subscriptions ( CREATE TABLE IF NOT EXISTS subscriptions (
user_id INTEGER PRIMARY KEY, user_id INTEGER PRIMARY KEY,
tier TEXT NOT NULL DEFAULT 'free', tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT, expires_at TEXT,
FOREIGN KEY(user_id) REFERENCES users(id) FOREIGN KEY(user_id) REFERENCES users(id)
) )
""" """)
) cur.execute("""
cur.execute(
"""
CREATE TABLE IF NOT EXISTS signal_logs ( CREATE TABLE IF NOT EXISTS signal_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL, symbol TEXT NOT NULL,
@ -58,12 +59,55 @@ def ensure_tables():
sent_at TEXT NOT NULL, sent_at TEXT NOT NULL,
message TEXT NOT NULL message TEXT NOT NULL
) )
""" """)
) cur.execute("""
CREATE TABLE IF NOT EXISTS invite_codes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT UNIQUE NOT NULL,
created_by INTEGER,
max_uses INTEGER NOT NULL DEFAULT 1,
used_count INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active',
expires_at TEXT,
created_at TEXT DEFAULT (datetime('now'))
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS invite_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
used_at TEXT DEFAULT (datetime('now')),
FOREIGN KEY (code_id) REFERENCES invite_codes(id),
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT DEFAULT (datetime('now')),
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
# migrate: add role/banned columns if missing
try:
cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'")
except:
pass
try:
cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0")
except:
pass
conn.commit() conn.commit()
conn.close() conn.close()
# ─── Password utils ──────────────────────────────────────────
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
salt = secrets.token_hex(16) salt = secrets.token_hex(16)
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
@ -79,19 +123,37 @@ def verify_password(password: str, stored: str) -> bool:
return False return False
# ─── JWT utils ───────────────────────────────────────────────
def b64url(data: bytes) -> str: def b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode() return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
def create_token(user_id: int, email: str) -> str: def create_access_token(user_id: int, email: str, role: str) -> str:
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode()) header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp()) exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp())
payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode()) payload = b64url(json.dumps({
"sub": user_id, "email": email, "role": role,
"exp": exp, "type": "access"
}, separators=(",", ":")).encode())
sign_input = f"{header}.{payload}".encode() sign_input = f"{header}.{payload}".encode()
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest() signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
return f"{header}.{payload}.{b64url(signature)}" return f"{header}.{payload}.{b64url(signature)}"
def create_refresh_token(user_id: int) -> str:
token = secrets.token_urlsafe(48)
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
conn = get_conn()
conn.execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)",
(user_id, token, expires_at)
)
conn.commit()
conn.close()
return token
def parse_token(token: str) -> Optional[dict]: def parse_token(token: str) -> Optional[dict]:
try: try:
header_b64, payload_b64, sig_b64 = token.split(".") header_b64, payload_b64, sig_b64 = token.split(".")
@ -108,24 +170,41 @@ def parse_token(token: str) -> Optional[dict]:
return None return None
def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row: # ─── Auth dependency ─────────────────────────────────────────
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
"""FastAPI dependency: returns user dict or raises 401"""
if not authorization or not authorization.startswith("Bearer "): if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing token") raise HTTPException(status_code=401, detail="missing token")
token = authorization.split(" ", 1)[1].strip() token = authorization.split(" ", 1)[1].strip()
payload = parse_token(token) payload = parse_token(token)
if not payload: if not payload:
raise HTTPException(status_code=401, detail="invalid token") raise HTTPException(status_code=401, detail="invalid or expired token")
if payload.get("type") != "access":
raise HTTPException(status_code=401, detail="invalid token type")
conn = get_conn() conn = get_conn()
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone() user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
conn.close() conn.close()
if not user: if not user:
raise HTTPException(status_code=401, detail="user not found") raise HTTPException(status_code=401, detail="user not found")
if user["banned"]:
raise HTTPException(status_code=403, detail="account banned")
return dict(user)
def require_admin(user: dict = Depends(get_current_user)) -> dict:
"""FastAPI dependency: requires admin role"""
if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="admin required")
return user return user
# ─── Request models ──────────────────────────────────────────
class RegisterReq(BaseModel): class RegisterReq(BaseModel):
email: EmailStr email: EmailStr
password: str password: str
invite_code: str
class LoginReq(BaseModel): class LoginReq(BaseModel):
@ -133,18 +212,47 @@ class LoginReq(BaseModel):
password: str password: str
class BindDiscordReq(BaseModel): class RefreshReq(BaseModel):
discord_id: str refresh_token: str
class GenInviteReq(BaseModel):
count: int = 1
max_uses: int = 1
class BanUserReq(BaseModel):
banned: bool = True
# ─── Auth routes ─────────────────────────────────────────────
@router.post("/auth/register") @router.post("/auth/register")
def register(body: RegisterReq): def register(body: RegisterReq):
ensure_tables() ensure_tables()
conn = get_conn() conn = get_conn()
# verify invite code
invite = conn.execute(
"SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,)
).fetchone()
if not invite:
conn.close()
raise HTTPException(status_code=400, detail="invalid invite code")
if invite["status"] != "active":
conn.close()
raise HTTPException(status_code=400, detail="invite code disabled")
if invite["used_count"] >= invite["max_uses"]:
conn.close()
raise HTTPException(status_code=400, detail="invite code exhausted")
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
conn.close()
raise HTTPException(status_code=400, detail="invite code expired")
try: try:
pwd_hash = hash_password(body.password) pwd_hash = hash_password(body.password)
cur = conn.execute( cur = conn.execute(
"INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)", "INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)",
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
) )
user_id = cur.lastrowid user_id = cur.lastrowid
@ -152,14 +260,39 @@ def register(body: RegisterReq):
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)", "INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
(user_id,), (user_id,),
) )
# record invite usage
conn.execute(
"INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)",
(invite["id"], user_id),
)
conn.execute(
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?",
(invite["id"],),
)
# auto-exhaust if max reached
new_count = invite["used_count"] + 1
if new_count >= invite["max_uses"]:
conn.execute(
"UPDATE invite_codes SET status = 'exhausted' WHERE id = ?",
(invite["id"],),
)
conn.commit() conn.commit()
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
conn.close() conn.close()
raise HTTPException(status_code=400, detail="email already registered") raise HTTPException(status_code=400, detail="email already registered")
user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone() # issue tokens
user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
conn.close() conn.close()
return dict(user) access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
}
@router.post("/auth/login") @router.post("/auth/login")
@ -170,27 +303,113 @@ def login(body: LoginReq):
conn.close() conn.close()
if not user or not verify_password(body.password, user["password_hash"]): if not user or not verify_password(body.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="invalid credentials") raise HTTPException(status_code=401, detail="invalid credentials")
token = create_token(user["id"], user["email"]) if user["banned"]:
return {"token": token, "token_type": "bearer"} raise HTTPException(status_code=403, detail="account banned")
access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
}
@router.post("/user/bind-discord") @router.post("/auth/refresh")
def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)): def refresh_token(body: RefreshReq):
user = get_user_from_auth(authorization)
conn = get_conn() conn = get_conn()
conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"])) row = conn.execute(
"SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,)
).fetchone()
if not row:
conn.close()
raise HTTPException(status_code=401, detail="invalid refresh token")
if row["expires_at"] < datetime.utcnow().isoformat():
conn.close()
raise HTTPException(status_code=401, detail="refresh token expired")
user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone()
if not user or user["banned"]:
conn.close()
raise HTTPException(status_code=403, detail="account unavailable")
# revoke old, issue new
conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],))
conn.commit() conn.commit()
updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone()
conn.close() conn.close()
return dict(updated) access = create_access_token(user["id"], user["email"], user["role"])
new_refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": new_refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
}
@router.get("/user/me") @router.get("/auth/me")
def me(authorization: Optional[str] = Header(default=None)): def me(user: dict = Depends(get_current_user)):
user = get_user_from_auth(authorization)
conn = get_conn() conn = get_conn()
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone() sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
conn.close() conn.close()
out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]} return {
out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None} "id": user["id"], "email": user["email"], "role": user["role"],
return out "discord_id": user.get("discord_id"),
"created_at": user["created_at"],
"subscription": dict(sub) if sub else {"tier": "free", "expires_at": None},
}
# ─── Admin routes ────────────────────────────────────────────
@router.post("/admin/invite-codes")
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
conn = get_conn()
codes = []
for _ in range(body.count):
code = secrets.token_urlsafe(6)[:8].upper()
conn.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)",
(code, admin["id"], body.max_uses),
)
codes.append(code)
conn.commit()
conn.close()
return {"codes": codes}
@router.get("/admin/invite-codes")
def list_invite_codes(admin: dict = Depends(require_admin)):
conn = get_conn()
rows = conn.execute(
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
).fetchall()
conn.close()
return {"items": [dict(r) for r in rows]}
@router.delete("/admin/invite-codes/{code_id}")
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
conn = get_conn()
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,))
conn.commit()
conn.close()
return {"ok": True}
@router.get("/admin/users")
def list_users(admin: dict = Depends(require_admin)):
conn = get_conn()
rows = conn.execute(
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
).fetchall()
conn.close()
return {"items": [dict(r) for r in rows]}
@router.put("/admin/users/{user_id}/ban")
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
conn = get_conn()
conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id))
conn.commit()
conn.close()
return {"ok": True, "user_id": user_id, "banned": body.banned}

View File

@ -1,9 +1,11 @@
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import httpx import httpx
from datetime import datetime, timedelta from datetime import datetime, timedelta
import asyncio, time, sqlite3, os import asyncio, time, sqlite3, os
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
app = FastAPI(title="Arbitrage Engine API") app = FastAPI(title="Arbitrage Engine API")
app.add_middleware( app.add_middleware(
@ -13,6 +15,8 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(auth_router)
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"] SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
@ -101,6 +105,7 @@ async def background_snapshot_loop():
@app.on_event("startup") @app.on_event("startup")
async def startup(): async def startup():
init_db() init_db()
ensure_auth_tables()
asyncio.create_task(background_snapshot_loop()) asyncio.create_task(background_snapshot_loop())
@ -137,7 +142,7 @@ async def get_rates():
@app.get("/api/snapshots") @app.get("/api/snapshots")
async def get_snapshots(hours: int = 24, limit: int = 5000): async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
"""查询本地落库的实时快照数据""" """查询本地落库的实时快照数据"""
since = int(time.time()) - hours * 3600 since = int(time.time()) - hours * 3600
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
@ -155,7 +160,7 @@ async def get_snapshots(hours: int = 24, limit: int = 5000):
@app.get("/api/kline") @app.get("/api/kline")
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500): async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
""" """
rate_snapshots 聚合K线数据 rate_snapshots 聚合K线数据
symbol: BTC | ETH symbol: BTC | ETH
@ -212,7 +217,7 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500)
@app.get("/api/stats/ytd") @app.get("/api/stats/ytd")
async def get_stats_ytd(): async def get_stats_ytd(user: dict = Depends(get_current_user)):
"""今年以来(YTD)资金费率年化统计""" """今年以来(YTD)资金费率年化统计"""
cached = get_cache("stats_ytd", 3600) cached = get_cache("stats_ytd", 3600)
if cached: return cached if cached: return cached
@ -245,7 +250,7 @@ async def get_stats_ytd():
@app.get("/api/signals/history") @app.get("/api/signals/history")
async def get_signals_history(limit: int = 100): async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
"""查询信号推送历史""" """查询信号推送历史"""
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
@ -261,7 +266,7 @@ async def get_signals_history(limit: int = 100):
@app.get("/api/history") @app.get("/api/history")
async def get_history(): async def get_history(user: dict = Depends(get_current_user)):
cached = get_cache("history", 60) cached = get_cache("history", 60)
if cached: return cached if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000) end_time = int(datetime.utcnow().timestamp() * 1000)
@ -288,7 +293,7 @@ async def get_history():
@app.get("/api/stats") @app.get("/api/stats")
async def get_stats(): async def get_stats(user: dict = Depends(get_current_user)):
cached = get_cache("stats", 60) cached = get_cache("stats", 60)
if cached: return cached if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000) end_time = int(datetime.utcnow().timestamp() * 1000)

View File

@ -2,7 +2,8 @@ import type { Metadata } from "next";
import { Geist, Geist_Mono } from "next/font/google"; import { Geist, Geist_Mono } from "next/font/google";
import "./globals.css"; import "./globals.css";
import Sidebar from "@/components/Sidebar"; import Sidebar from "@/components/Sidebar";
import Link from "next/link"; import { AuthProvider } from "@/lib/auth";
import AuthHeader from "@/components/AuthHeader";
const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"] }); const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"] });
const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"] }); const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"] });
@ -16,25 +17,17 @@ export default function RootLayout({ children }: Readonly<{ children: React.Reac
return ( return (
<html lang="zh"> <html lang="zh">
<body className={`${geistSans.variable} ${geistMono.variable} antialiased min-h-screen bg-slate-50 text-slate-900`}> <body className={`${geistSans.variable} ${geistMono.variable} antialiased min-h-screen bg-slate-50 text-slate-900`}>
<div className="flex min-h-screen"> <AuthProvider>
<Sidebar /> <div className="flex min-h-screen">
<div className="flex-1 flex flex-col min-w-0"> <Sidebar />
{/* 桌面端顶栏:右上角登录注册 */} <div className="flex-1 flex flex-col min-w-0">
<header className="hidden md:flex items-center justify-end px-6 py-3 bg-white border-b border-slate-200 gap-3"> <AuthHeader />
<Link href="/login" <main className="flex-1 p-4 md:p-6 pt-16 md:pt-6">
className="text-sm text-slate-600 hover:text-blue-600 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-blue-300 transition-colors"> {children}
</main>
</Link> </div>
<Link href="/register"
className="text-sm text-white bg-blue-600 hover:bg-blue-700 px-3 py-1.5 rounded-lg transition-colors font-medium">
</Link>
</header>
<main className="flex-1 p-4 md:p-6 pt-16 md:pt-6">
{children}
</main>
</div> </div>
</div> </AuthProvider>
</body> </body>
</html> </html>
); );

View File

@ -1,10 +1,13 @@
"use client"; "use client";
import { useState, Suspense } from "react";
import { useRouter, useSearchParams } from "next/navigation";
function LoginForm() { import { useState } from "react";
import { useAuth } from "@/lib/auth";
import { useRouter } from "next/navigation";
import Link from "next/link";
export default function LoginPage() {
const { login } = useAuth();
const router = useRouter(); const router = useRouter();
const params = useSearchParams();
const [email, setEmail] = useState(""); const [email, setEmail] = useState("");
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
const [error, setError] = useState(""); const [error, setError] = useState("");
@ -12,74 +15,61 @@ function LoginForm() {
const handleSubmit = async (e: React.FormEvent) => { const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault(); e.preventDefault();
setLoading(true);
setError(""); setError("");
setLoading(true);
try { try {
const form = new URLSearchParams(); await login(email, password);
form.append("username", email); router.push("/");
form.append("password", password); } catch (err: any) {
const r = await fetch("/api/auth/login", { setError(err.message || "登录失败");
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: form.toString(),
});
const data = await r.json();
if (!r.ok) { setError(data.detail || "登录失败"); return; }
localStorage.setItem("arb_token", data.access_token);
router.push("/dashboard");
} catch {
setError("网络错误,请重试");
} finally { } finally {
setLoading(false); setLoading(false);
} }
}; };
return ( return (
<div className="min-h-[70vh] flex items-center justify-center"> <div className="min-h-screen bg-slate-50 flex items-center justify-center p-4">
<div className="w-full max-w-md rounded-xl border border-slate-200 bg-white p-8 space-y-6"> <div className="w-full max-w-md">
<div> <div className="bg-white rounded-xl shadow-sm border border-slate-200 p-8">
<h1 className="text-2xl font-bold text-slate-900"></h1> <div className="text-center mb-8">
{params.get("registered") && ( <h1 className="text-2xl font-bold text-slate-800"> Arbitrage Engine</h1>
<p className="text-emerald-400 text-sm mt-1"> </p> <p className="text-slate-500 text-sm mt-2"></p>
)} </div>
<p className="text-slate-500 text-sm mt-1"></p>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1"></label>
<input
type="email" value={email} onChange={e => setEmail(e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
placeholder="you@example.com" required
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1"></label>
<input
type="password" value={password} onChange={e => setPassword(e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
placeholder="••••••••" required
/>
</div>
{error && <p className="text-red-500 text-sm">{error}</p>}
<button
type="submit" disabled={loading}
className="w-full bg-blue-600 hover:bg-blue-700 text-white py-2.5 rounded-lg font-medium text-sm transition-colors disabled:opacity-50"
>
{loading ? "登录中..." : "登录"}
</button>
</form>
<p className="text-center text-sm text-slate-500 mt-6">
{" "}
<Link href="/register" className="text-blue-600 hover:underline"></Link>
</p>
</div> </div>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label className="block text-sm text-slate-700 mb-1"></label>
<input
type="email" required value={email} onChange={e => setEmail(e.target.value)}
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
placeholder="your@email.com"
/>
</div>
<div>
<label className="block text-sm text-slate-700 mb-1"></label>
<input
type="password" required value={password} onChange={e => setPassword(e.target.value)}
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
/>
</div>
{error && <p className="text-red-400 text-sm">{error}</p>}
<button
type="submit" disabled={loading}
className="w-full bg-cyan-600 hover:bg-cyan-500 disabled:opacity-50 text-white font-medium py-2 rounded-lg text-sm transition-colors"
>
{loading ? "登录中..." : "登录"}
</button>
</form>
<p className="text-center text-sm text-slate-500">
<a href="/register" className="text-blue-600 hover:underline"></a>
</p>
</div> </div>
</div> </div>
); );
} }
export default function LoginPage() {
return (
<Suspense>
<LoginForm />
</Suspense>
);
}

View File

@ -3,8 +3,10 @@
import { useEffect, useState, useCallback, useRef } from "react"; import { useEffect, useState, useCallback, useRef } from "react";
import { createChart, ColorType, CandlestickSeries } from "lightweight-charts"; import { createChart, ColorType, CandlestickSeries } from "lightweight-charts";
import { api, RatesResponse, StatsResponse, HistoryResponse, HistoryPoint, SignalHistoryItem, KBar, YtdStatsResponse } from "@/lib/api"; import { api, RatesResponse, StatsResponse, HistoryResponse, HistoryPoint, SignalHistoryItem, KBar, YtdStatsResponse } from "@/lib/api";
import { useAuth } from "@/lib/auth";
import RateCard from "@/components/RateCard"; import RateCard from "@/components/RateCard";
import StatsCard from "@/components/StatsCard"; import StatsCard from "@/components/StatsCard";
import Link from "next/link";
import { import {
LineChart, Line, XAxis, YAxis, Tooltip, Legend, LineChart, Line, XAxis, YAxis, Tooltip, Legend,
ResponsiveContainer, ReferenceLine ResponsiveContainer, ReferenceLine
@ -91,8 +93,35 @@ function IntervalPicker({ value, onChange }: { value: string; onChange: (v: stri
); );
} }
// ─── 未登录遮挡组件 ──────────────────────────────────────────────
function AuthGate({ children }: { children: React.ReactNode }) {
const { isLoggedIn, loading } = useAuth();
if (loading) return <div className="text-center text-slate-400 py-12">...</div>;
if (!isLoggedIn) {
return (
<div className="relative">
<div className="filter blur-sm pointer-events-none select-none opacity-60">
{children}
</div>
<div className="absolute inset-0 flex items-center justify-center bg-white/70 rounded-xl">
<div className="text-center space-y-3">
<div className="text-4xl">🔒</div>
<p className="text-slate-600 font-medium"></p>
<div className="flex gap-2 justify-center">
<Link href="/login" className="text-sm bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg transition-colors"></Link>
<Link href="/register" className="text-sm border border-slate-300 text-slate-600 hover:border-blue-400 px-4 py-2 rounded-lg transition-colors"></Link>
</div>
</div>
</div>
</div>
);
}
return <>{children}</>;
}
// ─── 主仪表盘 ──────────────────────────────────────────────────── // ─── 主仪表盘 ────────────────────────────────────────────────────
export default function Dashboard() { export default function Dashboard() {
const { isLoggedIn } = useAuth();
const [rates, setRates] = useState<RatesResponse | null>(null); const [rates, setRates] = useState<RatesResponse | null>(null);
const [stats, setStats] = useState<StatsResponse | null>(null); const [stats, setStats] = useState<StatsResponse | null>(null);
const [history, setHistory] = useState<HistoryResponse | null>(null); const [history, setHistory] = useState<HistoryResponse | null>(null);
@ -110,11 +139,12 @@ export default function Dashboard() {
}, []); }, []);
const fetchSlow = useCallback(async () => { const fetchSlow = useCallback(async () => {
if (!isLoggedIn) return;
try { try {
const [s, h, sig, y] = await Promise.all([api.stats(), api.history(), api.signalsHistory(), api.statsYtd()]); const [s, h, sig, y] = await Promise.all([api.stats(), api.history(), api.signalsHistory(), api.statsYtd()]);
setStats(s); setHistory(h); setSignals(sig.items || []); setYtd(y); setStats(s); setHistory(h); setSignals(sig.items || []); setYtd(y);
} catch {} } catch {}
}, []); }, [isLoggedIn]);
useEffect(() => { useEffect(() => {
fetchRates(); fetchSlow(); fetchRates(); fetchSlow();
@ -154,6 +184,7 @@ export default function Dashboard() {
</div> </div>
{/* 统计卡片 */} {/* 统计卡片 */}
<AuthGate>
{stats && ( {stats && (
<div className="grid grid-cols-3 gap-3"> <div className="grid grid-cols-3 gap-3">
<StatsCard title="BTC 套利" mean7d={stats.BTC.mean7d} annualized={stats.BTC.annualized} accent="blue" /> <StatsCard title="BTC 套利" mean7d={stats.BTC.mean7d} annualized={stats.BTC.annualized} accent="blue" />
@ -283,6 +314,7 @@ export default function Dashboard() {
<span className="text-blue-600 font-medium"></span> <span className="text-blue-600 font-medium"></span>
+ 8 + 8
</div> </div>
</AuthGate>
</div> </div>
); );
} }

View File

@ -1,79 +1,87 @@
"use client"; "use client";
import { useState } from "react"; import { useState } from "react";
import { useAuth } from "@/lib/auth";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import Link from "next/link";
export default function RegisterPage() { export default function RegisterPage() {
const { register } = useAuth();
const router = useRouter(); const router = useRouter();
const [email, setEmail] = useState(""); const [email, setEmail] = useState("");
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
const [discordId, setDiscordId] = useState(""); const [inviteCode, setInviteCode] = useState("");
const [error, setError] = useState(""); const [error, setError] = useState("");
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const handleSubmit = async (e: React.FormEvent) => { const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault(); e.preventDefault();
setLoading(true);
setError(""); setError("");
if (password.length < 6) {
setError("密码至少6位");
return;
}
setLoading(true);
try { try {
const r = await fetch("/api/auth/register", { await register(email, password, inviteCode);
method: "POST", router.push("/");
headers: { "Content-Type": "application/json" }, } catch (err: any) {
body: JSON.stringify({ email, password, discord_id: discordId || undefined }), setError(err.message || "注册失败");
});
const data = await r.json();
if (!r.ok) { setError(data.detail || "注册失败"); return; }
router.push("/login?registered=1");
} catch {
setError("网络错误,请重试");
} finally { } finally {
setLoading(false); setLoading(false);
} }
}; };
return ( return (
<div className="min-h-[70vh] flex items-center justify-center"> <div className="min-h-screen bg-slate-50 flex items-center justify-center p-4">
<div className="w-full max-w-md rounded-xl border border-slate-200 bg-white p-8 space-y-6"> <div className="w-full max-w-md">
<div> <div className="bg-white rounded-xl shadow-sm border border-slate-200 p-8">
<h1 className="text-2xl font-bold text-slate-900"></h1> <div className="text-center mb-8">
<p className="text-slate-500 text-sm mt-1"></p> <h1 className="text-2xl font-bold text-slate-800"> Arbitrage Engine</h1>
<p className="text-slate-500 text-sm mt-2"></p>
</div>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1"></label>
<input
type="text" value={inviteCode} onChange={e => setInviteCode(e.target.value.toUpperCase())}
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent font-mono tracking-wider"
placeholder="输入邀请码" required
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1"></label>
<input
type="email" value={email} onChange={e => setEmail(e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
placeholder="you@example.com" required
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1"></label>
<input
type="password" value={password} onChange={e => setPassword(e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
placeholder="至少6位" required minLength={6}
/>
</div>
{error && <p className="text-red-500 text-sm">{error}</p>}
<button
type="submit" disabled={loading}
className="w-full bg-blue-600 hover:bg-blue-700 text-white py-2.5 rounded-lg font-medium text-sm transition-colors disabled:opacity-50"
>
{loading ? "注册中..." : "注册"}
</button>
</form>
<p className="text-center text-sm text-slate-500 mt-6">
{" "}
<Link href="/login" className="text-blue-600 hover:underline"></Link>
</p>
</div> </div>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label className="block text-sm text-slate-700 mb-1"></label>
<input
type="email" required value={email} onChange={e => setEmail(e.target.value)}
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
placeholder="your@email.com"
/>
</div>
<div>
<label className="block text-sm text-slate-700 mb-1"></label>
<input
type="password" required value={password} onChange={e => setPassword(e.target.value)}
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
placeholder="至少8位"
minLength={8}
/>
</div>
<div>
<label className="block text-sm text-slate-700 mb-1">Discord ID <span className="text-slate-500"></span></label>
<input
type="text" value={discordId} onChange={e => setDiscordId(e.target.value)}
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
placeholder="例123456789012345678"
/>
</div>
{error && <p className="text-red-400 text-sm">{error}</p>}
<button
type="submit" disabled={loading}
className="w-full bg-cyan-600 hover:bg-cyan-500 disabled:opacity-50 text-white font-medium py-2 rounded-lg text-sm transition-colors"
>
{loading ? "注册中..." : "注册"}
</button>
</form>
<p className="text-center text-sm text-slate-500">
<a href="/login" className="text-blue-600 hover:underline"></a>
</p>
</div> </div>
</div> </div>
); );

View File

@ -0,0 +1,54 @@
"use client";
import Link from "next/link";
import { useAuth } from "@/lib/auth";
export default function AuthHeader() {
const { user, isLoggedIn, logout } = useAuth();
return (
<>
{/* 桌面端顶栏 */}
<header className="hidden md:flex items-center justify-end px-6 py-3 bg-white border-b border-slate-200 gap-3">
{isLoggedIn ? (
<>
<span className="text-sm text-slate-500">{user?.email}</span>
{user?.role === "admin" && (
<span className="bg-amber-100 text-amber-700 text-xs px-2 py-0.5 rounded-full font-medium">Admin</span>
)}
<button onClick={logout}
className="text-sm text-slate-600 hover:text-red-500 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-red-300 transition-colors">
退
</button>
</>
) : (
<>
<Link href="/login"
className="text-sm text-slate-600 hover:text-blue-600 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-blue-300 transition-colors">
</Link>
<Link href="/register"
className="text-sm text-white bg-blue-600 hover:bg-blue-700 px-3 py-1.5 rounded-lg transition-colors font-medium">
</Link>
</>
)}
</header>
{/* 手机端顶栏 */}
<header className="md:hidden flex items-center justify-between px-4 py-2 bg-white border-b border-slate-200 fixed top-0 left-0 right-0 z-40">
<span className="font-bold text-blue-600 text-sm"> ArbEngine</span>
<div className="flex items-center gap-2">
{isLoggedIn ? (
<button onClick={logout} className="text-xs text-slate-500 hover:text-red-500">退</button>
) : (
<>
<Link href="/login" className="text-xs text-slate-500"></Link>
<Link href="/register" className="text-xs bg-blue-600 text-white px-2 py-1 rounded"></Link>
</>
)}
</div>
</header>
</>
);
}

View File

@ -1,17 +1,11 @@
"use client"; "use client";
import Link from "next/link"; import Link from "next/link";
import { useState } from "react"; import { useState } from "react";
import { useAuth } from "@/lib/auth";
const navLinks = [
{ href: "/", label: "仪表盘" },
{ href: "/kline", label: "K线" },
{ href: "/live", label: "实时" },
{ href: "/signals", label: "信号" },
{ href: "/about", label: "说明" },
];
export default function Navbar() { export default function Navbar() {
const [open, setOpen] = useState(false); const [open, setOpen] = useState(false);
const { user, isLoggedIn, logout } = useAuth();
return ( return (
<nav className="border-b border-slate-200 bg-white sticky top-0 z-50 shadow-sm"> <nav className="border-b border-slate-200 bg-white sticky top-0 z-50 shadow-sm">
@ -22,21 +16,36 @@ export default function Navbar() {
{/* Desktop nav */} {/* Desktop nav */}
<div className="hidden md:flex items-center gap-6 text-sm ml-8"> <div className="hidden md:flex items-center gap-6 text-sm ml-8">
{navLinks.map(l => ( <Link href="/" className="text-slate-600 hover:text-blue-600 transition-colors"></Link>
<Link key={l.href} href={l.href} className="text-slate-600 hover:text-blue-600 transition-colors"> <Link href="/about" className="text-slate-600 hover:text-blue-600 transition-colors"></Link>
{l.label}
</Link>
))}
</div> </div>
<div className="hidden md:flex items-center gap-3 text-sm ml-auto"> <div className="hidden md:flex items-center gap-3 text-sm ml-auto">
<Link href="/login" className="text-slate-500 hover:text-slate-800 transition-colors"></Link> {isLoggedIn ? (
<Link href="/register" className="bg-blue-600 hover:bg-blue-700 text-white px-3 py-1.5 rounded-lg transition-colors"></Link> <>
<span className="text-slate-500 text-xs">{user?.email}</span>
{user?.role === "admin" && (
<span className="bg-amber-100 text-amber-700 text-xs px-2 py-0.5 rounded-full font-medium">Admin</span>
)}
<button onClick={logout} className="text-slate-500 hover:text-red-500 transition-colors">退</button>
</>
) : (
<>
<Link href="/login" className="text-slate-500 hover:text-slate-800 transition-colors"></Link>
<Link href="/register" className="bg-blue-600 hover:bg-blue-700 text-white px-3 py-1.5 rounded-lg transition-colors"></Link>
</>
)}
</div> </div>
{/* Mobile */} {/* Mobile */}
<div className="md:hidden ml-auto flex items-center gap-2"> <div className="md:hidden ml-auto flex items-center gap-2">
<Link href="/login" className="text-slate-500 text-sm"></Link> {isLoggedIn ? (
<Link href="/register" className="bg-blue-600 text-white px-2 py-1 rounded text-sm"></Link> <button onClick={logout} className="text-slate-500 text-sm">退</button>
) : (
<>
<Link href="/login" className="text-slate-500 text-sm"></Link>
<Link href="/register" className="bg-blue-600 text-white px-2 py-1 rounded text-sm"></Link>
</>
)}
<button onClick={() => setOpen(!open)} className="ml-1 p-2 text-slate-500" aria-label="菜单"> <button onClick={() => setOpen(!open)} className="ml-1 p-2 text-slate-500" aria-label="菜单">
{open ? ( {open ? (
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24"> <svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
@ -53,12 +62,8 @@ export default function Navbar() {
{open && ( {open && (
<div className="md:hidden border-t border-slate-100 bg-white px-4 py-3 space-y-1"> <div className="md:hidden border-t border-slate-100 bg-white px-4 py-3 space-y-1">
{navLinks.map(l => ( <Link href="/" onClick={() => setOpen(false)} className="block py-2 text-slate-600 hover:text-blue-600 text-sm"></Link>
<Link key={l.href} href={l.href} onClick={() => setOpen(false)} <Link href="/about" onClick={() => setOpen(false)} className="block py-2 text-slate-600 hover:text-blue-600 text-sm"></Link>
className="block py-2 text-slate-600 hover:text-blue-600 transition-colors text-sm">
{l.label}
</Link>
))}
</div> </div>
)} )}
</nav> </nav>

View File

@ -1,3 +1,5 @@
import { authFetch } from "./auth";
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? ""; const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "";
export interface RateData { export interface RateData {
@ -84,21 +86,31 @@ export interface YtdStatsResponse {
ETH: YtdStats; ETH: YtdStats;
} }
async function fetchAPI<T>(path: string): Promise<T> { // Public fetch (no auth needed)
async function fetchPublic<T>(path: string): Promise<T> {
const res = await fetch(`${API_BASE}${path}`, { cache: "no-store" }); const res = await fetch(`${API_BASE}${path}`, { cache: "no-store" });
if (!res.ok) throw new Error(`API error ${res.status}`); if (!res.ok) throw new Error(`API error ${res.status}`);
return res.json(); return res.json();
} }
// Protected fetch (auth required, auto-refresh)
async function fetchProtected<T>(path: string): Promise<T> {
const res = await authFetch(path, { cache: "no-store" });
if (!res.ok) throw new Error(`API error ${res.status}`);
return res.json();
}
export const api = { export const api = {
rates: () => fetchAPI<RatesResponse>("/api/rates"), // Public
history: () => fetchAPI<HistoryResponse>("/api/history"), rates: () => fetchPublic<RatesResponse>("/api/rates"),
stats: () => fetchAPI<StatsResponse>("/api/stats"), health: () => fetchPublic<{ status: string }>("/api/health"),
health: () => fetchAPI<{ status: string }>("/api/health"), // Protected
signalsHistory: () => fetchAPI<SignalsHistoryResponse>("/api/signals/history"), history: () => fetchProtected<HistoryResponse>("/api/history"),
stats: () => fetchProtected<StatsResponse>("/api/stats"),
signalsHistory: () => fetchProtected<SignalsHistoryResponse>("/api/signals/history"),
snapshots: (hours = 24, limit = 5000) => snapshots: (hours = 24, limit = 5000) =>
fetchAPI<SnapshotsResponse>(`/api/snapshots?hours=${hours}&limit=${limit}`), fetchProtected<SnapshotsResponse>(`/api/snapshots?hours=${hours}&limit=${limit}`),
kline: (symbol = "BTC", interval = "1h", limit = 500) => kline: (symbol = "BTC", interval = "1h", limit = 500) =>
fetchAPI<KlineResponse>(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`), fetchProtected<KlineResponse>(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`),
statsYtd: () => fetchAPI<YtdStatsResponse>("/api/stats/ytd"), statsYtd: () => fetchProtected<YtdStatsResponse>("/api/stats/ytd"),
}; };

137
frontend/lib/auth.tsx Normal file
View File

@ -0,0 +1,137 @@
"use client";
import { createContext, useContext, useState, useEffect, ReactNode, useCallback } from "react";
interface User {
id: number;
email: string;
role: string;
}
interface AuthState {
user: User | null;
accessToken: string | null;
loading: boolean;
login: (email: string, password: string) => Promise<void>;
register: (email: string, password: string, inviteCode: string) => Promise<void>;
logout: () => void;
isLoggedIn: boolean;
isAdmin: boolean;
}
const AuthContext = createContext<AuthState | undefined>(undefined);
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "";
export function AuthProvider({ children }: { children: ReactNode }) {
const [user, setUser] = useState<User | null>(null);
const [accessToken, setAccessToken] = useState<string | null>(null);
const [loading, setLoading] = useState(true);
// init from localStorage
useEffect(() => {
const token = localStorage.getItem("access_token");
const saved = localStorage.getItem("user");
if (token && saved) {
try {
setAccessToken(token);
setUser(JSON.parse(saved));
} catch {}
}
setLoading(false);
}, []);
const saveAuth = (data: { access_token: string; refresh_token: string; user: User }) => {
localStorage.setItem("access_token", data.access_token);
localStorage.setItem("refresh_token", data.refresh_token);
localStorage.setItem("user", JSON.stringify(data.user));
setAccessToken(data.access_token);
setUser(data.user);
};
const login = useCallback(async (email: string, password: string) => {
const res = await fetch(`${API_BASE}/api/auth/login`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ email, password }),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
throw new Error(err.detail || "Login failed");
}
const data = await res.json();
saveAuth(data);
}, []);
const register = useCallback(async (email: string, password: string, inviteCode: string) => {
const res = await fetch(`${API_BASE}/api/auth/register`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ email, password, invite_code: inviteCode }),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
throw new Error(err.detail || "Registration failed");
}
const data = await res.json();
saveAuth(data);
}, []);
const logout = useCallback(() => {
localStorage.removeItem("access_token");
localStorage.removeItem("refresh_token");
localStorage.removeItem("user");
setAccessToken(null);
setUser(null);
}, []);
return (
<AuthContext.Provider value={{
user, accessToken, loading, login, register, logout,
isLoggedIn: !!user, isAdmin: user?.role === "admin",
}}>
{children}
</AuthContext.Provider>
);
}
export function useAuth() {
const ctx = useContext(AuthContext);
if (!ctx) throw new Error("useAuth must be used within AuthProvider");
return ctx;
}
// Authenticated fetch helper
export async function authFetch(path: string, options: RequestInit = {}): Promise<Response> {
const token = localStorage.getItem("access_token");
const headers = new Headers(options.headers);
if (token) headers.set("Authorization", `Bearer ${token}`);
let res = await fetch(`${API_BASE}${path}`, { ...options, headers });
// try refresh on 401
if (res.status === 401) {
const refreshToken = localStorage.getItem("refresh_token");
if (refreshToken) {
const refreshRes = await fetch(`${API_BASE}/api/auth/refresh`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ refresh_token: refreshToken }),
});
if (refreshRes.ok) {
const data = await refreshRes.json();
localStorage.setItem("access_token", data.access_token);
localStorage.setItem("refresh_token", data.refresh_token);
headers.set("Authorization", `Bearer ${data.access_token}`);
res = await fetch(`${API_BASE}${path}`, { ...options, headers });
} else {
// refresh failed, clear auth
localStorage.removeItem("access_token");
localStorage.removeItem("refresh_token");
localStorage.removeItem("user");
}
}
}
return res;
}