Coverage for chatgpt_proxy / auth / auth.py: 94%
116 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-12 16:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-12 16:19 +0000
1# MIT License
2#
3# Copyright (c) 2025 Tuomo Kriikkula
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in all
13# copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23"""Provides authentication and authorization for the chatgpt_proxy API server."""
25import datetime
26import hashlib
27import ipaddress
28import os
29from functools import wraps
30from http import HTTPStatus
31from inspect import isawaitable
32from secrets import compare_digest
33from typing import Callable
35import asyncpg
36import httpx
37import jwt
38import sanic
40from chatgpt_proxy.db import pool_acquire
41from chatgpt_proxy.db import queries
42from chatgpt_proxy.log import logger
43from chatgpt_proxy.steam import steam
44from chatgpt_proxy.types import Request
45from chatgpt_proxy.utils import get_remote_addr
47jwt_issuer = "ChatGPTProxy"
48jwt_audience = "ChatGPTProxy"
50_steam_web_api_key: str | None = None
53def load_config() -> None:
54 global _steam_web_api_key
55 _steam_web_api_key = os.environ.get("STEAM_WEB_API_KEY", None)
58load_config()
60ttl_is_real_game_server = datetime.timedelta(minutes=60).total_seconds()
61ttl_validate_db_token = datetime.timedelta(minutes=5).total_seconds()
64def is_real_game_server_key_builder(*args, **kwargs) -> str:
65 """NOTE: this function is specific to is_real_game_server!"""
67 # NOTE: has to be kept in sync manually with
68 # the real signature of is_real_game_server!
69 kwargs.pop("pg_pool", None)
70 kwargs.pop("http_client", None)
72 ordered_kwargs = sorted(kwargs.items())
73 return (
74 f"{is_real_game_server.__module__ or ""}"
75 f"{is_real_game_server.__name__}"
76 f"{str(args)}"
77 f"{str(ordered_kwargs)}"
78 )
81# TODO: waiting for updated aiocache + valkey-glide support on Windows!
82# - In the meanwhile, only use memory cache!
83# - See pyproject.toml for more details!
84# TODO: FIXME: IMPORTANT:
85# ENABLING THIS BREAKS TESTING! CACHE SHOULD BE CLEARED IN CERTAIN TESTS!
86# FOR THAT, WE NEED TO USE app_cache, BUT WE CANNOT DO THAT UNTIL
87# aiocache IS PATCHED! SEE THE ABOVE NOTE!
88# @aiocache.cached(
89# # cache=app_cache,
90# cache=aiocache.Cache.MEMORY,
91# ttl=ttl_is_real_game_server,
92# key_builder=is_real_game_server_key_builder,
93# # NOTE: Only cache the result if the server was successfully verified.
94# skip_cache_func=lambda x: x is False,
95# )
96async def is_real_game_server(
97 *,
98 game_server_address: ipaddress.IPv4Address,
99 game_server_port: int,
100 pg_pool: asyncpg.Pool,
101 http_client: httpx.AsyncClient,
102) -> bool:
103 try:
104 logger.debug("checking if {}:{} is a real RS2 server", game_server_address, game_server_port)
106 resp = await steam.web_api_request(
107 http_client=http_client,
108 pg_pool=pg_pool,
109 url=steam.server_list_url,
110 params={
111 "key": _steam_web_api_key,
112 "filter": f"\\gamedir\\rs2\\gameaddr\\{game_server_address}:{game_server_port}",
113 # TODO: limit param would speed up things, or would it?
114 }
115 )
116 resp.raise_for_status()
117 servers = resp.json()["response"]["servers"]
119 if not servers:
120 logger.debug("Steam Web API returned no servers for {}:{}",
121 game_server_address, game_server_port)
122 return False
124 return True
125 except Exception as e:
126 logger.debug("unable to verify {}:{} is a real RS2 server: {}: {}",
127 game_server_address, game_server_port, type(e).__name__, e)
128 return False
131def validate_db_token_key_builder(*args, **kwargs) -> str:
132 """NOTE: this function is specific to validate_db_token!"""
134 # NOTE: has to be kept in sync manually with
135 # the real signature of validate_db_token!
136 args = args[:-1]
138 ordered_kwargs = sorted(kwargs.items())
139 return (
140 f"{validate_db_token.__module__ or ""}"
141 f"{validate_db_token.__name__}"
142 f"{str(args)}"
143 f"{str(ordered_kwargs)}"
144 )
147# @aiocache.cached(
148# # cache=app_cache,
149# cache=aiocache.Cache.MEMORY,
150# ttl=ttl_validate_db_token,
151# key_builder=validate_db_token_key_builder,
152# # NOTE: Only cache the result if the server was successfully verified.
153# skip_cache_func=lambda x: x is False,
154# )
155async def validate_db_token(
156 request_token_hash: bytes,
157 addr: ipaddress.IPv4Address,
158 port: int,
159 pg_pool: asyncpg.Pool,
160) -> bool:
161 # TODO: we should cache this in memory (LRU) or diskcache.
162 # TODO: if an API key is deleted from the database, the cache
163 # for said API key should also be cleared!
165 async with pool_acquire(pg_pool) as conn:
166 api_key = await queries.select_game_server_api_key(
167 conn=conn,
168 game_server_address=addr,
169 game_server_port=port,
170 )
172 logger.debug("api_key: {}", api_key)
174 if not api_key:
175 logger.debug("JWT validation failed: no API key for {}:{}", addr, port)
176 return False
178 db_api_key_hash: bytes = api_key["api_key_hash"]
179 if not compare_digest(request_token_hash, db_api_key_hash):
180 logger.debug("JWT validation failed: stored hash does not match token hash")
181 return False
183 return True
186async def check_token(request: Request, pg_pool: asyncpg.Pool) -> bool:
187 if not request.token:
188 logger.debug("JWT validation failed: no token")
189 return False
191 try:
192 token = jwt.decode(
193 jwt=request.token,
194 key=request.app.config.SECRET,
195 options={"require": ["exp", "iss", "sub", "aud"]},
196 algorithms=["HS256"],
197 audience=request.app.config.JWT_AUDIENCE,
198 issuer=request.app.config.JWT_ISSUER,
199 )
200 except jwt.exceptions.PyJWTError as e:
201 logger.debug("JWT validation failed: {}: {}", type(e).__name__, e)
202 return False
204 # JWT subject should be IP:port.
205 sub: str = token["sub"]
206 a, p = sub.split(":")
207 addr = ipaddress.IPv4Address(a)
208 port = int(p)
209 logger.debug("token addr:port: {}:{}", addr, port)
211 # Small extra step of security since we can't use HTTPS.
212 # In any case, this is not really secure, but better than nothing.
213 client_addr = get_remote_addr(request)
214 logger.debug("client_addr: {}", client_addr)
215 if client_addr != addr:
216 logger.debug("JWT validation failed: (client_addr != addr): {} != {}", client_addr, addr)
217 return False
219 req_token_hash = hashlib.sha256(request.token.encode("utf-8")).digest()
220 if not await validate_db_token(
221 request_token_hash=req_token_hash,
222 addr=addr,
223 port=port,
224 pg_pool=pg_pool,
225 ):
226 return False
228 if _steam_web_api_key is None:
229 logger.warning("Steam Web API key is not set, "
230 "unable to verify server is a real RS2 server")
231 else:
232 ok = await is_real_game_server(
233 game_server_address=addr,
234 game_server_port=port,
235 pg_pool=pg_pool,
236 http_client=request.app.ctx.http_client,
237 )
238 if not ok:
239 logger.debug("JWT validation failed: server is not a real RS2 server "
240 "according to Steam Web API")
241 return False
243 request.ctx.jwt_game_server_port = port
244 request.ctx.jwt_game_server_address = addr
246 return True
249def check_and_inject_game(func: Callable) -> Callable:
250 def decorator(f: Callable) -> Callable:
251 @wraps(f)
252 async def game_owner_checked_handler(
253 request: Request,
254 game_id: str,
255 *args,
256 **kwargs,
257 ) -> sanic.HTTPResponse:
258 if ((request.ctx.jwt_game_server_address is None)
259 or (request.ctx.jwt_game_server_port is None)
260 ):
261 logger.debug(
262 "cannot verify game owner: jwt_game_server_address={}, jwt_game_server_port={}",
263 request.ctx.jwt_game_server_address,
264 request.ctx.jwt_game_server_port,
265 )
266 return sanic.HTTPResponse("Unauthorized.", status=HTTPStatus.UNAUTHORIZED)
268 async with pool_acquire(request.app.ctx.pg_pool) as conn:
269 game = await queries.select_game(conn=conn, game_id=game_id)
270 if not game:
271 logger.debug("no game found for game_id: {}", game_id)
272 return sanic.HTTPResponse(status=HTTPStatus.NOT_FOUND)
274 if game.game_server_address != request.ctx.jwt_game_server_address:
275 logger.debug(
276 "unauthorized: token address != DB address: {} != {}",
277 game.game_server_address,
278 request.ctx.jwt_game_server_address,
279 )
280 return sanic.HTTPResponse("Unauthorized.", status=HTTPStatus.UNAUTHORIZED)
282 if game.game_server_port != request.ctx.jwt_game_server_port:
283 logger.debug(
284 "unauthorized: token port != DB port: {} != {}",
285 game.game_server_port,
286 request.ctx.jwt_game_server_port,
287 )
288 return sanic.HTTPResponse("Unauthorized.", status=HTTPStatus.UNAUTHORIZED)
290 request.ctx.game = game
292 response = f(request, game_id=game_id, *args, **kwargs)
293 if isawaitable(response):
294 return await response
295 return response # pragma: no coverage
297 return game_owner_checked_handler
299 return decorator(func)