Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dispatch/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy import engine_from_config, pool

from dispatch.config import SQLALCHEMY_DATABASE_URI
from dispatch.database import Base
from dispatch.database.core import Base

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/alembic/versions/b6da2dad0396_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from alembic import op
import sqlalchemy as sa

from dispatch.database import SessionLocal
from dispatch.database.core import SessionLocal
from dispatch.incident_cost_type.models import IncidentCostTypeUpdate
from dispatch.incident_cost_type import service as incident_cost_type_service

Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlalchemy import Column, String, Binary, Integer
from sqlalchemy_utils import TSVectorType

from dispatch.database import Base
from dispatch.database.core import Base
from dispatch.models import TimeStampMixin, DispatchBase
from dispatch.enums import UserRoles

Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from dispatch.database import get_db
from dispatch.database.core import get_db

from dispatch.plugins.base import plugins
from dispatch.config import (
Expand Down
3 changes: 2 additions & 1 deletion src/dispatch/auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from sqlalchemy.orm import Session
from dispatch.auth.permissions import AdminPermission, PermissionsDependency

from dispatch.database import get_db, search_filter_sort_paginate
from dispatch.database.core import get_db
from dispatch.database.service import search_filter_sort_paginate

from .models import (
UserLogin,
Expand Down
12 changes: 6 additions & 6 deletions src/dispatch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dispatch.enums import UserRoles

from .main import * # noqa
from .database import Base, engine
from .database.core import Base, engine
from .exceptions import DispatchException
from .plugins.base import plugins
from .scheduler import scheduler
Expand Down Expand Up @@ -47,7 +47,7 @@ def plugins_group():
@plugins_group.command("list")
def list_plugins():
"""Shows all available plugins"""
from dispatch.database import SessionLocal
from dispatch.database.core import SessionLocal
from dispatch.plugin import service as plugin_service

db_session = SessionLocal()
Expand Down Expand Up @@ -99,7 +99,7 @@ def list_plugins():
)
def install_plugins(force):
"""Installs all plugins, or only one."""
from dispatch.database import SessionLocal
from dispatch.database.core import SessionLocal
from dispatch.plugin import service as plugin_service
from dispatch.plugin.models import Plugin

Expand Down Expand Up @@ -141,7 +141,7 @@ def install_plugins(force):
@click.argument("plugins", nargs=-1)
def uninstall_plugins(plugins):
"""Uninstalls all plugins, or only one."""
from dispatch.database import SessionLocal
from dispatch.database.core import SessionLocal
from dispatch.plugin import service as plugin_service

db_session = SessionLocal()
Expand Down Expand Up @@ -198,7 +198,7 @@ def dispatch_user():
)
def update_user(email: str, role: str):
"""Updates a user's roles."""
from dispatch.database import SessionLocal
from dispatch.database.core import SessionLocal
from dispatch.auth import service as user_service
from dispatch.auth.models import UserUpdate

Expand All @@ -217,7 +217,7 @@ def update_user(email: str, role: str):
@click.password_option()
def reset_user_password(email: str, password: str):
"""Resets a user's password."""
from dispatch.database import SessionLocal
from dispatch.database.core import SessionLocal
from dispatch.auth import service as user_service
from dispatch.auth.models import UserUpdate

Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/conference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import validator
from sqlalchemy import Column, Integer, String, ForeignKey

from dispatch.database import Base
from dispatch.database.core import Base
from dispatch.messaging.strings import INCIDENT_CONFERENCE_DESCRIPTION
from dispatch.models import DispatchBase, ResourceMixin

Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/conversation/messaging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dispatch.database import SessionLocal
from dispatch.database.core import SessionLocal
from dispatch.plugin import service as plugin_service


Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/conversation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sqlalchemy import Column, String, Integer, ForeignKey

from dispatch.database import Base
from dispatch.database.core import Base
from dispatch.messaging.strings import INCIDENT_CONVERSATION_DESCRIPTION
from dispatch.models import DispatchBase, ResourceMixin

Expand Down
66 changes: 66 additions & 0 deletions src/dispatch/database/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import re
import functools
from typing import Any

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker
from sqlalchemy_searchable import make_searchable
from starlette.requests import Request

from dispatch.config import SQLALCHEMY_DATABASE_URI


engine = create_engine(str(SQLALCHEMY_DATABASE_URI))
SessionLocal = sessionmaker(bind=engine)


def resolve_table_name(name):
"""Resolves table names to their mapped names."""
names = re.split("(?=[A-Z])", name) # noqa
return "_".join([x.lower() for x in names if x])


raise_attribute_error = object()


def resolve_attr(obj, attr, default=None):
"""Attempts to access attr via dotted notation, returns none if attr does not exist."""
try:
return functools.reduce(getattr, attr.split("."), obj)
except AttributeError:
return default


class CustomBase:
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)


Base = declarative_base(cls=CustomBase)
make_searchable(Base.metadata)


def get_db(request: Request):
return request.state.db


def get_model_name_by_tablename(table_fullname: str) -> str:
"""Returns the model name of a given table."""
return get_class_by_tablename(table_fullname=table_fullname).__name__


