diff --git a/app/pool.py b/app/pool.py index 0f89a46..1ddd2c9 100644 --- a/app/pool.py +++ b/app/pool.py @@ -1 +1,120 @@ """Question pool validation.""" + +from __future__ import annotations + +import json +from typing import Any + +from app.scoring import SCORE_FNS + +OPTION_KEYS = {"A", "B", "C", "D"} + + +class PoolValidationError(ValueError): + pass + + +def parse_pool_json(pool_json: str | dict[str, Any]) -> dict[str, Any]: + if isinstance(pool_json, str): + try: + data = json.loads(pool_json) + except json.JSONDecodeError as exc: + raise PoolValidationError(f"Invalid JSON: {exc.msg}") from exc + else: + data = pool_json + return validate_pool(data) + + +def validate_pool(data: dict[str, Any]) -> dict[str, Any]: + if not isinstance(data, dict): + raise PoolValidationError("Pool must be a JSON object") + title = data.get("title") + if not isinstance(title, str) or not title.strip(): + raise PoolValidationError("Pool title is required") + score_fn = data.get("score_fn", "linear_decay") + if score_fn not in SCORE_FNS: + raise PoolValidationError(f"Unknown score function: {score_fn}") + time_limit_default = _positive_int(data.get("time_limit_default", 60), "time_limit_default") + questions = data.get("questions") + if not isinstance(questions, list) or not questions: + raise PoolValidationError("Pool must include at least one question") + + normalized_questions: list[dict[str, Any]] = [] + for index, question in enumerate(questions): + if not isinstance(question, dict): + raise PoolValidationError(f"Question {index} must be an object") + normalized_questions.append(_validate_question(question, index, time_limit_default)) + + return { + "title": title.strip(), + "score_fn": score_fn, + "time_limit_default": time_limit_default, + "questions": normalized_questions, + } + + +def question_count(pool: dict[str, Any]) -> int: + return len(pool["questions"]) + + +def get_question(pool: dict[str, Any], question_idx: int) -> dict[str, Any]: + try: + return pool["questions"][question_idx] + except IndexError as exc: + raise PoolValidationError("Question index out of range") from exc + + +def question_time_limit(pool: dict[str, Any], question_idx: int) -> int: + question = get_question(pool, question_idx) + return int(question.get("time_limit") or pool["time_limit_default"]) + + +def public_question_payload(pool: dict[str, Any], question_idx: int) -> dict[str, Any]: + question = get_question(pool, question_idx) + return { + "question_idx": question_idx, + "text": question["text"], + "options": question["options"], + "time_limit": question_time_limit(pool, question_idx), + } + + +def _validate_question(question: dict[str, Any], index: int, default_limit: int) -> dict[str, Any]: + qid = question.get("id") + if not isinstance(qid, str) or not qid.strip(): + raise PoolValidationError(f"Question {index} id is required") + text = question.get("text") + if not isinstance(text, str) or not text.strip(): + raise PoolValidationError(f"Question {index} text is required") + options = question.get("options") + if not isinstance(options, dict) or set(options) != OPTION_KEYS: + raise PoolValidationError(f"Question {index} options must be exactly A, B, C, D") + for key, value in options.items(): + if not isinstance(value, str) or not value.strip(): + raise PoolValidationError(f"Question {index} option {key} is required") + correct = question.get("correct") + if correct not in OPTION_KEYS: + raise PoolValidationError(f"Question {index} correct must be one of A, B, C, D") + + normalized = { + "id": qid.strip(), + "text": text.strip(), + "options": {key: options[key].strip() for key in sorted(OPTION_KEYS)}, + "correct": correct, + } + if "time_limit" in question and question["time_limit"] is not None: + normalized["time_limit"] = _positive_int(question["time_limit"], f"Question {index} time_limit") + else: + normalized["time_limit"] = default_limit + explanation = question.get("explanation") + if explanation is not None: + if not isinstance(explanation, str): + raise PoolValidationError(f"Question {index} explanation must be text") + normalized["explanation"] = explanation.strip() + return normalized + + +def _positive_int(value: Any, label: str) -> int: + if isinstance(value, bool) or not isinstance(value, int) or value <= 0: + raise PoolValidationError(f"{label} must be a positive integer") + return value diff --git a/app/scoring.py b/app/scoring.py index f4be998..eadac44 100644 --- a/app/scoring.py +++ b/app/scoring.py @@ -1 +1,40 @@ """Score functions.""" + +from __future__ import annotations + +from collections.abc import Callable + +ScoreFn = Callable[[bool, int, int], int] +SCORE_FNS: dict[str, ScoreFn] = {} + + +def register(name: str) -> Callable[[ScoreFn], ScoreFn]: + def decorator(func: ScoreFn) -> ScoreFn: + SCORE_FNS[name] = func + return func + + return decorator + + +@register("linear_decay") +def linear_decay(correct: bool, elapsed_ms: int, time_limit_ms: int) -> int: + if not correct: + return 0 + elapsed_ms = max(0, min(elapsed_ms, time_limit_ms)) + return round(1000 * (1 - 0.5 * elapsed_ms / time_limit_ms)) + + +@register("flat") +def flat(correct: bool, elapsed_ms: int, time_limit_ms: int) -> int: + return 1000 if correct else 0 + + +@register("exponential_decay") +def exponential_decay(correct: bool, elapsed_ms: int, time_limit_ms: int) -> int: + if not correct: + return 0 + import math + + elapsed_ms = max(0, min(elapsed_ms, time_limit_ms)) + decay = math.exp(-2 * elapsed_ms / time_limit_ms) + return round(1000 * (0.5 + 0.5 * decay))