Coverage for chatgpt_proxy / db / queries.py: 97%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-12 16:19 +0000

1# MIT License 

2# 

3# Copyright (c) 2025 Tuomo Kriikkula 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11# 

12# The above copyright notice and this permission notice shall be included in all 

13# copies or substantial portions of the Software. 

14# 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23"""Database query helpers.""" 

24 

25import datetime 

26import ipaddress 

27 

28from asyncpg import Connection 

29from asyncpg import Record 

30from pypika import Order 

31from pypika import Table 

32from pypika.queries import QueryBuilder 

33 

34from chatgpt_proxy.db import models 

35 

36_default_conn_timeout = 15.0 

37 

38 

39class Ignored: 

40 pass 

41 

42 

43IGNORED = Ignored() 

44 

45 

46# TODO: add caching layer! 

47 

48 

49async def insert_game( 

50 conn: Connection, 

51 game_id: str, 

52 level: str, 

53 game_server_address: ipaddress.IPv4Address, 

54 game_server_port: int, 

55 start_time: datetime.datetime, 

56 stop_time: datetime.datetime | None = None, 

57 openai_previous_response_id: str | None = None, 

58 timeout: float | None = _default_conn_timeout, 

59): 

60 await conn.execute( 

61 """ 

62 INSERT INTO "game" 

63 (id, level, start_time, stop_time, game_server_address, 

64 game_server_port, openai_previous_response_id) 

65 VALUES ($1, $2, $3, $4, $5, $6, $7); 

66 """, 

67 game_id, 

68 level, 

69 start_time, 

70 stop_time, 

71 game_server_address, 

72 game_server_port, 

73 openai_previous_response_id, 

74 timeout=timeout, 

75 ) 

76 

77 

78def build_update_game_query( 

79 game_id: str, 

80 stop_time: datetime.datetime | Ignored = IGNORED, 

81 openai_previous_response_id: str | Ignored = IGNORED, 

82) -> QueryBuilder: 

83 game = Table(name="game") 

84 query = game.update() 

85 if stop_time is not IGNORED: 

86 query = query.set(game.stop_time, stop_time) 

87 if openai_previous_response_id is not IGNORED: 

88 query = query.set(game.openai_previous_response_id, openai_previous_response_id) 

89 query = query.where(game.id == game_id) 

90 return query 

91 

92 

93async def update_game( 

94 conn: Connection, 

95 game_id: str, 

96 stop_time: datetime.datetime | Ignored = IGNORED, 

97 openai_previous_response_id: str | Ignored = IGNORED, 

98 timeout: float | None = _default_conn_timeout, 

99): 

100 query = build_update_game_query( 

101 game_id=game_id, 

102 stop_time=stop_time, 

103 openai_previous_response_id=openai_previous_response_id, 

104 ) 

105 await conn.execute(str(query), timeout=timeout) 

106 

107 

108# TODO: what's the best way to handle this? If we make this too dynamic 

109# it's going to cross into ORM territory quickly. 

110# For now, assume we only ever want to select by game_id and 

111# return all columns even if it is wasteful. 

112async def select_game( 

113 conn: Connection, 

114 game_id: str, 

115 timeout: float | None = _default_conn_timeout, 

116) -> models.Game | None: 

117 record = await conn.fetchrow( 

118 """ 

119 SELECT * 

120 FROM "game" 

121 WHERE id = $1; 

122 """, 

123 game_id, 

124 timeout=timeout, 

125 ) 

126 

127 if record: 

128 return models.Game(**record) 

129 

130 return None 

131 

132 

133# TODO: select * or do we need a way to specify columns? 

134async def select_games( 

135 conn: Connection, 

136 timeout: float | None = _default_conn_timeout, 

137) -> list[models.Game]: 

138 games = await conn.fetch( 

139 """ 

140 SELECT * 

141 FROM "game"; 

142 """, 

143 timeout=timeout, 

144 ) 

145 return [ 

146 models.Game(**game) 

147 for game in games if game 

148 ] 

149 

150 

151async def upsert_game_objective_state( 

152 conn: Connection, 

153 state: models.GameObjectiveState, 

154 timeout: float | None = _default_conn_timeout, 

155) -> bool: 

156 # TODO: re-consider what's the best data format for this in DB? 