def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table."""
mapped_name = resolve_table_name(table_fullname)
for c in Base._decl_class_registry.values():
if hasattr(c, "__table__"):
if c.__table__.fullname.lower() == mapped_name.lower():
return c
raise Exception(f"Incorrect tablename '{mapped_name}'. Check the name of your model.")


def get_table_name_by_class_instance(class_instance: Base) -> str:
"""Returns the name of the table for a given class instance."""
return class_instance._sa_instance_state.mapper.mapped_table.name
118 changes: 34 additions & 84 deletions src/dispatch/database.py → src/dispatch/database/service.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,26 @@
import re
import logging
import json
from typing import Any, List
from itertools import groupby
import functools

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import Query, sessionmaker
from typing import List

from sqlalchemy import and_, not_
from sqlalchemy.orm import Query
from sqlalchemy_filters import apply_pagination, apply_sort, apply_filters
from sqlalchemy_searchable import make_searchable
from sqlalchemy_searchable import search as search_db
from starlette.requests import Request

from dispatch.common.utils.composite_search import CompositeSearch
from dispatch.enums import Visibility, UserRoles
from dispatch.incident.models import Incident
from dispatch.incident_type.models import IncidentType
from dispatch.individual.models import IndividualContact
from dispatch.participant.models import Participant

from .core import Base, get_class_by_tablename, get_model_name_by_tablename

from .config import SQLALCHEMY_DATABASE_URI

log = logging.getLogger(__file__)

engine = create_engine(str(SQLALCHEMY_DATABASE_URI))
SessionLocal = sessionmaker(bind=engine)


def resolve_table_name(name):
"""Resolves table names to their mapped names."""
names = re.split("(?=[A-Z])", name) # noqa
return "_".join([x.lower() for x in names if x])


raise_attribute_error = object()


def resolve_attr(obj, attr, default=None):
"""Attempts to access attr via dotted notation, returns none if attr does not exist."""
try:
return functools.reduce(getattr, attr.split("."), obj)
except AttributeError:
return default


class CustomBase:
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)


Base = declarative_base(cls=CustomBase)

make_searchable(Base.metadata)


def get_db(request: Request):
return request.state.db


def get_model_name_by_tablename(table_fullname: str) -> str:
"""Returns the model name of a given table."""
return get_class_by_tablename(table_fullname=table_fullname).__name__


def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table."""
mapped_name = resolve_table_name(table_fullname)
for c in Base._decl_class_registry.values():
if hasattr(c, "__table__"):
if c.__table__.fullname.lower() == mapped_name.lower():
return c
raise Exception(f"Incorrect tablename '{mapped_name}'. Check the name of your model.")


def get_table_name_by_class_instance(class_instance: Base) -> str:
"""Returns the name of the table for a given class instance."""
return class_instance._sa_instance_state.mapper.mapped_table.name


def paginate(query: Query, page: int, items_per_page: int):
# Never pass a negative OFFSET value to SQL.
Expand All @@ -98,7 +43,7 @@ def search(*, db_session, query_str: str, model: str, sort=False):
return search_db(q, query_str, sort=sort)


def create_filter_spec(model, fields, ops, values, user_role):
def create_filter_spec(model, fields, ops, values):
"""Creates a filter spec."""
filters = []

Expand Down Expand Up @@ -133,23 +78,8 @@ def create_filter_spec(model, fields, ops, values, user_role):
else:
filter_spec.append({"or": filters})

# add admin only filter
if user_role != UserRoles.admin:
# add support for filtering restricted incidents
if model.lower() in ["incident", "task"]:
filter_spec.append(
{
"model": "Incident",
"field": "visibility",
"op": "!=",
"value": Visibility.restricted,
}
)

if filter_spec:
filter_spec = {"and": filter_spec}

log.debug(f"Filter Spec: {json.dumps(filter_spec, indent=2)}")

return filter_spec


Expand Down Expand Up @@ -210,9 +140,10 @@ def search_filter_sort_paginate(
ops: List[str] = None,
values: List[str] = None,
join_attrs: List[str] = None,
user_role: UserRoles = UserRoles.user,
user_role: UserRoles = UserRoles.user.value,
user_email: str = None,
):
"""Common functionality for searching, filtering and sorting"""
"""Common functionality for searching, filtering, sorting, and pagination."""
model_cls = get_class_by_tablename(model)
sort_spec = create_sort_spec(model, sort_by, descending)

Expand All @@ -222,10 +153,28 @@ def search_filter_sort_paginate(
else:
query = db_session.query(model_cls)

if user_role != UserRoles.admin.value:
if model.lower() == "incident":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do something similar for tasks for restricted incidents?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to add logic for it, but when I started testing tasks I realized they were already being filtered. I did some research to try to figure out the why and the conclusion that I reached was that when we load /tasks we perform a GET against the /api/v1/incidents/ endpoint, which returns a list of filtered incidents, and then sqlalchemy filters out the tasks based on that list.

# we filter restricted incidents based on incident participation
query = (
query.join(Participant)
.join(IndividualContact)
.filter(
not_(
and_(
Incident.visibility == Visibility.restricted.value,
IndividualContact.email != user_email,
)
)
)
)
if model.lower() == "incidenttype":
query = query.filter(IncidentType.visibility == Visibility.open.value)

query = join_required_attrs(query, model_cls, join_attrs, fields, sort_by)

if not filter_spec:
filter_spec = create_filter_spec(model, fields, ops, values, user_role)
filter_spec = create_filter_spec(model, fields, ops, values)

query = apply_filters(query, filter_spec)
query = apply_sort(query, sort_spec)
Expand All @@ -234,6 +183,7 @@ def search_filter_sort_paginate(
items_per_page = None

query, pagination = apply_pagination(query, page_number=page, page_size=items_per_page)

return {
"items": query.all(),
"itemsPerPage": pagination.page_size,
Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Any, List

from dispatch.metrics import provider as metrics_provider
from .database import SessionLocal

from .database.core import SessionLocal

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/definition/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy_utils import TSVectorType

from dispatch.database import Base
from dispatch.database.core import Base
from dispatch.models import (
DispatchBase,
TermNested,
Expand Down
3 changes: 2 additions & 1 deletion src/dispatch/definition/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session

from dispatch.database import get_db, paginate
from dispatch.database.core import get_db
from dispatch.database.service import paginate
from dispatch.search.service import search

from .models import (
Expand Down
Loading