diff --git a/.gitignore b/.gitignore index 6daac89..81cc89c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ data/ # Agent/planning docs docs/ +.pi/ diff --git a/scripts/query_logs.py b/scripts/query_logs.py new file mode 100755 index 0000000..e9820e0 --- /dev/null +++ b/scripts/query_logs.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""Query Pydantic Logfire logs via SQL API. + +Usage: + python scripts/query_logs.py errors --minutes 30 + python scripts/query_logs.py warnings --limit 20 + python scripts/query_logs.py slow --threshold 5000 + python scripts/query_logs.py user --user-id 12345 + python scripts/query_logs.py group --group-id -1001234567 + python scripts/query_logs.py sql "SELECT * FROM records LIMIT 10" +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from datetime import UTC, datetime, timedelta + +import requests +from dotenv import load_dotenv + +DEFAULT_API_URL = "https://logfire-us.pydantic.dev/v2/query" +EU_API_URL = "https://logfire-eu.pydantic.dev/v2/query" +DEFAULT_LIMIT = 50 +DEFAULT_MINUTES = 30 + +QUERY_TEMPLATES: dict[str, str] = { + "errors": """\ +SELECT start_timestamp, duration, message, trace_id, is_exception, + exception_message, attributes +FROM records +WHERE is_exception = true +ORDER BY start_timestamp DESC +LIMIT {limit}""", + "warnings": """\ +SELECT start_timestamp, duration, message, trace_id, level, attributes +FROM records +WHERE level = 'warn' +ORDER BY start_timestamp DESC +LIMIT {limit}""", + "slow": """\ +SELECT start_timestamp, duration, message, trace_id, attributes +FROM records +WHERE duration > {threshold} +ORDER BY duration DESC +LIMIT {limit}""", + "user": """\ +SELECT start_timestamp, duration, message, trace_id, level, attributes +FROM records +WHERE attributes->>'user_id' = '{user_id}' +ORDER BY start_timestamp DESC +LIMIT {limit}""", + "group": """\ +SELECT start_timestamp, duration, message, trace_id, level, attributes +FROM records +WHERE attributes->>'group_id' = '{group_id}' +ORDER BY start_timestamp DESC +LIMIT {limit}""", +} + +def get_config() -> tuple[str, str]: + """Read config from environment variables. + + Loads .env file if present. Reads LOGFIRE_READ_TOKEN and LOGFIRE_API_URL. + + Returns: + Tuple of (api_url, read_token). + + Raises: + SystemExit: If LOGFIRE_READ_TOKEN is not set. + """ + token = os.environ.get("LOGFIRE_READ_TOKEN") or os.environ.get("LOGFIRE_TOKEN") + if not token: + print( + "Error: LOGFIRE_READ_TOKEN not set.\n" + "Create a read token at https://logfire.pydantic.dev → Project Settings → Read Tokens\n" + "Then add LOGFIRE_READ_TOKEN=your_token_here to .env", + file=sys.stderr, + ) + sys.exit(1) + + api_url = os.environ.get("LOGFIRE_API_URL") + if not api_url: + api_url = EU_API_URL if token.startswith("pylf_v1_eu") else DEFAULT_API_URL + return api_url, token + +def query_logfire(api_url: str, token: str, sql: str, min_timestamp: str | None = None) -> list[dict]: + """Execute a SQL query against Logfire API. + + Args: + api_url: Logfire query endpoint URL. + token: Read token for authentication. + sql: SQL query to execute. + min_timestamp: ISO format timestamp for time range filter. + + Returns: + List of record dicts. + + Raises: + SystemExit: On API error. + """ + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "Accept": "application/json", + } + payload: dict[str, str | dict] = {"sql": sql} + payload["min_timestamp"] = min_timestamp or "2020-01-01T00:00:00+00:00" + + try: + response = requests.post(api_url, headers=headers, json=payload, timeout=30) + except requests.RequestException as e: + print(f"Network error: {e}", file=sys.stderr) + sys.exit(1) + + if response.status_code != 200: + print(f"API error {response.status_code}: {response.text}", file=sys.stderr) + sys.exit(1) + + try: + data = response.json() + except Exception: + print(f"API returned non-JSON response: {response.text[:500]}", file=sys.stderr) + sys.exit(1) + if isinstance(data, dict) and "data" in data: + return data["data"] + if isinstance(data, list): + return data + return [] + +def format_text(records: list[dict]) -> str: + """Format records as human-readable text. + + Args: + records: List of record dicts from API. + + Returns: + Formatted text string. + """ + if not records: + return "No records found." + + lines: list[str] = [] + for r in records: + ts = r.get("start_timestamp", "unknown") + if isinstance(ts, str) and len(ts) > 19: + ts = ts[:19] + duration = r.get("duration", "") + if duration and isinstance(duration, (int, float)): + duration = f"{duration:.0f}ms" + level = r.get("level", "") + trace_id = r.get("trace_id", "") + message = r.get("message", "") + is_exception = r.get("is_exception", False) + exception_message = r.get("exception_message", "") + + prefix = "ERROR" if is_exception else (level.upper() if level else "INFO") + duration_str = f" ({duration})" if duration else "" + trace_str = f" trace:{trace_id[:8]}" if trace_id else "" + + lines.append(f"[{ts}] {prefix}{duration_str}{trace_str}") + if message: + lines.append(f" {message}") + if exception_message: + lines.append(f" Exception: {exception_message}") + lines.append("") + + return "\n".join(lines).rstrip() + +def build_query(command: str, args: argparse.Namespace) -> tuple[str, str | None]: + """Build SQL query and min_timestamp from command and arguments. + + Args: + command: Query command name (errors, warnings, slow, user, group, sql). + args: Parsed CLI arguments. + + Returns: + Tuple of (sql_query, min_timestamp_iso). + """ + min_ts = (datetime.now(tz=UTC) - timedelta(minutes=args.minutes)).isoformat() + + if command == "sql": + sql = args.query + if "limit" not in sql.lower(): + sql = f"{sql.rstrip(';')} LIMIT {args.limit}" + return sql, None + + template = QUERY_TEMPLATES[command] + params: dict[str, int | str] = { + "limit": args.limit, + } + + if command == "slow": + params["threshold"] = args.threshold + elif command == "user": + params["user_id"] = args.user_id + elif command == "group": + params["group_id"] = args.group_id + + return template.format(**params), min_ts + +def _add_common_args(parser: argparse.ArgumentParser) -> None: + """Add common flags (--json, --limit, --minutes) to a parser.""" + parser.add_argument( + "--json", + action="store_true", + dest="json_output", + help="Output as JSON instead of text", + ) + parser.add_argument( + "--limit", + type=int, + default=DEFAULT_LIMIT, + help=f"Max results (default: {DEFAULT_LIMIT})", + ) + parser.add_argument( + "--minutes", + type=int, + default=DEFAULT_MINUTES, + help=f"Time window in minutes (default: {DEFAULT_MINUTES})", + ) + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Parse CLI arguments. + + Args: + argv: Argument list (defaults to sys.argv[1:]). + + Returns: + Parsed namespace. + """ + parser = argparse.ArgumentParser( + description="Query Pydantic Logfire logs via SQL API", + ) + + subparsers = parser.add_subparsers(dest="command", required=True) + + errors_p = subparsers.add_parser("errors", help="Error/exception logs") + _add_common_args(errors_p) + + warnings_p = subparsers.add_parser("warnings", help="Warning-level logs") + _add_common_args(warnings_p) + + slow_parser = subparsers.add_parser("slow", help="Slow spans") + _add_common_args(slow_parser) + slow_parser.add_argument( + "--threshold", + type=int, + default=5000, + help="Duration threshold in ms (default: 5000)", + ) + + user_parser = subparsers.add_parser("user", help="Activity by user ID") + _add_common_args(user_parser) + user_parser.add_argument("--user-id", required=True, type=int, help="Telegram user ID") + + group_parser = subparsers.add_parser("group", help="Activity by group ID") + _add_common_args(group_parser) + group_parser.add_argument("--group-id", required=True, type=int, help="Telegram group ID") + + sql_parser = subparsers.add_parser("sql", help="Free-form SQL query") + _add_common_args(sql_parser) + sql_parser.add_argument("query", help="SQL query to execute") + + return parser.parse_args(argv) + +def main(argv: list[str] | None = None) -> None: + """Main entry point. + + Args: + argv: Argument list (defaults to sys.argv[1:]). + """ + load_dotenv() + args = parse_args(argv) + api_url, token = get_config() + sql, min_ts = build_query(args.command, args) + records = query_logfire(api_url, token, sql, min_ts) + + if args.json_output: + print(json.dumps(records, indent=2, default=str)) + else: + print(format_text(records)) + +if __name__ == "__main__": + main() diff --git a/tests/test_query_logs.py b/tests/test_query_logs.py new file mode 100644 index 0000000..e589332 --- /dev/null +++ b/tests/test_query_logs.py @@ -0,0 +1,268 @@ +"""Tests for the Logfire query CLI tool.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from scripts.query_logs import ( + build_query, + format_text, + get_config, + main, + parse_args, + query_logfire, +) + +class TestGetConfig: + """get_config reads from environment variables.""" + + def test_returns_token_and_url(self): + """Returns token and URL from env.""" + with patch.dict("os.environ", {"LOGFIRE_READ_TOKEN": "test_token"}): + url, token = get_config() + assert token == "test_token" + assert "logfire" in url + + def test_custom_url(self): + """Custom URL overrides default.""" + with patch.dict("os.environ", {"LOGFIRE_READ_TOKEN": "tok", "LOGFIRE_API_URL": "http://custom"}): + url, token = get_config() + assert url == "http://custom" + + def test_missing_token_exits(self): + """Exits with error when LOGFIRE_READ_TOKEN not set.""" + with patch.dict("os.environ", {}, clear=True): + with pytest.raises(SystemExit): + get_config() + +class TestBuildQuery: + """build_query generates correct SQL from commands.""" + + def test_errors_query(self): + """Errors command generates exception filter.""" + args = parse_args(["errors", "--minutes", "60", "--limit", "25"]) + sql, min_ts = build_query("errors", args) + assert "is_exception = true" in sql + assert "LIMIT 25" in sql + assert min_ts is not None + + def test_warnings_query(self): + """Warnings command generates level filter.""" + args = parse_args(["warnings"]) + sql, min_ts = build_query("warnings", args) + assert "level = 'warn'" in sql + assert min_ts is not None + + def test_slow_query_with_threshold(self): + """Slow command includes duration threshold.""" + args = parse_args(["slow", "--threshold", "3000"]) + sql, min_ts = build_query("slow", args) + assert "duration > 3000" in sql + assert min_ts is not None + + def test_user_query(self): + """User command filters by user_id.""" + args = parse_args(["user", "--user-id", "12345"]) + sql, min_ts = build_query("user", args) + assert "'12345'" in sql + assert "user_id" in sql + assert min_ts is not None + + def test_group_query(self): + """Group command filters by group_id.""" + args = parse_args(["group", "--group-id", "-100123"]) + sql, min_ts = build_query("group", args) + assert "'-100123'" in sql + assert "group_id" in sql + assert min_ts is not None + + def test_sql_query_passthrough(self): + """SQL command passes query through.""" + args = parse_args(["sql", "SELECT * FROM records LIMIT 10"]) + sql, min_ts = build_query("sql", args) + assert sql == "SELECT * FROM records LIMIT 10" + assert min_ts is None # sql command uses default 2020-01-01 + + def test_sql_query_adds_limit_if_missing(self): + """SQL command adds LIMIT if not in query.""" + args = parse_args(["sql", "SELECT * FROM records"]) + sql, min_ts = build_query("sql", args) + assert "LIMIT 50" in sql + assert min_ts is None # sql command uses default 2020-01-01 + + def test_sql_query_preserves_existing_limit(self): + """SQL command keeps existing LIMIT.""" + args = parse_args(["sql", "SELECT * FROM records LIMIT 10"]) + sql, min_ts = build_query("sql", args) + assert "LIMIT 10" in sql + assert sql.count("LIMIT") == 1 + assert min_ts is None # sql command uses default 2020-01-01 + +class TestFormatText: + """format_text produces readable output.""" + + def test_empty_records(self): + """Empty records returns 'No records found'.""" + assert format_text([]) == "No records found." + + def test_single_error_record(self): + """Error record shows ERROR prefix and exception.""" + records = [ + { + "start_timestamp": "2026-06-11T14:32:01Z", + "duration": 523, + "message": "Failed to restrict user", + "trace_id": "abc123def456", + "is_exception": True, + "exception_message": "BadRequest: User not found", + } + ] + output = format_text(records) + assert "[2026-06-11T14:32:01]" in output + assert "ERROR" in output + assert "(523ms)" in output + assert "trace:abc123d" in output + assert "Failed to restrict user" in output + assert "BadRequest: User not found" in output + + def test_warning_record(self): + """Warning record shows WARN prefix.""" + records = [{"level": "warn", "message": "Slow response"}] + output = format_text(records) + assert "WARN" in output + + def test_multiple_records(self): + """Multiple records are separated by blank lines.""" + records = [{"message": "a"}, {"message": "b"}] + output = format_text(records) + assert "a" in output + assert "b" in output + +class TestQueryLogfire: + """query_logfire makes API calls.""" + + @patch("scripts.query_logs.requests.post") + def test_successful_query(self, mock_post): + """Returns data from successful API response.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = [{"message": "test"}] + mock_post.return_value = mock_resp + + result = query_logfire("http://api", "token", "SELECT 1") + assert result == [{"message": "test"}] + + @patch("scripts.query_logs.requests.post") + def test_wrapped_response(self, mock_post): + """Handles response with 'data' wrapper.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"data": [{"message": "test"}]} + mock_post.return_value = mock_resp + + result = query_logfire("http://api", "token", "SELECT 1") + assert result == [{"message": "test"}] + + @patch("scripts.query_logs.requests.post") + def test_api_error_exits(self, mock_post): + """Exits on non-200 status.""" + mock_resp = MagicMock() + mock_resp.status_code = 401 + mock_resp.text = "Unauthorized" + mock_post.return_value = mock_resp + + with pytest.raises(SystemExit): + query_logfire("http://api", "bad_token", "SELECT 1") + + @patch("scripts.query_logs.requests.post") + def test_network_error_exits(self, mock_post): + """Exits on network error.""" + import requests as req + mock_post.side_effect = req.ConnectionError("timeout") + + with pytest.raises(SystemExit): + query_logfire("http://api", "token", "SELECT 1") + +class TestParseArgs: + """parse_args handles CLI arguments.""" + + def test_errors_default(self): + """Errors command with defaults.""" + args = parse_args(["errors"]) + assert args.command == "errors" + assert args.minutes == 30 + assert args.limit == 50 + + def test_custom_minutes_and_limit(self): + """Custom minutes and limit after subcommand.""" + args = parse_args(["errors", "--minutes", "120", "--limit", "100"]) + assert args.minutes == 120 + assert args.limit == 100 + + def test_json_flag(self): + """--json flag sets json_output.""" + args = parse_args(["errors", "--json"]) + assert args.json_output is True + + def test_slow_threshold(self): + """Slow command accepts threshold.""" + args = parse_args(["slow", "--threshold", "3000"]) + assert args.threshold == 3000 + + def test_user_requires_id(self): + """User command requires --user-id.""" + with pytest.raises(SystemExit): + parse_args(["user"]) + + def test_group_requires_id(self): + """Group command requires --group-id.""" + with pytest.raises(SystemExit): + parse_args(["group"]) + + def test_sql_requires_query(self): + """SQL command requires query argument.""" + with pytest.raises(SystemExit): + parse_args(["sql"]) + + def test_command_required(self): + """Command is required.""" + with pytest.raises(SystemExit): + parse_args([]) + + def test_user_id_rejects_non_int(self): + """--user-id rejects non-numeric input (SQL injection prevention).""" + with pytest.raises(SystemExit): + parse_args(["user", "--user-id", "abc"]) + + def test_group_id_rejects_non_int(self): + """--group-id rejects non-numeric input (SQL injection prevention).""" + with pytest.raises(SystemExit): + parse_args(["group", "--group-id", "abc"]) + + def test_threshold_rejects_non_int(self): + """--threshold rejects non-numeric input.""" + with pytest.raises(SystemExit): + parse_args(["slow", "--threshold", "abc"]) + +class TestMainIntegration: + """main() orchestrates query flow.""" + + @patch("scripts.query_logs.query_logfire") + @patch("scripts.query_logs.get_config", return_value=("http://api", "token")) + def test_main_text_output(self, mock_config, mock_query, capsys): + """main() prints formatted text.""" + mock_query.return_value = [{"message": "test error", "is_exception": True}] + main(["errors"]) + output = capsys.readouterr().out + assert "ERROR" in output + assert "test error" in output + + @patch("scripts.query_logs.query_logfire") + @patch("scripts.query_logs.get_config", return_value=("http://api", "token")) + def test_main_json_output(self, mock_config, mock_query, capsys): + """main() --json prints JSON.""" + mock_query.return_value = [{"message": "test"}] + main(["errors", "--json"]) + output = capsys.readouterr().out + data = __import__("json").loads(output) + assert data == [{"message": "test"}]