Skip to content

Commit 7cb062a

Browse files
committed
WIP: update to bedrock tag parser and move reusable methods
1 parent a1b764a commit 7cb062a

File tree

5 files changed

+147
-116
lines changed

5 files changed

+147
-116
lines changed

embabel-database-agent/src/main/java/com/embabel/database/agent/util/AWSBedrockTagParser.java

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,21 @@
1515
*/
1616
package com.embabel.database.agent.util;
1717

18-
import java.util.Collections;
18+
import java.util.ArrayList;
1919
import java.util.HashMap;
2020
import java.util.List;
2121
import java.util.Map;
2222

23+
import org.apache.commons.logging.Log;
24+
import org.apache.commons.logging.LogFactory;
2325
import org.springframework.beans.factory.annotation.Autowired;
2426

2527
import com.fasterxml.jackson.databind.ObjectMapper;
2628

2729
public class AWSBedrockTagParser implements TagParser {
2830

31+
static final Log logger = LogFactory.getLog(AWSBedrockTagParser.class);
32+
2933
public static final String INPUT_MODALITY_KEY = "inputModalities";
3034
public static final String OUTPUT_MODALITY_KEY = "outputModalities";
3135

@@ -36,27 +40,29 @@ public class AWSBedrockTagParser implements TagParser {
3640
private static final String OUTPUT_IMAGE_VALUE = "IMAGE";
3741
private static final String OUTPUT_EMBEDDING_VALUE = "EMBEDDING";
3842

39-
static final String[] INPUTS = {INPUT_TEXT_VALUE,INPUT_IMAGE_VALUE,INPUT_EMBEDDING_VALUE};
40-
static final String[] OUTPUTS = {OUTPUT_TEXT_VALUE,OUTPUT_IMAGE_VALUE,OUTPUT_EMBEDDING_VALUE};
43+
static final String[] BEDROCK_INPUTS = {INPUT_TEXT_VALUE,INPUT_IMAGE_VALUE,INPUT_EMBEDDING_VALUE};
44+
static final String[] BEDROCK_OUTPUTS = {OUTPUT_TEXT_VALUE,OUTPUT_IMAGE_VALUE,OUTPUT_EMBEDDING_VALUE};
4145

42-
private List<Map<String,Object>> tags;
46+
private List<Map<String,Object>> tasks;
4347

4448
@Autowired
4549
ObjectMapper objectMapper;
4650

4751
@Override
4852
public List<String> getTags(Map<String, Object> attributes) {
49-
String modelCategory = null;
53+
List<String> tags = new ArrayList<>();
5054
//load the categories
51-
if (tags == null || tags.isEmpty()) {
52-
tags = this.getTasks(objectMapper, RESOURCE_LOCATION);
55+
if (tasks == null || tasks.isEmpty()) {
56+
tasks = this.getTasks(objectMapper, RESOURCE_LOCATION);
5357
} //end if
5458
//map contains 2 keys "inputModalities" and "outputModalities"
5559
//values of which correspond to either inputText, inputImage, outputText, outputImage (true)
60+
//presence of the string is equivalent of "true"
5661
Map<String,Object> matches = new HashMap<>();
5762
List<String> inputValues = (List<String>) attributes.get(INPUT_MODALITY_KEY);
5863
List<String> outputValues = (List<String>) attributes.get(OUTPUT_MODALITY_KEY);
59-
for (String value : INPUTS) {
64+
//build a match map
65+
for (String value : BEDROCK_INPUTS) {
6066
//loop
6167
for (String inputValue : inputValues) {
6268
if (inputValue.equalsIgnoreCase(value)) {
@@ -65,7 +71,7 @@ public List<String> getTags(Map<String, Object> attributes) {
6571
} //end if
6672
} //end if
6773
} //end for
68-
for (String value : OUTPUTS) {
74+
for (String value : BEDROCK_OUTPUTS) {
6975
//loop
7076
for (String outputValue : outputValues) {
7177
if (outputValue.equalsIgnoreCase(value)) {
@@ -81,24 +87,18 @@ public List<String> getTags(Map<String, Object> attributes) {
8187
}//end if
8288
} //end for
8389
//now loop and check
84-
for (Map<String,Object> task : tags) {
85-
int matchedCount = 0;
86-
//check
87-
for (String key : matches.keySet()) {
88-
//go through the matches key and check
89-
if (task.get(key).toString().equalsIgnoreCase(matches.get(key).toString())) {
90-
matchedCount++;
91-
} //end if
92-
} //end for
93-
if (matchedCount == MATCH_COUNT) {
94-
//have a winner
95-
modelCategory = task.get(TAG_LABEL).toString();
96-
break;
97-
} //end if
98-
} //end for
99-
return Collections.singletonList(modelCategory);
90+
for (Map<String,Object> category : tasks) {
91+
String exactMatch = getExactMatchTag(category, matches);
92+
if (exactMatch != null) {
93+
tags.add(exactMatch);
94+
} else {
95+
//now we have "sub" tags
96+
tags = getAlternatives(tags, matches, category);
97+
}//end if
98+
} //end for
99+
return tags;
100100
}
101-
101+
102102
String getKey(String value,boolean input) {
103103
if (input) {
104104
if (value.equalsIgnoreCase(INPUT_IMAGE_VALUE)) {

embabel-database-agent/src/main/java/com/embabel/database/agent/util/LlmLeaderboardTagParser.java

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -83,89 +83,13 @@ public List<String> getTags(Map<String, Object> attributes) {
8383
if (exactMatch != null) {
8484
tags.add(exactMatch);
8585
} else {
86-
//now we have "sub" tags
87-
//this is where an input and an output are true but not ALL of the inputs and outpus match
88-
//inputs determine tag, not outputs (restrictions on what can be uploaded)
89-
//process inputs to limit the options
90-
int inputCounts = 0;
91-
int outputCounts = 0;
92-
//limitations of the tag are what can be uploaded, and the expected outputs
93-
//an image-text-to-text can just also be text-to-text and image-to-text, but it can't be text-to-image
94-
int trueOutputs = 0;
95-
int trueInputs = 0;
96-
boolean noMatch = false;
97-
//need to check how output matches
98-
for (Map.Entry<String, Object> entry : category.entrySet()) {
99-
String key = entry.getKey().toLowerCase();
100-
Object value = entry.getValue();
101-
Object matchValue = matches.get(entry.getKey());
102-
103-
// Normalize values once
104-
boolean isTrue = Boolean.parseBoolean(String.valueOf(value));
105-
boolean matchIsTrue = matchValue != null && Boolean.parseBoolean(String.valueOf(matchValue));
106-
107-
if (key.contains(OUTPUT)) {
108-
if (isTrue && matchIsTrue) {
109-
trueOutputs++;
110-
} else if (isTrue && (matchValue == null || matchIsTrue)) {
111-
noMatch = true;
112-
}
113-
} else if (key.contains(INPUT) && isTrue && matchIsTrue) {
114-
trueInputs++;
115-
} //end if
116-
}
117-
if (noMatch) {
118-
continue; //done here
119-
}
120-
121-
for (String input : INPUTS) {
122-
//check if there are ANY matches
123-
if (matches.get(input).toString().equalsIgnoreCase(Boolean.TRUE.toString().toLowerCase())) {
124-
inputCounts++;
125-
} //end if
126-
}//end for
127-
//check
128-
if (inputCounts <= 0) {
129-
//no matches, skip
130-
continue;
131-
} //end if
132-
//check the outputs
133-
for (String output : OUTPUTS) {
134-
//check if there are ANY matches
135-
if (matches.get(output).toString().equalsIgnoreCase(Boolean.TRUE.toString().toLowerCase())) {
136-
outputCounts++;
137-
} //end if
138-
}//end for
139-
if (outputCounts <= 0) {
140-
//no matches, skip
141-
continue;
142-
}//end if
143-
//match counts
144-
if (trueInputs >= 1 && trueOutputs >= 1) {
145-
//can use this tone
146-
tags.add(category.get(TAG_LABEL).toString());
147-
} //end if
86+
//now we have "sub" tags
87+
tags = getAlternatives(tags, matches, category);
14888
} //end if
14989
} //end for
15090
//return
15191
return tags;
15292
}
153-
154-
String getExactMatchTag(Map<String,Object> category,Map<String,Object> matches) {
155-
int matchedCount = 0;
156-
//check
157-
for (String key : matches.keySet()) {
158-
//go through the matches key and check
159-
if (category.get(key).toString().equalsIgnoreCase(matches.get(key).toString())) {
160-
matchedCount++;
161-
} //end if
162-
} //end for
163-
if (matchedCount == MATCH_COUNT) {
164-
//have a winner
165-
return category.get(TAG_LABEL).toString();
166-
}
167-
return null;
168-
}
16993

17094
void loadMap() {
17195
attributeMap = new HashMap<>();

embabel-database-agent/src/main/java/com/embabel/database/agent/util/TagParser.java

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
public interface TagParser {
3333

3434
static final Log logger = LogFactory.getLog(TagParser.class);
35-
35+
3636
public static final int MATCH_COUNT = 8;
37-
37+
3838
public static final String INPUT = "input";
3939
public static final String OUTPUT = "output";
4040
public static final String TAG_LABEL = "tag";
@@ -75,4 +75,86 @@ default List<Map<String,Object>> getTasks(ObjectMapper objectMapper,String resou
7575
//return
7676
return tasks;
7777
}
78+
79+
/**
80+
* reusable function to search for exact matches
81+
*
82+
* @param category
83+
* @param matches
84+
* @return
85+
*/
86+
default String getExactMatchTag(Map<String,Object> category,Map<String,Object> matches) {
87+
int matchedCount = 0;
88+
//check
89+
for (String key : matches.keySet()) {
90+
//go through the matches key and check
91+
if (category.get(key).toString().equalsIgnoreCase(matches.get(key).toString())) {
92+
matchedCount++;
93+
} //end if
94+
} //end for
95+
if (matchedCount == MATCH_COUNT) {
96+
//have a winner
97+
return category.get(TAG_LABEL).toString();
98+
}
99+
return null;
100+
}
101+
102+
/**
103+
* reusable function to look for near matches
104+
* @param tags
105+
* @param matches
106+
* @param category
107+
* @return
108+
*/
109+
default List<String> getAlternatives(List<String> tags,Map<String,Object> matches,Map<String,Object> category) {
110+
//this is where an input and an output are true but not ALL of the inputs and outpus match
111+
//inputs determine tag, not outputs (restrictions on what can be uploaded)
112+
//process inputs to limit the options
113+
int inputCounts = 0;
114+
int outputCounts = 0;
115+
//limitations of the tag are what can be uploaded, and the expected outputs
116+
//an image-text-to-text can just also be text-to-text and image-to-text, but it can't be text-to-image
117+
int trueOutputs = 0;
118+
int trueInputs = 0;
119+
boolean noMatch = false;
120+
//need to check how output matches
121+
for (Map.Entry<String, Object> entry : category.entrySet()) {
122+
String key = entry.getKey().toLowerCase();
123+
Object value = entry.getValue();
124+
Object matchValue = matches.get(entry.getKey());
125+
126+
// Normalize values once
127+
boolean isTrue = Boolean.parseBoolean(String.valueOf(value));
128+
boolean matchIsTrue = matchValue != null && Boolean.parseBoolean(String.valueOf(matchValue));
129+
130+
if (key.contains(OUTPUT)) {
131+
if (isTrue && matchIsTrue) {
132+
trueOutputs++;
133+
} else if (isTrue && (matchValue == null || matchIsTrue)) {
134+
noMatch = true;
135+
}
136+
} else if (key.contains(INPUT) && isTrue && matchIsTrue) {
137+
trueInputs++;
138+
} //end if
139+
}
140+
if (noMatch) return tags;
141+
//check inputs
142+
for (String input : INPUTS) {
143+
Object mv = matches.get(input);
144+
inputCounts += mv != null && mv.toString().equalsIgnoreCase("true") ? 1 : 0;
145+
} //end for
146+
//check
147+
if (inputCounts == 0) return tags;
148+
//check the outputs
149+
for (String output : OUTPUTS) {
150+
Object mv = matches.get(output);
151+
outputCounts += mv != null && mv.toString().equalsIgnoreCase("true") ? 1 : 0;
152+
}
153+
if (outputCounts == 0) return tags;
154+
//match counts
155+
if (trueInputs >= 1 && trueOutputs >= 1) {
156+
tags.add(String.valueOf(category.get(TAG_LABEL)));
157+
}
158+
return tags;
159+
}
78160
}

embabel-database-agent/src/test/java/com/embabel/database/agent/util/AWSBedrockTaskParserTest.java renamed to embabel-database-agent/src/test/java/com/embabel/database/agent/util/AWSBedrockTagParserTest.java

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,62 @@
1616
package com.embabel.database.agent.util;
1717

1818
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
import static org.junit.jupiter.api.Assertions.assertTrue;
1920

21+
import java.util.Arrays;
2022
import java.util.Collections;
2123
import java.util.HashMap;
2224
import java.util.List;
2325
import java.util.Map;
2426

27+
import org.apache.commons.logging.Log;
28+
import org.apache.commons.logging.LogFactory;
2529
import org.junit.jupiter.api.Test;
2630
import org.springframework.test.util.ReflectionTestUtils;
2731

2832
import com.fasterxml.jackson.databind.ObjectMapper;
2933

30-
public class AWSBedrockTaskParserTest {
31-
34+
public class AWSBedrockTagParserTest {
35+
36+
private static Log logger = LogFactory.getLog(LlmLeaderboardTagParserTest.class);
3237

3338
@Test
3439
void testGetCategory() throws Exception {
3540
Map<String,Object> map = new HashMap<>();
36-
map.put(AWSBedrockTagParser.INPUT_MODALITY_KEY,Collections.singletonList("TEXT"));
37-
map.put(AWSBedrockTagParser.OUTPUT_MODALITY_KEY,Collections.singletonList("TEXT"));
41+
map.put(AWSBedrockTagParser.INPUT_MODALITY_KEY,Arrays.asList("TEXT"));
42+
map.put(AWSBedrockTagParser.OUTPUT_MODALITY_KEY,Arrays.asList("TEXT"));
3843
//setup
3944
ObjectMapper objectMapper = new ObjectMapper();
4045
AWSBedrockTagParser parser = new AWSBedrockTagParser();
4146
ReflectionTestUtils.setField(parser, "objectMapper", objectMapper);
4247
String expectedCategory = "text-to-text";
4348
//get category
44-
List<String> result = parser.getTags(map);
49+
List<String> results = parser.getTags(map);
4550
//validate
46-
assertEquals(expectedCategory, result.get(0));
4751

52+
assertTrue(results.size() == 5);
53+
boolean found = false;
54+
for (String result : results) {
55+
if (result.equalsIgnoreCase(expectedCategory)) {
56+
found = true;
57+
break;
58+
} //end if
59+
} //end for
60+
assertTrue(found);
61+
//try again
62+
map = new HashMap<>();
63+
map.put(AWSBedrockTagParser.INPUT_MODALITY_KEY,Arrays.asList("TEXT","IMAGE"));
64+
map.put(AWSBedrockTagParser.OUTPUT_MODALITY_KEY,Arrays.asList("TEXT"));
65+
results = parser.getTags(map);
66+
//now test count
67+
assertTrue(results.size() == 6);
68+
found = false;
69+
for (String result : results) {
70+
if (result.equalsIgnoreCase(expectedCategory)) {
71+
found = true;
72+
break;
73+
} //end if
74+
} //end for
75+
assertTrue(found);
4876
}
4977
}

embabel-database-agent/src/test/java/com/embabel/database/agent/util/LlmLeaderboardTagParserTest.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@
1515
*/
1616
package com.embabel.database.agent.util;
1717

18-
import static org.junit.jupiter.api.Assertions.assertEquals;
1918
import static org.junit.jupiter.api.Assertions.assertFalse;
2019
import static org.junit.jupiter.api.Assertions.assertNotNull;
2120
import static org.junit.jupiter.api.Assertions.assertTrue;
22-
import static org.junit.jupiter.api.Assertions.fail;
23-
2421
import java.util.List;
2522
import java.util.Map;
2623

@@ -125,7 +122,7 @@ void testGetCategory() throws Exception {
125122
" }";
126123
map = objectMapper.readValue(model_json,new TypeReference<Map<String,Object>>(){});
127124
results = parser.getTags(map);
128-
logger.info(results);
125+
logger.debug(results);
129126
//validate
130127
assertTrue(results.size() == 6); //expect 3 results
131128
assertTrue(results.contains(expectedCategory));

0 commit comments

Comments
 (0)