Skip to content

Instantly share code, notes, and snippets.

@xDarkicex
Created March 16, 2025 07:26
Show Gist options
  • Select an option

  • Save xDarkicex/cc4ae6e3a6817276533baaae86e5edb7 to your computer and use it in GitHub Desktop.

Select an option

Save xDarkicex/cc4ae6e3a6817276533baaae86e5edb7 to your computer and use it in GitHub Desktop.
codebase
// Package nanite provides a lightweight, high-performance HTTP router for Go.
// It is designed to be developer-friendly, inspired by Express.js, and optimized
// for speed and efficiency in routing, grouping, middleware handling, and WebSocket support.
package nanite
import (
"context"
"fmt"
"net"
"net/http"
"slices"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// ### Core Types and Data Structures
// HandlerFunc defines the signature for HTTP request handlers.
// It takes a Context pointer to process the request and send a response.
type HandlerFunc func(*Context)
// WebSocketHandler defines the signature for WebSocket handlers.
// It processes WebSocket connections using a connection and context.
type WebSocketHandler func(*websocket.Conn, *Context)
// MiddlewareFunc defines the signature for middleware functions.
// It takes a Context and a next function to control request flow.
type MiddlewareFunc func(*Context, func())
// Param represents a route parameter with a key-value pair.
// Fields are aligned for simplicity and cache efficiency.
type Param struct {
Key string // Parameter name
Value string // Parameter value
}
// Context holds the state of an HTTP request and response.
// It is optimized with a fixed-size array for params.
type Context struct {
Writer http.ResponseWriter // Response writer for sending data
Request *http.Request // Incoming HTTP request
Params [5]Param // Fixed-size array for route parameters
ParamsCount int // Number of parameters used
Values map[string]interface{} // General-purpose value storage
ValidationErrs ValidationErrors // Validation errors, if any
lazyFields map[string]*LazyField // Lazy validation fields
aborted bool // Flag indicating if request is aborted
}
// ValidationErrors is a slice of ValidationError for multiple validation failures.
type ValidationErrors []ValidationError
// Error returns a string representation of all validation errors.
func (ve ValidationErrors) Error() string {
if len(ve) == 0 {
return "validation failed"
}
var errs []string
for _, e := range ve {
errs = append(errs, fmt.Sprintf("%s: %s", e.Field, e.Error()))
}
return fmt.Sprintf("validation failed: %s", strings.Join(errs, ", "))
}
// ### Router Configuration
// Config holds configuration options for the router.
type Config struct {
NotFoundHandler HandlerFunc // Handler for 404 responses
ErrorHandler func(*Context, error) // Custom error handler
Upgrader *websocket.Upgrader // WebSocket upgrader configuration
WebSocket *WebSocketConfig // WebSocket-specific settings
}
// WebSocketConfig holds configuration options for WebSocket connections.
type WebSocketConfig struct {
ReadTimeout time.Duration // Timeout for reading messages
WriteTimeout time.Duration // Timeout for writing messages
PingInterval time.Duration // Interval for sending pings
MaxMessageSize int64 // Maximum message size in bytes
BufferSize int // Buffer size for read/write operations
}
// ### static route Structure
// staticRoute represents a static route with a handler and parameters.
type staticRoute struct {
handler HandlerFunc
params []Param
}
// ### Radix Structure
// RadixNode represents a node in the radix tree
type RadixNode struct {
// The path segment this node represents
prefix string
// Handler for this route (if terminal)
handler HandlerFunc
// Static children indexed by their first byte for quick lookup
children map[byte]*RadixNode
// Special children for parameters and wildcards
paramChild *RadixNode
wildcardChild *RadixNode
// Parameter/wildcard names if applicable
paramName string
wildcardName string
}
// ### Router Structure
// Router is the main router type that manages HTTP and WebSocket requests.
type Router struct {
trees map[string]*RadixNode // Routing trees by HTTP method
pool sync.Pool // Pool for reusing Context instances
mutex sync.RWMutex // Mutex for thread-safe middleware updates
middleware []MiddlewareFunc // Global middleware stack
config *Config // Router configuration
httpClient *http.Client // HTTP client for proxying or external requests
staticRoutes map[string]map[string]staticRoute // method -> exact path -> handler
}
// ### Router Initialization
// New creates a new Router instance with default configurations.
// It initializes the routing trees, context pool, and WebSocket settings.
func New() *Router {
r := &Router{
trees: make(map[string]*RadixNode),
staticRoutes: make(map[string]map[string]staticRoute),
config: &Config{
WebSocket: &WebSocketConfig{
ReadTimeout: 60 * time.Second,
WriteTimeout: 10 * time.Second,
PingInterval: 30 * time.Second,
MaxMessageSize: 1024 * 1024, // 1MB
BufferSize: 4096,
},
},
httpClient: &http.Client{
Transport: &http.Transport{
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 120 * time.Second,
DisableCompression: false,
ForceAttemptHTTP2: false,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
},
Timeout: 30 * time.Second,
},
}
// Initialize context pool with pre-allocated structures
r.pool.New = func() interface{} {
return &Context{
Params: [5]Param{},
Values: make(map[string]interface{}, 8),
lazyFields: make(map[string]*LazyField),
aborted: false,
}
}
// Set up WebSocket upgrader with default settings
r.config.Upgrader = &websocket.Upgrader{
CheckOrigin: func(*http.Request) bool { return true },
ReadBufferSize: r.config.WebSocket.BufferSize,
WriteBufferSize: r.config.WebSocket.BufferSize,
}
return r
}
// ### Middleware Support
// Use adds one or more middleware functions to the router's global middleware stack.
// These middleware functions will be executed for every request in order.
func (r *Router) Use(middleware ...MiddlewareFunc) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.middleware = append(r.middleware, middleware...)
}
// ### Route Registration
// Get registers a handler for GET requests on the specified path.
func (r *Router) Get(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute("GET", path, handler, middleware...)
}
// Post registers a handler for POST requests on the specified path.
func (r *Router) Post(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute("POST", path, handler, middleware...)
}
// Put registers a handler for PUT requests on the specified path.
func (r *Router) Put(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute("PUT", path, handler, middleware...)
}
// Delete registers a handler for DELETE requests on the specified path.
func (r *Router) Delete(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute("DELETE", path, handler, middleware...)
}
// Patch registers a handler for PATCH requests on the specified path.
func (r *Router) Patch(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute("PATCH", path, handler, middleware...)
}
// Options registers a handler for OPTIONS requests on the specified path.
func (r *Router) Options(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute("OPTIONS", path, handler, middleware...)
}
// Head registers a handler for HEAD requests on the specified path.
func (r *Router) Head(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute("HEAD", path, handler, middleware...)
}
// Handle registers a handler for the specified HTTP method and path.
func (r *Router) Handle(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.addRoute(method, path, handler, middleware...)
}
// ### Server Start Methods
// Start launches the HTTP server on the specified port.
func (r *Router) Start(port string) error {
server := &http.Server{
Addr: ":" + port,
Handler: r,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 60 * time.Second,
MaxHeaderBytes: 1 << 20, // 1MB
ConnState: func(conn net.Conn, state http.ConnState) {
if state == http.StateNew {
tcpConn, ok := conn.(*net.TCPConn)
if ok {
tcpConn.SetReadBuffer(65536)
tcpConn.SetWriteBuffer(65536)
}
}
},
}
fmt.Printf("Nanite server running on port %s\n", port)
return server.ListenAndServe()
}
// StartTLS launches the HTTPS server on the specified port with TLS.
func (r *Router) StartTLS(port, certFile, keyFile string) error {
server := &http.Server{
Addr: ":" + port,
Handler: r,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 60 * time.Second,
MaxHeaderBytes: 1 << 20,
}
fmt.Printf("Nanite server running on port %s with TLS\n", port)
return server.ListenAndServeTLS(certFile, keyFile)
}
// ### WebSocket Support
// WebSocket registers a WebSocket handler for the specified path.
func (r *Router) WebSocket(path string, handler WebSocketHandler, middleware ...MiddlewareFunc) {
r.addRoute("GET", path, r.wrapWebSocketHandler(handler), middleware...)
}
// ### Static File Serving
// ServeStatic serves static files from the specified root directory under the given prefix.
func (r *Router) ServeStatic(prefix, root string) {
if !strings.HasPrefix(prefix, "/") {
prefix = "/" + prefix
}
fs := http.FileServer(http.Dir(root))
handler := func(c *Context) {
http.StripPrefix(prefix, fs).ServeHTTP(c.Writer, c.Request)
}
r.addRoute("GET", prefix+"/*", handler)
r.addRoute("HEAD", prefix+"/*", handler)
}
// ### Helper Functions
// addRoute adds a route to the router's tree for the given method and path.
// It optimizes static routes with a fast path lookup and builds a RadixTree for dynamic routes.
func (r *Router) addRoute(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
r.mutex.Lock()
defer r.mutex.Unlock()
// Pre-build middleware chain (keep this unchanged)
allMiddleware := append(r.middleware, middleware...)
wrapped := handler
for i := len(allMiddleware) - 1; i >= 0; i-- {
mw := allMiddleware[i]
next := wrapped
wrapped = func(c *Context) {
if !c.IsAborted() {
mw(c, func() {
if !c.IsAborted() {
next(c)
}
})
}
}
}
// Keep static routes optimization
isStatic := !strings.Contains(path, ":") && !strings.Contains(path, "*")
if isStatic {
if _, exists := r.staticRoutes[method]; !exists {
r.staticRoutes[method] = make(map[string]staticRoute)
}
r.staticRoutes[method][path] = staticRoute{handler: wrapped, params: []Param{}}
}
// Initialize or use existing method tree
if _, exists := r.trees[method]; !exists {
r.trees[method] = &RadixNode{
prefix: "",
children: make(map[byte]*RadixNode),
}
}
// Insert into radix tree
root := r.trees[method]
if path == "" || path == "/" {
root.handler = wrapped
} else {
// Normalize path
if path[0] != '/' {
path = "/" + path
}
root.insertRoute(path[1:], wrapped) // Skip leading slash
}
}
// findHandlerAndMiddleware finds the handler and parameters for a given method and path.
// It uses a fast path for static routes and falls back to trie traversal for dynamic routes.
func (r *Router) findHandlerAndMiddleware(method, path string) (HandlerFunc, []Param) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Fast path: check static routes first (O(1) lookup)
if methodRoutes, exists := r.staticRoutes[method]; exists {
if route, found := methodRoutes[path]; found {
return route.handler, route.params
}
}
// Use radix tree for dynamic routes
if tree, exists := r.trees[method]; exists {
// Use an empty params slice that we'll populate
params := make([]Param, 0, 5)
// Skip leading slash for consistency with insertRoute
searchPath := path
if len(path) > 0 && path[0] == '/' {
searchPath = path[1:]
}
handler, params := tree.findRoute(searchPath, params)
return handler, params
}
return nil, nil
}
// ServeHTTP implements the http.Handler interface for the router.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Wrap the response writer to track if headers have been sent
trackedWriter := WrapResponseWriter(w)
bufferedWriter := newBufferedResponseWriter(trackedWriter, 4096)
defer bufferedWriter.Close()
// Get a context from the pool
ctx := r.pool.Get().(*Context)
ctx.Writer = bufferedWriter
ctx.Request = req
ctx.ParamsCount = 0 // Reset params count
ctx.ClearValues()
ctx.ClearLazyFields()
ctx.ValidationErrs = nil
ctx.aborted = false
// Ensure context is returned to pool when done
defer func() {
ctx.CleanupPooledResources()
r.pool.Put(ctx)
}()
// Use the request's context for detecting cancellation and timeouts
reqCtx := req.Context()
// Set up a goroutine to monitor for cancellation if the context can be canceled
if reqCtx.Done() != nil {
finished := make(chan struct{})
defer close(finished)
go func() {
select {
case <-reqCtx.Done():
ctx.Abort()
if !trackedWriter.Written() {
statusCode := http.StatusGatewayTimeout
if reqCtx.Err() == context.Canceled {
statusCode = 499 // Client closed request
}
http.Error(trackedWriter, fmt.Sprintf("Request %v", reqCtx.Err()), statusCode)
}
case <-finished:
// Handler completed normally
}
}()
}
// Find the appropriate handler
handler, params := r.findHandlerAndMiddleware(req.Method, req.URL.Path)
if handler == nil {
if r.config.NotFoundHandler != nil {
r.config.NotFoundHandler(ctx)
} else {
http.NotFound(trackedWriter, req)
}
bufferedWriter.Close()
return
}
// Set parameters to context
for i, p := range params {
if i < len(ctx.Params) {
ctx.Params[i] = p
}
}
ctx.ParamsCount = len(params)
// Capture panics from handlers
defer func() {
if err := recover(); err != nil {
ctx.Abort()
if !trackedWriter.Written() {
if r.config.ErrorHandler != nil {
r.config.ErrorHandler(ctx, fmt.Errorf("%v", err))
} else {
http.Error(trackedWriter, "Internal Server Error", http.StatusInternalServerError)
}
} else {
fmt.Printf("Panic occurred after response started: %v\n", err)
}
bufferedWriter.Close()
}
}()
// Execute the handler with middleware
r.mutex.RLock()
allMiddleware := slices.Clone(r.middleware)
r.mutex.RUnlock()
executeMiddlewareChain(ctx, handler, allMiddleware)
if ctx.IsAborted() && !trackedWriter.Written() {
if r.config.NotFoundHandler != nil {
r.config.NotFoundHandler(ctx)
} else {
http.NotFound(trackedWriter, req)
}
bufferedWriter.Close()
}
bufferedWriter.Close()
}
// ### Helper Types and Functions
// longestCommonPrefix finds the longest common prefix of two strings
func longestCommonPrefix(a, b string) int {
max := len(a)
if len(b) < max {
max = len(b)
}
for i := 0; i < max; i++ {
if a[i] != b[i] {
return i
}
}
return max
}
// findRoute searches for a route in the radix tree.
func (n *RadixNode) findRoute(path string, params []Param) (HandlerFunc, []Param) {
// Base case: empty path
if path == "" {
return n.handler, params
}
// Try static children first
if len(path) > 0 {
if child, exists := n.children[path[0]]; exists {
if strings.HasPrefix(path, child.prefix) {
// Remove the prefix from the path
subPath := path[len(child.prefix):]
// IMPORTANT: Remove leading slash if present
if len(subPath) > 0 && subPath[0] == '/' {
subPath = subPath[1:]
}
if handler, subParams := child.findRoute(subPath, params); handler != nil {
return handler, subParams
}
}
}
}
// Try parameter child
if n.paramChild != nil {
// Extract parameter value
i := 0
for i < len(path) && path[i] != '/' {
i++
}
paramValue := path[:i]
remainingPath := ""
if i < len(path) {
remainingPath = path[i:]
if len(remainingPath) > 0 && remainingPath[0] == '/' {
remainingPath = remainingPath[1:] // Skip the slash
}
}
// Add parameter to params
newParams := append(params, Param{Key: n.paramChild.paramName, Value: paramValue})
// If no remaining path, return the handler directly
if remainingPath == "" {
return n.paramChild.handler, newParams
}
// Continue with parameter child
if handler, subParams := n.paramChild.findRoute(remainingPath, newParams); handler != nil {
return handler, subParams
}
}
// Try wildcard as a last resort
if n.wildcardChild != nil {
newParams := append(params, Param{Key: n.wildcardChild.wildcardName, Value: path})
return n.wildcardChild.handler, newParams
}
return nil, nil
}
// insertRoute inserts a route into the radix tree.
func (n *RadixNode) insertRoute(path string, handler HandlerFunc) {
// Base case: empty path
if path == "" {
n.handler = handler
return
}
// Handle parameters (:id)
if path[0] == ':' {
// Extract parameter name and remaining path
paramEnd := strings.IndexByte(path, '/')
var paramName, remainingPath string
if paramEnd == -1 {
paramName = path[1:]
remainingPath = ""
} else {
paramName = path[1:paramEnd]
remainingPath = path[paramEnd:]
}
// Create parameter child if needed
if n.paramChild == nil {
n.paramChild = &RadixNode{
prefix: ":" + paramName,
paramName: paramName,
children: make(map[byte]*RadixNode),
}
}
// Continue with remaining path
if remainingPath == "" {
n.paramChild.handler = handler
} else {
n.paramChild.insertRoute(remainingPath, handler)
}
return
}
// Handle wildcards (*path)
if path[0] == '*' {
n.wildcardChild = &RadixNode{
prefix: path,
handler: handler,
wildcardName: path[1:],
children: make(map[byte]*RadixNode),
}
return
}
// Find the first differing character
var i int
for i = 0; i < len(path); i++ {
if path[i] == '/' || path[i] == ':' || path[i] == '*' {
break
}
}
// Extract the current segment
segment := path[:i]
remainingPath := ""
if i < len(path) {
remainingPath = path[i:]
}
// Add check for empty segment to prevent index out of range panic
if len(segment) == 0 {
// Skip empty segments and continue with remaining path
if remainingPath != "" && len(remainingPath) > 0 {
// If remainingPath starts with a slash, skip it
if remainingPath[0] == '/' {
remainingPath = remainingPath[1:]
}
n.insertRoute(remainingPath, handler)
return
}
// If no remaining path, set handler on current node
n.handler = handler
return
}
// Look for matching child
c, exists := n.children[segment[0]]
if !exists {
// Create new child
c = &RadixNode{
prefix: segment,
children: make(map[byte]*RadixNode),
}
n.children[segment[0]] = c
// Set handler or continue with remaining path
if remainingPath == "" {
c.handler = handler
} else {
c.insertRoute(remainingPath, handler)
}
return
}
// Find common prefix length
commonPrefixLen := longestCommonPrefix(c.prefix, segment)
if commonPrefixLen == len(c.prefix) {
// Child prefix is completely contained in this segment
if commonPrefixLen == len(segment) {
// Exact match, continue with remaining path
if remainingPath == "" {
c.handler = handler
} else {
c.insertRoute(remainingPath, handler)
}
} else {
// Current segment extends beyond child prefix
c.insertRoute(segment[commonPrefixLen:]+remainingPath, handler)
}
} else {
// Need to split the node
child := &RadixNode{
prefix: c.prefix[commonPrefixLen:],
handler: c.handler,
children: c.children,
paramChild: c.paramChild,
wildcardChild: c.wildcardChild,
paramName: c.paramName,
wildcardName: c.wildcardName,
}
// Reset the original child
c.prefix = c.prefix[:commonPrefixLen]
c.handler = nil
c.children = make(map[byte]*RadixNode)
c.paramChild = nil
c.wildcardChild = nil
c.paramName = ""
c.wildcardName = ""
// Add the split node as a child
c.children[child.prefix[0]] = child
// Handle current path
if commonPrefixLen == len(segment) {
// Current segment matches prefix exactly
if remainingPath == "" {
c.handler = handler
} else {
c.insertRoute(remainingPath, handler)
}
} else {
// Current segment extends beyond common prefix
newChild := &RadixNode{
prefix: segment[commonPrefixLen:],
children: make(map[byte]*RadixNode),
}
if remainingPath == "" {
newChild.handler = handler
} else {
newChild.insertRoute(remainingPath, handler)
}
c.children[newChild.prefix[0]] = newChild
}
}
}
package nanite
import (
"encoding/json"
"fmt"
"mime/multipart"
"net/http"
"slices"
)
// ### Context Methods
// Set stores a value in the context's value map.
func (c *Context) Set(key string, value interface{}) {
c.Values[key] = value
}
// Get retrieves a value from the context's value map.
func (c *Context) Get(key string) interface{} {
if c.Values != nil {
return c.Values[key]
}
return nil
}
// Bind decodes the request body into the provided interface.
func (c *Context) Bind(v interface{}) error {
if err := json.NewDecoder(c.Request.Body).Decode(v); err != nil {
return fmt.Errorf("failed to decode JSON: %w", err)
}
return nil
}
// FormValue returns the value of the specified form field.
func (c *Context) FormValue(key string) string {
return c.Request.FormValue(key)
}
// Query returns the value of the specified query parameter.
func (c *Context) Query(key string) string {
return c.Request.URL.Query().Get(key)
}
// GetParam retrieves a route parameter by key, including wildcard (*).
func (c *Context) GetParam(key string) (string, bool) {
for i := 0; i < c.ParamsCount; i++ {
if c.Params[i].Key == key {
return c.Params[i].Value, true
}
}
return "", false
}
// MustParam retrieves a required route parameter or returns an error.
func (c *Context) MustParam(key string) (string, error) {
if val, ok := c.GetParam(key); ok && val != "" {
return val, nil
}
return "", fmt.Errorf("required parameter %s missing or empty", key)
}
// File retrieves a file from the request's multipart form.
func (c *Context) File(key string) (*multipart.FileHeader, error) {
if c.Request.MultipartForm == nil {
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
return nil, fmt.Errorf("failed to parse multipart form: %w", err)
}
}
_, fh, err := c.Request.FormFile(key)
if err != nil {
return nil, fmt.Errorf("failed to get file %s: %w", key, err)
}
return fh, nil
}
// JSON sends a JSON response with the specified status code.
func (c *Context) JSON(status int, data interface{}) {
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(status)
pair := getJSONEncoder()
defer putJSONEncoder(pair)
if err := pair.encoder.Encode(data); err != nil {
http.Error(c.Writer, "Failed to encode JSON", http.StatusInternalServerError)
return
}
c.Writer.Write(pair.buffer.Bytes())
}
// String sends a plain text response with the specified status code.
func (c *Context) String(status int, data string) {
c.Writer.Header().Set("Content-Type", "text/plain")
c.Writer.WriteHeader(status)
c.Writer.Write([]byte(data))
}
// HTML sends an HTML response with the specified status code.
func (c *Context) HTML(status int, html string) {
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
c.Writer.WriteHeader(status)
c.Writer.Write([]byte(html))
}
// SetHeader sets a header on the response writer.
func (c *Context) SetHeader(key, value string) {
c.Writer.Header().Set(key, value)
}
// Status sets the response status code.
func (c *Context) Status(status int) {
c.Writer.WriteHeader(status)
}
// Redirect sends a redirect response to the specified URL.
func (c *Context) Redirect(status int, url string) {
if status < 300 || status > 399 {
c.String(http.StatusBadRequest, "redirect status must be 3xx")
return
}
c.Writer.Header().Set("Location", url)
c.Writer.WriteHeader(status)
}
// Cookie sets a cookie on the response.
func (c *Context) Cookie(name, value string, options ...interface{}) {
cookie := &http.Cookie{Name: name, Value: value}
for i := 0; i < len(options)-1; i += 2 {
if key, ok := options[i].(string); ok {
switch key {
case "MaxAge":
if val, ok := options[i+1].(int); ok {
cookie.MaxAge = val
}
case "Path":
if val, ok := options[i+1].(string); ok {
cookie.Path = val
}
}
}
}
http.SetCookie(c.Writer, cookie)
}
// Abort marks the request as aborted, preventing further processing.
func (c *Context) Abort() {
c.aborted = true
}
// IsAborted checks if the request has been aborted.
func (c *Context) IsAborted() bool {
return c.aborted
}
// ClearValues efficiently clears the Values map without reallocating.
func (c *Context) ClearValues() {
clear(c.Values)
}
// CheckValidation validates all lazy fields and returns true if validation passed
func (c *Context) CheckValidation() bool {
// First validate all lazy fields
fieldsValid := c.ValidateAllFields()
// Check if we have any validation errors
if len(c.ValidationErrs) > 0 {
c.JSON(http.StatusBadRequest, map[string]interface{}{
"errors": c.ValidationErrs,
})
return false
}
return fieldsValid
}
// CleanupPooledResources returns all pooled resources to their respective pools
func (c *Context) CleanupPooledResources() {
// Clean up maps from Values
for k, v := range c.Values {
if m, ok := v.(map[string]interface{}); ok {
putMap(m)
}
delete(c.Values, k)
}
// Clean up lazy fields
c.ClearLazyFields()
// Return ValidationErrs to the pool
if c.ValidationErrs != nil {
putValidationErrors(c.ValidationErrs)
c.ValidationErrs = nil
}
}
// LazyField represents a field that will be validated lazily
type LazyField struct {
name string // The field's name (e.g., "username")
getValue func() string // Function to fetch the raw value from the request
rules []ValidatorFunc // List of validation rules (e.g., regex checks)
validated bool // Tracks if validation has run
value string // Stores the validated value
err *ValidationError // Stores any validation error
}
// Value validates and returns the field value
func (lf *LazyField) Value() (string, *ValidationError) {
if !lf.validated {
rawValue := lf.getValue()
lf.value = rawValue
for _, rule := range lf.rules {
if err := rule(rawValue); err != nil {
lf.err = err // This is now a *ValidationError directly
break
}
}
lf.validated = true
}
return lf.value, lf.err
}
// Field gets or creates a LazyField for the specified field name
func (c *Context) Field(name string) *LazyField {
// Safety net: initialize lazyFields if nil
if c.lazyFields == nil {
c.lazyFields = make(map[string]*LazyField)
}
field, exists := c.lazyFields[name]
if !exists {
// Use the pool instead of direct allocation
field = getLazyField(name, func() string {
// Try fetching from query, params, form, or body
if val := c.Request.URL.Query().Get(name); val != "" {
return val
}
if val, ok := c.GetParam(name); ok {
return val
}
if formData, ok := c.Values["formData"].(map[string]interface{}); ok {
if val, ok := formData[name]; ok {
return fmt.Sprintf("%v", val)
}
}
if body, ok := c.Values["body"].(map[string]interface{}); ok {
if val, ok := body[name]; ok {
return fmt.Sprintf("%v", val)
}
}
return ""
})
c.lazyFields[name] = field
}
return field
}
// In lazy_validation.go, update ValidateAllFields
func (c *Context) ValidateAllFields() bool {
if len(c.lazyFields) == 0 {
return true
}
hasErrors := false
for name, field := range c.lazyFields {
_, err := field.Value()
if err != nil {
if c.ValidationErrs == nil {
c.ValidationErrs = getValidationErrors()
c.ValidationErrs = slices.Grow(c.ValidationErrs, len(c.lazyFields))
}
// Create a copy of the error with the map key as the field name
errorCopy := *err
errorCopy.Field = name // Use the map key as the field name
c.ValidationErrs = append(c.ValidationErrs, errorCopy)
hasErrors = true
}
}
return !hasErrors
}
// ClearLazyFields efficiently clears the LazyFields map without reallocating.
func (c *Context) ClearLazyFields() {
for k, field := range c.lazyFields {
putLazyField(field)
delete(c.lazyFields, k)
}
}
// Package nanite provides a lightweight, high-performance HTTP router for Go.
package nanite
import (
"bytes"
"encoding/json"
"io"
"net/http"
"strings"
)
// Group represents a route group with shared path prefix and middleware.
// It allows organizing routes into logical sections and applying
// common middleware to multiple routes efficiently.
type Group struct {
router *Router // Reference to the parent router
prefix string // Path prefix for all routes in this group
middleware []MiddlewareFunc // Middleware applied to all routes in this group
}
// Group creates a new route group with the given path prefix and optional middleware.
// All routes registered on this group will have the prefix prepended to their paths
// and the middleware applied before their handlers.
//
// Parameters:
// - prefix: The path prefix for all routes in this group
// - middleware: Optional middleware functions to apply to all routes in this group
//
// Returns:
// - *Group: A new route group instance
func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Group {
return &Group{
router: r,
prefix: prefix,
middleware: middleware,
}
}
// Get registers a handler for GET requests on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Get(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Get(fullPath, handler, allMiddleware...)
}
// Post registers a handler for POST requests on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Post(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Post(fullPath, handler, allMiddleware...)
}
// Put registers a handler for PUT requests on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Put(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Put(fullPath, handler, allMiddleware...)
}
// Delete registers a handler for DELETE requests on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Delete(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Delete(fullPath, handler, allMiddleware...)
}
// Patch registers a handler for PATCH requests on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Patch(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Patch(fullPath, handler, allMiddleware...)
}
// Options registers a handler for OPTIONS requests on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Options(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Options(fullPath, handler, allMiddleware...)
}
// Head registers a handler for HEAD requests on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Head(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Head(fullPath, handler, allMiddleware...)
}
// Handle registers a handler for the specified HTTP method on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - method: The HTTP method (GET, POST, PUT, etc.)
// - path: The route path, relative to the group's prefix
// - handler: The handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) Handle(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.Handle(method, fullPath, handler, allMiddleware...)
}
// WebSocket registers a WebSocket handler on the group's path prefix.
// The path is normalized and combined with the group's prefix.
//
// Parameters:
// - path: The route path, relative to the group's prefix
// - handler: The WebSocket handler function to execute for matching requests
// - middleware: Optional route-specific middleware functions
func (g *Group) WebSocket(path string, handler WebSocketHandler, middleware ...MiddlewareFunc) {
fullPath := normalizePath(g.prefix + path)
allMiddleware := append(g.middleware, middleware...)
g.router.WebSocket(fullPath, handler, allMiddleware...)
}
// Group creates a sub-group with an additional prefix and optional middleware.
// The new group inherits all middleware from the parent group.
//
// Parameters:
// - prefix: The additional path prefix for the sub-group
// - middleware: Optional middleware functions specific to the sub-group
//
// Returns:
// - *Group: A new route group instance
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group {
fullPrefix := normalizePath(g.prefix + prefix)
allMiddleware := append(g.middleware, middleware...)
return &Group{
router: g.router,
prefix: fullPrefix,
middleware: allMiddleware,
}
}
// Use adds middleware to the group.
// These middleware functions will be applied to all routes in this group.
//
// Parameters:
// - middleware: The middleware functions to add
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
}
// normalizePath ensures paths start with a slash and don't end with one.
// This optimized version avoids unnecessary allocations for common cases.
//
// Parameters:
// - path: The path to normalize
//
// Returns:
// - string: The normalized path
func normalizePath(path string) string {
// Fast path for empty string
if path == "" {
return "/"
}
// Fast path for root path
if path == "/" {
return "/"
}
// Check if we need to add a leading slash
needsPrefix := path[0] != '/'
// Check if we need to remove a trailing slash
length := len(path)
needsSuffix := length > 1 && path[length-1] == '/'
// Fast path: if no changes needed, return original
if !needsPrefix && !needsSuffix {
return path
}
// Calculate the exact size needed for the new string
newLen := length
if needsPrefix {
newLen++
}
if needsSuffix {
newLen--
}
// Create a new string with the exact capacity needed
var b strings.Builder
b.Grow(newLen)
// Add leading slash if needed
if needsPrefix {
b.WriteByte('/')
}
// Write the path, excluding trailing slash if needed
if needsSuffix {
b.WriteString(path[:length-1])
} else {
b.WriteString(path)
}
return b.String()
}
// ### Validation Middleware
func ValidationMiddleware(chains ...*ValidationChain) MiddlewareFunc {
return func(ctx *Context, next func()) {
if ctx.IsAborted() {
return
}
// Handle request data parsing for POST, PUT, PATCH, DELETE methods
if len(chains) > 0 && (ctx.Request.Method == "POST" || ctx.Request.Method == "PUT" ||
ctx.Request.Method == "PATCH" || ctx.Request.Method == "DELETE") {
contentType := ctx.Request.Header.Get("Content-Type")
// Parse form data (application/x-www-form-urlencoded or multipart/form-data)
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") ||
strings.HasPrefix(contentType, "multipart/form-data") {
if err := ctx.Request.ParseForm(); err != nil {
ve := getValidationError("", "failed to parse form data")
if ctx.ValidationErrs == nil {
ctx.ValidationErrs = make(ValidationErrors, 0, 1)
}
ctx.ValidationErrs = append(ctx.ValidationErrs, *ve)
putValidationError(ve)
ctx.JSON(http.StatusBadRequest, map[string]interface{}{"errors": ctx.ValidationErrs})
return
}
// Store form data in ctx.Values
formData := getMap()
for key, values := range ctx.Request.Form {
if len(values) == 1 {
formData[key] = values[0]
} else {
formData[key] = values
}
}
ctx.Values["formData"] = formData
}
// Parse JSON body (application/json)
if strings.HasPrefix(contentType, "application/json") {
buffer := bufferPool.Get().(*bytes.Buffer)
buffer.Reset()
defer bufferPool.Put(buffer)
if _, err := io.Copy(buffer, ctx.Request.Body); err != nil {
ve := getValidationError("", "failed to read request body")
if ctx.ValidationErrs == nil {
ctx.ValidationErrs = make(ValidationErrors, 0, 1)
}
ctx.ValidationErrs = append(ctx.ValidationErrs, *ve)
putValidationError(ve)
ctx.JSON(http.StatusBadRequest, map[string]interface{}{"errors": ctx.ValidationErrs})
return
}
bodyBytes := buffer.Bytes()
// Restore request body for downstream use
ctx.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
var body map[string]interface{}
if err := json.Unmarshal(bodyBytes, &body); err != nil {
ve := getValidationError("", "invalid JSON")
if ctx.ValidationErrs == nil {
ctx.ValidationErrs = make(ValidationErrors, 0, 1)
}
ctx.ValidationErrs = append(ctx.ValidationErrs, *ve)
putValidationError(ve)
ctx.JSON(http.StatusBadRequest, map[string]interface{}{"errors": ctx.ValidationErrs})
return
}
ctx.Values["body"] = body
}
}
// Attach validation rules to LazyFields
for _, chain := range chains {
field := ctx.Field(chain.field) // Get or create the LazyField
field.rules = append(field.rules, chain.rules...) // Append validation rules
}
// Proceed to the next middleware or handler
next()
for _, chain := range chains {
chain.Release()
}
}
}
// ExecuteMiddleware executes the middleware chain for a route
func executeMiddlewareChain(c *Context, handler HandlerFunc, middleware []MiddlewareFunc) {
// No middleware, just execute the handler
if len(middleware) == 0 {
handler(c)
return
}
// Build the middleware chain
var next func()
var index int
next = func() {
if index < len(middleware) {
currentMiddleware := middleware[index]
index++
currentMiddleware(c, next)
} else {
// End of middleware chain, execute the handler
handler(c)
}
}
// Start the middleware chain
index = 0
next()
}
package nanite
// PathPart represents a single path segment with start/end indices
type PathPart struct {
Start int
End int
}
// PathParser provides zero-allocation path parsing
type PathParser struct {
path string
parts [12]PathPart // Fixed-size array for common case (most URLs have < 12 segments)
partsUsed byte
}
// NewPathParser creates a new parser for the given path
func NewPathParser(path string) PathParser {
parser := PathParser{
path: path,
partsUsed: 0,
}
parser.parse()
return parser
}
// parse splits the path into parts without allocations
func (p *PathParser) parse() {
if p.path == "" || p.path == "/" {
return
}
start := 0
if p.path[0] == '/' {
start = 1
}
for i := start; i < len(p.path); i++ {
if p.path[i] == '/' {
if i > start {
if p.partsUsed < 12 {
p.parts[p.partsUsed] = PathPart{Start: start, End: i}
p.partsUsed++
}
}
start = i + 1
}
}
// Add final part if exists
if start < len(p.path) {
if p.partsUsed < 12 {
p.parts[p.partsUsed] = PathPart{Start: start, End: len(p.path)}
p.partsUsed++
}
}
}
// Count returns the number of path parts
func (p *PathParser) Count() int {
return int(p.partsUsed)
}
// Part returns the path segment at the given index
func (p *PathParser) Part(index int) string {
if index < 0 || index >= int(p.partsUsed) {
return ""
}
part := p.parts[index]
return p.path[part.Start:part.End]
}
// IsParam returns true if the path segment at the given index is a parameter
func (p *PathParser) IsParam(index int) bool {
if index < 0 || index >= int(p.partsUsed) {
return false
}
part := p.parts[index]
return part.End > part.Start && p.path[part.Start] == ':'
}
// IsWildcard returns true if the path segment at the given index is a wildcard
func (p *PathParser) IsWildcard(index int) bool {
if index < 0 || index >= int(p.partsUsed) {
return false
}
part := p.parts[index]
return part.End > part.Start && p.path[part.Start] == '*'
}
// ParamName returns the parameter name at the given index
func (p *PathParser) ParamName(index int) string {
if !p.IsParam(index) && !p.IsWildcard(index) {
return ""
}
part := p.parts[index]
return p.path[part.Start+1 : part.End]
}
package nanite
import (
"bytes"
"net/http"
)
//------------------------------------------------------------------------------
// Buffered Response Writer
//------------------------------------------------------------------------------
// BufferedResponseWriter wraps TrackedResponseWriter with a buffer
type BufferedResponseWriter struct {
*TrackedResponseWriter
buffer *bytes.Buffer
bufferSize int
autoFlush bool
}
// newBufferedResponseWriter creates a new BufferedResponseWriter
func newBufferedResponseWriter(w *TrackedResponseWriter, bufferSize int) *BufferedResponseWriter {
return &BufferedResponseWriter{
TrackedResponseWriter: w,
buffer: bufferPool.Get().(*bytes.Buffer),
bufferSize: bufferSize,
autoFlush: true,
}
}
// Write buffers the data and flushes when buffer exceeds size
func (w *BufferedResponseWriter) Write(b []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
// If this write would exceed buffer size, flush first
if w.buffer.Len()+len(b) > w.bufferSize {
w.Flush()
}
n, err := w.buffer.Write(b)
w.bytesWritten += int64(n)
// Auto-flush if enabled
if w.autoFlush && w.buffer.Len() >= w.bufferSize {
w.Flush()
}
return n, err
}
// Flush writes buffered data to the underlying ResponseWriter
func (w *BufferedResponseWriter) Flush() {
if w == nil || w.TrackedResponseWriter == nil || w.buffer == nil {
return
}
if w.buffer.Len() > 0 {
// Only attempt to write if we have something to write
w.TrackedResponseWriter.Write(w.buffer.Bytes())
w.buffer.Reset()
}
}
// Close returns the buffer to the pool
func (w *BufferedResponseWriter) Close() {
if w == nil {
return
}
if w.buffer != nil {
w.Flush()
bufferPool.Put(w.buffer)
w.buffer = nil
}
}
// Package nanite provides a lightweight, high-performance HTTP router for Go
// with optimized memory management through sync.Pool implementations.
package nanite
import (
"bytes"
"encoding/json"
"sync"
)
//------------------------------------------------------------------------------
// Map Pool
//------------------------------------------------------------------------------
// mapPool is a pool of reusable map[string]interface{} objects.
// This reduces garbage collection pressure by reusing map allocations.
var mapPool = sync.Pool{
New: func() interface{} {
return make(map[string]interface{}, 16)
},
}
// getMap retrieves a map from the pool or creates a new one if necessary.
// The returned map is guaranteed to be empty and ready for use.
//
// Returns:
// - map[string]interface{}: An empty map with pre-allocated capacity
func getMap() map[string]interface{} {
m := mapPool.Get()
if m == nil {
// If the pool returns nil, create a new map with default capacity
return make(map[string]interface{}, 16)
}
// Type assertion with ok check to handle unexpected types
mapValue, ok := m.(map[string]interface{})
if !ok {
// If type assertion fails, create a new map
return make(map[string]interface{}, 16)
}
return mapValue
}
// putMap returns a map to the pool after clearing its contents.
// This ensures the map is empty when it's reused.
//
// Parameters:
// - m: The map to return to the pool
func putMap(m map[string]interface{}) {
clear(m) // Clear the map to prevent memory leaks using go 1.21 feature clear()
mapPool.Put(m)
}
//------------------------------------------------------------------------------
// Buffer Pool
//------------------------------------------------------------------------------
// bufferPool is a pool of reusable bytes.Buffer objects.
// Used primarily for efficient request body handling.
var bufferPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
//------------------------------------------------------------------------------
// Validation Error Pool
//------------------------------------------------------------------------------
// validationErrorPool is a pool of reusable ValidationError objects.
// This reduces allocations during request validation.
var validationErrorPool = sync.Pool{
New: func() interface{} {
return &ValidationError{}
},
}
// getValidationError retrieves a ValidationError from the pool and initializes it.
//
// Parameters:
// - field: The field name that failed validation
// - errorMsg: The error message describing the validation failure
//
// Returns:
// - *ValidationError: An initialized ValidationError
func getValidationError(field, errorMsg string) *ValidationError {
ve := validationErrorPool.Get().(*ValidationError)
ve.Field = field
ve.Err = errorMsg
return ve
}
// putValidationError returns a ValidationError to the pool after clearing its state.
//
// Parameters:
// - ve: The ValidationError to return to the pool
func putValidationError(ve *ValidationError) {
// Clear the fields to prevent memory leaks
ve.Field = ""
ve.Err = ""
validationErrorPool.Put(ve)
}
//------------------------------------------------------------------------------
// ValidationErrors Slice Pool
//------------------------------------------------------------------------------
// validationErrorsPool is a pool of reusable ValidationErrors slices.
// This reduces allocations when collecting multiple validation errors.
var validationErrorsPool = sync.Pool{
New: func() interface{} {
return make(ValidationErrors, 0, 8)
},
}
// getValidationErrors retrieves a ValidationErrors slice from the pool.
// The returned slice has zero length but pre-allocated capacity.
//
// Returns:
// - ValidationErrors: An empty slice with pre-allocated capacity
func getValidationErrors() ValidationErrors {
return validationErrorsPool.Get().(ValidationErrors)[:0]
}
// putValidationErrors returns a ValidationErrors slice to the pool.
//
// Parameters:
// - ve: The ValidationErrors slice to return to the pool
func putValidationErrors(ve ValidationErrors) {
if cap(ve) > 0 {
validationErrorsPool.Put(ve[:0])
}
}
//------------------------------------------------------------------------------
// LazyField Pool
//------------------------------------------------------------------------------
// lazyFieldPool is a pool of reusable LazyField objects.
// LazyFields are used for deferred validation of request parameters.
var lazyFieldPool = sync.Pool{
New: func() interface{} {
return &LazyField{
rules: make([]ValidatorFunc, 0, 5),
}
},
}
// getLazyField retrieves a LazyField from the pool and initializes it.
//
// Parameters:
// - name: The field name
// - getValue: Function that retrieves the raw value from the request
//
// Returns:
// - *LazyField: An initialized LazyField ready for validation rules
func getLazyField(name string, getValue func() string) *LazyField {
lf := lazyFieldPool.Get().(*LazyField)
lf.name = name
lf.getValue = getValue
lf.validated = false
lf.value = ""
lf.err = nil
return lf
}
// putLazyField returns a LazyField to the pool after clearing its state.
// This prevents memory leaks from lingering references.
//
// Parameters:
// - lf: The LazyField to return to the pool
func putLazyField(lf *LazyField) {
lf.name = ""
lf.getValue = nil
lf.rules = lf.rules[:0]
lf.validated = false
lf.value = ""
lf.err = nil
lazyFieldPool.Put(lf)
}
//------------------------------------------------------------------------------
// ValidationChain Pool
//------------------------------------------------------------------------------
// validationChainPool is a pool of reusable ValidationChain objects.
// ValidationChains are used to build validation rules for request fields.
var validationChainPool = sync.Pool{
New: func() interface{} {
return &ValidationChain{
rules: make([]ValidatorFunc, 0, 10), // Pre-allocate for efficiency
}
},
}
// getValidationChain retrieves a ValidationChain from the pool and initializes it.
//
// Parameters:
// - field: The field name to validate
//
// Returns:
// - *ValidationChain: An initialized ValidationChain ready for rules
func getValidationChain(field string) *ValidationChain {
vc := validationChainPool.Get().(*ValidationChain)
vc.field = field
vc.rules = vc.rules[:0] // Clear but reuse the slice
return vc
}
// putValidationChain returns a ValidationChain to the pool after clearing its state.
//
// Parameters:
// - vc: The ValidationChain to return to the pool
func putValidationChain(vc *ValidationChain) {
vc.field = ""
vc.rules = vc.rules[:0]
validationChainPool.Put(vc)
}
//------------------------------------------------------------------------------
// Utility Functions
//------------------------------------------------------------------------------
// cleanupNestedMaps iteratively cleans up nested maps and slices to prevent memory leaks.
// This function uses a stack-based approach to avoid recursion. Not currently used,
// this is intended for future use.
//
// Parameters:
// - value: The root object to clean up (typically a map or slice)
func cleanupNestedMaps(value interface{}) {
// Create a stack to hold values that need processing
stack := make([]interface{}, 0, 8)
stack = append(stack, value)
// Process items until stack is empty
for len(stack) > 0 {
// Pop the last item from the stack
current := stack[len(stack)-1]
stack = stack[:len(stack)-1]
// Process maps
if m, ok := current.(map[string]interface{}); ok {
// Add all map values to the stack for processing
for _, v := range m {
if subMap, ok := v.(map[string]interface{}); ok {
stack = append(stack, subMap)
} else if arr, ok := v.([]interface{}); ok {
stack = append(stack, arr)
}
}
// Return the map to the pool
putMap(m)
} else if arr, ok := current.([]interface{}); ok {
// Add all array elements to the stack for processing
for _, v := range arr {
if subMap, ok := v.(map[string]interface{}); ok {
stack = append(stack, subMap)
} else if subArr, ok := v.([]interface{}); ok {
stack = append(stack, subArr)
}
}
}
}
}
//------------------------------------------------------------------------------
// JSON Encoder Pool
//------------------------------------------------------------------------------
// jsonEncoderBufferPair holds a reusable encoder and buffer pair
type jsonEncoderBufferPair struct {
encoder *json.Encoder
buffer *bytes.Buffer
}
// jsonEncoderPool is a pool of reusable json.Encoder and buffer pairs
var jsonEncoderPool = sync.Pool{
New: func() interface{} {
buffer := new(bytes.Buffer)
return &jsonEncoderBufferPair{
encoder: json.NewEncoder(buffer),
buffer: buffer,
}
},
}
// getJSONEncoder retrieves an encoder/buffer pair from the pool
func getJSONEncoder() *jsonEncoderBufferPair {
pair := jsonEncoderPool.Get().(*jsonEncoderBufferPair)
pair.buffer.Reset()
// Reconnect the encoder to the buffer after reset
pair.encoder = json.NewEncoder(pair.buffer)
return pair
}
// putJSONEncoder returns an encoder/buffer pair to the pool
func putJSONEncoder(pair *jsonEncoderBufferPair) {
jsonEncoderPool.Put(pair)
}
package nanite
import (
"bufio"
"fmt"
"net"
"net/http"
)
// TrackedResponseWriter wraps http.ResponseWriter to track if headers have been sent.
type TrackedResponseWriter struct {
http.ResponseWriter
statusCode int
headerWritten bool
bytesWritten int64
}
// WrapResponseWriter creates a new TrackedResponseWriter.
func WrapResponseWriter(w http.ResponseWriter) *TrackedResponseWriter {
return &TrackedResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
}
// WriteHeader records that headers have been written.
func (w *TrackedResponseWriter) WriteHeader(statusCode int) {
if !w.headerWritten {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
w.headerWritten = true
}
}
// Write records that data (and implicitly headers) have been written.
func (w *TrackedResponseWriter) Write(b []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
n, err := w.ResponseWriter.Write(b)
w.bytesWritten += int64(n)
return n, err
}
// Status returns the HTTP status code that was set.
func (w *TrackedResponseWriter) Status() int {
return w.statusCode
}
// Written returns whether headers have been sent.
func (w *TrackedResponseWriter) Written() bool {
return w.headerWritten
}
// BytesWritten returns the number of bytes written.
func (w *TrackedResponseWriter) BytesWritten() int64 {
return w.bytesWritten
}
// Unwrap returns the original ResponseWriter.
func (w *TrackedResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
// Flush implements http.Flusher interface if the underlying writer supports it.
func (w *TrackedResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// Hijack implements http.Hijacker interface if the underlying writer supports it.
func (w *TrackedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker")
}
// Push implements http.Pusher interface if the underlying writer supports it.
func (w *TrackedResponseWriter) Push(target string, opts *http.PushOptions) error {
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
return pusher.Push(target, opts)
}
return fmt.Errorf("underlying ResponseWriter does not implement http.Pusher")
}
package nanite
import (
"fmt"
"regexp"
"strconv"
"strings"
)
var (
errRequired = "field is required"
errInvalidFormat = "invalid format"
errMustBeNumber = "must be a number"
errMustBeBoolean = "must be a boolean value"
)
// ValidationError represents a single validation error with field and message.
type ValidationError struct {
Field string `json:"field"` // Field name that failed validation
Err string `json:"error"` // Error message describing the failure
}
// Error implements the error interface.
func (ve *ValidationError) Error() string {
return fmt.Sprintf("%s: %s", ve.Field, ve.Err)
}
// ValidatorFunc defines the signature for validation functions.
// It validates a string value and returns a pre-allocated ValidationError object.
type ValidatorFunc func(string) *ValidationError
// ValidationChain represents a chain of validation rules for a field.
type ValidationChain struct {
field string
rules []ValidatorFunc
preAllocatedErrors []*ValidationError
}
// ### Validation Support
// IsObject adds a rule that the field must be a JSON object.
func (vc *ValidationChain) IsObject() *ValidationChain {
// Pre-allocate the error
errObj := getValidationError(vc.field, "must be an object")
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
if !strings.HasPrefix(value, "{") || !strings.HasSuffix(value, "}") {
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// IsArray adds a rule that the field must be a JSON array.
func (vc *ValidationChain) IsArray() *ValidationChain {
// Pre-allocate the error
errObj := getValidationError(vc.field, "must be an array")
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
if !strings.HasPrefix(value, "[") || !strings.HasSuffix(value, "]") {
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// Custom adds a custom validation function to the chain.
func (vc *ValidationChain) Custom(fn func(string) error) *ValidationChain {
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
if err := fn(value); err != nil {
// This is the one case where we have to create an error at runtime
// since we can't know the error message in advance
return getValidationError(vc.field, err.Error())
}
return nil
})
return vc
}
// OneOf adds a rule that the field must be one of the specified options.
func (vc *ValidationChain) OneOf(options ...string) *ValidationChain {
// Pre-allocate the error - format the message at setup time
message := fmt.Sprintf("must be one of: %s", strings.Join(options, ", "))
errObj := getValidationError(vc.field, message)
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
for _, option := range options {
if value == option {
return nil
}
}
return vc.preAllocatedErrors[errIndex]
})
return vc
}
// Matches adds a rule that the field must match the specified regular expression.
func (vc *ValidationChain) Matches(pattern string) *ValidationChain {
// Pre-compile the regex at setup time instead of per-request
re, err := regexp.Compile(pattern)
// Handle invalid pattern at setup time
if err != nil {
invalidPatternErr := getValidationError(vc.field, "invalid validation pattern")
vc.preAllocatedErrors = append(vc.preAllocatedErrors, invalidPatternErr)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
return vc.preAllocatedErrors[errIndex]
})
return vc
}
// Pre-allocate for invalid format error
formatErr := getValidationError(vc.field, errInvalidFormat)
vc.preAllocatedErrors = append(vc.preAllocatedErrors, formatErr)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
if !re.MatchString(value) {
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// Length adds a rule that the field must have a length within specified range
func (vc *ValidationChain) Length(min, maxLength int) *ValidationChain {
// Pre-allocate error messages
tooShortErr := getValidationError(vc.field, fmt.Sprintf("must be at least %d characters", min))
tooLongErr := getValidationError(vc.field, fmt.Sprintf("must be at most %d characters", maxLength))
vc.preAllocatedErrors = append(vc.preAllocatedErrors, tooShortErr, tooLongErr)
// Store indices
tooShortIndex := len(vc.preAllocatedErrors) - 2
tooLongIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
length := len(value)
if length < min {
return vc.preAllocatedErrors[tooShortIndex]
}
if length > maxLength {
return vc.preAllocatedErrors[tooLongIndex]
}
return nil
})
return vc
}
// Max adds a rule that the field must be at most a specified integer value.
func (vc *ValidationChain) Max(max int) *ValidationChain {
// Pre-allocate errors
numErr := getValidationError(vc.field, errMustBeNumber)
maxErr := getValidationError(vc.field, fmt.Sprintf("must be at most %d", max))
vc.preAllocatedErrors = append(vc.preAllocatedErrors, numErr, maxErr)
numErrIndex := len(vc.preAllocatedErrors) - 2
maxErrIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
num, err := strconv.Atoi(value)
if err != nil {
return vc.preAllocatedErrors[numErrIndex]
}
if num > max {
return vc.preAllocatedErrors[maxErrIndex]
}
return nil
})
return vc
}
// Min adds a rule that the field must be at least a specified integer value.
func (vc *ValidationChain) Min(min int) *ValidationChain {
// Pre-allocate errors
numErr := getValidationError(vc.field, errMustBeNumber)
minErr := getValidationError(vc.field, fmt.Sprintf("must be at least %d", min))
vc.preAllocatedErrors = append(vc.preAllocatedErrors, numErr, minErr)
numErrIndex := len(vc.preAllocatedErrors) - 2
minErrIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
num, err := strconv.Atoi(value)
if err != nil {
return vc.preAllocatedErrors[numErrIndex]
}
if num < min {
return vc.preAllocatedErrors[minErrIndex]
}
return nil
})
return vc
}
// IsBoolean adds a rule that the field must be a boolean value.
func (vc *ValidationChain) IsBoolean() *ValidationChain {
// Pre-allocate the error
errObj := getValidationError(vc.field, errMustBeBoolean)
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
lowerVal := strings.ToLower(value)
if lowerVal != "true" && lowerVal != "false" && lowerVal != "1" && lowerVal != "0" {
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// IsFloat adds a rule that the field must be a floating-point number.
func (vc *ValidationChain) IsFloat() *ValidationChain {
// Pre-allocate the error
errObj := getValidationError(vc.field, errMustBeNumber)
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
if _, err := strconv.ParseFloat(value, 64); err != nil {
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// IsInt adds a rule that the field must be an integer.
func (vc *ValidationChain) IsInt() *ValidationChain {
// Pre-allocate the error
errObj := getValidationError(vc.field, "must be an integer")
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
if _, err := strconv.Atoi(value); err != nil {
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// IsEmail adds a rule that the field must be a valid email address.
func (vc *ValidationChain) IsEmail() *ValidationChain {
// Pre-allocate the error
errObj := getValidationError(vc.field, "invalid email format")
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
return nil
}
if !strings.Contains(value, "@") || !strings.Contains(value, ".") {
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// Required adds a rule that the field must not be empty.
func (vc *ValidationChain) Required() *ValidationChain {
errObj := getValidationError(vc.field, errRequired)
vc.preAllocatedErrors = append(vc.preAllocatedErrors, errObj)
// Store index for this specific error
errIndex := len(vc.preAllocatedErrors) - 1
vc.rules = append(vc.rules, func(value string) *ValidationError {
if value == "" {
// Return the pre-allocated error directly
return vc.preAllocatedErrors[errIndex]
}
return nil
})
return vc
}
// NewValidationChain creates a new ValidationChain for the specified field.
func NewValidationChain(field string) *ValidationChain {
return getValidationChain(field)
}
// Release returns the ValidationChain to the pool
func (vc *ValidationChain) Release() {
// Return all pre-allocated errors to the pool
for _, err := range vc.preAllocatedErrors {
putValidationError(err)
}
vc.preAllocatedErrors = vc.preAllocatedErrors[:0]
vc.rules = vc.rules[:0]
putValidationChain(vc)
}
package nanite
import (
"context"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// ### WebSocket Wrapper
// wrapWebSocketHandler wraps a WebSocketHandler into a HandlerFunc.
// wrapWebSocketHandler wraps a WebSocketHandler into a HandlerFunc.
func (r *Router) wrapWebSocketHandler(handler WebSocketHandler) HandlerFunc {
return func(ctx *Context) {
conn, err := r.config.Upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
if err != nil {
http.Error(ctx.Writer, "Failed to upgrade to WebSocket", http.StatusBadRequest)
return
}
conn.SetReadLimit(r.config.WebSocket.MaxMessageSize)
wsCtx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
cleanup := func() {
// Cancel context to signal all goroutines to stop
cancel()
// Close the connection
conn.Close()
// Wait for all goroutines to finish
wg.Wait()
// Clean up any pooled objects
ctx.CleanupPooledResources()
}
defer cleanup()
// Set up ping handler for connection keepalive
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(r.config.WebSocket.ReadTimeout))
return nil
})
// Start ping goroutine
wg.Add(1)
go func() {
defer wg.Done()
pingTicker := time.NewTicker(r.config.WebSocket.PingInterval)
defer pingTicker.Stop()
for {
select {
case <-pingTicker.C:
conn.SetWriteDeadline(time.Now().Add(r.config.WebSocket.WriteTimeout))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
case <-wsCtx.Done():
return
}
}
}()
// Monitor for server shutdown
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Request.Context().Done():
conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down"),
time.Now().Add(time.Second),
)
case <-wsCtx.Done():
}
}()
// Set initial read deadline
conn.SetReadDeadline(time.Now().Add(r.config.WebSocket.ReadTimeout))
// Call the actual handler
handler(conn, ctx)
// Send normal closure message
conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
time.Now().Add(time.Second),
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment