Skip to content

Commit 03db52b

Browse files
authored
Merge branch 'main' into ti
2 parents bd225d0 + 99feae7 commit 03db52b

File tree

16 files changed

+962
-228
lines changed

16 files changed

+962
-228
lines changed

.github/workflows/publish-book.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: publish-book
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
8+
jobs:
9+
deploy-book:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
14+
- name: Set up Python 3.10
15+
uses: actions/setup-python@v4
16+
with:
17+
python-version: "3.10"
18+
19+
- name: Install dependencies
20+
run: |
21+
python -m pip install --upgrade pip setuptools
22+
python -m pip install .[docs]
23+
python -m pip install jupyter-book numpydoc
24+
pip install jupyter-book sphinxcontrib-mermaid
25+
26+
- name: Build the book
27+
run: |
28+
jupyter-book build .
29+
30+
- name: GitHub Pages action
31+
uses: peaceiris/[email protected]
32+
with:
33+
github_token: ${{ secrets.GITHUB_TOKEN }}
34+
publish_dir: ./_build/html

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
run: |
3737
python -m pip install --upgrade pip
3838
pip install pytest numpy==1.23.5 tables==3.8.0
39-
pip install deeplabcut==3.0.0rc4
39+
pip install deeplabcut==3.0.0rc8
4040
pip install pytest
4141
pip install pytest-timeout
4242
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ __pycache__
1313
__pycache__/
1414
*.py[cod]
1515
*$py.class
16+
notebooks/.ipynb_checkpoints/
1617

