Coverage for chatgpt_proxy / db / queries.py: 97%
112 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-12 16:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-12 16:19 +0000
1# MIT License
2#
3# Copyright (c) 2025 Tuomo Kriikkula
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in all
13# copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23"""Database query helpers."""
25import datetime
26import ipaddress
28from asyncpg import Connection
29from asyncpg import Record
30from pypika import Order
31from pypika import Table
32from pypika.queries import QueryBuilder
34from chatgpt_proxy.db import models
36_default_conn_timeout = 15.0
39class Ignored:
40 pass
43IGNORED = Ignored()
46# TODO: add caching layer!
49async def insert_game(
50 conn: Connection,
51 game_id: str,
52 level: str,
53 game_server_address: ipaddress.IPv4Address,
54 game_server_port: int,
55 start_time: datetime.datetime,
56 stop_time: datetime.datetime | None = None,
57 openai_previous_response_id: str | None = None,
58 timeout: float | None = _default_conn_timeout,
59):
60 await conn.execute(
61 """
62 INSERT INTO "game"
63 (id, level, start_time, stop_time, game_server_address,
64 game_server_port, openai_previous_response_id)
65 VALUES ($1, $2, $3, $4, $5, $6, $7);
66 """,
67 game_id,
68 level,
69 start_time,
70 stop_time,
71 game_server_address,
72 game_server_port,
73 openai_previous_response_id,
74 timeout=timeout,
75 )
78def build_update_game_query(
79 game_id: str,
80 stop_time: datetime.datetime | Ignored = IGNORED,
81 openai_previous_response_id: str | Ignored = IGNORED,
82) -> QueryBuilder:
83 game = Table(name="game")
84 query = game.update()
85 if stop_time is not IGNORED:
86 query = query.set(game.stop_time, stop_time)
87 if openai_previous_response_id is not IGNORED:
88 query = query.set(game.openai_previous_response_id, openai_previous_response_id)
89 query = query.where(game.id == game_id)
90 return query
93async def update_game(
94 conn: Connection,
95 game_id: str,
96 stop_time: datetime.datetime | Ignored = IGNORED,
97 openai_previous_response_id: str | Ignored = IGNORED,
98 timeout: float | None = _default_conn_timeout,
99):
100 query = build_update_game_query(
101 game_id=game_id,
102 stop_time=stop_time,
103 openai_previous_response_id=openai_previous_response_id,
104 )
105 await conn.execute(str(query), timeout=timeout)
108# TODO: what's the best way to handle this? If we make this too dynamic
109# it's going to cross into ORM territory quickly.
110# For now, assume we only ever want to select by game_id and
111# return all columns even if it is wasteful.
112async def select_game(
113 conn: Connection,
114 game_id: str,
115 timeout: float | None = _default_conn_timeout,
116) -> models.Game | None:
117 record = await conn.fetchrow(
118 """
119 SELECT *
120 FROM "game"
121 WHERE id = $1;
122 """,
123 game_id,
124 timeout=timeout,
125 )
127 if record:
128 return models.Game(**record)
130 return None
133# TODO: select * or do we need a way to specify columns?
134async def select_games(
135 conn: Connection,
136 timeout: float | None = _default_conn_timeout,
137) -> list[models.Game]:
138 games = await conn.fetch(
139 """
140 SELECT *
141 FROM "game";
142 """,
143 timeout=timeout,
144 )
145 return [
146 models.Game(**game)
147 for game in games if game
148 ]
151async def upsert_game_objective_state(
152 conn: Connection,
153 state: models.GameObjectiveState,
154 timeout: float | None = _default_conn_timeout,
155) -> bool:
156 # TODO: re-consider what's the best data format for this in DB?
157 objectives_db_fmt = [obj.wire_format() for obj in state.objectives]
159 inserted = await conn.fetchval(
160 """
161 INSERT INTO "game_objective_state" (game_id, objectives)
162 VALUES ($1, $2)
163 ON CONFLICT (game_id) DO UPDATE
164 SET objectives = excluded.objectives
165 RETURNING (xmax = 0) as inserted;
166 """,
167 state.game_id,
168 objectives_db_fmt,
169 timeout=timeout,
170 )
171 return bool(inserted)
174async def delete_completed_games(
175 conn: Connection,
176 game_expiration: datetime.timedelta,
177 timeout: float | None = _default_conn_timeout,
178) -> str:
179 return await conn.execute(
180 """
181 DELETE
182 FROM "game"
183 WHERE stop_time IS NOT NULL
184 OR NOW() > (stop_time + $1);
185 """,
186 game_expiration,
187 timeout=timeout,
188 )
191# TODO: add the rest of cols here if needed?
192async def select_game_server_api_key(
193 conn: Connection,
194 game_server_address: ipaddress.IPv4Address,
195 game_server_port: int,
196 timeout: float | None = _default_conn_timeout,
197) -> Record | None:
198 return await conn.fetchrow(
199 """
200 SELECT *
201 FROM "game_server_api_key"
202 WHERE game_server_address = $1
203 AND game_server_port = $2;
204 """,
205 game_server_address,
206 game_server_port,
207 timeout=timeout,
208 )
211async def select_game_server_api_keys(
212 conn: Connection,
213 timeout: float | None = _default_conn_timeout,
214) -> list[Record]:
215 return await conn.fetch(
216 """
217 SELECT *
218 FROM "game_server_api_key";
219 """,
220 timeout=timeout,
221 )
224async def insert_game_server_api_key(
225 conn: Connection,
226 issued_at: datetime.datetime,
227 expires_at: datetime.datetime,
228 token_hash: bytes,
229 game_server_address: ipaddress.IPv4Address,
230 game_server_port: int,
231 name: str | None = None,
232 timeout: float | None = _default_conn_timeout,
233):
234 await conn.execute(
235 """
236 INSERT INTO "game_server_api_key"
237 (created_at, expires_at, api_key_hash, game_server_address, game_server_port, name)
238 VALUES ($1, $2, $3, $4, $5, $6);
239 """,
240 issued_at,
241 expires_at,
242 token_hash,
243 game_server_address,
244 game_server_port,
245 name,
246 timeout=timeout,
247 )
250async def game_exists(
251 conn: Connection,
252 game_id: str,
253 timeout: float | None = _default_conn_timeout,
254) -> bool:
255 return await conn.fetchval(
256 """
257 SELECT 1
258 FROM "game"
259 WHERE id = $1;
260 """,
261 game_id,
262 timeout=timeout,
263 ) is not None
266async def delete_old_api_keys(
267 conn: Connection,
268 leeway: datetime.timedelta,
269 timeout: float | None = _default_conn_timeout,
270) -> str:
271 return await conn.execute(
272 """
273 DELETE
274 FROM "game_server_api_key"
275 WHERE NOW() > (expires_at + $1);
276 """,
277 leeway,
278 timeout=timeout,
279 )
282async def select_openai_query(
283 conn: Connection,
284 openai_response_id: str,
285 timeout: float | None = _default_conn_timeout,
286) -> models.OpenAIQuery | None:
287 """NOTE: for now, assuming we only want to select by openai_response_id."""
288 record = await conn.fetchrow(
289 """
290 SELECT *
291 FROM "openai_query"
292 WHERE openai_response_id = $1;
293 """,
294 openai_response_id,
295 timeout=timeout,
296 )
298 if not record:
299 return None
301 return models.OpenAIQuery(**record)
304async def insert_openai_query(
305 conn: Connection,
306 game_id: str,
307 time: datetime.datetime,
308 game_server_address: ipaddress.IPv4Address,
309 game_server_port: int,
310 request_length: int,
311 response_length: int,
312 openai_response_id: str,
313 timeout: float | None = _default_conn_timeout,
314) -> None:
315 await conn.execute(
316 """
317 INSERT INTO "openai_query"
318 (game_id, time, game_server_address,
319 game_server_port, request_length, response_length, openai_response_id)
320 VALUES ($1, $2, $3, $4, $5, $6, $7);
321 """,
322 game_id,
323 time,
324 game_server_address,
325 game_server_port,
326 request_length,
327 response_length,
328 openai_response_id,
329 timeout=timeout,
330 )
333async def insert_game_chat_message(
334 conn: Connection,
335 game_id: str,
336 message: str,
337 send_time: datetime.datetime,
338 sender_name: str,
339 sender_team: models.Team,
340 channel: models.SayType,
341 timeout: float | None = _default_conn_timeout,
342) -> int:
343 _sender_team = int(sender_team)
344 _channel = int(channel)
346 return await conn.fetchval(
347 """
348 INSERT INTO "game_chat_message"
349 (message, game_id, send_time, sender_name, sender_team, channel)
350 VALUES ($1, $2, $3, $4, $5, $6)
351 RETURNING id;
352 """,
353 message,
354 game_id,
355 send_time,
356 sender_name,
357 _sender_team,
358 _channel,
359 timeout=timeout,
360 )
363async def insert_game_kill(
364 conn: Connection,
365 game_id: str,
366 kill_time: datetime.datetime,
367 killer_name: str,
368 victim_name: str,
369 killer_team: models.Team,
370 victim_team: models.Team,
371 damage_type: str,
372 kill_distance_m: float,
373 timeout: float | None = _default_conn_timeout,
374) -> int:
375 _killer_team = int(killer_team)
376 _victim_team = int(victim_team)
378 return await conn.fetchval(
379 """
380 INSERT INTO "game_kill"
381 (game_id, kill_time, killer_name, victim_name, killer_team,
382 victim_team, damage_type, kill_distance_m)
383 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
384 RETURNING id;
385 """,
386 game_id,
387 kill_time,
388 killer_name,
389 victim_name,
390 _killer_team,
391 _victim_team,
392 damage_type,
393 kill_distance_m,
394 timeout=timeout,
395 )
398async def delete_game_player(
399 conn: Connection,
400 game_id: str,
401 player_id: int,
402 timeout: float | None = _default_conn_timeout,
403):
404 await conn.execute(
405 """
406 DELETE
407 FROM "game_player"
408 WHERE game_id = $1
409 AND id = $2;
410 """,
411 game_id,
412 player_id,
413 timeout=timeout,
414 )
417async def upsert_game_player(
418 conn: Connection,
419 game_id: str,
420 player_id: int,
421 name: str,
422 team_index: int,
423 score: float,
424 timeout: float | None = _default_conn_timeout,
425) -> bool:
426 inserted = await conn.fetchval(
427 """
428 INSERT INTO "game_player" (game_id, id, name, team, score)
429 VALUES ($1, $2, $3, $4, $5)
430 ON CONFLICT (game_id, id) DO UPDATE
431 SET game_id = excluded.game_id,
432 name = excluded.name,
433 team = excluded.team,
434 score = excluded.score
435 RETURNING (xmax = 0) as inserted;
436 """,
437 game_id,
438 player_id,
439 name,
440 team_index,
441 score,
442 timeout=timeout,
443 )
444 return bool(inserted)
447async def select_game_players(
448 conn: Connection,
449 game_id: str,
450 timeout: float | None = _default_conn_timeout,
451) -> list[models.GamePlayer]:
452 records = await conn.fetch(
453 """
454 SELECT *
455 FROM "game_player"
456 WHERE game_id = $1;
457 """,
458 game_id,
459 timeout=timeout,
460 )
462 return [
463 models.GamePlayer(**record)
464 for record in records
465 ]
468async def select_game_player(
469 conn: Connection,
470 game_id: str,
471 player_id: int,
472 timeout: float | None = _default_conn_timeout,
473) -> models.GamePlayer | None:
474 record = await conn.fetchrow(
475 """
476 SELECT *
477 FROM "game_player"
478 WHERE game_id = $1
479 AND id = $2;
480 """,
481 game_id,
482 player_id,
483 timeout=timeout,
484 )
486 if not record:
487 return None
489 return models.GamePlayer(
490 game_id=record["game_id"],
491 id=record["id"],
492 name=record["name"],
493 team=models.Team(str(record["team"])),
494 score=record["score"],
495 )
498async def game_player_exists(
499 conn: Connection,
500 game_id: str,
501 player_id: int,
502 timeout: float | None = _default_conn_timeout,
503) -> bool:
504 return await conn.fetchval(
505 """
506 SELECT 1
507 FROM "game_player"
508 WHERE game_id = $1
509 AND id = $2;
510 """,
511 game_id,
512 player_id,
513 timeout=timeout,
514 ) is not None
517async def select_game_kills(
518 conn: Connection,
519 game_id: str | None = None,
520 kill_time_from: datetime.datetime | None = None,
521 limit: int | None = None,
522 timeout: float | None = _default_conn_timeout,
523) -> list[models.GameKill]:
524 game_kill = Table(name="game_kill")
525 query = game_kill.select("*")
526 if game_id is not None:
527 query = query.where(game_kill.game_id == game_id)
528 if kill_time_from is not None:
529 query = query.where(game_kill.kill_time >= kill_time_from)
530 if limit is not None:
531 query = query.limit(limit)
532 query = query.orderby("id", order=Order.asc)
534 records = await conn.fetch(
535 str(query),
536 timeout=timeout,
537 )
539 return [
540 models.GameKill(**record)
541 for record in records
542 ]
545async def select_game_chat_messages(
546 conn: Connection,
547 game_id: str | None = None,
548 send_time_from: datetime.datetime | None = None,
549 limit: int | None = None,
550 timeout: float | None = _default_conn_timeout,
551) -> list[models.GameChatMessage]:
552 game_chat_message = Table(name="game_chat_message")
553 query = game_chat_message.select("*")
554 if game_id is not None:
555 query = query.where(game_chat_message.game_id == game_id)
556 if send_time_from is not None:
557 query = query.where(game_chat_message.send_time >= send_time_from)
558 if limit is not None:
559 query = query.limit(limit)
560 query = query.orderby("id", order=Order.asc)
562 records = await conn.fetch(
563 str(query),
564 timeout=timeout,
565 )
567 return [
568 models.GameChatMessage(**record)
569 for record in records
570 ]
573async def increment_steam_web_api_queries(
574 conn: Connection,
575 timeout: float | None = _default_conn_timeout,
576) -> None:
577 async with conn.transaction():
578 await conn.execute(
579 """
580 UPDATE "query_statistics"
581 SET steam_web_api_queries = steam_web_api_queries + 1,
582 last_steam_web_api_query = NOW();
583 """,
584 timeout=timeout,
585 )
588async def select_steam_web_api_queries(
589 conn: Connection,
590 timeout: float | None = _default_conn_timeout,
591) -> int:
592 record = await conn.fetchrow(
593 """
594 SELECT steam_web_api_queries
595 FROM "query_statistics";
596 """,
597 timeout=timeout,
598 )
599 if not record:
600 return 0
601 return record["steam_web_api_queries"]