diff --git a/.idea/dictionaries/project.xml b/.idea/dictionaries/project.xml index a7ca38b..5c02685 100644 --- a/.idea/dictionaries/project.xml +++ b/.idea/dictionaries/project.xml @@ -1,8 +1,10 @@ + bigserial httprouter plainto + postgre servemux tsvector diff --git a/internal/data/filters.go b/internal/data/filters.go index 3e4d560..33cbbca 100644 --- a/internal/data/filters.go +++ b/internal/data/filters.go @@ -1,6 +1,9 @@ package data -import "greenlight.craftr.fr/internal/validator" +import ( + "greenlight.craftr.fr/internal/validator" + "strings" +) type Filters struct { Page int @@ -19,3 +22,23 @@ func ValidateFilters(v *validator.Validator, f Filters) { // Check that the sort parameter matches a value in the safelist v.Check(validator.PermittedValue(f.Sort, f.SortSafelist...), "sort", "invalid sort value") } + +// sortColumn : Check that the client-provided Sort field matches one of the entries in our safelist and if it does, extract the column name from the Sort field by stripping the leading hyphen character (if one exists) +func (f Filters) sortColumn() string { + for _, safeValue := range f.SortSafelist { + if f.Sort == safeValue { + return strings.TrimPrefix(f.Sort, "-") + } + } + + // It will panic if the client-provided 'Sort' value doesn't match one of the entries in our safelist. In theory, this shouldn't happen - the 'Sort' value should have already been checked by calling the 'ValidateFilters()' function - but this is a sensible failsafe to help stop a SQL injection attack occurring + panic("unsafe sort parameter: " + f.Sort) +} + +// sortDirection : Return the sort direction ("ASC" or "DESC") depending on the prefix character of the Sort field +func (f Filters) sortDirection() string { + if strings.HasPrefix(f.Sort, "-") { + return "DESC" + } + return "ASC" +} diff --git a/internal/data/movies.go b/internal/data/movies.go index cb91693..be831ab 100644 --- a/internal/data/movies.go +++ b/internal/data/movies.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "fmt" "github.com/lib/pq" "time" @@ -197,13 +198,14 @@ func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*M // Construct the SQL query to retrieve all movie records. // to_tsvector('simple', title) transforms 'The Breakfast Club' into 'breakfast' 'club' 'the'. The 'simple' parameter's value is the configuration. // plainto_tsquery('simple', $1) takes a search value and turns it into a formatted query term that PostgreSQL full-text search can understand. As an example : "The Club" would result in the query term 'the' & 'club' - // The @@ operator is the matches operator. To continue the example, the query term 'the' & 'club' will match rows which contain both lexemes 'the' and 'club'. - query := ` + // The @@ operator is the matches' operator. To continue the example, the query term 'the' & 'club' will match rows which contain both lexemes 'the' and 'club'. + // Add an ORDER BY clause and interpolate the sort column and direction. Importantly, notice that we also include a secondary sort on the movie ID to ensure a consistent ordering. + query := fmt.Sprintf(` SELECT id, created_at, title, year, runtime, genres, version FROM movies WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '') AND (genres @> $2 OR $2 = '{}') - ORDER BY id` + ORDER BY %s %s, id ASC`, filters.sortColumn(), filters.sortDirection()) // Create a context with a 3-second timeout ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -219,7 +221,7 @@ func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*M defer rows.Close() // Initialize an empty slice to hold the movie data. - movies := []*Movie{} + var movies []*Movie // Use rows.Next to iterate through the rows in the resultset. for rows.Next() {