Skip to content

Commit 2ff5c46

Browse files
authored
♻️Director v2: remove aiopg usage (#7576)
1 parent e653a4a commit 2ff5c46

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+744
-712
lines changed

.github/copilot-instructions.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ This document provides guidelines and best practices for using GitHub Copilot in
2222
- ensure we use `pydantic` >2 compatible code.
2323
- ensure we use `fastapi` >0.100 compatible code
2424
- use f-string formatting
25+
- Only add comments in function if strictly necessary
2526

2627

2728
### Json serialization

packages/common-library/src/common_library/async_tools.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import functools
33
from collections.abc import Awaitable, Callable
44
from concurrent.futures import Executor
5-
from typing import ParamSpec, TypeVar
5+
from inspect import isawaitable
6+
from typing import ParamSpec, TypeVar, overload
67

78
R = TypeVar("R")
89
P = ParamSpec("P")
@@ -22,3 +23,42 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
2223
return wrapper
2324

2425
return decorator
26+
27+
28+
_AwaitableResult = TypeVar("_AwaitableResult")
29+
30+
31+
@overload
32+
async def maybe_await(obj: Awaitable[_AwaitableResult]) -> _AwaitableResult: ...
33+
34+
35+
@overload
36+
async def maybe_await(obj: _AwaitableResult) -> _AwaitableResult: ...
37+
38+
39+
async def maybe_await(
40+
obj: Awaitable[_AwaitableResult] | _AwaitableResult,
41+
) -> _AwaitableResult:
42+
"""Helper function to handle both async and sync database results.
43+
44+
This function allows code to work with both aiopg (async) and asyncpg (sync) result methods
45+
by automatically detecting and handling both cases.
46+
47+
Args:
48+
obj: Either an awaitable coroutine or direct result value
49+
50+
Returns:
51+
The result value, after awaiting if necessary
52+
53+
Example:
54+
```python
55+
result = await conn.execute(query)
56+
# Works with both aiopg and asyncpg
57+
row = await maybe_await(result.fetchone())
58+
```
59+
"""
60+
if isawaitable(obj):
61+
assert isawaitable(obj) # nosec
62+
return await obj
63+
assert not isawaitable(obj) # nosec
64+
return obj

packages/common-library/tests/test_async_tools.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
from concurrent.futures import ThreadPoolExecutor
3+
from typing import Any
34

45
import pytest
5-
from common_library.async_tools import make_async
6+
from common_library.async_tools import make_async, maybe_await
67

78

89
@make_async()
@@ -43,3 +44,52 @@ def heavy_computation(x: int) -> int:
4344

4445
result = await heavy_computation(4)
4546
assert result == 16, "Function should return 16"
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_maybe_await_with_coroutine():
51+
"""Test maybe_await with an async function"""
52+
53+
async def async_value():
54+
return 42
55+
56+
result = await maybe_await(async_value())
57+
assert result == 42
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_maybe_await_with_direct_value():
62+
"""Test maybe_await with a direct value"""
63+
value = 42
64+
result = await maybe_await(value)
65+
assert result == 42
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_maybe_await_with_none():
70+
"""Test maybe_await with None value"""
71+
result = await maybe_await(None)
72+
assert result is None
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_maybe_await_with_result_proxy():
77+
"""Test maybe_await with both async and sync ResultProxy implementations"""
78+
79+
class AsyncResultProxy:
80+
"""Mock async result proxy (aiopg style)"""
81+
82+
async def fetchone(self) -> Any: # pylint: disable=no-self-use
83+
return {"id": 1, "name": "test"}
84+
85+
class SyncResultProxy:
86+
"""Mock sync result proxy (asyncpg style)"""
87+
88+
def fetchone(self) -> Any: # pylint: disable=no-self-use
89+
return {"id": 2, "name": "test2"}
90+
91+
async_result = await maybe_await(AsyncResultProxy().fetchone())
92+
assert async_result == {"id": 1, "name": "test"}
93+
94+
sync_result = await maybe_await(SyncResultProxy().fetchone())
95+
assert sync_result == {"id": 2, "name": "test2"}
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,49 @@
11
"""Common protocols to annotate equivalent connections:
2-
- sqlalchemy.ext.asyncio.AsyncConnection
3-
- aiopg.sa.connection.SAConnection
2+
- sqlalchemy.ext.asyncio.AsyncConnection
3+
- aiopg.sa.connection.SAConnection
44
5-
6-
Purpose: to reduce dependency wit aiopg (expected full migration to asyncpg)
5+
Purpose: to reduce dependency with aiopg (expected full migration to asyncpg)
76
"""
87

9-
from typing import Protocol
8+
from collections.abc import Awaitable
9+
from typing import Any, Protocol, TypeAlias, TypeVar
1010

11+
from sqlalchemy.sql import Executable
1112

12-
class DBConnection(Protocol):
13-
# Prototype to account for aiopg and asyncio connection classes, i.e.
14-
# from aiopg.sa.connection import SAConnection
15-
# from sqlalchemy.ext.asyncio import AsyncConnection
16-
async def scalar(self, *args, **kwargs):
17-
...
13+
# Type for query results
14+
Result = TypeVar("Result")
1815

19-
async def execute(self, *args, **kwargs):
20-
...
16+
# Type alias for methods that can be either async or sync
17+
MaybeCoro: TypeAlias = Awaitable[Result] | Result
2118

22-
async def begin(self):
23-
...
2419

20+
class ResultProxy(Protocol):
21+
"""Protocol for query result objects from both engines
2522
26-
class AiopgConnection(Protocol):
27-
# Prototype to account for aiopg-only (this protocol avoids import <-> installation)
28-
async def scalar(self, *args, **kwargs):
29-
...
23+
Handles both aiopg's async methods and SQLAlchemy asyncpg's sync methods.
24+
This is temporary until we fully migrate to asyncpg.
25+
"""
3026

31-
async def execute(self, *args, **kwargs):
32-
...
27+
def fetchall(self) -> MaybeCoro[list[Any]]: ...
28+
def fetchone(self) -> MaybeCoro[Any | None]: ...
29+
def first(self) -> MaybeCoro[Any | None]: ...
3330

34-
async def begin(self):
35-
...
31+
32+
class DBConnection(Protocol):
33+
"""Protocol to account for both aiopg and SQLAlchemy async connections"""
34+
35+
async def scalar(
36+
self,
37+
statement: Executable,
38+
parameters: dict[str, Any] | None = None,
39+
*,
40+
execution_options: dict[str, Any] | None = None,
41+
) -> Any: ...
42+
43+
async def execute(
44+
self,
45+
statement: Executable,
46+
parameters: dict[str, Any] | None = None,
47+
*,
48+
execution_options: dict[str, Any] | None = None,
49+
) -> ResultProxy: ...

packages/postgres-database/src/simcore_postgres_database/utils_aiosqlalchemy.py

+51
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
from typing import Any, TypeAlias, TypeVar
2+
13
import sqlalchemy as sa
4+
import sqlalchemy.exc as sql_exc
5+
from common_library.errors_classes import OsparcErrorMixin
6+
from sqlalchemy.dialects.postgresql.asyncpg import AsyncAdapt_asyncpg_dbapi
27
from sqlalchemy.ext.asyncio import AsyncEngine
38

49
from .utils_migration import get_current_head
@@ -29,3 +34,49 @@ async def raise_if_migration_not_ready(engine: AsyncEngine) -> None:
2934
if version_num != head_version_num:
3035
msg = f"Migration is incomplete, expected {head_version_num} but got {version_num}"
3136
raise DBMigrationError(msg)
37+
38+
39+
AsyncpgSQLState: TypeAlias = str
40+
ErrorT = TypeVar("ErrorT", bound=OsparcErrorMixin)
41+
ErrorKwars: TypeAlias = dict[str, Any]
42+
43+
44+
def map_db_exception(
45+
exception: Exception,
46+
exception_map: dict[AsyncpgSQLState, tuple[type[ErrorT], ErrorKwars]],
47+
default_exception: type[ErrorT] | None = None,
48+
) -> ErrorT | Exception:
49+
"""Maps SQLAlchemy database exceptions to domain-specific exceptions.
50+
51+
This function inspects SQLAlchemy and asyncpg exceptions to identify the error type
52+
by checking pgcodes or error messages, and converts them to appropriate domain exceptions.
53+
54+
Args:
55+
exception: The original exception from SQLAlchemy or the database driver
56+
exception_map: Dictionary mapping pgcode
57+
default_exception: Exception class to use if no matching error is found
58+
59+
Returns:
60+
Domain-specific exception instance or the original exception if no mapping found
61+
and no default_exception provided
62+
"""
63+
pgcode = None
64+
65+
# Handle SQLAlchemy wrapped exceptions
66+
if isinstance(exception, sql_exc.IntegrityError) and hasattr(exception, "orig"):
67+
orig_error = exception.orig
68+
# Handle asyncpg adapter exceptions
69+
if isinstance(orig_error, AsyncAdapt_asyncpg_dbapi.IntegrityError) and hasattr(
70+
orig_error, "pgcode"
71+
):
72+
assert hasattr(orig_error, "pgcode") # nosec
73+
pgcode = orig_error.pgcode
74+
75+
# Match by pgcode if available
76+
if pgcode:
77+
for key, (exc_class, params) in exception_map.items():
78+
if key == pgcode:
79+
return exc_class(**params)
80+
81+
# If no match found, return default exception or original
82+
return default_exception() if default_exception else exception

packages/postgres-database/src/simcore_postgres_database/utils_groups_extra_properties.py

+10-50
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import datetime
22
import logging
33
import warnings
4+
from collections.abc import Callable
45
from dataclasses import dataclass, fields
5-
from typing import Any, Callable
6+
from typing import Any
67

78
import sqlalchemy as sa
8-
from aiopg.sa.connection import SAConnection
9-
from aiopg.sa.result import RowProxy
9+
from common_library.async_tools import maybe_await
1010
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
1111

12+
from ._protocols import DBConnection
1213
from .models.groups import GroupType, groups, user_to_groups
1314
from .models.groups_extra_properties import groups_extra_properties
1415
from .utils_models import FromRowMixin
@@ -22,12 +23,10 @@
2223
)
2324

