1
+ import ast
2
+ import inspect
3
+ import sys
4
+ import time
5
+ import traceback
6
+ from collections import defaultdict
7
+ import textwrap
8
+ import numpy as np
9
+ from amadeusgpt .analysis_objects .event import Event
10
+ from amadeusgpt .logger import AmadeusLogger
11
+ from IPython .display import Markdown , Video , display , HTML
12
+
13
+ def filter_kwargs_for_function (func , kwargs ):
14
+ sig = inspect .signature (func )
15
+ return {k : v for k , v in kwargs .items () if k in sig .parameters }
16
+
17
+ def timer_decorator (func ):
18
+ def wrapper (* args , ** kwargs ):
19
+ start_time = time .time () # before calling the function
20
+ result = func (* args , ** kwargs ) # call the function
21
+ end_time = time .time () # after calling the function
22
+ AmadeusLogger .debug (
23
+ f"The function { func .__name__ } took { end_time - start_time } seconds to execute."
24
+ )
25
+ print (
26
+ f"The function { func .__name__ } took { end_time - start_time } seconds to execute."
27
+ )
28
+ return result
29
+ return wrapper
30
+
31
+ def parse_error_message_from_python ():
32
+ exc_type , exc_value , exc_traceback = sys .exc_info ()
33
+ traceback_str = "" .join (
34
+ traceback .format_exception (exc_type , exc_value , exc_traceback )
35
+ )
36
+ return traceback_str
37
+
38
+ def validate_openai_api_key (key ):
39
+ import openai
40
+ openai .api_key = key
41
+ try :
42
+ openai .models .list ()
43
+ return True
44
+ except openai .AuthenticationError :
45
+ return False
46
+
47
+ def flatten_tuple (t ):
48
+ """
49
+ Used to handle function returns
50
+ """
51
+ flattened = []
52
+ for item in t :
53
+ if isinstance (item , tuple ):
54
+ flattened .extend (flatten_tuple (item ))
55
+ else :
56
+ flattened .append (item )
57
+ return tuple (flattened )
58
+
59
+ def func2json (func ):
60
+ if isinstance (func , str ):
61
+ func_str = textwrap .dedent (func )
62
+ parsed = ast .parse (func_str )
63
+ func_def = parsed .body [0 ]
64
+ func_name = func_def .name
65
+ docstring = ast .get_docstring (func_def )
66
+ if (
67
+ func_def .body
68
+ and isinstance (func_def .body [0 ], ast .Expr )
69
+ and isinstance (func_def .body [0 ].value , (ast .Str , ast .Constant ))
70
+ ):
71
+ func_def .body .pop (0 )
72
+ func_def .decorator_list = []
73
+ if hasattr (ast , "unparse" ):
74
+ source_without_docstring_or_decorators = ast .unparse (func_def )
75
+ else :
76
+ source_without_docstring_or_decorators = None
77
+ return_annotation = "No return annotation"
78
+ if func_def .returns :
79
+ return_annotation = ast .unparse (func_def .returns )
80
+ json_obj = {
81
+ "name" : func_name ,
82
+ "inputs" : "" ,
83
+ "source_code" : source_without_docstring_or_decorators ,
84
+ "docstring" : docstring ,
85
+ "return" : return_annotation ,
86
+ }
87
+ return json_obj
88
+ else :
89
+ sig = inspect .signature (func )
90
+ inputs = {name : str (param .annotation ) for name , param in sig .parameters .items ()}
91
+ docstring = inspect .getdoc (func )
92
+ if docstring :
93
+ docstring = textwrap .dedent (docstring )
94
+ full_source = inspect .getsource (func )
95
+ parsed = ast .parse (textwrap .dedent (full_source ))
96
+ func_def = parsed .body [0 ]
97
+ if (
98
+ func_def .body
99
+ and isinstance (func_def .body [0 ], ast .Expr )
100
+ and isinstance (func_def .body [0 ].value , (ast .Str , ast .Constant ))
101
+ ):
102
+ func_def .body .pop (0 )
103
+ func_def .decorator_list = []
104
+ if hasattr (ast , "unparse" ):
105
+ source_without_docstring_or_decorators = ast .unparse (func_def )
106
+ else :
107
+ source_without_docstring_or_decorators = None
108
+ json_obj = {
109
+ "name" : func .__name__ ,
110
+ "inputs" : inputs ,
111
+ "source_code" : textwrap .dedent (source_without_docstring_or_decorators ),
112
+ "docstring" : docstring ,
113
+ "return" : str (sig .return_annotation ),
114
+ }
115
+ return json_obj
116
+
117
+ class QA_Message :
118
+ def __init__ (self , query : str , video_file_paths : list [str ]):
119
+ self .query = query
120
+ self .video_file_paths = video_file_paths
121
+ self .code = None
122
+ self .chain_of_thought = None
123
+ self .error_message = defaultdict (list )
124
+ self .plots = defaultdict (list )
125
+ self .out_videos = defaultdict (list )
126
+ self .pose_video = defaultdict (list )
127
+ self .function_rets = defaultdict (list )
128
+ self .meta_info = {}
129
+ def get_masks (self ) -> dict [str , np .ndarray ]:
130
+ ret = {}
131
+ function_rets = self .function_rets
132
+ for video_path , rets in function_rets .items ():
133
+ if isinstance (rets , list ) and len (rets ) > 0 and isinstance (rets [0 ], Event ):
134
+ events = rets
135
+ masks = []
136
+ for event in events :
137
+ masks .append (event .generate_mask ())
138
+ ret [video_path ] = np .array (masks )
139
+ else :
140
+ ret [video_path ] = None
141
+ return ret
142
+ def serialize_qa_message (self ):
143
+ return {
144
+ "query" : self .query ,
145
+ "video_file_paths" : self .video_file_paths ,
146
+ "code" : self .code ,
147
+ "chain_of_thought" : self .chain_of_thought ,
148
+ "error_message" : self .error_message ,
149
+ "plots" : None ,
150
+ "out_videos" : self .out_videos ,
151
+ "pose_video" : self .pose_video ,
152
+ "function_rets" : self .function_rets ,
153
+ "meta_info" : self .meta_info ,
154
+ }
155
+ def create_qa_message (query : str , video_file_paths : list [str ]) -> QA_Message :
156
+ return QA_Message (query , video_file_paths )
157
+ def parse_result (amadeus , qa_message , use_ipython = True , skip_code_execution = False ):
158
+ if use_ipython :
159
+ display (Markdown (qa_message .chain_of_thought ))
160
+ else :
161
+ print (qa_message .chain_of_thought )
162
+ sandbox = amadeus .sandbox
163
+ if not skip_code_execution :
164
+ qa_message = sandbox .code_execution (qa_message )
165
+ qa_message = sandbox .render_qa_message (qa_message )
166
+ if len (qa_message .out_videos ) > 0 :
167
+ print (f"videos generated to { qa_message .out_videos } " )
168
+ print (
169
+ "Open it with media player if it does not properly display in the notebook"
170
+ )
171
+ if use_ipython :
172
+ if len (qa_message .out_videos ) > 0 :
173
+ for identifier , event_videos in qa_message .out_videos .items ():
174
+ for event_video in event_videos :
175
+ display (Video (event_video , embed = True ))
176
+ if use_ipython :
177
+ from matplotlib .animation import FuncAnimation
178
+ if len (qa_message .function_rets ) > 0 :
179
+ for identifier , rets in qa_message .function_rets .items ():
180
+ if not isinstance (rets , (tuple , list )):
181
+ rets = [rets ]
182
+ for ret in rets :
183
+ if isinstance (ret , FuncAnimation ):
184
+ display (HTML (ret .to_jshtml ()))
185
+ else :
186
+ display (Markdown (str (qa_message .function_rets [identifier ])))
187
+ return qa_message
188
+
189
+ def patch_pytorch_weights_only ():
190
+ """
191
+ Patch for PyTorch 2.6 weights_only issue with DeepLabCut SuperAnimal models.
192
+ This adds safe globals to allow loading of ruamel.yaml.scalarfloat.ScalarFloat objects.
193
+ Only applies the patch if torch.serialization.add_safe_globals exists (PyTorch >=2.6).
194
+ """
195
+ try :
196
+ import torch
197
+ from ruamel .yaml .scalarfloat import ScalarFloat
198
+ if hasattr (torch .serialization , "add_safe_globals" ):
199
+ torch .serialization .add_safe_globals ([ScalarFloat ])
200
+ except ImportError :
201
+ pass # If ruamel.yaml is not available, continue without the patch
0 commit comments