157 objectives_db_fmt = [obj.wire_format() for obj in state.objectives] 

158 

159 inserted = await conn.fetchval( 

160 """ 

161 INSERT INTO "game_objective_state" (game_id, objectives) 

162 VALUES ($1, $2) 

163 ON CONFLICT (game_id) DO UPDATE 

164 SET objectives = excluded.objectives 

165 RETURNING (xmax = 0) as inserted; 

166 """, 

167 state.game_id, 

168 objectives_db_fmt, 

169 timeout=timeout, 

170 ) 

171 return bool(inserted) 

172 

173 

174async def delete_completed_games( 

175 conn: Connection, 

176 game_expiration: datetime.timedelta, 

177 timeout: float | None = _default_conn_timeout, 

178) -> str: 

179 return await conn.execute( 

180 """ 

181 DELETE 

182 FROM "game" 

183 WHERE stop_time IS NOT NULL 

184 OR NOW() > (stop_time + $1); 

185 """, 

186 game_expiration, 

187 timeout=timeout, 

188 ) 

189 

190 

191# TODO: add the rest of cols here if needed? 

192async def select_game_server_api_key( 

193 conn: Connection, 

194 game_server_address: ipaddress.IPv4Address, 

195 game_server_port: int, 

196 timeout: float | None = _default_conn_timeout, 

197) -> Record | None: 

198 return await conn.fetchrow( 

199 """ 

200 SELECT * 

201 FROM "game_server_api_key" 

202 WHERE game_server_address = $1 

203 AND game_server_port = $2; 

204 """, 

205 game_server_address, 

206 game_server_port, 

207 timeout=timeout, 

208 ) 

209 

210 

211async def select_game_server_api_keys( 

212 conn: Connection, 

213 timeout: float | None = _default_conn_timeout, 

214) -> list[Record]: 

215 return await conn.fetch( 

216 """ 

217 SELECT * 

218 FROM "game_server_api_key"; 

219 """, 

220 timeout=timeout, 

221 ) 

222 

223 

224async def insert_game_server_api_key( 

225 conn: Connection, 

226 issued_at: datetime.datetime, 

227 expires_at: datetime.datetime, 

228 token_hash: bytes, 

229 game_server_address: ipaddress.IPv4Address, 

230 game_server_port: int, 

231 name: str | None = None, 

232 timeout: float | None = _default_conn_timeout, 

233): 

234 await conn.execute( 

235 """ 

236 INSERT INTO "game_server_api_key" 

237 (created_at, expires_at, api_key_hash, game_server_address, game_server_port, name) 

238 VALUES ($1, $2, $3, $4, $5, $6); 

239 """, 

240 issued_at, 

241 expires_at, 

242 token_hash, 

243 game_server_address, 

244 game_server_port, 

245 name, 

246 timeout=timeout, 

247 ) 

248 

249 

250async def game_exists( 

251 conn: Connection, 

252 game_id: str, 

253 timeout: float | None = _default_conn_timeout, 

254) -> bool: 

255 return await conn.fetchval( 

256 """ 

257 SELECT 1 

258 FROM "game" 

259 WHERE id = $1; 

260 """, 

261 game_id, 

262 timeout=timeout, 

263 ) is not None 

264 

265 

266async def delete_old_api_keys( 

267 conn: Connection, 

268 leeway: datetime.timedelta, 

269 timeout: float | None = _default_conn_timeout, 

270) -> str: 

271 return await conn.execute( 

272 """ 

273 DELETE 

274 FROM "game_server_api_key" 

275 WHERE NOW() > (expires_at + $1); 

276 """, 

277 leeway, 

278 timeout=timeout, 

279 ) 

280 

281 

282async def select_openai_query( 

283 conn: Connection, 

284 openai_response_id: str, 

285 timeout: float | None = _default_conn_timeout, 

286) -> models.OpenAIQuery | None: 

287 """NOTE: for now, assuming we only want to select by openai_response_id.""" 

288 record = await conn.fetchrow( 

289 """ 

290 SELECT * 

291 FROM "openai_query" 

292 WHERE openai_response_id = $1; 

293 """, 

294 openai_response_id, 

295 timeout=timeout, 

296 ) 

297 

298 if not record: 

299 return None 

300 

301 return models.OpenAIQuery(**record) 

302 

303 

