Skip to content

Commit 4535bb8

Browse files
committed
evaluate
1 parent 40c6fe4 commit 4535bb8

7 files changed

+894
-0
lines changed

evaluate/gpt_evaluation_script.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import re
2+
import json
3+
import openai # Assuming openai is installed and properly set up
4+
import time
5+
from tqdm import tqdm
6+
7+
import argparse
8+
from openai import OpenAI
9+
client = OpenAI()
10+
def load_jsonl(file_path):
11+
with open(file_path, 'r') as file:
12+
return [json.loads(line) for line in file]
13+
14+
def get_gpt_scores(prediction_jsonl_path, ground_truth_jsonl_path, output_jsonl_path, gpt_model):
15+
# Load the ground truths
16+
ground_truths = load_jsonl(ground_truth_jsonl_path)
17+
18+
# Create a dictionary for easy access to ground truths
19+
gt_dict = {item['question_id']: item for item in ground_truths}
20+
21+
# Process each prediction
22+
predictions = load_jsonl(prediction_jsonl_path)
23+
24+
with open(output_jsonl_path, 'w') as out_file:
25+
for item in tqdm(predictions,desc='Evaluating, If stuck, please Ctrl + C .', dynamic_ncols=True):
26+
question_id = item['question_id']
27+
prediction_text = item.get('model_output',"")
28+
29+
gt_item = gt_dict.get(question_id, {})
30+
gt_answer = gt_item.get('answer',"")
31+
32+
prediction_text=str(prediction_text)
33+
gt_answer=str(gt_answer)
34+
35+
36+
gt_question = gt_item.get('prompt')
37+
38+
print(f"question_id: {question_id}, prediction_text: {prediction_text}, gt_answer: {gt_answer}")
39+
if not prediction_text or not gt_answer:
40+
print(f"Skipping question_id {question_id} due to empty prediction_text or gt_answer.")
41+
continue
42+
43+
retries = 0
44+
max_retries = 3
45+
while retries < max_retries:
46+
# Create a question for the GPT model and other processing here...
47+
question = f"""Compare the ground truth and prediction from AI models, to give a correctness score for the prediction. Ignore case, single and plural grammar problems, and consider whether the meaning is similar. If the meaning is similar, it deserves full marks. A '/' in ground truth indicates that there are multiple responses to the question, with full marks for any one answer. The correctness score is 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right).
48+
Example:
49+
Question | Ground truth | Prediction | Correctness
50+
--- | --- | --- | ---
51+
How many apples are here? | 10 | 7 | 0.0
52+
How many apples are here? | 10 | 10 | 1.0
53+
What are keeping the elephants in their area? | bars / fence / fences / cage | fence | 1
54+
What are keeping the elephants in their area? | bars / fence / fences / cage | They are stuck in the cage. | 1.0
55+
Identify the relevant traffic signal for the ego-vehicle's current path | None | Green | 0.0
56+
Identify the relevant traffic signal for the ego-vehicle's current path | Green Light | Red | 0.0
57+
Identify the relevant traffic signal for the ego-vehicle's current path | Green Light | Green | 1.0
58+
What can the organ with black color in this image be used for?| breathe | Breathing. | 1.0
59+
60+
Here is the QA you need to compare and score
61+
Question: {gt_question}
62+
Ground Truth: {gt_answer}
63+
Prediction: {prediction_text}
64+
Score :
65+
66+
Provide only the numerical correctness score as the output.
67+
"""
68+
69+
try:
70+
response = client.chat.completions.create(
71+
model=gpt_model,
72+
max_tokens=64,
73+
messages=[{"role": "user", "content": question}],
74+
timeout = 10,
75+
)
76+
# print("response: ",response)
77+
except:
78+
print("sleep 30s")
79+
time.sleep(30)
80+
81+
# Example of how you might write results to the output file
82+
83+
else:
84+
# Example of how you might write results to the output file
85+
model_response = response.choices[0].message.content
86+
print(f"model_response: {model_response}")
87+
try:
88+
score_matches = re.findall(r"(\d+(\.\d+)?)", model_response)
89+
if score_matches:
90+
if len(score_matches) > 1:
91+
raise ValueError(f"Multiple numbers detected: {model_response}")
92+
93+
score = float(score_matches[0][0])
94+
# print(f"model_response: {model_response}")
95+
print(f"score: {score}")
96+
if 0 <= score <= 1:
97+
result = {
98+
'question_id': question_id,
99+
'image': gt_item.get('image', ''),
100+
'model_response': score
101+
}
102+
out_file.write(json.dumps(result) + '\n')
103+
break
104+
else:
105+
raise ValueError(f"Invalid response format: {model_response}")
106+
except ValueError:
107+
pass
108+
109+
110+
retries += 1
111+
if retries == max_retries:
112+
print(f"Failed to get a valid score after {max_retries} attempts for question_id {question_id}.")
113+
114+
115+
# 调用函数
116+
#get_gpt_scores("/workspace/LLaVA/Zirui/Results/llava_1.5/llava_1.5_13B_orignal.jsonl", "/workspace/LLaVA/Zirui/jsonl/llava/Benckmark_LLaVA_style.jsonl", "/workspace/LLaVA/Zirui/evaluate/score/oringal_score_LLaVA_1.5_13B.jsonl", "gpt-4-0613")
117+
118+
def main():
119+
parser = argparse.ArgumentParser(description='Evaluate predictions using GPT.')
120+
parser.add_argument('--prediction_jsonl_path', type=str, required=True,help='Path to the prediction JSONL file.')
121+
parser.add_argument('--ground_truth_jsonl_path', type=str, required=True,help='Path to the ground truth JSONL file.')
122+
parser.add_argument('--output_jsonl_path', type=str, required=True,help='Path to save the output JSONL file.')
123+
parser.add_argument('--gpt_model', type=str, required=True, help='GPT model to use for evaluation.')
124+
125+
args = parser.parse_args()
126+
get_gpt_scores(args.prediction_jsonl_path, args.ground_truth_jsonl_path, args.output_jsonl_path, args.gpt_model)
127+
128+
if __name__ == '__main__':
129+
main()

