Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions scripts/configure-guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
import textwrap
from collections import defaultdict
from typing import Annotated, Literal
from typing import Annotated, Any, Literal

import discord
from discord import VerificationLevel
Expand Down Expand Up @@ -100,6 +100,14 @@
"kick_members",
"ban_members",
"administrator",
"connect",
"speak",
"stream",
"use_soundboard",
"use_voice_activation",
"priority_speaker",
"deafen_members",
"mute_members",
]


Expand Down Expand Up @@ -138,9 +146,18 @@ class TextChannel(BaseModel):
channel_messages: list[MultilineString] = Field(default_factory=list)


class VoiceChannel(BaseModel):
type: Literal["voice"] = "voice"

name: str
permission_overwrites: list[PermissionOverwrite] = Field(default_factory=list)


class Category(BaseModel):
name: str
channels: list[Annotated[TextChannel | ForumChannel, Field(discriminator="type")]]
channels: list[
Annotated[TextChannel | ForumChannel | VoiceChannel, Field(discriminator="type")]
]
permission_overwrites: list[PermissionOverwrite] = Field(default_factory=list)


Expand All @@ -153,7 +170,7 @@ class GuildConfig(BaseModel):

@model_validator(mode="after")
def verify_system_channel_names(self) -> Self:
channel_names = []
channel_names: list[str] = []
for category in self.categories:
channel_names.extend(channel.name for channel in category.channels)

Expand Down Expand Up @@ -217,7 +234,13 @@ def verify_permission_roles(self) -> Self:
color=DARK_ORANGE,
hoist=True,
mentionable=True,
permissions=["kick_members", "ban_members"],
permissions=[
"kick_members",
"ban_members",
"priority_speaker",
"deafen_members",
"mute_members",
],
),
Role(
name=ROLE_MODERATORS,
Expand All @@ -229,6 +252,9 @@ def verify_permission_roles(self) -> Self:
"moderate_members",
"manage_messages",
"manage_threads",
"priority_speaker",
"deafen_members",
"mute_members",
],
),
Role(
Expand Down Expand Up @@ -278,6 +304,9 @@ def verify_permission_roles(self) -> Self:
"add_reactions",
"read_message_history",
"use_application_commands",
"connect",
"speak",
"use_voice_activation",
],
),
],
Expand Down Expand Up @@ -510,6 +539,17 @@ def verify_permission_roles(self) -> Self:
PermissionOverwrite(roles=ROLES_REGISTERED, allow=["view_channel"]),
],
),
Category(
name="Remote Attendees",
channels=[
TextChannel(name="remote-text", topic="Text chat for remote attendees"),
VoiceChannel(name="remote-voice"),
],
permission_overwrites=[
PermissionOverwrite(roles=[ROLE_EVERYONE], deny=["view_channel"]),
PermissionOverwrite(roles=ROLES_REGISTERED, allow=["view_channel"]),
],
),
Category(
name="Conference Organization",
channels=[
Expand Down Expand Up @@ -754,10 +794,14 @@ def get_forum(self, name: str) -> discord.ForumChannel:
raise RuntimeError(f"Could not find forum with name '{name}'")
return channel

def get_channel(self, name: str) -> discord.TextChannel | discord.ForumChannel:
def get_channel(
self, name: str
) -> discord.TextChannel | discord.ForumChannel | discord.VoiceChannel:
channel = discord_get(self.guild.channels, name=name)
if channel is None or not isinstance(channel, (discord.TextChannel, discord.ForumChannel)):
raise RuntimeError(f"Could not find text or forum channel with name '{name}'")
if channel is None or not isinstance(
channel, (discord.TextChannel, discord.ForumChannel, discord.VoiceChannel)
):
raise RuntimeError(f"Could not find text, forum, or voice channel with name '{name}'")
return channel

def get_role(self, name: str) -> discord.Role:
Expand All @@ -774,7 +818,7 @@ def get_category(self, name: str) -> discord.CategoryChannel:

async def ensure_channel_permissions(
self,
channel: discord.TextChannel | discord.ForumChannel,
channel: discord.TextChannel | discord.ForumChannel | discord.VoiceChannel,
permission_overwrite_templates: list[PermissionOverwrite],
) -> None:
logger.info("Ensure permissions for channel %s", channel.name)
Expand Down Expand Up @@ -830,6 +874,10 @@ async def ensure_categories_and_channels(self, category_templates: list[Category
await self.ensure_text_channel(
channel_template.name, category=category, position=channel_position
)
elif isinstance(channel_template, VoiceChannel):
await self.ensure_voice_channel(
channel_template.name, category=category, position=channel_position
)
elif isinstance(channel_template, ForumChannel):
await self.ensure_forum_channel(
channel_template.name,
Expand Down Expand Up @@ -873,6 +921,23 @@ async def ensure_text_channel(
logger.debug("Update position")
await channel.edit(position=position)

async def ensure_voice_channel(
self, name: str, *, category: discord.CategoryChannel | None, position: int
) -> None:
logger.info("Ensure voice channel %s at position %d", name, position)
channel = discord_get(self.guild.voice_channels, name=name)
if channel is None:
logger.debug("Create voice channel %s", name)
await self.guild.create_voice_channel(name=name, category=category, position=position)
else:
logger.debug("Found voice channel")
if channel.category != category:
logger.debug("Update category")
await channel.edit(category=category)
if channel.position != position:
logger.debug("Update position")
await channel.edit(position=position)

async def ensure_forum_channel(
self,
name: str,
Expand Down Expand Up @@ -1007,7 +1072,11 @@ async def ensure_channel_topics(self, category_templates: list[Category]) -> Non
logger.info("Ensure channel topics")
for category_template in category_templates:
for channel_template in category_template.channels:
if isinstance(channel_template, VoiceChannel):
continue # voice channels have no topic
channel = self.get_channel(channel_template.name)
if isinstance(channel, discord.VoiceChannel):
continue # voice channels have no topic
expected_topic = channel_template.topic
if channel.topic != expected_topic:
logger.debug("Update topic of channel %s", channel_template.name)
Expand Down Expand Up @@ -1076,7 +1145,7 @@ async def ensure_community_feature(
await self.guild.edit(verification_level=discord.VerificationLevel.medium)

if self.guild.default_notifications != discord.NotificationLevel.only_mentions:
self.guild.edit(default_notifications=discord.NotificationLevel.only_mentions)
await self.guild.edit(default_notifications=discord.NotificationLevel.only_mentions)

if "COMMUNITY" not in self.guild.features:
logger.debug("Enable guild 'COMMUNITY' feature")
Expand Down Expand Up @@ -1105,7 +1174,7 @@ async def on_ready(self) -> None:

await self.close()

async def on_error(self, event: str, /, *args, **kwargs) -> None: # noqa: ANN002,ANN003 (types)
async def on_error(self, event: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 (Any)
"""Event handler for uncaught exceptions."""
exc_type, exc_value, _exc_traceback = sys.exc_info()
if exc_type is None:
Expand Down