Skip to content

Instantly share code, notes, and snippets.

@shapeless-space
Last active September 14, 2024 18:20
Show Gist options
  • Select an option

  • Save shapeless-space/44886f7d345c0b24e9328bafa14465ae to your computer and use it in GitHub Desktop.

Select an option

Save shapeless-space/44886f7d345c0b24e9328bafa14465ae to your computer and use it in GitHub Desktop.
package main
import (
"context"
"database/sql"
"fmt"
"log"
"math/rand"
_ "github.com/lib/pq"
)
type Decoder[T any] func(rows *sql.Rows) (T, error)
// NewRandomCursor generates a random cursor for pagination.
// Placeholder for a more sophisticated cursor generator.
func NewRandomCursor() (string, error) {
return fmt.Sprintf("cursor_%d", rand.Intn(1000)), nil
}
func Paginate[T any](
ctx context.Context,
db *sql.DB,
query string,
batchSize int,
decoder Decoder[T],
) (func(func(T, error) bool), error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("error starting transaction: %w", err)
}
cursor, err := NewRandomCursor()
if err != nil {
return nil, fmt.Errorf("error generating cursor: %w", err)
}
return func(yield func(T, error) bool) {
defer func() {
_ = tx.Rollback()
}()
_, err = tx.ExecContext(ctx, fmt.Sprintf("DECLARE %s CURSOR FOR %s", cursor, query))
if err != nil {
log.Printf("Error declaring cursor: %v", err)
return
}
queries := 0
for {
queries++
log.Printf("Query %d", queries)
page, ok, err := ReadPage[T](
func() (*sql.Rows, error) {
return tx.QueryContext(ctx, fmt.Sprintf("FETCH %d FROM %s", batchSize, cursor))
},
decoder,
batchSize,
)
if err != nil {
var unit T
yield(unit, err)
}
for _, row := range page {
if !yield(row, nil) {
return
}
}
if !ok {
break
}
}
}, nil
}
func ReadPage[T any](
next func() (*sql.Rows, error),
decoder Decoder[T],
size int,
) ([]T, bool, error) {
rows, err := next()
if err != nil {
return nil, false, fmt.Errorf("error fetching from cursor: %w", err)
}
defer func() {
_ = rows.Close()
}()
hasRows := false
results := make([]T, 0, size)
for rows.Next() {
hasRows = true
decoded, err := decoder(rows)
if err != nil {
return nil, false, fmt.Errorf("error decoding row: %w", err)
}
results = append(results, decoded)
}
if rows.Err() != nil {
return nil, false, fmt.Errorf("error iterating rows: %w", rows.Err())
}
return results, hasRows, nil
}
type Entry struct {
ID int
Text string
}
func DecodeEntry(rows *sql.Rows) (Entry, error) {
var entry Entry
if err := rows.Scan(&entry.ID, &entry.Text); err != nil {
return Entry{}, fmt.Errorf("error scanning row: %w", err)
}
return entry, nil
}
func main() {
ctx := context.Background()
connStr := "user=dev dbname=generators sslmode=disable"
db, err := sql.Open("postgres", connStr)
if err != nil {
log.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = SetupDatabase(db)
if err != nil {
log.Fatal(err)
}
query := "SELECT id, text FROM test ORDER BY id"
batchSize := 2
pagination, err := Paginate[Entry](ctx, db, query, batchSize, DecodeEntry)
if err != nil {
log.Fatal(err)
}
for row, err := range pagination {
if err != nil {
log.Printf("Error fetching rows: %v", err)
return
}
log.Printf("Row: %v", row)
}
}
func SetupDatabase(db *sql.DB) error {
ctx := context.Background()
_, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS test")
if err != nil {
return fmt.Errorf("error dropping table: %w", err)
}
_, err = db.ExecContext(ctx, "CREATE TABLE test (id SERIAL PRIMARY KEY, text TEXT)")
if err != nil {
return fmt.Errorf("error creating table: %w", err)
}
for i := 0; i < 10; i++ {
_, err = db.ExecContext(ctx, "INSERT INTO test (text) VALUES ($1)", fmt.Sprintf("row %d", i))
if err != nil {
return fmt.Errorf("error inserting row: %w", err)
}
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment