Coverage for chatgpt_proxy / auth / auth.py: 94%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-12 16:19 +0000

1# MIT License 

2# 

3# Copyright (c) 2025 Tuomo Kriikkula 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11# 

12# The above copyright notice and this permission notice shall be included in all 

13# copies or substantial portions of the Software. 

14# 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23"""Provides authentication and authorization for the chatgpt_proxy API server.""" 

24 

25import datetime 

26import hashlib 

27import ipaddress 

28import os 

29from functools import wraps 

30from http import HTTPStatus 

31from inspect import isawaitable 

32from secrets import compare_digest 

33from typing import Callable 

34 

35import asyncpg 

36import httpx 

37import jwt 

38import sanic 

39 

40from chatgpt_proxy.db import pool_acquire 

41from chatgpt_proxy.db import queries 

42from chatgpt_proxy.log import logger 

43from chatgpt_proxy.steam import steam 

44from chatgpt_proxy.types import Request 

45from chatgpt_proxy.utils import get_remote_addr 

46 

47jwt_issuer = "ChatGPTProxy" 

48jwt_audience = "ChatGPTProxy" 

49 

50_steam_web_api_key: str | None = None 

51 

52 

53def load_config() -> None: 

54 global _steam_web_api_key 

55 _steam_web_api_key = os.environ.get("STEAM_WEB_API_KEY", None) 

56 

57 

58load_config() 

59 

60ttl_is_real_game_server = datetime.timedelta(minutes=60).total_seconds() 

61ttl_validate_db_token = datetime.timedelta(minutes=5).total_seconds() 

62 

63 

64def is_real_game_server_key_builder(*args, **kwargs) -> str: 

65 """NOTE: this function is specific to is_real_game_server!""" 

66 

67 # NOTE: has to be kept in sync manually with 

68 # the real signature of is_real_game_server! 

69 kwargs.pop("pg_pool", None) 

70 kwargs.pop("http_client", None) 

71 

72 ordered_kwargs = sorted(kwargs.items()) 

73 return ( 

74 f"{is_real_game_server.__module__ or ""}" 

75 f"{is_real_game_server.__name__}" 

76 f"{str(args)}" 

77 f"{str(ordered_kwargs)}" 

78 ) 

79 

80 

81# TODO: waiting for updated aiocache + valkey-glide support on Windows! 

82# - In the meanwhile, only use memory cache! 

83# - See pyproject.toml for more details! 

84# TODO: FIXME: IMPORTANT: 

85# ENABLING THIS BREAKS TESTING! CACHE SHOULD BE CLEARED IN CERTAIN TESTS! 

86# FOR THAT, WE NEED TO USE app_cache, BUT WE CANNOT DO THAT UNTIL 

87# aiocache IS PATCHED! SEE THE ABOVE NOTE! 

88# @aiocache.cached( 

89# # cache=app_cache, 

90# cache=aiocache.Cache.MEMORY, 

91# ttl=ttl_is_real_game_server, 

92# key_builder=is_real_game_server_key_builder, 

93# # NOTE: Only cache the result if the server was successfully verified. 

94# skip_cache_func=lambda x: x is False, 

95# ) 

96async def is_real_game_server( 

97 *, 

98 game_server_address: ipaddress.IPv4Address, 

99 game_server_port: int, 

100 pg_pool: asyncpg.Pool, 

101 http_client: httpx.AsyncClient, 

102) -> bool: 

103 try: 

104 logger.debug("checking if {}:{} is a real RS2 server", game_server_address, game_server_port) 

105 

106 resp = await steam.web_api_request( 

107 http_client=http_client, 

108 pg_pool=pg_pool, 

109 url=steam.server_list_url, 

110 params={ 

111 "key": _steam_web_api_key, 

112 "filter": f"\\gamedir\\rs2\\gameaddr\\{game_server_address}:{game_server_port}", 

113 # TODO: limit param would speed up things, or would it? 

114 } 

115 ) 

116 resp.raise_for_status() 

117 servers = resp.json()["response"]["servers"] 

118 

119 if not servers: 

120 logger.debug("Steam Web API returned no servers for {}:{}", 

121 game_server_address, game_server_port) 

122 return False 

123 

124 return True 

125 except Exception as e: 

126 logger.debug("unable to verify {}:{} is a real RS2 server: {}: {}", 

127 game_server_address, game_server_port, type(e).__name__, e) 

128 return False 

129 

130 

131def validate_db_token_key_builder(*args, **kwargs) -> str: 

132 """NOTE: this function is specific to validate_db_token!""" 

133 

134 # NOTE: has to be kept in sync manually with 

135 # the real signature of validate_db_token! 

136 args = args[:-1] 

137 

138 ordered_kwargs = sorted(kwargs.items()) 

139 return ( 

140 f"{validate_db_token.__module__ or ""}" 

141 f"{validate_db_token.__name__}" 

142 f"{str(args)}" 

143 f"{str(ordered_kwargs)}" 

144 ) 

145 

146 

147# @aiocache.cached( 

148# # cache=app_cache, 

149# cache=aiocache.Cache.MEMORY, 

150# ttl=ttl_validate_db_token, 

151# key_builder=validate_db_token_key_builder, 

152# # NOTE: Only cache the result if the server was successfully verified. 

153# skip_cache_func=lambda x: x is False, 

154# ) 

155async def validate_db_token( 

156 request_token_hash: bytes, 

157 addr: ipaddress.IPv4Address, 

158 port: int, 

159 pg_pool: asyncpg.Pool, 

160) -> bool: 

