Skip to content

275 Remove async driver #574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 124 additions & 84 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ packages = [{ include = "sbl_filing_api", from = "src" }]
python = ">=3.12,<4"
sqlalchemy = "^2.0.35"
psycopg2-binary = "^2.9.9"
asyncpg = "^0.30.0"
regtech-api-commons = {git = "https://github.com/cfpb/regtech-api-commons.git"}
regtech-data-validator = {git = "https://github.com/cfpb/regtech-data-validator.git"}
regtech-regex = {git = "https://github.com/cfpb/regtech-regex.git"}
Expand All @@ -27,7 +26,6 @@ pytest-mock = "^3.12.0"
pytest-env = "^1.1.5"
pytest-alembic = "^0.11.1"
pytest-asyncio = "^0.25.3"
aiosqlite = "^0.20.0"


[tool.poetry.group.linters.dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ DB_USER=
DB_PWD=
DB_HOST=
DB_SCHEMA=
# DB_SCHEME= can be used to override postgresql+asyncpg if needed
# DB_SCHEME= can be used to override postgresql+psycopg2 if needed
KC_URL=
KC_REALM=
KC_ADMIN_CLIENT_ID=
Expand Down
2 changes: 1 addition & 1 deletion src/sbl_filing_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Settings(BaseSettings):
db_user: str
db_pwd: str
db_host: str
db_scheme: str = "postgresql+asyncpg"
db_scheme: str = "postgresql+psycopg2"
db_logging: bool = False
conn: PostgresDsn | None = None

Expand Down
20 changes: 8 additions & 12 deletions src/sbl_filing_api/entities/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from sqlalchemy.ext.asyncio import (
create_async_engine,
async_sessionmaker,
async_scoped_session,
)
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.pool import NullPool
from asyncio import current_task
from sbl_filing_api.config import settings

engine = create_async_engine(
settings.conn.unicode_string(), echo=settings.db_logging, poolclass=NullPool
).execution_options(schema_translate_map={None: settings.db_schema})
SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)
engine = create_engine(settings.conn.unicode_string(), echo=settings.db_logging, poolclass=NullPool).execution_options(
schema_translate_map={None: settings.db_schema}
)
SessionLocal = scoped_session(sessionmaker(engine, expire_on_commit=False))


async def get_session():
def get_session():
session = SessionLocal()
try:
yield session
finally:
await session.close()
session.close()
3 changes: 1 addition & 2 deletions src/sbl_filing_api/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from sqlalchemy import Enum as SAEnum, String, desc
from sqlalchemy import ForeignKey, func, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.types import JSON


class Base(AsyncAttrs, DeclarativeBase):
class Base(DeclarativeBase):
pass


Expand Down
136 changes: 67 additions & 69 deletions src/sbl_filing_api/entities/repos/submission_repo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging

from sqlalchemy import select, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from typing import Any, List, TypeVar
from sbl_filing_api.entities.engine.engine import SessionLocal

from regtech_api_commons.models.auth import AuthenticatedUser

from async_lru import alru_cache
from functools import lru_cache

