Skip to content

Commit de374b0

Browse files
Merge pull request #128 from AikidoSec/add-allowlist-global
Add support for global IP allowlist (e.g. geo-fencing)
2 parents 4f5b31e + f47b434 commit de374b0

File tree

6 files changed

+199
-23
lines changed

6 files changed

+199
-23
lines changed

agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/ReportingApi.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,16 @@ public ReportingApi(int timeoutInSec) {
2929
*/
3030
public abstract Optional<APIResponse> report(String token, APIEvent event);
3131

32-
public record APIListsResponse(List<ListsResponseEntry> blockedIPAddresses, String blockedUserAgents) {}
32+
public record APIListsResponse(
33+
List<ListsResponseEntry> blockedIPAddresses,
34+
List<ListsResponseEntry> allowedIPAddresses,
35+
String blockedUserAgents
36+
) {}
3337
public record ListsResponseEntry(String source, String description, List<String> ips) {}
3438
/**
3539
* Fetch blocked lists using a separate API call, these can include :
3640
* -> blocked IP Addresses (e.g. geo restrictions)
41+
* -> allowed IP Addresses (e.g. geo restrictions)
3742
* -> blocked User-Agents (e.g. bot blocking)
3843
* @param token the authentication token
3944
*/

agent_api/src/main/java/dev/aikido/agent_api/thread_cache/ThreadCacheObject.java

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import static dev.aikido.agent_api.helpers.IPListBuilder.createIPList;
1616
import static dev.aikido.agent_api.helpers.UnixTimeMS.getUnixTimeMS;
17+
import static dev.aikido.agent_api.vulnerabilities.ssrf.IsPrivateIP.isPrivateIp;
1718

1819
public class ThreadCacheObject {
1920
private final List<Endpoint> endpoints;
@@ -23,9 +24,10 @@ public class ThreadCacheObject {
2324
private final Hostnames hostnames;
2425
private final Routes routes;
2526

26-
// IP Blocking (e.g. Geo-IP Restrictions) :
27-
public record BlockedIpEntry(IPList blocklist, String description) {}
28-
private List<BlockedIpEntry> blockedIps = new ArrayList<>();
27+
// IP restrictions (e.g. Geo-IP Restrictions) :
28+
public record IPListEntry(IPList ipList, String description) {}
29+
private List<IPListEntry> blockedIps = new ArrayList<>();
30+
private List<IPListEntry> allowedIps = new ArrayList<>();
2931
// User-Agent Blocking (e.g. bot blocking) :
3032
private Pattern blockedUserAgentRegex;
3133

@@ -72,8 +74,24 @@ public boolean isBypassedIP(String ip) {
7274
* Check if the IP is blocked (e.g. Geo IP Restrictions)
7375
*/
7476
public BlockedResult isIpBlocked(String ip) {
75-
for (BlockedIpEntry entry: blockedIps) {
76-
if (entry.blocklist.matches(ip)) {
77+
// Check for allowed ip addresses (i.e. only one country is allowed to visit the site)
78+
// Always allow access from private IP addresses (those include local IP addresses)
79+
if(allowedIps != null && allowedIps.size() > 0 && !isPrivateIp(ip)) {
80+
boolean ipAllowed = false;
81+
for (IPListEntry entry: allowedIps) {
82+
if (entry.ipList.matches(ip)) {
83+
ipAllowed = true; // We allow IP addresses as long as they match with one of the lists.
84+
break;
85+
}
86+
}
87+
if (!ipAllowed) {
88+
return new BlockedResult(true, "allowlist");
89+
}
90+
}
91+
92+
// Check for blocked ip addresses
93+
for (IPListEntry entry: blockedIps) {
94+
if (entry.ipList.matches(ip)) {
7795
return new BlockedResult(true, entry.description);
7896
}
7997
}
@@ -87,7 +105,14 @@ public void updateBlockedLists(Optional<ReportingApi.APIListsResponse> blockedLi
87105
if (res.blockedIPAddresses() != null) {
88106
for (ReportingApi.ListsResponseEntry entry : res.blockedIPAddresses()) {
89107
IPList ipList = createIPList(entry.ips());
90-
blockedIps.add(new BlockedIpEntry(ipList, entry.description()));
108+
blockedIps.add(new IPListEntry(ipList, entry.description()));
109+
}
110+
}
111+
// Update allowed IP addresses (e.g. for geo restrictions) :
112+
if (res.allowedIPAddresses() != null) {
113+
for (ReportingApi.ListsResponseEntry entry: res.allowedIPAddresses()) {
114+
IPList ipList = createIPList(entry.ips());
115+
this.allowedIps.add(new IPListEntry(ipList, entry.description()));
91116
}
92117
}
93118
// Update Blocked User-Agents regex

agent_api/src/test/java/background/ServiceConfigurationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void updateConfig_ShouldHandleAllUpdates_WhenApiResponseIsComplete() {
104104
@Test
105105
void testSetForBlockedIpRes() {
106106
assertNull(serviceConfiguration.blockedListsRes);
107-
serviceConfiguration.storeBlockedListsRes(Optional.of(new ReportingApi.APIListsResponse(null, null)));
107+
serviceConfiguration.storeBlockedListsRes(Optional.of(new ReportingApi.APIListsResponse(null, null, null)));
108108
assertNotNull(serviceConfiguration.blockedListsRes);
109109
serviceConfiguration.storeBlockedListsRes(Optional.empty());
110110

agent_api/src/test/java/collectors/WebRequestCollectorTest.java

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
class WebRequestCollectorTest {
2525

26-
private ContextObject contextObject;
26+
private EmptySampleContextObject contextObject;
2727
private ThreadCacheObject threadCacheObject;
2828

2929
@BeforeEach
@@ -105,7 +105,7 @@ void testReport_noThreadCacheObject() {
105105
void testReport_ipBlockedTwice() {
106106
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(List.of(
107107
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("192.168.1.1"))
108-
), "");
108+
), null, "");
109109
// Mock ThreadCache
110110
threadCacheObject = new ThreadCacheObject(List.of(new Endpoint(
111111
"GET", "/api/resource", 100, 100,
@@ -126,7 +126,7 @@ void testReport_ipBlockedTwice() {
126126
void testReport_ipBlockedUsingLists() {
127127
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(List.of(
128128
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("bullshit.ip", "192.168.1.1"))
129-
), "");
129+
), null, "");
130130
// Mock ThreadCache
131131
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(), Set.of(), new Routes(), Optional.of(blockedListsRes));
132132
ThreadCache.set(threadCacheObject);
@@ -138,13 +138,49 @@ void testReport_ipBlockedUsingLists() {
138138
assertEquals("Your IP address is not allowed to access this resource. (Your IP: 192.168.1.1)", response.msg());
139139
assertEquals(403, response.status());
140140
}
141+
@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "test-token")
142+
@Test
143+
void testReport_ipNotAllowedUsingLists() {
144+
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(null, List.of(
145+
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("192.168.2.1"))
146+
), "");
147+
// Mock ThreadCache
148+
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(), Set.of(), new Routes(), Optional.of(blockedListsRes));
149+
ThreadCache.set(threadCacheObject);
150+
151+
152+
WebRequestCollector.Res response = WebRequestCollector.report(contextObject);
153+
154+
assertNull(response); // Private IP
155+
156+
contextObject.setIp("4.4.4.4");
157+
response = WebRequestCollector.report(contextObject);
158+
assertNotNull(response);
159+
assertEquals("Your IP address is not allowed to access this resource. (Your IP: 4.4.4.4)", response.msg());
160+
assertEquals(403, response.status());
161+
}
162+
@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "test-token")
163+
@Test
164+
void testReport_ipInAllowlist() {
165+
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(null, List.of(
166+
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("192.168.1.1", "10.0.0.0/24"))
167+
), "");
168+
// Mock ThreadCache
169+
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(), Set.of(), new Routes(), Optional.of(blockedListsRes));
170+
ThreadCache.set(threadCacheObject);
171+
172+
173+
WebRequestCollector.Res response = WebRequestCollector.report(contextObject);
174+
175+
assertNull(response);
176+
}
141177

142178
@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "test-token")
143179
@Test
144180
void testReport_ipNotBlockedUsingListsNorUserAgent() {
145181
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(List.of(
146182
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("192.168.1.2", "192.168.1.3"))
147-
), "Unrelated|random");
183+
), null, "Unrelated|random");
148184
// Mock ThreadCache
149185
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(), Set.of(), new Routes(), Optional.of(blockedListsRes));
150186
ThreadCache.set(threadCacheObject);
@@ -160,7 +196,7 @@ void testReport_ipNotBlockedUsingListsNorUserAgent() {
160196
void testReport_userAgentBlocked() {
161197
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(List.of(
162198
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("192.168.1.2", "192.168.1.3"))
163-
), "AI2Bot|hacker");
199+
), null, "AI2Bot|hacker");
164200
// Mock ThreadCache
165201
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(), Set.of(), new Routes(), Optional.of(blockedListsRes));
166202
ThreadCache.set(threadCacheObject);
@@ -178,7 +214,7 @@ void testReport_userAgentBlocked() {
178214
void testReport_userAgentBlocked_Ip_Bypassed() {
179215
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(List.of(
180216
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("192.168.1.2", "192.168.1.3"))
181-
), "AI2Bot|hacker");
217+
), null, "AI2Bot|hacker");
182218
// Mock ThreadCache
183219
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(),
184220
/* bypassedIps : */ Set.of("192.168.1.1"), new Routes(), Optional.of(blockedListsRes));
@@ -196,7 +232,27 @@ void testReport_userAgentBlocked_Ip_Bypassed() {
196232
void testReport_ipBlockedUsingLists_Ip_Bypassed() {
197233
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(List.of(
198234
new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("bullshit.ip", "192.168.1.1"))
199-
), "");
235+
), null, "");
236+
// Mock ThreadCache
237+
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(),
238+
/* bypassedIps : */ Set.of("192.168.1.1"), new Routes(), Optional.of(blockedListsRes));
239+
ThreadCache.set(threadCacheObject);
240+
241+
242+
WebRequestCollector.Res response = WebRequestCollector.report(contextObject);
243+
244+
assertNull(response);
245+
assertNull(Context.get());
246+
}
247+
248+
@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "test-token")
249+
@Test
250+
void testReport_ipNotAllowedUsingLists_Ip_Bypassed() {
251+
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(
252+
null,
253+
List.of(new ReportingApi.ListsResponseEntry("geoip", "geoip restrictions", List.of("1.2.3.4"))),
254+
""
255+
);
200256
// Mock ThreadCache
201257
threadCacheObject = new ThreadCacheObject(List.of(), Set.of(),
202258
/* bypassedIps : */ Set.of("192.168.1.1"), new Routes(), Optional.of(blockedListsRes));

agent_api/src/test/java/thread_cache/ThreadCacheObjectTest.java

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package thread_cache;
22

3-
import dev.aikido.agent_api.background.ServiceConfiguration;
43
import dev.aikido.agent_api.background.cloud.api.ReportingApi;
54
import dev.aikido.agent_api.thread_cache.ThreadCacheObject;
65
import org.junit.jupiter.api.Test;
7-
import utils.EmtpyThreadCacheObject;
86

97
import java.util.List;
108
import java.util.Optional;
@@ -24,7 +22,7 @@ public void update() {
2422
"fd00:3234:5678:9abc::1/64",
2523
"5.6.7.8/32"
2624
))
27-
), "Test|One")));
25+
), null, "Test|One")));
2826

