diff --git a/app/rate_limit.py b/app/rate_limit.py new file mode 100644 index 0000000..ff74bd3 --- /dev/null +++ b/app/rate_limit.py @@ -0,0 +1,72 @@ +"""Tiny in-memory token-bucket rate limiter. + +Used for `/admin/login` only. The student endpoints intentionally have +no IP-based throttling because a campus deployment puts ~40 students +behind one or a few NAT IPs; rate-limiting at the IP level would +false-positive the entire class. + +For the admin login endpoint, IP-based limiting is appropriate: the +instructor logs in from a single device, and brute-force attempts +generally come from a few attacker IPs. Per-IP token bucket of +10 attempts / minute is generous for the legitimate user, hostile +to a guesser. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Optional + +from fastapi import Request + + +@dataclass(slots=True) +class _Bucket: + tokens: float + last_ts: float + + +class TokenBucket: + """Per-key (e.g., per-IP) token bucket. + + `capacity` tokens accrue at `rate_per_sec`. Each call to `take()` + consumes one token; if the bucket is empty, returns False. + + State is process-local. An app restart resets all buckets, which + is acceptable for the threat model (slows attackers; doesn't + permanently lock anyone out). + """ + + def __init__(self, capacity: int, refill_per_minute: float) -> None: + self.capacity = float(capacity) + self.rate_per_sec = refill_per_minute / 60.0 + self.buckets: dict[str, _Bucket] = {} + + def take(self, key: str) -> bool: + now = time.monotonic() + b = self.buckets.get(key) + if b is None: + b = _Bucket(tokens=self.capacity, last_ts=now) + self.buckets[key] = b + elapsed = now - b.last_ts + b.tokens = min(self.capacity, b.tokens + elapsed * self.rate_per_sec) + b.last_ts = now + if b.tokens < 1.0: + return False + b.tokens -= 1.0 + return True + + +def client_ip(request: Request) -> str: + """Best-effort client IP extraction. + + Caddy puts the real client in `X-Forwarded-For`; uvicorn behind a + 127.0.0.1-only proxy will see `request.client.host == "127.0.0.1"` + for every request, so trusting X-F-F is necessary for any per-client + behaviour at all. + """ + xff = request.headers.get("x-forwarded-for") + if xff: + return xff.split(",")[0].strip() + return request.client.host if request.client else "unknown" diff --git a/app/routes_admin.py b/app/routes_admin.py index 40a4407..7d28f10 100644 --- a/app/routes_admin.py +++ b/app/routes_admin.py @@ -22,17 +22,30 @@ from app import auth from app.config import Settings from app.csv_export import export_session_csv from app.models import AdminLoginRequest +from app.rate_limit import TokenBucket, client_ip from app.room import RoomManager def router(settings: Settings, rooms: RoomManager) -> APIRouter: api = APIRouter() + # Per-app instance so test apps get fresh state. + # 10 attempts/minute/IP — generous for the instructor, hostile to brute + # force without locking out the campus network on student endpoints + # (which are not rate-limited at all, see rate_limit.py). + login_bucket = TokenBucket(capacity=10, refill_per_minute=10) + def require_admin(request: Request) -> None: auth.require_admin_request(settings, request) @api.post("/admin/login") - async def login(body: AdminLoginRequest, response: Response): + async def login(body: AdminLoginRequest, request: Request, response: Response): + ip = client_ip(request) + if not login_bucket.take(ip): + raise HTTPException( + status_code=429, + detail="Too many login attempts; try again in a minute.", + ) if not auth.verify_admin_password(settings, body.password): raise HTTPException(status_code=401, detail="Invalid admin password") auth.set_admin_cookie(settings, response, auth.sign_admin(settings)) diff --git a/tests/stress/live_accuracy.mjs b/tests/stress/live_accuracy.mjs new file mode 100644 index 0000000..01b95ee --- /dev/null +++ b/tests/stress/live_accuracy.mjs @@ -0,0 +1,298 @@ +// Live-target accuracy + latency stress test. +// +// Drives a real classroom-sized run against an already-deployed server +// (single-session app, sid=main), via the public HTTPS endpoint, and +// measures three things: +// 1. Stress: N concurrent student WS connections + one instructor WS, +// driving the full quiz lifecycle. +// 2. Accuracy: every submitted answer that matches the correct option +// (revealed after question_closed) MUST score > 0; every other +// submission MUST score == 0. +// 3. Latency: per-submit round-trip time from `ws.send(submit)` to the +// receipt of the matching `submit_ack`. Reports p50 / p95 / p99. +// +// Each simulated student is a SEPARATE WebSocket with its own cookie; +// "batching" only refers to how the opening handshakes are staggered +// (groups of 8, 250ms apart) so the source IP doesn't ETIMEDOUT under +// 50-simultaneous-handshake pressure. Once open, all 50 connections +// stay simultaneously connected through the whole quiz. +// +// Usage: +// node live_accuracy.mjs [num_students=50] [correct_pct=0.6] + +import WebSocket from "ws"; + +const baseUrl = (process.argv[2] || "https://quiz.ahkhan.me").replace(/\/$/, ""); +const adminPassword = process.argv[3]; +const N = parseInt(process.argv[4] || "50", 10); +const CORRECT_PCT = parseFloat(process.argv[5] || "0.6"); +const SID = process.env.QUIZ_SID || "main"; + +if (!adminPassword) { + console.error("Usage: node live_accuracy.mjs [N] [correct_pct]"); + process.exit(2); +} + +const wsBase = baseUrl.replace(/^http/, "ws"); +const sleep = (ms) => new Promise((r) => setTimeout(r, ms)); + +// -- HTTP / cookie helpers ------------------------------------------------ + +function parseSetCookie(headerVal) { + if (!headerVal) return null; + const m = headerVal.match(/(qz_(?:admin|student))=[^;,]+/); + return m ? m[0] : null; +} + +async function httpJson(method, path, body, cookie) { + const headers = { Accept: "application/json" }; + if (body !== undefined) headers["Content-Type"] = "application/json"; + if (cookie) headers["Cookie"] = cookie; + const res = await fetch(`${baseUrl}${path}`, { + method, + headers, + body: body !== undefined ? JSON.stringify(body) : undefined, + }); + const setCookie = res.headers.get("set-cookie"); + let json = null; + try { json = await res.json(); } catch {} + return { status: res.status, body: json, cookie: parseSetCookie(setCookie) }; +} + +async function adminLogin() { + const r = await httpJson("POST", "/admin/login", { password: adminPassword }); + if (r.status !== 200) throw new Error(`admin login: ${r.status}`); + if (!r.cookie) throw new Error(`admin login: no Set-Cookie`); + return r.cookie; +} + +async function adminReset(adminCookie) { + const r = await httpJson("POST", "/admin/api/reset", undefined, adminCookie); + if (r.status !== 200) throw new Error(`reset: ${r.status}`); +} + +async function adminState(adminCookie) { + const r = await httpJson("GET", "/admin/api/state", undefined, adminCookie); + if (r.status !== 200) throw new Error(`state: ${r.status}`); + return r.body; +} + +async function joinStudent(sid, studentId, name) { + const r = await httpJson("POST", `/api/session/${sid}/join`, { student_id: studentId, name }); + if (r.status !== 200) throw new Error(`join ${studentId}: ${r.status}`); + if (!r.cookie) throw new Error(`join ${studentId}: no Set-Cookie`); + return r.cookie; +} + +// -- WS bookkeeping -------------------------------------------------------- + +// Build a Student object: opens the WS, attaches the message listener +// IMMEDIATELY (before connection establishes), so no incoming frame is +// ever lost to a listener-attach race. Returns a Promise that resolves +// to the bookkeeping struct once the lobby snapshot has arrived. +function makeStudent(sid, cookie, idx) { + const studentId = `S${String(idx).padStart(3, "0")}`; + const ws = new WebSocket(`${wsBase}/ws/student/${SID}`, { + headers: { Cookie: cookie }, + perMessageDeflate: false, + }); + const state = { + studentId, + ws, + submits: new Map(), + inLobby: false, + lastQuestionOpen: null, + closedSeen: new Map(), + ended: null, + closed: false, + }; + let resolveLobby; + const lobbyP = new Promise((r) => { resolveLobby = r; }); + ws.on("error", () => {}); + ws.on("close", () => { state.closed = true; }); + ws.on("message", (raw) => { + let m; + try { m = JSON.parse(raw.toString()); } catch { return; } + switch (m.type) { + case "state": + if (m.state === "lobby") { + state.inLobby = true; + resolveLobby(); + } + break; + case "question_open": + state.lastQuestionOpen = m; + break; + case "submit_ack": { + const sub = state.submits.get(m.question_idx); + if (sub) { sub.ackTs = performance.now(); sub.score = m.score; } + break; + } + case "question_closed": + state.closedSeen.set(m.question_idx, { + correct: m.correct, + your_answer: m.your_answer, + your_score: m.your_score, + }); + break; + case "session_ended": + state.ended = m; + break; + } + }); + return { state, lobbyP }; +} + +function openInstructorWS(adminCookie) { + const ws = new WebSocket(`${wsBase}/ws/instructor/${SID}`, { + headers: { Cookie: adminCookie }, + perMessageDeflate: false, + }); + const ev = { ws, lastQuestionOpen: null }; + let resolveOpen; + const openP = new Promise((r) => { resolveOpen = r; }); + ws.on("open", () => resolveOpen()); + ws.on("error", () => {}); + ws.on("message", (raw) => { + let m; try { m = JSON.parse(raw.toString()); } catch { return; } + if (m.type === "question_open") ev.lastQuestionOpen = m; + }); + return { ev, openP }; +} + +// -- Driver --------------------------------------------------------------- + +async function main() { + console.log(`[live_accuracy] target=${baseUrl} sid=${SID} N=${N} correct_pct=${CORRECT_PCT}`); + + console.log(`[stage 1] admin login + reset`); + const adminCookie = await adminLogin(); + await adminReset(adminCookie); + const initialState = await adminState(adminCookie); + const totalQs = initialState.pool_meta.question_count; + console.log(`[stage 1] ok — pool="${initialState.title}" Qs=${totalQs} score_fn=${initialState.pool_meta.score_fn}`); + + console.log(`[stage 2] joining ${N} students (HTTP /join, serial)`); + const cookies = []; + for (let i = 0; i < N; i++) { + cookies.push(await joinStudent(SID, `S${String(i).padStart(3, "0")}`, `Student ${i}`)); + if ((i + 1) % 10 === 0) process.stdout.write(` joined ${i + 1}/${N}\n`); + } + + console.log(`[stage 3] opening 1 admin + ${N} student WSs (batched)`); + const inst = openInstructorWS(adminCookie); + await inst.openP; + + // Open student WSs in batches of 8, 250ms apart. + const students = []; + const BATCH = 8, GAP_MS = 250; + for (let i = 0; i < cookies.length; i += BATCH) { + const slice = cookies.slice(i, i + BATCH); + const wave = slice.map((c, j) => makeStudent(SID, c, i + j)); + await Promise.all(wave.map((s) => s.lobbyP)); + students.push(...wave.map((s) => s.state)); + if (i + BATCH < cookies.length) await sleep(GAP_MS); + } + console.log(`[stage 3] ok — all ${students.length} students saw the lobby snapshot`); + + // -- Drive each question --- + console.log(`[stage 4] driving ${totalQs} questions via admin "next"`); + const correctByIdx = new Map(); + const allLatencies = []; + let totalSubmits = 0; + let accuracyOk = 0; + const accuracyMismatches = []; + + for (let qIdx = 0; qIdx < totalQs; qIdx++) { + // Trigger the question via admin + const beforeIdx = inst.ev.lastQuestionOpen?.question_idx ?? -1; + inst.ev.ws.send(JSON.stringify({ type: "next" })); + // Wait for the admin WS to see the new question_open; that confirms + // the broadcast went out. + const broadcastDeadline = Date.now() + 5000; + while ( + (inst.ev.lastQuestionOpen?.question_idx ?? -1) === beforeIdx && + Date.now() < broadcastDeadline + ) { + await sleep(20); + } + const opened = inst.ev.lastQuestionOpen; + if (!opened || opened.question_idx !== qIdx) { + throw new Error(`question_open for q=${qIdx} not received within 5s`); + } + const optionKeys = Object.keys(opened.options); + + // Each student picks an answer with random delay 50-1500ms + await Promise.all(students.map(async (s) => { + const answer = optionKeys[Math.floor(Math.random() * optionKeys.length)]; + const delay = 50 + Math.random() * 1450; + await sleep(delay); + const sub = { picked: answer, sentTs: performance.now() }; + s.submits.set(qIdx, sub); + try { s.ws.send(JSON.stringify({ type: "submit", question_idx: qIdx, answer })); } + catch (e) { sub.sendError = String(e); } + })); + + // Wait long enough for acks to arrive (latency p99 well under 1s on a healthy box) + await sleep(1500); + console.log(` q=${qIdx} sent; waiting for next loop`); + } + + // Final advance closes last question + ends session + console.log(`[stage 5] advancing past final → session_ended`); + inst.ev.ws.send(JSON.stringify({ type: "next" })); + // Give the broadcast a moment + collect closed snapshots + await sleep(2000); + + // Collect correct-answer map from any student who saw question_closed for each idx + for (let i = 0; i < totalQs; i++) { + for (const s of students) { + const c = s.closedSeen.get(i); + if (c) { correctByIdx.set(i, c.correct); break; } + } + } + + // -- Aggregate --- + for (const s of students) { + for (const [qidx, sub] of s.submits.entries()) { + totalSubmits++; + const correct = correctByIdx.get(qidx); + const wasCorrect = correct !== undefined && sub.picked === correct; + const scoreNonZero = sub.score !== undefined && sub.score > 0; + const scoreZero = sub.score !== undefined && sub.score === 0; + const accurate = (wasCorrect && scoreNonZero) || (!wasCorrect && scoreZero); + if (accurate) accuracyOk++; + else accuracyMismatches.push({ + student: s.studentId, qidx, + picked: sub.picked, correct, score: sub.score, + }); + if (sub.ackTs !== undefined) allLatencies.push(sub.ackTs - sub.sentTs); + } + } + + allLatencies.sort((a, b) => a - b); + const pct = (p) => allLatencies.length + ? allLatencies[Math.min(allLatencies.length - 1, Math.floor(p / 100 * allLatencies.length))] + : 0; + const mean = allLatencies.length + ? allLatencies.reduce((a, b) => a + b, 0) / allLatencies.length + : 0; + + console.log(`\n=== Results ===`); + console.log(`Submits : ${totalSubmits}`); + console.log(`Acks received : ${allLatencies.length} / ${totalSubmits} (${(100 * allLatencies.length / Math.max(1, totalSubmits)).toFixed(2)}%)`); + console.log(`Accuracy ok : ${accuracyOk} / ${totalSubmits} (${(100 * accuracyOk / Math.max(1, totalSubmits)).toFixed(2)}%)`); + console.log(`Accuracy fail : ${accuracyMismatches.length}`); + if (accuracyMismatches.length) { + console.log(`First few mismatches:`); + accuracyMismatches.slice(0, 5).forEach((d) => console.log(` `, d)); + } + console.log(`Latency (ms) : mean=${mean.toFixed(1)} p50=${pct(50).toFixed(1)} p95=${pct(95).toFixed(1)} p99=${pct(99).toFixed(1)} max=${(allLatencies[allLatencies.length-1] ?? 0).toFixed(1)}`); + console.log(`Correct answers : ${[...correctByIdx.entries()].map(([i, c]) => `Q${i+1}=${c}`).join(", ")}`); + + inst.ev.ws.close(); + for (const s of students) { try { s.ws.close(); } catch {} } + process.exit(accuracyMismatches.length === 0 ? 0 : 1); +} + +main().catch((err) => { console.error(err); process.exit(1); }); diff --git a/tests/stress/live_loop.sh b/tests/stress/live_loop.sh new file mode 100644 index 0000000..1757409 --- /dev/null +++ b/tests/stress/live_loop.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# Long-running live accuracy + latency loop. +# Each cycle resets the live session, runs the full live_accuracy.mjs +# test, parses the summary, and appends a JSON line to runs/live_summary.jsonl. +# +# Run: +# ADMIN_PW=$(cat /tmp/quiz-admin-pw.txt) tmux new -d -s quiz_live \ +# 'cd /home/ameer/RD/Projects/Apps/quiz/tests/stress && \ +# ADMIN_PW="$ADMIN_PW" bash live_loop.sh' +# Stop: +# tmux send -t quiz_live C-c # graceful, then tmux kill-session -t quiz_live +# +# Tunables: +# BASE_URL - default https://quiz.ahkhan.me +# N - default 50 students +# GAP_S - seconds between cycles (default 60) +# ADMIN_PW - required, the live admin password + +set -uo pipefail +cd "$(dirname "$0")" + +BASE_URL="${BASE_URL:-https://quiz.ahkhan.me}" +N="${N:-50}" +GAP_S="${GAP_S:-60}" + +if [ -z "${ADMIN_PW:-}" ]; then + echo "ADMIN_PW must be set in env" >&2 + exit 2 +fi + +mkdir -p runs +SUM="runs/live_summary.jsonl" +LOG="runs/live-$(date -u +%Y%m%dT%H%M%SZ).log" + +echo "{\"event\":\"loop_start\",\"ts\":\"$(date -u +%FT%TZ)\",\"target\":\"$BASE_URL\",\"N\":$N,\"gap_s\":$GAP_S,\"log\":\"$LOG\"}" | tee -a "$SUM" + +cycle=0 +total_pass=0 +total_fail=0 +total_acks=0 +total_submits=0 + +trap 'echo "{\"event\":\"loop_stop\",\"ts\":\"$(date -u +%FT%TZ)\",\"cycles\":'$cycle'}" | tee -a "$SUM"; exit 0' INT TERM + +while true; do + cycle=$((cycle + 1)) + ts=$(date -u +%FT%TZ) + printf '\n----- live cycle %d (%s) -----\n' "$cycle" "$ts" | tee -a "$LOG" + + out=$(timeout 180 node live_accuracy.mjs "$BASE_URL" "$ADMIN_PW" "$N" 2>&1) + ec=$? + echo "$out" | tee -a "$LOG" >/dev/null + + pass=$(echo "$out" | sed -n 's/.*Accuracy ok *: \([0-9]*\) \/ \([0-9]*\).*/\1/p') + total=$(echo "$out" | sed -n 's/.*Accuracy ok *: \([0-9]*\) \/ \([0-9]*\).*/\2/p') + fail=$(echo "$out" | sed -n 's/.*Accuracy fail *: \([0-9]*\)/\1/p') + acks=$(echo "$out" | sed -n 's/.*Acks received *: \([0-9]*\) \/.*/\1/p') + p50=$(echo "$out" | sed -n 's/.*p50=\([0-9.]*\) .*/\1/p' | tail -1) + p95=$(echo "$out" | sed -n 's/.*p95=\([0-9.]*\) .*/\1/p' | tail -1) + p99=$(echo "$out" | sed -n 's/.*p99=\([0-9.]*\) .*/\1/p' | tail -1) + max=$(echo "$out" | sed -n 's/.*max=\([0-9.]*\)$/\1/p' | tail -1) + mean=$(echo "$out" | sed -n 's/.*mean=\([0-9.]*\) .*/\1/p' | tail -1) + + total_pass=$((total_pass + ${pass:-0})) + total_fail=$((total_fail + ${fail:-0})) + total_acks=$((total_acks + ${acks:-0})) + total_submits=$((total_submits + ${total:-0})) + + echo "{\"event\":\"cycle\",\"ts\":\"$ts\",\"cycle\":$cycle,\"exit\":$ec,\"submits\":${total:-0},\"acc_ok\":${pass:-0},\"acc_fail\":${fail:-0},\"acks\":${acks:-0},\"mean_ms\":${mean:-0},\"p50\":${p50:-0},\"p95\":${p95:-0},\"p99\":${p99:-0},\"max\":${max:-0},\"running_pass\":$total_pass,\"running_fail\":$total_fail}" | tee -a "$SUM" + + sleep "$GAP_S" +done diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 0000000..a90ad88 --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,27 @@ +from app.rate_limit import TokenBucket +from conftest import admin_login + + +def test_token_bucket_allows_then_denies_then_refills(): + bucket = TokenBucket(capacity=3, refill_per_minute=60) # 1 token/sec + assert bucket.take("ip1") is True + assert bucket.take("ip1") is True + assert bucket.take("ip1") is True + assert bucket.take("ip1") is False # exhausted + # Different key has its own bucket + assert bucket.take("ip2") is True + + +def test_admin_login_rate_limits_after_burst(client): + # Default config: 10 attempts/min/IP. Eleventh attempt should 429. + # Exhaust on wrong-password attempts so the test doesn't depend on + # the right password being unknown. + for _ in range(10): + response = client.post("/admin/login", json={"password": "wrong"}) + assert response.status_code == 401 + # Eleventh attempt: throttled + response = client.post("/admin/login", json={"password": "wrong"}) + assert response.status_code == 429 + # Even a correct password is throttled until the bucket refills. + response = client.post("/admin/login", json={"password": "admin-pass"}) + assert response.status_code == 429