diff --git a/cmd/api/middleware.go b/cmd/api/middleware.go index 3abc444..710317c 100644 --- a/cmd/api/middleware.go +++ b/cmd/api/middleware.go @@ -3,7 +3,10 @@ package main import ( "fmt" "golang.org/x/time/rate" + "net" "net/http" + "sync" + "time" ) func (app *application) recoverPanic(next http.Handler) http.Handler { @@ -25,17 +28,68 @@ func (app *application) recoverPanic(next http.Handler) http.Handler { } func (app *application) rateLimit(next http.Handler) http.Handler { - // Initialize a new rate limiter which allows an overage of 2 requests per second, with a maximum of 4 requests in a single `burst` - limiter := rate.NewLimiter(2, 4) + // Define a client struct to hold the rate limiter and last seen time for each client + type client struct { + limiter *rate.Limiter + lastSeen time.Time + } + + var ( + mu sync.Mutex + // Update the map so the values are pointers to a client struct + clients = make(map[string]*client) + ) + + // Launch a background goroutine which removes old entries from the clients map once every minute + go func() { + for { + time.Sleep(time.Minute) + + // Lock the mutex to prevent any rate limiter checks from happening while the cleanup is taking place + mu.Lock() + + // Loop through all clients. If they haven't been see within the last three minutes, delete the corresponding entry from the mp + for ip, client := range clients { + if time.Since(client.lastSeen) > 3*time.Minute { + delete(clients, ip) + } + } + + // Importantly, unlock the mutex when the cleanup is complete + mu.Unlock() + } + }() // The function we are returning is a closure, which 'closes over' the limiter variable return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Call limiter.Allow() to see if the request is permitted, and if it's not, then we call the rateLimitExceededResponse() helper to return a 429 Too Many Requests response (we will create this helper in a minute). - if !limiter.Allow() { + // Extract the client's IP address from request. + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + // Lock the mutex to prevent this code from being executed concurrently + mu.Lock() + + // Check to see if the IP address already exists in the map. If it doesn't, then initialize a new rate limiter and add the IP address and limiter to he map. + if _, found := clients[ip]; !found { + clients[ip] = &client{limiter: rate.NewLimiter(2, 4)} + } + + // Update the last see time for the client. + clients[ip].lastSeen = time.Now() + + // Call the Allow() method on the rate limiter for the current IP address. If the request isn't allowed, unlock the mutex and send a 429 Too Many Requests response + if !clients[ip].limiter.Allow() { + mu.Unlock() app.rateLimitExceededResponse(w, r) return } + // Very importantly, unlock the mutex before calling the next handler in the chain. Notice that we DON'T use defer to unlock the mutex, as that would mean that the mutex isn't unlocked until all the handlers downstream of this middleware have also returned + mu.Unlock() + next.ServeHTTP(w, r) }) }