1718
# Binary files
1819
*.jpg
@@ -23,6 +24,7 @@ __pycache__/
2324
# Distribution / packaging
2425
.Python
2526
build/
27+
_build/
2628
develop-eggs/
2729
dist/
2830
downloads/

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
[🛠️ Installation](https://github.com/AdaptiveMotorControlLab/AmadeusGPT?tab=readme-ov-file#install--run-amadeusgpt) |
1515
[🌎 Home Page](http://www.mackenziemathislab.org/amadeusgpt) |
16+
[🚀 Demos & Docs](https://adaptivemotorcontrollab.github.io/AmadeusGPT/README.html)
1617
[🚨 News](https://github.com/AdaptiveMotorControlLab/AmadeusGPT?tab=readme-ov-file#news) |
1718
[🪲 Reporting Issues](https://github.com/AdaptiveMotorControlLab/AmadeusGPT/issues) |
1819
[💬 Discussions!](https://github.com/AdaptiveMotorControlLab/AmadeusGPT/discussions)
@@ -157,6 +158,7 @@ AmadeusGPT is license under the Apache-2.0 license.
157158
- If you already have keypoint file corresponding to the video file, look how we set-up the config file in the Notebooks. Right now we only support keypoint output from DeepLabCut.
158159

159160
## News
161+
- June 2025 [v0.1.3](https://pypi.org/project/amadeusgpt/0.1.3/) is out, and we introduce new demo docs!
160162
- July 2024 [v0.1.1](https://github.com/AdaptiveMotorControlLab/AmadeusGPT/releases/tag/v0.1.1) is released! This is a major code update ...
161163
- June 2024 as part of the CZI EOSS, The Kavli Foundation now supports this work! ✨
162164
- 🤩 Dec 2023, code released!

_config.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
title: AmadeusGPT
2+
author: MLAI
3+
logo: docs/logo.png
4+
only_build_toc_files: true
5+
6+
sphinx:
7+
config:
8+
autodoc_mock_imports: list #["wx"]
9+
extra_extensions:
10+
- numpydoc
11+
12+
execute:
13+
execute_notebooks: "off"
14+
15+
html:
16+
extra_navbar: ""
17+
use_issues_button: true
18+
use_repository_button: true
19+
extra_footer: |
20+
<div>Powered by <a href="https://jupyterbook.org/">Jupyter Book</a>.</div>
21+
22+
repository:
23+
url: https://github.com/AdaptiveMotorControlLab/AmadeusGPT
24+
path_to_book: main
25+
branch: main
26+
27+
launch_buttons:
28+
colab_url: "https://colab.research.google.com/github.com/AdaptiveMotorControlLab/AmadeusGPT/examples/yourdemo.ipynb"

_toc.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
format: jb-book
2+
root: README
3+
parts:
4+
- caption: Using AmadeusGPT
5+
chapters:
6+
- file: notebooks/EPM_demo
7+
- file: notebooks/Horse_demo
8+
- file: notebooks/MABe_demo
9+
- file: notebooks/MausHaus_demo
10+
- file: notebooks/Use_Task_Program
11+
- file: notebooks/YourData

amadeusgpt/analysis_objects/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ def _superanimal_inference(
2020
):
2121
import deeplabcut
2222

23+
# Patch for PyTorch 2.6 weights_only issue
24+
from amadeusgpt.utils import patch_pytorch_weights_only
25+
patch_pytorch_weights_only()
26+
2327
progress_obj = st.progress(0)
2428
deeplabcut.video_inference_superanimal(
2529
[video_file_path],

amadeusgpt/managers/animal_manager.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def __init__(self, identifier: Identifier):
9696
self.full_keypoint_names = []
9797
self.superanimal_predicted_video = None
9898
self.superanimal_name = None
99+
self.model_name = None
100+
self.detector_name = None
99101
self.init_pose()
100102

101103
def configure_animal_from_meta(self, meta_info):
@@ -106,11 +108,17 @@ def configure_animal_from_meta(self, meta_info):
106108
self.max_individuals = int(meta_info["individuals"])
107109
species = meta_info["species"]
108110
if species == "topview_mouse":
109-
self.superanimal_name = "superanimal_topviewmouse_hrnetw32"
111+
self.superanimal_name = "superanimal_topviewmouse"
112+
self.model_name = "hrnet_w32"
113+
self.detector_name = "fasterrcnn_resnet50_fpn_v2"
110114
elif species == "sideview_quadruped":
111-
self.superanimal_name = "superanimal_quadruped_hrnetw32"
115+
self.superanimal_name = "superanimal_quadruped"
116+
self.model_name = "hrnet_w32"
117+
self.detector_name = "fasterrcnn_resnet50_fpn_v2"
112118
else:
113119
self.superanimal_name = None
120+
self.model_name = None
121+
self.detector_name = None
114122

115123
def init_pose(self):
116124

@@ -304,20 +312,30 @@ def get_keypoints(self) -> ndarray:
304312
from deeplabcut.modelzoo.video_inference import \
305313
video_inference_superanimal
306314

315+
# Patch for PyTorch 2.6+ weights_only issue
316+
from amadeusgpt.utils import patch_pytorch_weights_only
317+
patch_pytorch_weights_only()
318+
307319
video_suffix = Path(self.video_file_path).suffix
308320

309321
self.keypoint_file_path = self.video_file_path.replace(
310-
video_suffix, "_" + self.superanimal_name + ".h5"
322+
video_suffix, f"_superanimal_{self.superanimal_name.split('_', 1)[1]}_{self.detector_name}_{self.model_name}.h5"
311323
)
312324
self.superanimal_predicted_video = self.keypoint_file_path.replace(
313325
".h5", "_labeled.mp4"
314326
)
315327

316328
if not os.path.exists(self.keypoint_file_path):
317329
print(f"going to inference video with {self.superanimal_name}")
330+
if self.model_name is None:
331+
raise ValueError("Model name not set. Please call configure_animal_from_meta first.")
332+
if self.detector_name is None:
333+
raise ValueError("Detector name not set. Please call configure_animal_from_meta first.")
318334
video_inference_superanimal(
319335
videos=[self.video_file_path],
320336
superanimal_name=self.superanimal_name,
337+
model_name=self.model_name,
338+
detector_name=self.detector_name,
321339
max_individuals=self.max_individuals,
322340
video_adapt=False,
323341
)

amadeusgpt/utils/__init__.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import ast
2+
import inspect
3+
import sys
4+
import time
5+
import traceback
6+
from collections import defaultdict
7+
import textwrap
8+
import numpy as np
9+
from amadeusgpt.analysis_objects.event import Event
10+
from amadeusgpt.logger import AmadeusLogger
11+
from IPython.display import Markdown, Video, display, HTML
12+
13+
def filter_kwargs_for_function(func, kwargs):
14+
sig = inspect.signature(func)
15+
return {k: v for k, v in kwargs.items() if k in sig.parameters}
16+
17+
def timer_decorator(func):
18+
def wrapper(*args, **kwargs):
19+
start_time = time.time() # before calling the function
20+
result = func(*args, **kwargs) # call the function
21+
end_time = time.time() # after calling the function
22+
AmadeusLogger.debug(
23+
f"The function {func.__name__} took {end_time - start_time} seconds to execute."
24+
)
25+
print(
26+
f"The function {func.__name__} took {end_time - start_time} seconds to execute."
27+
)
28+
return result
29+
return wrapper
30+
31+
def parse_error_message_from_python():
32+
exc_type, exc_value, exc_traceback = sys.exc_info()
33+
traceback_str = "".join(
34+
traceback.format_exception(exc_type, exc_value, exc_traceback)
35+
)
36+
return traceback_str
37+
38+
def validate_openai_api_key(key):
39+
import openai
40+
openai.api_key = key
41+
try:
42+
openai.models.list()
43+
return True
44+
except openai.AuthenticationError:
45+
return False
46+
47+
def flatten_tuple(t):
48+
"""
49+
Used to handle function returns
50+
"""
51+
flattened = []
52+
for item in t:
53+
if isinstance(item, tuple):
54+
flattened.extend(flatten_tuple(item))
55+
else:
56+
flattened.append(item)
57+
return tuple(flattened)
58+
59+
def func2json(func):
60+
if isinstance(func, str):
61+
func_str = textwrap.dedent(func)
62+
parsed = ast.parse(func_str)
63+
func_def = parsed.body[0]
64+
func_name = func_def.name
65+
docstring = ast.get_docstring(func_def)
66+
if (
67+
func_def.body
68+
and isinstance(func_def.body[0], ast.Expr)
69+
and isinstance(func_def.body[0].value, (ast.Str, ast.Constant))
70+
):
71+
func_def.body.pop(0)
72+
func_def.decorator_list = []
73+
if hasattr(ast, "unparse"):
74+
source_without_docstring_or_decorators = ast.unparse(func_def)
75+
else:
76+
source_without_docstring_or_decorators = None
77+
return_annotation = "No return annotation"
78+
if func_def.returns:
79+
return_annotation = ast.unparse(func_def.returns)
80+
json_obj = {
81+
"name": func_name,
82+
"inputs": "",
83+
"source_code": source_without_docstring_or_decorators,
84+
"docstring": docstring,
85+
"return": return_annotation,
86+
}
87+
return json_obj
88+
else:
89+
sig = inspect.signature(func)
90+
inputs = {name: str(param.annotation) for name, param in sig.parameters.items()}
91+
docstring = inspect.getdoc(func)
92+
if docstring:
93+
docstring = textwrap.dedent(docstring)
94+
full_source = inspect.getsource(func)
95+
parsed = ast.parse(textwrap.dedent(full_source))
96+
func_def = parsed.body[0]
97+
if (
98+
func_def.body
99+
and isinstance(func_def.body[0], ast.Expr)
100+
and isinstance(func_def.body[0].value, (ast.Str, ast.Constant))
101+
):
102+
func_def.body.pop(0)
103+
func_def.decorator_list = []
104+
if hasattr(ast, "unparse"):
105+
source_without_docstring_or_decorators = ast.unparse(func_def)
106+
else:
107+
source_without_docstring_or_decorators = None
108+
json_obj = {
109+
"name": func.__name__,
110+
"inputs": inputs,
111+
"source_code": textwrap.dedent(source_without_docstring_or_decorators),
112+
"docstring": docstring,
113+
"return": str(sig.return_annotation),
114+
}
115+
return json_obj
116+
117+
class QA_Message:
118+
def __init__(self, query: str, video_file_paths: list[str]):
119+
self.query = query
120+
self.video_file_paths = video_file_paths
121+
self.code = None
122+
self.chain_of_thought = None
123+
self.error_message = defaultdict(list)
124+
self.plots = defaultdict(list)
125+
self.out_videos = defaultdict(list)
126+
self.pose_video = defaultdict(list)
127+
self.function_rets = defaultdict(list)
128+
self.meta_info = {}
129+
def get_masks(self) -> dict[str, np.ndarray]:
130+
ret = {}
131+
function_rets = self.function_rets
132+
for video_path, rets in function_rets.items():
133+
if isinstance(rets, list) and len(rets) > 0 and isinstance(rets[0], Event):
134+
events = rets
135+
masks = []
136+
for event in events:
137+
masks.append(event.generate_mask())
138+
ret[video_path] = np.array(masks)
139+
else:
140+
ret[video_path] = None
141+
return ret
142+
def serialize_qa_message(self):
143+
return {
144+
"query": self.query,
145+
"video_file_paths": self.video_file_paths,
146+
"code": self.code,
147+
"chain_of_thought": self.chain_of_thought,
148+
"error_message": self.error_message,
149+
"plots": None,
150+
"out_videos": self.out_videos,
151+
"pose_video": self.pose_video,
152+
"function_rets": self.function_rets,
153+
"meta_info": self.meta_info,
154+
}
155+
def create_qa_message(query: str, video_file_paths: list[str]) -> QA_Message:
156+
return QA_Message(query, video_file_paths)
157+
def parse_result(amadeus, qa_message, use_ipython=True, skip_code_execution=False):
158+
if use_ipython:
159+
display(Markdown(qa_message.chain_of_thought))
160+
else:
161+
print(qa_message.chain_of_thought)
162+
sandbox = amadeus.sandbox
163+
if not skip_code_execution:
164+
qa_message = sandbox.code_execution(qa_message)
165+
qa_message = sandbox.render_qa_message(qa_message)
166+
if len(qa_message.out_videos) > 0:
167+
print(f"videos generated to {qa_message.out_videos}")
168+
print(
169+
"Open it with media player if it does not properly display in the notebook"
170+
)
171+
if use_ipython:
172+
if len(qa_message.out_videos) > 0:
173+
for identifier, event_videos in qa_message.out_videos.items():
174+
for event_video in event_videos:
175+
display(Video(event_video, embed=True))
176+
if use_ipython:
177+
from matplotlib.animation import FuncAnimation
178+
if len(qa_message.function_rets) > 0:
179+
for identifier, rets in qa_message.function_rets.items():
180+
if not isinstance(rets, (tuple, list)):
181+
rets = [rets]
182+
for ret in rets:
183+
if isinstance(ret, FuncAnimation):
184+
display(HTML(ret.to_jshtml()))
185+
else:
186+
display(Markdown(str(qa_message.function_rets[identifier])))
187+
return qa_message
188+
189+
def patch_pytorch_weights_only():
190+
"""
191+
Patch for PyTorch 2.6 weights_only issue with DeepLabCut SuperAnimal models.
192+
This adds safe globals to allow loading of ruamel.yaml.scalarfloat.ScalarFloat objects.
193+
Only applies the patch if torch.serialization.add_safe_globals exists (PyTorch >=2.6).
194+
"""
195+
try:
196+
import torch
197+
from ruamel.yaml.scalarfloat import ScalarFloat
198+
if hasattr(torch.serialization, "add_safe_globals"):
199+
torch.serialization.add_safe_globals([ScalarFloat])
200+
except ImportError:
201+
pass # If ruamel.yaml is not available, continue without the patch

docs/logo.png

295 KB
Loading

0 commit comments

Comments
 (0)