Source code for ezcord.sql.postgresql

from __future__ import annotations

import contextlib
import json
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any

import asyncpg


def _process_args(args) -> tuple:
    """If SQL query parameters are passed as a tuple instead of single values,
    the tuple will be unpacked.
    """
    if len(args) == 1 and isinstance(args, tuple):
        if isinstance(args[0], tuple):
            args = args[0]

    # convert dict to str
    return tuple(json.dumps(arg) if isinstance(arg, dict) else arg for arg in args)


def _process_one_result(row: asyncpg.Record, default: Any):
    row = row[0] if row is not None and len(row) == 1 else row
    return row if row is not None else default


def _process_exec_status(status: str) -> QueryStatus:
    status_list = status.split(" ")
    if len(status_list) > 1 and not status_list[1].isdigit():
        status_list[0] += " " + status_list[1]
        status_list.pop(1)

    query_type = status_list[0]
    if len(status_list) == 2:
        return QueryStatus(type=query_type, rowcount=int(status_list[1]))
    elif len(status_list) == 3:
        return QueryStatus(
            type=query_type, rowcount=int(status_list[1]), inserts=int(status_list[2])
        )
    else:
        return QueryStatus(type=query_type)


[docs] @dataclass class QueryStatus: """A class to access the status of a :meth:`PGHandler.exec` call.""" type: str rowcount: int = 0 inserts: int = 0
[docs] class EzConnection(asyncpg.Connection): """A subclass of :class:`asyncpg.Connection` that adds aliases to be compatible with the sqlite handler. """
[docs] async def one(self, sql: str, *args, default=None, **kwargs): row = await super().fetchrow(sql, *_process_args(args), **kwargs) return _process_one_result(row, default)
[docs] async def all(self, sql: str, *args, **kwargs) -> list: result = await super().fetch(sql, *_process_args(args), **kwargs) if result and len(result[0]) == 1: return [row[0] for row in result] return result
[docs] async def fetchval(self, sql: str, *args, default=None, **kwargs): value = await super().fetchval(sql, *_process_args(args), **kwargs) return value or default
[docs] async def exec(self, sql: str, *args, **kwargs) -> QueryStatus: status = await super().execute(sql, *_process_args(args), **kwargs) return _process_exec_status(status)
[docs] async def execute(self, *args, **kwargs) -> QueryStatus: """Alias for :meth:`exec`.""" return await self.exec(*args, **kwargs)
[docs] class PGHandler: """A class that provides helper methods for PostgreSQL databases. .. note:: It's recommended to set the database connection parameters in the ``.env`` file. - Reference: https://www.postgresql.org/docs/current/libpq-envars.html Parameters ---------- custom_pool: Override the default connection pool with a key. Each custom pool has a unique key. Defaults to ``None``. auto_setup: Whether to call :meth:`setup` when the first instance of this class is created. Defaults to ``True``. **kwargs: Keyword arguments for :func:`asyncpg.create_pool`. """ pool: asyncpg.Pool | None = None _pools: dict[str, asyncpg.Pool | None] = {} _auto_setup: list[PGHandler] = [] _auto_pool: list[PGHandler] = [] def __init__( self, *, custom_pool: str | None = None, auto_setup: bool = True, **kwargs, ): self.kwargs = kwargs self.custom_pool = custom_pool if auto_setup and self not in self._auto_setup: PGHandler._auto_setup.append(self) if custom_pool: if self not in self._auto_pool: PGHandler._auto_pool.append(self) if custom_pool not in PGHandler._pools: PGHandler._pools[custom_pool] = None async def _check_pool(self) -> asyncpg.Pool: """Create a new connection pool or returns an existing one. Custom pools are stored in :attr:`_pools`. If a custom pool for a specified key already exists, it will be returned and set as the pool for the current class instance. """ if self.custom_pool: if self._pools[self.custom_pool]: self.pool = self._pools[self.custom_pool] return self._pools[self.custom_pool] elif PGHandler.pool is not None: return PGHandler.pool pool = await asyncpg.create_pool(connection_class=EzConnection, **self.kwargs) if self.custom_pool: PGHandler._pools[self.custom_pool] = pool self.pool = pool return self.pool else: PGHandler.pool = pool return PGHandler.pool
[docs] @contextlib.asynccontextmanager async def transaction(self): """Async context manager that provides a connection with an active transaction. Yields ------ :class:`EzConnection` A connection object with an active transaction. Example ------- :: async with self.transaction() as con: await con.exec("INSERT INTO users VALUES (...)") await con.exec("UPDATE users SET ...") """ pool = await self._check_pool() async with pool.acquire() as con: async with con.transaction(): yield con
[docs] async def one(self, sql: str, *args, default=None, **kwargs): """Returns one result record. If no record is found, ``None`` is returned. Parameters ---------- sql: The SQL query to execute. *args: Arguments for the query. default: When the query returns no results, this value will be returned instead of ``None``. Returns ------- The result record or ``None``. """ pool = await self._check_pool() async with pool.acquire() as con: return await con.one(sql, *args, default=default, **kwargs)
[docs] async def all(self, sql: str, *args, **kwargs) -> list: """Returns all result records. Parameters ---------- sql: The SQL query to execute. *args: Arguments for the query. Returns ------- A list of result records. """ pool = await self._check_pool() async with pool.acquire() as con: return await con.all(sql, *args, **kwargs)
[docs] async def fetchval(self, sql: str, *args, default=None, **kwargs): """Returns one value. Parameters ---------- sql: The SQL query to execute. *args: Arguments for the query. default: When the query returns no results, this value will be returned instead of ``None``. Returns ------- The value or ``None``. """ pool = await self._check_pool() async with pool.acquire() as con: return await con.fetchval(sql, *args, default=default, **kwargs)
[docs] async def exec(self, sql: str, *args, **kwargs) -> QueryStatus: """Executes a SQL query. Parameters ---------- sql: The SQL query to execute. *args: Arguments for the query. """ pool = await self._check_pool() async with pool.acquire() as con: return await con.exec(sql, *args, **kwargs)
[docs] async def execute(self, sql: str, *args, **kwargs) -> QueryStatus: """Alias for :meth:`exec`.""" return await self.exec(sql, *args, **kwargs)
[docs] async def executemany(self, sql: str, args: Iterable[Iterable[Any]], **kwargs) -> str: """Executes a SQL multiquery. Parameters ---------- sql: The multiquery to execute. *args: Arguments for the multiquery. """ pool = await self._check_pool() async with pool.acquire() as con: return await con.executemany(sql, args, **kwargs)