|
| 1 | +// Copyright (c) Microsoft Corporation. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +using Microsoft.Extensions.Configuration; |
| 5 | +using Microsoft.Graph.DeveloperProxy.Abstractions; |
| 6 | +using System.Net; |
| 7 | +using System.Text.Json; |
| 8 | +using System.Text.Json.Serialization; |
| 9 | +using System.Text.RegularExpressions; |
| 10 | +using Titanium.Web.Proxy.Http; |
| 11 | +using Titanium.Web.Proxy.Models; |
| 12 | + |
| 13 | +namespace Microsoft.Graph.DeveloperProxy.Plugins.Behavior; |
| 14 | + |
| 15 | +public class RateLimitConfiguration { |
| 16 | + public string HeaderLimit { get; set; } = "RateLimit-Limit"; |
| 17 | + public string HeaderRemaining { get; set; } = "RateLimit-Remaining"; |
| 18 | + public string HeaderReset { get; set; } = "RateLimit-Reset"; |
| 19 | + public string HeaderRetryAfter { get; set; } = "Retry-After"; |
| 20 | + public int CostPerRequest { get; set; } = 2; |
| 21 | + public int ResetTimeWindowSeconds { get; set; } = 60; |
| 22 | + public int WarningThresholdPercent { get; set; } = 80; |
| 23 | + public int RateLimit { get; set; } = 120; |
| 24 | + public int RetryAfterSeconds { get; set; } = 5; |
| 25 | +} |
| 26 | + |
| 27 | +public class RateLimitingPlugin : BaseProxyPlugin { |
| 28 | + public override string Name => nameof(RateLimitingPlugin); |
| 29 | + private readonly RateLimitConfiguration _configuration = new(); |
| 30 | + private readonly Dictionary<string, DateTime> _throttledRequests = new(); |
| 31 | + // initial values so that we know when we intercept the |
| 32 | + // first request and can set the initial values |
| 33 | + private int _resourcesRemaining = -1; |
| 34 | + private DateTime _resetTime = DateTime.MinValue; |
| 35 | + |
| 36 | + private bool ShouldForceThrottle(ProxyRequestArgs e) { |
| 37 | + var r = e.Session.HttpClient.Request; |
| 38 | + string key = BuildThrottleKey(r); |
| 39 | + if (_throttledRequests.TryGetValue(key, out DateTime retryAfterDate)) { |
| 40 | + if (retryAfterDate > DateTime.Now) { |
| 41 | + _logger?.LogRequest(new[] { $"Calling {r.Url} again before waiting for the Retry-After period.", "Request will be throttled" }, MessageType.Failed, new LoggingContext(e.Session)); |
| 42 | + // update the retryAfterDate to extend the throttling window to ensure that brute forcing won't succeed. |
| 43 | + _throttledRequests[key] = retryAfterDate.AddSeconds(_configuration.RetryAfterSeconds); |
| 44 | + return true; |
| 45 | + } |
| 46 | + else { |
| 47 | + // clean up expired throttled request and ensure that this request is passed through. |
| 48 | + _throttledRequests.Remove(key); |
| 49 | + return false; |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + return false; |
| 54 | + } |
| 55 | + |
| 56 | + private void ForceThrottleResponse(ProxyRequestArgs e) => UpdateProxyResponse(e, HttpStatusCode.TooManyRequests); |
| 57 | + |
| 58 | + private bool ShouldThrottle(ProxyRequestArgs e) { |
| 59 | + if (_resourcesRemaining > 0) { |
| 60 | + return false; |
| 61 | + } |
| 62 | + |
| 63 | + var r = e.Session.HttpClient.Request; |
| 64 | + string key = BuildThrottleKey(r); |
| 65 | + |
| 66 | + _logger?.LogRequest(new[] { $"Exceeded resource limit when calling {r.Url}.", "Request will be throttled" }, MessageType.Failed, new LoggingContext(e.Session)); |
| 67 | + // update the retryAfterDate to extend the throttling window to ensure that brute forcing won't succeed. |
| 68 | + _throttledRequests[key] = DateTime.Now.AddSeconds(_configuration.RetryAfterSeconds); |
| 69 | + return true; |
| 70 | + } |
| 71 | + |
| 72 | + private void ThrottleResponse(ProxyRequestArgs e) => UpdateProxyResponse(e, HttpStatusCode.TooManyRequests); |
| 73 | + |
| 74 | + private void UpdateProxyResponse(ProxyHttpEventArgsBase e, HttpStatusCode errorStatus) { |
| 75 | + var headers = new List<HttpHeader>(); |
| 76 | + var body = string.Empty; |
| 77 | + var request = e.Session.HttpClient.Request; |
| 78 | + |
| 79 | + // override the response body and headers for the error response |
| 80 | + if (errorStatus != HttpStatusCode.OK && |
| 81 | + ProxyUtils.IsGraphRequest(request)) { |
| 82 | + string requestId = Guid.NewGuid().ToString(); |
| 83 | + string requestDate = DateTime.Now.ToString(); |
| 84 | + headers.AddRange(ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate)); |
| 85 | + |
| 86 | + body = JsonSerializer.Serialize(new GraphErrorResponseBody( |
| 87 | + new GraphErrorResponseError { |
| 88 | + Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), |
| 89 | + Message = BuildApiErrorMessage(request), |
| 90 | + InnerError = new GraphErrorResponseInnerError { |
| 91 | + RequestId = requestId, |
| 92 | + Date = requestDate |
| 93 | + } |
| 94 | + }) |
| 95 | + ); |
| 96 | + } |
| 97 | + |
| 98 | + // add rate limiting headers if reached the threshold percentage |
| 99 | + if (_resourcesRemaining <= _configuration.RateLimit - (_configuration.RateLimit * _configuration.WarningThresholdPercent / 100)) { |
| 100 | + headers.AddRange(new List<HttpHeader> { |
| 101 | + new HttpHeader(_configuration.HeaderLimit, _configuration.RateLimit.ToString()), |
| 102 | + new HttpHeader(_configuration.HeaderRemaining, _resourcesRemaining.ToString()), |
| 103 | + new HttpHeader(_configuration.HeaderReset, (_resetTime - DateTime.Now).TotalSeconds.ToString("N0")) // drop decimals |
| 104 | + }); |
| 105 | + } |
| 106 | + |
| 107 | + // send an error response if we are (forced) throttling |
| 108 | + if (errorStatus == HttpStatusCode.TooManyRequests) { |
| 109 | + headers.Add(new HttpHeader(_configuration.HeaderRetryAfter, _configuration.RetryAfterSeconds.ToString())); |
| 110 | + |
| 111 | + e.Session.GenericResponse(body ?? string.Empty, errorStatus, headers); |
| 112 | + return; |
| 113 | + } |
| 114 | + |
| 115 | + if (errorStatus == HttpStatusCode.OK) { |
| 116 | + // add headers to the original API response |
| 117 | + e.Session.HttpClient.Response.Headers.AddHeaders(headers); |
| 118 | + } |
| 119 | + } |
| 120 | + private static string BuildApiErrorMessage(Request r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : String.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage(r)) : "")}"; |
| 121 | + |
| 122 | + private string BuildThrottleKey(Request r) => $"{r.Method}-{r.Url}"; |
| 123 | + |
| 124 | + public override void Register(IPluginEvents pluginEvents, |
| 125 | + IProxyContext context, |
| 126 | + ISet<Regex> urlsToWatch, |
| 127 | + IConfigurationSection? configSection = null) { |
| 128 | + base.Register(pluginEvents, context, urlsToWatch, configSection); |
| 129 | + |
| 130 | + configSection?.Bind(_configuration); |
| 131 | + pluginEvents.BeforeRequest += OnRequest; |
| 132 | + pluginEvents.BeforeResponse += OnResponse; |
| 133 | + } |
| 134 | + |
| 135 | + // add rate limiting headers to the response from the API |
| 136 | + private async Task OnResponse(object? sender, ProxyResponseArgs e) { |
| 137 | + var session = e.Session; |
| 138 | + var state = e.ResponseState; |
| 139 | + if (_urlsToWatch is null || |
| 140 | + !e.HasRequestUrlMatch(_urlsToWatch)) { |
| 141 | + return; |
| 142 | + } |
| 143 | + |
| 144 | + UpdateProxyResponse(e, HttpStatusCode.OK); |
| 145 | + } |
| 146 | + |
| 147 | + private async Task OnRequest(object? sender, ProxyRequestArgs e) { |
| 148 | + var session = e.Session; |
| 149 | + var state = e.ResponseState; |
| 150 | + if (e.ResponseState.HasBeenSet || |
| 151 | + _urlsToWatch is null || |
| 152 | + !e.ShouldExecute(_urlsToWatch)) { |
| 153 | + return; |
| 154 | + } |
| 155 | + |
| 156 | + // set the initial values for the first request |
| 157 | + if (_resetTime == DateTime.MinValue) { |
| 158 | + _resetTime = DateTime.Now.AddSeconds(_configuration.ResetTimeWindowSeconds); |
| 159 | + } |
| 160 | + if (_resourcesRemaining == -1) { |
| 161 | + _resourcesRemaining = _configuration.RateLimit; |
| 162 | + } |
| 163 | + |
| 164 | + // see if we passed the reset time window |
| 165 | + if (DateTime.Now > _resetTime) { |
| 166 | + _resourcesRemaining = _configuration.RateLimit; |
| 167 | + _resetTime = DateTime.Now.AddSeconds(_configuration.ResetTimeWindowSeconds); |
| 168 | + } |
| 169 | + |
| 170 | + // subtract the cost of the request |
| 171 | + _resourcesRemaining -= _configuration.CostPerRequest; |
| 172 | + // avoid communicating negative values |
| 173 | + if (_resourcesRemaining < 0) { |
| 174 | + _resourcesRemaining = 0; |
| 175 | + } |
| 176 | + |
| 177 | + if (ShouldForceThrottle(e)) { |
| 178 | + ForceThrottleResponse(e); |
| 179 | + state.HasBeenSet = true; |
| 180 | + } |
| 181 | + else if (ShouldThrottle(e)) { |
| 182 | + ThrottleResponse(e); |
| 183 | + state.HasBeenSet = true; |
| 184 | + } |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | + |
| 189 | +internal class GraphErrorResponseBody { |
| 190 | + [JsonPropertyName("error")] |
| 191 | + public GraphErrorResponseError Error { get; set; } |
| 192 | + |
| 193 | + public GraphErrorResponseBody(GraphErrorResponseError error) { |
| 194 | + Error = error; |
| 195 | + } |
| 196 | +} |
| 197 | + |
| 198 | +internal class GraphErrorResponseError { |
| 199 | + [JsonPropertyName("code")] |
| 200 | + public string Code { get; set; } = string.Empty; |
| 201 | + [JsonPropertyName("message")] |
| 202 | + public string Message { get; set; } = string.Empty; |
| 203 | + [JsonPropertyName("innerError")] |
| 204 | + public GraphErrorResponseInnerError? InnerError { get; set; } |
| 205 | +} |
| 206 | + |
| 207 | +internal class GraphErrorResponseInnerError { |
| 208 | + [JsonPropertyName("request-id")] |
| 209 | + public string RequestId { get; set; } = string.Empty; |
| 210 | + [JsonPropertyName("date")] |
| 211 | + public string Date { get; set; } = string.Empty; |
| 212 | +} |
0 commit comments