feat: V2.0 auth system - JWT access/refresh, invite codes, route protection, admin CLI, auth gate blur overlay
This commit is contained in:
parent
052e5a0541
commit
1ab228286c
115
backend/admin_cli.py
Normal file
115
backend/admin_cli.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Admin CLI for Arbitrage Engine"""
|
||||||
|
import sys, os, sqlite3, secrets, json
|
||||||
|
|
||||||
|
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||||||
|
|
||||||
|
def get_conn():
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def gen_invite(count=1, max_uses=1):
|
||||||
|
conn = get_conn()
|
||||||
|
codes = []
|
||||||
|
for _ in range(count):
|
||||||
|
code = secrets.token_urlsafe(6)[:8].upper()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)",
|
||||||
|
(code, max_uses)
|
||||||
|
)
|
||||||
|
codes.append(code)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
for c in codes:
|
||||||
|
print(f" {c}")
|
||||||
|
print(f"\nGenerated {len(codes)} invite code(s)")
|
||||||
|
|
||||||
|
def list_invites():
|
||||||
|
conn = get_conn()
|
||||||
|
rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall()
|
||||||
|
conn.close()
|
||||||
|
if not rows:
|
||||||
|
print("No invite codes found")
|
||||||
|
return
|
||||||
|
print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}")
|
||||||
|
print("-" * 60)
|
||||||
|
for r in rows:
|
||||||
|
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}")
|
||||||
|
|
||||||
|
def disable_invite(code):
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
print(f"Disabled invite code: {code}")
|
||||||
|
|
||||||
|
def list_users():
|
||||||
|
conn = get_conn()
|
||||||
|
rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall()
|
||||||
|
conn.close()
|
||||||
|
if not rows:
|
||||||
|
print("No users found")
|
||||||
|
return
|
||||||
|
print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}")
|
||||||
|
print("-" * 72)
|
||||||
|
for r in rows:
|
||||||
|
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}")
|
||||||
|
|
||||||
|
def ban_user(user_id):
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
print(f"Banned user ID: {user_id}")
|
||||||
|
|
||||||
|
def unban_user(user_id):
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
print(f"Unbanned user ID: {user_id}")
|
||||||
|
|
||||||
|
def set_admin(user_id):
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
print(f"Set user {user_id} as admin")
|
||||||
|
|
||||||
|
def usage():
|
||||||
|
print("""Usage: python3 admin_cli.py <command> [args]
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
gen-invite [count] [max_uses] Generate invite codes (default: 1 code, 1 use)
|
||||||
|
list-invites List all invite codes
|
||||||
|
disable-invite <CODE> Disable an invite code
|
||||||
|
list-users List all users
|
||||||
|
ban-user <USER_ID> Ban a user
|
||||||
|
unban-user <USER_ID> Unban a user
|
||||||
|
set-admin <USER_ID> Set user as admin
|
||||||
|
""")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
usage()
|
||||||
|
sys.exit(1)
|
||||||
|
cmd = sys.argv[1]
|
||||||
|
if cmd == "gen-invite":
|
||||||
|
count = int(sys.argv[2]) if len(sys.argv) > 2 else 1
|
||||||
|
max_uses = int(sys.argv[3]) if len(sys.argv) > 3 else 1
|
||||||
|
gen_invite(count, max_uses)
|
||||||
|
elif cmd == "list-invites":
|
||||||
|
list_invites()
|
||||||
|
elif cmd == "disable-invite":
|
||||||
|
disable_invite(sys.argv[2])
|
||||||
|
elif cmd == "list-users":
|
||||||
|
list_users()
|
||||||
|
elif cmd == "ban-user":
|
||||||
|
ban_user(int(sys.argv[2]))
|
||||||
|
elif cmd == "unban-user":
|
||||||
|
unban_user(int(sys.argv[2]))
|
||||||
|
elif cmd == "set-admin":
|
||||||
|
set_admin(int(sys.argv[2]))
|
||||||
|
else:
|
||||||
|
usage()
|
||||||
297
backend/auth.py
297
backend/auth.py
@ -7,17 +7,21 @@ import base64
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header
|
from fastapi import APIRouter, HTTPException, Header, Depends, Request
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
|
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
|
||||||
JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me")
|
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
|
||||||
JWT_EXPIRE_HOURS = 24 * 7
|
ACCESS_TOKEN_HOURS = 24
|
||||||
|
REFRESH_TOKEN_DAYS = 7
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["auth"])
|
router = APIRouter(prefix="/api", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ─── DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
def get_conn():
|
def get_conn():
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
@ -27,29 +31,26 @@ def get_conn():
|
|||||||
def ensure_tables():
|
def ensure_tables():
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(
|
cur.execute("""
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
email TEXT UNIQUE NOT NULL,
|
email TEXT UNIQUE NOT NULL,
|
||||||
password_hash TEXT NOT NULL,
|
password_hash TEXT NOT NULL,
|
||||||
discord_id TEXT,
|
discord_id TEXT,
|
||||||
|
role TEXT NOT NULL DEFAULT 'user',
|
||||||
|
banned INTEGER NOT NULL DEFAULT 0,
|
||||||
created_at TEXT NOT NULL
|
created_at TEXT NOT NULL
|
||||||
)
|
)
|
||||||
"""
|
""")
|
||||||
)
|
cur.execute("""
|
||||||
cur.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||||
user_id INTEGER PRIMARY KEY,
|
user_id INTEGER PRIMARY KEY,
|
||||||
tier TEXT NOT NULL DEFAULT 'free',
|
tier TEXT NOT NULL DEFAULT 'free',
|
||||||
expires_at TEXT,
|
expires_at TEXT,
|
||||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||||
)
|
)
|
||||||
"""
|
""")
|
||||||
)
|
cur.execute("""
|
||||||
cur.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS signal_logs (
|
CREATE TABLE IF NOT EXISTS signal_logs (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
symbol TEXT NOT NULL,
|
symbol TEXT NOT NULL,
|
||||||
@ -58,12 +59,55 @@ def ensure_tables():
|
|||||||
sent_at TEXT NOT NULL,
|
sent_at TEXT NOT NULL,
|
||||||
message TEXT NOT NULL
|
message TEXT NOT NULL
|
||||||
)
|
)
|
||||||
"""
|
""")
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_codes (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
code TEXT UNIQUE NOT NULL,
|
||||||
|
created_by INTEGER,
|
||||||
|
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||||
|
used_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active',
|
||||||
|
expires_at TEXT,
|
||||||
|
created_at TEXT DEFAULT (datetime('now'))
|
||||||
)
|
)
|
||||||
|
""")
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_usage (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
code_id INTEGER NOT NULL,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
used_at TEXT DEFAULT (datetime('now')),
|
||||||
|
FOREIGN KEY (code_id) REFERENCES invite_codes(id),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
token TEXT UNIQUE NOT NULL,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
revoked INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at TEXT DEFAULT (datetime('now')),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
# migrate: add role/banned columns if missing
|
||||||
|
try:
|
||||||
|
cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Password utils ──────────────────────────────────────────
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
salt = secrets.token_hex(16)
|
salt = secrets.token_hex(16)
|
||||||
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
||||||
@ -79,19 +123,37 @@ def verify_password(password: str, stored: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ─── JWT utils ───────────────────────────────────────────────
|
||||||
|
|
||||||
def b64url(data: bytes) -> str:
|
def b64url(data: bytes) -> str:
|
||||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||||||
|
|
||||||
|
|
||||||
def create_token(user_id: int, email: str) -> str:
|
def create_access_token(user_id: int, email: str, role: str) -> str:
|
||||||
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
|
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
|
||||||
exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp())
|
exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp())
|
||||||
payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode())
|
payload = b64url(json.dumps({
|
||||||
|
"sub": user_id, "email": email, "role": role,
|
||||||
|
"exp": exp, "type": "access"
|
||||||
|
}, separators=(",", ":")).encode())
|
||||||
sign_input = f"{header}.{payload}".encode()
|
sign_input = f"{header}.{payload}".encode()
|
||||||
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
|
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
|
||||||
return f"{header}.{payload}.{b64url(signature)}"
|
return f"{header}.{payload}.{b64url(signature)}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(user_id: int) -> str:
|
||||||
|
token = secrets.token_urlsafe(48)
|
||||||
|
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)",
|
||||||
|
(user_id, token, expires_at)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
def parse_token(token: str) -> Optional[dict]:
|
def parse_token(token: str) -> Optional[dict]:
|
||||||
try:
|
try:
|
||||||
header_b64, payload_b64, sig_b64 = token.split(".")
|
header_b64, payload_b64, sig_b64 = token.split(".")
|
||||||
@ -108,24 +170,41 @@ def parse_token(token: str) -> Optional[dict]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row:
|
# ─── Auth dependency ─────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
|
||||||
|
"""FastAPI dependency: returns user dict or raises 401"""
|
||||||
if not authorization or not authorization.startswith("Bearer "):
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
raise HTTPException(status_code=401, detail="missing token")
|
raise HTTPException(status_code=401, detail="missing token")
|
||||||
token = authorization.split(" ", 1)[1].strip()
|
token = authorization.split(" ", 1)[1].strip()
|
||||||
payload = parse_token(token)
|
payload = parse_token(token)
|
||||||
if not payload:
|
if not payload:
|
||||||
raise HTTPException(status_code=401, detail="invalid token")
|
raise HTTPException(status_code=401, detail="invalid or expired token")
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
raise HTTPException(status_code=401, detail="invalid token type")
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
|
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="user not found")
|
raise HTTPException(status_code=401, detail="user not found")
|
||||||
|
if user["banned"]:
|
||||||
|
raise HTTPException(status_code=403, detail="account banned")
|
||||||
|
return dict(user)
|
||||||
|
|
||||||
|
|
||||||
|
def require_admin(user: dict = Depends(get_current_user)) -> dict:
|
||||||
|
"""FastAPI dependency: requires admin role"""
|
||||||
|
if user.get("role") != "admin":
|
||||||
|
raise HTTPException(status_code=403, detail="admin required")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Request models ──────────────────────────────────────────
|
||||||
|
|
||||||
class RegisterReq(BaseModel):
|
class RegisterReq(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
|
invite_code: str
|
||||||
|
|
||||||
|
|
||||||
class LoginReq(BaseModel):
|
class LoginReq(BaseModel):
|
||||||
@ -133,18 +212,47 @@ class LoginReq(BaseModel):
|
|||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
class BindDiscordReq(BaseModel):
|
class RefreshReq(BaseModel):
|
||||||
discord_id: str
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class GenInviteReq(BaseModel):
|
||||||
|
count: int = 1
|
||||||
|
max_uses: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
class BanUserReq(BaseModel):
|
||||||
|
banned: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Auth routes ─────────────────────────────────────────────
|
||||||
|
|
||||||
@router.post("/auth/register")
|
@router.post("/auth/register")
|
||||||
def register(body: RegisterReq):
|
def register(body: RegisterReq):
|
||||||
ensure_tables()
|
ensure_tables()
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
|
|
||||||
|
# verify invite code
|
||||||
|
invite = conn.execute(
|
||||||
|
"SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,)
|
||||||
|
).fetchone()
|
||||||
|
if not invite:
|
||||||
|
conn.close()
|
||||||
|
raise HTTPException(status_code=400, detail="invalid invite code")
|
||||||
|
if invite["status"] != "active":
|
||||||
|
conn.close()
|
||||||
|
raise HTTPException(status_code=400, detail="invite code disabled")
|
||||||
|
if invite["used_count"] >= invite["max_uses"]:
|
||||||
|
conn.close()
|
||||||
|
raise HTTPException(status_code=400, detail="invite code exhausted")
|
||||||
|
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
|
||||||
|
conn.close()
|
||||||
|
raise HTTPException(status_code=400, detail="invite code expired")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pwd_hash = hash_password(body.password)
|
pwd_hash = hash_password(body.password)
|
||||||
cur = conn.execute(
|
cur = conn.execute(
|
||||||
"INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)",
|
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)",
|
||||||
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
||||||
)
|
)
|
||||||
user_id = cur.lastrowid
|
user_id = cur.lastrowid
|
||||||
@ -152,14 +260,39 @@ def register(body: RegisterReq):
|
|||||||
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
|
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
|
# record invite usage
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)",
|
||||||
|
(invite["id"], user_id),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?",
|
||||||
|
(invite["id"],),
|
||||||
|
)
|
||||||
|
# auto-exhaust if max reached
|
||||||
|
new_count = invite["used_count"] + 1
|
||||||
|
if new_count >= invite["max_uses"]:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE invite_codes SET status = 'exhausted' WHERE id = ?",
|
||||||
|
(invite["id"],),
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
conn.close()
|
conn.close()
|
||||||
raise HTTPException(status_code=400, detail="email already registered")
|
raise HTTPException(status_code=400, detail="email already registered")
|
||||||
|
|
||||||
user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone()
|
# issue tokens
|
||||||
|
user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
return dict(user)
|
access = create_access_token(user["id"], user["email"], user["role"])
|
||||||
|
refresh = create_refresh_token(user["id"])
|
||||||
|
return {
|
||||||
|
"access_token": access,
|
||||||
|
"refresh_token": refresh,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||||||
|
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth/login")
|
@router.post("/auth/login")
|
||||||
@ -170,27 +303,113 @@ def login(body: LoginReq):
|
|||||||
conn.close()
|
conn.close()
|
||||||
if not user or not verify_password(body.password, user["password_hash"]):
|
if not user or not verify_password(body.password, user["password_hash"]):
|
||||||
raise HTTPException(status_code=401, detail="invalid credentials")
|
raise HTTPException(status_code=401, detail="invalid credentials")
|
||||||
token = create_token(user["id"], user["email"])
|
if user["banned"]:
|
||||||
return {"token": token, "token_type": "bearer"}
|
raise HTTPException(status_code=403, detail="account banned")
|
||||||
|
access = create_access_token(user["id"], user["email"], user["role"])
|
||||||
|
refresh = create_refresh_token(user["id"])
|
||||||
|
return {
|
||||||
|
"access_token": access,
|
||||||
|
"refresh_token": refresh,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||||||
|
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/user/bind-discord")
|
@router.post("/auth/refresh")
|
||||||
def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)):
|
def refresh_token(body: RefreshReq):
|
||||||
user = get_user_from_auth(authorization)
|
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"]))
|
row = conn.execute(
|
||||||
conn.commit()
|
"SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,)
|
||||||
updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone()
|
).fetchone()
|
||||||
|
if not row:
|
||||||
conn.close()
|
conn.close()
|
||||||
return dict(updated)
|
raise HTTPException(status_code=401, detail="invalid refresh token")
|
||||||
|
if row["expires_at"] < datetime.utcnow().isoformat():
|
||||||
|
conn.close()
|
||||||
|
raise HTTPException(status_code=401, detail="refresh token expired")
|
||||||
|
user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone()
|
||||||
|
if not user or user["banned"]:
|
||||||
|
conn.close()
|
||||||
|
raise HTTPException(status_code=403, detail="account unavailable")
|
||||||
|
# revoke old, issue new
|
||||||
|
conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
access = create_access_token(user["id"], user["email"], user["role"])
|
||||||
|
new_refresh = create_refresh_token(user["id"])
|
||||||
|
return {
|
||||||
|
"access_token": access,
|
||||||
|
"refresh_token": new_refresh,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/me")
|
@router.get("/auth/me")
|
||||||
def me(authorization: Optional[str] = Header(default=None)):
|
def me(user: dict = Depends(get_current_user)):
|
||||||
user = get_user_from_auth(authorization)
|
|
||||||
conn = get_conn()
|
conn = get_conn()
|
||||||
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
|
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]}
|
return {
|
||||||
out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None}
|
"id": user["id"], "email": user["email"], "role": user["role"],
|
||||||
return out
|
"discord_id": user.get("discord_id"),
|
||||||
|
"created_at": user["created_at"],
|
||||||
|
"subscription": dict(sub) if sub else {"tier": "free", "expires_at": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Admin routes ────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/admin/invite-codes")
|
||||||
|
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
|
||||||
|
conn = get_conn()
|
||||||
|
codes = []
|
||||||
|
for _ in range(body.count):
|
||||||
|
code = secrets.token_urlsafe(6)[:8].upper()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)",
|
||||||
|
(code, admin["id"], body.max_uses),
|
||||||
|
)
|
||||||
|
codes.append(code)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return {"codes": codes}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/admin/invite-codes")
|
||||||
|
def list_invite_codes(admin: dict = Depends(require_admin)):
|
||||||
|
conn = get_conn()
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
return {"items": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/admin/invite-codes/{code_id}")
|
||||||
|
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/admin/users")
|
||||||
|
def list_users(admin: dict = Depends(require_admin)):
|
||||||
|
conn = get_conn()
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
return {"items": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/admin/users/{user_id}/ban")
|
||||||
|
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return {"ok": True, "user_id": user_id, "banned": body.banned}
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import httpx
|
import httpx
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import asyncio, time, sqlite3, os
|
import asyncio, time, sqlite3, os
|
||||||
|
|
||||||
|
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
|
||||||
|
|
||||||
app = FastAPI(title="Arbitrage Engine API")
|
app = FastAPI(title="Arbitrage Engine API")
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@ -13,6 +15,8 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.include_router(auth_router)
|
||||||
|
|
||||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||||
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
||||||
@ -101,6 +105,7 @@ async def background_snapshot_loop():
|
|||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
init_db()
|
init_db()
|
||||||
|
ensure_auth_tables()
|
||||||
asyncio.create_task(background_snapshot_loop())
|
asyncio.create_task(background_snapshot_loop())
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +142,7 @@ async def get_rates():
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/snapshots")
|
@app.get("/api/snapshots")
|
||||||
async def get_snapshots(hours: int = 24, limit: int = 5000):
|
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
|
||||||
"""查询本地落库的实时快照数据"""
|
"""查询本地落库的实时快照数据"""
|
||||||
since = int(time.time()) - hours * 3600
|
since = int(time.time()) - hours * 3600
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
@ -155,7 +160,7 @@ async def get_snapshots(hours: int = 24, limit: int = 5000):
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/kline")
|
@app.get("/api/kline")
|
||||||
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500):
|
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
|
||||||
"""
|
"""
|
||||||
从 rate_snapshots 聚合K线数据
|
从 rate_snapshots 聚合K线数据
|
||||||
symbol: BTC | ETH
|
symbol: BTC | ETH
|
||||||
@ -212,7 +217,7 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500)
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/stats/ytd")
|
@app.get("/api/stats/ytd")
|
||||||
async def get_stats_ytd():
|
async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
||||||
"""今年以来(YTD)资金费率年化统计"""
|
"""今年以来(YTD)资金费率年化统计"""
|
||||||
cached = get_cache("stats_ytd", 3600)
|
cached = get_cache("stats_ytd", 3600)
|
||||||
if cached: return cached
|
if cached: return cached
|
||||||
@ -245,7 +250,7 @@ async def get_stats_ytd():
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/signals/history")
|
@app.get("/api/signals/history")
|
||||||
async def get_signals_history(limit: int = 100):
|
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
|
||||||
"""查询信号推送历史"""
|
"""查询信号推送历史"""
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
@ -261,7 +266,7 @@ async def get_signals_history(limit: int = 100):
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/history")
|
@app.get("/api/history")
|
||||||
async def get_history():
|
async def get_history(user: dict = Depends(get_current_user)):
|
||||||
cached = get_cache("history", 60)
|
cached = get_cache("history", 60)
|
||||||
if cached: return cached
|
if cached: return cached
|
||||||
end_time = int(datetime.utcnow().timestamp() * 1000)
|
end_time = int(datetime.utcnow().timestamp() * 1000)
|
||||||
@ -288,7 +293,7 @@ async def get_history():
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/stats")
|
@app.get("/api/stats")
|
||||||
async def get_stats():
|
async def get_stats(user: dict = Depends(get_current_user)):
|
||||||
cached = get_cache("stats", 60)
|
cached = get_cache("stats", 60)
|
||||||
if cached: return cached
|
if cached: return cached
|
||||||
end_time = int(datetime.utcnow().timestamp() * 1000)
|
end_time = int(datetime.utcnow().timestamp() * 1000)
|
||||||
|
|||||||
@ -2,7 +2,8 @@ import type { Metadata } from "next";
|
|||||||
import { Geist, Geist_Mono } from "next/font/google";
|
import { Geist, Geist_Mono } from "next/font/google";
|
||||||
import "./globals.css";
|
import "./globals.css";
|
||||||
import Sidebar from "@/components/Sidebar";
|
import Sidebar from "@/components/Sidebar";
|
||||||
import Link from "next/link";
|
import { AuthProvider } from "@/lib/auth";
|
||||||
|
import AuthHeader from "@/components/AuthHeader";
|
||||||
|
|
||||||
const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"] });
|
const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"] });
|
||||||
const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"] });
|
const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"] });
|
||||||
@ -16,25 +17,17 @@ export default function RootLayout({ children }: Readonly<{ children: React.Reac
|
|||||||
return (
|
return (
|
||||||
<html lang="zh">
|
<html lang="zh">
|
||||||
<body className={`${geistSans.variable} ${geistMono.variable} antialiased min-h-screen bg-slate-50 text-slate-900`}>
|
<body className={`${geistSans.variable} ${geistMono.variable} antialiased min-h-screen bg-slate-50 text-slate-900`}>
|
||||||
|
<AuthProvider>
|
||||||
<div className="flex min-h-screen">
|
<div className="flex min-h-screen">
|
||||||
<Sidebar />
|
<Sidebar />
|
||||||
<div className="flex-1 flex flex-col min-w-0">
|
<div className="flex-1 flex flex-col min-w-0">
|
||||||
{/* 桌面端顶栏:右上角登录注册 */}
|
<AuthHeader />
|
||||||
<header className="hidden md:flex items-center justify-end px-6 py-3 bg-white border-b border-slate-200 gap-3">
|
|
||||||
<Link href="/login"
|
|
||||||
className="text-sm text-slate-600 hover:text-blue-600 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-blue-300 transition-colors">
|
|
||||||
登录
|
|
||||||
</Link>
|
|
||||||
<Link href="/register"
|
|
||||||
className="text-sm text-white bg-blue-600 hover:bg-blue-700 px-3 py-1.5 rounded-lg transition-colors font-medium">
|
|
||||||
注册
|
|
||||||
</Link>
|
|
||||||
</header>
|
|
||||||
<main className="flex-1 p-4 md:p-6 pt-16 md:pt-6">
|
<main className="flex-1 p-4 md:p-6 pt-16 md:pt-6">
|
||||||
{children}
|
{children}
|
||||||
</main>
|
</main>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
</AuthProvider>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
);
|
);
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import { useState, Suspense } from "react";
|
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
|
||||||
|
|
||||||
function LoginForm() {
|
import { useState } from "react";
|
||||||
|
import { useAuth } from "@/lib/auth";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import Link from "next/link";
|
||||||
|
|
||||||
|
export default function LoginPage() {
|
||||||
|
const { login } = useAuth();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const params = useSearchParams();
|
|
||||||
const [email, setEmail] = useState("");
|
const [email, setEmail] = useState("");
|
||||||
const [password, setPassword] = useState("");
|
const [password, setPassword] = useState("");
|
||||||
const [error, setError] = useState("");
|
const [error, setError] = useState("");
|
||||||
@ -12,74 +15,61 @@ function LoginForm() {
|
|||||||
|
|
||||||
const handleSubmit = async (e: React.FormEvent) => {
|
const handleSubmit = async (e: React.FormEvent) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
setLoading(true);
|
|
||||||
setError("");
|
setError("");
|
||||||
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
const form = new URLSearchParams();
|
await login(email, password);
|
||||||
form.append("username", email);
|
router.push("/");
|
||||||
form.append("password", password);
|
} catch (err: any) {
|
||||||
const r = await fetch("/api/auth/login", {
|
setError(err.message || "登录失败");
|
||||||
method: "POST",
|
|
||||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
|
||||||
body: form.toString(),
|
|
||||||
});
|
|
||||||
const data = await r.json();
|
|
||||||
if (!r.ok) { setError(data.detail || "登录失败"); return; }
|
|
||||||
localStorage.setItem("arb_token", data.access_token);
|
|
||||||
router.push("/dashboard");
|
|
||||||
} catch {
|
|
||||||
setError("网络错误,请重试");
|
|
||||||
} finally {
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="min-h-[70vh] flex items-center justify-center">
|
<div className="min-h-screen bg-slate-50 flex items-center justify-center p-4">
|
||||||
<div className="w-full max-w-md rounded-xl border border-slate-200 bg-white p-8 space-y-6">
|
<div className="w-full max-w-md">
|
||||||
<div>
|
<div className="bg-white rounded-xl shadow-sm border border-slate-200 p-8">
|
||||||
<h1 className="text-2xl font-bold text-slate-900">登录</h1>
|
<div className="text-center mb-8">
|
||||||
{params.get("registered") && (
|
<h1 className="text-2xl font-bold text-slate-800">⚡ Arbitrage Engine</h1>
|
||||||
<p className="text-emerald-400 text-sm mt-1">✅ 注册成功,请登录</p>
|
<p className="text-slate-500 text-sm mt-2">登录您的账户</p>
|
||||||
)}
|
|
||||||
<p className="text-slate-500 text-sm mt-1">登录后查看信号和账户信息</p>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<form onSubmit={handleSubmit} className="space-y-4">
|
<form onSubmit={handleSubmit} className="space-y-4">
|
||||||
<div>
|
<div>
|
||||||
<label className="block text-sm text-slate-700 mb-1">邮箱</label>
|
<label className="block text-sm font-medium text-slate-700 mb-1">邮箱</label>
|
||||||
<input
|
<input
|
||||||
type="email" required value={email} onChange={e => setEmail(e.target.value)}
|
type="email" value={email} onChange={e => setEmail(e.target.value)}
|
||||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||||
placeholder="your@email.com"
|
placeholder="you@example.com" required
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label className="block text-sm text-slate-700 mb-1">密码</label>
|
<label className="block text-sm font-medium text-slate-700 mb-1">密码</label>
|
||||||
<input
|
<input
|
||||||
type="password" required value={password} onChange={e => setPassword(e.target.value)}
|
type="password" value={password} onChange={e => setPassword(e.target.value)}
|
||||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||||
|
placeholder="••••••••" required
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{error && <p className="text-red-400 text-sm">{error}</p>}
|
|
||||||
|
{error && <p className="text-red-500 text-sm">{error}</p>}
|
||||||
|
|
||||||
<button
|
<button
|
||||||
type="submit" disabled={loading}
|
type="submit" disabled={loading}
|
||||||
className="w-full bg-cyan-600 hover:bg-cyan-500 disabled:opacity-50 text-white font-medium py-2 rounded-lg text-sm transition-colors"
|
className="w-full bg-blue-600 hover:bg-blue-700 text-white py-2.5 rounded-lg font-medium text-sm transition-colors disabled:opacity-50"
|
||||||
>
|
>
|
||||||
{loading ? "登录中..." : "登录"}
|
{loading ? "登录中..." : "登录"}
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
<p className="text-center text-sm text-slate-500">
|
|
||||||
没有账号?<a href="/register" className="text-blue-600 hover:underline">注册</a>
|
<p className="text-center text-sm text-slate-500 mt-6">
|
||||||
|
没有账户?{" "}
|
||||||
|
<Link href="/register" className="text-blue-600 hover:underline">注册</Link>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
</div>
|
||||||
}
|
|
||||||
|
|
||||||
export default function LoginPage() {
|
|
||||||
return (
|
|
||||||
<Suspense>
|
|
||||||
<LoginForm />
|
|
||||||
</Suspense>
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,8 +3,10 @@
|
|||||||
import { useEffect, useState, useCallback, useRef } from "react";
|
import { useEffect, useState, useCallback, useRef } from "react";
|
||||||
import { createChart, ColorType, CandlestickSeries } from "lightweight-charts";
|
import { createChart, ColorType, CandlestickSeries } from "lightweight-charts";
|
||||||
import { api, RatesResponse, StatsResponse, HistoryResponse, HistoryPoint, SignalHistoryItem, KBar, YtdStatsResponse } from "@/lib/api";
|
import { api, RatesResponse, StatsResponse, HistoryResponse, HistoryPoint, SignalHistoryItem, KBar, YtdStatsResponse } from "@/lib/api";
|
||||||
|
import { useAuth } from "@/lib/auth";
|
||||||
import RateCard from "@/components/RateCard";
|
import RateCard from "@/components/RateCard";
|
||||||
import StatsCard from "@/components/StatsCard";
|
import StatsCard from "@/components/StatsCard";
|
||||||
|
import Link from "next/link";
|
||||||
import {
|
import {
|
||||||
LineChart, Line, XAxis, YAxis, Tooltip, Legend,
|
LineChart, Line, XAxis, YAxis, Tooltip, Legend,
|
||||||
ResponsiveContainer, ReferenceLine
|
ResponsiveContainer, ReferenceLine
|
||||||
@ -91,8 +93,35 @@ function IntervalPicker({ value, onChange }: { value: string; onChange: (v: stri
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── 未登录遮挡组件 ──────────────────────────────────────────────
|
||||||
|
function AuthGate({ children }: { children: React.ReactNode }) {
|
||||||
|
const { isLoggedIn, loading } = useAuth();
|
||||||
|
if (loading) return <div className="text-center text-slate-400 py-12">加载中...</div>;
|
||||||
|
if (!isLoggedIn) {
|
||||||
|
return (
|
||||||
|
<div className="relative">
|
||||||
|
<div className="filter blur-sm pointer-events-none select-none opacity-60">
|
||||||
|
{children}
|
||||||
|
</div>
|
||||||
|
<div className="absolute inset-0 flex items-center justify-center bg-white/70 rounded-xl">
|
||||||
|
<div className="text-center space-y-3">
|
||||||
|
<div className="text-4xl">🔒</div>
|
||||||
|
<p className="text-slate-600 font-medium">登录后查看完整数据</p>
|
||||||
|
<div className="flex gap-2 justify-center">
|
||||||
|
<Link href="/login" className="text-sm bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg transition-colors">登录</Link>
|
||||||
|
<Link href="/register" className="text-sm border border-slate-300 text-slate-600 hover:border-blue-400 px-4 py-2 rounded-lg transition-colors">注册</Link>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return <>{children}</>;
|
||||||
|
}
|
||||||
|
|
||||||
// ─── 主仪表盘 ────────────────────────────────────────────────────
|
// ─── 主仪表盘 ────────────────────────────────────────────────────
|
||||||
export default function Dashboard() {
|
export default function Dashboard() {
|
||||||
|
const { isLoggedIn } = useAuth();
|
||||||
const [rates, setRates] = useState<RatesResponse | null>(null);
|
const [rates, setRates] = useState<RatesResponse | null>(null);
|
||||||
const [stats, setStats] = useState<StatsResponse | null>(null);
|
const [stats, setStats] = useState<StatsResponse | null>(null);
|
||||||
const [history, setHistory] = useState<HistoryResponse | null>(null);
|
const [history, setHistory] = useState<HistoryResponse | null>(null);
|
||||||
@ -110,11 +139,12 @@ export default function Dashboard() {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const fetchSlow = useCallback(async () => {
|
const fetchSlow = useCallback(async () => {
|
||||||
|
if (!isLoggedIn) return;
|
||||||
try {
|
try {
|
||||||
const [s, h, sig, y] = await Promise.all([api.stats(), api.history(), api.signalsHistory(), api.statsYtd()]);
|
const [s, h, sig, y] = await Promise.all([api.stats(), api.history(), api.signalsHistory(), api.statsYtd()]);
|
||||||
setStats(s); setHistory(h); setSignals(sig.items || []); setYtd(y);
|
setStats(s); setHistory(h); setSignals(sig.items || []); setYtd(y);
|
||||||
} catch {}
|
} catch {}
|
||||||
}, []);
|
}, [isLoggedIn]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
fetchRates(); fetchSlow();
|
fetchRates(); fetchSlow();
|
||||||
@ -154,6 +184,7 @@ export default function Dashboard() {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 统计卡片 */}
|
{/* 统计卡片 */}
|
||||||
|
<AuthGate>
|
||||||
{stats && (
|
{stats && (
|
||||||
<div className="grid grid-cols-3 gap-3">
|
<div className="grid grid-cols-3 gap-3">
|
||||||
<StatsCard title="BTC 套利" mean7d={stats.BTC.mean7d} annualized={stats.BTC.annualized} accent="blue" />
|
<StatsCard title="BTC 套利" mean7d={stats.BTC.mean7d} annualized={stats.BTC.annualized} accent="blue" />
|
||||||
@ -283,6 +314,7 @@ export default function Dashboard() {
|
|||||||
<span className="text-blue-600 font-medium">策略原理:</span>
|
<span className="text-blue-600 font-medium">策略原理:</span>
|
||||||
持有现货多头 + 永续空头,每8小时收取资金费率,赚取无方向风险的稳定收益。
|
持有现货多头 + 永续空头,每8小时收取资金费率,赚取无方向风险的稳定收益。
|
||||||
</div>
|
</div>
|
||||||
|
</AuthGate>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,80 +1,88 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
import { useAuth } from "@/lib/auth";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
|
import Link from "next/link";
|
||||||
|
|
||||||
export default function RegisterPage() {
|
export default function RegisterPage() {
|
||||||
|
const { register } = useAuth();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const [email, setEmail] = useState("");
|
const [email, setEmail] = useState("");
|
||||||
const [password, setPassword] = useState("");
|
const [password, setPassword] = useState("");
|
||||||
const [discordId, setDiscordId] = useState("");
|
const [inviteCode, setInviteCode] = useState("");
|
||||||
const [error, setError] = useState("");
|
const [error, setError] = useState("");
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
const handleSubmit = async (e: React.FormEvent) => {
|
const handleSubmit = async (e: React.FormEvent) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
setLoading(true);
|
|
||||||
setError("");
|
setError("");
|
||||||
|
if (password.length < 6) {
|
||||||
|
setError("密码至少6位");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
const r = await fetch("/api/auth/register", {
|
await register(email, password, inviteCode);
|
||||||
method: "POST",
|
router.push("/");
|
||||||
headers: { "Content-Type": "application/json" },
|
} catch (err: any) {
|
||||||
body: JSON.stringify({ email, password, discord_id: discordId || undefined }),
|
setError(err.message || "注册失败");
|
||||||
});
|
|
||||||
const data = await r.json();
|
|
||||||
if (!r.ok) { setError(data.detail || "注册失败"); return; }
|
|
||||||
router.push("/login?registered=1");
|
|
||||||
} catch {
|
|
||||||
setError("网络错误,请重试");
|
|
||||||
} finally {
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="min-h-[70vh] flex items-center justify-center">
|
<div className="min-h-screen bg-slate-50 flex items-center justify-center p-4">
|
||||||
<div className="w-full max-w-md rounded-xl border border-slate-200 bg-white p-8 space-y-6">
|
<div className="w-full max-w-md">
|
||||||
<div>
|
<div className="bg-white rounded-xl shadow-sm border border-slate-200 p-8">
|
||||||
<h1 className="text-2xl font-bold text-slate-900">注册账号</h1>
|
<div className="text-center mb-8">
|
||||||
<p className="text-slate-500 text-sm mt-1">注册后可接收套利信号推送</p>
|
<h1 className="text-2xl font-bold text-slate-800">⚡ Arbitrage Engine</h1>
|
||||||
|
<p className="text-slate-500 text-sm mt-2">注册新账户(需要邀请码)</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<form onSubmit={handleSubmit} className="space-y-4">
|
<form onSubmit={handleSubmit} className="space-y-4">
|
||||||
<div>
|
<div>
|
||||||
<label className="block text-sm text-slate-700 mb-1">邮箱</label>
|
<label className="block text-sm font-medium text-slate-700 mb-1">邀请码</label>
|
||||||
<input
|
<input
|
||||||
type="email" required value={email} onChange={e => setEmail(e.target.value)}
|
type="text" value={inviteCode} onChange={e => setInviteCode(e.target.value.toUpperCase())}
|
||||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent font-mono tracking-wider"
|
||||||
placeholder="your@email.com"
|
placeholder="输入邀请码" required
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label className="block text-sm text-slate-700 mb-1">密码</label>
|
<label className="block text-sm font-medium text-slate-700 mb-1">邮箱</label>
|
||||||
<input
|
<input
|
||||||
type="password" required value={password} onChange={e => setPassword(e.target.value)}
|
type="email" value={email} onChange={e => setEmail(e.target.value)}
|
||||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||||
placeholder="至少8位"
|
placeholder="you@example.com" required
|
||||||
minLength={8}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label className="block text-sm text-slate-700 mb-1">Discord ID <span className="text-slate-500">(选填,用于接收信号)</span></label>
|
<label className="block text-sm font-medium text-slate-700 mb-1">密码</label>
|
||||||
<input
|
<input
|
||||||
type="text" value={discordId} onChange={e => setDiscordId(e.target.value)}
|
type="password" value={password} onChange={e => setPassword(e.target.value)}
|
||||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||||
placeholder="例:123456789012345678"
|
placeholder="至少6位" required minLength={6}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{error && <p className="text-red-400 text-sm">{error}</p>}
|
|
||||||
|
{error && <p className="text-red-500 text-sm">{error}</p>}
|
||||||
|
|
||||||
<button
|
<button
|
||||||
type="submit" disabled={loading}
|
type="submit" disabled={loading}
|
||||||
className="w-full bg-cyan-600 hover:bg-cyan-500 disabled:opacity-50 text-white font-medium py-2 rounded-lg text-sm transition-colors"
|
className="w-full bg-blue-600 hover:bg-blue-700 text-white py-2.5 rounded-lg font-medium text-sm transition-colors disabled:opacity-50"
|
||||||
>
|
>
|
||||||
{loading ? "注册中..." : "注册"}
|
{loading ? "注册中..." : "注册"}
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
<p className="text-center text-sm text-slate-500">
|
|
||||||
已有账号?<a href="/login" className="text-blue-600 hover:underline">登录</a>
|
<p className="text-center text-sm text-slate-500 mt-6">
|
||||||
|
已有账户?{" "}
|
||||||
|
<Link href="/login" className="text-blue-600 hover:underline">登录</Link>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
54
frontend/components/AuthHeader.tsx
Normal file
54
frontend/components/AuthHeader.tsx
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useAuth } from "@/lib/auth";
|
||||||
|
|
||||||
|
export default function AuthHeader() {
|
||||||
|
const { user, isLoggedIn, logout } = useAuth();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{/* 桌面端顶栏 */}
|
||||||
|
<header className="hidden md:flex items-center justify-end px-6 py-3 bg-white border-b border-slate-200 gap-3">
|
||||||
|
{isLoggedIn ? (
|
||||||
|
<>
|
||||||
|
<span className="text-sm text-slate-500">{user?.email}</span>
|
||||||
|
{user?.role === "admin" && (
|
||||||
|
<span className="bg-amber-100 text-amber-700 text-xs px-2 py-0.5 rounded-full font-medium">Admin</span>
|
||||||
|
)}
|
||||||
|
<button onClick={logout}
|
||||||
|
className="text-sm text-slate-600 hover:text-red-500 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-red-300 transition-colors">
|
||||||
|
退出
|
||||||
|
</button>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Link href="/login"
|
||||||
|
className="text-sm text-slate-600 hover:text-blue-600 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-blue-300 transition-colors">
|
||||||
|
登录
|
||||||
|
</Link>
|
||||||
|
<Link href="/register"
|
||||||
|
className="text-sm text-white bg-blue-600 hover:bg-blue-700 px-3 py-1.5 rounded-lg transition-colors font-medium">
|
||||||
|
注册
|
||||||
|
</Link>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{/* 手机端顶栏 */}
|
||||||
|
<header className="md:hidden flex items-center justify-between px-4 py-2 bg-white border-b border-slate-200 fixed top-0 left-0 right-0 z-40">
|
||||||
|
<span className="font-bold text-blue-600 text-sm">⚡ ArbEngine</span>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{isLoggedIn ? (
|
||||||
|
<button onClick={logout} className="text-xs text-slate-500 hover:text-red-500">退出</button>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Link href="/login" className="text-xs text-slate-500">登录</Link>
|
||||||
|
<Link href="/register" className="text-xs bg-blue-600 text-white px-2 py-1 rounded">注册</Link>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</header>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -1,17 +1,11 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
import { useAuth } from "@/lib/auth";
|
||||||
const navLinks = [
|
|
||||||
{ href: "/", label: "仪表盘" },
|
|
||||||
{ href: "/kline", label: "K线" },
|
|
||||||
{ href: "/live", label: "实时" },
|
|
||||||
{ href: "/signals", label: "信号" },
|
|
||||||
{ href: "/about", label: "说明" },
|
|
||||||
];
|
|
||||||
|
|
||||||
export default function Navbar() {
|
export default function Navbar() {
|
||||||
const [open, setOpen] = useState(false);
|
const [open, setOpen] = useState(false);
|
||||||
|
const { user, isLoggedIn, logout } = useAuth();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<nav className="border-b border-slate-200 bg-white sticky top-0 z-50 shadow-sm">
|
<nav className="border-b border-slate-200 bg-white sticky top-0 z-50 shadow-sm">
|
||||||
@ -22,21 +16,36 @@ export default function Navbar() {
|
|||||||
|
|
||||||
{/* Desktop nav */}
|
{/* Desktop nav */}
|
||||||
<div className="hidden md:flex items-center gap-6 text-sm ml-8">
|
<div className="hidden md:flex items-center gap-6 text-sm ml-8">
|
||||||
{navLinks.map(l => (
|
<Link href="/" className="text-slate-600 hover:text-blue-600 transition-colors">仪表盘</Link>
|
||||||
<Link key={l.href} href={l.href} className="text-slate-600 hover:text-blue-600 transition-colors">
|
<Link href="/about" className="text-slate-600 hover:text-blue-600 transition-colors">说明</Link>
|
||||||
{l.label}
|
|
||||||
</Link>
|
|
||||||
))}
|
|
||||||
</div>
|
</div>
|
||||||
<div className="hidden md:flex items-center gap-3 text-sm ml-auto">
|
<div className="hidden md:flex items-center gap-3 text-sm ml-auto">
|
||||||
|
{isLoggedIn ? (
|
||||||
|
<>
|
||||||
|
<span className="text-slate-500 text-xs">{user?.email}</span>
|
||||||
|
{user?.role === "admin" && (
|
||||||
|
<span className="bg-amber-100 text-amber-700 text-xs px-2 py-0.5 rounded-full font-medium">Admin</span>
|
||||||
|
)}
|
||||||
|
<button onClick={logout} className="text-slate-500 hover:text-red-500 transition-colors">退出</button>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
<Link href="/login" className="text-slate-500 hover:text-slate-800 transition-colors">登录</Link>
|
<Link href="/login" className="text-slate-500 hover:text-slate-800 transition-colors">登录</Link>
|
||||||
<Link href="/register" className="bg-blue-600 hover:bg-blue-700 text-white px-3 py-1.5 rounded-lg transition-colors">注册</Link>
|
<Link href="/register" className="bg-blue-600 hover:bg-blue-700 text-white px-3 py-1.5 rounded-lg transition-colors">注册</Link>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Mobile */}
|
{/* Mobile */}
|
||||||
<div className="md:hidden ml-auto flex items-center gap-2">
|
<div className="md:hidden ml-auto flex items-center gap-2">
|
||||||
|
{isLoggedIn ? (
|
||||||
|
<button onClick={logout} className="text-slate-500 text-sm">退出</button>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
<Link href="/login" className="text-slate-500 text-sm">登录</Link>
|
<Link href="/login" className="text-slate-500 text-sm">登录</Link>
|
||||||
<Link href="/register" className="bg-blue-600 text-white px-2 py-1 rounded text-sm">注册</Link>
|
<Link href="/register" className="bg-blue-600 text-white px-2 py-1 rounded text-sm">注册</Link>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<button onClick={() => setOpen(!open)} className="ml-1 p-2 text-slate-500" aria-label="菜单">
|
<button onClick={() => setOpen(!open)} className="ml-1 p-2 text-slate-500" aria-label="菜单">
|
||||||
{open ? (
|
{open ? (
|
||||||
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
@ -53,12 +62,8 @@ export default function Navbar() {
|
|||||||
|
|
||||||
{open && (
|
{open && (
|
||||||
<div className="md:hidden border-t border-slate-100 bg-white px-4 py-3 space-y-1">
|
<div className="md:hidden border-t border-slate-100 bg-white px-4 py-3 space-y-1">
|
||||||
{navLinks.map(l => (
|
<Link href="/" onClick={() => setOpen(false)} className="block py-2 text-slate-600 hover:text-blue-600 text-sm">仪表盘</Link>
|
||||||
<Link key={l.href} href={l.href} onClick={() => setOpen(false)}
|
<Link href="/about" onClick={() => setOpen(false)} className="block py-2 text-slate-600 hover:text-blue-600 text-sm">说明</Link>
|
||||||
className="block py-2 text-slate-600 hover:text-blue-600 transition-colors text-sm">
|
|
||||||
{l.label}
|
|
||||||
</Link>
|
|
||||||
))}
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</nav>
|
</nav>
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import { authFetch } from "./auth";
|
||||||
|
|
||||||
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "";
|
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "";
|
||||||
|
|
||||||
export interface RateData {
|
export interface RateData {
|
||||||
@ -84,21 +86,31 @@ export interface YtdStatsResponse {
|
|||||||
ETH: YtdStats;
|
ETH: YtdStats;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function fetchAPI<T>(path: string): Promise<T> {
|
// Public fetch (no auth needed)
|
||||||
|
async function fetchPublic<T>(path: string): Promise<T> {
|
||||||
const res = await fetch(`${API_BASE}${path}`, { cache: "no-store" });
|
const res = await fetch(`${API_BASE}${path}`, { cache: "no-store" });
|
||||||
if (!res.ok) throw new Error(`API error ${res.status}`);
|
if (!res.ok) throw new Error(`API error ${res.status}`);
|
||||||
return res.json();
|
return res.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Protected fetch (auth required, auto-refresh)
|
||||||
|
async function fetchProtected<T>(path: string): Promise<T> {
|
||||||
|
const res = await authFetch(path, { cache: "no-store" });
|
||||||
|
if (!res.ok) throw new Error(`API error ${res.status}`);
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
export const api = {
|
export const api = {
|
||||||
rates: () => fetchAPI<RatesResponse>("/api/rates"),
|
// Public
|
||||||
history: () => fetchAPI<HistoryResponse>("/api/history"),
|
rates: () => fetchPublic<RatesResponse>("/api/rates"),
|
||||||
stats: () => fetchAPI<StatsResponse>("/api/stats"),
|
health: () => fetchPublic<{ status: string }>("/api/health"),
|
||||||
health: () => fetchAPI<{ status: string }>("/api/health"),
|
// Protected
|
||||||
signalsHistory: () => fetchAPI<SignalsHistoryResponse>("/api/signals/history"),
|
history: () => fetchProtected<HistoryResponse>("/api/history"),
|
||||||
|
stats: () => fetchProtected<StatsResponse>("/api/stats"),
|
||||||
|
signalsHistory: () => fetchProtected<SignalsHistoryResponse>("/api/signals/history"),
|
||||||
snapshots: (hours = 24, limit = 5000) =>
|
snapshots: (hours = 24, limit = 5000) =>
|
||||||
fetchAPI<SnapshotsResponse>(`/api/snapshots?hours=${hours}&limit=${limit}`),
|
fetchProtected<SnapshotsResponse>(`/api/snapshots?hours=${hours}&limit=${limit}`),
|
||||||
kline: (symbol = "BTC", interval = "1h", limit = 500) =>
|
kline: (symbol = "BTC", interval = "1h", limit = 500) =>
|
||||||
fetchAPI<KlineResponse>(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`),
|
fetchProtected<KlineResponse>(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`),
|
||||||
statsYtd: () => fetchAPI<YtdStatsResponse>("/api/stats/ytd"),
|
statsYtd: () => fetchProtected<YtdStatsResponse>("/api/stats/ytd"),
|
||||||
};
|
};
|
||||||
|
|||||||
137
frontend/lib/auth.tsx
Normal file
137
frontend/lib/auth.tsx
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { createContext, useContext, useState, useEffect, ReactNode, useCallback } from "react";
|
||||||
|
|
||||||
|
interface User {
|
||||||
|
id: number;
|
||||||
|
email: string;
|
||||||
|
role: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AuthState {
|
||||||
|
user: User | null;
|
||||||
|
accessToken: string | null;
|
||||||
|
loading: boolean;
|
||||||
|
login: (email: string, password: string) => Promise<void>;
|
||||||
|
register: (email: string, password: string, inviteCode: string) => Promise<void>;
|
||||||
|
logout: () => void;
|
||||||
|
isLoggedIn: boolean;
|
||||||
|
isAdmin: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AuthContext = createContext<AuthState | undefined>(undefined);
|
||||||
|
|
||||||
|
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "";
|
||||||
|
|
||||||
|
export function AuthProvider({ children }: { children: ReactNode }) {
|
||||||
|
const [user, setUser] = useState<User | null>(null);
|
||||||
|
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
|
||||||
|
// init from localStorage
|
||||||
|
useEffect(() => {
|
||||||
|
const token = localStorage.getItem("access_token");
|
||||||
|
const saved = localStorage.getItem("user");
|
||||||
|
if (token && saved) {
|
||||||
|
try {
|
||||||
|
setAccessToken(token);
|
||||||
|
setUser(JSON.parse(saved));
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
setLoading(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const saveAuth = (data: { access_token: string; refresh_token: string; user: User }) => {
|
||||||
|
localStorage.setItem("access_token", data.access_token);
|
||||||
|
localStorage.setItem("refresh_token", data.refresh_token);
|
||||||
|
localStorage.setItem("user", JSON.stringify(data.user));
|
||||||
|
setAccessToken(data.access_token);
|
||||||
|
setUser(data.user);
|
||||||
|
};
|
||||||
|
|
||||||
|
const login = useCallback(async (email: string, password: string) => {
|
||||||
|
const res = await fetch(`${API_BASE}/api/auth/login`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({ email, password }),
|
||||||
|
});
|
||||||
|
if (!res.ok) {
|
||||||
|
const err = await res.json().catch(() => ({}));
|
||||||
|
throw new Error(err.detail || "Login failed");
|
||||||
|
}
|
||||||
|
const data = await res.json();
|
||||||
|
saveAuth(data);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const register = useCallback(async (email: string, password: string, inviteCode: string) => {
|
||||||
|
const res = await fetch(`${API_BASE}/api/auth/register`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({ email, password, invite_code: inviteCode }),
|
||||||
|
});
|
||||||
|
if (!res.ok) {
|
||||||
|
const err = await res.json().catch(() => ({}));
|
||||||
|
throw new Error(err.detail || "Registration failed");
|
||||||
|
}
|
||||||
|
const data = await res.json();
|
||||||
|
saveAuth(data);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const logout = useCallback(() => {
|
||||||
|
localStorage.removeItem("access_token");
|
||||||
|
localStorage.removeItem("refresh_token");
|
||||||
|
localStorage.removeItem("user");
|
||||||
|
setAccessToken(null);
|
||||||
|
setUser(null);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AuthContext.Provider value={{
|
||||||
|
user, accessToken, loading, login, register, logout,
|
||||||
|
isLoggedIn: !!user, isAdmin: user?.role === "admin",
|
||||||
|
}}>
|
||||||
|
{children}
|
||||||
|
</AuthContext.Provider>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useAuth() {
|
||||||
|
const ctx = useContext(AuthContext);
|
||||||
|
if (!ctx) throw new Error("useAuth must be used within AuthProvider");
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticated fetch helper
|
||||||
|
export async function authFetch(path: string, options: RequestInit = {}): Promise<Response> {
|
||||||
|
const token = localStorage.getItem("access_token");
|
||||||
|
const headers = new Headers(options.headers);
|
||||||
|
if (token) headers.set("Authorization", `Bearer ${token}`);
|
||||||
|
|
||||||
|
let res = await fetch(`${API_BASE}${path}`, { ...options, headers });
|
||||||
|
|
||||||
|
// try refresh on 401
|
||||||
|
if (res.status === 401) {
|
||||||
|
const refreshToken = localStorage.getItem("refresh_token");
|
||||||
|
if (refreshToken) {
|
||||||
|
const refreshRes = await fetch(`${API_BASE}/api/auth/refresh`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({ refresh_token: refreshToken }),
|
||||||
|
});
|
||||||
|
if (refreshRes.ok) {
|
||||||
|
const data = await refreshRes.json();
|
||||||
|
localStorage.setItem("access_token", data.access_token);
|
||||||
|
localStorage.setItem("refresh_token", data.refresh_token);
|
||||||
|
headers.set("Authorization", `Bearer ${data.access_token}`);
|
||||||
|
res = await fetch(`${API_BASE}${path}`, { ...options, headers });
|
||||||
|
} else {
|
||||||
|
// refresh failed, clear auth
|
||||||
|
localStorage.removeItem("access_token");
|
||||||
|
localStorage.removeItem("refresh_token");
|
||||||
|
localStorage.removeItem("user");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user