evaluate/gpt_evaluation_script_AD.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import re
2+
import json
3+
import openai # Assuming openai is installed and properly set up
4+
import time
5+
from tqdm import tqdm
6+
import argparse
7+
from openai import OpenAI
8+
client = OpenAI()
9+
10+
11+
def load_jsonl(file_path):
12+
with open(file_path, 'r') as file:
13+
return [json.loads(line) for line in file]
14+
15+
def get_gpt_scores(prediction_jsonl_path, ground_truth_jsonl_path, output_jsonl_path, gpt_model):
16+
# Load the ground truths
17+
ground_truths = load_jsonl(ground_truth_jsonl_path)
18+
19+
# Create a dictionary for easy access to ground truths
20+
gt_dict = {item['question_id']: item for item in ground_truths}
21+
22+
# Process each prediction
23+
predictions = load_jsonl(prediction_jsonl_path)
24+
25+
with open(output_jsonl_path, 'w') as out_file:
26+
for item in tqdm(predictions,desc='Evaluating, If stuck, please Ctrl + C .', dynamic_ncols=True):
27+
question_id = item['question_id']
28+
prediction_text = item.get('model_output',"")
29+
30+
gt_item = gt_dict.get(question_id, {})
31+
gt_answer = gt_item.get('answer',"")
32+
33+
prediction_text=str(prediction_text)
34+
gt_answer=str(gt_answer)
35+
36+
37+
gt_question = gt_item.get('prompt')
38+
39+
print(f"question_id: {question_id}, prediction_text: {prediction_text}, gt_answer: {gt_answer}")
40+
if not prediction_text or not gt_answer:
41+
print(f"Skipping question_id {question_id} due to empty prediction_text or gt_answer.")
42+
continue
43+
44+
retries = 0
45+
max_retries = 3
46+
while retries < max_retries:
47+
# Create a question for the GPT model and other processing here...
48+
question = f"""Compare the ground truth and prediction from AI models, to give a correctness score for the prediction. Ignore case, single and plural grammar problems, and consider whether the meaning is similar. If the meaning is similar, it deserves full marks. A '/' in ground truth indicates that there are multiple responses to the question, with full marks for any one answer. The correctness score is 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right).
49+
Example:
50+
Question | Ground truth | Prediction | Correctness
51+
--- | --- | --- | ---
52+
Identify the relevant traffic signal for the ego-vehicle's current path\nAnswer the question using a single word or phrase. | Green | None | 0.0
53+
Identify the relevant traffic signal for the ego-vehicle's current path\nAnswer the question using a single word or phrase. | green. | Green Light | 1.0
54+
Identify the relevant traffic signal for the ego-vehicle's current path\nAnswer the question using a single word or phrase. | Green Light | Green | 1.0
55+
Identify the relevant traffic signal for the ego-vehicle's current path\nAnswer the question using a single word or phrase. | green. | Green | 1.0
56+
Considering the objects such as car and traffic light from the visual language dataset categories, what is the intended behavior or action for the main vehicle in an autonomous driving scenario?\nAnswer the question using a single word or phrase. | stop | Stop | 1.0
57+
How many apples are here? | 10 | 10 | 1.0
58+
What are keeping the elephants in their area? | bars / fence / fences / cage | fence | 1
59+
What are keeping the elephants in their area? | bars / fence / fences / cage | They are stuck in the cage. | 1.0
60+
Identify the relevant traffic signal for the ego-vehicle's current path | None | Green | 0.0
61+
Identify the relevant traffic signal for the ego-vehicle's current path | Green Light | Red | 0.0
62+
Identify the relevant traffic signal for the ego-vehicle's current path | Green Light | Green | 1.0
63+
What can the organ with black color in this image be used for?| breathe | Breathing. | 1.0
64+
65+
Here is the QA you need to compare and score
66+
Question: {gt_question}
67+
Ground Truth: {gt_answer}
68+
Prediction: {prediction_text}
69+
Score :
70+
71+
Provide only the numerical correctness score as the output.
72+
"""
73+
74+
75+
76+
try:
77+
response = client.chat.completions.create(
78+
model=gpt_model,
79+
max_tokens=64,
80+
messages=[{"role": "user", "content": question}],
81+
timeout = 10,
82+
)
83+
# print("response: ",response)
84+
except:
85+
print("sleep 30s")
86+
time.sleep(30)
87+
88+
# Example of how you might write results to the output file
89+
model_response = response.choices[0].message.content
90+
print(f"model_response: {model_response}")
91+
try:
92+
score_matches = re.findall(r"(\d+(\.\d+)?)", model_response)
93+
if score_matches:
94+
if len(score_matches) > 1:
95+
raise ValueError(f"Multiple numbers detected: {model_response}")
96+
97+
score = float(score_matches[0][0])
98+
# print(f"model_response: {model_response}")
99+
print(f"score: {score}")
100+
if 0 <= score <= 1:
101+
result = {
102+
'question_id': question_id,
103+
'image': gt_item.get('image', ''),
104+
'model_response': score
105+
}
106+
out_file.write(json.dumps(result) + '\n')
107+
break
108+
else:
109+
raise ValueError(f"Invalid response format: {model_response}")
110+
except ValueError:
111+
pass
112+
113+
114+
retries += 1
115+
if retries == max_retries:
116+
print(f"Failed to get a valid score after {max_retries} attempts for question_id {question_id}.")
117+
118+
119+
# 调用函数
120+
#get_gpt_scores("/workspace/LLaVA/Zirui/Results/llava_1.5/llava_1.5_13B_orignal.jsonl", "/workspace/LLaVA/Zirui/jsonl/llava/Benckmark_LLaVA_style.jsonl", "/workspace/LLaVA/Zirui/evaluate/score/oringal_score_LLaVA_1.5_13B.jsonl", "gpt-4-0613")
121+
122+
def main():
123+
parser = argparse.ArgumentParser(description='Evaluate predictions using GPT.')
124+
parser.add_argument('--prediction_jsonl_path', type=str, required=True,help='Path to the prediction JSONL file.')
125+
parser.add_argument('--ground_truth_jsonl_path', type=str, required=True,help='Path to the ground truth JSONL file.')
126+
parser.add_argument('--output_jsonl_path', type=str, required=True,help='Path to save the output JSONL file.')
127+
parser.add_argument('--gpt_model', type=str, required=True, help='GPT model to use for evaluation.')
128+
129+
130+
131+
args = parser.parse_args()
132+
get_gpt_scores(args.prediction_jsonl_path, args.ground_truth_jsonl_path, args.output_jsonl_path, args.gpt_model)
133+
134+
if __name__ == '__main__':
135+
main()

0 commit comments

Comments
 (0)