Skip to content

Improve Updater signature checks for substantial speedups, add UpdaterWrappers and more Updater aliases #3807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference_index/utilities_misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ Module Index
~utils.tex_file_writing
~utils.tex_templates
typing
~utils.updaters
17 changes: 9 additions & 8 deletions manim/animation/speedmodifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

from __future__ import annotations

import inspect
import types
from typing import TYPE_CHECKING, Callable

from numpy import piecewise

from ..animation.animation import Animation, Wait, prepare_animation
from ..animation.composition import AnimationGroup
from ..mobject.mobject import Mobject, _AnimationBuilder
from ..scene.scene import Scene
from manim.animation.animation import Animation, Wait, prepare_animation
from manim.animation.composition import AnimationGroup
from manim.mobject.mobject import Mobject, _AnimationBuilder
from manim.scene.scene import Scene
from manim.utils.updaters import MobjectUpdaterWrapper

if TYPE_CHECKING:
from ..mobject.mobject import Updater
from manim.utils.updaters import MobjectUpdater

__all__ = ["ChangeSpeed"]

Expand Down Expand Up @@ -235,7 +235,7 @@ def get_scaled_total_time(self) -> float:
def add_updater(
cls,
mobject: Mobject,
update_function: Updater,
update_function: MobjectUpdater,
index: int | None = None,
call_updater: bool = False,
):
Expand Down Expand Up @@ -264,7 +264,8 @@ def add_updater(
:class:`.ChangeSpeed`
:meth:`.Mobject.add_updater`
"""
if "dt" in inspect.signature(update_function).parameters:
wrapper = MobjectUpdaterWrapper(update_function)
if wrapper.is_time_based:
mobject.add_updater(
lambda mob, dt: update_function(
mob, ChangeSpeed.dt if ChangeSpeed.is_changing_dt else dt
Expand Down
92 changes: 46 additions & 46 deletions manim/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,24 @@


import copy
import inspect
import itertools as it
import math
import operator as op
import random
import sys
import types
import warnings
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from functools import partialmethod, reduce
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Literal

import numpy as np

from manim import config, logger
from manim.constants import *
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL

from .. import config, logger
from ..constants import *
from ..utils.color import (
from manim.utils.color import (
BLACK,
WHITE,
YELLOW_C,
Expand All @@ -34,14 +32,16 @@
color_gradient,
interpolate_color,
)
from ..utils.exceptions import MultiAnimationOverrideException
from ..utils.iterables import list_update, remove_list_redundancies
from ..utils.paths import straight_path
from ..utils.space_ops import angle_between_vectors, normalize, rotation_matrix
from manim.utils.exceptions import MultiAnimationOverrideException
from manim.utils.iterables import list_update, remove_list_redundancies
from manim.utils.paths import straight_path
from manim.utils.space_ops import angle_between_vectors, normalize, rotation_matrix
from manim.utils.updaters import MobjectUpdaterWrapper

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'MobjectUpdaterWrapper' may not be defined if module
manim.utils.updaters
is imported before module
manim.mobject.mobject
, as the
definition
of MobjectUpdaterWrapper occurs after the cyclic
import
of manim.mobject.mobject.

if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias
from typing_extensions import Self

from manim.animation.animation import Animation

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'Animation' may not be defined if module
manim.animation.animation
is imported before module
manim.mobject.mobject
, as the
definition
of Animation occurs after the cyclic
import
of manim.mobject.mobject.
'Animation' may not be defined if module
manim.animation.animation
is imported before module
manim.mobject.mobject
, as the
definition
of Animation occurs after the cyclic
import
of manim.mobject.mobject.
'Animation' may not be defined if module
manim.animation.animation
is imported before module
manim.mobject.mobject
, as the
definition
of Animation occurs after the cyclic
import
of manim.mobject.mobject.
'Animation' may not be defined if module
manim.animation.animation
is imported before module
manim.mobject.mobject
, as the
definition
of Animation occurs after the cyclic
import
of manim.mobject.mobject.
'Animation' may not be defined if module
manim.animation.animation
is imported before module
manim.mobject.mobject
, as the
definition
of Animation occurs after the cyclic
import
of manim.mobject.mobject.
from manim.typing import (
FunctionOverride,
ManimFloat,
Expand All @@ -53,12 +53,7 @@
Point3D_Array,
Vector3D,
)

from ..animation.animation import Animation

TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object]
NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object]
Updater: TypeAlias = NonTimeBasedUpdater | TimeBasedUpdater
from manim.utils.updaters import MobjectDtUpdater, MobjectUpdater

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'MobjectDtUpdater' may not be defined if module
manim.utils.updaters
is imported before module
manim.mobject.mobject
, as the
definition
of MobjectDtUpdater occurs after the cyclic
import
of manim.mobject.mobject.

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'MobjectUpdater' may not be defined if module
manim.utils.updaters
is imported before module
manim.mobject.mobject
, as the
definition
of MobjectUpdater occurs after the cyclic
import
of manim.mobject.mobject.


class Mobject:
Expand All @@ -71,6 +66,7 @@
Attributes
----------
submobjects : List[:class:`Mobject`]

The contained objects.
points : :class:`numpy.ndarray`
The points of the objects.
Expand All @@ -96,7 +92,7 @@

def __init__(
self,
color: ParsableManimColor | list[ParsableManimColor] = WHITE,
color: ParsableManimColor | Sequence[ParsableManimColor] = WHITE,
name: str | None = None,
dim: int = 3,
target=None,
Expand All @@ -108,7 +104,7 @@
self.z_index = z_index
self.point_hash = None
self.submobjects = []
self.updaters: list[Updater] = []
self.updater_wrappers: Sequence[MobjectUpdaterWrapper] = []
self.updating_suspended = False
self.color = ManimColor.parse(color)

Expand Down Expand Up @@ -865,6 +861,10 @@

# Updating

@property
def updaters(self) -> Sequence[MobjectUpdater]:
return self.get_updaters()

def update(self, dt: float = 0, recursive: bool = True) -> Self:
"""Apply all updaters.

Expand All @@ -891,17 +891,17 @@
"""
if self.updating_suspended:
return self
for updater in self.updaters:
if "dt" in inspect.signature(updater).parameters:
updater(self, dt)
for wrapper in self.updater_wrappers:
if wrapper.is_time_based:
wrapper.updater(self, dt)
else:
updater(self)
wrapper.updater(self)
if recursive:
for submob in self.submobjects:
submob.update(dt, recursive)
return self

def get_time_based_updaters(self) -> list[TimeBasedUpdater]:
def get_time_based_updaters(self) -> Sequence[MobjectDtUpdater]:
"""Return all updaters using the ``dt`` parameter.

The updaters use this parameter as the input for difference in time.
Expand All @@ -918,9 +918,9 @@

"""
return [
updater
for updater in self.updaters
if "dt" in inspect.signature(updater).parameters
wrapper.updater
for wrapper in self.updater_wrappers
if wrapper.is_time_based
]

def has_time_based_updater(self) -> bool:
Expand All @@ -937,11 +937,9 @@
:meth:`get_time_based_updaters`

"""
return any(
"dt" in inspect.signature(updater).parameters for updater in self.updaters
)
return any(wrapper.is_time_based for wrapper in self.updater_wrappers)

def get_updaters(self) -> list[Updater]:
def get_updaters(self) -> Sequence[MobjectUpdater]:
"""Return all updaters.

Returns
Expand All @@ -955,14 +953,14 @@
:meth:`get_time_based_updaters`

"""
return self.updaters
return [wrapper.updater for wrapper in self.updater_wrappers]

def get_family_updaters(self) -> list[Updater]:
def get_family_updaters(self) -> Sequence[MobjectUpdater]:
return list(it.chain(*(sm.get_updaters() for sm in self.get_family())))

def add_updater(
self,
update_function: Updater,
update_function: MobjectUpdater,
index: int | None = None,
call_updater: bool = False,
) -> Self:
Expand Down Expand Up @@ -1026,19 +1024,19 @@
:meth:`remove_updater`
:class:`~.UpdateFromFunc`
"""
wrapper = MobjectUpdaterWrapper(update_function)
if index is None:
self.updaters.append(update_function)
self.updater_wrappers.append(wrapper)
else:
self.updaters.insert(index, update_function)
self.updater_wrappers.insert(index, wrapper)
if call_updater:
parameters = inspect.signature(update_function).parameters
if "dt" in parameters:
update_function(self, 0)
if wrapper.is_time_based:
wrapper.updater(self, 0)
else:
update_function(self)
wrapper.updater(self)
return self

def remove_updater(self, update_function: Updater) -> Self:
def remove_updater(self, update_function: MobjectUpdater) -> Self:
"""Remove an updater.

If the same updater is applied multiple times, every instance gets removed.
Expand All @@ -1061,8 +1059,11 @@
:meth:`get_updaters`

"""
while update_function in self.updaters:
self.updaters.remove(update_function)
self.updater_wrappers = [
wrapper
for wrapper in self.updater_wrappers
if wrapper.updater != update_function
]
return self

def clear_updaters(self, recursive: bool = True) -> Self:
Expand All @@ -1085,7 +1086,7 @@
:meth:`get_updaters`

"""
self.updaters = []
self.updater_wrappers = []
if recursive:
for submob in self.submobjects:
submob.clear_updaters()
Expand Down Expand Up @@ -1116,8 +1117,7 @@

"""
self.clear_updaters()
for updater in mobject.get_updaters():
self.add_updater(updater)
self.updater_wrappers = mobject.updater_wrappers.copy()
return self

def suspend_updating(self, recursive: bool = True) -> Self:
Expand Down
Loading
Loading