2927
assertEquals(new ThreadCacheObject.BlockedResult(true, "description"), tCache.isIpBlocked("1.2.3.4"));
3028
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("2.3.4.5"));
@@ -59,9 +57,9 @@ public void updateEmpty() {
5957
"fd00:3234:5678:9abc::1/64",
6058
"5.6.7.8/32"
6159
))
62-
), "Test|One")));
60+
), null, "Test|One")));
6361

64-
tCache.updateBlockedLists(Optional.of(new ReportingApi.APIListsResponse(null, null)));
62+
tCache.updateBlockedLists(Optional.of(new ReportingApi.APIListsResponse(null, null,null)));
6563

6664
assertEquals(new ThreadCacheObject.BlockedResult(true, "description"), tCache.isIpBlocked("1.2.3.4"));
6765
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("2.3.4.5"));
@@ -96,16 +94,16 @@ public void updateRegexes() {
9694
"fd00:3234:5678:9abc::1/64",
9795
"5.6.7.8/32"
9896
))
99-
), "Test|One")));
97+
), null, "Test|One")));
10098

101-
tCache.updateBlockedLists(Optional.of(new ReportingApi.APIListsResponse(null, "")));
99+
tCache.updateBlockedLists(Optional.of(new ReportingApi.APIListsResponse(null, null, "")));
102100