304async def insert_openai_query( 

305 conn: Connection, 

306 game_id: str, 

307 time: datetime.datetime, 

308 game_server_address: ipaddress.IPv4Address, 

309 game_server_port: int, 

310 request_length: int, 

311 response_length: int, 

312 openai_response_id: str, 

313 timeout: float | None = _default_conn_timeout, 

314) -> None: 

315 await conn.execute( 

316 """ 

317 INSERT INTO "openai_query" 

318 (game_id, time, game_server_address, 

319 game_server_port, request_length, response_length, openai_response_id) 

320 VALUES ($1, $2, $3, $4, $5, $6, $7); 

321 """, 

322 game_id, 

323 time, 

324 game_server_address, 

325 game_server_port, 

326 request_length, 

327 response_length, 

328 openai_response_id, 

329 timeout=timeout, 

330 ) 

331 

332 

333async def insert_game_chat_message( 

334 conn: Connection, 

335 game_id: str, 

336 message: str, 

337 send_time: datetime.datetime, 

338 sender_name: str, 

339 sender_team: models.Team, 

340 channel: models.SayType, 

341 timeout: float | None = _default_conn_timeout, 

342) -> int: 

343 _sender_team = int(sender_team) 

344 _channel = int(channel) 

345 

346 return await conn.fetchval( 

347 """ 

348 INSERT INTO "game_chat_message" 

349 (message, game_id, send_time, sender_name, sender_team, channel) 

350 VALUES ($1, $2, $3, $4, $5, $6) 

351 RETURNING id; 

352 """, 

353 message, 

354 game_id, 

355 send_time, 

356 sender_name, 

357 _sender_team, 

358 _channel, 

359 timeout=timeout, 

360 ) 

361 

362 

363async def insert_game_kill( 

364 conn: Connection, 

365 game_id: str, 

366 kill_time: datetime.datetime, 

367 killer_name: str, 

368 victim_name: str, 

369 killer_team: models.Team, 

370 victim_team: models.Team, 

371 damage_type: str, 

372 kill_distance_m: float, 

373 timeout: float | None = _default_conn_timeout, 

374) -> int: 

375 _killer_team = int(killer_team) 

376 _victim_team = int(victim_team) 

377 

378 return await conn.fetchval( 

379 """ 

380 INSERT INTO "game_kill" 

381 (game_id, kill_time, killer_name, victim_name, killer_team, 

382 victim_team, damage_type, kill_distance_m) 

383 VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 

384 RETURNING id; 

385 """, 

386 game_id, 

387 kill_time, 

388 killer_name, 

389 victim_name, 

390 _killer_team, 

391 _victim_team, 

392 damage_type, 

393 kill_distance_m, 

394 timeout=timeout, 

395 ) 

396 

397 

398async def delete_game_player( 

399 conn: Connection, 

400 game_id: str, 

401 player_id: int, 

402 timeout: float | None = _default_conn_timeout, 

403): 

404 await conn.execute( 

405 """ 

406 DELETE 

407 FROM "game_player" 

408 WHERE game_id = $1 

409 AND id = $2; 

410 """, 

411 game_id, 

412 player_id, 

413 timeout=timeout, 

414 ) 

415 

416 

417async def upsert_game_player( 

418 conn: Connection, 

419 game_id: str, 

420 player_id: int, 

421 name: str, 

422 team_index: int, 

423 score: float, 

424 timeout: float | None = _default_conn_timeout, 

425) -> bool: 

426 inserted = await conn.fetchval( 

427 """ 

428 INSERT INTO "game_player" (game_id, id, name, team, score) 

429 VALUES ($1, $2, $3, $4, $5) 

430 ON CONFLICT (game_id, id) DO UPDATE 

431 SET game_id = excluded.game_id, 

432 name = excluded.name, 

433 team = excluded.team, 

434 score = excluded.score 

435 RETURNING (xmax = 0) as inserted; 

436 """, 

437 game_id, 

438 player_id, 

439 name, 

440 team_index, 

441 score, 

442 timeout=timeout, 

443 ) 

444 return bool(inserted) 

445 

446 

447async def select_game_players( 

448 conn: Connection, 

449 game_id: str, 

450 timeout: float | None = _default_conn_timeout, 

451) -> list[models.GamePlayer]: 