2425

25-
class GroupExtraPropertiesError(Exception):
26-
...
26+
class GroupExtraPropertiesError(Exception): ...
2727

2828

29-
class GroupExtraPropertiesNotFoundError(GroupExtraPropertiesError):
30-
...
29+
class GroupExtraPropertiesNotFoundError(GroupExtraPropertiesError): ...
3130

3231

3332
@dataclass(frozen=True, slots=True, kw_only=True)
@@ -99,24 +98,6 @@ def _get_stmt(gid: int, product_name: str):
9998
& (groups_extra_properties.c.product_name == product_name)
10099
)
101100

102-
@staticmethod
103-
async def get(
104-
connection: SAConnection, *, gid: int, product_name: str
105-
) -> GroupExtraProperties:
106-
warnings.warn(
107-
_WARNING_FMSG.format("get", "get_v2"),
108-
DeprecationWarning,
109-
stacklevel=1,
110-
)
111-
112-
query = GroupExtraPropertiesRepo._get_stmt(gid, product_name)
113-
result = await connection.execute(query)
114-
assert result # nosec
115-
if row := await result.first():
116-
return GroupExtraProperties.from_row_proxy(row)
117-
msg = f"Properties for group {gid} not found"
118-
raise GroupExtraPropertiesNotFoundError(msg)
119-
120101
@staticmethod
121102
async def get_v2(
122103
engine: AsyncEngine,
@@ -174,7 +155,7 @@ def _aggregate(
174155

175156
@staticmethod
176157
async def get_aggregated_properties_for_user(
177-
connection: SAConnection,
158+
connection: DBConnection,
178159
*,
179160
user_id: int,
180161
product_name: str,
@@ -197,30 +178,9 @@ async def get_aggregated_properties_for_user(
197178
)
198179
assert result # nosec
199180

200-
rows: list[RowProxy] | None = await result.fetchall()
201-
assert rows is not None # nosec
181+
rows = await maybe_await(result.fetchall())
182+
assert isinstance(rows, list) # nosec
202183

203184
return GroupExtraPropertiesRepo._aggregate(
204-
rows, user_id, product_name, GroupExtraProperties.from_row_proxy
185+
rows, user_id, product_name, GroupExtraProperties.from_row
205186
)
206-
207-
@staticmethod
208-
async def get_aggregated_properties_for_user_v2(
209-
engine: AsyncEngine,
210-
connection: AsyncConnection | None = None,
211-
*,
212-
user_id: int,
213-
product_name: str,
214-
) -> GroupExtraProperties:
215-
async with pass_or_acquire_connection(engine, connection) as conn:
216-
217-
list_stmt = _list_table_entries_ordered_by_group_type_stmt(
218-
user_id=user_id, product_name=product_name
219-
)
220-
result = await conn.stream(
221-
sa.select(list_stmt).order_by(list_stmt.c.type_order)
222-
)
223-
rows = [row async for row in result]
224-
return GroupExtraPropertiesRepo._aggregate(
225-
rows, user_id, product_name, GroupExtraProperties.from_row
226-
)

0 commit comments

Comments
 (0)