diff --git a/src/dishka/plotter/mermaid.py b/src/dishka/plotter/mermaid.py index 6d518d5c..169c6789 100644 --- a/src/dishka/plotter/mermaid.py +++ b/src/dishka/plotter/mermaid.py @@ -1,4 +1,4 @@ -import re +import html from dishka.plotter.model import Group, Node, NodeType, Renderer @@ -24,6 +24,12 @@ """ +# these symbols should be escaped additionally to html.escape +MERMAID_SYMBOLS_SUBST = str.maketrans({ + "<": "<", + ">": ">", +}) + class MermaidRenderer(Renderer): def __init__(self) -> None: self.nodes: dict[str, Node] = {} @@ -32,18 +38,20 @@ def _render_node(self, node: Node) -> str: if node.type is NodeType.ALIAS: return "" name = self._node_type(node) + self._escape(node.name) - return ( - f'class {node.id}["{name}"]' - + "{\n" - + (f" {node.source_name}()\n" if node.source_name else " \n") - + "".join( - f" {self.nodes[dep].name}\n" for dep in node.dependencies - ) - + "}\n" - ) + source_name = self._escape(node.source_name) + return "\n".join([ + f'class {node.id}["{name}"]{{', + f"{source_name}()" if source_name else " ", + *( + f"{self._escape(self.nodes[dep].name)}" + for dep in node.dependencies + ), + "}", + ]) def _escape(self, line: str) -> str: - return re.sub(r"[^\w_.\-]", "_", line) + line = line.translate(MERMAID_SYMBOLS_SUBST) + return html.escape(line, quote=True) def _render_node_deps(self, node: Node) -> list[str]: res: list[str] = [] diff --git a/src/dishka/text_rendering/name.py b/src/dishka/text_rendering/name.py index 0c1aeae4..06682410 100644 --- a/src/dishka/text_rendering/name.py +++ b/src/dishka/text_rendering/name.py @@ -13,6 +13,12 @@ def _render_args(hint: Any) -> str: def get_name(hint: Any, *, include_module: bool) -> str: + if isinstance(hint, list): + res = ",".join( + get_name(item, include_module=include_module) + for item in hint + ) + return f"[{res}]" if hint is ...: return "..." if func := getattr(object, "__func__", None): diff --git a/tests/unit/text_rendering/test_name.py b/tests/unit/text_rendering/test_name.py index 5dd665a8..14a97d63 100644 --- a/tests/unit/text_rendering/test_name.py +++ b/tests/unit/text_rendering/test_name.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from typing import Generic, TypeVar import pytest @@ -42,6 +43,7 @@ class GenericA(Generic[T]): (dishka.Scope, True, "dishka.entities.scope.Scope"), (GenericA[int], False, "GenericA[int]"), (GenericA[T], False, "GenericA[T]"), + (Callable[[str], str], False, "Callable[[str], str]"), (GenericA, False, "GenericA"), ], )