452 records = await conn.fetch( 

453 """ 

454 SELECT * 

455 FROM "game_player" 

456 WHERE game_id = $1; 

457 """, 

458 game_id, 

459 timeout=timeout, 

460 ) 

461 

462 return [ 

463 models.GamePlayer(**record) 

464 for record in records 

465 ] 

466 

467 

468async def select_game_player( 

469 conn: Connection, 

470 game_id: str, 

471 player_id: int, 

472 timeout: float | None = _default_conn_timeout, 

473) -> models.GamePlayer | None: 

474 record = await conn.fetchrow( 

475 """ 

476 SELECT * 

477 FROM "game_player" 

478 WHERE game_id = $1 

479 AND id = $2; 

480 """, 

481 game_id, 

482 player_id, 

483 timeout=timeout, 

484 ) 

485 

486 if not record: 

487 return None 

488 

489 return models.GamePlayer( 

490 game_id=record["game_id"], 

491 id=record["id"], 

492 name=record["name"], 

493 team=models.Team(str(record["team"])), 

494 score=record["score"], 

495 ) 

496 

497 

498async def game_player_exists( 

499 conn: Connection, 

500 game_id: str, 

501 player_id: int, 

502 timeout: float | None = _default_conn_timeout, 

503) -> bool: 

504 return await conn.fetchval( 

505 """ 

506 SELECT 1 

507 FROM "game_player" 

508 WHERE game_id = $1 

509 AND id = $2; 

510 """, 

511 game_id, 

512 player_id, 

513 timeout=timeout, 

514 ) is not None 

515 

516 

517async def select_game_kills( 

518 conn: Connection, 

519 game_id: str | None = None, 

520 kill_time_from: datetime.datetime | None = None, 

521 limit: int | None = None, 

522 timeout: float | None = _default_conn_timeout, 

523) -> list[models.GameKill]: 

524 game_kill = Table(name="game_kill") 

525 query = game_kill.select("*") 

526 if game_id is not None: 

527 query = query.where(game_kill.game_id == game_id) 

528 if kill_time_from is not None: 

529 query = query.where(game_kill.kill_time >= kill_time_from) 

530 if limit is not None: 

531 query = query.limit(limit) 

532 query = query.orderby("id", order=Order.asc) 

533 

534 records = await conn.fetch( 

535 str(query), 

536 timeout=timeout, 

537 ) 

538 

539 return [ 

540 models.GameKill(**record) 

541 for record in records 

542 ] 

543 

544 

545async def select_game_chat_messages( 

546 conn: Connection, 

547 game_id: str | None = None, 

548 send_time_from: datetime.datetime | None = None, 

549 limit: int | None = None, 

550 timeout: float | None = _default_conn_timeout, 

551) -> list[models.GameChatMessage]: 

552 game_chat_message = Table(name="game_chat_message") 

553 query = game_chat_message.select("*") 

554 if game_id is not None: 

555 query = query.where(game_chat_message.game_id == game_id) 

556 if send_time_from is not None: 

557 query = query.where(game_chat_message.send_time >= send_time_from) 

558 if limit is not None: 

559 query = query.limit(limit) 

560 query = query.orderby("id", order=Order.asc) 

561 

562 records = await conn.fetch( 

563 str(query), 

564 timeout=timeout, 

565 ) 

566 

567 return [ 

568 models.GameChatMessage(**record) 

569 for record in records 

570 ] 

571 

572 

573async def increment_steam_web_api_queries( 

574 conn: Connection, 

575 timeout: float | None = _default_conn_timeout, 

576) -> None: 

577 async with conn.transaction(): 

578 await conn.execute( 

579 """ 

580 UPDATE "query_statistics" 

581 SET steam_web_api_queries = steam_web_api_queries + 1, 

582 last_steam_web_api_query = NOW(); 

583 """, 

584 timeout=timeout, 

585 ) 

586 

587 

588async def select_steam_web_api_queries( 

589 conn: Connection, 

590 timeout: float | None = _default_conn_timeout, 

591) -> int: 

592 record = await conn.fetchrow( 

593 """ 

594 SELECT steam_web_api_queries 

595 FROM "query_statistics"; 

596 """, 

597 timeout=timeout, 

598 ) 

599 if not record: 

600 return 0 

601 return record["steam_web_api_queries"]