103101
assertTrue(tCache.isBlockedUserAgent("This is my TEST user agent"));
104102
assertTrue(tCache.isBlockedUserAgent("Test"));
105103
assertTrue(tCache.isBlockedUserAgent("TEst and ONE"));
106104
assertFalse(tCache.isBlockedUserAgent("Est|On"));
107105
assertFalse(tCache.isBlockedUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"));
108-
tCache.updateBlockedLists(Optional.of(new ReportingApi.APIListsResponse(null, "Mozilla")));
106+
tCache.updateBlockedLists(Optional.of(new ReportingApi.APIListsResponse(null, null, "Mozilla")));
109107

110108
assertFalse(tCache.isBlockedUserAgent("This is my TEST user agent"));
111109
assertFalse(tCache.isBlockedUserAgent("Test"));
@@ -155,4 +153,93 @@ public void testThreadCacheBypassedIPsSubnet() {
155153
assertFalse(tCache.isBypassedIP("10.0.1.1"));
156154
assertFalse(tCache.isBypassedIP("1.2.3.4"));
157155
}
156+
157+
@Test
158+
public void testIsIpBlockedWithAllowedAndBlockedIPs() {
159+
// Create a ThreadCacheObject with both allowed and blocked IPs
160+
ThreadCacheObject tCache = new ThreadCacheObject(null, null, null, null, Optional.of(new ReportingApi.APIListsResponse(List.of(
161+
new ReportingApi.ListsResponseEntry("geoip", "description", List.of(
162+
"1.2.3.4", // Blocked IP
163+
"192.168.1.1" // Blocked IP
164+
))
165+
), List.of(
166+
new ReportingApi.ListsResponseEntry("geoip", "description", List.of(
167+
"10.0.0.1", // Allowed IP
168+
"1.2.3.4"
169+
))
170+
), "Test|One")));
171+
172+
// Test blocked IPs
173+
assertEquals(new ThreadCacheObject.BlockedResult(true, "description"), tCache.isIpBlocked("1.2.3.4"));
174+
assertEquals(new ThreadCacheObject.BlockedResult(true, "description"), tCache.isIpBlocked("192.168.1.1"));
175+
176+
// Test allowed IPs
177+
/// Private IP :
178+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("10.0.0.2"));
179+
/// Not in allowlist
180+
assertEquals(new ThreadCacheObject.BlockedResult(true, "allowlist"), tCache.isIpBlocked("1.2.3.3"));
181+
}
182+
183+
@Test
184+
public void testIsIpBlockedWithOnlyAllowedIPs() {
185+
// Create a ThreadCacheObject with only allowed IPs
186+
ThreadCacheObject tCache = new ThreadCacheObject(null, null, null, null, Optional.of(new ReportingApi.APIListsResponse(null, List.of(
187+
new ReportingApi.ListsResponseEntry("geoip", "description", List.of(
188+
"10.0.0.1" // Allowed IP
189+
))
190+
), "Test|One")));
191+
192+
// Test allowed IP
193+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("10.0.0.1"));
194+
// Test a non-allowed private-IP
195+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("10.0.0.2"));
196+
// Test a non-allowed IP
197+
assertEquals(new ThreadCacheObject.BlockedResult(true, "allowlist"), tCache.isIpBlocked("1.2.3.4"));
198+
}
199+
200+
@Test
201+
public void testIsIpBlockedWithOnlyBlockedIPs() {
202+
// Create a ThreadCacheObject with only blocked IPs
203+
ThreadCacheObject tCache = new ThreadCacheObject(null, null, null, null, Optional.of(new ReportingApi.APIListsResponse(List.of(
204+
new ReportingApi.ListsResponseEntry("geoip", "description", List.of(
205+
"1.2.3.4", // Blocked IP
206+
"192.168.1.1" // Blocked IP
207+
))
208+
), null, "Test|One")));
209+
210+
// Test blocked IPs
211+
assertEquals(new ThreadCacheObject.BlockedResult(true, "description"), tCache.isIpBlocked("1.2.3.4"));
212+
assertEquals(new ThreadCacheObject.BlockedResult(true, "description"), tCache.isIpBlocked("192.168.1.1"));
213+
// Test a non-blocked IP
214+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("10.0.0.1"));
215+
}
216+
217+
@Test
218+
public void testIsIpBlockedWithAllowedIPsAndBlockedIPs() {
219+
// Create a ThreadCacheObject with multiple allowed and blocked IPs
220+
ThreadCacheObject tCache = new ThreadCacheObject(null, null, null, null, Optional.of(new ReportingApi.APIListsResponse(null, List.of(
221+
new ReportingApi.ListsResponseEntry("geoip1", "description", List.of(
222+
"1.2.3.4" // Blocked IP
223+
)),
224+
new ReportingApi.ListsResponseEntry("geoip2", "description", List.of(
225+
"8.8.8.0/24"
226+
)),
227+
new ReportingApi.ListsResponseEntry("geoip3", "description", List.of(
228+
"4.4.4.4" // Another allowed IP
229+
))
230+
), "Test|One")));
231+
232+
// Test allowed IPs
233+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("10.0.0.1"));
234+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("4.4.4.4"));
235+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("1.2.3.4"));
236+
assertEquals(new ThreadCacheObject.BlockedResult(false, null), tCache.isIpBlocked("8.8.8.1"));
237+
238+
// Test a non-allowed IP
239+
assertEquals(new ThreadCacheObject.BlockedResult(true, "allowlist"), tCache.isIpBlocked("4.4.4.1"));
240+
assertEquals(new ThreadCacheObject.BlockedResult(true, "allowlist"), tCache.isIpBlocked("8.8.7.8"));
241+
242+
}
243+
244+
158245
}

agent_api/src/test/java/utils/EmptySampleContextObject.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,7 @@ public EmptySampleContextObject(String argument, String route, String method) {
3333
this.route = route;
3434
this.method = method;
3535
}
36+
public void setIp(String ip) {
37+
this.remoteAddress = ip;
38+
}
3639
}

0 commit comments

Comments
 (0)