From 76bf365b91cea27b851029f97d3181990e77f1e0 Mon Sep 17 00:00:00 2001 From: MrDiver Date: Thu, 29 Dec 2022 07:44:23 +0100 Subject: [PATCH 1/4] ported functionality of Mobject from 3b1b to OpenGLMobject --- manim/event_handler/__init__.py | 5 + manim/event_handler/event_dispatcher.py | 91 ++ manim/event_handler/event_listener.py | 34 + manim/event_handler/event_type.py | 11 + manim/mobject/opengl/opengl_mobject.py | 1587 +++++++++++++++-------- manim/utils/color.py | 27 + manim/utils/space_ops.py | 16 + 7 files changed, 1221 insertions(+), 550 deletions(-) create mode 100644 manim/event_handler/__init__.py create mode 100644 manim/event_handler/event_dispatcher.py create mode 100644 manim/event_handler/event_listener.py create mode 100644 manim/event_handler/event_type.py diff --git a/manim/event_handler/__init__.py b/manim/event_handler/__init__.py new file mode 100644 index 0000000000..1a9a247106 --- /dev/null +++ b/manim/event_handler/__init__.py @@ -0,0 +1,5 @@ +from manim.event_handler.event_dispatcher import EventDispatcher + +# This is supposed to be a Singleton +# i.e., during runtime there should be only one object of Event Dispatcher +EVENT_DISPATCHER = EventDispatcher() diff --git a/manim/event_handler/event_dispatcher.py b/manim/event_handler/event_dispatcher.py new file mode 100644 index 0000000000..050379b5e1 --- /dev/null +++ b/manim/event_handler/event_dispatcher.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import numpy as np + +from manim.event_handler.event_listener import EventListener +from manim.event_handler.event_type import EventType + + +class EventDispatcher: + def __init__(self): + self.event_listeners: dict[EventType, list[EventListener]] = { + event_type: [] for event_type in EventType + } + self.mouse_point = np.array((0.0, 0.0, 0.0)) + self.mouse_drag_point = np.array((0.0, 0.0, 0.0)) + self.pressed_keys: set[int] = set() + self.draggable_object_listeners: list[EventListener] = [] + + def add_listener(self, event_listener: EventListener): + assert isinstance(event_listener, EventListener) + self.event_listeners[event_listener.event_type].append(event_listener) + return self + + def remove_listener(self, event_listener: EventListener): + assert isinstance(event_listener, EventListener) + try: + while event_listener in self.event_listeners[event_listener.event_type]: + self.event_listeners[event_listener.event_type].remove(event_listener) + except Exception: + # raise ValueError("Handler is not handling this event, so cannot remove it.") + pass + return self + + def dispatch(self, event_type: EventType, **event_data): + if event_type == EventType.MouseMotionEvent: + self.mouse_point = event_data["point"] + elif event_type == EventType.MouseDragEvent: + self.mouse_drag_point = event_data["point"] + elif event_type == EventType.KeyPressEvent: + self.pressed_keys.add(event_data["symbol"]) # Modifiers? + elif event_type == EventType.KeyReleaseEvent: + self.pressed_keys.difference_update({event_data["symbol"]}) # Modifiers? + elif event_type == EventType.MousePressEvent: + self.draggable_object_listeners = [ + listener + for listener in self.event_listeners[EventType.MouseDragEvent] + if listener.mobject.is_point_touching(self.mouse_point) + ] + elif event_type == EventType.MouseReleaseEvent: + self.draggable_object_listeners = [] + + propagate_event = None + + if event_type == EventType.MouseDragEvent: + for listener in self.draggable_object_listeners: + assert isinstance(listener, EventListener) + propagate_event = listener.callback(listener.mobject, event_data) + if propagate_event is not None and propagate_event is False: + return propagate_event + + elif event_type.value.startswith("mouse"): + for listener in self.event_listeners[event_type]: + if listener.mobject.is_point_touching(self.mouse_point): + propagate_event = listener.callback(listener.mobject, event_data) + if propagate_event is not None and propagate_event is False: + return propagate_event + + elif event_type.value.startswith("key"): + for listener in self.event_listeners[event_type]: + propagate_event = listener.callback(listener.mobject, event_data) + if propagate_event is not None and propagate_event is False: + return propagate_event + + return propagate_event + + def get_listeners_count(self) -> int: + return sum([len(value) for key, value in self.event_listeners.items()]) + + def get_mouse_point(self) -> np.ndarray: + return self.mouse_point + + def get_mouse_drag_point(self) -> np.ndarray: + return self.mouse_drag_point + + def is_key_pressed(self, symbol: int) -> bool: + return symbol in self.pressed_keys + + __iadd__ = add_listener + __isub__ = remove_listener + __call__ = dispatch + __len__ = get_listeners_count diff --git a/manim/event_handler/event_listener.py b/manim/event_handler/event_listener.py new file mode 100644 index 0000000000..9c923d4d57 --- /dev/null +++ b/manim/event_handler/event_listener.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable + + import manim.mobject.opengl.opengl_mobject as glmob + from manim.event_handler.event_type import EventType + + +class EventListener: + def __init__( + self, + mobject: glmob.OpenGLMobject, + event_type: EventType, + event_callback: Callable[[glmob.OpenGLMobject, dict[str, str]], None], + ): + self.mobject = mobject + self.event_type = event_type + self.callback = event_callback + + def __eq__(self, o: object) -> bool: + return_val = False + if isinstance(o, EventListener): + try: + return_val = ( + self.callback == o.callback + and self.mobject == o.mobject + and self.event_type == o.event_type + ) + except Exception: + pass + return return_val diff --git a/manim/event_handler/event_type.py b/manim/event_handler/event_type.py new file mode 100644 index 0000000000..b9a9eb2da3 --- /dev/null +++ b/manim/event_handler/event_type.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class EventType(Enum): + MouseMotionEvent = "mouse_motion_event" + MousePressEvent = "mouse_press_event" + MouseReleaseEvent = "mouse_release_event" + MouseDragEvent = "mouse_drag_event" + MouseScrollEvent = "mouse_scroll_event" + KeyPressEvent = "key_press_event" + KeyReleaseEvent = "key_release_event" diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 25159ab23c..7af7439b3a 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -2,22 +2,30 @@ import copy import itertools as it +import numbers +import os +import pickle import random import sys from functools import partialmethod, wraps from math import ceil -from typing import Iterable, Sequence +from typing import TYPE_CHECKING import moderngl import numpy as np from colour import Color +from typing_extensions import TypedDict from manim import config, logger from manim.constants import * +from manim.event_handler import EVENT_DISPATCHER +from manim.event_handler.event_listener import EventListener +from manim.event_handler.event_type import EventType +from manim.renderer.shader_wrapper import ShaderWrapper, get_colormap_code from manim.utils.bezier import integer_interpolate, interpolate from manim.utils.color import * -from manim.utils.color import Colors -from manim.utils.config_ops import _Data, _Uniforms +from manim.utils.color import Colors, get_colormap_list +from manim.utils.deprecation import deprecated # from ..utils.iterables import batch_by_property from manim.utils.iterables import ( @@ -34,20 +42,57 @@ from manim.utils.simple_functions import get_parameters from manim.utils.space_ops import ( angle_between_vectors, + angle_of_vector, + get_norm, normalize, rotation_matrix_transpose, ) +if TYPE_CHECKING: + from typing import Callable, Iterable, Sequence, Tuple, Union -def affects_shader_info_id(func): - @wraps(func) - def wrapper(self): - for mob in self.get_family(): - func(mob) - mob.refresh_shader_wrapper_id() - return self + from typing_extensions import TypeAlias + + TimeBasedUpdater: TypeAlias = Callable[[OpenGLMobject, float], OpenGLMobject | None] + NonTimeUpdater: TypeAlias = Callable[[OpenGLMobject], OpenGLMobject | None] + Updater: TypeAlias = Union[TimeBasedUpdater, NonTimeUpdater] + PointUpdateFunction: TypeAlias = Callable[[np.ndarray], np.ndarray] + + +class MobjectData(TypedDict, total=False): + points: np.ndarray + bounding_box: np.ndarray + rgbas: np.ndarray + + +def to_mobject_data(values: dict[str, np.ndarray] | MobjectData) -> MobjectData: + result = MobjectData( + points=np.array(values["points"]), + bounding_box=np.array(values["bounding_box"]), + rgbas=np.array(values["rgbas"]), + ) + return result - return wrapper + +class MobjectUniforms(TypedDict): + is_fixed_in_frame: np.ndarray + is_fixed_orientation: np.ndarray + gloss: np.ndarray + shadow: np.ndarray + reflectiveness: np.ndarray + + +def to_mobject_uniforms( + values: dict[str, np.ndarray] | MobjectUniforms +) -> MobjectUniforms: + result = MobjectUniforms( + is_fixed_in_frame=np.array(values["is_fixed_in_frame"]), + is_fixed_orientation=np.array(values["is_fixed_orientation"]), + gloss=np.array(values["gloss"]), + shadow=np.array(values["shadow"]), + reflectiveness=np.array(values["reflectiveness"]), + ) + return result class OpenGLMobject: @@ -66,96 +111,60 @@ class OpenGLMobject: """ - shader_dtype = [ + dim: int = 3 + shader_folder: str = "" + render_primitive: int = moderngl.TRIANGLE_STRIP + shader_dtype: Sequence[tuple[str, type, tuple[int]]] = [ ("point", np.float32, (3,)), ] - shader_folder = "" - - # _Data and _Uniforms are set as class variables to tell manim how to handle setting/getting these attributes later. - points = _Data() - bounding_box = _Data() - rgbas = _Data() - - is_fixed_in_frame = _Uniforms() - is_fixed_orientation = _Uniforms() - fixed_orientation_center = _Uniforms() # for fixed orientation reference - gloss = _Uniforms() - shadow = _Uniforms() def __init__( self, color=WHITE, - opacity=1, - dim=3, # TODO, get rid of this - # Lighting parameters - # Positive gloss up to 1 makes it reflect the light. - gloss=0.0, - # Positive shadow up to 1 makes a side opposite the light darker - shadow=0.0, - # For shaders - render_primitive=moderngl.TRIANGLES, - texture_paths=None, - depth_test=False, - # If true, the mobject will not get rotated according to camera position - is_fixed_in_frame=False, - is_fixed_orientation=False, - # Must match in attributes of vert shader - # Event listener - listen_to_events=False, - model_matrix=None, - should_render=True, + opacity: float = 1.0, + reflectiveness: float = 0.0, + shadow: float = 0.0, + gloss: float = 0.0, + texture_paths: dict[str, str] | None = None, + is_fixed_in_frame: bool = False, + depth_test: bool = False, name: str | None = None, **kwargs, ): - self.name = self.__class__.__name__ if name is None else name - # getattr in case data/uniforms are already defined in parent classes. - self.data = getattr(self, "data", {}) - self.uniforms = getattr(self, "uniforms", {}) - + self.color = color self.opacity = opacity - self.dim = dim # TODO, get rid of this - # Lighting parameters - # Positive gloss up to 1 makes it reflect the light. - self.gloss = gloss - # Positive shadow up to 1 makes a side opposite the light darker + self.reflectiveness = reflectiveness self.shadow = shadow - # For shaders - self.render_primitive = render_primitive + self.gloss = gloss self.texture_paths = texture_paths + self.is_fixed_in_frame = is_fixed_in_frame self.depth_test = depth_test - # If true, the mobject will not get rotated according to camera position - self.is_fixed_in_frame = float(is_fixed_in_frame) - self.is_fixed_orientation = float(is_fixed_orientation) - self.fixed_orientation_center = (0, 0, 0) - # Must match in attributes of vert shader - # Event listener - self.listen_to_events = listen_to_events - - self._submobjects = [] - self.parents = [] - self.parent = None - self.family = [self] - self.locked_data_keys = set() - self.needs_new_bounding_box = True - if model_matrix is None: - self.model_matrix = np.eye(4) - else: - self.model_matrix = model_matrix + self.name = self.__class__.__name__ if name is None else name + + # internal_state + self.submobjects: list[OpenGLMobject] = [] + self.parents: list[OpenGLMobject] = [] + self.family: list[OpenGLMobject] = [self] + self.locked_data_keys: set[str] = set() + self.needs_new_bounding_box: bool = True + self._is_animating: bool = False + self.saved_state: OpenGLMobject | None = None + self.target: OpenGLMobject | None = None + + self.data: MobjectData + self.uniforms: MobjectUniforms self.init_data() + self.init_uniforms() self.init_updaters() - # self.init_event_listners() + self.init_event_listeners() self.init_points() - self.color = Color(color) if color else None self.init_colors() - - self.shader_indices = None + self.init_shader_data() if self.depth_test: self.apply_depth_test() - self.should_render = should_render - @classmethod def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -167,17 +176,15 @@ def __str__(self): def __repr__(self): return str(self.name) - def __sub__(self, other): - return NotImplemented - - def __isub__(self, other): - return NotImplemented + def __add__(self, other: OpenGLMobject) -> OpenGLMobject: + if not isinstance(other, OpenGLMobject): + raise TypeError(f"Only Mobjects can be added to Mobjects not {type(other)}") + return self.get_group_class()(self, other) - def __add__(self, mobject): - return NotImplemented - - def __iadd__(self, mobject): - return NotImplemented + def __mul__(self, other: int) -> OpenGLMobject: + if not isinstance(other, int): + raise TypeError(f"Only int can be multiplied to Mobjects not {type(other)}") + return self.replicate(other) @classmethod def set_default(cls, **kwargs): @@ -227,12 +234,49 @@ def construct(self): else: cls.__init__ = cls._original__init__ + @property + def points(self): + return self.data["points"] + + @points.setter + def points(self, value): + self.data["points"] = value + + @property + def bounding_box(self): + return self.data["bounding_box"] + + @bounding_box.setter + def bounding_box(self, value): + self.data["bounding_box"] = value + + @property + def rgbas(self): + return self.data["rgbas"] + + @rgbas.setter + def rgbas(self, value): + self.data["rgbas"] = value + def init_data(self): """Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data. Subclasses can inherit and overwrite this method to extend `self.data`.""" - self.points = np.zeros((0, 3)) - self.bounding_box = np.zeros((3, 3)) - self.rgbas = np.zeros((1, 4)) + self.data = MobjectData( + points=np.zeros((0, 3)), + bounding_box=np.zeros((3, 3)), + rgbas=np.zeros((1, 4)), + ) + + def init_uniforms(self): + """Initializes the uniforms. + + Gets called upon creation""" + self.uniforms = MobjectUniforms( + is_fixed_in_frame=np.array(float(self.is_fixed_in_frame)), + gloss=np.array(self.gloss), + shadow=np.array(self.shadow), + reflectiveness=np.array(self.reflectiveness), + ) def init_colors(self): """Initializes the colors. @@ -248,52 +292,20 @@ def init_points(self): # Typically implemented in subclass, unless purposefully left blank pass - def set(self, **kwargs) -> OpenGLMobject: - """Sets attributes. - - Mainly to be used along with :attr:`animate` to - animate setting attributes. - - Examples - -------- - :: - - >>> mob = OpenGLMobject() - >>> mob.set(foo=0) - OpenGLMobject - >>> mob.foo - 0 - - Parameters - ---------- - **kwargs - The attributes and corresponding values to set. - - Returns - ------- - :class:`OpenGLMobject` - ``self`` - - - """ - - for attr, value in kwargs.items(): - setattr(self, attr, value) - - return self - def set_data(self, data): for key in data: - self.data[key] = data[key].copy() + self.data[key] = data[key] return self def set_uniforms(self, uniforms): - for key in uniforms: - self.uniforms[key] = uniforms[key] # Copy? + for key, value in uniforms.items(): + if isinstance(value, np.ndarray): + value = value.copy() + self.uniforms[key] = value return self @property - def animate(self): + def animate(self) -> _AnimationBuilder: """Used to animate the application of a method. .. warning:: @@ -380,6 +392,101 @@ def construct(self): """ return _AnimationBuilder(self) + def resize_points(self, new_length, resize_func=resize_array): + if new_length != len(self.points): + self.points = resize_func(self.points, new_length) + self.refresh_bounding_box() + return self + + def set_points(self, points): + if len(points) == len(self.points): + self.points[:] = points + elif isinstance(points, np.ndarray): + self.points = points.copy() + else: + self.points = np.array(points) + self.refresh_bounding_box() + return self + + def append_points(self, new_points): + self.points = np.vstack([self.points, new_points]) + self.refresh_bounding_box() + return self + + def reverse_points(self): + for mob in self.get_family(): + for key in mob.data: + mob.data[key] = mob.data[key][::-1] + return self + + def apply_points_function( + self, + func: PointUpdateFunction, + about_point=None, + about_edge=ORIGIN, + works_on_bounding_box=False, + ): + if about_point is None and about_edge is not None: + about_point = self.get_bounding_box_point(about_edge) + + for mob in self.get_family(): + arrs = [] + if mob.has_points(): + arrs.append(mob.points) + if works_on_bounding_box: + arrs.append(mob.get_bounding_box()) + + for arr in arrs: + if about_point is None: + arr[:] = func(arr) + else: + arr[:] = func(arr - about_point) + about_point + + if not works_on_bounding_box: + self.refresh_bounding_box(recurse_down=True) + else: + for parent in self.parents: + parent.refresh_bounding_box() + return self + + # ce only + def get_array_attrs(self): + """This method is used to determine which attributes of the :class:`~.OpenGLMobject` are arrays. + These can be used to apply functions to all of them at once. + """ + return ["points"] + + # ce only + def apply_over_attr_arrays(self, func): + """This method is used to apply a function to all attributes of the :class:`~.OpenGLMobject` that are arrays.""" + for attr in self.get_array_attrs(): + setattr(self, attr, func(getattr(self, attr))) + return self + + # ce only + def get_midpoint(self) -> np.ndarray: + """Get coordinates of the middle of the path that forms the :class:`~.OpenGLMobject`. + + Examples + -------- + + .. manim:: AngleMidPoint + :save_last_frame: + + class AngleMidPoint(Scene): + def construct(self): + line1 = Line(ORIGIN, 2*RIGHT) + line2 = Line(ORIGIN, 2*RIGHT).rotate_about_origin(80*DEGREES) + + a = Angle(line1, line2, radius=1.5, other_angle=False) + d = Dot(a.get_midpoint()).set_color(RED) + + self.add(line1, line2, a, d) + self.wait() + + """ + return self.point_from_proportion(0.5) + @property def width(self): """The width of the mobject. @@ -476,91 +583,6 @@ def depth(self): def depth(self, value): self.rescale_to_fit(value, 2, stretch=False) - def resize_points(self, new_length, resize_func=resize_array): - if new_length != len(self.points): - self.points = resize_func(self.points, new_length) - self.refresh_bounding_box() - return self - - def set_points(self, points): - if len(points) == len(self.points): - self.points[:] = points - elif isinstance(points, np.ndarray): - self.points = points.copy() - else: - self.points = np.array(points) - self.refresh_bounding_box() - return self - - def apply_over_attr_arrays(self, func): - for attr in self.get_array_attrs(): - setattr(self, attr, func(getattr(self, attr))) - return self - - def append_points(self, new_points): - self.points = np.vstack([self.points, new_points]) - self.refresh_bounding_box() - return self - - def reverse_points(self): - for mob in self.get_family(): - for key in mob.data: - mob.data[key] = mob.data[key][::-1] - return self - - def get_midpoint(self) -> np.ndarray: - """Get coordinates of the middle of the path that forms the :class:`~.OpenGLMobject`. - - Examples - -------- - - .. manim:: AngleMidPoint - :save_last_frame: - - class AngleMidPoint(Scene): - def construct(self): - line1 = Line(ORIGIN, 2*RIGHT) - line2 = Line(ORIGIN, 2*RIGHT).rotate_about_origin(80*DEGREES) - - a = Angle(line1, line2, radius=1.5, other_angle=False) - d = Dot(a.get_midpoint()).set_color(RED) - - self.add(line1, line2, a, d) - self.wait() - - """ - return self.point_from_proportion(0.5) - - def apply_points_function( - self, - func, - about_point=None, - about_edge=ORIGIN, - works_on_bounding_box=False, - ): - if about_point is None and about_edge is not None: - about_point = self.get_bounding_box_point(about_edge) - - for mob in self.get_family(): - arrs = [] - if mob.has_points(): - arrs.append(mob.points) - if works_on_bounding_box: - arrs.append(mob.get_bounding_box()) - - for arr in arrs: - if about_point is None: - arr[:] = func(arr) - else: - arr[:] = func(arr - about_point) + about_point - - if not works_on_bounding_box: - self.refresh_bounding_box(recurse_down=True) - else: - for parent in self.parents: - parent.refresh_bounding_box() - return self - # Others related to points def match_points(self, mobject): @@ -581,6 +603,7 @@ def construct(self): self.wait(0.5) """ self.set_points(mobject.points) + return self def clear_points(self): self.points = np.empty((0, 3)) @@ -631,11 +654,28 @@ def refresh_bounding_box(self, recurse_down=False, recurse_up=True): parent.refresh_bounding_box() return self - def is_point_touching(self, point, buff=MED_SMALL_BUFF): + def are_points_touching(self, points, buff: float = 0) -> np.ndarray: bb = self.get_bounding_box() mins = bb[0] - buff maxs = bb[2] + buff - return (point >= mins).all() and (point <= maxs).all() + return ((points >= mins) * (points <= maxs)).all(1) + + def is_point_touching(self, point, buff=MED_SMALL_BUFF): + return self.are_points_touching(np.array(point, ndmin=2), buff)[0] + + def is_touching(self, mobject: OpenGLMobject, buff: float = 1e-2) -> bool: + bb1 = self.get_bounding_box() + bb2 = mobject.get_bounding_box() + return not any( + ( + ( + bb2[2] < bb1[0] - buff + ).any(), # E.g. Right of mobject is left of self's left + ( + bb2[0] > bb1[2] + buff + ).any(), # E.g. Left of mobject is right of self's right + ) + ) # Family matters @@ -648,13 +688,13 @@ def __getitem__(self, value): def __iter__(self): return iter(self.split()) - def __len__(self): + def __len__(self) -> int: return len(self.split()) - def split(self): + def split(self) -> list[OpenGLMobject]: return self.submobjects - def assemble_family(self): + def assemble_family(self) -> OpenGLMobject: sub_families = (sm.get_family() for sm in self.submobjects) self.family = [self, *uniq_chain(*sub_families)] self.refresh_has_updater_status() @@ -663,17 +703,39 @@ def assemble_family(self): parent.assemble_family() return self - def get_family(self, recurse=True): + def get_family(self, recurse=True) -> list[OpenGLMobject]: if recurse and hasattr(self, "family"): return self.family else: return [self] - def family_members_with_points(self): + def family_members_with_points(self) -> list[OpenGLMobject]: return [m for m in self.get_family() if m.has_points()] + def get_ancestors(self, extended: bool = False) -> list[OpenGLMobject]: + """ + Returns parents, grandparents, etc. + Order of result should be from higher members of the hierarchy down. + + If extended is set to true, it includes the ancestors of all family members, + e.g. any other parents of a submobject + """ + ancestors = [] + to_process = list(self.get_family(recurse=extended)) + excluded = set(to_process) + while to_process: + for p in to_process.pop().parents: + if p not in excluded: + ancestors.append(p) + to_process.append(p) + # Ensure mobjects highest in the hierarchy show up first + ancestors.reverse() + # Remove list redundancies while preserving order + return list(dict.fromkeys(ancestors)) + def add( - self, *mobjects: OpenGLMobject, update_parent: bool = False + self, + *mobjects: OpenGLMobject, ) -> OpenGLMobject: """Add mobjects as submobjects. @@ -733,17 +795,6 @@ def add( ValueError: OpenGLMobject cannot contain self """ - if update_parent: - assert len(mobjects) == 1, "Can't set multiple parents." - mobjects[0].parent = self - - if self in mobjects: - raise ValueError("OpenGLMobject cannot contain self") - if any(mobjects.count(elem) > 1 for elem in mobjects): - logger.warning( - "Attempted adding some Mobject as a child more than once, " - "this is not possible. Repetitions are ignored.", - ) for mobject in mobjects: if not isinstance(mobject, OpenGLMobject): raise TypeError("All submobjects must be of type OpenGLMobject") @@ -754,44 +805,8 @@ def add( self.assemble_family() return self - def insert(self, index: int, mobject: OpenGLMobject, update_parent: bool = False): - """Inserts a mobject at a specific position into self.submobjects - - Effectively just calls ``self.submobjects.insert(index, mobject)``, - where ``self.submobjects`` is a list. - - Highly adapted from ``OpenGLMobject.add``. - - Parameters - ---------- - index - The index at which - mobject - The mobject to be inserted. - update_parent - Whether or not to set ``mobject.parent`` to ``self``. - """ - - if update_parent: - mobject.parent = self - - if mobject is self: - raise ValueError("OpenGLMobject cannot contain self") - - if not isinstance(mobject, OpenGLMobject): - raise TypeError("All submobjects must be of type OpenGLMobject") - - if mobject not in self.submobjects: - self.submobjects.insert(index, mobject) - - if self not in mobject.parents: - mobject.parents.append(self) - - self.assemble_family() - return self - def remove( - self, *mobjects: OpenGLMobject, update_parent: bool = False + self, *mobjects: OpenGLMobject, reassemble: bool = True ) -> OpenGLMobject: """Remove :attr:`submobjects`. @@ -814,16 +829,13 @@ def remove( :meth:`add` """ - if update_parent: - assert len(mobjects) == 1, "Can't remove multiple parents." - mobjects[0].parent = None - for mobject in mobjects: if mobject in self.submobjects: self.submobjects.remove(mobject) if self in mobject.parents: mobject.parents.remove(self) - self.assemble_family() + if reassemble: + self.assemble_family() return self def add_to_back(self, *mobjects: OpenGLMobject) -> OpenGLMobject: @@ -871,7 +883,7 @@ def add_to_back(self, *mobjects: OpenGLMobject) -> OpenGLMobject: :meth:`add` """ - self.submobjects = list_update(mobjects, self.submobjects) + self.set_submobjects(list_update(mobjects, self.submobjects)) return self def replace_submobject(self, index, new_submob): @@ -882,32 +894,39 @@ def replace_submobject(self, index, new_submob): self.assemble_family() return self - def invert(self, recursive=False): - """Inverts the list of :attr:`submobjects`. + def insert_submobject(self, index: int, mobject: OpenGLMobject): + """Inserts a mobject at a specific position into self.submobjects + + Effectively just calls ``self.submobjects.insert(index, mobject)``, + where ``self.submobjects`` is a list. + + Highly adapted from ``OpenGLMobject.add``. Parameters ---------- - recursive - If ``True``, all submobject lists of this mobject's family are inverted. + index + The index at which + mobject + The mobject to be inserted. + update_parent + Whether or not to set ``mobject.parent`` to ``self``. + """ + if mobject is self: + raise ValueError("OpenGLMobject cannot contain self") - Examples - -------- + if not isinstance(mobject, OpenGLMobject): + raise TypeError("All submobjects must be of type OpenGLMobject") - .. manim:: InvertSumobjectsExample + if mobject not in self.submobjects: + self.submobjects.insert(index, mobject) - class InvertSumobjectsExample(Scene): - def construct(self): - s = VGroup(*[Dot().shift(i*0.1*RIGHT) for i in range(-20,20)]) - s2 = s.copy() - s2.invert() - s2.shift(DOWN) - self.play(Write(s), Write(s2)) - """ - if recursive: - for submob in self.submobjects: - submob.invert(recursive=True) - list.reverse(self.submobjects) self.assemble_family() + return self + + def set_submobjects(self, submobject_list: list[OpenGLMobject]): + self.remove(*self.submobjects, reassemble=False) + self.add(*submobject_list) + return self # Submobject organization @@ -935,7 +954,8 @@ def construct(self): self.center() return self - def arrange_in_grid( + # !TODO this differs a lot from 3b1b/manim + def arrange_in_grid_legacy( self, rows: int | None = None, cols: int | None = None, @@ -1062,7 +1082,7 @@ def init_size(num, alignments, sizes): # This is favored over rows>cols since in general # the sceene is wider than high. if rows is None: - rows = ceil(len(mobs) / cols) + rows = ceil(len(mobs) / cols) # type: ignore if cols is None: cols = ceil(len(mobs) / rows) if rows * cols < len(mobs): @@ -1117,7 +1137,7 @@ def init_alignments(alignments, num, mapping, name, dir): raise ValueError( 'flow_order must be one of the following values: "dr", "rd", "ld" "dl", "ru", "ur", "lu", "ul".', ) - flow_order = mapper[flow_order] + flow_order = mapper[flow_order] # type: ignore # Reverse row_alignments and row_heights. Necessary since the # grid filling is handled bottom up for simplicity reasons. @@ -1136,7 +1156,11 @@ def reverse(maybe_list): # properties of 0. mobs.extend([placeholder] * (rows * cols - len(mobs))) - grid = [[mobs[flow_order(r, c)] for c in range(cols)] for r in range(rows)] + + grid = [ + [mobs[flow_order(r, c)] for c in range(cols)] # type:ignore + for r in range(rows) + ] measured_heigths = [ max(grid[r][c].height for c in range(cols)) for r in range(rows) @@ -1163,7 +1187,7 @@ def init_sizes(sizes, num, measures, name): x = 0 for c in range(cols): if grid[r][c] is not placeholder: - alignment = row_alignments[r] + col_alignments[c] + alignment = row_alignments[r] + col_alignments[c] # type:ignore line = Line( x * RIGHT + y * UP, (x + widths[c]) * RIGHT + (y + heights[r]) * UP, @@ -1179,20 +1203,76 @@ def init_sizes(sizes, num, measures, name): self.move_to(start_pos) return self - def get_grid(self, n_rows, n_cols, height=None, **kwargs): - """ - Returns a new mobject containing multiple copies of this one - arranged in a grid - """ - grid = self.duplicate(n_rows * n_cols) - grid.arrange_in_grid(n_rows, n_cols, **kwargs) - if height is not None: - grid.set_height(height) - return grid + def arrange_in_grid( + self, + n_rows: int | None = None, + n_cols: int | None = None, + buff: float | None = None, + h_buff: float | None = None, + v_buff: float | None = None, + buff_ratio: float | None = None, + h_buff_ratio: float = 0.5, + v_buff_ratio: float = 0.5, + aligned_edge: np.ndarray = ORIGIN, + fill_rows_first: bool = True, + ): + submobs = self.submobjects + if n_rows is None and n_cols is None: + n_rows = int(np.sqrt(len(submobs))) + if n_rows is None and n_cols is not None: + n_rows = len(submobs) // n_cols + if n_cols is None and n_rows is not None: + n_cols = len(submobs) // n_rows + + if buff is not None: + h_buff = buff + v_buff = buff + else: + if buff_ratio is not None: + v_buff_ratio = buff_ratio + h_buff_ratio = buff_ratio + if h_buff is None: + h_buff = h_buff_ratio * self[0].get_width() + if v_buff is None: + v_buff = v_buff_ratio * self[0].get_height() + + x_unit = h_buff + max([sm.get_width() for sm in submobs]) + y_unit = v_buff + max([sm.get_height() for sm in submobs]) + + for index, sm in enumerate(submobs): + if fill_rows_first: + x, y = index % n_cols, index // n_cols # type: ignore + else: + x, y = index // n_rows, index % n_rows # type: ignore + sm.move_to(ORIGIN, aligned_edge) + sm.shift(x * x_unit * RIGHT + y * y_unit * DOWN) + self.center() + return self - def duplicate(self, n: int): - """Returns an :class:`~.OpenGLVGroup` containing ``n`` copies of the mobject.""" - return self.get_group_class()(*[self.copy() for _ in range(n)]) + def arrange_to_fit_dim(self, length: float, dim: int, about_edge=ORIGIN): + ref_point = self.get_bounding_box_point(about_edge) + n_submobs = len(self.submobjects) + if n_submobs <= 1: + return + total_length = sum(sm.length_over_dim(dim) for sm in self.submobjects) + buff = (length - total_length) / (n_submobs - 1) + vect = np.zeros(self.dim) + vect[dim] = 1 + x = 0 + for submob in self.submobjects: + submob.set_coord(x, dim, -vect) + x += submob.length_over_dim(dim) + buff + self.move_to(ref_point, about_edge) + return self + + def arrange_to_fit_width(self, width: float, about_edge=ORIGIN): + return self.arrange_to_fit_dim(width, 0, about_edge) + + def arrange_to_fit_height(self, height: float, about_edge=ORIGIN): + return self.arrange_to_fit_dim(height, 1, about_edge) + + def arrange_to_fit_depth(self, depth: float, about_edge=ORIGIN): + return self.arrange_to_fit_dim(depth, 2, about_edge) def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): """Sorts the list of :attr:`submobjects` by a function defined by ``submob_func``.""" @@ -1225,8 +1305,8 @@ def construct(self): self.assemble_family() return self - def invert(self, recursive=False): - """Inverts the list of :attr:`submobjects`. + def reverse_submobjects(self, recursive=False): + """Reverses the list of :attr:`submobjects`. Parameters ---------- @@ -1236,24 +1316,58 @@ def invert(self, recursive=False): Examples -------- - .. manim:: InvertSumobjectsExample + .. manim:: ReverseSumobjectsExample - class InvertSumobjectsExample(Scene): + class ReverseSumobjectsExample(Scene): def construct(self): s = VGroup(*[Dot().shift(i*0.1*RIGHT) for i in range(-20,20)]) s2 = s.copy() - s2.invert() + s2.reverse_submobjects() s2.shift(DOWN) self.play(Write(s), Write(s2)) """ if recursive: for submob in self.submobjects: - submob.invert(recursive=True) - list.reverse(self.submobjects) + submob.reverse_submobjects(recursive=True) + self.submobjects.reverse() + self.assemble_family() # Copying - def copy(self, shallow: bool = False): + @staticmethod + def stash_mobject_pointers(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + uncopied_attrs = ["parents", "target", "saved_state"] + stash = {} + for attr in uncopied_attrs: + if hasattr(self, attr): + value = getattr(self, attr) + stash[attr] = value + null_value = [] if isinstance(value, list) else None + setattr(self, attr, null_value) + result = func(self, *args, **kwargs) + self.__dict__.update(stash) + return result + + return wrapper + + @stash_mobject_pointers + def serialize(self) -> bytes: + return pickle.dumps(self) + + def deserialize(self, data: bytes) -> OpenGLMobject: + self.become(pickle.loads(data)) + return self + + def deepcopy(self) -> OpenGLMobject: + try: + return pickle.loads(pickle.dumps(self)) + except AttributeError: + return copy.deepcopy(self) + + @stash_mobject_pointers + def copy(self, deep: bool = False) -> OpenGLMobject: """Create and return an identical copy of the :class:`OpenGLMobject` including all :attr:`submobjects`. @@ -1271,70 +1385,58 @@ def copy(self, shallow: bool = False): ---- The clone is initially not visible in the Scene, even if the original was. """ - if not shallow: + if deep: return self.deepcopy() - # TODO, either justify reason for shallow copy, or - # remove this redundancy everywhere - # return self.deepcopy() - - parents = self.parents - self.parents = [] - copy_mobject = copy.copy(self) - self.parents = parents - - copy_mobject.data = dict(self.data) - for key in self.data: - copy_mobject.data[key] = self.data[key].copy() - - # TODO, are uniforms ever numpy arrays? - copy_mobject.uniforms = dict(self.uniforms) - - copy_mobject.submobjects = [] - copy_mobject.add(*(sm.copy() for sm in self.submobjects)) - copy_mobject.match_updaters(self) + result = copy.copy(self) + + # The line above is only a shallow copy, so the internal + # data which are numpyu arrays or other mobjects still + # need to be further copied. + result.data = to_mobject_data(self.data) + result.uniforms = to_mobject_uniforms(self.uniforms) + + # Instead of adding using result.add, which does some checks for updating + # updater statues and bounding box, just directly modify the family-related + # lists + result.submobjects = [sm.copy() for sm in self.submobjects] + for sm in result.submobjects: + sm.parents = [result] + result.family = [ + result, + *it.chain(*(sm.get_family() for sm in result.submobjects)), + ] - copy_mobject.needs_new_bounding_box = self.needs_new_bounding_box + # Similarly, instead of calling match_updaters, since we know the status + # won't have changed, just directly match. + result.non_time_updaters = list(self.non_time_updaters) + result.time_based_updaters = list(self.time_based_updaters) - # Make sure any mobject or numpy array attributes are copied family = self.get_family() for attr, value in list(self.__dict__.items()): if ( isinstance(value, OpenGLMobject) - and value in family and value is not self + and value in family ): - setattr(copy_mobject, attr, value.copy()) + setattr(result, attr, result.family[self.family.index(value)]) if isinstance(value, np.ndarray): - setattr(copy_mobject, attr, value.copy()) - # if isinstance(value, ShaderWrapper): - # setattr(copy_mobject, attr, value.copy()) - return copy_mobject - - def deepcopy(self): - parents = self.parents - self.parents = [] - result = copy.deepcopy(self) - self.parents = parents + setattr(result, attr, value.copy()) + if isinstance(value, ShaderWrapper): + setattr(result, attr, value.copy()) return result def generate_target(self, use_deepcopy: bool = False): - self.target = None # Prevent exponential explosion - if use_deepcopy: - self.target = self.deepcopy() - else: - self.target = self.copy() + target: OpenGLMobject = self.copy(use_deepcopy=use_deepcopy) + target.saved_state = self.saved_state + self.target = target return self.target def save_state(self, use_deepcopy: bool = False): """Save the current state (position, color & size). Can be restored with :meth:`~.OpenGLMobject.restore`.""" - if hasattr(self, "saved_state"): - # Prevent exponential growth of data - self.saved_state = None - if use_deepcopy: - self.saved_state = self.deepcopy() - else: - self.saved_state = self.copy() + saved_state: OpenGLMobject = self.copy(deep=use_deepcopy) + saved_state.target = self.target + self.saved_state = saved_state return self def restore(self): @@ -1344,43 +1446,113 @@ def restore(self): self.become(self.saved_state) return self + def save_to_file(self, file_path: str): + with open(file_path, "wb") as fp: + fp.write(self.serialize()) + logger.info(f"Saved mobject to {file_path}") + return self + + @staticmethod + def load(file_path): + if not os.path.exists(file_path): + logger.error(f"No file found at {file_path}") + sys.exit(2) + with open(file_path, "rb") as fp: + mobject = pickle.load(fp) + return mobject + + # Creating new Mobjects from this one + + def replicate(self, n: int) -> OpenGLGroup: + """Returns an :class:`~.OpenGLVGroup` containing ``n`` copies of the mobject.""" + group_class = self.get_group_class() + return group_class(*(self.copy() for _ in range(n))) + + def get_grid_legacy(self, n_rows, n_cols, height=None, **kwargs): + """ + Returns a new mobject containing multiple copies of this one + arranged in a grid + """ + grid = self.duplicate(n_rows * n_cols) + grid.arrange_in_grid(n_rows, n_cols, **kwargs) + if height is not None: + grid.set_height(height) + return grid + + def get_grid( + self, + n_rows: int, + n_cols: int, + height: float | None = None, + width: float | None = None, + group_by_rows: bool = False, + group_by_cols: bool = False, + **kwargs, + ) -> OpenGLGroup: + """ + Returns a new mobject containing multiple copies of this one + arranged in a grid + """ + total = n_rows * n_cols + grid = self.replicate(total) + if group_by_cols: + kwargs["fill_rows_first"] = False + grid.arrange_in_grid(n_rows, n_cols, **kwargs) + if height is not None: + grid.set_height(height) + if width is not None: + grid.set_height(width) + + group_class = self.get_group_class() + if group_by_rows: + return group_class(*(grid[n : n + n_cols] for n in range(0, total, n_cols))) + elif group_by_cols: + return group_class(*(grid[n : n + n_rows] for n in range(0, total, n_rows))) + else: + return grid + # Updating - def init_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] - self.has_updaters = False - self.updating_suspended = False + def init_updaters(self) -> None: + self.time_based_updaters: list[TimeBasedUpdater] = [] + self.non_time_updaters: list[NonTimeUpdater] = [] + self.has_updaters: bool = False + self.updating_suspended: bool = False - def update(self, dt=0, recurse=True): + def update(self, dt: float = 0, recurse: bool = True) -> OpenGLMobject: if not self.has_updaters or self.updating_suspended: return self - for updater in self.time_based_updaters: - updater(self, dt) - for updater in self.non_time_updaters: - updater(self) + for time_updater in self.time_based_updaters: + time_updater(self, dt) + for non_time_updater in self.non_time_updaters: + non_time_updater(self) if recurse: for submob in self.submobjects: submob.update(dt, recurse) return self - def get_time_based_updaters(self): + def get_time_based_updaters(self) -> list[TimeBasedUpdater]: return self.time_based_updaters - def has_time_based_updater(self): + def has_time_based_updater(self) -> bool: return len(self.time_based_updaters) > 0 - def get_updaters(self): + def get_updaters(self) -> list[Updater]: return self.time_based_updaters + self.non_time_updaters - def get_family_updaters(self): + def get_family_updaters(self) -> list[Updater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) - def add_updater(self, update_function, index=None, call_updater=False): + def add_updater( + self, + update_function: Updater, + index: int | None = None, + call_updater: bool = False, + ) -> OpenGLMobject: if "dt" in get_parameters(update_function): - updater_list = self.time_based_updaters + updater_list: list[Updater] = self.time_based_updaters # type: ignore else: - updater_list = self.non_time_updaters + updater_list: list[Updater] = self.non_time_updaters # type: ignore if index is None: updater_list.append(update_function) @@ -1392,14 +1564,18 @@ def add_updater(self, update_function, index=None, call_updater=False): self.update() return self - def remove_updater(self, update_function): - for updater_list in [self.time_based_updaters, self.non_time_updaters]: + def remove_updater(self, update_function: Updater) -> OpenGLMobject: + updater_lists: list[list[Updater]] = [ + self.time_based_updaters, # type: ignore + self.non_time_updaters, # type: ignore + ] + for updater_list in updater_lists: while update_function in updater_list: updater_list.remove(update_function) self.refresh_has_updater_status() return self - def clear_updaters(self, recurse=True): + def clear_updaters(self, recurse: bool = True) -> OpenGLMobject: self.time_based_updaters = [] self.non_time_updaters = [] self.refresh_has_updater_status() @@ -1408,20 +1584,22 @@ def clear_updaters(self, recurse=True): submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, mobject: OpenGLMobject) -> OpenGLMobject: self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self, recurse=True): + def suspend_updating(self, recurse: bool = True) -> OpenGLMobject: self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating(self, recurse=True, call_updater=True): + def resume_updating( + self, recurse: bool = True, call_updater: bool = True + ) -> OpenGLMobject: self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -1432,13 +1610,25 @@ def resume_updating(self, recurse=True, call_updater=True): self.update(dt=0, recurse=recurse) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> OpenGLMobject: self.has_updaters = any(mob.get_updaters() for mob in self.get_family()) return self + # Check if mark as static or not for camera + + def is_changing(self) -> bool: + return self.has_updaters or self._is_animating + + def set_animating_status( + self, is_animating: bool, recurse: bool = True + ) -> OpenGLMobject: + for mob in (*self.get_family(recurse), *self.get_ancestors(extended=True)): + mob._is_animating = is_animating + return self + # Transforming operations - def shift(self, vector): + def shift(self, vector) -> OpenGLMobject: self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -1448,9 +1638,10 @@ def shift(self, vector): def scale( self, - scale_factor: float, - about_point: Sequence[float] | None = None, - about_edge: Sequence[float] = ORIGIN, + scale_factor: float | np.ndarray, + min_scale_factor: float = 1e-8, + about_point: Sequence[float] | np.ndarray | None = None, + about_edge: Sequence[float] | np.ndarray = ORIGIN, **kwargs, ) -> OpenGLMobject: r"""Scale the size by a factor. @@ -1469,6 +1660,10 @@ def scale( The scaling factor :math:`\alpha`. If :math:`0 < |\alpha| < 1`, the mobject will shrink, and for :math:`|\alpha| > 1` it will grow. Furthermore, if :math:`\alpha < 0`, the mobject is also flipped. + + min_scale_factor + The minimum scaling factor that is used such that the mobject is not scaled to zero. + kwargs Additional keyword arguments passed to :meth:`apply_points_function_about_point`. @@ -1499,16 +1694,33 @@ def construct(self): :meth:`move_to` """ + if isinstance(scale_factor, numbers.Number): + scale_factor = max(scale_factor, min_scale_factor) + else: + scale_factor = np.array(scale_factor).clip(min=min_scale_factor) # type: ignore self.apply_points_function( lambda points: scale_factor * points, about_point=about_point, about_edge=about_edge, works_on_bounding_box=True, - **kwargs, ) + for mob in self.get_family(): + mob._handle_scale_side_effects(scale_factor) return self - def stretch(self, factor, dim, **kwargs): + def _handle_scale_side_effects(self, scale_factor: float | np.ndarray) -> None: + """In case subclasses, such as DecimalNumber, need to make + any other changes when the size gets altered by scaling. + This method can be overridden in subclasses. + + Parameters + ---------- + scale_factor + The scaling factor :math:`\alpha`. If :math:`0 < |\alpha| < 1` + """ + pass + + def stretch(self, factor: float, dim: int, **kwargs) -> OpenGLMobject: def func(points): points[:, dim] *= factor return points @@ -1516,16 +1728,16 @@ def func(points): self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle, axis=OUT): - return self.rotate(angle, axis, about_point=ORIGIN) + def rotate_about_origin(self, angle: float, axis=OUT) -> OpenGLMobject: + return self.rotate(angle, axis, about_point=ORIGIN) # type: ignore def rotate( self, - angle, + angle: float, axis=OUT, about_point: Sequence[float] | None = None, **kwargs, - ): + ) -> OpenGLMobject: """Rotates the :class:`~.OpenGLMobject` about a certain point.""" rot_matrix_T = rotation_matrix_transpose(angle, axis) self.apply_points_function( @@ -1535,7 +1747,7 @@ def rotate( ) return self - def flip(self, axis=UP, **kwargs): + def flip(self, axis=UP, **kwargs) -> OpenGLMobject: """Flips/Mirrors an mobject about its center. Examples @@ -1554,7 +1766,7 @@ def construct(self): """ return self.rotate(TAU / 2, axis, **kwargs) - def apply_function(self, function, **kwargs): + def apply_function(self, function: PointUpdateFunction, **kwargs) -> OpenGLMobject: # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -1563,16 +1775,20 @@ def apply_function(self, function, **kwargs): ) return self - def apply_function_to_position(self, function): + def apply_function_to_position( + self, function: PointUpdateFunction + ) -> OpenGLMobject: self.move_to(function(self.get_center())) return self - def apply_function_to_submobject_positions(self, function): + def apply_function_to_submobject_positions( + self, function: PointUpdateFunction + ) -> OpenGLMobject: for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix, **kwargs): + def apply_matrix(self, matrix, **kwargs) -> OpenGLMobject: # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -1584,7 +1800,7 @@ def apply_matrix(self, matrix, **kwargs): ) return self - def apply_complex_function(self, function, **kwargs): + def apply_complex_function(self, function, **kwargs) -> OpenGLMobject: """Applies a complex function to a :class:`OpenGLMobject`. The x and y coordinates correspond to the real and imaginary parts respectively. @@ -1629,7 +1845,7 @@ def hierarchical_model_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(model_matrices))) - def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): + def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0) -> OpenGLMobject: for mob in self.family_members_with_points(): alphas = np.dot(mob.points, np.transpose(axis)) alphas -= min(alphas) @@ -1646,19 +1862,21 @@ def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): # Positioning methods - def center(self): + def center(self) -> OpenGLMobject: """Moves the mobject to the center of the Scene.""" self.shift(-self.get_center()) return self - def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def align_on_border( + self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER + ) -> OpenGLMobject: """ Direction just needs to be a vector pointing towards side or corner in the 2d plane. """ target_point = np.sign(direction) * ( - config["frame_x_radius"], - config["frame_y_radius"], + config.frame_x_radius, + config.frame_y_radius, 0, ) point_to_align = self.get_bounding_box_point(direction) @@ -1667,10 +1885,12 @@ def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): self.shift(shift_val) return self - def to_corner(self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_corner( + self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER + ) -> OpenGLMobject: return self.align_on_border(corner, buff) - def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER) -> OpenGLMobject: return self.align_on_border(edge, buff) def next_to( @@ -1682,7 +1902,7 @@ def next_to( submobject_to_align=None, index_of_submobject_to_align=None, coor_mask=np.array([1, 1, 1]), - ): + ) -> OpenGLMobject: """Move this :class:`~.OpenGLMobject` next to another's :class:`~.OpenGLMobject` or coordinate. Examples @@ -1725,7 +1945,7 @@ def construct(self): return self def shift_onto_screen(self, **kwargs): - space_lengths = [config["frame_x_radius"], config["frame_y_radius"]] + space_lengths = [config.frame_x_radius, config.frame_y_radius] for vect in UP, DOWN, LEFT, RIGHT: dim = np.argmax(np.abs(vect)) buff = kwargs.get("buff", DEFAULT_MOBJECT_TO_EDGE_BUFFER) @@ -1831,6 +2051,36 @@ def set_depth(self, depth, stretch=False, **kwargs): scale_to_fit_depth = set_depth + def set_max_width(self, max_width: float, **kwargs): + if self.get_width() > max_width: + self.set_width(max_width, **kwargs) + return self + + def set_max_height(self, max_height: float, **kwargs): + if self.get_height() > max_height: + self.set_height(max_height, **kwargs) + return self + + def set_max_depth(self, max_depth: float, **kwargs): + if self.get_depth() > max_depth: + self.set_depth(max_depth, **kwargs) + return self + + def set_min_width(self, min_width: float, **kwargs): + if self.get_width() < min_width: + self.set_width(min_width, **kwargs) + return self + + def set_min_height(self, min_height: float, **kwargs): + if self.get_height() < min_height: + self.set_height(min_height, **kwargs) + return self + + def set_min_depth(self, min_depth: float, **kwargs): + if self.get_depth() < min_depth: + self.set_depth(min_depth, **kwargs) + return self + def set_coord(self, value, dim, direction=ORIGIN): curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) @@ -1860,7 +2110,7 @@ def move_to( self, point_or_mobject, aligned_edge=ORIGIN, - coor_mask=np.array([1, 1, 1]), + coor_mask=np.array((1, 1, 1)), ): """Move center of the :class:`~.OpenGLMobject` to certain coordinate.""" if isinstance(point_or_mobject, OpenGLMobject): @@ -1899,12 +2149,13 @@ def surround( self.scale((length + buff) / length) return self - def put_start_and_end_on(self, start, end): + # ! TODO: Check implementation of 3b1b for this method + def put_start_and_end_on_legacy(self, start, end): curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start if np.all(curr_vect == 0): raise Exception("Cannot position endpoints of closed loop") - target_vect = np.array(end) - np.array(start) + target_vect = np.asarray(end) - np.asarray(start) axis = ( normalize(np.cross(curr_vect, target_vect)) if np.linalg.norm(np.cross(curr_vect, target_vect)) != 0 @@ -1922,9 +2173,32 @@ def put_start_and_end_on(self, start, end): self.shift(start - curr_start) return self + def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray): + curr_start, curr_end = self.get_start_and_end() + curr_vect = curr_end - curr_start + if np.all(curr_vect == 0): + raise Exception("Cannot position endpoints of closed loop") + target_vect = end - start + self.scale( + get_norm(target_vect) / get_norm(curr_vect), + about_point=curr_start, + ) + self.rotate( + angle_of_vector(target_vect) - angle_of_vector(curr_vect), + ) + self.rotate( + np.arctan2(curr_vect[2], get_norm(curr_vect[:2])) + - np.arctan2(target_vect[2], get_norm(target_vect[:2])), + axis=np.array([-target_vect[1], target_vect[0], 0]), + ) + self.shift(start - self.get_start()) + return self + # Color functions - def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True): + def set_rgba_array_legacy( + self, color=None, opacity=None, name="rgbas", recurse=True + ): if color is not None: rgbs = np.array([color_to_rgb(c) for c in listify(color)]) if opacity is not None: @@ -1954,7 +2228,9 @@ def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True): mob.data[name] = rgbas.copy() return self - def set_rgba_array_direct(self, rgbas: np.ndarray, name="rgbas", recurse=True): + def set_rgba_array( + self, rgba_array: np.ndarray, name: str = "rgbas", recurse: bool = False + ): """Directly set rgba data from `rgbas` and optionally do the same recursively with submobjects. This can be used if the `rgbas` have already been generated with the correct shape and simply need to be set. @@ -1969,7 +2245,58 @@ def set_rgba_array_direct(self, rgbas: np.ndarray, name="rgbas", recurse=True): set to true to recursively apply this method to submobjects """ for mob in self.get_family(recurse): - mob.data[name] = rgbas.copy() + mob.data[name] = np.array(rgba_array) # type: ignore + return self + + def set_color_by_rgba_func( + self, func: Callable[[np.ndarray], np.ndarray], recurse: bool = True + ): + """ + Func should take in a point in R3 and output an rgba value + """ + for mob in self.get_family(recurse): + rgba_array = np.asarray([func(point) for point in mob.points]) + mob.set_rgba_array(rgba_array) + return self + + def set_color_by_rgb_func( + self, + func: Callable[[np.ndarray], np.ndarray], + opacity: float = 1, + recurse: bool = True, + ): + """ + Func should take in a point in R3 and output an rgb value + """ + for mob in self.get_family(recurse): + rgba_array = np.asarray([[*func(point), opacity] for point in mob.points]) + mob.set_rgba_array(rgba_array) + return self + + def set_rgba_array_by_color( + self, + color=None, + opacity: float | Iterable[float] | None = None, + name: str = "rgbas", + recurse: bool = True, + ): + max_len = 0 + if color is not None: + rgbs = np.array([color_to_rgb(c) for c in listify(color)]) + max_len = len(rgbs) + if opacity is not None: + opacities = np.array(listify(opacity)) + max_len = max(max_len, len(opacities)) + + for mob in self.get_family(recurse): + if max_len > len(mob.data[name]): # type: ignore + mob.data[name] = resize_array(mob.data[name], max_len) # type: ignore + size = len(mob.data[name]) # type: ignore + if color is not None: + mob.data[name][:, :3] = resize_array(rgbs, size) # type: ignore + if opacity is not None: + mob.data[name][:, 3] = resize_array(opacities, size) # type: ignore + return self def set_color(self, color, opacity=None, recurse=True): self.set_rgba_array(color, opacity, recurse=False) @@ -1995,10 +2322,13 @@ def get_color(self): return rgb_to_hex(self.rgbas[0, :3]) def get_opacity(self): - return self.rgbas[0, 3] + return self.data["rgbas"][0, 3] - def set_color_by_gradient(self, *colors): - self.set_submobject_colors_by_gradient(*colors) + def set_color_by_gradient(self, *colors: Color): + if self.has_points(): + self.set_color(colors) + else: + self.set_submobject_colors_by_gradient(*colors) return self def set_submobject_colors_by_gradient(self, *colors): @@ -2018,20 +2348,28 @@ def set_submobject_colors_by_gradient(self, *colors): def fade(self, darkness=0.5, recurse=True): self.set_opacity(1.0 - darkness, recurse=recurse) - def get_gloss(self): - return self.gloss + def get_reflectiveness(self) -> np.ndarray: + return self.uniforms["reflectiveness"] + + def set_reflectiveness(self, reflectiveness: float, recurse: bool = True): + for mob in self.get_family(recurse): + mob.uniforms["reflectiveness"] = np.asarray(reflectiveness) + return self + + def get_shadow(self) -> np.ndarray: + return self.uniforms["shadow"] - def set_gloss(self, gloss, recurse=True): + def set_shadow(self, shadow: float, recurse: bool = True): for mob in self.get_family(recurse): - mob.gloss = gloss + mob.uniforms["shadow"] = np.asarray(shadow) return self - def get_shadow(self): - return self.shadow + def get_gloss(self) -> np.ndarray: + return self.uniforms["gloss"] - def set_shadow(self, shadow, recurse=True): + def set_gloss(self, gloss: float, recurse: bool = True): for mob in self.get_family(recurse): - mob.shadow = shadow + mob.uniforms["gloss"] = np.asarray(gloss) return self # Background rectangle @@ -2073,7 +2411,7 @@ def add_background_rectangle( self.background_rectangle = BackgroundRectangle( self, color=color, fill_opacity=opacity, **kwargs ) - self.add_to_back(self.background_rectangle) + self.add_to_back(self.background_rectangle) # type: ignore return self def add_background_rectangle_to_submobjects(self, **kwargs): @@ -2101,6 +2439,15 @@ def get_corner(self, direction) -> np.ndarray: """Get corner coordinates for certain direction.""" return self.get_bounding_box_point(direction) + def get_all_corners(self): + bb = self.get_bounding_box() + return np.array( + [ + [bb[indices[-i + 1]][i] for i in range(3)] + for indices in it.product([0, 2], repeat=3) + ] + ) + def get_center(self) -> np.ndarray: """Get center coordinates.""" return self.get_bounding_box()[1] @@ -2205,9 +2552,8 @@ def point_from_proportion(self, alpha): i, subalpha = integer_interpolate(0, len(points) - 1, alpha) return interpolate(points[i], points[i + 1], subalpha) - def pfp(self, alpha): - """Abbreviation for point_from_proportion""" - return self.point_from_proportion(alpha) + pfp = point_from_proportion + """Abbreviation for point_from_proportion""" def get_pieces(self, n_pieces): template = self.copy() @@ -2247,25 +2593,25 @@ def match_depth(self, mobject: OpenGLMobject, **kwargs): """Match the depth with the depth of another :class:`~.OpenGLMobject`.""" return self.match_dim_size(mobject, 2, **kwargs) - def match_coord(self, mobject: OpenGLMobject, dim, direction=ORIGIN): + def match_coord(self, mobject_or_point: OpenGLMobject, dim, direction=ORIGIN): """Match the coordinates with the coordinates of another :class:`~.OpenGLMobject`.""" - return self.set_coord( - mobject.get_coord(dim, direction), - dim=dim, - direction=direction, - ) + if isinstance(mobject_or_point, OpenGLMobject): + coord = mobject_or_point.get_coord(dim, direction) + else: + coord = mobject_or_point[dim] + return self.set_coord(coord, dim=dim, direction=direction) - def match_x(self, mobject, direction=ORIGIN): + def match_x(self, mobject_or_point, direction=ORIGIN): """Match x coord. to the x coord. of another :class:`~.OpenGLMobject`.""" - return self.match_coord(mobject, 0, direction) + return self.match_coord(mobject_or_point, 0, direction) - def match_y(self, mobject, direction=ORIGIN): + def match_y(self, mobject_or_point, direction=ORIGIN): """Match y coord. to the x coord. of another :class:`~.OpenGLMobject`.""" - return self.match_coord(mobject, 1, direction) + return self.match_coord(mobject_or_point, 1, direction) - def match_z(self, mobject, direction=ORIGIN): + def match_z(self, mobject_or_point, direction=ORIGIN): """Match z coord. to the x coord. of another :class:`~.OpenGLMobject`.""" - return self.match_coord(mobject, 2, direction) + return self.match_coord(mobject_or_point, 2, direction) def align_to( self, @@ -2305,9 +2651,9 @@ def align_data_and_family(self, mobject): self.align_family(mobject) self.align_data(mobject) - def align_data(self, mobject): + def align_data(self, mobject) -> None: # In case any data arrays get resized when aligned to shader data - # self.refresh_shader_data() + self.refresh_shader_data() for mob1, mob2 in zip(self.get_family(), mobject.get_family()): # Separate out how points are treated so that subclasses # can handle that case differently if they choose @@ -2315,20 +2661,20 @@ def align_data(self, mobject): for key in mob1.data.keys() & mob2.data.keys(): if key == "points": continue - arr1 = mob1.data[key] + arr1 = mob1.data[key] # type: ignore arr2 = mob2.data[key] if len(arr2) > len(arr1): - mob1.data[key] = resize_preserving_order(arr1, len(arr2)) + mob1.data[key] = resize_preserving_order(arr1, len(arr2)) # type: ignore elif len(arr1) > len(arr2): mob2.data[key] = resize_preserving_order(arr2, len(arr1)) - def align_points(self, mobject): + def align_points(self, mobject) -> OpenGLMobject: max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject): + def align_family(self, mobject) -> OpenGLMobject: mob1 = self mob2 = mobject n1 = len(mob1) @@ -2341,14 +2687,14 @@ def align_family(self, mobject): sm1.align_family(sm2) return self - def push_self_into_submobjects(self): + def push_self_into_submobjects(self) -> OpenGLMobject: copy = self.deepcopy() copy.submobjects = [] - self.resize_points(0) + self.clear_points() self.add(copy) return self - def add_n_more_submobjects(self, n): + def add_n_more_submobjects(self, n) -> OpenGLMobject: if n == 0: return self @@ -2377,7 +2723,9 @@ def add_n_more_submobjects(self, n): # Interpolate - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path()): + def interpolate( + self, mobject1, mobject2, alpha, path_func=straight_path + ) -> OpenGLMobject: """Turns this :class:`~.OpenGLMobject` into an interpolation between ``mobject1`` and ``mobject2``. @@ -2401,7 +2749,7 @@ def construct(self): for key in self.data: if key in self.locked_data_keys: continue - if len(self.data[key]) == 0: + if len(self.data[key]) == 0: # type: ignore continue if key not in mobject1.data or key not in mobject2.data: continue @@ -2411,23 +2759,11 @@ def construct(self): else: func = interpolate - self.data[key][:] = func(mobject1.data[key], mobject2.data[key], alpha) - + self.data[key][:] = func(mobject1.data[key], mobject2.data[key], alpha) # type: ignore for key in self.uniforms: - if key != "fixed_orientation_center": - self.uniforms[key] = interpolate( - mobject1.uniforms[key], - mobject2.uniforms[key], - alpha, - ) - else: - self.uniforms["fixed_orientation_center"] = tuple( - interpolate( - np.array(mobject1.uniforms["fixed_orientation_center"]), - np.array(mobject2.uniforms["fixed_orientation_center"]), - alpha, - ) - ) + self.uniforms[key] = interpolate( # type: ignore + mobject1.uniforms[key], mobject2.uniforms[key], alpha + ) return self def pointwise_become_partial(self, mobject, a, b): @@ -2439,9 +2775,45 @@ def pointwise_become_partial(self, mobject, a, b): """ pass # To implement in subclass + # Locking data + + def lock_data(self, keys: Iterable[str]): + """ + To speed up some animations, particularly transformations, + it can be handy to acknowledge which pieces of data + won't change during the animation so that calls to + interpolate can skip this, and so that it's not + read into the shader_wrapper objects needlessly + """ + if self.has_updaters: + return + # Be sure shader data has most up to date information + self.refresh_shader_data() + self.locked_data_keys = set(keys) + + def lock_matching_data(self, mobject1: OpenGLMobject, mobject2: OpenGLMobject): + for sm, sm1, sm2 in zip( + self.get_family(), mobject1.get_family(), mobject2.get_family() + ): + keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys() + sm.lock_data( + list( + filter( + lambda key: np.all(sm1.data[key] == sm2.data[key]), # type: ignore + keys, + ) + ) + ) + return self + + def unlock_data(self): + for mob in self.get_family(): + mob.locked_data_keys = set() + def become( self, mobject: OpenGLMobject, + match_updaters=False, match_height: bool = False, match_width: bool = False, match_depth: bool = False, @@ -2482,7 +2854,7 @@ def construct(self): circ.become(square) self.wait(0.5) """ - + # Manim CE Weird stretching thing which also modifies the original mobject if stretch: mobject.stretch_to_fit_height(self.height) mobject.stretch_to_fit_width(self.width) @@ -2498,74 +2870,62 @@ def construct(self): if match_center: mobject.move_to(self.get_center()) + # Original 3b1b/manim behaviour self.align_family(mobject) - for sm1, sm2 in zip(self.get_family(), mobject.get_family()): + family1 = self.get_family() + family2 = mobject.get_family() + for sm1, sm2 in zip(family1, family2): sm1.set_data(sm2.data) sm1.set_uniforms(sm2.uniforms) + sm1.shader_folder = sm2.shader_folder + sm1.texture_paths = sm2.texture_paths + sm1.depth_test = sm2.depth_test + sm1.render_primitive = sm2.render_primitive + # Make sure named family members carry over + for attr, value in list(mobject.__dict__.items()): + if isinstance(value, OpenGLMobject) and value in family2: + setattr(self, attr, family1[family2.index(value)]) self.refresh_bounding_box(recurse_down=True) + if match_updaters: + self.match_updaters(mobject) return self - # Locking data - - def lock_data(self, keys): - """ - To speed up some animations, particularly transformations, - it can be handy to acknowledge which pieces of data - won't change during the animation so that calls to - interpolate can skip this, and so that it's not - read into the shader_wrapper objects needlessly - """ - if self.has_updaters: - return - # Be sure shader data has most up to date information - self.refresh_shader_data() - self.locked_data_keys = set(keys) - - def lock_matching_data(self, mobject1, mobject2): - for sm, sm1, sm2 in zip( - self.get_family(), - mobject1.get_family(), - mobject2.get_family(), - ): - keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys() - sm.lock_data( - list( - filter( - lambda key: np.all(sm1.data[key] == sm2.data[key]), - keys, - ), - ), - ) - return self + # Operations touching shader uniforms - def unlock_data(self): - for mob in self.get_family(): - mob.locked_data_keys = set() + @staticmethod + def affects_shader_info_id(func): + @wraps(func) + def wrapper(self): + for mob in self.get_family(): + func(mob) + mob.refresh_shader_wrapper_id() + return self - # Operations touching shader uniforms + return wrapper @affects_shader_info_id - def fix_in_frame(self): - self.is_fixed_in_frame = 1.0 + def fix_in_frame(self) -> OpenGLMobject: + self.uniforms["is_fixed_in_frame"] = np.asarray(1.0) + self.is_fixed_in_frame = True return self @affects_shader_info_id - def fix_orientation(self): - self.is_fixed_orientation = 1.0 + def fix_orientation(self) -> OpenGLMobject: + self.uniforms["is_fixed_orientation"] = np.asarray(1.0) + self.is_fixed_orientation = True self.fixed_orientation_center = tuple(self.get_center()) - self.depth_test = True return self @affects_shader_info_id - def unfix_from_frame(self): - self.is_fixed_in_frame = 0.0 + def unfix_from_frame(self) -> OpenGLMobject: + self.uniforms["is_fixed_in_frame"] = np.asarray(0.0) + self.is_fixed_in_frame = False return self @affects_shader_info_id def unfix_orientation(self): self.is_fixed_orientation = 0.0 self.fixed_orientation_center = (0, 0, 0) - self.depth_test = False return self @affects_shader_info_id @@ -2625,26 +2985,27 @@ def set_color_by_xyz_func( return self # For shader data + def init_shader_data(self): + # TODO, only call this when needed? + self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype) + self.shader_indices = None + self.shader_wrapper = ShaderWrapper( + vert_data=self.shader_data, + shader_folder=self.shader_folder, + texture_paths=self.texture_paths, + depth_test=self.depth_test, + render_primitive=self.render_primitive, + ) def refresh_shader_wrapper_id(self): - self.get_shader_wrapper().refresh_id() + self.shader_wrapper.refresh_id() return self def get_shader_wrapper(self): - from manim.renderer.shader_wrapper import ShaderWrapper - - # if hasattr(self, "__shader_wrapper"): - # return self.__shader_wrapper - - self.shader_wrapper = ShaderWrapper( - vert_data=self.get_shader_data(), - vert_indices=self.get_shader_vert_indices(), - uniforms=self.get_shader_uniforms(), - depth_test=self.depth_test, - texture_paths=self.texture_paths, - render_primitive=self.render_primitive, - shader_folder=self.__class__.shader_folder, - ) + self.shader_wrapper.vert_data = self.get_shader_data() + self.shader_wrapper.vert_indices = self.get_shader_vert_indices() + self.shader_wrapper.uniforms = self.get_shader_uniforms() + self.shader_wrapper.depth_test = self.depth_test return self.shader_wrapper def get_shader_wrapper_list(self): @@ -2676,12 +3037,12 @@ def check_data_alignment(self, array, data_key): ) return self - def get_resized_shader_data_array(self, length): + def get_resized_shader_data_array(self, length: int) -> np.ndarray: # If possible, try to populate an existing array, rather # than recreating it each frame - points = self.points - shader_data = np.zeros(len(points), dtype=self.shader_dtype) - return shader_data + if len(self.shader_data) != length: + self.shader_data = resize_array(self.shader_data, length) + return self.shader_data def read_data_to_shader(self, shader_data, shader_data_key, data_key): if data_key in self.locked_data_keys: @@ -2703,14 +3064,98 @@ def get_shader_uniforms(self): def get_shader_vert_indices(self): return self.shader_indices - @property - def submobjects(self): - return self._submobjects if hasattr(self, "_submobjects") else [] + # Event Handlers + """ + Event handling follows the Event Bubbling model of DOM in javascript. + Return false to stop the event bubbling. + To learn more visit https://www.quirksmode.org/js/events_order.html - @submobjects.setter - def submobjects(self, submobject_list): - self.remove(*self.submobjects) - self.add(*submobject_list) + Event Callback Argument is a callable function taking two arguments: + 1. Mobject + 2. EventData + """ + + def init_event_listeners(self): + self.event_listeners: list[EventListener] = [] + + def add_event_listener( + self, + event_type: EventType, + event_callback: Callable[[OpenGLMobject, dict[str, str]], None], + ): + event_listener = EventListener(self, event_type, event_callback) + self.event_listeners.append(event_listener) + EVENT_DISPATCHER.add_listener(event_listener) + return self + + def remove_event_listener( + self, + event_type: EventType, + event_callback: Callable[[OpenGLMobject, dict[str, str]], None], + ): + event_listener = EventListener(self, event_type, event_callback) + while event_listener in self.event_listeners: + self.event_listeners.remove(event_listener) + EVENT_DISPATCHER.remove_listener(event_listener) + return self + + def clear_event_listeners(self, recurse: bool = True): + self.event_listeners = [] + if recurse: + for submob in self.submobjects: + submob.clear_event_listeners(recurse=recurse) + return self + + def get_event_listeners(self): + return self.event_listeners + + def get_family_event_listeners(self): + return list(it.chain(*[sm.get_event_listeners() for sm in self.get_family()])) + + def get_has_event_listener(self): + return any(mob.get_event_listeners() for mob in self.get_family()) + + def add_mouse_motion_listener(self, callback): + self.add_event_listener(EventType.MouseMotionEvent, callback) + + def remove_mouse_motion_listener(self, callback): + self.remove_event_listener(EventType.MouseMotionEvent, callback) + + def add_mouse_press_listener(self, callback): + self.add_event_listener(EventType.MousePressEvent, callback) + + def remove_mouse_press_listener(self, callback): + self.remove_event_listener(EventType.MousePressEvent, callback) + + def add_mouse_release_listener(self, callback): + self.add_event_listener(EventType.MouseReleaseEvent, callback) + + def remove_mouse_release_listener(self, callback): + self.remove_event_listener(EventType.MouseReleaseEvent, callback) + + def add_mouse_drag_listener(self, callback): + self.add_event_listener(EventType.MouseDragEvent, callback) + + def remove_mouse_drag_listener(self, callback): + self.remove_event_listener(EventType.MouseDragEvent, callback) + + def add_mouse_scroll_listener(self, callback): + self.add_event_listener(EventType.MouseScrollEvent, callback) + + def remove_mouse_scroll_listener(self, callback): + self.remove_event_listener(EventType.MouseScrollEvent, callback) + + def add_key_press_listener(self, callback): + self.add_event_listener(EventType.KeyPressEvent, callback) + + def remove_key_press_listener(self, callback): + self.remove_event_listener(EventType.KeyPressEvent, callback) + + def add_key_release_listener(self, callback): + self.add_event_listener(EventType.KeyReleaseEvent, callback) + + def remove_key_release_listener(self, callback): + self.remove_event_listener(EventType.KeyReleaseEvent, callback) # Errors @@ -2722,6 +3167,44 @@ def throw_error_if_no_points(self): caller_name = sys._getframe(1).f_code.co_name raise Exception(message.format(caller_name)) + @deprecated( + since="v0.17.2", + message="The usage of this method is discouraged please set attributes directly", + ) + def set(self, **kwargs) -> OpenGLMobject: + """Sets attributes. + + Mainly to be used along with :attr:`animate` to + animate setting attributes. + + Examples + -------- + :: + + >>> mob = OpenGLMobject() + >>> mob.set(foo=0) + OpenGLMobject + >>> mob.foo + 0 + + Parameters + ---------- + **kwargs + The attributes and corresponding values to set. + + Returns + ------- + :class:`OpenGLMobject` + ``self`` + + + """ + + for attr, value in kwargs.items(): + setattr(self, attr, value) + + return self + class OpenGLGroup(OpenGLMobject): def __init__(self, *mobjects, **kwargs): @@ -2730,6 +3213,10 @@ def __init__(self, *mobjects, **kwargs): super().__init__(**kwargs) self.add(*mobjects) + def __add__(self, other: OpenGLMobject | OpenGLGroup): + assert isinstance(other, OpenGLMobject) + return self.add(other) + class OpenGLPoint(OpenGLMobject): def __init__( @@ -2769,7 +3256,7 @@ def __init__(self, mobject): self.cannot_pass_args = False self.anim_args = {} - def __call__(self, **kwargs): + def __call__(self, **kwargs) -> _AnimationBuilder: if self.cannot_pass_args: raise ValueError( "Animation arguments must be passed before accessing methods and can only be passed once", diff --git a/manim/utils/color.py b/manim/utils/color.py index 6356fb3f04..c48525a4ed 100644 --- a/manim/utils/color.py +++ b/manim/utils/color.py @@ -2,6 +2,8 @@ from __future__ import annotations +from manim.utils.iterables import resize_with_interpolation + __all__ = [ "color_to_rgb", "color_to_rgba", @@ -550,3 +552,28 @@ def get_shaded_rgb( factor *= 0.5 result = rgb + factor return result + + +COLORMAP_3B1B: list[Color] = [BLUE_E, GREEN, YELLOW, RED] + + +def get_colormap_list(map_name: str = "viridis", n_colors: int = 9) -> np.ndarray: + """ + Options for map_name: + 3b1b_colormap + magma + inferno + plasma + viridis + cividis + twilight + twilight_shifted + turbo + """ + from matplotlib.cm import get_cmap + + if map_name == "3b1b_colormap": + rgbs = np.array([color_to_rgb(color) for color in COLORMAP_3B1B]) + else: + rgbs = get_cmap(map_name).colors # Make more general? + return resize_with_interpolation(np.array(rgbs), n_colors) diff --git a/manim/utils/space_ops.py b/manim/utils/space_ops.py index 203462057a..cc98b169ba 100644 --- a/manim/utils/space_ops.py +++ b/manim/utils/space_ops.py @@ -285,6 +285,22 @@ def rotation_about_z(angle: float) -> np.ndarray: ) +def get_norm(vector: np.ndarray) -> float: + """Returns the norm of the vector. + + Parameters + ---------- + vector + The vector for which you want to find the norm. + + Returns + ------- + float + The norm of the vector. + """ + return np.linalg.norm(vector) + + def z_to_vector(vector: np.ndarray) -> np.ndarray: """ Returns some matrix in SO(3) which takes the z-axis to the From d9941461e2068622e904ee2eb30ad88ae6207fef Mon Sep 17 00:00:00 2001 From: MrDiver Date: Thu, 29 Dec 2022 20:58:13 +0100 Subject: [PATCH 2/4] ported functionality of VMobject from 3b1b to OpenVGLMobject --- manim/_config/utils.py | 2 + manim/mobject/opengl/opengl_mobject.py | 70 +- .../opengl/opengl_vectorized_mobject.py | 836 ++++++++++-------- manim/mobject/types/vectorized_mobject.py | 1 - manim/renderer/shader_wrapper.py | 65 +- manim/utils/bezier.py | 31 + manim/utils/directories.py | 50 ++ manim/utils/file_ops.py | 3 + 8 files changed, 602 insertions(+), 456 deletions(-) create mode 100644 manim/utils/directories.py diff --git a/manim/_config/utils.py b/manim/_config/utils.py index cdbd62fce4..98229ac4c6 100644 --- a/manim/_config/utils.py +++ b/manim/_config/utils.py @@ -15,6 +15,8 @@ import configparser import copy import errno +import importlib +import inspect import logging import os import re diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 7af7439b3a..6ac029ea2f 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -58,41 +58,7 @@ Updater: TypeAlias = Union[TimeBasedUpdater, NonTimeUpdater] PointUpdateFunction: TypeAlias = Callable[[np.ndarray], np.ndarray] - -class MobjectData(TypedDict, total=False): - points: np.ndarray - bounding_box: np.ndarray - rgbas: np.ndarray - - -def to_mobject_data(values: dict[str, np.ndarray] | MobjectData) -> MobjectData: - result = MobjectData( - points=np.array(values["points"]), - bounding_box=np.array(values["bounding_box"]), - rgbas=np.array(values["rgbas"]), - ) - return result - - -class MobjectUniforms(TypedDict): - is_fixed_in_frame: np.ndarray - is_fixed_orientation: np.ndarray - gloss: np.ndarray - shadow: np.ndarray - reflectiveness: np.ndarray - - -def to_mobject_uniforms( - values: dict[str, np.ndarray] | MobjectUniforms -) -> MobjectUniforms: - result = MobjectUniforms( - is_fixed_in_frame=np.array(values["is_fixed_in_frame"]), - is_fixed_orientation=np.array(values["is_fixed_orientation"]), - gloss=np.array(values["gloss"]), - shadow=np.array(values["shadow"]), - reflectiveness=np.array(values["reflectiveness"]), - ) - return result +UNIFORM_DTYPE = np.float64 class OpenGLMobject: @@ -151,8 +117,8 @@ def __init__( self.saved_state: OpenGLMobject | None = None self.target: OpenGLMobject | None = None - self.data: MobjectData - self.uniforms: MobjectUniforms + self.data: dict[str, np.ndarray] = {} + self.uniforms: dict[str, np.ndarray] = {} self.init_data() self.init_uniforms() @@ -261,22 +227,24 @@ def rgbas(self, value): def init_data(self): """Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data. Subclasses can inherit and overwrite this method to extend `self.data`.""" - self.data = MobjectData( - points=np.zeros((0, 3)), - bounding_box=np.zeros((3, 3)), - rgbas=np.zeros((1, 4)), - ) + self.data = { + "points": np.zeros((0, 3)), + "bounding_box": np.zeros((3, 3)), + "rgbas": np.zeros((1, 4)), + } def init_uniforms(self): """Initializes the uniforms. Gets called upon creation""" - self.uniforms = MobjectUniforms( - is_fixed_in_frame=np.array(float(self.is_fixed_in_frame)), - gloss=np.array(self.gloss), - shadow=np.array(self.shadow), - reflectiveness=np.array(self.reflectiveness), - ) + self.uniforms = { + "is_fixed_in_frame": np.array( + float(self.is_fixed_in_frame), dtype=UNIFORM_DTYPE + ), + "gloss": np.array(self.gloss, dtype=UNIFORM_DTYPE), + "shadow": np.array(self.shadow, dtype=UNIFORM_DTYPE), + "reflectiveness": np.array(self.reflectiveness, dtype=UNIFORM_DTYPE), + } def init_colors(self): """Initializes the colors. @@ -1393,8 +1361,8 @@ def copy(self, deep: bool = False) -> OpenGLMobject: # The line above is only a shallow copy, so the internal # data which are numpyu arrays or other mobjects still # need to be further copied. - result.data = to_mobject_data(self.data) - result.uniforms = to_mobject_uniforms(self.uniforms) + result.data = {k: np.array(v) for k, v in self.data.items()} + result.uniforms = {k: np.array(v) for k, v in self.uniforms.items()} # Instead of adding using result.add, which does some checks for updating # updater statues and bounding box, just directly modify the family-related @@ -1427,7 +1395,7 @@ def copy(self, deep: bool = False) -> OpenGLMobject: return result def generate_target(self, use_deepcopy: bool = False): - target: OpenGLMobject = self.copy(use_deepcopy=use_deepcopy) + target: OpenGLMobject = self.copy(deep=use_deepcopy) target.saved_state = self.saved_state self.target = target return self.target diff --git a/manim/mobject/opengl/opengl_vectorized_mobject.py b/manim/mobject/opengl/opengl_vectorized_mobject.py index 8dc174e12d..207d404e69 100644 --- a/manim/mobject/opengl/opengl_vectorized_mobject.py +++ b/manim/mobject/opengl/opengl_vectorized_mobject.py @@ -11,55 +11,54 @@ from manim import config from manim.constants import * -from manim.mobject.opengl.opengl_mobject import OpenGLMobject, OpenGLPoint +from manim.mobject.opengl.opengl_mobject import ( + UNIFORM_DTYPE, + OpenGLMobject, + OpenGLPoint, +) from manim.renderer.shader_wrapper import ShaderWrapper from manim.utils.bezier import ( bezier, get_quadratic_approximation_of_cubic, get_smooth_cubic_bezier_handle_points, + get_smooth_quadratic_bezier_handle_points, integer_interpolate, interpolate, + inverse_interpolate, partial_quadratic_bezier_points, proportions_along_bezier_curve_for_point, quadratic_bezier_remap, ) from manim.utils.color import * -from manim.utils.config_ops import _Data -from manim.utils.iterables import listify, make_even, resize_with_interpolation +from manim.utils.iterables import ( + listify, + make_even, + resize_array, + resize_with_interpolation, +) from manim.utils.space_ops import ( angle_between_vectors, cross2d, earclip_triangulation, + get_norm, get_unit_normal, shoelace_direction, z_to_vector, ) - -def triggers_refreshed_triangulation(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - old_points = np.empty((0, 3)) - for mob in self.family_members_with_points(): - old_points = np.concatenate((old_points, mob.points), axis=0) - func(self, *args, **kwargs) - new_points = np.empty((0, 3)) - for mob in self.family_members_with_points(): - new_points = np.concatenate((new_points, mob.points), axis=0) - if not np.array_equal(new_points, old_points): - self.refresh_triangulation() - self.refresh_unit_normal() - return self - - return wrapper +DEFAULT_STROKE_COLOR = GREY_A +DEFAULT_FILL_COLOR = GREY_C class OpenGLVMobject(OpenGLMobject): """A vectorized mobject.""" + n_points_per_curve: int = 3 + stroke_shader_folder = "quadratic_bezier_stroke" + fill_shader_folder = "quadratic_bezier_fill" fill_dtype = [ ("point", np.float32, (3,)), - ("unit_normal", np.float32, (3,)), + ("orientation", np.float32, (3,)), ("color", np.float32, (4,)), ("vert_index", np.float32, (1,)), ] @@ -67,89 +66,48 @@ class OpenGLVMobject(OpenGLMobject): ("point", np.float32, (3,)), ("prev_point", np.float32, (3,)), ("next_point", np.float32, (3,)), - ("unit_normal", np.float32, (3,)), ("stroke_width", np.float32, (1,)), ("color", np.float32, (4,)), ] - stroke_shader_folder = "quadratic_bezier_stroke" - fill_shader_folder = "quadratic_bezier_fill" + render_primitive: int = moderngl.TRIANGLES - fill_rgba = _Data() - stroke_rgba = _Data() - stroke_width = _Data() - unit_normal = _Data() + pre_function_handle_to_anchor_scale_factor: float = 0.01 + make_smooth_after_applying_functions: bool = False + tolerance_for_point_equality: float = 1e-8 def __init__( self, + color: Color | None = None, fill_color: Color | None = None, fill_opacity: float = 0.0, stroke_color: Color | None = None, stroke_opacity: float = 1.0, stroke_width: float = DEFAULT_STROKE_WIDTH, draw_stroke_behind_fill: bool = False, - # Indicates that it will not be displayed, but - # that it should count in parent mobject's path - pre_function_handle_to_anchor_scale_factor: float = 0.01, - make_smooth_after_applying_functions: float = False, background_image_file: str | None = None, - # This is within a pixel - # TODO, do we care about accounting for - # varying zoom levels? - tolerance_for_point_equality: float = 1e-8, - n_points_per_curve: int = 3, long_lines: bool = False, - should_subdivide_sharp_curves: bool = False, - should_remove_null_curves: bool = False, - # Could also be "bevel", "miter", "round" - joint_type: LineJointType | None = None, + joint_type: LineJointType = LineJointType.AUTO, flat_stroke: bool = True, - render_primitive=moderngl.TRIANGLES, - triangulation_locked: bool = False, + # Measured in pixel widths + anti_alias_width: float = 1.0, **kwargs, ): - self.data = {} + self.fill_color = fill_color or color or DEFAULT_FILL_COLOR self.fill_opacity = fill_opacity + self.stroke_color = stroke_color or color or DEFAULT_STROKE_COLOR self.stroke_opacity = stroke_opacity self.stroke_width = stroke_width self.draw_stroke_behind_fill = draw_stroke_behind_fill - # Indicates that it will not be displayed, but - # that it should count in parent mobject's path - self.pre_function_handle_to_anchor_scale_factor = ( - pre_function_handle_to_anchor_scale_factor - ) - self.make_smooth_after_applying_functions = make_smooth_after_applying_functions self.background_image_file = background_image_file - # This is within a pixel - # TODO, do we care about accounting for - # varying zoom levels? - self.tolerance_for_point_equality = tolerance_for_point_equality - self.n_points_per_curve = n_points_per_curve self.long_lines = long_lines - self.should_subdivide_sharp_curves = should_subdivide_sharp_curves - self.should_remove_null_curves = should_remove_null_curves - if joint_type is None: - joint_type = LineJointType.AUTO self.joint_type = joint_type self.flat_stroke = flat_stroke - self.render_primitive = render_primitive - self.triangulation_locked = triangulation_locked + self.anti_alias_width = anti_alias_width self.needs_new_triangulation = True self.triangulation = np.zeros(0, dtype="i4") - self.orientation = 1 - self.fill_data = None - self.stroke_data = None - self.fill_shader_wrapper = None - self.stroke_shader_wrapper = None - self.init_shader_data() super().__init__(**kwargs) - self.refresh_unit_normal() - - if fill_color: - self.fill_color = Color(fill_color) - if stroke_color: - self.stroke_color = Color(stroke_color) def get_group_class(self): return OpenGLVGroup @@ -158,13 +116,67 @@ def get_group_class(self): def get_mobject_type_class(): return OpenGLVMobject + @property + def rgbas(self): + raise NotImplementedError( + "rgbas is not implemented for OpenGLVMobject. please use fill_rgba and stroke_rgba." + ) + + @rgbas.setter + def rgbas(self, value): + raise NotImplementedError( + "rgbas is not implemented for OpenGLVMobject. please use fill_rgba and stroke_rgba." + ) + def init_data(self): super().init_data() self.data.pop("rgbas") - self.fill_rgba = np.zeros((1, 4)) - self.stroke_rgba = np.zeros((1, 4)) - self.unit_normal = np.zeros((1, 3)) - # stroke_width belongs to self.data, but is defined through init_colors+set_stroke + self.data.update( + { + "fill_rgba": np.zeros((1, 4)), + "stroke_rgba": np.zeros((1, 4)), + "stroke_width": np.zeros((1, 1)), + "orientation": np.zeros((1, 1)), + } + ) + + def init_uniforms(self): + super().init_uniforms() + self.uniforms["anti_alias_width"] = np.asarray( + self.anti_alias_width, dtype=UNIFORM_DTYPE + ) + self.uniforms["joint_type"] = np.asarray( + self.joint_type.value, dtype=UNIFORM_DTYPE + ) + self.uniforms["flat_stroke"] = np.asarray( + float(self.flat_stroke), dtype=UNIFORM_DTYPE + ) + + # These are here just to make type checkers happy + def get_family(self, recurse: bool = True) -> list[OpenGLVMobject]: # type: ignore + return super().get_family(recurse) # type: ignore + + def family_members_with_points(self) -> list[OpenGLVMobject]: # type: ignore + return super().family_members_with_points() # type: ignore + + def replicate(self, n: int) -> OpenGLVGroup: # type: ignore + return super().replicate(n) # type: ignore + + def get_grid(self, *args, **kwargs) -> OpenGLVGroup: # type: ignore + return super().get_grid(*args, **kwargs) # type: ignore + + def __getitem__(self, value: int | slice) -> OpenGLVMobject: # type: ignore + return super().__getitem__(value) # type: ignore + + def add(self, *vmobjects: OpenGLVMobject): # type: ignore + if not all(isinstance(m, OpenGLVMobject) for m in vmobjects): + raise Exception("All submobjects must be of type OpenGLVMobject") + super().add(*vmobjects) + + def copy(self, deep: bool = False) -> OpenGLVMobject: + result = super().copy(deep) + result.shader_wrapper_list = [sw.copy() for sw in self.shader_wrapper_list] + return result # Colors def init_colors(self): @@ -180,12 +192,28 @@ def init_colors(self): ) self.set_gloss(self.gloss) self.set_flat_stroke(self.flat_stroke) + self.color = self.get_color() + return self + + def set_rgba_array( + self, rgba_array: np.ndarray, name: str | None = None, recurse: bool = False + ) -> OpenGLVMobject: + if name is None: + names = ["fill_rgba", "stroke_rgba"] + else: + names = [name] + + for name in names: + if name in self.data: + self.data[name] = rgba_array + else: + raise Exception(f"{name} is not a valid data name.") return self def set_fill( self, - color: Color | None = None, - opacity: float | None = None, + color: Color | Iterable[Color] | None = None, + opacity: float | Iterable[float] | None = None, recurse: bool = True, ) -> OpenGLVMobject: """Set the fill color and fill opacity of a :class:`OpenGLVMobject`. @@ -223,13 +251,7 @@ def construct(self): -------- :meth:`~.OpenGLVMobject.set_style` """ - if opacity is not None: - self.fill_opacity = opacity - if recurse: - for submobject in self.submobjects: - submobject.set_fill(color, opacity, recurse) - - self.set_rgba_array(color, opacity, "fill_rgba", recurse) + self.set_rgba_array_by_color(color, opacity, "fill_rgba", recurse) return self def set_stroke( @@ -240,78 +262,98 @@ def set_stroke( background=None, recurse=True, ): - if opacity is not None: - self.stroke_opacity = opacity - if recurse: - for submobject in self.submobjects: - submobject.set_stroke( - color=color, - width=width, - opacity=opacity, - background=background, - recurse=recurse, - ) - - self.set_rgba_array(color, opacity, "stroke_rgba", recurse) + self.set_rgba_array_by_color(color, opacity, "stroke_rgba", recurse) if width is not None: for mob in self.get_family(recurse): - mob.stroke_width = np.array([[width] for width in listify(width)]) + if isinstance(width, np.ndarray): + arr = width.reshape((len(width), 1)) + else: + arr = np.array([[w] for w in listify(width)], dtype=float) + mob.data["stroke_width"] = arr if background is not None: for mob in self.get_family(recurse): mob.draw_stroke_behind_fill = background return self - def set_style( + def set_backstroke( self, - fill_color=None, - fill_opacity=None, - fill_rgba=None, - stroke_color=None, - stroke_opacity=None, - stroke_rgba=None, - stroke_width=None, - gloss=None, - shadow=None, - recurse=True, - ): - if fill_rgba is not None: - self.fill_rgba = resize_with_interpolation(fill_rgba, len(fill_rgba)) - else: - self.set_fill(color=fill_color, opacity=fill_opacity, recurse=recurse) + color: Color | Iterable[Color] | None = None, + width: float | Iterable[float] = 3, + background: bool = True, + ) -> OpenGLVMobject: + self.set_stroke(color, width, background=background) + return self - if stroke_rgba is not None: - self.stroke_rgba = resize_with_interpolation(stroke_rgba, len(fill_rgba)) - self.set_stroke(width=stroke_width) - else: - self.set_stroke( - color=stroke_color, - width=stroke_width, - opacity=stroke_opacity, - recurse=recurse, + def align_stroke_width_data_to_points(self, recurse: bool = True) -> None: + for mob in self.get_family(recurse): + mob.data["stroke_width"] = resize_with_interpolation( + mob.data["stroke_width"], len(mob.points) ) - if gloss is not None: - self.set_gloss(gloss, recurse=recurse) - if shadow is not None: - self.set_shadow(shadow, recurse=recurse) + def set_style( + self, + fill_color: Color | Iterable[Color] | None = None, + fill_opacity: float | Iterable[float] | None = None, + fill_rgba: np.ndarray | None = None, + stroke_color: Color | Iterable[Color] | None = None, + stroke_opacity: float | Iterable[float] | None = None, + stroke_rgba: np.ndarray | None = None, + stroke_width: float | Iterable[float] | None = None, + stroke_background: bool = True, + reflectiveness: float | None = None, + gloss: float | None = None, + shadow: float | None = None, + recurse: bool = True, + ) -> OpenGLVMobject: + for mob in self.get_family(recurse): + if fill_rgba is not None: + mob.data["fill_rgba"] = resize_with_interpolation( + fill_rgba, len(fill_rgba) + ) + else: + mob.set_fill(color=fill_color, opacity=fill_opacity, recurse=False) + + if stroke_rgba is not None: + mob.data["stroke_rgba"] = resize_with_interpolation( + stroke_rgba, len(stroke_rgba) + ) + mob.set_stroke( + width=stroke_width, + background=stroke_background, + recurse=False, + ) + else: + mob.set_stroke( + color=stroke_color, + width=stroke_width, + opacity=stroke_opacity, + recurse=False, + background=stroke_background, + ) + + if reflectiveness is not None: + mob.set_reflectiveness(reflectiveness, recurse=False) + if gloss is not None: + mob.set_gloss(gloss, recurse=False) + if shadow is not None: + mob.set_shadow(shadow, recurse=False) return self def get_style(self): return { - "fill_rgba": self.fill_rgba, - "stroke_rgba": self.stroke_rgba, - "stroke_width": self.stroke_width, - "gloss": self.gloss, - "shadow": self.shadow, + "fill_rgba": self.data["fill_rgba"].copy(), + "stroke_rgba": self.data["stroke_rgba"].copy(), + "stroke_width": self.data["stroke_width"].copy(), + "stroke_background": self.draw_stroke_behind_fill, + "reflectiveness": self.get_reflectiveness(), + "gloss": self.get_gloss(), + "shadow": self.get_shadow(), } - def match_style(self, vmobject, recurse=True): - vmobject_style = vmobject.get_style() - if config.renderer == RendererType.OPENGL: - vmobject_style["stroke_width"] = vmobject_style["stroke_width"][0][0] - self.set_style(**vmobject_style, recurse=False) + def match_style(self, vmobject: OpenGLVMobject, recurse: bool = True): + self.set_style(**vmobject.get_style(), recurse=False) if recurse: # Does its best to match up submobject lists, and # match styles accordingly @@ -324,112 +366,101 @@ def match_style(self, vmobject, recurse=True): sm1.match_style(sm2) return self - def set_color(self, color, opacity=None, recurse=True): - if opacity is not None: - self.opacity = opacity - + def set_color(self, color, opacity=None, recurse=True) -> OpenGLVMobject: self.set_fill(color, opacity=opacity, recurse=recurse) self.set_stroke(color, opacity=opacity, recurse=recurse) return self - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity, recurse=True) -> OpenGLVMobject: self.set_fill(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse) return self - def fade(self, darkness=0.5, recurse=True): - factor = 1.0 - darkness - self.set_fill( - opacity=factor * self.get_fill_opacity(), - recurse=False, - ) - self.set_stroke( - opacity=factor * self.get_stroke_opacity(), - recurse=False, - ) - super().fade(darkness, recurse) + def fade(self, darkness=0.5, recurse=True) -> OpenGLVMobject: + mobs = self.get_family() if recurse else [self] + for mob in mobs: + factor = 1.0 - darkness + mob.set_fill( + opacity=factor * mob.get_fill_opacity(), + recurse=False, + ) + mob.set_stroke( + opacity=factor * mob.get_stroke_opacity(), + recurse=False, + ) return self - def get_fill_colors(self): - return [Color(rgb_to_hex(rgba[:3])) for rgba in self.fill_rgba] + def get_fill_colors(self) -> list[str]: + return [rgb_to_hex(rgba[:3]) for rgba in self.data["fill_rgba"]] - def get_fill_opacities(self): - return self.fill_rgba[:, 3] + def get_fill_opacities(self) -> np.ndarray: + return self.data["fill_rgba"][:, 3] - def get_stroke_colors(self): - return [Color(rgb_to_hex(rgba[:3])) for rgba in self.stroke_rgba] + def get_stroke_colors(self) -> list[str]: + return [rgb_to_hex(rgba[:3]) for rgba in self.data["stroke_rgba"]] - def get_stroke_opacities(self): - return self.stroke_rgba[:, 3] + def get_stroke_opacities(self) -> np.ndarray: + return self.data["stroke_rgba"][:, 3] - def get_stroke_widths(self): - return self.stroke_width + def get_stroke_widths(self) -> np.ndarray: + return self.data["stroke_width"][:, 0] # TODO, it's weird for these to return the first of various lists # rather than the full information - def get_fill_color(self): + def get_fill_color(self) -> str: """ If there are multiple colors (for gradient) this returns the first one """ return self.get_fill_colors()[0] - def get_fill_opacity(self): + def get_fill_opacity(self) -> float: """ If there are multiple opacities, this returns the first """ return self.get_fill_opacities()[0] - def get_stroke_color(self): + def get_stroke_color(self) -> str: return self.get_stroke_colors()[0] - def get_stroke_width(self): + def get_stroke_width(self) -> float | np.ndarray: return self.get_stroke_widths()[0] - def get_stroke_opacity(self): + def get_stroke_opacity(self) -> float: return self.get_stroke_opacities()[0] - def get_color(self): - if self.has_stroke(): - return self.get_stroke_color() - return self.get_fill_color() - - def get_colors(self): - if self.has_stroke(): - return self.get_stroke_colors() - return self.get_fill_colors() + def get_color(self) -> str: + if self.has_fill(): + return self.get_fill_color() + return self.get_stroke_color() - stroke_color = property(get_stroke_color, set_stroke) - color = property(get_color, set_color) - fill_color = property(get_fill_color, set_fill) + def has_stroke(self) -> bool: + return any(self.data["stroke_width"]) and any(self.data["stroke_rgba"][:, 3]) - def has_stroke(self): - stroke_widths = self.get_stroke_widths() - stroke_opacities = self.get_stroke_opacities() - return ( - stroke_widths is not None - and stroke_opacities is not None - and any(stroke_widths) - and any(stroke_opacities) - ) + def has_fill(self) -> bool: + return any(self.data["fill_rgba"][:, 3]) - def has_fill(self): - fill_opacities = self.get_fill_opacities() - return fill_opacities is not None and any(fill_opacities) - - def get_opacity(self): + def get_opacity(self) -> float: if self.has_fill(): return self.get_fill_opacity() return self.get_stroke_opacity() - def set_flat_stroke(self, flat_stroke=True, recurse=True): + def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): for mob in self.get_family(recurse): - mob.flat_stroke = flat_stroke + mob.uniforms["flat_stroke"] = np.asarray(float(flat_stroke)) return self - def get_flat_stroke(self): - return self.flat_stroke + def get_flat_stroke(self) -> bool: + return self.uniforms["flat_stroke"] == 1.0 + + def set_joint_type(self, joint_type: LineJointType, recurse: bool = True): + for mob in self.get_family(recurse): + mob.uniforms["joint_type"] = np.asarray(joint_type.value) + return self + + def get_joint_type(self) -> LineJointType: + return LineJointType(int(self.uniforms["joint_type"][0])) # Points def set_anchors_and_handles(self, anchors1, handles, anchors2): @@ -511,9 +542,12 @@ def add_smooth_curve_to(self, point): self.add_quadratic_bezier_curve_to(new_handle, point) return self - def add_smooth_cubic_curve_to(self, handle, point): + def add_smooth_cubic_curve_to(self, handle: np.ndarray, point: np.ndarray): self.throw_error_if_no_points() - new_handle = self.get_reflection_of_last_handle() + if self.get_num_points() == 1: + new_handle = self.points[-1] + else: + new_handle = self.get_reflection_of_last_handle() self.add_cubic_bezier_curve_to(new_handle, handle, point) def has_new_path_started(self): @@ -581,12 +615,15 @@ def set_points_as_corners(self, points: Iterable[float]) -> OpenGLVMobject: ) return self - def set_points_smoothly(self, points, true_smooth=False): + def set_points_smoothly(self, points, true_smooth=False) -> OpenGLVMobject: self.set_points_as_corners(points) - self.make_smooth() + if true_smooth: + self.make_smooth() + else: + self.make_approximately_smooth() return self - def change_anchor_mode(self, mode): + def change_anchor_mode(self, mode) -> OpenGLVMobject: """Changes the anchor mode of the bezier curves. This will modify the handles. There can be only three modes, "jagged", "approx_smooth" and "true_smooth". @@ -605,17 +642,13 @@ def change_anchor_mode(self, mode): anchors = np.vstack([subpath[::nppc], subpath[-1:]]) new_subpath = np.array(subpath) if mode == "approx_smooth": - # TODO: get_smooth_quadratic_bezier_handle_points is not defined new_subpath[1::nppc] = get_smooth_quadratic_bezier_handle_points( - anchors, + anchors ) elif mode == "true_smooth": h1, h2 = get_smooth_cubic_bezier_handle_points(anchors) new_subpath = get_quadratic_approximation_of_cubic( - anchors[:-1], - h1, - h2, - anchors[1:], + anchors[:-1], h1, h2, anchors[1:] ) elif mode == "jagged": new_subpath[1::nppc] = 0.5 * (anchors[:-1] + anchors[1:]) @@ -823,6 +856,55 @@ def get_num_curves(self) -> int: """ return self.get_num_points() // self.n_points_per_curve + def quick_point_from_proportion(self, alpha: float) -> np.ndarray: + # Assumes all curves have the same length, so is inaccurate + num_curves = self.get_num_curves() + n, residue = integer_interpolate(0, num_curves, alpha) + curve_func = self.get_nth_curve_function(n) + return curve_func(residue) + + def point_from_proportion(self, alpha: float) -> np.ndarray: + """Gets the point at a proportion along the path of the :class:`OpenGLVMobject`. + + Parameters + ---------- + alpha + The proportion along the the path of the :class:`OpenGLVMobject`. + + Returns + ------- + :class:`numpy.ndarray` + The point on the :class:`OpenGLVMobject`. + + Raises + ------ + :exc:`ValueError` + If ``alpha`` is not between 0 and 1. + :exc:`Exception` + If the :class:`OpenGLVMobject` has no points. + """ + + if alpha <= 0: + return self.get_start() + elif alpha >= 1: + return self.get_end() + + partials = [0.0] + for tup in self.get_bezier_tuples(): + # Approximate length with straight line from start to end + arclen = get_norm(tup[0] - tup[-1]) + partials.append(partials[-1] + arclen) + full = partials[-1] + if full == 0: + return self.get_start() + # First index where the partial length is more alpha times the full length + i = next( + (i for i, x in enumerate(partials) if x >= full * alpha), + len(partials), # Default + ) + residue = inverse_interpolate(partials[i - 1] / full, partials[i] / full, alpha) + return self.get_nth_curve_function(i - 1)(residue) # type: ignore + def get_nth_curve_length( self, n: int, @@ -888,7 +970,7 @@ def get_nth_curve_length_pieces( curve = self.get_nth_curve_function(n) points = np.array([curve(a) for a in np.linspace(0, 1, sample_points)]) diffs = points[1:] - points[:-1] - norms = np.apply_along_axis(np.linalg.norm, 1, diffs) + norms = np.apply_along_axis(np.linalg.norm, 1, diffs) # type: ignore return norms @@ -913,50 +995,6 @@ def get_curve_functions_with_lengths( for n in range(num_curves): yield self.get_nth_curve_function_with_length(n, **kwargs) - def point_from_proportion(self, alpha: float) -> np.ndarray: - """Gets the point at a proportion along the path of the :class:`OpenGLVMobject`. - - Parameters - ---------- - alpha - The proportion along the the path of the :class:`OpenGLVMobject`. - - Returns - ------- - :class:`numpy.ndarray` - The point on the :class:`OpenGLVMobject`. - - Raises - ------ - :exc:`ValueError` - If ``alpha`` is not between 0 and 1. - :exc:`Exception` - If the :class:`OpenGLVMobject` has no points. - """ - - if alpha < 0 or alpha > 1: - raise ValueError(f"Alpha {alpha} not between 0 and 1.") - - self.throw_error_if_no_points() - if alpha == 1: - return self.points[-1] - - curves_and_lengths = tuple(self.get_curve_functions_with_lengths()) - - target_length = alpha * np.sum(length for _, length in curves_and_lengths) - current_length = 0 - - for curve, length in curves_and_lengths: - if current_length + length >= target_length: - if length != 0: - residue = (target_length - current_length) / length - else: - residue = 0 - - return curve(residue) - - current_length += length - def proportion_from_point( self, point: Iterable[float | int], @@ -993,7 +1031,7 @@ def proportion_from_point( num_curves = self.get_num_curves() total_length = self.get_arc_length() - target_length = 0 + target_length = 0.0 for n in range(num_curves): control_points = self.get_nth_curve_points(n) length = self.get_nth_curve_length(n) @@ -1063,8 +1101,8 @@ def get_anchors(self) -> np.ndarray: self.get_start_anchors(), self.get_end_anchors(), ) - ), - ), + ) + ) ) def get_points_without_null_curves(self, atol=1e-9): @@ -1079,13 +1117,14 @@ def get_points_without_null_curves(self, atol=1e-9): ) return points[distinct_curves.repeat(nppc)] - def get_arc_length(self, sample_points_per_curve: int | None = None) -> float: + def get_arc_length(self, n_sample_points: int | None = None) -> float: """Return the approximated length of the whole curve. Parameters ---------- - sample_points_per_curve - Number of sample points per curve used to approximate the length. More points result in a better approximation. + n_sample_points + The number of points to sample. If ``None``, the number of points is calculated automatically. + Takes points on the outline of the :class:`OpenGLVMobject` and calculates the distance between them. Returns ------- @@ -1093,12 +1132,14 @@ def get_arc_length(self, sample_points_per_curve: int | None = None) -> float: The length of the :class:`OpenGLVMobject`. """ - return np.sum( - length - for _, length in self.get_curve_functions_with_lengths( - sample_points=sample_points_per_curve, - ) + if n_sample_points is None: + n_sample_points = 4 * self.get_num_curves() + 1 + points = np.array( + [self.point_from_proportion(a) for a in np.linspace(0, 1, n_sample_points)] ) + diffs = points[1:] - points[:-1] + norms = np.array([get_norm(d) for d in diffs]) + return norms.sum() def get_area_vector(self): # Returns a vector whose length is the area bound by @@ -1113,6 +1154,11 @@ def get_area_vector(self): p0 = points[0::nppc] p1 = points[nppc - 1 :: nppc] + if len(p0) != len(p1): + m = min(len(p0), len(p1)) + p0 = p0[:m] + p1 = p1[:m] + # Each term goes through all edges [(x1, y1, z1), (x2, y2, z2)] return 0.5 * np.array( [ @@ -1148,28 +1194,21 @@ def get_direction(self): """ return shoelace_direction(self.get_start_anchors()) - def get_unit_normal(self, recompute=False): - if not recompute: - return self.unit_normal[0] - - if len(self.points) < 3: + def get_unit_normal(self) -> np.ndarray: + if self.get_num_points() < 3: return OUT area_vect = self.get_area_vector() - area = np.linalg.norm(area_vect) + area = get_norm(area_vect) if area > 0: - return area_vect / area + normal = area_vect / area else: points = self.points - return get_unit_normal( + normal = get_unit_normal( points[1] - points[0], points[2] - points[1], ) - - def refresh_unit_normal(self): - for mob in self.get_family(): - mob.unit_normal[:] = mob.get_unit_normal(recompute=True) - return self + return normal # Alignment def align_points(self, vmobject): @@ -1202,14 +1241,14 @@ def get_nth_subpath(path_list, n): # Create a null path at the very end return [path_list[-1][-1]] * nppc path = path_list[n] - # Check for useless points at the end of the path and remove them - # https://github.com/ManimCommunity/manim/issues/1959 - while len(path) > nppc: - # If the last nppc points are all equal to the preceding point - if self.consider_points_equals(path[-nppc:], path[-nppc - 1]): - path = path[:-nppc] - else: - break + # # Check for useless points at the end of the path and remove them + # # https://github.com/ManimCommunity/manim/issues/1959 + # while len(path) > nppc: + # # If the last nppc points are all equal to the preceding point + # if self.consider_points_equals(path[-nppc:], path[-nppc - 1]): + # path = path[:-nppc] + # else: + # break return path for n in range(n_subpaths): @@ -1269,7 +1308,7 @@ def insert_n_curves_to_point_list(self, n: int, points: np.ndarray) -> np.ndarra return np.repeat(points, nppc * n, 0) bezier_groups = self.get_bezier_tuples_from_points(points) - norms = np.array([np.linalg.norm(bg[nppc - 1] - bg[0]) for bg in bezier_groups]) + norms = np.array([get_norm(bg[nppc - 1] - bg[0]) for bg in bezier_groups]) total_norm = sum(norms) # Calculate insertions per curve (ipc) if total_norm < 1e-6: @@ -1295,6 +1334,7 @@ def insert_n_curves_to_point_list(self, n: int, points: np.ndarray) -> np.ndarra def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs): super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) + # TODO: Do we still need this? Because for many scenes it just doesn't work if config["use_projection_fill_shaders"]: self.refresh_triangulation() else: @@ -1305,6 +1345,7 @@ def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs): self.refresh_triangulation() return self + # TODO: compare to 3b1b/manim again check if something changed so we don't need the cairo interpolation anymore def pointwise_become_partial( self, vmobject: OpenGLVMobject, a: float, b: float, remap: bool = True ) -> OpenGLVMobject: @@ -1399,45 +1440,43 @@ def get_subcurve(self, a: float, b: float) -> OpenGLVMobject: def refresh_triangulation(self): for mob in self.get_family(): mob.needs_new_triangulation = True + mob.data["orientation"] = resize_array( + mob.data["orientation"], mob.get_num_points() + ) return self - def get_triangulation(self, normal_vector=None): + def get_triangulation(self): # Figure out how to triangulate the interior to know # how to send the points as to the vertex shader. # First triangles come directly from the points - if normal_vector is None: - normal_vector = self.get_unit_normal() - if not self.needs_new_triangulation: return self.triangulation - points = self.points + points = self.get_points() if len(points) <= 1: self.triangulation = np.zeros(0, dtype="i4") self.needs_new_triangulation = False return self.triangulation - if not np.isclose(normal_vector, OUT).all(): - # Rotate points such that unit normal vector is OUT - points = np.dot(points, z_to_vector(normal_vector)) + normal_vector = self.get_unit_normal() indices = np.arange(len(points), dtype=int) - b0s = points[0::3] - b1s = points[1::3] - b2s = points[2::3] - v01s = b1s - b0s - v12s = b2s - b1s - - crosses = cross2d(v01s, v12s) - convexities = np.sign(crosses) + # Rotate points such that unit normal vector is OUT + if not np.isclose(normal_vector, OUT).all(): + points = np.dot(points, z_to_vector(normal_vector)) atol = self.tolerance_for_point_equality - end_of_loop = np.zeros(len(b0s), dtype=bool) - end_of_loop[:-1] = (np.abs(b2s[:-1] - b0s[1:]) > atol).any(1) + end_of_loop = np.zeros(len(points) // 3, dtype=bool) + end_of_loop[:-1] = (np.abs(points[2:-3:3] - points[3::3]) > atol).any(1) end_of_loop[-1] = True - concave_parts = convexities < 0 + v01s = points[1::3] - points[0::3] + v12s = points[2::3] - points[1::3] + curve_orientations = np.sign(cross2d(v01s, v12s)) + self.data["orientation"] = np.transpose([curve_orientations.repeat(3)]) + + concave_parts = curve_orientations < 0 # These are the vertices to which we'll apply a polygon triangulation inner_vert_indices = np.hstack( @@ -1445,7 +1484,7 @@ def get_triangulation(self, normal_vector=None): indices[0::3], indices[1::3][concave_parts], indices[2::3][end_of_loop], - ], + ] ) inner_vert_indices.sort() rings = np.arange(1, len(inner_vert_indices) + 1)[inner_vert_indices % 3 == 2] @@ -1461,11 +1500,28 @@ def get_triangulation(self, normal_vector=None): self.needs_new_triangulation = False return tri_indices + @staticmethod + def triggers_refreshed_triangulation(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + func(self, *args, **kwargs) + self.refresh_triangulation() + + return wrapper + @triggers_refreshed_triangulation def set_points(self, points): super().set_points(points) return self + @triggers_refreshed_triangulation + def append_points(self, points): + return super().append_points(points) + + @triggers_refreshed_triangulation + def reverse_points(self): + return super().reverse_points() + @triggers_refreshed_triangulation def set_data(self, data): super().set_data(data) @@ -1496,36 +1552,42 @@ def init_shader_data(self): self.fill_shader_wrapper = ShaderWrapper( vert_data=self.fill_data, vert_indices=np.zeros(0, dtype="i4"), + uniforms=self.uniforms, shader_folder=self.fill_shader_folder, render_primitive=self.render_primitive, ) self.stroke_shader_wrapper = ShaderWrapper( vert_data=self.stroke_data, + uniforms=self.uniforms, shader_folder=self.stroke_shader_folder, render_primitive=self.render_primitive, ) + self.shader_wrapper_list = [ + self.stroke_shader_wrapper.copy(), # Use for back stroke + self.fill_shader_wrapper.copy(), + self.stroke_shader_wrapper.copy(), + ] + for sw in self.shader_wrapper_list: + sw.uniforms = self.uniforms + def refresh_shader_wrapper_id(self): for wrapper in [self.fill_shader_wrapper, self.stroke_shader_wrapper]: wrapper.refresh_id() return self - def get_fill_shader_wrapper(self): - self.update_fill_shader_wrapper() - return self.fill_shader_wrapper - - def update_fill_shader_wrapper(self): + def get_fill_shader_wrapper(self) -> ShaderWrapper: + self.fill_shader_wrapper.vert_indices = self.get_fill_shader_vert_indices() self.fill_shader_wrapper.vert_data = self.get_fill_shader_data() - self.fill_shader_wrapper.vert_indices = self.get_triangulation() - self.fill_shader_wrapper.uniforms = self.get_fill_uniforms() - - def get_stroke_shader_wrapper(self): - self.update_stroke_shader_wrapper() - return self.stroke_shader_wrapper + self.fill_shader_wrapper.uniforms = self.get_shader_uniforms() + self.fill_shader_wrapper.depth_test = self.depth_test + return self.fill_shader_wrapper - def update_stroke_shader_wrapper(self): + def get_stroke_shader_wrapper(self) -> ShaderWrapper: self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data() - self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms() + self.stroke_shader_wrapper.uniforms = self.get_shader_uniforms() + self.stroke_shader_wrapper.depth_test = self.depth_test + return self.stroke_shader_wrapper def get_shader_wrapper_list(self): # Build up data lists @@ -1533,9 +1595,9 @@ def get_shader_wrapper_list(self): stroke_shader_wrappers = [] back_stroke_shader_wrappers = [] for submob in self.family_members_with_points(): - if submob.has_fill() and not config["use_projection_fill_shaders"]: + if submob.has_fill(): fill_shader_wrappers.append(submob.get_fill_shader_wrapper()) - if submob.has_stroke() and not config["use_projection_stroke_shaders"]: + if submob.has_stroke(): ssw = submob.get_stroke_shader_wrapper() if submob.draw_stroke_behind_fill: back_stroke_shader_wrappers.append(ssw) @@ -1543,38 +1605,23 @@ def get_shader_wrapper_list(self): stroke_shader_wrappers.append(ssw) # Combine data lists - wrapper_lists = [ + sw_lists = [ back_stroke_shader_wrappers, fill_shader_wrappers, stroke_shader_wrappers, ] - result = [] - for wlist in wrapper_lists: - if wlist: - wrapper = wlist[0] - wrapper.combine_with(*wlist[1:]) - result.append(wrapper) - return result - - def get_stroke_uniforms(self): - result = dict(super().get_shader_uniforms()) - result["joint_type"] = self.joint_type.value - result["flat_stroke"] = float(self.flat_stroke) - return result - - def get_fill_uniforms(self): - return { - "is_fixed_in_frame": float(self.is_fixed_in_frame), - "is_fixed_orientation": float(self.is_fixed_orientation), - "fixed_orientation_center": self.fixed_orientation_center, - "gloss": self.gloss, - "shadow": self.shadow, - } - - def get_stroke_shader_data(self): - points = self.points + for sw, sw_list in zip(self.shader_wrapper_list, sw_lists): + if not sw_list: + continue + sw.read_in(*sw_list) + sw.depth_test = any(sw.depth_test for sw in sw_list) + sw.uniforms.update(sw_list[0].uniforms) + return list(filter(lambda sw: len(sw.vert_data) > 0, self.shader_wrapper_list)) + + def get_stroke_shader_data(self) -> np.ndarray: + points = self.get_points() if len(self.stroke_data) != len(points): - self.stroke_data = np.zeros(len(points), dtype=OpenGLVMobject.stroke_dtype) + self.stroke_data = resize_array(self.stroke_data, len(points)) if "points" not in self.locked_data_keys: nppc = self.n_points_per_curve @@ -1586,19 +1633,18 @@ def get_stroke_shader_data(self): self.read_data_to_shader(self.stroke_data, "color", "stroke_rgba") self.read_data_to_shader(self.stroke_data, "stroke_width", "stroke_width") - self.read_data_to_shader(self.stroke_data, "unit_normal", "unit_normal") return self.stroke_data - def get_fill_shader_data(self): - points = self.points + def get_fill_shader_data(self) -> np.ndarray: + points = self.get_points() if len(self.fill_data) != len(points): - self.fill_data = np.zeros(len(points), dtype=OpenGLVMobject.fill_dtype) + self.fill_data = resize_array(self.fill_data, len(points)) self.fill_data["vert_index"][:, 0] = range(len(points)) self.read_data_to_shader(self.fill_data, "point", "points") self.read_data_to_shader(self.fill_data, "color", "fill_rgba") - self.read_data_to_shader(self.fill_data, "unit_normal", "unit_normal") + self.read_data_to_shader(self.fill_data, "orientation", "orientation") return self.fill_data @@ -1606,7 +1652,7 @@ def refresh_shader_data(self): self.get_fill_shader_data() self.get_stroke_shader_data() - def get_fill_shader_vert_indices(self): + def get_fill_shader_vert_indices(self) -> np.ndarray: return self.get_triangulation() @@ -1693,7 +1739,7 @@ def __str__(self): f"submobject{'s' if len(self.submobjects) > 0 else ''}" ) - def add(self, *vmobjects: OpenGLVMobject): + def add(self, *vmobjects: OpenGLVMobject): # type: ignore """Checks if all passed elements are an instance of OpenGLVMobject and then add them to submobjects Parameters @@ -1790,7 +1836,7 @@ def __setitem__(self, key: int, value: OpenGLVMobject | Sequence[OpenGLVMobject] """ if not all(isinstance(m, OpenGLVMobject) for m in value): raise TypeError("All submobjects must be of type OpenGLVMobject") - self.submobjects[key] = value + self.submobjects[key] = value # type: ignore class OpenGLVectorizedPoint(OpenGLPoint, OpenGLVMobject): @@ -1800,15 +1846,15 @@ def __init__( color=BLACK, fill_opacity=0, stroke_width=0, - artificial_width=0.01, - artificial_height=0.01, **kwargs, ): - self.artificial_width = artificial_width - self.artificial_height = artificial_height - - super().__init__( - color=color, fill_opacity=fill_opacity, stroke_width=stroke_width, **kwargs + OpenGLPoint.__init__(self, location, **kwargs) + OpenGLVMobject.__init__( + self, + color=color, + fill_opacity=fill_opacity, + stroke_width=stroke_width, + **kwargs, ) self.set_points(np.array([location])) @@ -1878,32 +1924,50 @@ def __init__( self, vmobject: OpenGLVMobject, num_dashes: int = 15, - dashed_ratio: float = 0.5, - color: Color = WHITE, + positive_space_ratio: float = 0.5, **kwargs, ): - self.dashed_ratio = dashed_ratio - self.num_dashes = num_dashes - super().__init__(color=color, **kwargs) - r = self.dashed_ratio - n = self.num_dashes + super().__init__(**kwargs) + if num_dashes > 0: - # Assuming total length is 1 - dash_len = r / n - if vmobject.is_closed(): - void_len = (1 - r) / n - else: - void_len = (1 - r) / (n - 1) + # End points of the unit interval for division + alphas = np.linspace(0, 1, num_dashes + 1) + + # This determines the length of each "dash" + full_d_alpha = 1.0 / num_dashes + partial_d_alpha = full_d_alpha * positive_space_ratio + + # Rescale so that the last point of vmobject will + # be the end of the last dash + alphas /= 1 - full_d_alpha + partial_d_alpha self.add( - *( - vmobject.get_subcurve( - i * (dash_len + void_len), - i * (dash_len + void_len) + dash_len, - ) - for i in range(n) - ) + *[ + vmobject.get_subcurve(alpha, alpha + partial_d_alpha) + for alpha in alphas[:-1] + ] ) # Family is already taken care of by get_subcurve # implementation self.match_style(vmobject, recurse=False) + + +class VHighlight(OpenGLVGroup): + def __init__( + self, + vmobject: OpenGLVMobject, + n_layers: int = 5, + color_bounds: tuple[Color, Color] = (GREY_C, GREY_E), + max_stroke_addition: float = 5.0, + ): + outline = vmobject.replicate(n_layers) + outline.set_fill(opacity=0) + added_widths = np.linspace(0, max_stroke_addition, n_layers + 1)[1:] + colors = color_gradient(color_bounds, n_layers) + for part, added_width, color in zip(reversed(outline), added_widths, colors): + for sm in part.family_members_with_points(): + sm.set_stroke( + width=sm.get_stroke_width() + added_width, + color=color, + ) + super().__init__(*outline) diff --git a/manim/mobject/types/vectorized_mobject.py b/manim/mobject/types/vectorized_mobject.py index 36bf146263..b5cf76ec08 100644 --- a/manim/mobject/types/vectorized_mobject.py +++ b/manim/mobject/types/vectorized_mobject.py @@ -2410,7 +2410,6 @@ def __init__( equal_lengths=True, **kwargs, ): - self.dashed_ratio = dashed_ratio self.num_dashes = num_dashes super().__init__(color=color, **kwargs) diff --git a/manim/renderer/shader_wrapper.py b/manim/renderer/shader_wrapper.py index 8eff1772c1..0c32441ad0 100644 --- a/manim/renderer/shader_wrapper.py +++ b/manim/renderer/shader_wrapper.py @@ -7,6 +7,8 @@ import moderngl import numpy as np +from manim.utils.iterables import resize_array + from .. import logger # Mobjects that should be rendered with @@ -55,6 +57,29 @@ def __init__( self.init_program_code() self.refresh_id() + def __eq__(self, shader_wrapper: object): + if not isinstance(shader_wrapper, ShaderWrapper): + raise TypeError( + f"Cannot compare ShaderWrapper with non-ShaderWrapper object of type {type(shader_wrapper)}" + ) + return all( + ( + np.all(self.vert_data == shader_wrapper.vert_data), + np.all(self.vert_indices == shader_wrapper.vert_indices), + self.shader_folder == shader_wrapper.shader_folder, + all( + np.all(self.uniforms[key] == shader_wrapper.uniforms[key]) + for key in self.uniforms + ), + all( + self.texture_paths[key] == shader_wrapper.texture_paths[key] + for key in self.texture_paths + ), + self.depth_test == shader_wrapper.depth_test, + self.render_primitive == shader_wrapper.render_primitive, + ) + ) + def copy(self): result = copy.copy(self) result.vert_data = np.array(self.vert_data) @@ -125,30 +150,34 @@ def get_program_code(self): def replace_code(self, old, new): code_map = self.program_code - for (name, _code) in code_map.items(): + for name, _code in code_map.items(): if code_map[name] is None: continue code_map[name] = re.sub(old, new, code_map[name]) self.refresh_id() - def combine_with(self, *shader_wrappers): - # Assume they are of the same type - if len(shader_wrappers) == 0: - return + def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: + self.read_in(self.copy(), *shader_wrappers) + return self + + def read_in(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: + # Assume all are of the same type + total_len = sum(len(sw.vert_data) for sw in shader_wrappers) + self.vert_data = resize_array(self.vert_data, total_len) if self.vert_indices is not None: - num_verts = len(self.vert_data) - indices_list = [self.vert_indices] - data_list = [self.vert_data] - for sw in shader_wrappers: - indices_list.append(sw.vert_indices + num_verts) - data_list.append(sw.vert_data) - num_verts += len(sw.vert_data) - self.vert_indices = np.hstack(indices_list) - self.vert_data = np.hstack(data_list) - else: - self.vert_data = np.hstack( - [self.vert_data, *(sw.vert_data for sw in shader_wrappers)], - ) + total_verts = sum(len(sw.vert_indices) for sw in shader_wrappers) + self.vert_indices = resize_array(self.vert_indices, total_verts) + + n_points = 0 + n_verts = 0 + for sw in shader_wrappers: + new_n_points = n_points + len(sw.vert_data) + self.vert_data[n_points:new_n_points] = sw.vert_data + if self.vert_indices is not None and sw.vert_indices is not None: + new_n_verts = n_verts + len(sw.vert_indices) + self.vert_indices[n_verts:new_n_verts] = sw.vert_indices + n_points + n_verts = new_n_verts + n_points = new_n_points return self diff --git a/manim/utils/bezier.py b/manim/utils/bezier.py index ce5fff437f..8943bfd2a5 100644 --- a/manim/utils/bezier.py +++ b/manim/utils/bezier.py @@ -288,6 +288,37 @@ def match_interpolate( # Figuring out which bezier curves most smoothly connect a sequence of points +def get_smooth_quadratic_bezier_handle_points(points: FloatArray) -> FloatArray: + """ + Figuring out which bezier curves most smoothly connect a sequence of points. + + Given three successive points, P0, P1 and P2, you can compute that by defining + h = (1/4) P0 + P1 - (1/4)P2, the bezier curve defined by (P0, h, P1) will pass + through the point P2. + + So for a given set of four successive points, P0, P1, P2, P3, if we want to add + a handle point h between P1 and P2 so that the quadratic bezier (P1, h, P2) is + part of a smooth curve passing through all four points, we calculate one solution + for h that would produce a parbola passing through P3, call it smooth_to_right, and + another that would produce a parabola passing through P0, call it smooth_to_left, + and use the midpoint between the two. + """ + if len(points) == 2: + return midpoint(*points) + smooth_to_right, smooth_to_left = ( + 0.25 * ps[0:-2] + ps[1:-1] - 0.25 * ps[2:] for ps in (points, points[::-1]) + ) + if np.isclose(points[0], points[-1]).all(): + last_str = 0.25 * points[-2] + points[-1] - 0.25 * points[1] + last_stl = 0.25 * points[1] + points[0] - 0.25 * points[-2] + else: + last_str = smooth_to_left[0] + last_stl = smooth_to_right[0] + handles = 0.5 * np.vstack([smooth_to_right, [last_str]]) + handles += 0.5 * np.vstack([last_stl, smooth_to_left[::-1]]) + return handles + + def get_smooth_cubic_bezier_handle_points(points): points = np.array(points) num_handles = len(points) - 1 diff --git a/manim/utils/directories.py b/manim/utils/directories.py new file mode 100644 index 0000000000..e51f3c3c59 --- /dev/null +++ b/manim/utils/directories.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import os + +from manim._config import config +from manim.utils.file_ops import guarantee_existence + + +def get_directories() -> dict[str, str]: + return config["directories"] + + +def get_temp_dir() -> str: + return get_directories()["temporary_storage"] + + +def get_tex_dir() -> str: + return guarantee_existence(os.path.join(get_temp_dir(), "Tex")) + + +def get_text_dir() -> str: + return guarantee_existence(os.path.join(get_temp_dir(), "Text")) + + +def get_mobject_data_dir() -> str: + return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data")) + + +def get_downloads_dir() -> str: + return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads")) + + +def get_output_dir() -> str: + return guarantee_existence(get_directories()["output"]) + + +def get_raster_image_dir() -> str: + return get_directories()["raster_images"] + + +def get_vector_image_dir() -> str: + return get_directories()["vector_images"] + + +def get_sound_dir() -> str: + return get_directories()["sounds"] + + +def get_shader_dir() -> str: + return get_directories()["shaders"] diff --git a/manim/utils/file_ops.py b/manim/utils/file_ops.py index 1bbb048214..21bda0b8c6 100644 --- a/manim/utils/file_ops.py +++ b/manim/utils/file_ops.py @@ -27,8 +27,11 @@ from shutil import copyfile from typing import TYPE_CHECKING +import validators + if TYPE_CHECKING: from ..scene.scene_file_writer import SceneFileWriter + from typing import Iterable from manim import __version__, config, logger From f76c07bffe3c13be221c96b60058fe7be566d551 Mon Sep 17 00:00:00 2001 From: MrDiver Date: Fri, 30 Dec 2022 00:36:42 +0100 Subject: [PATCH 3/4] first working render --- manim/constants.py | 3 + manim/mobject/opengl/opengl_mobject.py | 2 +- .../opengl/opengl_vectorized_mobject.py | 4 +- manim/renderer/opengl_renderer.py | 505 +++++++++++++++++- manim/utils/simple_functions.py | 13 + manim/utils/space_ops.py | 10 + 6 files changed, 528 insertions(+), 9 deletions(-) diff --git a/manim/constants.py b/manim/constants.py index d81d85902c..efbb86e1d4 100644 --- a/manim/constants.py +++ b/manim/constants.py @@ -194,6 +194,9 @@ DEGREES: float = TAU / 360 """The exchange rate between radians and degrees.""" +RADIANS: float = 1.0 +"""Just a default to select for camera.""" + # Video qualities QUALITIES: dict[str, dict[str, str | int | None]] = { "fourk_quality": { diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 6ac029ea2f..724f6ceffd 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -118,7 +118,7 @@ def __init__( self.target: OpenGLMobject | None = None self.data: dict[str, np.ndarray] = {} - self.uniforms: dict[str, np.ndarray] = {} + self.uniforms: dict[str, float | np.ndarray] = {} self.init_data() self.init_uniforms() diff --git a/manim/mobject/opengl/opengl_vectorized_mobject.py b/manim/mobject/opengl/opengl_vectorized_mobject.py index 207d404e69..7fd491f9c7 100644 --- a/manim/mobject/opengl/opengl_vectorized_mobject.py +++ b/manim/mobject/opengl/opengl_vectorized_mobject.py @@ -1619,7 +1619,7 @@ def get_shader_wrapper_list(self): return list(filter(lambda sw: len(sw.vert_data) > 0, self.shader_wrapper_list)) def get_stroke_shader_data(self) -> np.ndarray: - points = self.get_points() + points = self.points if len(self.stroke_data) != len(points): self.stroke_data = resize_array(self.stroke_data, len(points)) @@ -1637,7 +1637,7 @@ def get_stroke_shader_data(self) -> np.ndarray: return self.stroke_data def get_fill_shader_data(self) -> np.ndarray: - points = self.get_points() + points = self.points if len(self.fill_data) != len(points): self.fill_data = resize_array(self.fill_data, len(points)) self.fill_data["vert_index"][:, 0] = range(len(points)) diff --git a/manim/renderer/opengl_renderer.py b/manim/renderer/opengl_renderer.py index aa0bdbebf4..cdc6924be5 100644 --- a/manim/renderer/opengl_renderer.py +++ b/manim/renderer/opengl_renderer.py @@ -1,9 +1,14 @@ from __future__ import annotations import itertools as it +import math import sys import time -from typing import Any +from typing import Any, Iterable + +from manim.renderer.shader_wrapper import ShaderWrapper + +from ..constants import RADIANS if sys.version_info < (3, 8): from backports.cached_property import cached_property @@ -12,22 +17,25 @@ import moderngl import numpy as np +import OpenGL.GL as gl from PIL import Image +from scipy.spatial.transform import Rotation from manim import config, logger from manim.mobject.opengl.opengl_mobject import OpenGLMobject, OpenGLPoint from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject from manim.utils.caching import handle_caching_play -from manim.utils.color import color_to_rgba +from manim.utils.color import BLACK, color_to_rgba from manim.utils.exceptions import EndSceneEarlyException from ..constants import * from ..scene.scene_file_writer import SceneFileWriter from ..utils import opengl from ..utils.config_ops import _Data -from ..utils.simple_functions import clip +from ..utils.simple_functions import clip, fdiv from ..utils.space_ops import ( angle_of_vector, + normalize, quaternion_from_angle_axis, quaternion_mult, rotation_matrix_transpose, @@ -40,7 +48,494 @@ ) -class OpenGLCamera(OpenGLMobject): +class OpenGLCameraFrame(OpenGLMobject): + def __init__( + self, + frame_shape: tuple[float, float] = (config.frame_width, config.frame_height), + center_point: np.ndarray = ORIGIN, + focal_dist_to_height: float = 2.0, + **kwargs, + ): + self.frame_shape = frame_shape + self.center_point = center_point + self.focal_dist_to_height = focal_dist_to_height + super().__init__(**kwargs) + + def init_uniforms(self): + super().init_uniforms() + # as a quarternion + self.uniforms["orientation"] = Rotation.identity().as_quat() + self.uniforms["focal_dist_to_height"] = self.focal_dist_to_height + + def init_points(self) -> None: + self.set_points([ORIGIN, LEFT, RIGHT, DOWN, UP]) + self.set_width(self.frame_shape[0], stretch=True) + self.set_height(self.frame_shape[1], stretch=True) + self.move_to(self.center_point) + + def set_orientation(self, rotation: Rotation): + self.uniforms["orientation"] = rotation.as_quat() + return self + + def get_orientation(self): + return Rotation.from_quat(self.uniforms["orientation"]) + + def to_default_state(self): + self.center() + self.set_height(config.frame_width) + self.set_width(config.frame_height) + self.set_orientation(Rotation.identity()) + return self + + def get_euler_angles(self): + return self.get_orientation().as_euler("zxz")[::-1] + + def get_theta(self): + return self.get_euler_angles()[0] + + def get_phi(self): + return self.get_euler_angles()[1] + + def get_gamma(self): + return self.get_euler_angles()[2] + + def get_inverse_camera_rotation_matrix(self): + return self.get_orientation().as_matrix().T + + def rotate(self, angle: float, axis: np.ndarray = OUT, **kwargs): # type: ignore + rot = Rotation.from_rotvec(axis * normalize(axis)) # type: ignore + self.set_orientation(rot * self.get_orientation()) + + def set_euler_angles( + self, + theta: float | None = None, + phi: float | None = None, + gamma: float | None = None, + units: float = RADIANS, + ): + eulers = self.get_euler_angles() # theta, phi, gamma + for i, var in enumerate([theta, phi, gamma]): + if var is not None: + eulers[i] = var * units + self.set_orientation(Rotation.from_euler("zxz", eulers[::-1])) + return self + + def reorient( + self, + theta_degrees: float | None = None, + phi_degrees: float | None = None, + gamma_degrees: float | None = None, + ): + """ + Shortcut for set_euler_angles, defaulting to taking + in angles in degrees + """ + self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES) + return self + + def set_theta(self, theta: float): + return self.set_euler_angles(theta=theta) + + def set_phi(self, phi: float): + return self.set_euler_angles(phi=phi) + + def set_gamma(self, gamma: float): + return self.set_euler_angles(gamma=gamma) + + def increment_theta(self, dtheta: float): + self.rotate(dtheta, OUT) + return self + + def increment_phi(self, dphi: float): + self.rotate(dphi, self.get_inverse_camera_rotation_matrix()[0]) + return self + + def increment_gamma(self, dgamma: float): + self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2]) + return self + + def set_focal_distance(self, focal_distance: float): + self.uniforms["focal_dist_to_height"] = focal_distance / self.get_height() + return self + + def set_field_of_view(self, field_of_view: float): + self.uniforms["focal_dist_to_height"] = 2 * math.tan(field_of_view / 2) + return self + + def get_shape(self): + return (self.get_width(), self.get_height()) + + def get_center(self) -> np.ndarray: + # Assumes first point is at the center + return self.points[0] + + def get_width(self) -> float: + points = self.points + return points[2, 0] - points[1, 0] + + def get_height(self) -> float: + points = self.points + return points[4, 1] - points[3, 1] + + def get_focal_distance(self) -> float: + return self.uniforms["focal_dist_to_height"] * self.get_height() # type: ignore + + def get_field_of_view(self) -> float: + return 2 * math.atan(self.uniforms["focal_dist_to_height"] / 2) + + def get_implied_camera_location(self) -> np.ndarray: + to_camera = self.get_inverse_camera_rotation_matrix()[2] + dist = self.get_focal_distance() + return self.get_center() + dist * to_camera + + +class OpenGLCamera: + def __init__( + self, + ctx: moderngl.Context | None = None, + background_image: str | None = None, + frame_config: dict = {}, + pixel_width: int = config.pixel_width, + pixel_height: int = config.pixel_height, + fps: int = config.frame_rate, + # Note: frame height and width will be resized to match the pixel aspect rati + background_color=BLACK, + background_opacity: float = 1.0, + # Points in vectorized mobjects with norm greater + # than this value will be rescaled + max_allowable_norm: float = 1.0, + image_mode: str = "RGBA", + n_channels: int = 4, + pixel_array_dtype: type = np.uint8, + light_source_position: np.ndarray = np.array([-10, 10, 10]), + # Although vector graphics handle antialiasing fine + # without multisampling, for 3d scenes one might want + # to set samples to be greater than 0. + samples: int = 0, + ) -> None: + self.background_image = background_image + self.pixel_width = pixel_width + self.pixel_height = pixel_height + self.fps = fps + self.max_allowable_norm = max_allowable_norm + self.image_mode = image_mode + self.n_channels = n_channels + self.pixel_array_dtype = pixel_array_dtype + self.light_source_position = light_source_position + self.samples = samples + + self.rgb_max_val: float = np.iinfo(self.pixel_array_dtype).max + self.background_color: list[float] = list( + color_to_rgba(background_color, background_opacity) + ) + self.init_frame(**frame_config) + self.init_context(ctx) + self.init_shaders() + self.init_textures() + self.init_light_source() + self.refresh_perspective_uniforms() + # A cached map from mobjects to their associated list of render groups + # so that these render groups are not regenerated unnecessarily for static + # mobjects + self.mob_to_render_groups: dict = {} + + def init_frame(self, **config) -> None: + self.frame = OpenGLCameraFrame(**config) + + def init_context(self, ctx: moderngl.Context | None = None) -> None: + if ctx is None: + ctx = moderngl.create_standalone_context() + fbo = self.get_fbo(ctx, 0) + else: + fbo = ctx.detect_framebuffer() + + self.ctx = ctx + self.fbo = fbo + self.set_ctx_blending() + + # For multisample antisampling + fbo_msaa = self.get_fbo(ctx, self.samples) + fbo_msaa.use() + self.fbo_msaa = fbo_msaa + + def set_ctx_blending(self, enable: bool = True) -> None: + if enable: + self.ctx.enable(moderngl.BLEND) + else: + self.ctx.disable(moderngl.BLEND) + + def set_ctx_depth_test(self, enable: bool = True) -> None: + if enable: + self.ctx.enable(moderngl.DEPTH_TEST) + else: + self.ctx.disable(moderngl.DEPTH_TEST) + + def init_light_source(self) -> None: + self.light_source = OpenGLPoint(self.light_source_position) + + # Methods associated with the frame buffer + def get_fbo(self, ctx: moderngl.Context, samples: int = 0) -> moderngl.Framebuffer: + pw = self.pixel_width + ph = self.pixel_height + return ctx.framebuffer( + color_attachments=ctx.texture( + (pw, ph), components=self.n_channels, samples=samples + ), + depth_attachment=ctx.depth_renderbuffer((pw, ph), samples=samples), + ) + + def clear(self) -> None: + self.fbo.clear(*self.background_color) + self.fbo_msaa.clear(*self.background_color) + + def reset_pixel_shape(self, new_width: int, new_height: int) -> None: + self.pixel_width = new_width + self.pixel_height = new_height + self.refresh_perspective_uniforms() + + def get_raw_fbo_data(self, dtype: str = "f1") -> bytes: + # Copy blocks from the fbo_msaa to the drawn fbo using Blit + pw, ph = (self.pixel_width, self.pixel_height) + gl.glBindFrameBuffer(gl.GL_READ_FRAMEBUFFER, self.fbo_msaa.glo) + gl.glBindFrameBuffer(gl.GL_DRAW_FRAMEBUFFER, self.fbo.glo) + gl.glBlitFramebuffer( + 0, 0, pw, ph, 0, 0, pw, ph, gl.GL_COLOR_BUFFER_BIT, gl.GL_LINEAR + ) + return self.fbo.read( + viewport=self.fbo.viewport, + components=self.n_channels, + dtype=dtype, + ) + + def get_image(self) -> Image.Image: + return Image.frombytes( + "RGBA", + self.get_pixel_shape(), + self.get_raw_fbo_data(), + "raw", + "RGBA", + 0, + -1, + ) + + def get_pixel_array(self) -> np.ndarray: + raw = self.get_raw_fbo_data(dtype="f4") + flat_arr = np.frombuffer(raw, dtype="f4") + arr = flat_arr.reshape([*self.fbo.size, self.n_channels]) + # Convert from float + return (self.rgb_max_val * arr).astype(self.pixel_array_dtype) + + def get_texture(self): + texture = self.ctx.texture( + size=self.fbo.size, components=4, data=self.get_raw_fbo_data(), dtype="f4" + ) + return texture + + # Getting camera attributes + def get_pixel_shape(self) -> tuple[int, int]: + return self.fbo.viewport[2:4] + # return (self.pixel_width, self.pixel_height) + + def get_pixel_width(self) -> int: + return self.get_pixel_shape()[0] + + def get_pixel_height(self) -> int: + return self.get_pixel_shape()[1] + + def get_frame_height(self) -> float: + return self.frame.get_height() + + def get_frame_width(self) -> float: + return self.frame.get_width() + + def get_frame_shape(self) -> tuple[float, float]: + return (self.get_frame_width(), self.get_frame_height()) + + def get_frame_center(self) -> np.ndarray: + return self.frame.get_center() + + def get_location(self) -> tuple[float, float, float] | np.ndarray: + return self.frame.get_implied_camera_location() + + def resize_frame_shape(self, fixed_dimension: bool = False) -> None: + """ + Changes frame_shape to match the aspect ratio + of the pixels, where fixed_dimension determines + whether frame_height or frame_width + remains fixed while the other changes accordingly. + """ + pixel_height = self.get_pixel_height() + pixel_width = self.get_pixel_width() + frame_height = self.get_frame_height() + frame_width = self.get_frame_width() + aspect_ratio = fdiv(pixel_width, pixel_height) + if not fixed_dimension: + frame_height = frame_width / aspect_ratio + else: + frame_width = aspect_ratio * frame_height + self.frame.set_height(frame_height) + self.frame.set_width(frame_width) + + # Rendering + def capture(self, *mobjects: OpenGLMobject) -> None: + self.refresh_perspective_uniforms() + for mobject in mobjects: + for render_group in self.get_render_group_list(mobject): + self.render(render_group) + + def render(self, render_group: dict[str, Any]) -> None: + shader_wrapper: ShaderWrapper = render_group["shader_wrapper"] + shader_program = render_group["prog"] + self.set_shader_uniforms(shader_program, shader_wrapper) + self.set_ctx_depth_test(shader_wrapper.depth_test) + render_group["vao"].render(int(shader_wrapper.render_primitive)) + if render_group["single_use"]: + self.release_render_group(render_group) + + def get_render_group_list(self, mobject: OpenGLMobject) -> Iterable[dict[str, Any]]: + if mobject.is_changing(): + return self.generate_render_group_list(mobject) + + # Otherwise, cache result for later use + key = id(mobject) + if key not in self.mob_to_render_groups: + self.mob_to_render_groups[key] = list( + self.generate_render_group_list(mobject) + ) + return self.mob_to_render_groups[key] + + def generate_render_group_list( + self, mobject: OpenGLMobject + ) -> Iterable[dict[str, Any]]: + return ( + self.get_render_group(sw, single_use=mobject.is_changing()) + for sw in mobject.get_shader_wrapper_list() + ) + + def get_render_group( + self, shader_wrapper: ShaderWrapper, single_use: bool = True + ) -> dict[str, Any]: + # Data buffers + vbo = self.ctx.buffer(shader_wrapper.vert_data.tobytes()) + if shader_wrapper.vert_indices is None: + ibo = None + else: + vert_index_data = shader_wrapper.vert_indices.astype("i4").tobytes() + if vert_index_data: + ibo = self.ctx.buffer(vert_index_data) + else: + ibo = None + + # Program an vertex array + shader_program, vert_format = self.get_shader_program(shader_wrapper) # type: ignore + vao = self.ctx.vertex_array( + program=shader_program, + content=[(vbo, vert_format, *shader_wrapper.vert_attributes)], + index_buffer=ibo, + ) + return { + "vbo": vbo, + "ibo": ibo, + "vao": vao, + "prog": shader_program, + "shader_wrapper": shader_wrapper, + "single_use": single_use, + } + + def release_render_group(self, render_group: dict[str, Any]) -> None: + for key in ["vbo", "ibo", "vao"]: + if render_group[key] is not None: + render_group[key].release() + + def refresh_static_mobjects(self) -> None: + for render_group in it.chain(*self.mob_to_render_groups.values()): + self.release_render_group(render_group) + self.mob_to_render_groups = {} + + # Shaders + def init_shaders(self) -> None: + # Initialize with the null id going to None + self.id_to_shader_program: dict[int, tuple[moderngl.Program, str] | None] = { + hash(""): None + } + + def get_shader_program( + self, shader_wrapper: ShaderWrapper + ) -> tuple[moderngl.Program, str] | None: + sid = shader_wrapper.get_program_id() + if sid not in self.id_to_shader_program: + # Create shader program for the first time, then cache + # in the id_to_shader_program dictionary + program = self.ctx.program(**shader_wrapper.get_program_code()) + vert_format = moderngl.detect_format( + program, shader_wrapper.vert_attributes + ) + self.id_to_shader_program[sid] = (program, vert_format) + + return self.id_to_shader_program[sid] + + def set_shader_uniforms( + self, + shader: moderngl.Program, + shader_wrapper: ShaderWrapper, + ) -> None: + for name, path in shader_wrapper.texture_paths.items(): + tid = self.get_texture_id(path) + shader[name].value = tid + for name, value in it.chain( + self.perspective_uniforms.items(), shader_wrapper.uniforms.items() + ): + if name in shader: + if isinstance(value, np.ndarray) and value.ndim > 0: + value = tuple(value) + shader[name].value = value + + def refresh_perspective_uniforms(self) -> None: + frame = self.frame + # Orient light + rotation = frame.get_inverse_camera_rotation_matrix() + offset = frame.get_center() + light_pos = np.dot(rotation, self.light_source.get_location() + offset) + cam_pos = self.frame.get_implied_camera_location() # TODO + + self.perspective_uniforms = { + "frame_shape": frame.get_shape(), + "pixel_shape": self.get_pixel_shape(), + "camera_offset": tuple(offset), + "camera_rotation": tuple(np.array(rotation).T.flatten()), + "camera_position": tuple(cam_pos), + "light_source_position": tuple(light_pos), + "focal_distance": frame.get_focal_distance(), + } + + def init_textures(self) -> None: + self.n_textures: int = 0 + self.path_to_texture: dict[str, tuple[int, moderngl.Texture]] = {} + + def get_texture_id(self, path: str) -> int: + if path not in self.path_to_texture: + if self.n_textures == 15: # I have no clue why this is needed + self.n_textures += 1 + tid = self.n_textures + self.n_textures += 1 + im = Image.open(path).convert("RGBA") + texture = self.ctx.texture( + size=im.size, + components=len(im.getbands()), + data=im.tobytes(), + ) + texture.use(location=tid) + self.path_to_texture[path] = (tid, texture) + return self.path_to_texture[path][0] + + def release_texture(self, path: str): + tid_and_texture = self.path_to_texture.pop(path, None) + if tid_and_texture: + tid_and_texture[1].release() + return self + + +class OpenGLCameraLegacy(OpenGLMobject): euler_angles = _Data() def __init__( @@ -465,8 +960,6 @@ def update_frame(self, scene): self.refresh_perspective_uniforms(scene.camera) for mobject in scene.mobjects: - if not mobject.should_render: - continue self.render_mobject(mobject) for obj in scene.meshes: diff --git a/manim/utils/simple_functions.py b/manim/utils/simple_functions.py index 2a4a7afed6..c7f617bf59 100644 --- a/manim/utils/simple_functions.py +++ b/manim/utils/simple_functions.py @@ -117,6 +117,19 @@ def clip(a, min_a, max_a): return a +def fdiv( + a: Scalable, b: Scalable, zero_over_zero_value: Scalable | None = None +) -> Scalable: + if zero_over_zero_value is not None: + out = np.full_like(a, zero_over_zero_value) + where = np.logical_or(a != 0, b != 0) + else: + out = None + where = True + + return np.true_divide(a, b, out=out, where=where) + + def get_parameters(function: Callable) -> MappingProxyType[str, inspect.Parameter]: """Return the parameters of ``function`` as an ordered mapping of parameters' names to their corresponding ``Parameter`` objects. diff --git a/manim/utils/space_ops.py b/manim/utils/space_ops.py index cc98b169ba..43b8b140c1 100644 --- a/manim/utils/space_ops.py +++ b/manim/utils/space_ops.py @@ -301,6 +301,16 @@ def get_norm(vector: np.ndarray) -> float: return np.linalg.norm(vector) +def normalize(vect: list[float], fall_back: list[float] | None = None) -> np.ndarray: + norm = get_norm(vect) + if norm > 0: + return np.array(vect) / norm + elif fall_back is not None: + return np.array(fall_back) + else: + return np.zeros(len(vect)) + + def z_to_vector(vector: np.ndarray) -> np.ndarray: """ Returns some matrix in SO(3) which takes the z-axis to the From 862bb338efa9e30d120f7538e9eeb66480fe6fc1 Mon Sep 17 00:00:00 2001 From: MrDiver Date: Mon, 2 Jan 2023 15:46:05 +0100 Subject: [PATCH 4/4] first step to dump old scene structure --- manim/mobject/opengl/opengl_mobject.py | 136 +- .../opengl/opengl_vectorized_mobject.py | 39 +- manim/renderer/opengl_renderer.py | 3 +- manim/renderer/shader_wrapper.py | 2 + manim/scene/scene.py | 2254 ++++++----------- 5 files changed, 868 insertions(+), 1566 deletions(-) diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 724f6ceffd..568c040ea5 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -14,7 +14,6 @@ import moderngl import numpy as np from colour import Color -from typing_extensions import TypedDict from manim import config, logger from manim.constants import * @@ -51,7 +50,7 @@ if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, Tuple, Union - from typing_extensions import TypeAlias + from typing_extensions import Self, TypeAlias TimeBasedUpdater: TypeAlias = Callable[[OpenGLMobject, float], OpenGLMobject | None] NonTimeUpdater: TypeAlias = Callable[[OpenGLMobject], OpenGLMobject | None] @@ -142,12 +141,12 @@ def __str__(self): def __repr__(self): return str(self.name) - def __add__(self, other: OpenGLMobject) -> OpenGLMobject: + def __add__(self, other: OpenGLMobject) -> Self: if not isinstance(other, OpenGLMobject): raise TypeError(f"Only Mobjects can be added to Mobjects not {type(other)}") return self.get_group_class()(self, other) - def __mul__(self, other: int) -> OpenGLMobject: + def __mul__(self, other: int) -> Self: if not isinstance(other, int): raise TypeError(f"Only int can be multiplied to Mobjects not {type(other)}") return self.replicate(other) @@ -662,7 +661,7 @@ def __len__(self) -> int: def split(self) -> list[OpenGLMobject]: return self.submobjects - def assemble_family(self) -> OpenGLMobject: + def assemble_family(self) -> Self: sub_families = (sm.get_family() for sm in self.submobjects) self.family = [self, *uniq_chain(*sub_families)] self.refresh_has_updater_status() @@ -704,7 +703,7 @@ def get_ancestors(self, extended: bool = False) -> list[OpenGLMobject]: def add( self, *mobjects: OpenGLMobject, - ) -> OpenGLMobject: + ) -> Self: """Add mobjects as submobjects. The mobjects are added to :attr:`submobjects`. @@ -773,9 +772,7 @@ def add( self.assemble_family() return self - def remove( - self, *mobjects: OpenGLMobject, reassemble: bool = True - ) -> OpenGLMobject: + def remove(self, *mobjects: OpenGLMobject, reassemble: bool = True) -> Self: """Remove :attr:`submobjects`. The mobjects are removed from :attr:`submobjects`, if they exist. @@ -806,7 +803,7 @@ def remove( self.assemble_family() return self - def add_to_back(self, *mobjects: OpenGLMobject) -> OpenGLMobject: + def add_to_back(self, *mobjects: OpenGLMobject) -> Self: # NOTE: is the note true OpenGLMobjects? """Add all passed mobjects to the back of the submobjects. @@ -859,6 +856,7 @@ def replace_submobject(self, index, new_submob): if self in old_submob.parents: old_submob.parents.remove(self) self.submobjects[index] = new_submob + new_submob.parents.append(self) self.assemble_family() return self @@ -935,7 +933,7 @@ def arrange_in_grid_legacy( col_widths: Iterable[float | None] | None = None, flow_order: str = "rd", **kwargs, - ) -> OpenGLMobject: + ) -> Self: """Arrange submobjects in a grid. Parameters @@ -1324,18 +1322,18 @@ def wrapper(self, *args, **kwargs): def serialize(self) -> bytes: return pickle.dumps(self) - def deserialize(self, data: bytes) -> OpenGLMobject: + def deserialize(self, data: bytes) -> Self: self.become(pickle.loads(data)) return self - def deepcopy(self) -> OpenGLMobject: + def deepcopy(self) -> Self: try: return pickle.loads(pickle.dumps(self)) except AttributeError: return copy.deepcopy(self) @stash_mobject_pointers - def copy(self, deep: bool = False) -> OpenGLMobject: + def copy(self, deep: bool = False) -> Self: """Create and return an identical copy of the :class:`OpenGLMobject` including all :attr:`submobjects`. @@ -1487,7 +1485,7 @@ def init_updaters(self) -> None: self.has_updaters: bool = False self.updating_suspended: bool = False - def update(self, dt: float = 0, recurse: bool = True) -> OpenGLMobject: + def update(self, dt: float = 0, recurse: bool = True) -> Self: if not self.has_updaters or self.updating_suspended: return self for time_updater in self.time_based_updaters: @@ -1516,7 +1514,7 @@ def add_updater( update_function: Updater, index: int | None = None, call_updater: bool = False, - ) -> OpenGLMobject: + ) -> Self: if "dt" in get_parameters(update_function): updater_list: list[Updater] = self.time_based_updaters # type: ignore else: @@ -1532,7 +1530,7 @@ def add_updater( self.update() return self - def remove_updater(self, update_function: Updater) -> OpenGLMobject: + def remove_updater(self, update_function: Updater) -> Self: updater_lists: list[list[Updater]] = [ self.time_based_updaters, # type: ignore self.non_time_updaters, # type: ignore @@ -1543,7 +1541,7 @@ def remove_updater(self, update_function: Updater) -> OpenGLMobject: self.refresh_has_updater_status() return self - def clear_updaters(self, recurse: bool = True) -> OpenGLMobject: + def clear_updaters(self, recurse: bool = True) -> Self: self.time_based_updaters = [] self.non_time_updaters = [] self.refresh_has_updater_status() @@ -1552,22 +1550,20 @@ def clear_updaters(self, recurse: bool = True) -> OpenGLMobject: submob.clear_updaters() return self - def match_updaters(self, mobject: OpenGLMobject) -> OpenGLMobject: + def match_updaters(self, mobject: OpenGLMobject) -> Self: self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self, recurse: bool = True) -> OpenGLMobject: + def suspend_updating(self, recurse: bool = True) -> Self: self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating( - self, recurse: bool = True, call_updater: bool = True - ) -> OpenGLMobject: + def resume_updating(self, recurse: bool = True, call_updater: bool = True) -> Self: self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -1578,7 +1574,7 @@ def resume_updating( self.update(dt=0, recurse=recurse) return self - def refresh_has_updater_status(self) -> OpenGLMobject: + def refresh_has_updater_status(self) -> Self: self.has_updaters = any(mob.get_updaters() for mob in self.get_family()) return self @@ -1587,16 +1583,14 @@ def refresh_has_updater_status(self) -> OpenGLMobject: def is_changing(self) -> bool: return self.has_updaters or self._is_animating - def set_animating_status( - self, is_animating: bool, recurse: bool = True - ) -> OpenGLMobject: + def set_animating_status(self, is_animating: bool, recurse: bool = True) -> Self: for mob in (*self.get_family(recurse), *self.get_ancestors(extended=True)): mob._is_animating = is_animating return self # Transforming operations - def shift(self, vector) -> OpenGLMobject: + def shift(self, vector) -> Self: self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -1611,7 +1605,7 @@ def scale( about_point: Sequence[float] | np.ndarray | None = None, about_edge: Sequence[float] | np.ndarray = ORIGIN, **kwargs, - ) -> OpenGLMobject: + ) -> Self: r"""Scale the size by a factor. Default behavior is to scale about the center of the mobject. @@ -1688,7 +1682,7 @@ def _handle_scale_side_effects(self, scale_factor: float | np.ndarray) -> None: """ pass - def stretch(self, factor: float, dim: int, **kwargs) -> OpenGLMobject: + def stretch(self, factor: float, dim: int, **kwargs) -> Self: def func(points): points[:, dim] *= factor return points @@ -1696,7 +1690,7 @@ def func(points): self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle: float, axis=OUT) -> OpenGLMobject: + def rotate_about_origin(self, angle: float, axis=OUT) -> Self: return self.rotate(angle, axis, about_point=ORIGIN) # type: ignore def rotate( @@ -1705,7 +1699,7 @@ def rotate( axis=OUT, about_point: Sequence[float] | None = None, **kwargs, - ) -> OpenGLMobject: + ) -> Self: """Rotates the :class:`~.OpenGLMobject` about a certain point.""" rot_matrix_T = rotation_matrix_transpose(angle, axis) self.apply_points_function( @@ -1715,7 +1709,7 @@ def rotate( ) return self - def flip(self, axis=UP, **kwargs) -> OpenGLMobject: + def flip(self, axis=UP, **kwargs) -> Self: """Flips/Mirrors an mobject about its center. Examples @@ -1734,7 +1728,7 @@ def construct(self): """ return self.rotate(TAU / 2, axis, **kwargs) - def apply_function(self, function: PointUpdateFunction, **kwargs) -> OpenGLMobject: + def apply_function(self, function: PointUpdateFunction, **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -1743,20 +1737,18 @@ def apply_function(self, function: PointUpdateFunction, **kwargs) -> OpenGLMobje ) return self - def apply_function_to_position( - self, function: PointUpdateFunction - ) -> OpenGLMobject: + def apply_function_to_position(self, function: PointUpdateFunction) -> Self: self.move_to(function(self.get_center())) return self def apply_function_to_submobject_positions( self, function: PointUpdateFunction - ) -> OpenGLMobject: + ) -> Self: for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix, **kwargs) -> OpenGLMobject: + def apply_matrix(self, matrix, **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -1768,7 +1760,7 @@ def apply_matrix(self, matrix, **kwargs) -> OpenGLMobject: ) return self - def apply_complex_function(self, function, **kwargs) -> OpenGLMobject: + def apply_complex_function(self, function, **kwargs) -> Self: """Applies a complex function to a :class:`OpenGLMobject`. The x and y coordinates correspond to the real and imaginary parts respectively. @@ -1813,7 +1805,7 @@ def hierarchical_model_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(model_matrices))) - def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0) -> OpenGLMobject: + def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0) -> Self: for mob in self.family_members_with_points(): alphas = np.dot(mob.points, np.transpose(axis)) alphas -= min(alphas) @@ -1830,14 +1822,12 @@ def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0) -> OpenGLMobject: # Positioning methods - def center(self) -> OpenGLMobject: + def center(self) -> Self: """Moves the mobject to the center of the Scene.""" self.shift(-self.get_center()) return self - def align_on_border( - self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER - ) -> OpenGLMobject: + def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER) -> Self: """ Direction just needs to be a vector pointing towards side or corner in the 2d plane. @@ -1855,10 +1845,10 @@ def align_on_border( def to_corner( self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER - ) -> OpenGLMobject: + ) -> Self: return self.align_on_border(corner, buff) - def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER) -> OpenGLMobject: + def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER) -> Self: return self.align_on_border(edge, buff) def next_to( @@ -1870,7 +1860,7 @@ def next_to( submobject_to_align=None, index_of_submobject_to_align=None, coor_mask=np.array([1, 1, 1]), - ) -> OpenGLMobject: + ) -> Self: """Move this :class:`~.OpenGLMobject` next to another's :class:`~.OpenGLMobject` or coordinate. Examples @@ -2636,13 +2626,13 @@ def align_data(self, mobject) -> None: elif len(arr1) > len(arr2): mob2.data[key] = resize_preserving_order(arr2, len(arr1)) - def align_points(self, mobject) -> OpenGLMobject: + def align_points(self, mobject) -> Self: max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject) -> OpenGLMobject: + def align_family(self, mobject) -> Self: mob1 = self mob2 = mobject n1 = len(mob1) @@ -2655,14 +2645,14 @@ def align_family(self, mobject) -> OpenGLMobject: sm1.align_family(sm2) return self - def push_self_into_submobjects(self) -> OpenGLMobject: + def push_self_into_submobjects(self) -> Self: copy = self.deepcopy() copy.submobjects = [] self.clear_points() self.add(copy) return self - def add_n_more_submobjects(self, n) -> OpenGLMobject: + def add_n_more_submobjects(self, n) -> Self: if n == 0: return self @@ -2691,9 +2681,7 @@ def add_n_more_submobjects(self, n) -> OpenGLMobject: # Interpolate - def interpolate( - self, mobject1, mobject2, alpha, path_func=straight_path - ) -> OpenGLMobject: + def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path) -> Self: """Turns this :class:`~.OpenGLMobject` into an interpolation between ``mobject1`` and ``mobject2``. @@ -2858,6 +2846,36 @@ def construct(self): self.match_updaters(mobject) return self + def looks_identical(self, mobject: OpenGLMobject) -> bool: + fam1 = self.family_members_with_points() + fam2 = mobject.family_members_with_points() + if len(fam1) != len(fam2): + return False + for m1, m2 in zip(fam1, fam2): + for d1, d2 in [(m1.data, m2.data), (m1.uniforms, m2.uniforms)]: + if set(d1).difference(d2): + return False + for key in d1: + if ( + isinstance(d1[key], np.ndarray) + and isinstance(d2[key], np.ndarray) + and (d1[key].size != d2[key].size) + ): + return False + if not np.isclose(d1[key], d2[key]).all(): + return False + return True + + def has_same_shape_as(self, mobject: OpenGLMobject) -> bool: + # Normalize both point sets by centering and making height 1 + points1, points2 = ( + (m.get_all_points() - m.get_center()) / m.get_height() + for m in (self, mobject) + ) + if len(points1) != len(points2): + return False + return bool(np.isclose(points1, points2).all()) + # Operations touching shader uniforms @staticmethod @@ -2872,20 +2890,20 @@ def wrapper(self): return wrapper @affects_shader_info_id - def fix_in_frame(self) -> OpenGLMobject: + def fix_in_frame(self) -> Self: self.uniforms["is_fixed_in_frame"] = np.asarray(1.0) self.is_fixed_in_frame = True return self @affects_shader_info_id - def fix_orientation(self) -> OpenGLMobject: + def fix_orientation(self) -> Self: self.uniforms["is_fixed_orientation"] = np.asarray(1.0) self.is_fixed_orientation = True self.fixed_orientation_center = tuple(self.get_center()) return self @affects_shader_info_id - def unfix_from_frame(self) -> OpenGLMobject: + def unfix_from_frame(self) -> Self: self.uniforms["is_fixed_in_frame"] = np.asarray(0.0) self.is_fixed_in_frame = False return self @@ -3139,7 +3157,7 @@ def throw_error_if_no_points(self): since="v0.17.2", message="The usage of this method is discouraged please set attributes directly", ) - def set(self, **kwargs) -> OpenGLMobject: + def set(self, **kwargs) -> Self: """Sets attributes. Mainly to be used along with :attr:`animate` to diff --git a/manim/mobject/opengl/opengl_vectorized_mobject.py b/manim/mobject/opengl/opengl_vectorized_mobject.py index 7fd491f9c7..169c277714 100644 --- a/manim/mobject/opengl/opengl_vectorized_mobject.py +++ b/manim/mobject/opengl/opengl_vectorized_mobject.py @@ -3,7 +3,7 @@ import itertools as it import operator as op from functools import reduce, wraps -from typing import Callable, Iterable, Optional, Sequence +from typing import TYPE_CHECKING import moderngl import numpy as np @@ -46,6 +46,11 @@ z_to_vector, ) +if TYPE_CHECKING: + from typing import Callable, Iterable, Optional, Sequence + + from typing_extensions import Self + DEFAULT_STROKE_COLOR = GREY_A DEFAULT_FILL_COLOR = GREY_C @@ -165,7 +170,7 @@ def replicate(self, n: int) -> OpenGLVGroup: # type: ignore def get_grid(self, *args, **kwargs) -> OpenGLVGroup: # type: ignore return super().get_grid(*args, **kwargs) # type: ignore - def __getitem__(self, value: int | slice) -> OpenGLVMobject: # type: ignore + def __getitem__(self, value: int | slice) -> Self: # type: ignore return super().__getitem__(value) # type: ignore def add(self, *vmobjects: OpenGLVMobject): # type: ignore @@ -173,7 +178,7 @@ def add(self, *vmobjects: OpenGLVMobject): # type: ignore raise Exception("All submobjects must be of type OpenGLVMobject") super().add(*vmobjects) - def copy(self, deep: bool = False) -> OpenGLVMobject: + def copy(self, deep: bool = False) -> Self: result = super().copy(deep) result.shader_wrapper_list = [sw.copy() for sw in self.shader_wrapper_list] return result @@ -197,7 +202,7 @@ def init_colors(self): def set_rgba_array( self, rgba_array: np.ndarray, name: str | None = None, recurse: bool = False - ) -> OpenGLVMobject: + ) -> Self: if name is None: names = ["fill_rgba", "stroke_rgba"] else: @@ -215,7 +220,7 @@ def set_fill( color: Color | Iterable[Color] | None = None, opacity: float | Iterable[float] | None = None, recurse: bool = True, - ) -> OpenGLVMobject: + ) -> Self: """Set the fill color and fill opacity of a :class:`OpenGLVMobject`. Parameters @@ -282,7 +287,7 @@ def set_backstroke( color: Color | Iterable[Color] | None = None, width: float | Iterable[float] = 3, background: bool = True, - ) -> OpenGLVMobject: + ) -> Self: self.set_stroke(color, width, background=background) return self @@ -306,7 +311,7 @@ def set_style( gloss: float | None = None, shadow: float | None = None, recurse: bool = True, - ) -> OpenGLVMobject: + ) -> Self: for mob in self.get_family(recurse): if fill_rgba is not None: mob.data["fill_rgba"] = resize_with_interpolation( @@ -366,17 +371,17 @@ def match_style(self, vmobject: OpenGLVMobject, recurse: bool = True): sm1.match_style(sm2) return self - def set_color(self, color, opacity=None, recurse=True) -> OpenGLVMobject: + def set_color(self, color, opacity=None, recurse=True) -> Self: self.set_fill(color, opacity=opacity, recurse=recurse) self.set_stroke(color, opacity=opacity, recurse=recurse) return self - def set_opacity(self, opacity, recurse=True) -> OpenGLVMobject: + def set_opacity(self, opacity, recurse=True) -> Self: self.set_fill(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse) return self - def fade(self, darkness=0.5, recurse=True) -> OpenGLVMobject: + def fade(self, darkness=0.5, recurse=True) -> Self: mobs = self.get_family() if recurse else [self] for mob in mobs: factor = 1.0 - darkness @@ -510,7 +515,7 @@ def add_quadratic_bezier_curve_to(self, handle, anchor): else: self.append_points([self.get_last_point(), handle, anchor]) - def add_line_to(self, point: Sequence[float]) -> OpenGLVMobject: + def add_line_to(self, point: Sequence[float]) -> Self: """Add a straight line from the last point of OpenGLVMobject to the given point. Parameters @@ -592,7 +597,7 @@ def add_points_as_corners(self, points): self.add_line_to(point) return points - def set_points_as_corners(self, points: Iterable[float]) -> OpenGLVMobject: + def set_points_as_corners(self, points: Iterable[float]) -> Self: """Given an array of points, set them as corner of the vmobject. To achieve that, this algorithm sets handles aligned with the anchors such that the resultant bezier curve will be the segment @@ -615,7 +620,7 @@ def set_points_as_corners(self, points: Iterable[float]) -> OpenGLVMobject: ) return self - def set_points_smoothly(self, points, true_smooth=False) -> OpenGLVMobject: + def set_points_smoothly(self, points, true_smooth=False) -> Self: self.set_points_as_corners(points) if true_smooth: self.make_smooth() @@ -623,7 +628,7 @@ def set_points_smoothly(self, points, true_smooth=False) -> OpenGLVMobject: self.make_approximately_smooth() return self - def change_anchor_mode(self, mode) -> OpenGLVMobject: + def change_anchor_mode(self, mode) -> Self: """Changes the anchor mode of the bezier curves. This will modify the handles. There can be only three modes, "jagged", "approx_smooth" and "true_smooth". @@ -1264,7 +1269,7 @@ def get_nth_subpath(path_list, n): vmobject.set_points(np.vstack(new_subpaths2)) return self - def insert_n_curves(self, n: int, recurse=True) -> OpenGLVMobject: + def insert_n_curves(self, n: int, recurse=True) -> Self: """Inserts n curves to the bezier curves of the vmobject. Parameters @@ -1348,7 +1353,7 @@ def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs): # TODO: compare to 3b1b/manim again check if something changed so we don't need the cairo interpolation anymore def pointwise_become_partial( self, vmobject: OpenGLVMobject, a: float, b: float, remap: bool = True - ) -> OpenGLVMobject: + ) -> Self: """Given two bounds a and b, transforms the points of the self vmobject into the points of the vmobject passed as parameter with respect to the bounds. Points here stand for control points of the bezier curves (anchors and handles) @@ -1414,7 +1419,7 @@ def pointwise_become_partial( ) return self - def get_subcurve(self, a: float, b: float) -> OpenGLVMobject: + def get_subcurve(self, a: float, b: float) -> Self: """Returns the subcurve of the OpenGLVMobject between the interval [a, b]. The curve is a OpenGLVMobject itself. diff --git a/manim/renderer/opengl_renderer.py b/manim/renderer/opengl_renderer.py index cdc6924be5..837f9aa819 100644 --- a/manim/renderer/opengl_renderer.py +++ b/manim/renderer/opengl_renderer.py @@ -321,7 +321,8 @@ def get_image(self) -> Image.Image: def get_pixel_array(self) -> np.ndarray: raw = self.get_raw_fbo_data(dtype="f4") flat_arr = np.frombuffer(raw, dtype="f4") - arr = flat_arr.reshape([*self.fbo.size, self.n_channels]) + arr = flat_arr.reshape([*reversed(self.fbo.size), self.n_channels]) + arr = arr[::-1] # Convert from float return (self.rgb_max_val * arr).astype(self.pixel_array_dtype) diff --git a/manim/renderer/shader_wrapper.py b/manim/renderer/shader_wrapper.py index 0c32441ad0..f80c29c474 100644 --- a/manim/renderer/shader_wrapper.py +++ b/manim/renderer/shader_wrapper.py @@ -2,6 +2,7 @@ import copy import re +from functools import lru_cache from pathlib import Path import moderngl @@ -185,6 +186,7 @@ def read_in(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: filename_to_code_map: dict = {} +@lru_cache(maxsize=12) def get_shader_code_from_file(filename: Path) -> str | None: if filename in filename_to_code_map: return filename_to_code_map[filename] diff --git a/manim/scene/scene.py b/manim/scene/scene.py index bf5753750e..a21f40a64b 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -1,1667 +1,943 @@ -"""Basic canvas for animations.""" - from __future__ import annotations -__all__ = ["Scene"] - -import copy -import datetime import inspect +import os import platform import random -import threading import time -import types -from queue import Queue -from typing import Callable - -import srt - -from manim.scene.section import DefaultSectionType - -try: - import dearpygui.dearpygui as dpg +from collections import OrderedDict +from typing import TYPE_CHECKING - dearpygui_imported = True -except ImportError: - dearpygui_imported = False import numpy as np -from tqdm import tqdm -from watchdog.events import FileSystemEventHandler -from watchdog.observers import Observer - -from manim.mobject.mobject import Mobject -from manim.mobject.opengl.opengl_mobject import OpenGLPoint - -from .. import config, logger -from ..animation.animation import Animation, Wait, prepare_animation -from ..camera.camera import Camera -from ..constants import * -from ..gui.gui import configure_pygui -from ..renderer.cairo_renderer import CairoRenderer -from ..renderer.opengl_renderer import OpenGLRenderer -from ..renderer.shader import Object3D -from ..utils import opengl, space_ops -from ..utils.exceptions import EndSceneEarlyException, RerunSceneException -from ..utils.family import extract_mobject_family_members -from ..utils.family_ops import restructure_list_to_exclude_certain_family_members -from ..utils.file_ops import open_media_file -from ..utils.iterables import list_difference_update, list_update - - -class RerunSceneHandler(FileSystemEventHandler): - """A class to handle rerunning a Scene after the input file is modified.""" - - def __init__(self, queue): - super().__init__() - self.queue = queue - - def on_modified(self, event): - self.queue.put(("rerun_file", [], {})) +import pyperclip +from IPython.core.getipython import get_ipython +from IPython.terminal import pt_inputhooks +from IPython.terminal.embed import InteractiveShellEmbed +from tqdm import tqdm as ProgressDisplay + +from manim._config import logger as log +from manim.animation.animation import prepare_animation +from manim.animation.fading import VFadeInThenOut +from manim.camera.camera import Camera +from manim.config import get_module +from manim.constants import ( + ARROW_SYMBOLS, + COMMAND_MODIFIER, + DEFAULT_WAIT_TIME, + RED, + SHIFT_MODIFIER, +) +from manim.event_handler import EVENT_DISPATCHER +from manim.event_handler.event_type import EventType +from manim.mobject.frame import FullScreenRectangle +from manim.mobject.mobject import Group, Mobject, Point, _AnimationBuilder +from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.scene.scene_file_writer import SceneFileWriter +from manim.utils.family_ops import ( + extract_mobject_family_members, + recursive_mobject_remove, +) + +if TYPE_CHECKING: + from typing import Callable, Iterable + + from PIL.Image import Image + + from manim.animation.animation import Animation + + +PAN_3D_KEY = "d" +FRAME_SHIFT_KEY = "f" +ZOOM_KEY = "z" +RESET_FRAME_KEY = "r" +QUIT_KEY = "q" class Scene: - """A Scene is the canvas of your animation. - - The primary role of :class:`Scene` is to provide the user with tools to manage - mobjects and animations. Generally speaking, a manim script consists of a class - that derives from :class:`Scene` whose :meth:`Scene.construct` method is overridden - by the user's code. - - Mobjects are displayed on screen by calling :meth:`Scene.add` and removed from - screen by calling :meth:`Scene.remove`. All mobjects currently on screen are kept - in :attr:`Scene.mobjects`. Animations are played by calling :meth:`Scene.play`. - - A :class:`Scene` is rendered internally by calling :meth:`Scene.render`. This in - turn calls :meth:`Scene.setup`, :meth:`Scene.construct`, and - :meth:`Scene.tear_down`, in that order. - - It is not recommended to override the ``__init__`` method in user Scenes. For code - that should be ran before a Scene is rendered, use :meth:`Scene.setup` instead. - - Examples - -------- - Override the :meth:`Scene.construct` method with your code. - - .. code-block:: python - - class MyScene(Scene): - def construct(self): - self.play(Write(Text("Hello World!"))) - - """ + random_seed: int = 0 + pan_sensitivity: float = 3.0 + max_num_saved_states: int = 50 + default_camera_config: dict = {} + default_window_config: dict = {} + default_file_writer_config: dict = {} def __init__( self, - renderer=None, - camera_class=Camera, - always_update_mobjects=False, - random_seed=None, - skip_animations=False, + window_config: dict = {}, + camera_config: dict = {}, + file_writer_config: dict = {}, + skip_animations: bool = False, + always_update_mobjects: bool = False, + start_at_animation_number: int | None = None, + end_at_animation_number: int | None = None, + leave_progress_bars: bool = False, + preview: bool = True, + presenter_mode: bool = False, + show_animation_progress: bool = False, + embed_exception_mode: str = "", + embed_error_sound: bool = False, ): - self.camera_class = camera_class - self.always_update_mobjects = always_update_mobjects - self.random_seed = random_seed self.skip_animations = skip_animations - - self.animations = None - self.stop_condition = None - self.moving_mobjects = [] - self.static_mobjects = [] - self.time_progression = None - self.duration = None - self.last_t = None - self.queue = Queue() - self.skip_animation_preview = False - self.meshes = [] - self.camera_target = ORIGIN - self.widgets = [] - self.dearpygui_imported = dearpygui_imported - self.updaters = [] - self.point_lights = [] - self.ambient_light = None - self.key_to_function_map = {} - self.mouse_press_callbacks = [] - self.interactive_mode = False - - if config.renderer == RendererType.OPENGL: - # Items associated with interaction - self.mouse_point = OpenGLPoint() - self.mouse_drag_point = OpenGLPoint() - if renderer is None: - renderer = OpenGLRenderer() - - if renderer is None: - self.renderer = CairoRenderer( - camera_class=self.camera_class, - skip_animations=self.skip_animations, - ) + self.always_update_mobjects = always_update_mobjects + self.start_at_animation_number = start_at_animation_number + self.end_at_animation_number = end_at_animation_number + self.leave_progress_bars = leave_progress_bars + self.preview = preview + self.presenter_mode = presenter_mode + self.show_animation_progress = show_animation_progress + self.embed_exception_mode = embed_exception_mode + self.embed_error_sound = embed_error_sound + + self.camera_config = {**self.default_camera_config, **camera_config} + self.window_config = {**self.default_window_config, **window_config} + self.file_writer_config = { + **self.default_file_writer_config, + **file_writer_config, + } + + # Initialize window, if applicable + if self.preview: + from manimlib.window import Window + + self.window = Window(scene=self, **self.window_config) + self.camera_config["ctx"] = self.window.ctx + self.camera_config["fps"] = 30 # Where's that 30 from? else: - self.renderer = renderer - self.renderer.init_scene(self) + self.window = None + + # Core state of the scene + self.camera: Camera = Camera(**self.camera_config) + self.file_writer = SceneFileWriter(self, **self.file_writer_config) + self.mobjects: list[Mobject] = [self.camera.frame] + self.id_to_mobject_map: dict[int, Mobject] = {} + self.num_plays: int = 0 + self.time: float = 0 + self.skip_time: float = 0 + self.original_skipping_status: bool = self.skip_animations + self.checkpoint_states: dict[str, list[tuple[Mobject, Mobject]]] = {} + self.undo_stack = [] + self.redo_stack = [] + + if self.start_at_animation_number is not None: + self.skip_animations = True + if self.file_writer.has_progress_display(): + self.show_animation_progress = False + + # Items associated with interaction + self.mouse_point = Point() + self.mouse_drag_point = Point() + self.hold_on_wait = self.presenter_mode + self.quit_interaction = False - self.mobjects = [] - # TODO, remove need for foreground mobjects - self.foreground_mobjects = [] + # Much nicer to work with deterministic scenes if self.random_seed is not None: random.seed(self.random_seed) np.random.seed(self.random_seed) - @property - def camera(self): - return self.renderer.camera - - def __deepcopy__(self, clone_from_id): - cls = self.__class__ - result = cls.__new__(cls) - clone_from_id[id(self)] = result - for k, v in self.__dict__.items(): - if k in ["renderer", "time_progression"]: - continue - if k == "camera_class": - setattr(result, k, v) - setattr(result, k, copy.deepcopy(v, clone_from_id)) - result.mobject_updater_lists = [] - - # Update updaters - for mobject in self.mobjects: - cloned_updaters = [] - for updater in mobject.updaters: - # Make the cloned updater use the cloned Mobjects as free variables - # rather than the original ones. Analyzing function bytecode with the - # dis module will help in understanding this. - # https://docs.python.org/3/library/dis.html - # TODO: Do the same for function calls recursively. - free_variable_map = inspect.getclosurevars(updater).nonlocals - cloned_co_freevars = [] - cloned_closure = [] - for free_variable_name in updater.__code__.co_freevars: - free_variable_value = free_variable_map[free_variable_name] - - # If the referenced variable has not been cloned, raise. - if id(free_variable_value) not in clone_from_id: - raise Exception( - f"{free_variable_name} is referenced from an updater " - "but is not an attribute of the Scene, which isn't " - "allowed.", - ) - - # Add the cloned object's name to the free variable list. - cloned_co_freevars.append(free_variable_name) - - # Add a cell containing the cloned object's reference to the - # closure list. - cloned_closure.append( - types.CellType(clone_from_id[id(free_variable_value)]), - ) - - cloned_updater = types.FunctionType( - updater.__code__.replace(co_freevars=tuple(cloned_co_freevars)), - updater.__globals__, - updater.__name__, - updater.__defaults__, - tuple(cloned_closure), - ) - cloned_updaters.append(cloned_updater) - mobject_clone = clone_from_id[id(mobject)] - mobject_clone.updaters = cloned_updaters - if len(cloned_updaters) > 0: - result.mobject_updater_lists.append((mobject_clone, cloned_updaters)) - return result - - def render(self, preview: bool = False): - """ - Renders this Scene. + def __str__(self) -> str: + return self.__class__.__name__ + + def run(self) -> None: + self.virtual_animation_start_time: float = 0 + self.real_animation_start_time: float = time.time() + self.file_writer.begin() - Parameters - --------- - preview - If true, opens scene in a file viewer. - """ self.setup() try: self.construct() - except EndSceneEarlyException: + self.interact() + except EndScene: pass - except RerunSceneException as e: - self.remove(*self.mobjects) - self.renderer.clear_screen() - self.renderer.num_plays = 0 - return True + except KeyboardInterrupt: + # Get rid keyboard interrupt symbols + print("", end="\r") + self.file_writer.ended_with_interrupt = True self.tear_down() - # We have to reset these settings in case of multiple renders. - self.renderer.scene_finished(self) - # Show info only if animations are rendered or to get image - if ( - self.renderer.num_plays - or config["format"] == "png" - or config["save_last_frame"] - ): - logger.info( - f"Rendered {str(self)}\nPlayed {self.renderer.num_plays} animations", - ) - - # If preview open up the render after rendering. - if preview: - config["preview"] = True - - if config["preview"] or config["show_in_file_browser"]: - open_media_file(self.renderer.file_writer) - - def setup(self): + def setup(self) -> None: """ - This is meant to be implemented by any scenes which - are commonly subclassed, and have some common setup + This is meant to be implement by any scenes which + are comonly subclassed, and have some common setup involved before the construct method is called. """ pass - def tear_down(self): - """ - This is meant to be implemented by any scenes which - are commonly subclassed, and have some common method - to be invoked before the scene ends. - """ + def construct(self) -> None: + # Where all the animation happens + # To be implemented in subclasses pass - def construct(self): - """Add content to the Scene. - - From within :meth:`Scene.construct`, display mobjects on screen by calling - :meth:`Scene.add` and remove them from screen by calling :meth:`Scene.remove`. - All mobjects currently on screen are kept in :attr:`Scene.mobjects`. Play - animations by calling :meth:`Scene.play`. - - Notes - ----- - Initialization code should go in :meth:`Scene.setup`. Termination code should - go in :meth:`Scene.tear_down`. - - Examples - -------- - A typical manim script includes a class derived from :class:`Scene` with an - overridden :meth:`Scene.contruct` method: - - .. code-block:: python - - class MyScene(Scene): - def construct(self): - self.play(Write(Text("Hello World!"))) - - See Also - -------- - :meth:`Scene.setup` - :meth:`Scene.render` - :meth:`Scene.tear_down` + def tear_down(self) -> None: + self.stop_skipping() + self.file_writer.finish() + if self.window: + self.window.destroy() + self.window = None + def interact(self) -> None: + """ + If there is a window, enter a loop + which updates the frame while under + the hood calling the pyglet event loop """ - pass # To be implemented in subclasses + if self.window is None: + return + log.info( + "\nTips: Using the keys `d`, `f`, or `z` " + + "you can interact with the scene. " + + "Press `command + q` or `esc` to quit" + ) + self.skip_animations = False + self.refresh_static_mobjects() + while not self.is_window_closing(): + self.update_frame(1 / self.camera.fps) - def next_section( + def embed( self, - name: str = "unnamed", - type: str = DefaultSectionType.NORMAL, - skip_animations: bool = False, + close_scene_on_exit: bool = True, + show_animation_progress: bool = True, ) -> None: - """Create separation here; the last section gets finished and a new one gets created. - ``skip_animations`` skips the rendering of all animations in this section. - Refer to :doc:`the documentation` on how to use sections. - """ - self.renderer.file_writer.next_section(name, type, skip_animations) + if not self.preview: + return # Embed is only relevant with a preview + self.stop_skipping() + self.update_frame() + self.save_state() + self.show_animation_progress = show_animation_progress + + # Create embedded IPython terminal to be configured + shell = InteractiveShellEmbed.instance() + + # Use the locals namespace of the caller + caller_frame = inspect.currentframe().f_back + local_ns = dict(caller_frame.f_locals) + + # Add a few custom shortcuts + local_ns.update( + play=self.play, + wait=self.wait, + add=self.add, + remove=self.remove, + clear=self.clear, + save_state=self.save_state, + undo=self.undo, + redo=self.redo, + i2g=self.i2g, + i2m=self.i2m, + checkpoint_paste=self.checkpoint_paste, + ) - def __str__(self): - return self.__class__.__name__ + # Enables gui interactions during the embed + def inputhook(context): + while not context.input_is_ready(): + if not self.is_window_closing(): + self.update_frame(dt=0) + if self.is_window_closing(): + shell.ask_exit() + + pt_inputhooks.register("manim", inputhook) + shell.enable_gui("manim") + + # This is hacky, but there's an issue with ipython which is that + # when you define lambda's or list comprehensions during a shell session, + # they are not aware of local variables in the surrounding scope. Because + # That comes up a fair bit during scene construction, to get around this, + # we (admittedly sketchily) update the global namespace to match the local + # namespace, since this is just a shell session anyway. + shell.events.register( + "pre_run_cell", lambda: shell.user_global_ns.update(shell.user_ns) + ) - def get_attrs(self, *keys: str): - """ - Gets attributes of a scene given the attribute's identifier/name. + # Operation to run after each ipython command + def post_cell_func(): + self.refresh_static_mobjects() + if not self.is_window_closing(): + self.update_frame(dt=0, ignore_skipping=True) + self.save_state() + + shell.events.register("post_run_cell", post_cell_func) + + # Flash border, and potentially play sound, on exceptions + def custom_exc(shell, etype, evalue, tb, tb_offset=None): + # still show the error don't just swallow it + shell.showtraceback((etype, evalue, tb), tb_offset=tb_offset) + if self.embed_error_sound: + os.system("printf '\a'") + rect = FullScreenRectangle().set_stroke(RED, 30).set_fill(opacity=0) + rect.fix_in_frame() + self.play(VFadeInThenOut(rect, run_time=0.5)) + + shell.set_custom_exc((Exception,), custom_exc) + + # Set desired exception mode + shell.magic(f"xmode {self.embed_exception_mode}") + + # Launch shell + shell( + local_ns=local_ns, + # Pretend like we're embeding in the caller function, not here + stack_depth=2, + # Specify that the present module is the caller's, not here + module=get_module(caller_frame.f_globals["__file__"]), + ) - Parameters - ---------- - *keys - Name(s) of the argument(s) to return the attribute of. + # End scene when exiting an embed + if close_scene_on_exit: + raise EndScene() - Returns - ------- - list - List of attributes of the passed identifiers. - """ - return [getattr(self, key) for key in keys] + # Only these methods should touch the camera - def update_mobjects(self, dt: float): - """ - Begins updating all mobjects in the Scene. + def get_image(self) -> Image: + return self.camera.get_image() - Parameters - ---------- - dt - Change in time between updates. Defaults (mostly) to 1/frames_per_second - """ - for mobject in self.mobjects: - mobject.update(dt) + def show(self) -> None: + self.update_frame(ignore_skipping=True) + self.get_image().show() - def update_meshes(self, dt): - for obj in self.meshes: - for mesh in obj.get_family(): - mesh.update(dt) + def update_frame(self, dt: float = 0, ignore_skipping: bool = False) -> None: + self.increment_time(dt) + self.update_mobjects(dt) + if self.skip_animations and not ignore_skipping: + return - def update_self(self, dt: float): - """Run all scene updater functions. + if self.is_window_closing(): + raise EndScene() - Among all types of update functions (mobject updaters, mesh updaters, - scene updaters), scene update functions are called last. + if self.window: + self.window.clear() + self.camera.clear() + self.camera.capture(*self.mobjects) - Parameters - ---------- - dt - Scene time since last update. + if self.window: + self.window.swap_buffers() + vt = self.time - self.virtual_animation_start_time + rt = time.time() - self.real_animation_start_time + if rt < vt: + self.update_frame(0) - See Also - -------- - :meth:`.Scene.add_updater` - :meth:`.Scene.remove_updater` - """ - for func in self.updaters: - func(dt) + def emit_frame(self) -> None: + if not self.skip_animations: + self.file_writer.write_frame(self.camera) + + # Related to updating + + def update_mobjects(self, dt: float) -> None: + for mobject in self.mobjects: + mobject.update(dt) def should_update_mobjects(self) -> bool: - """ - Returns True if the mobjects of this scene should be updated. + return self.always_update_mobjects or any( + [len(mob.get_family_updaters()) > 0 for mob in self.mobjects] + ) - In particular, this checks whether + def has_time_based_updaters(self) -> bool: + return any( + [ + sm.has_time_based_updater() + for mob in self.mobjects() + for sm in mob.get_family() + ] + ) - - the :attr:`always_update_mobjects` attribute of :class:`.Scene` - is set to ``True``, - - the :class:`.Scene` itself has time-based updaters attached, - - any mobject in this :class:`.Scene` has time-based updaters attached. + # Related to time - This is only called when a single Wait animation is played. - """ - wait_animation = self.animations[0] - if wait_animation.is_static_wait is None: - should_update = ( - self.always_update_mobjects - or self.updaters - or any( - [ - mob.has_time_based_updater() - for mob in self.get_mobject_family_members() - ], - ) - ) - wait_animation.is_static_wait = not should_update - return not wait_animation.is_static_wait + def get_time(self) -> float: + return self.time - def get_top_level_mobjects(self): - """ - Returns all mobjects which are not submobjects. + def increment_time(self, dt: float) -> None: + self.time += dt - Returns - ------- - list - List of top level mobjects. - """ + # Related to internal mobject organization + + def get_top_level_mobjects(self) -> list[Mobject]: # Return only those which are not in the family # of another mobject from the scene - families = [m.get_family() for m in self.mobjects] + mobjects = self.get_mobjects() + families = [m.get_family() for m in mobjects] def is_top_level(mobject): - num_families = sum((mobject in family) for family in families) + num_families = sum([(mobject in family) for family in families]) return num_families == 1 - return list(filter(is_top_level, self.mobjects)) + return list(filter(is_top_level, mobjects)) - def get_mobject_family_members(self): - """ - Returns list of family-members of all mobjects in scene. - If a Circle() and a VGroup(Rectangle(),Triangle()) were added, - it returns not only the Circle(), Rectangle() and Triangle(), but - also the VGroup() object. - - Returns - ------- - list - List of mobject family members. - """ - if config.renderer == RendererType.OPENGL: - family_members = [] - for mob in self.mobjects: - family_members.extend(mob.get_family()) - return family_members - elif config.renderer == RendererType.CAIRO: - return extract_mobject_family_members( - self.mobjects, - use_z_index=self.renderer.camera.use_z_index, - ) + def get_mobject_family_members(self) -> list[Mobject]: + return extract_mobject_family_members(self.mobjects) - def add(self, *mobjects: Mobject): + def add(self, *new_mobjects: Mobject): """ Mobjects will be displayed, from background to foreground in the order with which they are added. - - Parameters - --------- - *mobjects - Mobjects to add. - - Returns - ------- - Scene - The same scene after adding the Mobjects in. - """ - if config.renderer == RendererType.OPENGL: - new_mobjects = [] - new_meshes = [] - for mobject_or_mesh in mobjects: - if isinstance(mobject_or_mesh, Object3D): - new_meshes.append(mobject_or_mesh) - else: - new_mobjects.append(mobject_or_mesh) - self.remove(*new_mobjects) - self.mobjects += new_mobjects - self.remove(*new_meshes) - self.meshes += new_meshes - elif config.renderer == RendererType.CAIRO: - mobjects = [*mobjects, *self.foreground_mobjects] - self.restructure_mobjects(to_remove=mobjects) - self.mobjects += mobjects - if self.moving_mobjects: - self.restructure_mobjects( - to_remove=mobjects, - mobject_list_name="moving_mobjects", - ) - self.moving_mobjects += mobjects + self.remove(*new_mobjects) + self.mobjects += new_mobjects + self.id_to_mobject_map.update( + {id(sm): sm for m in new_mobjects for sm in m.get_family()} + ) return self - def add_mobjects_from_animations(self, animations): - curr_mobjects = self.get_mobject_family_members() - for animation in animations: - if animation.is_introducer(): - continue - # Anything animated that's not already in the - # scene gets added to the scene - mob = animation.mobject - if mob is not None and mob not in curr_mobjects: - self.add(mob) - curr_mobjects += mob.get_family() - - def remove(self, *mobjects: Mobject): - """ - Removes mobjects in the passed list of mobjects - from the scene and the foreground, by removing them - from "mobjects" and "foreground_mobjects" - - Parameters - ---------- - *mobjects - The mobjects to remove. + def add_mobjects_among(self, values: Iterable): """ - if config.renderer == RendererType.OPENGL: - mobjects_to_remove = [] - meshes_to_remove = set() - for mobject_or_mesh in mobjects: - if isinstance(mobject_or_mesh, Object3D): - meshes_to_remove.add(mobject_or_mesh) - else: - mobjects_to_remove.append(mobject_or_mesh) - self.mobjects = restructure_list_to_exclude_certain_family_members( - self.mobjects, - mobjects_to_remove, - ) - self.meshes = list( - filter(lambda mesh: mesh not in set(meshes_to_remove), self.meshes), - ) - return self - elif config.renderer == RendererType.CAIRO: - for list_name in "mobjects", "foreground_mobjects": - self.restructure_mobjects(mobjects, list_name, False) - return self - - def add_updater(self, func: Callable[[float], None]) -> None: - """Add an update function to the scene. - - The scene updater functions are run every frame, - and they are the last type of updaters to run. - - .. WARNING:: - - When using the Cairo renderer, scene updaters that - modify mobjects are not detected in the same way - that mobject updaters are. To be more concrete, - a mobject only modified via a scene updater will - not necessarily be added to the list of *moving - mobjects* and thus might not be updated every frame. - - TL;DR: Use mobject updaters to update mobjects. - - Parameters - ---------- - func - The updater function. It takes a float, which is the - time difference since the last update (usually equal - to the frame rate). - - See also - -------- - :meth:`.Scene.remove_updater` - :meth:`.Scene.update_self` - """ - self.updaters.append(func) - - def remove_updater(self, func: Callable[[float], None]) -> None: - """Remove an update function from the scene. - - Parameters - ---------- - func - The updater function to be removed. - - See also - -------- - :meth:`.Scene.add_updater` - :meth:`.Scene.update_self` + This is meant mostly for quick prototyping, + e.g. to add all mobjects defined up to a point, + call self.add_mobjects_among(locals().values()) """ - self.updaters = [f for f in self.updaters if f is not func] - - def restructure_mobjects( - self, - to_remove: Mobject, - mobject_list_name: str = "mobjects", - extract_families: bool = True, - ): - """ - tl:wr - If your scene has a Group(), and you removed a mobject from the Group, - this dissolves the group and puts the rest of the mobjects directly - in self.mobjects or self.foreground_mobjects. - - In cases where the scene contains a group, e.g. Group(m1, m2, m3), but one - of its submobjects is removed, e.g. scene.remove(m1), the list of mobjects - will be edited to contain other submobjects, but not m1, e.g. it will now - insert m2 and m3 to where the group once was. - - Parameters - ---------- - to_remove - The Mobject to remove. - - mobject_list_name - The list of mobjects ("mobjects", "foreground_mobjects" etc) to remove from. - - extract_families - Whether the mobject's families should be recursively extracted. - - Returns - ------- - Scene - The Scene mobject with restructured Mobjects. - """ - if extract_families: - to_remove = extract_mobject_family_members( - to_remove, - use_z_index=self.renderer.camera.use_z_index, - ) - _list = getattr(self, mobject_list_name) - new_list = self.get_restructured_mobject_list(_list, to_remove) - setattr(self, mobject_list_name, new_list) + self.add(*filter(lambda m: isinstance(m, Mobject), values)) return self - def get_restructured_mobject_list(self, mobjects: list, to_remove: list): - """ - Given a list of mobjects and a list of mobjects to be removed, this - filters out the removable mobjects from the list of mobjects. - - Parameters - ---------- - - mobjects - The Mobjects to check. - - to_remove - The list of mobjects to remove. - - Returns - ------- - list - The list of mobjects with the mobjects to remove removed. - """ - - new_mobjects = [] - - def add_safe_mobjects_from_list(list_to_examine, set_to_remove): - for mob in list_to_examine: - if mob in set_to_remove: - continue - intersect = set_to_remove.intersection(mob.get_family()) - if intersect: - add_safe_mobjects_from_list(mob.submobjects, intersect) - else: - new_mobjects.append(mob) - - add_safe_mobjects_from_list(mobjects, set(to_remove)) - return new_mobjects - - # TODO, remove this, and calls to this - def add_foreground_mobjects(self, *mobjects: Mobject): - """ - Adds mobjects to the foreground, and internally to the list - foreground_mobjects, and mobjects. - - Parameters - ---------- - *mobjects - The Mobjects to add to the foreground. - - Returns - ------ - Scene - The Scene, with the foreground mobjects added. - """ - self.foreground_mobjects = list_update(self.foreground_mobjects, mobjects) - self.add(*mobjects) + def replace(self, mobject: Mobject, *replacements: Mobject): + if mobject in self.mobjects: + index = self.mobjects.index(mobject) + self.mobjects = [ + *self.mobjects[:index], + *replacements, + *self.mobjects[index + 1 :], + ] return self - def add_foreground_mobject(self, mobject: Mobject): - """ - Adds a single mobject to the foreground, and internally to the list - foreground_mobjects, and mobjects. - - Parameters - ---------- - mobject - The Mobject to add to the foreground. - - Returns - ------ - Scene - The Scene, with the foreground mobject added. - """ - return self.add_foreground_mobjects(mobject) - - def remove_foreground_mobjects(self, *to_remove: Mobject): + def remove(self, *mobjects_to_remove: Mobject): """ - Removes mobjects from the foreground, and internally from the list - foreground_mobjects. - - Parameters - ---------- - *to_remove - The mobject(s) to remove from the foreground. - - Returns - ------ - Scene - The Scene, with the foreground mobjects removed. - """ - self.restructure_mobjects(to_remove, "foreground_mobjects") - return self + Removes anything in mobjects from scenes mobject list, but in the event that one + of the items to be removed is a member of the family of an item in mobject_list, + the other family members are added back into the list. - def remove_foreground_mobject(self, mobject: Mobject): - """ - Removes a single mobject from the foreground, and internally from the list - foreground_mobjects. - - Parameters - ---------- - mobject - The mobject to remove from the foreground. - - Returns - ------ - Scene - The Scene, with the foreground mobject removed. + For example, if the scene includes Group(m1, m2, m3), and we call scene.remove(m1), + the desired behavior is for the scene to then include m2 and m3 (ungrouped). """ - return self.remove_foreground_mobjects(mobject) + to_remove = set(extract_mobject_family_members(mobjects_to_remove)) + new_mobjects, _ = recursive_mobject_remove(self.mobjects, to_remove) + self.mobjects = new_mobjects def bring_to_front(self, *mobjects: Mobject): - """ - Adds the passed mobjects to the scene again, - pushing them to he front of the scene. - - Parameters - ---------- - *mobjects - The mobject(s) to bring to the front of the scene. - - Returns - ------ - Scene - The Scene, with the mobjects brought to the front - of the scene. - """ self.add(*mobjects) return self def bring_to_back(self, *mobjects: Mobject): - """ - Removes the mobject from the scene and - adds them to the back of the scene. - - Parameters - ---------- - *mobjects - The mobject(s) to push to the back of the scene. - - Returns - ------ - Scene - The Scene, with the mobjects pushed to the back - of the scene. - """ self.remove(*mobjects) self.mobjects = list(mobjects) + self.mobjects return self def clear(self): - """ - Removes all mobjects present in self.mobjects - and self.foreground_mobjects from the scene. - - Returns - ------ - Scene - The Scene, with all of its mobjects in - self.mobjects and self.foreground_mobjects - removed. - """ self.mobjects = [] - self.foreground_mobjects = [] return self - def get_moving_mobjects(self, *animations: Animation): - """ - Gets all moving mobjects in the passed animation(s). - - Parameters - ---------- - *animations - The animations to check for moving mobjects. - - Returns - ------ - list - The list of mobjects that could be moving in - the Animation(s) - """ - # Go through mobjects from start to end, and - # as soon as there's one that needs updating of - # some kind per frame, return the list from that - # point forward. - animation_mobjects = [anim.mobject for anim in animations] - mobjects = self.get_mobject_family_members() - for i, mob in enumerate(mobjects): - update_possibilities = [ - mob in animation_mobjects, - len(mob.get_family_updaters()) > 0, - mob in self.foreground_mobjects, - ] - if any(update_possibilities): - return mobjects[i:] - return [] - - def get_moving_and_static_mobjects(self, animations): - all_mobjects = list_update(self.mobjects, self.foreground_mobjects) - all_mobject_families = extract_mobject_family_members( - all_mobjects, - use_z_index=self.renderer.camera.use_z_index, - only_those_with_points=True, - ) - moving_mobjects = self.get_moving_mobjects(*animations) - all_moving_mobject_families = extract_mobject_family_members( - moving_mobjects, - use_z_index=self.renderer.camera.use_z_index, - ) - static_mobjects = list_difference_update( - all_mobject_families, - all_moving_mobject_families, - ) - return all_moving_mobject_families, static_mobjects + def get_mobjects(self) -> list[Mobject]: + return list(self.mobjects) - def compile_animations(self, *args: Animation, **kwargs): - """ - Creates _MethodAnimations from any _AnimationBuilders and updates animation - kwargs with kwargs passed to play(). - - Parameters - ---------- - *args - Animations to be played. - **kwargs - Configuration for the call to play(). - - Returns - ------- - Tuple[:class:`Animation`] - Animations to be played. - """ - animations = [] - for arg in args: - try: - animations.append(prepare_animation(arg)) - except TypeError: - if inspect.ismethod(arg): - raise TypeError( - "Passing Mobject methods to Scene.play is no longer" - " supported. Use Mobject.animate instead.", - ) - else: - raise TypeError( - f"Unexpected argument {arg} passed to Scene.play().", - ) - - for animation in animations: - for k, v in kwargs.items(): - setattr(animation, k, v) - - return animations - - def _get_animation_time_progression( - self, animations: list[Animation], duration: float - ): - """ - You will hardly use this when making your own animations. - This method is for Manim's internal use. - - Uses :func:`~.get_time_progression` to obtain a - CommandLine ProgressBar whose ``fill_time`` is - dependent on the qualities of the passed Animation, - - Parameters - ---------- - animations - The list of animations to get - the time progression for. - - duration - duration of wait time - - Returns - ------- - time_progression - The CommandLine Progress Bar. - """ - if len(animations) == 1 and isinstance(animations[0], Wait): - stop_condition = animations[0].stop_condition - if stop_condition is not None: - time_progression = self.get_time_progression( - duration, - f"Waiting for {stop_condition.__name__}", - n_iterations=-1, # So it doesn't show % progress - override_skip_animations=True, - ) - else: - time_progression = self.get_time_progression( - duration, - f"Waiting {self.renderer.num_plays}", - ) - else: - time_progression = self.get_time_progression( - duration, - "".join( - [ - f"Animation {self.renderer.num_plays}: ", - str(animations[0]), - (", etc." if len(animations) > 1 else ""), - ], - ), - ) - return time_progression + def get_mobject_copies(self) -> list[Mobject]: + return [m.copy() for m in self.mobjects] - def get_time_progression( + def point_to_mobject( self, - run_time: float, - description, - n_iterations: int | None = None, - override_skip_animations: bool = False, - ): - """ - You will hardly use this when making your own animations. - This method is for Manim's internal use. - - Returns a CommandLine ProgressBar whose ``fill_time`` - is dependent on the ``run_time`` of an animation, - the iterations to perform in that animation - and a bool saying whether or not to consider - the skipped animations. - - Parameters - ---------- - run_time - The ``run_time`` of the animation. - - n_iterations - The number of iterations in the animation. - - override_skip_animations - Whether or not to show skipped animations in the progress bar. - - Returns - ------- - time_progression - The CommandLine Progress Bar. - """ - if self.renderer.skip_animations and not override_skip_animations: - times = [run_time] + point: np.ndarray, + search_set: Iterable[Mobject] | None = None, + buff: float = 0, + ) -> Mobject | None: + """ + E.g. if clicking on the scene, this returns the top layer mobject + under a given point + """ + if search_set is None: + search_set = self.mobjects + for mobject in reversed(search_set): + if mobject.is_point_touching(point, buff=buff): + return mobject + return None + + def get_group(self, *mobjects): + if all(isinstance(m, VMobject) for m in mobjects): + return VGroup(*mobjects) else: - step = 1 / config["frame_rate"] - times = np.arange(0, run_time, step) - time_progression = tqdm( - times, - desc=description, - total=n_iterations, - leave=config["progress_bar"] == "leave", - ascii=True if platform.system() == "Windows" else None, - disable=config["progress_bar"] == "none", + return Group(*mobjects) + + def id_to_mobject(self, id_value): + return self.id_to_mobject_map[id_value] + + def ids_to_group(self, *id_values): + return self.get_group( + *filter(lambda x: x is not None, map(self.id_to_mobject, id_values)) ) - return time_progression - def get_run_time(self, animations: list[Animation]): - """ - Gets the total run time for a list of animations. - - Parameters - ---------- - animations - A list of the animations whose total - ``run_time`` is to be calculated. - - Returns - ------- - float - The total ``run_time`` of all of the animations in the list. - """ + def i2g(self, *id_values): + return self.ids_to_group(*id_values) - if len(animations) == 1 and isinstance(animations[0], Wait): - if animations[0].stop_condition is not None: - return 0 - else: - return animations[0].duration + def i2m(self, id_value): + return self.id_to_mobject(id_value) - else: - return np.max([animation.run_time for animation in animations]) + # Related to skipping - def play( - self, - *args, - subcaption=None, - subcaption_duration=None, - subcaption_offset=0, - **kwargs, - ): - r"""Plays an animation in this scene. - - Parameters - ---------- - - args - Animations to be played. - subcaption - The content of the external subcaption that should - be added during the animation. - subcaption_duration - The duration for which the specified subcaption is - added. If ``None`` (the default), the run time of the - animation is taken. - subcaption_offset - An offset (in seconds) for the start time of the - added subcaption. - kwargs - All other keywords are passed to the renderer. + def update_skipping_status(self) -> None: + if (self.start_at_animation_number is not None) and ( + self.num_plays == self.start_at_animation_number + ): + self.skip_time = self.time + if not self.original_skipping_status: + self.stop_skipping() + if (self.end_at_animation_number is not None) and ( + self.num_plays >= self.end_at_animation_number + ): + raise EndScene() - """ - # Make sure this is running on the main thread - if threading.current_thread().name != "MainThread": - kwargs.update( - { - "subcaption": subcaption, - "subcaption_duration": subcaption_duration, - "subcaption_offset": subcaption_offset, - } - ) - self.queue.put( - ( - "play", - args, - kwargs, - ) - ) - return + def stop_skipping(self) -> None: + self.virtual_animation_start_time = self.time + self.skip_animations = False - start_time = self.renderer.time - self.renderer.play(self, *args, **kwargs) - run_time = self.renderer.time - start_time - if subcaption: - if subcaption_duration is None: - subcaption_duration = run_time - # The start of the subcaption needs to be offset by the - # run_time of the animation because it is added after - # the animation has already been played (and Scene.renderer.time - # has already been updated). - self.add_subcaption( - content=subcaption, - duration=subcaption_duration, - offset=-run_time + subcaption_offset, - ) + # Methods associated with running animations - def wait( + def get_time_progression( self, - duration: float = DEFAULT_WAIT_TIME, - stop_condition: Callable[[], bool] | None = None, - frozen_frame: bool | None = None, - ): - """Plays a "no operation" animation. - - Parameters - ---------- - duration - The run time of the animation. - stop_condition - A function without positional arguments that is evaluated every time - a frame is rendered. The animation only stops when the return value - of the function is truthy. Overrides any value passed to ``duration``. - frozen_frame - If True, updater functions are not evaluated, and the animation outputs - a frozen frame. If False, updater functions are called and frames - are rendered as usual. If None (the default), the scene tries to - determine whether or not the frame is frozen on its own. - - See also - -------- - :class:`.Wait`, :meth:`.should_mobjects_update` - """ - self.play( - Wait( - run_time=duration, - stop_condition=stop_condition, - frozen_frame=frozen_frame, - ) - ) + run_time: float, + n_iterations: int | None = None, + desc: str = "", + override_skip_animations: bool = False, + ) -> list[float] | np.ndarray | ProgressDisplay: + if self.skip_animations and not override_skip_animations: + return [run_time] - def pause(self, duration: float = DEFAULT_WAIT_TIME): - """Pauses the scene (i.e., displays a frozen frame). + times = np.arange(0, run_time, 1 / self.camera.fps) - This is an alias for :meth:`.wait` with ``frozen_frame`` - set to ``True``. + self.file_writer.set_progress_display_description(sub_desc=desc) - Parameters - ---------- - duration - The duration of the pause. + if self.show_animation_progress: + return ProgressDisplay( + times, + total=n_iterations, + leave=self.leave_progress_bars, + ascii=True if platform.system() == "Windows" else None, + desc=desc, + ) + else: + return times + + def get_run_time(self, animations: Iterable[Animation]) -> float: + return np.max([animation.get_run_time() for animation in animations]) + + def get_animation_time_progression( + self, animations: Iterable[Animation] + ) -> list[float] | np.ndarray | ProgressDisplay: + animations = list(animations) + run_time = self.get_run_time(animations) + description = f"{self.num_plays} {animations[0]}" + if len(animations) > 1: + description += ", etc." + time_progression = self.get_time_progression(run_time, desc=description) + return time_progression - See also - -------- - :meth:`.wait`, :class:`.Wait` - """ - self.wait(duration=duration, frozen_frame=True) + def get_wait_time_progression( + self, duration: float, stop_condition: Callable[[], bool] | None = None + ) -> list[float] | np.ndarray | ProgressDisplay: + kw = {"desc": f"{self.num_plays} Waiting"} + if stop_condition is not None: + kw["n_iterations"] = -1 # So it doesn't show % progress + kw["override_skip_animations"] = True + return self.get_time_progression(duration, **kw) - def wait_until(self, stop_condition: Callable[[], bool], max_time: float = 60): - """ - Like a wrapper for wait(). - You pass a function that determines whether to continue waiting, - and a max wait time if that is never fulfilled. + def pre_play(self): + if self.presenter_mode and self.num_plays == 0: + self.hold_loop() - Parameters - ---------- - stop_condition - The function whose boolean return value determines whether to continue waiting + self.update_skipping_status() - max_time - The maximum wait time in seconds, if the stop_condition is never fulfilled. - """ - self.wait(max_time, stop_condition=stop_condition) + if not self.skip_animations: + self.file_writer.begin_animation() - def compile_animation_data(self, *animations: Animation, **play_kwargs): - """Given a list of animations, compile the corresponding - static and moving mobjects, and gather the animation durations. + if self.window: + self.real_animation_start_time = time.time() + self.virtual_animation_start_time = self.time - This also begins the animations. + self.refresh_static_mobjects() - Parameters - ---------- - animations - Animation or mobject with mobject method and params - play_kwargs - Named parameters affecting what was passed in ``animations``, - e.g. ``run_time``, ``lag_ratio`` and so on. + def post_play(self): + if not self.skip_animations: + self.file_writer.end_animation() - Returns - ------- - self, None - None if there is nothing to play, or self otherwise. - """ - # NOTE TODO : returns statement of this method are wrong. It should return nothing, as it makes a little sense to get any information from this method. - # The return are kept to keep webgl renderer from breaking. - if len(animations) == 0: - raise ValueError("Called Scene.play with no animations") - - self.animations = self.compile_animations(*animations, **play_kwargs) - self.add_mobjects_from_animations(self.animations) - - self.last_t = 0 - self.stop_condition = None - self.moving_mobjects = [] - self.static_mobjects = [] - - if len(self.animations) == 1 and isinstance(self.animations[0], Wait): - if self.should_update_mobjects(): - self.update_mobjects(dt=0) # Any problems with this? - self.stop_condition = self.animations[0].stop_condition - else: - self.duration = self.animations[0].duration - # Static image logic when the wait is static is done by the renderer, not here. - self.animations[0].is_static_wait = True - return None - self.duration = self.get_run_time(self.animations) - return self + if self.skip_animations and self.window is not None: + # Show some quick frames along the way + self.update_frame(dt=0, ignore_skipping=True) - def begin_animations(self) -> None: - """Start the animations of the scene.""" - for animation in self.animations: - animation._setup_scene(self) - animation.begin() + self.num_plays += 1 - if config.renderer == RendererType.CAIRO: - # Paint all non-moving objects onto the screen, so they don't - # have to be rendered every frame - ( - self.moving_mobjects, - self.static_mobjects, - ) = self.get_moving_and_static_mobjects(self.animations) - - def is_current_animation_frozen_frame(self) -> bool: - """Returns whether the current animation produces a static frame (generally a Wait).""" - return ( - isinstance(self.animations[0], Wait) - and len(self.animations) == 1 - and self.animations[0].is_static_wait - ) + def refresh_static_mobjects(self) -> None: + self.camera.refresh_static_mobjects() - def play_internal(self, skip_rendering: bool = False): - """ - This method is used to prep the animations for rendering, - apply the arguments and parameters required to them, - render them, and write them to the video file. - - Parameters - ---------- - skip_rendering - Whether the rendering should be skipped, by default False - """ - self.duration = self.get_run_time(self.animations) - self.time_progression = self._get_animation_time_progression( - self.animations, - self.duration, - ) - for t in self.time_progression: - self.update_to_time(t) - if not skip_rendering and not self.skip_animation_preview: - self.renderer.render(self, t, self.moving_mobjects) - if self.stop_condition is not None and self.stop_condition(): - self.time_progression.close() - break - - for animation in self.animations: + def begin_animations(self, animations: Iterable[Animation]) -> None: + for animation in animations: + animation.begin() + # Anything animated that's not already in the + # scene gets added to the scene. Note, for + # animated mobjects that are in the family of + # those on screen, this can result in a restructuring + # of the scene.mobjects list, which is usually desired. + if animation.mobject not in self.mobjects: + self.add(animation.mobject) + + def progress_through_animations(self, animations: Iterable[Animation]) -> None: + last_t = 0 + for t in self.get_animation_time_progression(animations): + dt = t - last_t + last_t = t + for animation in animations: + animation.update_mobjects(dt) + alpha = t / animation.run_time + animation.interpolate(alpha) + self.update_frame(dt) + self.emit_frame() + + def finish_animations(self, animations: Iterable[Animation]) -> None: + for animation in animations: animation.finish() animation.clean_up_from_scene(self) - if not self.renderer.skip_animations: + if self.skip_animations: + self.update_mobjects(self.get_run_time(animations)) + else: self.update_mobjects(0) - self.renderer.static_image = None - # Closing the progress bar at the end of the play. - self.time_progression.close() - - def check_interactive_embed_is_valid(self): - if config["force_window"]: - return True - if self.skip_animation_preview: - logger.warning( - "Disabling interactive embed as 'skip_animation_preview' is enabled", - ) - return False - elif config["write_to_movie"]: - logger.warning("Disabling interactive embed as 'write_to_movie' is enabled") - return False - elif config["format"]: - logger.warning( - "Disabling interactive embed as '--format' is set as " - + config["format"], - ) - return False - elif not self.renderer.window: - logger.warning("Disabling interactive embed as no window was created") - return False - elif config.dry_run: - logger.warning("Disabling interactive embed as dry_run is enabled") - return False - return True - - def interactive_embed(self): - """ - Like embed(), but allows for screen interaction. - """ - if not self.check_interactive_embed_is_valid(): + + def play( + self, + *proto_animations: Animation | _AnimationBuilder, + run_time: float | None = None, + rate_func: Callable[[float], float] | None = None, + lag_ratio: float | None = None, + ) -> None: + if len(proto_animations) == 0: + log.warning("Called Scene.play with no animations") return - self.interactive_mode = True + animations = list(map(prepare_animation, proto_animations)) + for anim in animations: + anim.update_rate_info(run_time, rate_func, lag_ratio) + self.pre_play() + self.begin_animations(animations) + self.progress_through_animations(animations) + self.finish_animations(animations) + self.post_play() - def ipython(shell, namespace): - import manim.opengl + def wait( + self, + duration: float = DEFAULT_WAIT_TIME, + stop_condition: Callable[[], bool] = None, + note: str = None, + ignore_presenter_mode: bool = False, + ): + self.pre_play() + self.update_mobjects(dt=0) # Any problems with this? + if ( + self.presenter_mode + and not self.skip_animations + and not ignore_presenter_mode + ): + if note: + log.info(note) + self.hold_loop() + else: + time_progression = self.get_wait_time_progression(duration, stop_condition) + last_t = 0 + for t in time_progression: + dt = t - last_t + last_t = t + self.update_frame(dt) + self.emit_frame() + if stop_condition is not None and stop_condition(): + break + self.refresh_static_mobjects() + self.post_play() - def load_module_into_namespace(module, namespace): - for name in dir(module): - namespace[name] = getattr(module, name) + def hold_loop(self): + while self.hold_on_wait: + self.update_frame(dt=1 / self.camera.fps) + self.hold_on_wait = True - load_module_into_namespace(manim, namespace) - load_module_into_namespace(manim.opengl, namespace) + def wait_until(self, stop_condition: Callable[[], bool], max_time: float = 60): + self.wait(max_time, stop_condition=stop_condition) - def embedded_rerun(*args, **kwargs): - self.queue.put(("rerun_keyboard", args, kwargs)) - shell.exiter() + def force_skipping(self): + self.original_skipping_status = self.skip_animations + self.skip_animations = True + return self - namespace["rerun"] = embedded_rerun + def revert_to_original_skipping_status(self): + if hasattr(self, "original_skipping_status"): + self.skip_animations = self.original_skipping_status + return self - shell(local_ns=namespace) - self.queue.put(("exit_keyboard", [], {})) + def add_sound( + self, + sound_file: str, + time_offset: float = 0, + gain: float | None = None, + gain_to_background: float | None = None, + ): + if self.skip_animations: + return + time = self.get_time() + time_offset + self.file_writer.add_sound(sound_file, time, gain, gain_to_background) - def get_embedded_method(method_name): - return lambda *args, **kwargs: self.queue.put((method_name, args, kwargs)) + # Helpers for interactive development - local_namespace = inspect.currentframe().f_back.f_locals - for method in ("play", "wait", "add", "remove"): - embedded_method = get_embedded_method(method) - # Allow for calling scene methods without prepending 'self.'. - local_namespace[method] = embedded_method + def get_state(self) -> SceneState: + return SceneState(self) - from IPython.terminal.embed import InteractiveShellEmbed - from traitlets.config import Config + def restore_state(self, scene_state: SceneState): + scene_state.restore_scene(self) - cfg = Config() - cfg.TerminalInteractiveShell.confirm_exit = False - shell = InteractiveShellEmbed(config=cfg) + def save_state(self) -> None: + if not self.preview: + return + state = self.get_state() + if self.undo_stack and state.mobjects_match(self.undo_stack[-1]): + return + self.redo_stack = [] + self.undo_stack.append(state) + if len(self.undo_stack) > self.max_num_saved_states: + self.undo_stack.pop(0) + + def undo(self): + if self.undo_stack: + self.redo_stack.append(self.get_state()) + self.restore_state(self.undo_stack.pop()) + self.refresh_static_mobjects() + + def redo(self): + if self.redo_stack: + self.undo_stack.append(self.get_state()) + self.restore_state(self.redo_stack.pop()) + self.refresh_static_mobjects() + + def checkpoint_paste(self, skip: bool = False): + """ + Used during interactive development to run (or re-run) + a block of scene code. + + If the copied selection starts with a comment, this will + revert to the state of the scene the first time this function + was called on a block of code starting with that comment. + """ + shell = get_ipython() + if shell is None: + raise Exception( + "Scene.checkpoint_paste cannot be called outside of " + + "an ipython shell" + ) - keyboard_thread = threading.Thread( - target=ipython, - args=(shell, local_namespace), - ) - # run as daemon to kill thread when main thread exits - if not shell.pt_app: - keyboard_thread.daemon = True - keyboard_thread.start() - - if self.dearpygui_imported and config["enable_gui"]: - if not dpg.is_dearpygui_running(): - gui_thread = threading.Thread( - target=configure_pygui, - args=(self.renderer, self.widgets), - kwargs={"update": False}, - ) - gui_thread.start() + pasted = pyperclip.paste() + line0 = pasted.lstrip().split("\n")[0] + if line0.startswith("#"): + if line0 not in self.checkpoint_states: + self.checkpoint(line0) else: - configure_pygui(self.renderer, self.widgets, update=True) + self.revert_to_checkpoint(line0) - self.camera.model_matrix = self.camera.default_model_matrix + prev_skipping = self.skip_animations + self.skip_animations = skip - self.interact(shell, keyboard_thread) + shell.run_cell(pasted) - def interact(self, shell, keyboard_thread): - event_handler = RerunSceneHandler(self.queue) - file_observer = Observer() - file_observer.schedule(event_handler, config["input_file"], recursive=True) - file_observer.start() + self.skip_animations = prev_skipping - self.quit_interaction = False - keyboard_thread_needs_join = shell.pt_app is not None - assert self.queue.qsize() == 0 - - last_time = time.time() - while not (self.renderer.window.is_closing or self.quit_interaction): - if not self.queue.empty(): - tup = self.queue.get_nowait() - if tup[0].startswith("rerun"): - # Intentionally skip calling join() on the file thread to save time. - if not tup[0].endswith("keyboard"): - if shell.pt_app: - shell.pt_app.app.exit(exception=EOFError) - file_observer.unschedule_all() - raise RerunSceneException - keyboard_thread.join() - - kwargs = tup[2] - if "from_animation_number" in kwargs: - config["from_animation_number"] = kwargs[ - "from_animation_number" - ] - # # TODO: This option only makes sense if interactive_embed() is run at the - # # end of a scene by default. - # if "upto_animation_number" in kwargs: - # config["upto_animation_number"] = kwargs[ - # "upto_animation_number" - # ] - - keyboard_thread.join() - file_observer.unschedule_all() - raise RerunSceneException - elif tup[0].startswith("exit"): - # Intentionally skip calling join() on the file thread to save time. - if not tup[0].endswith("keyboard") and shell.pt_app: - shell.pt_app.app.exit(exception=EOFError) - keyboard_thread.join() - # Remove exit_keyboard from the queue if necessary. - while self.queue.qsize() > 0: - self.queue.get() - keyboard_thread_needs_join = False - break - else: - method, args, kwargs = tup - getattr(self, method)(*args, **kwargs) - else: - self.renderer.animation_start_time = 0 - dt = time.time() - last_time - last_time = time.time() - self.renderer.render(self, dt, self.moving_mobjects) - self.update_mobjects(dt) - self.update_meshes(dt) - self.update_self(dt) - - # Join the keyboard thread if necessary. - if shell is not None and keyboard_thread_needs_join: - shell.pt_app.app.exit(exception=EOFError) - keyboard_thread.join() - # Remove exit_keyboard from the queue if necessary. - while self.queue.qsize() > 0: - self.queue.get() - - file_observer.stop() - file_observer.join() - - if self.dearpygui_imported and config["enable_gui"]: - dpg.stop_dearpygui() - - if self.renderer.window.is_closing: - self.renderer.window.destroy() - - def embed(self): - if not config["preview"]: - logger.warning("Called embed() while no preview window is available.") - return - if config["write_to_movie"]: - logger.warning("embed() is skipped while writing to a file.") - return + def checkpoint(self, key: str): + self.checkpoint_states[key] = self.get_state() - self.renderer.animation_start_time = 0 - self.renderer.render(self, -1, self.moving_mobjects) + def revert_to_checkpoint(self, key: str): + if key not in self.checkpoint_states: + log.error(f"No checkpoint at {key}") + return + all_keys = list(self.checkpoint_states.keys()) + index = all_keys.index(key) + for later_key in all_keys[index + 1 :]: + self.checkpoint_states.pop(later_key) - # Configure IPython shell. - from IPython.terminal.embed import InteractiveShellEmbed + self.restore_state(self.checkpoint_states[key]) - shell = InteractiveShellEmbed() + def clear_checkpoints(self): + self.checkpoint_states = {} - # Have the frame update after each command - shell.events.register( - "post_run_cell", - lambda *a, **kw: self.renderer.render(self, -1, self.moving_mobjects), - ) - - # Use the locals of the caller as the local namespace - # once embedded, and add a few custom shortcuts. - local_ns = inspect.currentframe().f_back.f_locals - # local_ns["touch"] = self.interact - for method in ( - "play", - "wait", - "add", - "remove", - "interact", - # "clear", - # "save_state", - # "restore", - ): - local_ns[method] = getattr(self, method) - shell(local_ns=local_ns, stack_depth=2) - - # End scene when exiting an embed. - raise Exception("Exiting scene.") - - def update_to_time(self, t): - dt = t - self.last_t - self.last_t = t - for animation in self.animations: - animation.update_mobjects(dt) - alpha = t / animation.run_time - animation.interpolate(alpha) - self.update_mobjects(dt) - self.update_meshes(dt) - self.update_self(dt) - - def add_subcaption( - self, content: str, duration: float = 1, offset: float = 0 + def save_mobject_to_file( + self, mobject: Mobject, file_path: str | None = None ) -> None: - r"""Adds an entry in the corresponding subcaption file - at the current time stamp. - - The current time stamp is obtained from ``Scene.renderer.time``. + if file_path is None: + file_path = self.file_writer.get_saved_mobject_path(mobject) + if file_path is None: + return + mobject.save_to_file(file_path) + + def load_mobject(self, file_name): + if os.path.exists(file_name): + path = file_name + else: + directory = self.file_writer.get_saved_mobject_directory() + path = os.path.join(directory, file_name) + return Mobject.load(path) - Parameters - ---------- + def is_window_closing(self): + return self.window and (self.window.is_closing or self.quit_interaction) - content - The subcaption content. - duration - The duration (in seconds) for which the subcaption is shown. - offset - This offset (in seconds) is added to the starting time stamp - of the subcaption. + # Event handling - Examples - -------- + def on_mouse_motion(self, point: np.ndarray, d_point: np.ndarray) -> None: + self.mouse_point.move_to(point) - This example illustrates both possibilities for adding - subcaptions to Manimations:: + event_data = {"point": point, "d_point": d_point} + propagate_event = EVENT_DISPATCHER.dispatch( + EventType.MouseMotionEvent, **event_data + ) + if propagate_event is not None and propagate_event is False: + return - class SubcaptionExample(Scene): - def construct(self): - square = Square() - circle = Circle() + frame = self.camera.frame + # Handle perspective changes + if self.window.is_key_pressed(ord(PAN_3D_KEY)): + frame.increment_theta(-self.pan_sensitivity * d_point[0]) + frame.increment_phi(self.pan_sensitivity * d_point[1]) + # Handle frame movements + elif self.window.is_key_pressed(ord(FRAME_SHIFT_KEY)): + shift = -d_point + shift[0] *= frame.get_width() / 2 + shift[1] *= frame.get_height() / 2 + transform = frame.get_inverse_camera_rotation_matrix() + shift = np.dot(np.transpose(transform), shift) + frame.shift(shift) - # first option: via the add_subcaption method - self.add_subcaption("Hello square!", duration=1) - self.play(Create(square)) + def on_mouse_drag( + self, point: np.ndarray, d_point: np.ndarray, buttons: int, modifiers: int + ) -> None: + self.mouse_drag_point.move_to(point) - # second option: within the call to Scene.play - self.play( - Transform(square, circle), - subcaption="The square transforms." - ) + event_data = { + "point": point, + "d_point": d_point, + "buttons": buttons, + "modifiers": modifiers, + } + propagate_event = EVENT_DISPATCHER.dispatch( + EventType.MouseDragEvent, **event_data + ) + if propagate_event is not None and propagate_event is False: + return - """ - subtitle = srt.Subtitle( - index=len(self.renderer.file_writer.subcaptions), - content=content, - start=datetime.timedelta(seconds=self.renderer.time + offset), - end=datetime.timedelta(seconds=self.renderer.time + offset + duration), + def on_mouse_press(self, point: np.ndarray, button: int, mods: int) -> None: + self.mouse_drag_point.move_to(point) + event_data = {"point": point, "button": button, "mods": mods} + propagate_event = EVENT_DISPATCHER.dispatch( + EventType.MousePressEvent, **event_data ) - self.renderer.file_writer.subcaptions.append(subtitle) + if propagate_event is not None and propagate_event is False: + return - def add_sound( - self, - sound_file: str, - time_offset: float = 0, - gain: float | None = None, - **kwargs, - ): - """ - This method is used to add a sound to the animation. - - Parameters - ---------- - - sound_file - The path to the sound file. - time_offset - The offset in the sound file after which - the sound can be played. - gain - Amplification of the sound. - - Examples - -------- - .. manim:: SoundExample - :no_autoplay: - - class SoundExample(Scene): - # Source of sound under Creative Commons 0 License. https://freesound.org/people/Druminfected/sounds/250551/ - def construct(self): - dot = Dot().set_color(GREEN) - self.add_sound("click.wav") - self.add(dot) - self.wait() - self.add_sound("click.wav") - dot.set_color(BLUE) - self.wait() - self.add_sound("click.wav") - dot.set_color(RED) - self.wait() - - Download the resource for the previous example `here `_ . - """ - if self.renderer.skip_animations: + def on_mouse_release(self, point: np.ndarray, button: int, mods: int) -> None: + event_data = {"point": point, "button": button, "mods": mods} + propagate_event = EVENT_DISPATCHER.dispatch( + EventType.MouseReleaseEvent, **event_data + ) + if propagate_event is not None and propagate_event is False: return - time = self.renderer.time + time_offset - self.renderer.file_writer.add_sound(sound_file, time, gain, **kwargs) - def on_mouse_motion(self, point, d_point): - self.mouse_point.move_to(point) - if SHIFT_VALUE in self.renderer.pressed_keys: - shift = -d_point - shift[0] *= self.camera.get_width() / 2 - shift[1] *= self.camera.get_height() / 2 - transform = self.camera.inverse_rotation_matrix - shift = np.dot(np.transpose(transform), shift) - self.camera.shift(shift) + def on_mouse_scroll(self, point: np.ndarray, offset: np.ndarray) -> None: + event_data = {"point": point, "offset": offset} + propagate_event = EVENT_DISPATCHER.dispatch( + EventType.MouseScrollEvent, **event_data + ) + if propagate_event is not None and propagate_event is False: + return - def on_mouse_scroll(self, point, offset): - if not config.use_projection_stroke_shaders: - factor = 1 + np.arctan(-2.1 * offset[1]) - self.camera.scale(factor, about_point=self.camera_target) - self.mouse_scroll_orbit_controls(point, offset) + frame = self.camera.frame + if self.window.is_key_pressed(ord(ZOOM_KEY)): + factor = 1 + np.arctan(10 * offset[1]) + frame.scale(1 / factor, about_point=point) + else: + transform = frame.get_inverse_camera_rotation_matrix() + shift = np.dot(np.transpose(transform), offset) + frame.shift(-20.0 * shift) + + def on_key_release(self, symbol: int, modifiers: int) -> None: + event_data = {"symbol": symbol, "modifiers": modifiers} + propagate_event = EVENT_DISPATCHER.dispatch( + EventType.KeyReleaseEvent, **event_data + ) + if propagate_event is not None and propagate_event is False: + return - def on_key_press(self, symbol, modifiers): + def on_key_press(self, symbol: int, modifiers: int) -> None: try: char = chr(symbol) except OverflowError: - logger.warning("The value of the pressed key is too large.") + log.warning("The value of the pressed key is too large.") return - if char == "r": - self.camera.to_default_state() - self.camera_target = np.array([0, 0, 0], dtype=np.float32) - elif char == "q": + event_data = {"symbol": symbol, "modifiers": modifiers} + propagate_event = EVENT_DISPATCHER.dispatch( + EventType.KeyPressEvent, **event_data + ) + if propagate_event is not None and propagate_event is False: + return + + if char == RESET_FRAME_KEY: + self.play(self.camera.frame.animate.to_default_state()) + elif char == "z" and modifiers == COMMAND_MODIFIER: + self.undo() + elif char == "z" and modifiers == COMMAND_MODIFIER | SHIFT_MODIFIER: + self.redo() + # command + q + elif char == QUIT_KEY and modifiers == COMMAND_MODIFIER: self.quit_interaction = True - else: - if char in self.key_to_function_map: - self.key_to_function_map[char]() + # Space or right arrow + elif char == " " or symbol == ARROW_SYMBOLS[2]: + self.hold_on_wait = False - def on_key_release(self, symbol, modifiers): - pass + def on_resize(self, width: int, height: int) -> None: + self.camera.reset_pixel_shape(width, height) - def on_mouse_drag(self, point, d_point, buttons, modifiers): - self.mouse_drag_point.move_to(point) - if buttons == 1: - self.camera.increment_theta(-d_point[0]) - self.camera.increment_phi(d_point[1]) - elif buttons == 4: - camera_x_axis = self.camera.model_matrix[:3, 0] - horizontal_shift_vector = -d_point[0] * camera_x_axis - vertical_shift_vector = -d_point[1] * np.cross(OUT, camera_x_axis) - total_shift_vector = horizontal_shift_vector + vertical_shift_vector - self.camera.shift(1.1 * total_shift_vector) - - self.mouse_drag_orbit_controls(point, d_point, buttons, modifiers) - - def mouse_scroll_orbit_controls(self, point, offset): - camera_to_target = self.camera_target - self.camera.get_position() - camera_to_target *= np.sign(offset[1]) - shift_vector = 0.01 * camera_to_target - self.camera.model_matrix = ( - opengl.translation_matrix(*shift_vector) @ self.camera.model_matrix - ) + def on_show(self) -> None: + pass - def mouse_drag_orbit_controls(self, point, d_point, buttons, modifiers): - # Left click drag. - if buttons == 1: - # Translate to target the origin and rotate around the z axis. - self.camera.model_matrix = ( - opengl.rotation_matrix(z=-d_point[0]) - @ opengl.translation_matrix(*-self.camera_target) - @ self.camera.model_matrix - ) + def on_hide(self) -> None: + pass - # Rotation off of the z axis. - camera_position = self.camera.get_position() - camera_y_axis = self.camera.model_matrix[:3, 1] - axis_of_rotation = space_ops.normalize( - np.cross(camera_y_axis, camera_position), - ) - rotation_matrix = space_ops.rotation_matrix( - d_point[1], - axis_of_rotation, - homogeneous=True, - ) + def on_close(self) -> None: + pass - maximum_polar_angle = self.camera.maximum_polar_angle - minimum_polar_angle = self.camera.minimum_polar_angle - potential_camera_model_matrix = rotation_matrix @ self.camera.model_matrix - potential_camera_location = potential_camera_model_matrix[:3, 3] - potential_camera_y_axis = potential_camera_model_matrix[:3, 1] - sign = ( - np.sign(potential_camera_y_axis[2]) - if potential_camera_y_axis[2] != 0 - else 1 - ) - potential_polar_angle = sign * np.arccos( - potential_camera_location[2] - / np.linalg.norm(potential_camera_location), - ) - if minimum_polar_angle <= potential_polar_angle <= maximum_polar_angle: - self.camera.model_matrix = potential_camera_model_matrix +class SceneState: + def __init__(self, scene: Scene, ignore: list[Mobject] | None = None): + self.time = scene.time + self.num_plays = scene.num_plays + self.mobjects_to_copies = OrderedDict.fromkeys(scene.mobjects) + if ignore: + for mob in ignore: + self.mobjects_to_copies.pop(mob, None) + + last_m2c = scene.undo_stack[-1].mobjects_to_copies if scene.undo_stack else {} + for mob in self.mobjects_to_copies: + # If it hasn't changed since the last state, just point to the + # same copy as before + if mob in last_m2c and last_m2c[mob].looks_identical(mob): + self.mobjects_to_copies[mob] = last_m2c[mob] else: - sign = np.sign(camera_y_axis[2]) if camera_y_axis[2] != 0 else 1 - current_polar_angle = sign * np.arccos( - camera_position[2] / np.linalg.norm(camera_position), - ) - if potential_polar_angle > maximum_polar_angle: - polar_angle_delta = maximum_polar_angle - current_polar_angle - else: - polar_angle_delta = minimum_polar_angle - current_polar_angle - rotation_matrix = space_ops.rotation_matrix( - polar_angle_delta, - axis_of_rotation, - homogeneous=True, - ) - self.camera.model_matrix = rotation_matrix @ self.camera.model_matrix - - # Translate to target the original target. - self.camera.model_matrix = ( - opengl.translation_matrix(*self.camera_target) - @ self.camera.model_matrix - ) - # Right click drag. - elif buttons == 4: - camera_x_axis = self.camera.model_matrix[:3, 0] - horizontal_shift_vector = -d_point[0] * camera_x_axis - vertical_shift_vector = -d_point[1] * np.cross(OUT, camera_x_axis) - total_shift_vector = horizontal_shift_vector + vertical_shift_vector - - self.camera.model_matrix = ( - opengl.translation_matrix(*total_shift_vector) - @ self.camera.model_matrix + self.mobjects_to_copies[mob] = mob.copy() + + def __eq__(self, state: SceneState): + return all( + ( + self.time == state.time, + self.num_plays == state.num_plays, + self.mobjects_to_copies == state.mobjects_to_copies, ) - self.camera_target += total_shift_vector + ) + + def mobjects_match(self, state: SceneState): + return self.mobjects_to_copies == state.mobjects_to_copies + + def n_changes(self, state: SceneState): + m2c = state.mobjects_to_copies + return sum( + 1 - int(mob in m2c and mob.looks_identical(m2c[mob])) + for mob in self.mobjects_to_copies + ) + + def restore_scene(self, scene: Scene): + scene.time = self.time + scene.num_plays = self.num_plays + scene.mobjects = [ + mob.become(mob_copy) for mob, mob_copy in self.mobjects_to_copies.items() + ] - def set_key_function(self, char, func): - self.key_to_function_map[char] = func - def on_mouse_press(self, point, button, modifiers): - for func in self.mouse_press_callbacks: - func() +class EndScene(Exception): + pass