package postgrest
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
)
// Builder is the base builder for PostgREST queries
// Similar to PostgrestBuilder in postgrest-js
type Builder[T any] struct {
method string
url *url.URL
headers http.Header
schema string
body interface{}
shouldThrowOnError bool
signal context.Context
client *Client
isMaybeSingle bool
}
// NewBuilder creates a new Builder instance
func NewBuilder[T any](client *Client, method string, url *url.URL, opts *BuilderOptions) *Builder[T] {
if opts == nil {
opts = &BuilderOptions{}
}
b := &Builder[T]{
method: method,
url: url,
headers: make(http.Header),
schema: opts.Schema,
body: opts.Body,
shouldThrowOnError: opts.ShouldThrowOnError,
signal: opts.Signal,
client: client,
isMaybeSingle: opts.IsMaybeSingle,
}
// Copy headers from client
if client != nil && client.Transport != nil {
client.Transport.mu.RLock()
for key, values := range client.Transport.header {
for _, val := range values {
b.headers.Add(key, val)
}
}
client.Transport.mu.RUnlock()
}
// Copy additional headers
if opts.Headers != nil {
for key, values := range opts.Headers {
for _, val := range values {
b.headers.Add(key, val)
}
}
}
return b
}
// BuilderOptions contains options for creating a Builder
type BuilderOptions struct {
Headers http.Header
Schema string
Body interface{}
ShouldThrowOnError bool
Signal context.Context
IsMaybeSingle bool
}
// ThrowOnError sets the builder to throw errors instead of returning them
func (b *Builder[T]) ThrowOnError() *Builder[T] {
b.shouldThrowOnError = true
return b
}
// SetHeader sets an HTTP header for the request
func (b *Builder[T]) SetHeader(name, value string) *Builder[T] {
b.headers.Set(name, value)
return b
}
// Execute executes the query and returns the response
func (b *Builder[T]) Execute(ctx context.Context) (*PostgrestResponse[T], error) {
if ctx == nil {
ctx = context.Background()
}
if b.signal != nil {
ctx = b.signal
}
// Set schema headers
if b.schema != "" {
if b.method == "GET" || b.method == "HEAD" {
b.headers.Set("Accept-Profile", b.schema)
} else {
b.headers.Set("Content-Profile", b.schema)
}
}
// Set Content-Type for non-GET/HEAD requests
if b.method != "GET" && b.method != "HEAD" {
b.headers.Set("Content-Type", "application/json")
}
// Prepare request body
var bodyReader io.Reader
if b.body != nil {
bodyBytes, err := json.Marshal(b.body)
if err != nil {
return nil, fmt.Errorf("error marshaling body: %w", err)
}
bodyReader = bytes.NewBuffer(bodyBytes)
}
// Check if context is already canceled
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Create request
req, err := http.NewRequestWithContext(ctx, b.method, b.url.String(), bodyReader)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
// Set headers
for key, values := range b.headers {
for _, val := range values {
req.Header.Add(key, val)
}
}
// Execute request
resp, err := b.client.session.Do(req)
if err != nil {
// Check if error is due to context cancellation
if ctx.Err() != nil {
return nil, ctx.Err()
}
if b.shouldThrowOnError {
return nil, err
}
return &PostgrestResponse[T]{
Error: NewPostgrestError(
fmt.Sprintf("FetchError: %s", err.Error()),
fmt.Sprintf("%v", err),
"",
"",
),
Data: *new(T),
Count: nil,
Status: 0,
StatusText: "",
}, nil
}
defer resp.Body.Close()
// Read response body
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response: %w", err)
}
// Parse response
response := &PostgrestResponse[T]{
Status: resp.StatusCode,
StatusText: resp.Status,
}
// Handle errors
if resp.StatusCode >= 400 {
var errorData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &errorData); err != nil {
// Workaround for https://github.com/supabase/postgrest-js/issues/295
if resp.StatusCode == 404 && len(bodyBytes) == 0 {
response.Status = 204
response.StatusText = "No Content"
return response, nil
}
response.Error = NewPostgrestError(
string(bodyBytes),
"",
"",
"",
)
return response, nil
}
// Workaround for https://github.com/supabase/postgrest-js/issues/295
if resp.StatusCode == 404 && len(bodyBytes) > 0 {
var arr []interface{}
if err := json.Unmarshal(bodyBytes, &arr); err == nil {
response.Data = *new(T)
response.Status = 200
response.StatusText = "OK"
return response, nil
}
}
errorMsg := ""
if msg, ok := errorData["message"].(string); ok {
errorMsg = msg
}
errorDetails := ""
if details, ok := errorData["details"].(string); ok {
errorDetails = details
}
errorHint := ""
if hint, ok := errorData["hint"].(string); ok {
errorHint = hint
}
errorCode := ""
if code, ok := errorData["code"].(string); ok {
errorCode = code
}
response.Error = NewPostgrestError(errorMsg, errorDetails, errorHint, errorCode)
// Handle maybeSingle case
if b.isMaybeSingle && response.Error != nil && strings.Contains(errorDetails, "0 rows") {
response.Error = nil
response.Status = 200
response.StatusText = "OK"
}
// When Single() is used and there's an error, return the error
acceptHeader := b.headers.Get("Accept")
if acceptHeader == "application/vnd.pgrst.object+json" && response.Error != nil {
return nil, response.Error
}
if b.shouldThrowOnError && response.Error != nil {
return nil, response.Error
}
return response, nil
}
// Parse successful response
if b.method != "HEAD" {
acceptHeader := b.headers.Get("Accept")
if acceptHeader == "text/csv" || strings.Contains(acceptHeader, "application/vnd.pgrst.plan+text") {
// For CSV and plan text, try to unmarshal as string
strData := string(bodyBytes)
// Use type assertion to check if T is string
var zeroT T
if _, ok := any(zeroT).(string); ok {
response.Data = any(strData).(T)
} else {
// If T is not string, try to unmarshal normally
json.Unmarshal(bodyBytes, &response.Data)
}
} else if len(bodyBytes) > 0 {
acceptHeader := b.headers.Get("Accept")
// Handle Single() case - application/vnd.pgrst.object+json returns a single object
if acceptHeader == "application/vnd.pgrst.object+json" {
// Single() returns a single object, but T might be []T (array type)
// Use reflection to check if T is a slice type
var zeroT T
tType := reflect.TypeOf(zeroT)
if tType != nil && tType.Kind() == reflect.Slice {
// T is a slice type, unmarshal single object and wrap in array
// Create an array JSON with the single object
var arrJSON []byte
arrJSON = append(arrJSON, '[')
arrJSON = append(arrJSON, bodyBytes...)
arrJSON = append(arrJSON, ']')
// Unmarshal the array into response.Data
if err := json.Unmarshal(arrJSON, &response.Data); err != nil {
return nil, fmt.Errorf("error unmarshaling single object array: %w", err)
}
} else {
// T is not a slice, unmarshal directly
if err := json.Unmarshal(bodyBytes, &response.Data); err != nil {
return nil, fmt.Errorf("error unmarshaling single object: %w", err)
}
}
} else if b.isMaybeSingle {
// Handle maybeSingle case
var arr []interface{}
if err := json.Unmarshal(bodyBytes, &arr); err == nil {
if len(arr) > 1 {
response.Error = NewPostgrestError(
"JSON object requested, multiple (or no) rows returned",
fmt.Sprintf("Results contain %d rows, application/vnd.pgrst.object+json requires 1 row", len(arr)),
"",
"PGRST116",
)
response.Status = 406
response.StatusText = "Not Acceptable"
return response, nil
} else if len(arr) == 1 {
// Unmarshal single item
itemBytes, _ := json.Marshal(arr[0])
json.Unmarshal(itemBytes, &response.Data)
} else {
// Empty array, return null equivalent
response.Data = *new(T)
}
} else {
// Not an array, unmarshal directly
json.Unmarshal(bodyBytes, &response.Data)
}
} else {
json.Unmarshal(bodyBytes, &response.Data)
}
}
}
// Parse count from Content-Range header
contentRange := resp.Header.Get("Content-Range")
if contentRange != "" {
parts := strings.Split(contentRange, "/")
if len(parts) > 1 && parts[1] != "*" {
if count, err := strconv.ParseInt(parts[1], 10, 64); err == nil {
response.Count = &count
}
}
}
return response, nil
}
// ExecuteTo executes the query and unmarshals the result into the provided interface
func (b *Builder[T]) ExecuteTo(ctx context.Context, to interface{}) (*int64, error) {
response, err := b.Execute(ctx)
if err != nil {
return nil, err
}
if response.Error != nil {
return nil, response.Error
}
// Marshal and unmarshal to convert to target type
dataBytes, err := json.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("error marshaling response data: %w", err)
}
if err := json.Unmarshal(dataBytes, to); err != nil {
return nil, fmt.Errorf("error unmarshaling to target: %w", err)
}
return response.Count, nil
}
package postgrest
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"strings"
"sync"
)
var (
version = "v0.1.1"
)
// Client represents a PostgREST client
// Similar to PostgrestClient in postgrest-js
type Client struct {
ClientError error
session *http.Client
Transport *transport
schemaName string
}
// NewClientWithError constructs a new client given a URL to a Postgrest instance.
func NewClientWithError(rawURL, schema string, headers map[string]string) (*Client, error) {
// Create URL from rawURL
baseURL, err := url.Parse(rawURL)
if err != nil {
return nil, err
}
t := transport{
header: http.Header{},
baseURL: *baseURL,
Parent: nil,
}
c := Client{
session: &http.Client{Transport: &t},
Transport: &t,
schemaName: schema,
}
if schema == "" {
schema = "public"
c.schemaName = schema
}
// Set required headers
c.Transport.SetHeaders(map[string]string{
"Accept": "application/json",
"Content-Type": "application/json",
"Accept-Profile": schema,
"Content-Profile": schema,
"X-Client-Info": "postgrest-go/" + version,
})
// Set optional headers if they exist
c.Transport.SetHeaders(headers)
return &c, nil
}
// NewClient constructs a new client given a URL to a Postgrest instance.
func NewClient(rawURL, schema string, headers map[string]string) *Client {
client, err := NewClientWithError(rawURL, schema, headers)
if err != nil {
return &Client{ClientError: err}
}
return client
}
func (c *Client) PingWithError() error {
req, err := http.NewRequest("GET", path.Join(c.Transport.baseURL.Path, ""), nil)
if err != nil {
return err
}
resp, err := c.session.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return errors.New("ping failed")
}
return nil
}
func (c *Client) Ping() bool {
err := c.PingWithError()
if err != nil {
c.ClientError = err
return false
}
return true
}
// SetApiKey sets api key header for subsequent requests.
func (c *Client) SetApiKey(apiKey string) *Client {
c.Transport.SetHeader("apikey", apiKey)
return c
}
// SetAuthToken sets authorization header for subsequent requests.
func (c *Client) SetAuthToken(authToken string) *Client {
c.Transport.SetHeader("Authorization", "Bearer "+authToken)
return c
}
// ChangeSchema modifies the schema for subsequent requests.
func (c *Client) ChangeSchema(schema string) *Client {
c.schemaName = schema
c.Transport.SetHeaders(map[string]string{
"Accept-Profile": schema,
"Content-Profile": schema,
})
return c
}
// Schema selects a schema to query or perform an function (rpc) call
func (c *Client) Schema(schema string) *Client {
newClient := &Client{
session: c.session,
Transport: c.Transport,
schemaName: schema,
}
// Update schema headers
newClient.Transport.SetHeaders(map[string]string{
"Accept-Profile": schema,
"Content-Profile": schema,
})
return newClient
}
// From sets the table or view to query from
func (c *Client) From(relation string) *QueryBuilder[map[string]interface{}] {
return NewQueryBuilder[map[string]interface{}](c, relation)
}
// RpcOptions contains options for RPC
type RpcOptions struct {
Head bool
Get bool
Count string // "exact", "planned", or "estimated"
}
// Rpc performs a function call
func (c *Client) Rpc(fn string, args interface{}, opts *RpcOptions) *FilterBuilder[interface{}] {
if opts == nil {
opts = &RpcOptions{}
}
var method string
var body interface{}
rpcURL := c.Transport.baseURL.JoinPath("rpc", fn)
headers := make(http.Header)
if c.Transport != nil {
c.Transport.mu.RLock()
for key, values := range c.Transport.header {
for _, val := range values {
headers.Add(key, val)
}
}
c.Transport.mu.RUnlock()
}
if opts.Head || opts.Get {
if opts.Head {
method = "HEAD"
} else {
method = "GET"
}
// Add args as query parameters
if argsMap, ok := args.(map[string]interface{}); ok {
query := rpcURL.Query()
for name, value := range argsMap {
if value != nil {
// Handle array values
if arr, ok := value.([]interface{}); ok {
var strValues []string
for _, v := range arr {
strValues = append(strValues, fmt.Sprintf("%v", v))
}
query.Set(name, fmt.Sprintf("{%s}", strings.Join(strValues, ",")))
} else {
query.Set(name, fmt.Sprintf("%v", value))
}
}
}
rpcURL.RawQuery = query.Encode()
}
} else {
method = "POST"
body = args
}
if opts.Count != "" && (opts.Count == "exact" || opts.Count == "planned" || opts.Count == "estimated") {
headers.Add("Prefer", fmt.Sprintf("count=%s", opts.Count))
}
builder := NewBuilder[interface{}](c, method, rpcURL, &BuilderOptions{
Headers: headers,
Schema: c.schemaName,
Body: body,
})
return &FilterBuilder[interface{}]{Builder: builder}
}
// RpcWithError executes a Postgres function (a.k.a., Remote Procedure Call), given the
// function name and, optionally, a body, returning the result as a string.
func (c *Client) RpcWithError(name string, count string, rpcBody interface{}) (string, error) {
opts := &RpcOptions{Count: count}
filterBuilder := c.Rpc(name, rpcBody, opts)
response, err := filterBuilder.Execute(context.Background())
if err != nil {
return "", err
}
if response.Error != nil {
return "", response.Error
}
// Convert response.Data to string
dataBytes, _ := json.Marshal(response.Data)
return string(dataBytes), nil
}
type transport struct {
baseURL url.URL
Parent http.RoundTripper
mu sync.RWMutex
header http.Header
}
func (t *transport) SetHeader(key, value string) {
t.mu.Lock()
defer t.mu.Unlock()
t.header.Set(key, value)
}
func (t *transport) SetHeaders(headers map[string]string) {
t.mu.Lock()
defer t.mu.Unlock()
for key, value := range headers {
t.header.Set(key, value)
}
}
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
t.mu.RLock()
for headerName, values := range t.header {
for _, val := range values {
req.Header.Add(headerName, val)
}
}
t.mu.RUnlock()
req.URL = t.baseURL.ResolveReference(req.URL)
// This is only needed with usage of httpmock in testing. It would be better to initialize
// t.Parent with http.DefaultTransport and then use t.Parent.RoundTrip(req)
if t.Parent != nil {
return t.Parent.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
}
package postgrest
import "fmt"
// DefaultHeaders returns the default headers for PostgREST requests
func DefaultHeaders() map[string]string {
return map[string]string{
"X-Client-Info": fmt.Sprintf("postgrest-go/%s", version),
}
}
package postgrest
// PostgrestError represents an error response from PostgREST
// https://postgrest.org/en/stable/api.html?highlight=options#errors-and-http-status-codes
type PostgrestError struct {
Message string
Details string
Hint string
Code string
}
func (e *PostgrestError) Error() string {
return e.Message
}
// NewPostgrestError creates a new PostgrestError
func NewPostgrestError(message, details, hint, code string) *PostgrestError {
return &PostgrestError{
Message: message,
Details: details,
Hint: hint,
Code: code,
}
}
package postgrest
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
)
// FilterBuilder provides filtering methods for queries
// Similar to PostgrestFilterBuilder in postgrest-js
type FilterBuilder[T any] struct {
*Builder[T]
}
var filterOperators = []string{"eq", "neq", "gt", "gte", "lt", "lte", "like", "ilike", "is", "in", "cs", "cd", "sl", "sr", "nxl", "nxr", "adj", "ov", "fts", "plfts", "phfts", "wfts"}
// appendFilter is a helper method that appends a filter to existing filters on a column
func (f *FilterBuilder[T]) appendFilter(column, filterValue string) *FilterBuilder[T] {
query := f.url.Query()
existing := query.Get(column)
andValue := query.Get("and")
// Check if there's already an 'and' param that contains filters for this column
columnPrefix := column + "."
if andValue != "" && strings.Contains(andValue, columnPrefix) {
// Append to existing 'and' param
andValue = strings.TrimSuffix(andValue, ")") + "," + column + "." + filterValue + ")"
query.Set("and", andValue)
} else if existing != "" {
// If a filter already exists for this column, combine with 'and'
if andValue != "" {
andValue = strings.TrimSuffix(andValue, ")") + "," + column + "." + filterValue + ")"
} else {
andValue = fmt.Sprintf("(%s.%s,%s.%s)", column, existing, column, filterValue)
}
query.Set("and", andValue)
query.Del(column)
} else {
query.Set(column, filterValue)
}
f.url.RawQuery = query.Encode()
return f
}
func isOperator(value string) bool {
for _, op := range filterOperators {
if op == value {
return true
}
}
return false
}
// Filter adds a filtering operator to the query
func (f *FilterBuilder[T]) Filter(column, operator, value string) *FilterBuilder[T] {
if !isOperator(operator) {
return f
}
return f.appendFilter(column, fmt.Sprintf("%s.%s", operator, value))
}
// Eq matches only rows where column is equal to value
func (f *FilterBuilder[T]) Eq(column string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("eq.%v", value))
}
// Neq matches only rows where column is not equal to value
func (f *FilterBuilder[T]) Neq(column string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("neq.%v", value))
}
// Gt matches only rows where column is greater than value
func (f *FilterBuilder[T]) Gt(column string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("gt.%v", value))
}
// Gte matches only rows where column is greater than or equal to value
func (f *FilterBuilder[T]) Gte(column string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("gte.%v", value))
}
// Lt matches only rows where column is less than value
func (f *FilterBuilder[T]) Lt(column string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("lt.%v", value))
}
// Lte matches only rows where column is less than or equal to value
func (f *FilterBuilder[T]) Lte(column string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("lte.%v", value))
}
// Like matches only rows where column matches pattern case-sensitively
func (f *FilterBuilder[T]) Like(column, pattern string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("like.%s", pattern))
}
// LikeAllOf matches only rows where column matches all of patterns case-sensitively
func (f *FilterBuilder[T]) LikeAllOf(column string, patterns []string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("like(all).{%s}", strings.Join(patterns, ",")))
}
// LikeAnyOf matches only rows where column matches any of patterns case-sensitively
func (f *FilterBuilder[T]) LikeAnyOf(column string, patterns []string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("like(any).{%s}", strings.Join(patterns, ",")))
}
// Ilike matches only rows where column matches pattern case-insensitively
func (f *FilterBuilder[T]) Ilike(column, pattern string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("ilike.%s", pattern))
}
// IlikeAllOf matches only rows where column matches all of patterns case-insensitively
func (f *FilterBuilder[T]) IlikeAllOf(column string, patterns []string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("ilike(all).{%s}", strings.Join(patterns, ",")))
}
// IlikeAnyOf matches only rows where column matches any of patterns case-insensitively
func (f *FilterBuilder[T]) IlikeAnyOf(column string, patterns []string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("ilike(any).{%s}", strings.Join(patterns, ",")))
}
// Is matches only rows where column IS value
func (f *FilterBuilder[T]) Is(column string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("is.%v", value))
}
// In matches only rows where column is included in the values array
func (f *FilterBuilder[T]) In(column string, values []interface{}) *FilterBuilder[T] {
postgrestReservedCharsRegexp := regexp.MustCompile(`[,()]`)
var cleanedValues []string
for _, v := range values {
valStr := fmt.Sprintf("%v", v)
if postgrestReservedCharsRegexp.MatchString(valStr) {
cleanedValues = append(cleanedValues, fmt.Sprintf(`"%s"`, valStr))
} else {
cleanedValues = append(cleanedValues, valStr)
}
}
return f.appendFilter(column, fmt.Sprintf("in.(%s)", strings.Join(cleanedValues, ",")))
}
// Contains matches only rows where column contains every element appearing in value
func (f *FilterBuilder[T]) Contains(column string, value interface{}) *FilterBuilder[T] {
switch v := value.(type) {
case string:
// range types
return f.appendFilter(column, fmt.Sprintf("cs.%s", v))
case []interface{}:
// array
var strValues []string
for _, item := range v {
strValues = append(strValues, fmt.Sprintf("%v", item))
}
return f.appendFilter(column, fmt.Sprintf("cs.{%s}", strings.Join(strValues, ",")))
default:
// json
jsonBytes, _ := json.Marshal(value)
return f.appendFilter(column, fmt.Sprintf("cs.%s", string(jsonBytes)))
}
}
// ContainedBy matches only rows where every element appearing in column is contained by value
func (f *FilterBuilder[T]) ContainedBy(column string, value interface{}) *FilterBuilder[T] {
switch v := value.(type) {
case string:
// range types
return f.appendFilter(column, fmt.Sprintf("cd.%s", v))
case []interface{}:
// array
var strValues []string
for _, item := range v {
strValues = append(strValues, fmt.Sprintf("%v", item))
}
return f.appendFilter(column, fmt.Sprintf("cd.{%s}", strings.Join(strValues, ",")))
default:
// json
jsonBytes, _ := json.Marshal(value)
return f.appendFilter(column, fmt.Sprintf("cd.%s", string(jsonBytes)))
}
}
// RangeGt matches only rows where every element in column is greater than any element in range
func (f *FilterBuilder[T]) RangeGt(column, rangeValue string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("sr.%s", rangeValue))
}
// RangeGte matches only rows where every element in column is either contained in range or greater than any element in range
func (f *FilterBuilder[T]) RangeGte(column, rangeValue string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("nxl.%s", rangeValue))
}
// RangeLt matches only rows where every element in column is less than any element in range
func (f *FilterBuilder[T]) RangeLt(column, rangeValue string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("sl.%s", rangeValue))
}
// RangeLte matches only rows where every element in column is either contained in range or less than any element in range
func (f *FilterBuilder[T]) RangeLte(column, rangeValue string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("nxr.%s", rangeValue))
}
// RangeAdjacent matches only rows where column is mutually exclusive to range
func (f *FilterBuilder[T]) RangeAdjacent(column, rangeValue string) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("adj.%s", rangeValue))
}
// Overlaps matches only rows where column and value have an element in common
func (f *FilterBuilder[T]) Overlaps(column string, value interface{}) *FilterBuilder[T] {
switch v := value.(type) {
case string:
// range
return f.appendFilter(column, fmt.Sprintf("ov.%s", v))
case []interface{}:
// array
var strValues []string
for _, item := range v {
strValues = append(strValues, fmt.Sprintf("%v", item))
}
return f.appendFilter(column, fmt.Sprintf("ov.{%s}", strings.Join(strValues, ",")))
default:
return f
}
}
// TextSearchOptions contains options for text search
type TextSearchOptions struct {
Config string
Type string // "plain", "phrase", or "websearch"
}
// TextSearch matches only rows where column matches the query string
func (f *FilterBuilder[T]) TextSearch(column, query string, opts *TextSearchOptions) *FilterBuilder[T] {
var typePart string
if opts != nil {
switch opts.Type {
case "plain":
typePart = "pl"
case "phrase":
typePart = "ph"
case "websearch":
typePart = "w"
}
}
configPart := ""
if opts != nil && opts.Config != "" {
configPart = fmt.Sprintf("(%s)", opts.Config)
}
return f.appendFilter(column, fmt.Sprintf("%sfts%s.%s", typePart, configPart, query))
}
// Match matches only rows where each column in query keys is equal to its associated value
func (f *FilterBuilder[T]) Match(query map[string]interface{}) *FilterBuilder[T] {
for column, value := range query {
f.appendFilter(column, fmt.Sprintf("eq.%v", value))
}
return f
}
// Not matches only rows which doesn't satisfy the filter
func (f *FilterBuilder[T]) Not(column, operator string, value interface{}) *FilterBuilder[T] {
return f.appendFilter(column, fmt.Sprintf("not.%s.%v", operator, value))
}
// OrOptions contains options for Or
type OrOptions struct {
ReferencedTable string
// Deprecated: Use ReferencedTable instead
ForeignTable string
}
// Or matches only rows which satisfy at least one of the filters
func (f *FilterBuilder[T]) Or(filters string, opts *OrOptions) *FilterBuilder[T] {
if opts == nil {
opts = &OrOptions{}
}
key := "or"
if opts.ReferencedTable != "" {
key = opts.ReferencedTable + ".or"
}
query := f.url.Query()
query.Set(key, fmt.Sprintf("(%s)", filters))
f.url.RawQuery = query.Encode()
return f
}
// Embed TransformBuilder methods
func (f *FilterBuilder[T]) Select(columns string) *FilterBuilder[T] {
tb := &TransformBuilder[T]{Builder: f.Builder}
return tb.Select(columns)
}
func (f *FilterBuilder[T]) Order(column string, opts *OrderOptions) *TransformBuilder[T] {
tb := &TransformBuilder[T]{Builder: f.Builder}
return tb.Order(column, opts)
}
func (f *FilterBuilder[T]) Limit(count int, opts *LimitOptions) *TransformBuilder[T] {
tb := &TransformBuilder[T]{Builder: f.Builder}
return tb.Limit(count, opts)
}
func (f *FilterBuilder[T]) Range(from, to int, opts *RangeOptions) *TransformBuilder[T] {
tb := &TransformBuilder[T]{Builder: f.Builder}
return tb.Range(from, to, opts)
}
func (f *FilterBuilder[T]) Single() *Builder[T] {
tb := &TransformBuilder[T]{Builder: f.Builder}
return tb.Single()
}
func (f *FilterBuilder[T]) MaybeSingle() *Builder[T] {
tb := &TransformBuilder[T]{Builder: f.Builder}
return tb.MaybeSingle()
}
// Execute executes the query and returns the response
func (f *FilterBuilder[T]) Execute(ctx context.Context) (*PostgrestResponse[T], error) {
return f.Builder.Execute(ctx)
}
// ExecuteTo executes the query and unmarshals the result into the provided interface
func (f *FilterBuilder[T]) ExecuteTo(ctx context.Context, to interface{}) (*int64, error) {
return f.Builder.ExecuteTo(ctx, to)
}
package postgrest
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
)
// QueryBuilder provides query building methods
// Similar to PostgrestQueryBuilder in postgrest-js
type QueryBuilder[T any] struct {
url *url.URL
headers http.Header
schema string
client *Client
}
// NewQueryBuilder creates a new QueryBuilder instance
func NewQueryBuilder[T any](client *Client, relation string) *QueryBuilder[T] {
baseURL := client.Transport.baseURL
queryURL := baseURL.JoinPath(relation)
headers := make(http.Header)
if client.Transport != nil {
client.Transport.mu.RLock()
for key, values := range client.Transport.header {
for _, val := range values {
headers.Add(key, val)
}
}
client.Transport.mu.RUnlock()
}
return &QueryBuilder[T]{
url: queryURL,
headers: headers,
schema: client.schemaName,
client: client,
}
}
// SelectOptions contains options for Select
type SelectOptions struct {
Head bool
Count string // "exact", "planned", or "estimated"
}
// Select performs a SELECT query on the table or view
func (q *QueryBuilder[T]) Select(columns string, opts *SelectOptions) *FilterBuilder[[]T] {
if opts == nil {
opts = &SelectOptions{}
}
method := "GET"
if opts.Head {
method = "HEAD"
}
// Remove whitespaces except when quoted
quoted := false
var cleanedColumns strings.Builder
if columns == "" {
cleanedColumns.WriteString("*")
} else {
for _, char := range columns {
if char == '"' {
quoted = !quoted
}
if char == ' ' && !quoted {
continue
}
cleanedColumns.WriteRune(char)
}
}
query := q.url.Query()
query.Set("select", cleanedColumns.String())
q.url.RawQuery = query.Encode()
if opts.Count != "" && (opts.Count == "exact" || opts.Count == "planned" || opts.Count == "estimated") {
q.headers.Add("Prefer", fmt.Sprintf("count=%s", opts.Count))
}
builder := NewBuilder[[]T](q.client, method, q.url, &BuilderOptions{
Headers: q.headers,
Schema: q.schema,
})
return &FilterBuilder[[]T]{Builder: builder}
}
// InsertOptions contains options for Insert
type InsertOptions struct {
Count string // "exact", "planned", or "estimated"
DefaultToNull bool
}
// Insert performs an INSERT into the table or view
func (q *QueryBuilder[T]) Insert(values interface{}, opts *InsertOptions) *FilterBuilder[interface{}] {
if opts == nil {
opts = &InsertOptions{DefaultToNull: true}
}
method := "POST"
headers := make(http.Header)
for key, values := range q.headers {
for _, val := range values {
headers.Add(key, val)
}
}
if opts.Count != "" && (opts.Count == "exact" || opts.Count == "planned" || opts.Count == "estimated") {
headers.Add("Prefer", fmt.Sprintf("count=%s", opts.Count))
}
if !opts.DefaultToNull {
headers.Add("Prefer", "missing=default")
}
// Handle array values to set columns parameter
valuesBytes, _ := json.Marshal(values)
var valuesArray []map[string]interface{}
if json.Unmarshal(valuesBytes, &valuesArray) == nil && len(valuesArray) > 0 {
columns := make(map[string]bool)
for _, row := range valuesArray {
for key := range row {
columns[key] = true
}
}
var uniqueColumns []string
for col := range columns {
uniqueColumns = append(uniqueColumns, fmt.Sprintf(`"%s"`, col))
}
if len(uniqueColumns) > 0 {
query := q.url.Query()
query.Set("columns", strings.Join(uniqueColumns, ","))
q.url.RawQuery = query.Encode()
}
}
builder := NewBuilder[interface{}](q.client, method, q.url, &BuilderOptions{
Headers: headers,
Schema: q.schema,
Body: values,
})
return &FilterBuilder[interface{}]{Builder: builder}
}
// UpsertOptions contains options for Upsert
type UpsertOptions struct {
OnConflict string
IgnoreDuplicates bool
Count string // "exact", "planned", or "estimated"
DefaultToNull bool
}
// Upsert performs an UPSERT on the table or view
func (q *QueryBuilder[T]) Upsert(values interface{}, opts *UpsertOptions) *FilterBuilder[interface{}] {
if opts == nil {
opts = &UpsertOptions{IgnoreDuplicates: false, DefaultToNull: true}
}
method := "POST"
headers := make(http.Header)
for key, values := range q.headers {
for _, val := range values {
headers.Add(key, val)
}
}
resolution := "merge-duplicates"
if opts.IgnoreDuplicates {
resolution = "ignore-duplicates"
}
headers.Add("Prefer", fmt.Sprintf("resolution=%s", resolution))
if opts.OnConflict != "" {
query := q.url.Query()
query.Set("on_conflict", opts.OnConflict)
q.url.RawQuery = query.Encode()
}
if opts.Count != "" && (opts.Count == "exact" || opts.Count == "planned" || opts.Count == "estimated") {
headers.Add("Prefer", fmt.Sprintf("count=%s", opts.Count))
}
if !opts.DefaultToNull {
headers.Add("Prefer", "missing=default")
}
// Handle array values to set columns parameter
valuesBytes, _ := json.Marshal(values)
var valuesArray []map[string]interface{}
if json.Unmarshal(valuesBytes, &valuesArray) == nil && len(valuesArray) > 0 {
columns := make(map[string]bool)
for _, row := range valuesArray {
for key := range row {
columns[key] = true
}
}
var uniqueColumns []string
for col := range columns {
uniqueColumns = append(uniqueColumns, fmt.Sprintf(`"%s"`, col))
}
if len(uniqueColumns) > 0 {
query := q.url.Query()
query.Set("columns", strings.Join(uniqueColumns, ","))
q.url.RawQuery = query.Encode()
}
}
builder := NewBuilder[interface{}](q.client, method, q.url, &BuilderOptions{
Headers: headers,
Schema: q.schema,
Body: values,
})
return &FilterBuilder[interface{}]{Builder: builder}
}
// UpdateOptions contains options for Update
type UpdateOptions struct {
Count string // "exact", "planned", or "estimated"
}
// Update performs an UPDATE on the table or view
func (q *QueryBuilder[T]) Update(values interface{}, opts *UpdateOptions) *FilterBuilder[interface{}] {
if opts == nil {
opts = &UpdateOptions{}
}
method := "PATCH"
headers := make(http.Header)
for key, values := range q.headers {
for _, val := range values {
headers.Add(key, val)
}
}
if opts.Count != "" && (opts.Count == "exact" || opts.Count == "planned" || opts.Count == "estimated") {
headers.Add("Prefer", fmt.Sprintf("count=%s", opts.Count))
}
builder := NewBuilder[interface{}](q.client, method, q.url, &BuilderOptions{
Headers: headers,
Schema: q.schema,
Body: values,
})
return &FilterBuilder[interface{}]{Builder: builder}
}
// DeleteOptions contains options for Delete
type DeleteOptions struct {
Count string // "exact", "planned", or "estimated"
}
// Delete performs a DELETE on the table or view
func (q *QueryBuilder[T]) Delete(opts *DeleteOptions) *FilterBuilder[interface{}] {
if opts == nil {
opts = &DeleteOptions{}
}
method := "DELETE"
headers := make(http.Header)
for key, values := range q.headers {
for _, val := range values {
headers.Add(key, val)
}
}
if opts.Count != "" && (opts.Count == "exact" || opts.Count == "planned" || opts.Count == "estimated") {
headers.Add("Prefer", fmt.Sprintf("count=%s", opts.Count))
}
builder := NewBuilder[interface{}](q.client, method, q.url, &BuilderOptions{
Headers: headers,
Schema: q.schema,
})
return &FilterBuilder[interface{}]{Builder: builder}
}
package postgrest
import (
"fmt"
"strconv"
"strings"
)
// TransformBuilder provides transformation methods for queries
// Similar to PostgrestTransformBuilder in postgrest-js
type TransformBuilder[T any] struct {
*Builder[T]
}
// Select performs a SELECT on the query result
func (t *TransformBuilder[T]) Select(columns string) *FilterBuilder[T] {
// Remove whitespaces except when quoted
quoted := false
var cleanedColumns strings.Builder
for _, char := range columns {
if char == '"' {
quoted = !quoted
}
if char == ' ' && !quoted {
continue
}
cleanedColumns.WriteRune(char)
}
cleaned := cleanedColumns.String()
if cleaned == "" {
cleaned = "*"
}
query := t.url.Query()
query.Set("select", cleaned)
t.url.RawQuery = query.Encode()
t.headers.Add("Prefer", "return=representation")
return &FilterBuilder[T]{Builder: t.Builder}
}
// Order orders the query result by column
func (t *TransformBuilder[T]) Order(column string, opts *OrderOptions) *TransformBuilder[T] {
if opts == nil {
opts = &OrderOptions{Ascending: true}
}
key := "order"
if opts.ReferencedTable != "" {
key = opts.ReferencedTable + ".order"
}
ascendingStr := "desc"
if opts.Ascending {
ascendingStr = "asc"
}
nullsStr := ""
if opts.NullsFirst != nil {
if *opts.NullsFirst {
nullsStr = ".nullsfirst"
} else {
nullsStr = ".nullslast"
}
}
query := t.url.Query()
existingOrder := query.Get(key)
orderValue := fmt.Sprintf("%s.%s%s", column, ascendingStr, nullsStr)
if existingOrder != "" {
orderValue = existingOrder + "," + orderValue
}
query.Set(key, orderValue)
t.url.RawQuery = query.Encode()
return t
}
// OrderOptions contains options for ordering
type OrderOptions struct {
Ascending bool
NullsFirst *bool
ReferencedTable string
// Deprecated: Use ReferencedTable instead
ForeignTable string
}
// Limit limits the query result by count
func (t *TransformBuilder[T]) Limit(count int, opts *LimitOptions) *TransformBuilder[T] {
if opts == nil {
opts = &LimitOptions{}
}
key := "limit"
if opts.ReferencedTable != "" {
key = opts.ReferencedTable + ".limit"
}
query := t.url.Query()
query.Set(key, strconv.Itoa(count))
t.url.RawQuery = query.Encode()
return t
}
// LimitOptions contains options for limiting
type LimitOptions struct {
ReferencedTable string
// Deprecated: Use ReferencedTable instead
ForeignTable string
}
// Range limits the query result by starting at an offset from and ending at to
func (t *TransformBuilder[T]) Range(from, to int, opts *RangeOptions) *TransformBuilder[T] {
if opts == nil {
opts = &RangeOptions{}
}
offsetKey := "offset"
limitKey := "limit"
if opts.ReferencedTable != "" {
offsetKey = opts.ReferencedTable + ".offset"
limitKey = opts.ReferencedTable + ".limit"
}
query := t.url.Query()
query.Set(offsetKey, strconv.Itoa(from))
// Range is inclusive, so add 1
query.Set(limitKey, strconv.Itoa(to-from+1))
t.url.RawQuery = query.Encode()
return t
}
// RangeOptions contains options for range
type RangeOptions struct {
ReferencedTable string
// Deprecated: Use ReferencedTable instead
ForeignTable string
}
// AbortSignal sets the AbortSignal for the fetch request
func (t *TransformBuilder[T]) AbortSignal(ctx interface{}) *TransformBuilder[T] {
// In Go, we use context.Context instead of AbortSignal
// This is a placeholder for API compatibility
return t
}
// Single returns data as a single object instead of an array
func (t *TransformBuilder[T]) Single() *Builder[T] {
t.headers.Set("Accept", "application/vnd.pgrst.object+json")
return t.Builder
}
// MaybeSingle returns data as a single object or null
func (t *TransformBuilder[T]) MaybeSingle() *Builder[T] {
if t.method == "GET" {
t.headers.Set("Accept", "application/json")
} else {
t.headers.Set("Accept", "application/vnd.pgrst.object+json")
}
t.isMaybeSingle = true
return t.Builder
}
// CSV returns data as a string in CSV format
func (t *TransformBuilder[T]) CSV() *Builder[string] {
t.headers.Set("Accept", "text/csv")
return &Builder[string]{
method: t.method,
url: t.url,
headers: t.headers,
schema: t.schema,
body: t.body,
shouldThrowOnError: t.shouldThrowOnError,
signal: t.signal,
client: t.client,
isMaybeSingle: t.isMaybeSingle,
}
}
// GeoJSON returns data as an object in GeoJSON format
func (t *TransformBuilder[T]) GeoJSON() *Builder[map[string]interface{}] {
t.headers.Set("Accept", "application/geo+json")
return &Builder[map[string]interface{}]{
method: t.method,
url: t.url,
headers: t.headers,
schema: t.schema,
body: t.body,
shouldThrowOnError: t.shouldThrowOnError,
signal: t.signal,
client: t.client,
isMaybeSingle: t.isMaybeSingle,
}
}
// ExplainOptions contains options for explain
type ExplainOptions struct {
Analyze bool
Verbose bool
Settings bool
Buffers bool
WAL bool
Format string // "text" or "json"
}
// Explain returns data as the EXPLAIN plan for the query
func (t *TransformBuilder[T]) Explain(opts *ExplainOptions) *Builder[interface{}] {
if opts == nil {
opts = &ExplainOptions{Format: "text"}
}
var options []string
if opts.Analyze {
options = append(options, "analyze")
}
if opts.Verbose {
options = append(options, "verbose")
}
if opts.Settings {
options = append(options, "settings")
}
if opts.Buffers {
options = append(options, "buffers")
}
if opts.WAL {
options = append(options, "wal")
}
optionsStr := strings.Join(options, "|")
forMediatype := t.headers.Get("Accept")
if forMediatype == "" {
forMediatype = "application/json"
}
acceptValue := fmt.Sprintf("application/vnd.pgrst.plan+%s; for=\"%s\"; options=%s;", opts.Format, forMediatype, optionsStr)
t.headers.Set("Accept", acceptValue)
return &Builder[interface{}]{
method: t.method,
url: t.url,
headers: t.headers,
schema: t.schema,
body: t.body,
shouldThrowOnError: t.shouldThrowOnError,
signal: t.signal,
client: t.client,
isMaybeSingle: t.isMaybeSingle,
}
}
// Rollback rolls back the query
func (t *TransformBuilder[T]) Rollback() *TransformBuilder[T] {
t.headers.Add("Prefer", "tx=rollback")
return t
}
// MaxAffected sets the maximum number of rows that can be affected by the query
func (t *TransformBuilder[T]) MaxAffected(value int) *TransformBuilder[T] {
t.headers.Add("Prefer", "handling=strict")
t.headers.Add("Prefer", fmt.Sprintf("max-affected=%d", value))
return t
}