1
1
use std:: collections:: HashMap ;
2
2
use std:: net:: IpAddr ;
3
3
use std:: sync:: { Arc , Mutex } ;
4
- use tracing:: { debug, warn} ;
4
+ use tracing:: { debug, error , info , warn} ;
5
5
6
6
use thiserror:: Error ;
7
7
use tokio:: sync:: { OwnedSemaphorePermit , Semaphore } ;
8
8
9
9
use redis:: { Client , Commands , RedisError } ;
10
- use tracing:: error;
11
10
12
11
#[ derive( Error , Debug ) ]
13
12
pub enum RateLimitError {
@@ -138,13 +137,22 @@ impl RedisRateLimit {
138
137
) -> Result < Self , RedisError > {
139
138
let client = Client :: open ( redis_url) ?;
140
139
141
- Ok ( Self {
140
+ let limiter = Self {
142
141
redis_client : client,
143
142
global_limit,
144
143
per_ip_limit,
145
144
semaphore : Arc :: new ( Semaphore :: new ( global_limit) ) ,
146
145
key_prefix : key_prefix. to_string ( ) ,
147
- } )
146
+ } ;
147
+
148
+ if let Err ( e) = limiter. reset_counters ( ) {
149
+ error ! (
150
+ message = "Failed to reset Redis counters on startup" ,
151
+ error = e. to_string( )
152
+ ) ;
153
+ }
154
+
155
+ Ok ( limiter)
148
156
}
149
157
150
158
/// Get Redis key for tracking global connections
@@ -156,6 +164,29 @@ impl RedisRateLimit {
156
164
fn ip_key ( & self , addr : & IpAddr ) -> String {
157
165
format ! ( "{}:ip:{}:connections" , self . key_prefix, addr)
158
166
}
167
+
168
+ /// Reset all Redis counters associated with this rate limiter
169
+ pub fn reset_counters ( & self ) -> Result < ( ) , RedisError > {
170
+ let mut conn = self . redis_client . get_connection ( ) ?;
171
+
172
+ // Delete the global counter
173
+ let _: ( ) = conn. del ( self . global_key ( ) ) ?;
174
+
175
+ // Find and delete all IP-specific counters with this prefix
176
+ let pattern = format ! ( "{}:ip:*:connections" , self . key_prefix) ;
177
+ let keys: Vec < String > = conn. keys ( pattern) ?;
178
+
179
+ if !keys. is_empty ( ) {
180
+ let _: ( ) = conn. del ( keys) ?;
181
+ }
182
+
183
+ info ! (
184
+ message = "Reset all Redis rate limit counters" ,
185
+ prefix = self . key_prefix
186
+ ) ;
187
+
188
+ Ok ( ( ) )
189
+ }
159
190
}
160
191
161
192
impl RateLimit for RedisRateLimit {
0 commit comments