diff --git a/agent_api/src/main/java/dev/aikido/agent_api/ShouldBlockRequest.java b/agent_api/src/main/java/dev/aikido/agent_api/ShouldBlockRequest.java index 41dbb4fb..0fc9e2e4 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/ShouldBlockRequest.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/ShouldBlockRequest.java @@ -5,6 +5,8 @@ import dev.aikido.agent_api.ratelimiting.ShouldRateLimit; import dev.aikido.agent_api.storage.ServiceConfigStore; import dev.aikido.agent_api.storage.ServiceConfiguration; +import dev.aikido.agent_api.storage.routes.RoutesStore; +import dev.aikido.agent_api.storage.statistics.StatisticsStore; public final class ShouldBlockRequest { private ShouldBlockRequest() { @@ -34,6 +36,14 @@ public static ShouldBlockRequestResult shouldBlockRequest() { context.getRouteMetadata(), context.getUser(), context.getRemoteAddress() ); if (rateLimitDecision.block()) { + // increment rate-limiting stats both globally and on the route : + StatisticsStore.incrementRateLimited(); + // increment routes stats using method & route from the endpoint (store stats for wildcards, in wildcard route) + RoutesStore.addRouteRateLimitedCount( + rateLimitDecision.rateLimitedEndpoint().getMethod(), + rateLimitDecision.rateLimitedEndpoint().getRoute() + ); + BlockedRequestResult blockedRequestResult = new BlockedRequestResult( "ratelimited", rateLimitDecision.trigger(), context.getRemoteAddress() ); diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebResponseCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebResponseCollector.java index d8ccd832..540843ed 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebResponseCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebResponseCollector.java @@ -28,13 +28,13 @@ public static void report(int statusCode) { return; } - RoutesStore.addRouteHits(routeMetadata); + RoutesStore.addRouteHits(context.getMethod(), context.getRoute()); // check if we need to generate api spec - int hits = RoutesStore.getRouteHits(routeMetadata); + int hits = RoutesStore.getRouteHits(context.getMethod(), context.getRoute()); if (hits <= ANALYSIS_ON_FIRST_X_REQUESTS) { APISpec apiSpec = getApiInfo(context); - RoutesStore.updateApiSpec(routeMetadata, apiSpec); + RoutesStore.updateApiSpec(context.getMethod(), context.getRoute(), apiSpec); } } -} \ No newline at end of file +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/ratelimiting/ShouldRateLimit.java b/agent_api/src/main/java/dev/aikido/agent_api/ratelimiting/ShouldRateLimit.java index 31bbd31f..2482314e 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/ratelimiting/ShouldRateLimit.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/ratelimiting/ShouldRateLimit.java @@ -13,13 +13,20 @@ public final class ShouldRateLimit { private ShouldRateLimit() {} - public record RateLimitDecision(boolean block, String trigger) {} + + public record RateLimitDecision( + boolean block, + String trigger, + Endpoint rateLimitedEndpoint + ) { + } + public static RateLimitDecision shouldRateLimit(RouteMetadata routeMetadata, User user, String remoteAddress) { List endpoints = ServiceConfigStore.getConfig().getEndpoints(); List matches = matchEndpoints(routeMetadata, endpoints); Endpoint rateLimitedEndpoint = getRateLimitedEndpoint(matches, routeMetadata.route()); if (rateLimitedEndpoint == null) { - return new RateLimitDecision(/*block*/false, null); + return new RateLimitDecision(/*block*/false, null, null); } long windowSizeInMS = rateLimitedEndpoint.getRateLimiting().windowSizeInMS(); @@ -29,17 +36,17 @@ public static RateLimitDecision shouldRateLimit(RouteMetadata routeMetadata, Use boolean allowed = RateLimiterStore.isAllowed(key, windowSizeInMS, maxRequests); if (allowed) { // Do not continue to check based on IP if user is present: - return new RateLimitDecision(/*block*/false, null); + return new RateLimitDecision(/*block*/false, null, null); } - return new RateLimitDecision(/*block*/ true, /*trigger*/ "user"); + return new RateLimitDecision(/*block*/ true, /*trigger*/ "user", rateLimitedEndpoint); } if (remoteAddress != null && !remoteAddress.isEmpty()) { String key = rateLimitedEndpoint.getMethod() + ":" + rateLimitedEndpoint.getRoute() + ":ip:" + remoteAddress; boolean allowed = RateLimiterStore.isAllowed(key, windowSizeInMS, maxRequests); if (!allowed) { - return new RateLimitDecision(/*block*/ true, /*trigger*/ "ip"); + return new RateLimitDecision(/*block*/ true, /*trigger*/ "ip", rateLimitedEndpoint); } } - return new RateLimitDecision(/*block*/false, null); + return new RateLimitDecision(/*block*/false, null, null); } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteEntry.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteEntry.java index 9bfe2635..d17ebf60 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteEntry.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteEntry.java @@ -1,17 +1,15 @@ package dev.aikido.agent_api.storage.routes; -import com.google.gson.*; import dev.aikido.agent_api.api_discovery.APISpec; import dev.aikido.agent_api.context.RouteMetadata; -import java.lang.reflect.Type; - import static dev.aikido.agent_api.api_discovery.APISpecMerger.mergeAPISpecs; public class RouteEntry { final String method; final String path; private int hits; + private int rateLimitedCount; private APISpec apispec; public RouteEntry(String method, String path) { @@ -32,6 +30,13 @@ public int getHits() { return hits; } + public void incrementRateLimitCount() { + rateLimitedCount++; + } + + public int getRateLimitCount() { + return rateLimitedCount; + } public void updateApiSpec(APISpec newApiSpec) { this.apispec = mergeAPISpecs(newApiSpec, this.apispec); } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteToKeyHelper.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteToKeyHelper.java index 7b74c500..e4835906 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteToKeyHelper.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RouteToKeyHelper.java @@ -1,11 +1,9 @@ package dev.aikido.agent_api.storage.routes; -import dev.aikido.agent_api.context.RouteMetadata; - public final class RouteToKeyHelper { private RouteToKeyHelper() {} - public static String routeToKey(RouteMetadata routeMetadata) { - return routeMetadata.method() + ":" + routeMetadata.route(); + public static String routeToKey(String method, String route) { + return method + ":" + route; } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/Routes.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/Routes.java index 64503b5f..311a4f6d 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/Routes.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/Routes.java @@ -19,26 +19,32 @@ public Routes() { this(1000); // Default max size } - private void initializeRoute(RouteMetadata routeMetadata) { + private void ensureRoute(String method, String route) { manageRoutesSize(); - String key = routeToKey(routeMetadata); - routes.put(key, new RouteEntry(routeMetadata)); + String key = routeToKey(method, route); + if(!routes.containsKey(key)) { + routes.put(key, new RouteEntry(method, route)); + } } - public void incrementRoute(RouteMetadata routeMetadata) { - String key = routeToKey(routeMetadata); - if (!routes.containsKey(key)) { - // if the route does not yet exist, create it. - initializeRoute(routeMetadata); + public void incrementRoute(String method, String route) { + ensureRoute(method, route); + RouteEntry routeEntry = this.get(method, route); + if (routeEntry != null) { + routeEntry.incrementHits(); } - RouteEntry route = routes.get(key); - if (route != null) { - route.incrementHits(); + } + + public void incrementRateLimitCount(String method, String route) { + ensureRoute(method, route); + RouteEntry routeEntry = this.get(method, route); + if (routeEntry != null) { + routeEntry.incrementRateLimitCount(); } } - public RouteEntry get(RouteMetadata routeMetadata) { - String key = routeToKey(routeMetadata); + public RouteEntry get(String method, String route) { + String key = routeToKey(method, route); return routes.get(key); } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RoutesStore.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RoutesStore.java index 4fcb6b31..9f7b2e4d 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RoutesStore.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RoutesStore.java @@ -16,10 +16,10 @@ private RoutesStore() { } - public static int getRouteHits(RouteMetadata routeMetadata) { + public static int getRouteHits(String method, String route) { mutex.lock(); try { - return routes.get(routeMetadata).getHits(); + return routes.get(method, route).getHits(); } finally { mutex.unlock(); } @@ -34,12 +34,12 @@ public static RouteEntry[] getRoutesAsList() { } } - public static void updateApiSpec(RouteMetadata routeMetadata, APISpec apiSpec) { + public static void updateApiSpec(String method, String route, APISpec apiSpec) { mutex.lock(); try { - RouteEntry route = routes.get(routeMetadata); - if (route != null) { - route.updateApiSpec(apiSpec); + RouteEntry routeEntry = routes.get(method, route); + if (routeEntry != null) { + routeEntry.updateApiSpec(apiSpec); } } catch (Throwable e) { logger.debug("Error occurred updating api specs: %s", e.getMessage()); @@ -48,10 +48,10 @@ public static void updateApiSpec(RouteMetadata routeMetadata, APISpec apiSpec) { } } - public static void addRouteHits(RouteMetadata routeMetadata) { + public static void addRouteHits(String method, String route) { mutex.lock(); try { - routes.incrementRoute(routeMetadata); + routes.incrementRoute(method, route); } catch (Throwable e) { logger.debug("Error occurred incrementing route hits: %s", e.getMessage()); } finally { @@ -59,6 +59,17 @@ public static void addRouteHits(RouteMetadata routeMetadata) { } } + public static void addRouteRateLimitedCount(String method, String route) { + mutex.lock(); + try { + routes.incrementRateLimitCount(method, route); + } catch (Throwable e) { + logger.debug("Error occurred incrementing route rate limit count: %s", e.getMessage()); + } finally { + mutex.unlock(); + } + } + public static void clear() { mutex.lock(); try { diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/Statistics.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/Statistics.java index ce53f76a..64e258ac 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/Statistics.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/Statistics.java @@ -10,19 +10,19 @@ public class Statistics { private final Map ipAddressMatches = new HashMap<>(); private final Map userAgentMatches = new HashMap<>(); private int totalHits; + private final int aborted; // We don't use the "aborted" field right now + private int rateLimited; private int attacksDetected; private int attacksBlocked; private long startedAt; - public Statistics(int totalHits, int attacksDetected, int attacksBlocked) { - this.totalHits = totalHits; - this.attacksDetected = attacksDetected; - this.attacksBlocked = attacksBlocked; - this.startedAt = UnixTimeMS.getUnixTimeMS(); - } - public Statistics() { - this(0, 0, 0); + this.totalHits = 0; + this.rateLimited = 0; + this.aborted = 0; + this.attacksDetected = 0; + this.attacksBlocked = 0; + this.startedAt = UnixTimeMS.getUnixTimeMS(); } @@ -35,6 +35,14 @@ public int getTotalHits() { return totalHits; } + public void incrementRateLimited() { + rateLimited += 1; + } + + public int getRateLimited() { + return rateLimited; + } + // attack stats public void incrementAttacksDetected(String operation) { @@ -104,8 +112,7 @@ public void addMatchToUserAgents(String key) { public StatsRecord getRecord() { long endedAt = UnixTimeMS.getUnixTimeMS(); return new StatsRecord(this.startedAt, endedAt, new StatsRequestsRecord( - /* total */ totalHits, - /* aborted */ 0, // Unknown statistic, default to 0, + totalHits, aborted, rateLimited, /* attacksDetected */ Map.of( "total", attacksDetected, "blocked", attacksBlocked @@ -118,6 +125,7 @@ public StatsRecord getRecord() { public void clear() { this.totalHits = 0; + this.rateLimited = 0; this.attacksBlocked = 0; this.attacksDetected = 0; this.startedAt = UnixTimeMS.getUnixTimeMS(); @@ -127,7 +135,8 @@ public void clear() { } // Stats records for sending out the heartbeat : - public record StatsRequestsRecord(long total, long aborted, Map attacksDetected) { + public record StatsRequestsRecord(long total, long aborted, long rateLimited, + Map attacksDetected) { } public record StatsRecord(long startedAt, long endedAt, StatsRequestsRecord requests, diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/StatisticsStore.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/StatisticsStore.java index 3d1fad42..e611793c 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/StatisticsStore.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/statistics/StatisticsStore.java @@ -35,6 +35,15 @@ public static void incrementHits() { } } + public static void incrementRateLimited() { + mutex.lock(); + try { + stats.incrementRateLimited(); + } finally { + mutex.unlock(); + } + } + public static void incrementAttacksDetected(String operation) { mutex.lock(); try { diff --git a/agent_api/src/test/java/ShouldBlockRequestTest.java b/agent_api/src/test/java/ShouldBlockRequestTest.java index 3b731e03..eeddc6c2 100644 --- a/agent_api/src/test/java/ShouldBlockRequestTest.java +++ b/agent_api/src/test/java/ShouldBlockRequestTest.java @@ -4,7 +4,10 @@ import dev.aikido.agent_api.context.Context; import dev.aikido.agent_api.context.ContextObject; import dev.aikido.agent_api.context.User; +import dev.aikido.agent_api.storage.RateLimiterStore; import dev.aikido.agent_api.storage.ServiceConfigStore; +import dev.aikido.agent_api.storage.routes.RoutesStore; +import dev.aikido.agent_api.storage.statistics.StatisticsStore; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -38,12 +41,18 @@ public SampleContextObject() { public static void clean() { Context.set(null); ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse); + StatisticsStore.clear(); + RoutesStore.clear(); + RateLimiterStore.clear(); }; @AfterEach public void tearDown() throws SQLException { Context.set(null); ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse); + StatisticsStore.clear(); + RoutesStore.clear(); + RateLimiterStore.clear(); } @Test @@ -59,6 +68,7 @@ public void testNoContext() throws SQLException { // Test with thread cache not set : var res2 = ShouldBlockRequest.shouldBlockRequest(); assertFalse(res2.block()); + assertEquals(0, StatisticsStore.getStatsRecord().requests().rateLimited()); } @Test @@ -112,7 +122,8 @@ public void testUserSet() throws SQLException { @Test public void testEndpointsExistButNoMatch() throws SQLException { - Context.set(null); + ContextObject ctx = new SampleContextObject(); + Context.set(ctx); setEmptyConfigWithEndpointList(List.of( new Endpoint("POST", "/api2/*", 1, 1000, Collections.emptyList(), false, false, false) )); @@ -121,7 +132,6 @@ public void testEndpointsExistButNoMatch() throws SQLException { var res1 = ShouldBlockRequest.shouldBlockRequest(); assertFalse(res1.block()); - Context.set(null); setEmptyConfigWithEndpointList(List.of( new Endpoint("POST", "/api2/*", 1, 1000, Collections.emptyList(), false, false, true) )); @@ -133,7 +143,8 @@ public void testEndpointsExistButNoMatch() throws SQLException { @Test public void testEndpointsExistWithMatch() throws SQLException { - Context.set(null); + ContextObject ctx = new SampleContextObject(); + Context.set(ctx); setEmptyConfigWithEndpointList(List.of( new Endpoint("GET", "/api/*", 1, 1000, Collections.emptyList(), false, false, false) )); @@ -142,7 +153,6 @@ public void testEndpointsExistWithMatch() throws SQLException { var res1 = ShouldBlockRequest.shouldBlockRequest(); assertFalse(res1.block()); - Context.set(null); setEmptyConfigWithEndpointList(List.of( new Endpoint("GET", "/api/*", 1, 1000, Collections.emptyList(), false, false, true) )); @@ -150,6 +160,19 @@ public void testEndpointsExistWithMatch() throws SQLException { // Test with match & rate-limiting enabled : var res2 = ShouldBlockRequest.shouldBlockRequest(); assertFalse(res2.block()); + assertEquals(0, StatisticsStore.getStatsRecord().requests().rateLimited()); + + + var res3 = ShouldBlockRequest.shouldBlockRequest(); + var res4 = ShouldBlockRequest.shouldBlockRequest(); + assertTrue(res3.block()); + assertTrue(res4.block()); + assertEquals("ip", res3.data().trigger()); + assertEquals("192.168.1.1", res3.data().ip()); + assertEquals("ratelimited", res3.data().type()); + assertEquals(2, StatisticsStore.getStatsRecord().requests().rateLimited()); + assertEquals(2, RoutesStore.getRoutesAsList()[0].getRateLimitCount()); + } @Test diff --git a/agent_api/src/test/java/collectors/WebResponseCollectorTest.java b/agent_api/src/test/java/collectors/WebResponseCollectorTest.java index 657b8514..6ac91354 100644 --- a/agent_api/src/test/java/collectors/WebResponseCollectorTest.java +++ b/agent_api/src/test/java/collectors/WebResponseCollectorTest.java @@ -32,7 +32,6 @@ public SampleContextObject(String method) { this.executedMiddleware = true; // Start with "executed middleware" as true } } - public static RouteMetadata routeMetadata1 = new RouteMetadata("/api/resource", "https://example.com/api/resource", "GET"); @BeforeAll public static void clean() { @@ -53,23 +52,23 @@ public void testResponseCollector1() throws SQLException { assertEquals(0, RoutesStore.getRoutesAsList().length); WebResponseCollector.report(200); assertEquals(1, RoutesStore.getRoutesAsList().length); - assertEquals(1, RoutesStore.getRouteHits(routeMetadata1)); + assertEquals(1, RoutesStore.getRouteHits("GET", "/api/resource")); // Test same route but incremented hits : WebResponseCollector.report(201); assertEquals(1, RoutesStore.getRoutesAsList().length); - assertEquals(2, RoutesStore.getRouteHits(routeMetadata1)); + assertEquals(2, RoutesStore.getRouteHits("GET", "/api/resource")); // Test same route but invalid status code WebResponseCollector.report(0); assertEquals(1, RoutesStore.getRoutesAsList().length); - assertEquals(2, RoutesStore.getRouteHits(routeMetadata1)); + assertEquals(2, RoutesStore.getRouteHits("GET", "/api/resource")); // Test same route but context not set : Context.set(null); WebResponseCollector.report(200); assertEquals(1, RoutesStore.getRoutesAsList().length); - assertEquals(2, RoutesStore.getRouteHits(routeMetadata1)); + assertEquals(2, RoutesStore.getRouteHits("GET", "/api/resource")); RoutesStore.clear(); assertEquals(0, RoutesStore.getRoutesAsList().length); diff --git a/agent_api/src/test/java/storage/RouteEntryTest.java b/agent_api/src/test/java/storage/RouteEntryTest.java index d25fea8a..e77388dd 100644 --- a/agent_api/src/test/java/storage/RouteEntryTest.java +++ b/agent_api/src/test/java/storage/RouteEntryTest.java @@ -32,9 +32,22 @@ public void testGsonWithoutSerializer() throws IOException { Gson gson = new Gson(); String json = gson.toJson(route1); assertEquals( - "{\"method\":\"GET\",\"path\":\"/api/1\",\"hits\":0,\"apispec\":{\"body\":{\"schema\":{\"type\":\"object\",\"properties\":{\"oldProp\":{\"type\":\"string\",\"optional\":false}},\"optional\":false},\"type\":\"oldType\"},\"auth\":[{\"type\":\"apiKey\"}]}}", + "{\"method\":\"GET\",\"path\":\"/api/1\",\"hits\":0,\"rateLimitedCount\":0,\"apispec\":{\"body\":{\"schema\":{\"type\":\"object\",\"properties\":{\"oldProp\":{\"type\":\"string\",\"optional\":false}},\"optional\":false},\"type\":\"oldType\"},\"auth\":[{\"type\":\"apiKey\"}]}}", json ); } + @Test + public void testIncrementRateLimitedCount() { + // Initial count should be 0 + assertEquals(0, route1.getRateLimitCount()); + + // Increment the rate limited count + route1.incrementRateLimitCount(); + assertEquals(1, route1.getRateLimitCount()); + + // Increment again + route1.incrementRateLimitCount(); + assertEquals(2, route1.getRateLimitCount()); + } } diff --git a/agent_api/src/test/java/storage/RoutesTest.java b/agent_api/src/test/java/storage/RoutesTest.java index 90eabb43..1538e868 100644 --- a/agent_api/src/test/java/storage/RoutesTest.java +++ b/agent_api/src/test/java/storage/RoutesTest.java @@ -11,81 +11,89 @@ class RoutesTest { private Routes routes; - private RouteMetadata routeMetadata1; - private RouteMetadata routeMetadata2; - private RouteMetadata routeMetadata3; @BeforeEach void setUp() { routes = new Routes(2); // Set max size to 2 for testing - routeMetadata1 = new RouteMetadata("/api/test1", "/api/test1", "GET"); - routeMetadata2 = new RouteMetadata( "/api/test2", "/api/test2", "POST"); - routeMetadata3 = new RouteMetadata("/api/test3", "/api/test3", "PUT"); } @Test void testInitializeRoute() { - routes.incrementRoute(routeMetadata1); + routes.incrementRoute("GET", "/api/test1"); assertEquals(1, routes.size()); - assertNotNull(routes.get(routeMetadata1)); + assertNotNull(routes.get("GET", "/api/test1")); } @Test void testInitializeDuplicateRoute() { - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata1); // Should not add again + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("GET", "/api/test1"); // Should not add again assertEquals(1, routes.size()); } @Test void testIncrementRouteHits() { - routes.incrementRoute(routeMetadata1); - RouteEntry entry = routes.get(routeMetadata1); + routes.incrementRoute("GET", "/api/test1"); + RouteEntry entry = routes.get("GET", "/api/test1"); assertNotNull(entry); assertEquals(1, entry.getHits()); } @Test void testIncrementNonExistentRoute() { - routes.incrementRoute(routeMetadata1); + routes.incrementRoute("GET", "/api/test1"); + assertEquals(1, routes.size()); + } + + @Test + void testIncrementRouteRateLimitCount() { + routes.incrementRateLimitCount("GET", "/api/test1"); + RouteEntry entry = routes.get("GET", "/api/test1"); + assertNotNull(entry); + assertEquals(1, entry.getRateLimitCount()); + } + + @Test + void testIncrementNonExistentRouteRateLimit() { + routes.incrementRateLimitCount("GET", "/api/test1"); assertEquals(1, routes.size()); } @Test void testManageRoutesSize() { - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata2); + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("POST", "/api/test2"); assertEquals(2, routes.size()); - routes.incrementRoute(routeMetadata3); // This should evict the least used route + routes.incrementRoute("PUT", "/api/test3"); // This should evict the least used route assertEquals(2, routes.size()); - assertNull(routes.get(routeMetadata1)); // routeMetadata1 should be evicted - assertNotNull(routes.get(routeMetadata2)); // routeMetadata2 should still exist + assertNull(routes.get("GET", "/api/test1")); // routeMetadata1 should be evicted + assertNotNull(routes.get("POST", "/api/test2")); // "POST", "/api/test2" should still exist } @Test void testClearRoutes() { - routes.incrementRoute(routeMetadata1); + routes.incrementRoute("GET", "/api/test1"); routes.clear(); assertEquals(0, routes.size()); } @Test void testMultipleInitializations() { - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata2); - routes.incrementRoute(routeMetadata3); + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("POST", "/api/test2"); + routes.incrementRoute("PUT", "/api/test3"); assertEquals(2, routes.size()); // Only 2 should remain } @Test void testIncrementMultipleTimes() { - routes.incrementRoute(routeMetadata1); + routes.incrementRoute("GET", "/api/test1"); for (int i = 0; i < 5; i++) { - routes.incrementRoute(routeMetadata1); + routes.incrementRoute("GET", "/api/test1"); } - RouteEntry entry = routes.get(routeMetadata1); + RouteEntry entry = routes.get("GET", "/api/test1"); assertNotNull(entry); assertEquals(6, entry.getHits()); } @@ -93,52 +101,52 @@ void testIncrementMultipleTimes() { @Test void testDefaultConstructor() { routes = new Routes(); - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata1); // Increment hits for routeMetadata1 - routes.incrementRoute(routeMetadata2); + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("GET", "/api/test1"); // Increment hits for routeMetadata1 + routes.incrementRoute("POST", "/api/test2"); for (int i = 0; i < (1000 - 1); i++) { - routes.incrementRoute(new RouteMetadata(String.valueOf(i), "api/test3", "GET")); + routes.incrementRoute("GET", String.valueOf(i)); } assertEquals(1000, routes.asList().length); - assertNull(routes.get(routeMetadata2)); // routeMetadata2 should be evicted - assertNotNull(routes.get(routeMetadata1)); // routeMetadata1 should still exist + assertNull(routes.get("POST", "/api/test2")); // "POST", "/api/test2" should be evicted + assertNotNull(routes.get("GET", "/api/test1")); // routeMetadata1 should still exist } @Test void testEvictionOrder() { - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata1); // Increment hits for routeMetadata1 - routes.incrementRoute(routeMetadata2); - routes.incrementRoute(routeMetadata3); // This should evict routeMetadata2 (routeMetadata1 has more hits) - - assertNull(routes.get(routeMetadata2)); // routeMetadata2 should be evicted - assertNotNull(routes.get(routeMetadata1)); // routeMetadata1 should still exist - assertNotNull(routes.get(routeMetadata3)); // routeMetadata3 should exist + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("GET", "/api/test1"); // Increment hits for routeMetadata1 + routes.incrementRoute("POST", "/api/test2"); + routes.incrementRoute("PUT", "/api/test3"); // This should evict "POST", "/api/test2" (routeMetadata1 has more hits) + + assertNull(routes.get("POST", "/api/test2")); // "POST", "/api/test2" should be evicted + assertNotNull(routes.get("GET", "/api/test1")); // routeMetadata1 should still exist + assertNotNull(routes.get("PUT", "/api/test3")); // "PUT", "/api/test3" should exist } @Test void testSizeAfterEviction() { - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata2); + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("POST", "/api/test2"); assertEquals(2, routes.size()); - routes.incrementRoute(new RouteMetadata("DELETE", "", "/api/test4")); // Evict one + routes.incrementRoute("DELETE", "/api/test4"); assertEquals(2, routes.size()); // Size should remain 2 } @Test void testIterator() { - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata2); + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("POST", "/api/test2"); assertEquals(2, routes.size()); } @Test void testIteratorAfterEviction() { - routes.incrementRoute(routeMetadata1); - routes.incrementRoute(routeMetadata2); + routes.incrementRoute("GET", "/api/test1"); + routes.incrementRoute("POST", "/api/test2"); assertEquals(2, routes.size()); // Should still be 2 after eviction - routes.incrementRoute(routeMetadata3); // Evict one + routes.incrementRoute("PUT", "/api/test3"); // Evict one assertEquals(2, routes.size()); // Should still be 2 after eviction } diff --git a/agent_api/src/test/java/storage/StatisticsTest.java b/agent_api/src/test/java/storage/StatisticsTest.java index 6d0101c9..c66dfbbe 100644 --- a/agent_api/src/test/java/storage/StatisticsTest.java +++ b/agent_api/src/test/java/storage/StatisticsTest.java @@ -34,9 +34,11 @@ public void testClear() { stats.incrementAttacksDetected("test2"); stats.incrementAttacksDetected("test1"); stats.incrementAttacksDetected("test1"); + stats.incrementRateLimited(); assertEquals(3, stats.getAttacksDetected()); assertEquals(2, stats.getAttacksBlocked()); assertEquals(20, stats.getTotalHits()); + assertEquals(1, stats.getRateLimited()); assertEquals(2, stats.getOperations().get("test1").getAttacksDetected().get("total")); assertEquals(1, stats.getOperations().get("test1").getAttacksDetected().get("blocked")); @@ -47,20 +49,29 @@ public void testClear() { assertEquals(0, stats.getAttacksBlocked()); assertEquals(0, stats.getAttacksDetected()); assertEquals(0, stats.getTotalHits()); - + assertEquals(0, stats.getRateLimited()); } @Test public void testConstructor() { - Statistics stats2 = new Statistics(100, 5, 1); - assertEquals(100, stats2.getTotalHits()); - assertEquals(5, stats2.getAttacksDetected()); - assertEquals(1, stats2.getAttacksBlocked()); + Statistics stats2 = new Statistics(); + assertEquals(0, stats2.getTotalHits()); + assertEquals(0, stats2.getRateLimited()); + assertEquals(0, stats2.getAttacksDetected()); + assertEquals(0, stats2.getAttacksBlocked()); } @Test public void testStatsRecord() { - Statistics stats2 = new Statistics(100, 5, 1); + Statistics stats2 = new Statistics(); + stats2.incrementTotalHits(100); + stats2.incrementAttacksDetected("op2"); + stats2.incrementAttacksDetected("op2"); + stats2.incrementAttacksDetected("op2"); + stats2.incrementAttacksDetected("op2"); + stats2.incrementAttacksDetected("op2"); + stats2.incrementAttacksBlocked("op2"); + stats2.registerCall("operation1", OperationKind.FS_OP); Statistics.StatsRecord statsRecord = stats2.getRecord(); assertEquals(5, statsRecord.requests().attacksDetected().get("total"));