from sbl_filing_api.entities.models.dao import (
SubmissionDAO,
Expand All @@ -31,73 +31,73 @@ class NoFilingPeriodException(Exception):
pass


async def get_submissions(session: AsyncSession, lei: str = None, filing_period: str = None) -> List[SubmissionDAO]:
def get_submissions(session: Session, lei: str = None, filing_period: str = None) -> List[SubmissionDAO]:
filing_id = None
if lei and filing_period:
filing = await get_filing(session, lei=lei, filing_period=filing_period)
filing = get_filing(session, lei=lei, filing_period=filing_period)
filing_id = filing.id
return await query_helper(session, SubmissionDAO, filing=filing_id)
return query_helper(session, SubmissionDAO, filing=filing_id)


async def get_latest_submission(session: AsyncSession, lei: str, filing_period: str) -> SubmissionDAO | None:
filing = await get_filing(session, lei=lei, filing_period=filing_period)
def get_latest_submission(session: Session, lei: str, filing_period: str) -> SubmissionDAO | None:
filing = get_filing(session, lei=lei, filing_period=filing_period)
stmt = select(SubmissionDAO).filter_by(filing=filing.id).order_by(desc(SubmissionDAO.submission_time)).limit(1)
return await session.scalar(stmt)
return session.scalar(stmt)


async def get_filing_periods(session: AsyncSession) -> List[FilingPeriodDAO]:
return await query_helper(session, FilingPeriodDAO)
def get_filing_periods(session: Session) -> List[FilingPeriodDAO]:
return query_helper(session, FilingPeriodDAO)


async def get_submission(session: AsyncSession, submission_id: int) -> SubmissionDAO:
result = await query_helper(session, SubmissionDAO, id=submission_id)
def get_submission(session: Session, submission_id: int) -> SubmissionDAO:
result = query_helper(session, SubmissionDAO, id=submission_id)
return result[0] if result else None


async def get_submission_by_counter(session: AsyncSession, lei: str, filing_period: str, counter: int) -> SubmissionDAO:
filing = await get_filing(session, lei=lei, filing_period=filing_period)
result = await query_helper(session, SubmissionDAO, filing=filing.id, counter=counter)
def get_submission_by_counter(session: Session, lei: str, filing_period: str, counter: int) -> SubmissionDAO:
filing = get_filing(session, lei=lei, filing_period=filing_period)
result = query_helper(session, SubmissionDAO, filing=filing.id, counter=counter)
return result[0] if result else None


async def get_filing(session: AsyncSession, lei: str, filing_period: str) -> FilingDAO:
result = await query_helper(session, FilingDAO, lei=lei, filing_period=filing_period)
def get_filing(session: Session, lei: str, filing_period: str) -> FilingDAO:
result = query_helper(session, FilingDAO, lei=lei, filing_period=filing_period)
return result[0] if result else None


async def get_filings(session: AsyncSession, leis: list[str], filing_period: str) -> list[FilingDAO]:
def get_filings(session: Session, leis: list[str], filing_period: str) -> list[FilingDAO]:
stmt = select(FilingDAO).filter(FilingDAO.lei.in_(leis), FilingDAO.filing_period == filing_period)
result = (await session.scalars(stmt)).all()
result = (session.scalars(stmt)).all()
return result if result else []


async def get_period_filings(session: AsyncSession, filing_period: str) -> List[FilingDAO]:
filings = await query_helper(session, FilingDAO, filing_period=filing_period)
def get_period_filings(session: Session, filing_period: str) -> List[FilingDAO]:
filings = query_helper(session, FilingDAO, filing_period=filing_period)
return filings


async def get_filing_period(session: AsyncSession, filing_period: str) -> FilingPeriodDAO:
result = await query_helper(session, FilingPeriodDAO, code=filing_period)
def get_filing_period(session: Session, filing_period: str) -> FilingPeriodDAO:
result = query_helper(session, FilingPeriodDAO, code=filing_period)
return result[0] if result else None


@alru_cache(maxsize=128)
async def get_filing_tasks(session: AsyncSession) -> List[FilingTaskDAO]:
return await query_helper(session, FilingTaskDAO)
@lru_cache(maxsize=128)
def get_filing_tasks(session: Session) -> List[FilingTaskDAO]:
return query_helper(session, FilingTaskDAO)


async def get_user_action(session: AsyncSession, id: int) -> UserActionDAO:
result = await query_helper(session, UserActionDAO, id=id)
def get_user_action(session: Session, id: int) -> UserActionDAO:
result = query_helper(session, UserActionDAO, id=id)
return result[0] if result else None


async def get_user_actions(session: AsyncSession) -> List[UserActionDAO]:
return await query_helper(session, UserActionDAO)
def get_user_actions(session: Session) -> List[UserActionDAO]:
return query_helper(session, UserActionDAO)


async def add_submission(session: AsyncSession, filing_id: int, filename: str, submitter_id: int) -> SubmissionDAO:
def add_submission(session: Session, filing_id: int, filename: str, submitter_id: int) -> SubmissionDAO:
stmt = select(SubmissionDAO).filter_by(filing=filing_id).order_by(desc(SubmissionDAO.counter)).limit(1)
last_sub = await session.scalar(stmt)
last_sub = session.scalar(stmt)
current_count = last_sub.counter if last_sub else 0
new_sub = SubmissionDAO(
filing=filing_id,
Expand All @@ -107,93 +107,91 @@ async def add_submission(session: AsyncSession, filing_id: int, filename: str, s
counter=(current_count + 1),
)
# this returns the attached object, most importantly with the new submission id
new_sub = await session.merge(new_sub)
await session.commit()
new_sub = session.merge(new_sub)
session.commit()
return new_sub


async def update_submission(session: AsyncSession, submission: SubmissionDAO) -> SubmissionDAO:
return await upsert_helper(session, submission, SubmissionDAO)
def update_submission(session: Session, submission: SubmissionDAO) -> SubmissionDAO:
return upsert_helper(session, submission, SubmissionDAO)


async def expire_submission(submission_id: int):
async with SessionLocal() as session:
submission = await get_submission(session, submission_id)
def expire_submission(submission_id: int):
with SessionLocal() as session:
submission = get_submission(session, submission_id)
submission.state = SubmissionState.VALIDATION_EXPIRED
await upsert_helper(session, submission, SubmissionDAO)
upsert_helper(session, submission, SubmissionDAO)


async def error_out_submission(submission_id: int):
async with SessionLocal() as session:
submission = await get_submission(session, submission_id)
def error_out_submission(submission_id: int):
with SessionLocal() as session:
submission = get_submission(session, submission_id)
submission.state = SubmissionState.VALIDATION_ERROR
await upsert_helper(session, submission, SubmissionDAO)
upsert_helper(session, submission, SubmissionDAO)


async def upsert_filing_period(session: AsyncSession, filing_period: FilingPeriodDTO) -> FilingPeriodDAO:
return await upsert_helper(session, filing_period, FilingPeriodDAO)
def upsert_filing_period(session: Session, filing_period: FilingPeriodDTO) -> FilingPeriodDAO:
return upsert_helper(session, filing_period, FilingPeriodDAO)


async def upsert_filing(session: AsyncSession, filing: FilingDTO) -> FilingDAO:
return await upsert_helper(session, filing, FilingDAO)
def upsert_filing(session: Session, filing: FilingDTO) -> FilingDAO:
return upsert_helper(session, filing, FilingDAO)


async def create_new_filing(session: AsyncSession, lei: str, filing_period: str, creator_id: int) -> FilingDAO:
def create_new_filing(session: Session, lei: str, filing_period: str, creator_id: int) -> FilingDAO:
new_filing = FilingDAO(filing_period=filing_period, lei=lei, creator_id=creator_id)
return await upsert_helper(session, new_filing, FilingDAO)
return upsert_helper(session, new_filing, FilingDAO)


async def update_task_state(
session: AsyncSession, lei: str, filing_period: str, task_name: str, state: FilingTaskState, user: AuthenticatedUser
def update_task_state(
session: Session, lei: str, filing_period: str, task_name: str, state: FilingTaskState, user: AuthenticatedUser
):
filing = await get_filing(session, lei=lei, filing_period=filing_period)
found_task = await query_helper(session, FilingTaskProgressDAO, filing=filing.id, task_name=task_name)
filing = get_filing(session, lei=lei, filing_period=filing_period)
found_task = query_helper(session, FilingTaskProgressDAO, filing=filing.id, task_name=task_name)
if found_task:
task = found_task[0] # should only be one
task.state = state
task.user = user.username
else:
task = FilingTaskProgressDAO(filing=filing.id, state=state, task_name=task_name, user=user.username)
await upsert_helper(session, task, FilingTaskProgressDAO)
upsert_helper(session, task, FilingTaskProgressDAO)


async def update_contact_info(
session: AsyncSession, lei: str, filing_period: str, new_contact_info: ContactInfoDTO
) -> FilingDAO:
filing = await get_filing(session, lei=lei, filing_period=filing_period)
def update_contact_info(session: Session, lei: str, filing_period: str, new_contact_info: ContactInfoDTO) -> FilingDAO:
filing = get_filing(session, lei=lei, filing_period=filing_period)
if filing.contact_info:
for key, value in new_contact_info.__dict__.items():
if key != "id":
setattr(filing.contact_info, key, value)
else:
filing.contact_info = ContactInfoDAO(**new_contact_info.__dict__.copy(), filing=filing.id)
return await upsert_helper(session, filing, FilingDAO)
return upsert_helper(session, filing, FilingDAO)


async def add_user_action(
session: AsyncSession,
def add_user_action(
session: Session,
new_user_action: UserActionDTO,
) -> UserActionDAO:
return await upsert_helper(session, new_user_action, UserActionDAO)
return upsert_helper(session, new_user_action, UserActionDAO)


async def upsert_helper(session: AsyncSession, original_data: Any, table_obj: T) -> T:
def upsert_helper(session: Session, original_data: Any, table_obj: T) -> T:
copy_data = original_data.__dict__.copy()
# this is only for if a DAO is passed in
# Should be DTOs, but hey, it's python
if "_sa_instance_state" in copy_data:
del copy_data["_sa_instance_state"]
new_dao = table_obj(**copy_data)
new_dao = await session.merge(new_dao)
await session.commit()
await session.refresh(new_dao)
new_dao = session.merge(new_dao)
session.commit()
session.refresh(new_dao)
return new_dao


async def query_helper(session: AsyncSession, table_obj: T, **filter_args) -> List[T]:
def query_helper(session: Session, table_obj: T, **filter_args) -> List[T]:
stmt = select(table_obj)
# remove empty args
filter_args = {k: v for k, v in filter_args.items() if v is not None}
if filter_args:
stmt = stmt.filter_by(**filter_args)
return (await session.scalars(stmt)).all()
return (session.scalars(stmt)).all()
Loading
Loading