Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.

Commit da8b145

Browse files
authored
Consolidates restricted visibility-based filtering (#995)
* Consolidates restricted visibility-based filtering * Improvements * Renames base.py to core.py * Removes unused imports * Adds restricted filtering for incident types
1 parent b8b876c commit da8b145

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+223
-197
lines changed

src/dispatch/alembic/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlalchemy import engine_from_config, pool
55

66
from dispatch.config import SQLALCHEMY_DATABASE_URI
7-
from dispatch.database import Base
7+
from dispatch.database.core import Base
88

99
# this is the Alembic Config object, which provides
1010
# access to the values within the .ini file in use.

src/dispatch/alembic/versions/b6da2dad0396_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from alembic import op
99
import sqlalchemy as sa
1010

11-
from dispatch.database import SessionLocal
11+
from dispatch.database.core import SessionLocal
1212
from dispatch.incident_cost_type.models import IncidentCostTypeUpdate
1313
from dispatch.incident_cost_type import service as incident_cost_type_service
1414

src/dispatch/auth/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sqlalchemy import Column, String, Binary, Integer
1111
from sqlalchemy_utils import TSVectorType
1212

13-
from dispatch.database import Base
13+
from dispatch.database.core import Base
1414
from dispatch.models import TimeStampMixin, DispatchBase
1515
from dispatch.enums import UserRoles
1616

src/dispatch/auth/service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from sqlalchemy.orm import Session
1616
from sqlalchemy.exc import IntegrityError
17-
from dispatch.database import get_db
17+
from dispatch.database.core import get_db
1818

1919
from dispatch.plugins.base import plugins
2020
from dispatch.config import (

src/dispatch/auth/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from sqlalchemy.orm import Session
55
from dispatch.auth.permissions import AdminPermission, PermissionsDependency
66

7-
from dispatch.database import get_db, search_filter_sort_paginate
7+
from dispatch.database.core import get_db
8+
from dispatch.database.service import search_filter_sort_paginate
89

910
from .models import (
1011
UserLogin,

src/dispatch/cli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dispatch.enums import UserRoles
1414

1515
from .main import * # noqa
16-
from .database import Base, engine
16+
from .database.core import Base, engine
1717
from .exceptions import DispatchException
1818
from .plugins.base import plugins
1919
from .scheduler import scheduler
@@ -47,7 +47,7 @@ def plugins_group():
4747
@plugins_group.command("list")
4848
def list_plugins():
4949
"""Shows all available plugins"""
50-
from dispatch.database import SessionLocal
50+
from dispatch.database.core import SessionLocal
5151
from dispatch.plugin import service as plugin_service
5252

5353
db_session = SessionLocal()
@@ -99,7 +99,7 @@ def list_plugins():
9999
)
100100
def install_plugins(force):
101101
"""Installs all plugins, or only one."""
102-
from dispatch.database import SessionLocal
102+
from dispatch.database.core import SessionLocal
103103
from dispatch.plugin import service as plugin_service
104104
from dispatch.plugin.models import Plugin
105105

@@ -141,7 +141,7 @@ def install_plugins(force):
141141
@click.argument("plugins", nargs=-1)
142142
def uninstall_plugins(plugins):
143143
"""Uninstalls all plugins, or only one."""
144-
from dispatch.database import SessionLocal
144+
from dispatch.database.core import SessionLocal
145145
from dispatch.plugin import service as plugin_service
146146

147147
db_session = SessionLocal()
@@ -198,7 +198,7 @@ def dispatch_user():
198198
)
199199
def update_user(email: str, role: str):
200200
"""Updates a user's roles."""
201-
from dispatch.database import SessionLocal
201+
from dispatch.database.core import SessionLocal
202202
from dispatch.auth import service as user_service
203203
from dispatch.auth.models import UserUpdate
204204

@@ -217,7 +217,7 @@ def update_user(email: str, role: str):
217217
@click.password_option()
218218
def reset_user_password(email: str, password: str):
219219
"""Resets a user's password."""
220-
from dispatch.database import SessionLocal
220+
from dispatch.database.core import SessionLocal
221221
from dispatch.auth import service as user_service
222222
from dispatch.auth.models import UserUpdate
223223

src/dispatch/conference/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pydantic import validator
55
from sqlalchemy import Column, Integer, String, ForeignKey
66

7-
from dispatch.database import Base
7+
from dispatch.database.core import Base
88
from dispatch.messaging.strings import INCIDENT_CONFERENCE_DESCRIPTION
99
from dispatch.models import DispatchBase, ResourceMixin
1010

src/dispatch/conversation/messaging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dispatch.database import SessionLocal
1+
from dispatch.database.core import SessionLocal
22
from dispatch.plugin import service as plugin_service
33

44

src/dispatch/conversation/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from sqlalchemy import Column, String, Integer, ForeignKey
66

7-
from dispatch.database import Base
7+
from dispatch.database.core import Base
88
from dispatch.messaging.strings import INCIDENT_CONVERSATION_DESCRIPTION
99
from dispatch.models import DispatchBase, ResourceMixin
1010

src/dispatch/database/core.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import re
2+
import functools
3+
from typing import Any
4+
5+
from sqlalchemy import create_engine
6+
from sqlalchemy.ext.declarative import declarative_base, declared_attr
7+
from sqlalchemy.orm import sessionmaker
8+
from sqlalchemy_searchable import make_searchable
9+
from starlette.requests import Request
10+
11+
from dispatch.config import SQLALCHEMY_DATABASE_URI
12+
13+
14+
engine = create_engine(str(SQLALCHEMY_DATABASE_URI))
15+
SessionLocal = sessionmaker(bind=engine)
16+
17+
18+
def resolve_table_name(name):
19+
"""Resolves table names to their mapped names."""
20+
names = re.split("(?=[A-Z])", name) # noqa
21+
return "_".join([x.lower() for x in names if x])
22+
23+
24+
raise_attribute_error = object()
25+
26+
27+
def resolve_attr(obj, attr, default=None):
28+
"""Attempts to access attr via dotted notation, returns none if attr does not exist."""
29+
try:
30+
return functools.reduce(getattr, attr.split("."), obj)
31+
except AttributeError:
32+
return default
33+
34+
35+
class CustomBase:
36+
@declared_attr
37+
def __tablename__(self):
38+
return resolve_table_name(self.__name__)
39+
40+
41+
Base = declarative_base(cls=CustomBase)
42+
make_searchable(Base.metadata)
43+
44+
45+
def get_db(request: Request):
46+
return request.state.db
47+
48+
49+
def get_model_name_by_tablename(table_fullname: str) -> str:
50+
"""Returns the model name of a given table."""
51+
return get_class_by_tablename(table_fullname=table_fullname).__name__
52+
53+
54+
def get_class_by_tablename(table_fullname: str) -> Any:
55+
"""Return class reference mapped to table."""
56+
mapped_name = resolve_table_name(table_fullname)
57+
for c in Base._decl_class_registry.values():
58+
if hasattr(c, "__table__"):
59+
if c.__table__.fullname.lower() == mapped_name.lower():
60+
return c
61+
raise Exception(f"Incorrect tablename '{mapped_name}'. Check the name of your model.")
62+
63+
64+
def get_table_name_by_class_instance(class_instance: Base) -> str:
65+
"""Returns the name of the table for a given class instance."""
66+
return class_instance._sa_instance_state.mapper.mapped_table.name

0 commit comments

Comments
 (0)