Files
greenlight/cmd/api/middleware.go
2025-11-18 20:10:55 +01:00

102 lines
3.6 KiB
Go

package main
import (
"fmt"
"golang.org/x/time/rate"
"net"
"net/http"
"sync"
"time"
)
func (app *application) recoverPanic(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a deferred function (which will always be run in the event of a panic as Go unwinds the stack).
defer func() {
// Use the builtin recover function to check if there has been a panic or not.
if err := recover(); err != nil {
// If there was a panic, set a "Connection: close" header on the response. This acts as a trigger to make Go's HTTP server automatically close the current connection after a response has been sent.
w.Header().Set("Connection", "close")
// The value returned by recover() has the type any, so we use fmt.Errorf() to normalize it into an error and call our serverErrorResponse() helper. In turn, this will log the error using our custom Logger type at the ERROR level and send the client a 500 Internal Server Error response.
app.serverErrorResponse(w, r, fmt.Errorf("%s", err))
}
}()
next.ServeHTTP(w, r)
})
}
func (app *application) rateLimit(next http.Handler) http.Handler {
// Define a client struct to hold the rate limiter and last seen time for each client
type client struct {
limiter *rate.Limiter
lastSeen time.Time
}
var (
mu sync.Mutex
// Update the map so the values are pointers to a client struct
clients = make(map[string]*client)
)
// Launch a background goroutine which removes old entries from the clients map once every minute
go func() {
for {
time.Sleep(time.Minute)
// Lock the mutex to prevent any rate limiter checks from happening while the cleanup is taking place
mu.Lock()
// Loop through all clients. If they haven't been see within the last three minutes, delete the corresponding entry from the mp
for ip, client := range clients {
if time.Since(client.lastSeen) > 3*time.Minute {
delete(clients, ip)
}
}
// Importantly, unlock the mutex when the cleanup is complete
mu.Unlock()
}
}()
// The function we are returning is a closure, which 'closes over' the limiter variable
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only carry out the check if rate limiting is enabled
if app.config.limiter.enabled {
// Extract the client's IP address from request.
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
// Lock the mutex to prevent this code from being executed concurrently
mu.Lock()
// Check to see if the IP address already exists in the map. If it doesn't, then initialize a new rate limiter and add the IP address and limiter to the map.
if _, found := clients[ip]; !found {
// Use the requests-per-second and burst values from the config struct
rps := rate.Limit(app.config.limiter.rps)
burst := app.config.limiter.burst
clients[ip] = &client{limiter: rate.NewLimiter(rps, burst)}
}
// Update the last see time for the client.
clients[ip].lastSeen = time.Now()
// Call the Allow() method on the rate limiter for the current IP address. If the request isn't allowed, unlock the mutex and send a 429 Too Many Requests response
if !clients[ip].limiter.Allow() {
mu.Unlock()
app.rateLimitExceededResponse(w, r)
return
}
// Very importantly, unlock the mutex before calling the next handler in the chain. Notice that we DON'T use defer to unlock the mutex, as that would mean that the mutex isn't unlocked until all the handlers downstream of this middleware have also returned
mu.Unlock()
}
next.ServeHTTP(w, r)
})
}