161 # TODO: we should cache this in memory (LRU) or diskcache. 

162 # TODO: if an API key is deleted from the database, the cache 

163 # for said API key should also be cleared! 

164 

165 async with pool_acquire(pg_pool) as conn: 

166 api_key = await queries.select_game_server_api_key( 

167 conn=conn, 

168 game_server_address=addr, 

169 game_server_port=port, 

170 ) 

171 

172 logger.debug("api_key: {}", api_key) 

173 

174 if not api_key: 

175 logger.debug("JWT validation failed: no API key for {}:{}", addr, port) 

176 return False 

177 

178 db_api_key_hash: bytes = api_key["api_key_hash"] 

179 if not compare_digest(request_token_hash, db_api_key_hash): 

180 logger.debug("JWT validation failed: stored hash does not match token hash") 

181 return False 

182 

183 return True 

184 

185 

186async def check_token(request: Request, pg_pool: asyncpg.Pool) -> bool: 

187 if not request.token: 

188 logger.debug("JWT validation failed: no token") 

189 return False 

190 

191 try: 

192 token = jwt.decode( 

193 jwt=request.token, 

194 key=request.app.config.SECRET, 

195 options={"require": ["exp", "iss", "sub", "aud"]}, 

196 algorithms=["HS256"], 

197 audience=request.app.config.JWT_AUDIENCE, 

198 issuer=request.app.config.JWT_ISSUER, 

199 ) 

200 except jwt.exceptions.PyJWTError as e: 

201 logger.debug("JWT validation failed: {}: {}", type(e).__name__, e) 

202 return False 

203 

204 # JWT subject should be IP:port. 

205 sub: str = token["sub"] 

206 a, p = sub.split(":") 

207 addr = ipaddress.IPv4Address(a) 

208 port = int(p) 

209 logger.debug("token addr:port: {}:{}", addr, port) 

210 

211 # Small extra step of security since we can't use HTTPS. 

212 # In any case, this is not really secure, but better than nothing. 

213 client_addr = get_remote_addr(request) 

214 logger.debug("client_addr: {}", client_addr) 

215 if client_addr != addr: 

216 logger.debug("JWT validation failed: (client_addr != addr): {} != {}", client_addr, addr) 

217 return False 

218 

219 req_token_hash = hashlib.sha256(request.token.encode("utf-8")).digest() 

220 if not await validate_db_token( 

221 request_token_hash=req_token_hash, 

222 addr=addr, 

223 port=port, 

224 pg_pool=pg_pool, 

225 ): 

226 return False 

227 

228 if _steam_web_api_key is None: 

229 logger.warning("Steam Web API key is not set, " 

230 "unable to verify server is a real RS2 server") 

231 else: 

232 ok = await is_real_game_server( 

233 game_server_address=addr, 

234 game_server_port=port, 

235 pg_pool=pg_pool, 

236 http_client=request.app.ctx.http_client, 

237 ) 

238 if not ok: 

239 logger.debug("JWT validation failed: server is not a real RS2 server " 

240 "according to Steam Web API") 

241 return False 

242 

243 request.ctx.jwt_game_server_port = port 

244 request.ctx.jwt_game_server_address = addr 

245 

246 return True 

247 

248 

249def check_and_inject_game(func: Callable) -> Callable: 

250 def decorator(f: Callable) -> Callable: 

251 @wraps(f) 

252 async def game_owner_checked_handler( 

253 request: Request, 

254 game_id: str, 

255 *args, 

256 **kwargs, 

257 ) -> sanic.HTTPResponse: 

258 if ((request.ctx.jwt_game_server_address is None) 

259 or (request.ctx.jwt_game_server_port is None) 

260 ): 

261 logger.debug( 

262 "cannot verify game owner: jwt_game_server_address={}, jwt_game_server_port={}", 

263 request.ctx.jwt_game_server_address, 

264 request.ctx.jwt_game_server_port, 

265 ) 

266 return sanic.HTTPResponse("Unauthorized.", status=HTTPStatus.UNAUTHORIZED) 

267 

268 async with pool_acquire(request.app.ctx.pg_pool) as conn: 

269 game = await queries.select_game(conn=conn, game_id=game_id) 

270 if not game: 

271 logger.debug("no game found for game_id: {}", game_id) 

272 return sanic.HTTPResponse(status=HTTPStatus.NOT_FOUND) 

273 

274 if game.game_server_address != request.ctx.jwt_game_server_address: 

275 logger.debug( 

276 "unauthorized: token address != DB address: {} != {}", 

277 game.game_server_address, 

278 request.ctx.jwt_game_server_address, 

279 ) 

280 return sanic.HTTPResponse("Unauthorized.", status=HTTPStatus.UNAUTHORIZED) 

281 

282 if game.game_server_port != request.ctx.jwt_game_server_port: 

283 logger.debug( 

284 "unauthorized: token port != DB port: {} != {}", 

285 game.game_server_port, 

286 request.ctx.jwt_game_server_port, 

287 ) 

288 return sanic.HTTPResponse("Unauthorized.", status=HTTPStatus.UNAUTHORIZED) 

289 

290 request.ctx.game = game 

291 

292 response = f(request, game_id=game_id, *args, **kwargs) 

293 if isawaitable(response): 

294 return await response 

295 return response # pragma: no coverage 

296 

297 return game_owner_checked_handler 

298 

299 return decorator(func)