Last active
September 14, 2024 18:20
-
-
Save shapeless-space/44886f7d345c0b24e9328bafa14465ae to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package main | |
| import ( | |
| "context" | |
| "database/sql" | |
| "fmt" | |
| "log" | |
| "math/rand" | |
| _ "github.com/lib/pq" | |
| ) | |
| type Decoder[T any] func(rows *sql.Rows) (T, error) | |
| // NewRandomCursor generates a random cursor for pagination. | |
| // Placeholder for a more sophisticated cursor generator. | |
| func NewRandomCursor() (string, error) { | |
| return fmt.Sprintf("cursor_%d", rand.Intn(1000)), nil | |
| } | |
| func Paginate[T any]( | |
| ctx context.Context, | |
| db *sql.DB, | |
| query string, | |
| batchSize int, | |
| decoder Decoder[T], | |
| ) (func(func(T, error) bool), error) { | |
| tx, err := db.BeginTx(ctx, nil) | |
| if err != nil { | |
| return nil, fmt.Errorf("error starting transaction: %w", err) | |
| } | |
| cursor, err := NewRandomCursor() | |
| if err != nil { | |
| return nil, fmt.Errorf("error generating cursor: %w", err) | |
| } | |
| return func(yield func(T, error) bool) { | |
| defer func() { | |
| _ = tx.Rollback() | |
| }() | |
| _, err = tx.ExecContext(ctx, fmt.Sprintf("DECLARE %s CURSOR FOR %s", cursor, query)) | |
| if err != nil { | |
| log.Printf("Error declaring cursor: %v", err) | |
| return | |
| } | |
| queries := 0 | |
| for { | |
| queries++ | |
| log.Printf("Query %d", queries) | |
| page, ok, err := ReadPage[T]( | |
| func() (*sql.Rows, error) { | |
| return tx.QueryContext(ctx, fmt.Sprintf("FETCH %d FROM %s", batchSize, cursor)) | |
| }, | |
| decoder, | |
| batchSize, | |
| ) | |
| if err != nil { | |
| var unit T | |
| yield(unit, err) | |
| } | |
| for _, row := range page { | |
| if !yield(row, nil) { | |
| return | |
| } | |
| } | |
| if !ok { | |
| break | |
| } | |
| } | |
| }, nil | |
| } | |
| func ReadPage[T any]( | |
| next func() (*sql.Rows, error), | |
| decoder Decoder[T], | |
| size int, | |
| ) ([]T, bool, error) { | |
| rows, err := next() | |
| if err != nil { | |
| return nil, false, fmt.Errorf("error fetching from cursor: %w", err) | |
| } | |
| defer func() { | |
| _ = rows.Close() | |
| }() | |
| hasRows := false | |
| results := make([]T, 0, size) | |
| for rows.Next() { | |
| hasRows = true | |
| decoded, err := decoder(rows) | |
| if err != nil { | |
| return nil, false, fmt.Errorf("error decoding row: %w", err) | |
| } | |
| results = append(results, decoded) | |
| } | |
| if rows.Err() != nil { | |
| return nil, false, fmt.Errorf("error iterating rows: %w", rows.Err()) | |
| } | |
| return results, hasRows, nil | |
| } | |
| type Entry struct { | |
| ID int | |
| Text string | |
| } | |
| func DecodeEntry(rows *sql.Rows) (Entry, error) { | |
| var entry Entry | |
| if err := rows.Scan(&entry.ID, &entry.Text); err != nil { | |
| return Entry{}, fmt.Errorf("error scanning row: %w", err) | |
| } | |
| return entry, nil | |
| } | |
| func main() { | |
| ctx := context.Background() | |
| connStr := "user=dev dbname=generators sslmode=disable" | |
| db, err := sql.Open("postgres", connStr) | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| defer func() { | |
| _ = db.Close() | |
| }() | |
| err = SetupDatabase(db) | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| query := "SELECT id, text FROM test ORDER BY id" | |
| batchSize := 2 | |
| pagination, err := Paginate[Entry](ctx, db, query, batchSize, DecodeEntry) | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| for row, err := range pagination { | |
| if err != nil { | |
| log.Printf("Error fetching rows: %v", err) | |
| return | |
| } | |
| log.Printf("Row: %v", row) | |
| } | |
| } | |
| func SetupDatabase(db *sql.DB) error { | |
| ctx := context.Background() | |
| _, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS test") | |
| if err != nil { | |
| return fmt.Errorf("error dropping table: %w", err) | |
| } | |
| _, err = db.ExecContext(ctx, "CREATE TABLE test (id SERIAL PRIMARY KEY, text TEXT)") | |
| if err != nil { | |
| return fmt.Errorf("error creating table: %w", err) | |
| } | |
| for i := 0; i < 10; i++ { | |
| _, err = db.ExecContext(ctx, "INSERT INTO test (text) VALUES ($1)", fmt.Sprintf("row %d", i)) | |
| if err != nil { | |
| return fmt.Errorf("error inserting row: %w", err) | |
| } | |
| } | |
| return nil | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment