package postgrest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/url" "path" ) var ( version = "v0.1.1" ) type Client struct { ClientError error session http.Client Transport *transport } // NewClient constructs a new client given a URL to a Postgrest instance. func NewClient(rawURL, schema string, headers map[string]string) *Client { // Create URL from rawURL baseURL, err := url.Parse(rawURL) if err != nil { return &Client{ClientError: err} } t := transport{ header: http.Header{}, baseURL: *baseURL, Parent: nil, } c := Client{ session: http.Client{Transport: &t}, Transport: &t, } if schema == "" { schema = "public" } // Set required headers c.Transport.header.Set("Accept", "application/json") c.Transport.header.Set("Content-Type", "application/json") c.Transport.header.Set("Accept-Profile", schema) c.Transport.header.Set("Content-Profile", schema) c.Transport.header.Set("X-Client-Info", "postgrest-go/"+version) // Set optional headers if they exist for key, value := range headers { c.Transport.header.Set(key, value) } return &c } func (c *Client) Ping() bool { req, err := http.NewRequest("GET", path.Join(c.Transport.baseURL.Path, ""), nil) if err != nil { c.ClientError = err return false } resp, err := c.session.Do(req) if err != nil { c.ClientError = err return false } if resp.Status != "200 OK" { c.ClientError = errors.New("ping failed") return false } return true } // SetApiKey sets api key header for subsequent requests. func (c *Client) SetApiKey(apiKey string) *Client { c.Transport.header.Set("apikey", apiKey) return c } // SetAuthToken sets authorization header for subsequent requests. func (c *Client) SetAuthToken(authToken string) *Client { c.Transport.header.Set("Authorization", "Bearer "+authToken) return c } // ChangeSchema modifies the schema for subsequent requests. func (c *Client) ChangeSchema(schema string) *Client { c.Transport.header.Set("Accept-Profile", schema) c.Transport.header.Set("Content-Profile", schema) return c } // From sets the table to query from. func (c *Client) From(table string) *QueryBuilder { return &QueryBuilder{client: c, tableName: table, headers: map[string]string{}, params: map[string]string{}} } // Rpc executes a Postgres function (a.k.a., Remote Prodedure Call), given the // function name and, optionally, a body, returning the result as a string. func (c *Client) Rpc(name string, count string, rpcBody interface{}) string { // Get body if it exists var byteBody []byte = nil if rpcBody != nil { jsonBody, err := json.Marshal(rpcBody) if err != nil { c.ClientError = err return "" } byteBody = jsonBody } readerBody := bytes.NewBuffer(byteBody) url := path.Join(c.Transport.baseURL.Path, "rpc", name) req, err := http.NewRequest("POST", url, readerBody) if err != nil { c.ClientError = err return "" } if count != "" && (count == `exact` || count == `planned` || count == `estimated`) { req.Header.Add("Prefer", "count="+count) } resp, err := c.session.Do(req) if err != nil { c.ClientError = err return "" } body, err := io.ReadAll(resp.Body) if err != nil { c.ClientError = err return "" } result := string(body) err = resp.Body.Close() if err != nil { c.ClientError = err return "" } return result } type transport struct { header http.Header baseURL url.URL Parent http.RoundTripper } func (t transport) RoundTrip(req *http.Request) (*http.Response, error) { for headerName, values := range t.header { for _, val := range values { req.Header.Add(headerName, val) } } req.URL = t.baseURL.ResolveReference(req.URL) // This is only needed with usage of httpmock in testing. It would be better to initialize // t.Parent with http.DefaultTransport and then use t.Parent.RoundTrip(req) if t.Parent != nil { return t.Parent.RoundTrip(req) } return http.DefaultTransport.RoundTrip(req) }
package postgrest import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "path" "strconv" "strings" ) // countType is the integer type returned from execute functions when a count // specifier is supplied to a builder. type countType = int64 // ExecuteError is the error response format from postgrest. We really // only use Code and Message, but we'll keep it as a struct for now. type ExecuteError struct { Hint string `json:"hint"` Details string `json:"details"` Code string `json:"code"` Message string `json:"message"` } func executeHelper(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) ([]byte, countType, error) { if client.ClientError != nil { return nil, 0, client.ClientError } readerBody := bytes.NewBuffer(body) baseUrl := path.Join(append([]string{client.Transport.baseURL.Path}, urlFragments...)...) req, err := http.NewRequestWithContext(ctx, method, baseUrl, readerBody) if err != nil { return nil, 0, fmt.Errorf("error creating request: %s", err.Error()) } for key, val := range headers { req.Header.Add(key, val) } q := req.URL.Query() for key, val := range params { q.Add(key, val) } req.URL.RawQuery = q.Encode() resp, err := client.session.Do(req) if err != nil { return nil, 0, err } respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, 0, err } // https://postgrest.org/en/stable/api.html#errors-and-http-status-codes if resp.StatusCode >= 400 { var errmsg *ExecuteError err := json.Unmarshal(respBody, &errmsg) if err != nil { return nil, 0, fmt.Errorf("error parsing error response: %s", err.Error()) } return nil, 0, fmt.Errorf("(%s) %s", errmsg.Code, errmsg.Message) } var count countType contentRange := resp.Header.Get("Content-Range") if contentRange != "" { split := strings.Split(contentRange, "/") if len(split) > 1 && split[1] != "*" { count, err = strconv.ParseInt(split[1], 0, 64) if err != nil { return nil, 0, fmt.Errorf("error parsing count from Content-Range header: %s", err.Error()) } } } err = resp.Body.Close() if err != nil { return nil, 0, errors.New("error closing response body") } return respBody, count, nil } func executeString(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) (string, countType, error) { resp, count, err := executeHelper(ctx, client, method, body, urlFragments, headers, params) return string(resp), count, err } func execute(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) ([]byte, countType, error) { return executeHelper(ctx, client, method, body, urlFragments, headers, params) } func executeTo(ctx context.Context, client *Client, method string, body []byte, to interface{}, urlFragments []string, headers map[string]string, params map[string]string) (countType, error) { resp, count, err := executeHelper(ctx, client, method, body, urlFragments, headers, params) if err != nil { return count, err } readableRes := bytes.NewBuffer(resp) err = json.NewDecoder(readableRes).Decode(&to) return count, err }
package postgrest import ( "context" "encoding/json" "fmt" "regexp" "strconv" "strings" ) // FilterBuilder describes a builder for a filtered result set. type FilterBuilder struct { client *Client method string // One of "HEAD", "GET", "POST", "PUT", "DELETE" body []byte tableName string headers map[string]string params map[string]string } // ExecuteString runs the PostgREST query, returning the result as a JSON // string. func (f *FilterBuilder) ExecuteString() (string, int64, error) { return executeString(context.Background(), f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params) } // ExecuteStringWithContext runs the PostgREST query, returning the result as // a JSON string. func (f *FilterBuilder) ExecuteStringWithContext(ctx context.Context) (string, int64, error) { return executeString(ctx, f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params) } // Execute runs the PostgREST query, returning the result as a byte slice. func (f *FilterBuilder) Execute() ([]byte, int64, error) { return execute(context.Background(), f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params) } // ExecuteWithContext runs the PostgREST query with the given context, // returning the result as a byte slice. func (f *FilterBuilder) ExecuteWithContext(ctx context.Context) ([]byte, int64, error) { return execute(ctx, f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params) } // ExecuteTo runs the PostgREST query, encoding the result to the supplied // interface. Note that the argument for the to parameter should always be a // reference to a slice. func (f *FilterBuilder) ExecuteTo(to interface{}) (countType, error) { return executeTo(context.Background(), f.client, f.method, f.body, to, []string{f.tableName}, f.headers, f.params) } // ExecuteToWithContext runs the PostgREST query with the given context, // encoding the result to the supplied interface. Note that the argument for // the to parameter should always be a reference to a slice. func (f *FilterBuilder) ExecuteToWithContext(ctx context.Context, to interface{}) (countType, error) { return executeTo(ctx, f.client, f.method, f.body, to, []string{f.tableName}, f.headers, f.params) } var filterOperators = []string{"eq", "neq", "gt", "gte", "lt", "lte", "like", "ilike", "is", "in", "cs", "cd", "sl", "sr", "nxl", "nxr", "adj", "ov", "fts", "plfts", "phfts", "wfts"} func isOperator(value string) bool { for _, operator := range filterOperators { if value == operator { return true } } return false } // Filter adds a filtering operator to the query. For a list of available // operators, see: https://postgrest.org/en/stable/api.html#operators func (f *FilterBuilder) Filter(column, operator, value string) *FilterBuilder { if !isOperator(operator) { f.client.ClientError = fmt.Errorf("invalid filter operator") return f } f.params[column] = fmt.Sprintf("%s.%s", operator, value) return f } func (f *FilterBuilder) And(filters, foreignTable string) *FilterBuilder { if foreignTable != "" { f.params[foreignTable+".and"] = fmt.Sprintf("(%s)", filters) } else { f.params[foreignTable+"and"] = fmt.Sprintf("(%s)", filters) } return f } func (f *FilterBuilder) Or(filters, foreignTable string) *FilterBuilder { if foreignTable != "" { f.params[foreignTable+".or"] = fmt.Sprintf("(%s)", filters) } else { f.params[foreignTable+"or"] = fmt.Sprintf("(%s)", filters) } return f } func (f *FilterBuilder) Not(column, operator, value string) *FilterBuilder { if !isOperator(operator) { return f } f.params[column] = fmt.Sprintf("not.%s.%s", operator, value) return f } func (f *FilterBuilder) Match(userQuery map[string]string) *FilterBuilder { for key, value := range userQuery { f.params[key] = "eq." + value } return f } func (f *FilterBuilder) Eq(column, value string) *FilterBuilder { f.params[column] = "eq." + value return f } func (f *FilterBuilder) Neq(column, value string) *FilterBuilder { f.params[column] = "neq." + value return f } func (f *FilterBuilder) Gt(column, value string) *FilterBuilder { f.params[column] = "gt." + value return f } func (f *FilterBuilder) Gte(column, value string) *FilterBuilder { f.params[column] = "gte." + value return f } func (f *FilterBuilder) Lt(column, value string) *FilterBuilder { f.params[column] = "lt." + value return f } func (f *FilterBuilder) Lte(column, value string) *FilterBuilder { f.params[column] = "lte." + value return f } func (f *FilterBuilder) Like(column, value string) *FilterBuilder { f.params[column] = "like." + value return f } func (f *FilterBuilder) Ilike(column, value string) *FilterBuilder { f.params[column] = "ilike." + value return f } func (f *FilterBuilder) Is(column, value string) *FilterBuilder { f.params[column] = "is." + value return f } func (f *FilterBuilder) In(column string, values []string) *FilterBuilder { var cleanedValues []string illegalChars := regexp.MustCompile("[,()]") for _, value := range values { exp := illegalChars.MatchString(value) if exp { cleanedValues = append(cleanedValues, fmt.Sprintf("\"%s\"", value)) } else { cleanedValues = append(cleanedValues, value) } } f.params[column] = fmt.Sprintf("in.(%s)", strings.Join(cleanedValues, ",")) return f } func (f *FilterBuilder) Contains(column string, value []string) *FilterBuilder { newValue := []string{} for _, v := range value { newValue = append(newValue, fmt.Sprintf("%#v", v)) } valueString := fmt.Sprintf("{%s}", strings.Join(newValue, ",")) f.params[column] = "cs." + valueString return f } func (f *FilterBuilder) ContainedBy(column string, value []string) *FilterBuilder { newValue := []string{} for _, v := range value { newValue = append(newValue, fmt.Sprintf("%#v", v)) } valueString := fmt.Sprintf("{%s}", strings.Join(newValue, ",")) f.params[column] = "cd." + valueString return f } func (f *FilterBuilder) ContainsObject(column string, value interface{}) *FilterBuilder { sum, err := json.Marshal(value) if err != nil { f.client.ClientError = err } f.params[column] = "cs." + string(sum) return f } func (f *FilterBuilder) ContainedByObject(column string, value interface{}) *FilterBuilder { sum, err := json.Marshal(value) if err != nil { f.client.ClientError = err } f.params[column] = "cs." + string(sum) return f } func (f *FilterBuilder) RangeLt(column, value string) *FilterBuilder { f.params[column] = "sl." + value return f } func (f *FilterBuilder) RangeGt(column, value string) *FilterBuilder { f.params[column] = "sr." + value return f } func (f *FilterBuilder) RangeGte(column, value string) *FilterBuilder { f.params[column] = "nxl." + value return f } func (f *FilterBuilder) RangeLte(column, value string) *FilterBuilder { f.params[column] = "nxr." + value return f } func (f *FilterBuilder) RangeAdjacent(column, value string) *FilterBuilder { f.params[column] = "adj." + value return f } func (f *FilterBuilder) Overlaps(column string, value []string) *FilterBuilder { newValue := []string{} for _, v := range value { newValue = append(newValue, fmt.Sprintf("%#v", v)) } valueString := fmt.Sprintf("{%s}", strings.Join(newValue, ",")) f.params[column] = "ov." + valueString return f } // TextSearch performs a full-text search filter. For more information, see // https://postgrest.org/en/stable/api.html#fts. func (f *FilterBuilder) TextSearch(column, userQuery, config, tsType string) *FilterBuilder { var typePart, configPart string if tsType == "plain" { typePart = "pl" } else if tsType == "phrase" { typePart = "ph" } else if tsType == "websearch" { typePart = "w" } else if tsType == "" { typePart = "" } else { f.client.ClientError = fmt.Errorf("invalid text search type") return f } if config != "" { configPart = fmt.Sprintf("(%s)", config) } f.params[column] = typePart + "fts" + configPart + "." + userQuery return f } // OrderOpts describes the options to be provided to Order. type OrderOpts struct { Ascending bool NullsFirst bool ForeignTable string } // DefaultOrderOpts is the default set of options used by Order. var DefaultOrderOpts = OrderOpts{ Ascending: false, NullsFirst: false, ForeignTable: "", } // Limit the result to the specified count. func (f *FilterBuilder) Limit(count int, foreignTable string) *FilterBuilder { if foreignTable != "" { f.params[foreignTable+".limit"] = strconv.Itoa(count) } else { f.params["limit"] = strconv.Itoa(count) } return f } // Order the result with the specified column. A pointer to an OrderOpts // object can be supplied to specify ordering options. func (f *FilterBuilder) Order(column string, opts *OrderOpts) *FilterBuilder { if opts == nil { opts = &DefaultOrderOpts } key := "order" if opts.ForeignTable != "" { key = opts.ForeignTable + ".order" } ascendingString := "desc" if opts.Ascending { ascendingString = "asc" } nullsString := "nullslast" if opts.NullsFirst { nullsString = "nullsfirst" } existingOrder, ok := f.params[key] if ok && existingOrder != "" { f.params[key] = fmt.Sprintf("%s,%s.%s.%s", existingOrder, column, ascendingString, nullsString) } else { f.params[key] = fmt.Sprintf("%s.%s.%s", column, ascendingString, nullsString) } return f } // Range Limits the result to rows within the specified range, inclusive. func (f *FilterBuilder) Range(from, to int, foreignTable string) *FilterBuilder { if foreignTable != "" { f.params[foreignTable+".offset"] = strconv.Itoa(from) f.params[foreignTable+".limit"] = strconv.Itoa(to - from + 1) } else { f.params["offset"] = strconv.Itoa(from) f.params["limit"] = strconv.Itoa(to - from + 1) } return f } // Single Retrieves only one row from the result. The total result set must be one row // (e.g., by using Limit). Otherwise, this will result in an error. func (f *FilterBuilder) Single() *FilterBuilder { f.headers["Accept"] = "application/vnd.pgrst.object+json" return f }
package postgrest import ( "context" "encoding/json" "fmt" "strings" ) // QueryBuilder describes a builder for a query. type QueryBuilder struct { client *Client method string body []byte tableName string headers map[string]string params map[string]string } // ExecuteString runs the PostgREST query, returning the result as a JSON // string. func (q *QueryBuilder) ExecuteString() (string, int64, error) { return executeString(context.Background(), q.client, q.method, q.body, []string{q.tableName}, q.headers, q.params) } // ExecuteStringWithContext runs the PostgREST query, returning the result as // a JSON string. func (q *QueryBuilder) ExecuteStringWithContext(ctx context.Context) (string, int64, error) { return executeString(ctx, q.client, q.method, q.body, []string{q.tableName}, q.headers, q.params) } // Execute runs the Postgrest query, returning the result as a byte slice. func (q *QueryBuilder) Execute() ([]byte, int64, error) { return execute(context.Background(), q.client, q.method, q.body, []string{q.tableName}, q.headers, q.params) } // ExecuteWithContext runs the PostgREST query with the given context, // returning the result as a byte slice. func (q *QueryBuilder) ExecuteWithContext(ctx context.Context) ([]byte, int64, error) { return execute(ctx, q.client, q.method, q.body, []string{q.tableName}, q.headers, q.params) } // ExecuteTo runs the PostgREST query, encoding the result to the supplied // interface. Note that the argument for the to parameter should always be a // reference to a slice. func (q *QueryBuilder) ExecuteTo(to interface{}) (int64, error) { return executeTo(context.Background(), q.client, q.method, q.body, to, []string{q.tableName}, q.headers, q.params) } // ExecuteToWithContext runs the PostgREST query with the given context, // encoding the result to the supplied interface. Note that the argument for // the to parameter should always be a reference to a slice. func (q *QueryBuilder) ExecuteToWithContext(ctx context.Context, to interface{}) (int64, error) { return executeTo(ctx, q.client, q.method, q.body, to, []string{q.tableName}, q.headers, q.params) } // Select performs vertical filtering. func (q *QueryBuilder) Select(columns, count string, head bool) *FilterBuilder { if head { q.method = "HEAD" } else { q.method = "GET" } if columns == "" { q.params["select"] = "*" } else { quoted := false var resultArr []string for _, char := range strings.Split(columns, "") { if char == `"` { quoted = !quoted } if char == " " { char = "" } resultArr = append(resultArr, char) } result := strings.Join(resultArr, "") q.params["select"] = result } if count != "" && (count == `exact` || count == `planned` || count == `estimated`) { currentValue, ok := q.headers["Prefer"] if ok && currentValue != "" { q.headers["Prefer"] = fmt.Sprintf("%s,count=%s", currentValue, count) } else { q.headers["Prefer"] = fmt.Sprintf("count=%s", count) } } return &FilterBuilder{client: q.client, method: q.method, body: q.body, tableName: q.tableName, headers: q.headers, params: q.params} } // Insert performs an insertion into the table. func (q *QueryBuilder) Insert(value interface{}, upsert bool, onConflict, returning, count string) *FilterBuilder { q.method = "POST" if onConflict != "" && upsert { q.params["on_conflict"] = onConflict } var headerList []string if upsert { headerList = append(headerList, "resolution=merge-duplicates") } if returning == "" { returning = "representation" } if returning == "minimal" || returning == "representation" { headerList = append(headerList, "return="+returning) } if count != "" && (count == `exact` || count == `planned` || count == `estimated`) { headerList = append(headerList, "count="+count) } q.headers["Prefer"] = strings.Join(headerList, ",") // Get body if exist var byteBody []byte = nil if value != nil { jsonBody, err := json.Marshal(value) if err != nil { q.client.ClientError = err return &FilterBuilder{} } byteBody = jsonBody } q.body = byteBody return &FilterBuilder{client: q.client, method: q.method, body: q.body, tableName: q.tableName, headers: q.headers, params: q.params} } // Upsert performs an upsert into the table. func (q *QueryBuilder) Upsert(value interface{}, onConflict, returning, count string) *FilterBuilder { q.method = "POST" if onConflict != "" { q.params["on_conflict"] = onConflict } headerList := []string{"resolution=merge-duplicates"} if returning == "" { returning = "representation" } if returning == "minimal" || returning == "representation" { headerList = append(headerList, "return="+returning) } if count != "" && (count == `exact` || count == `planned` || count == `estimated`) { headerList = append(headerList, "count="+count) } q.headers["Prefer"] = strings.Join(headerList, ",") // Get body if exist var byteBody []byte = nil if value != nil { jsonBody, err := json.Marshal(value) if err != nil { q.client.ClientError = err return &FilterBuilder{} } byteBody = jsonBody } q.body = byteBody return &FilterBuilder{client: q.client, method: q.method, body: q.body, tableName: q.tableName, headers: q.headers, params: q.params} } // Delete performs a deletion from the table. func (q *QueryBuilder) Delete(returning, count string) *FilterBuilder { q.method = "DELETE" var headerList []string if returning == "" { returning = "representation" } if returning == "minimal" || returning == "representation" { headerList = append(headerList, "return="+returning) } if count != "" && (count == `exact` || count == `planned` || count == `estimated`) { headerList = append(headerList, "count="+count) } q.headers["Prefer"] = strings.Join(headerList, ",") return &FilterBuilder{client: q.client, method: q.method, body: q.body, tableName: q.tableName, headers: q.headers, params: q.params} } // Update performs an update on the table. func (q *QueryBuilder) Update(value interface{}, returning, count string) *FilterBuilder { q.method = "PATCH" var headerList []string if returning == "" { returning = "representation" } if returning == "minimal" || returning == "representation" { headerList = append(headerList, "return="+returning) } if count != "" && (count == `exact` || count == `planned` || count == `estimated`) { headerList = append(headerList, "count="+count) } q.headers["Prefer"] = strings.Join(headerList, ",") // Get body if it exists var byteBody []byte = nil if value != nil { jsonBody, err := json.Marshal(value) if err != nil { q.client.ClientError = err return &FilterBuilder{} } byteBody = jsonBody } q.body = byteBody return &FilterBuilder{client: q.client, method: q.method, body: q.body, tableName: q.tableName, headers: q.headers, params: q.params} }
// This is basic example for postgrest-go library usage. // For now this example is represent wanted syntax and bindings for library. // After core development this test files will be used for CI tests. package main import ( "fmt" postgrest "github.com/supabase-community/postgrest-go" ) var ( RestUrl = `http://localhost:3000` headers = map[string]string{} schema = "public" ) func main() { client := postgrest.NewClient(RestUrl, schema, headers) res, _, err := client.From("actor").Select("actor_id,first_name", "", false).ExecuteString() if err != nil { panic(err) } fmt.Println(res) }
package main import ( "fmt" postgrest "github.com/supabase-community/postgrest-go" ) var ( REST_URL = `http://localhost:3000` ) func main() { client := postgrest.NewClient(REST_URL, "", nil) if client.ClientError != nil { panic(client.ClientError) } result := client.Rpc("add_them", "", map[string]int{"a": 9, "b": 3}) if client.ClientError != nil { panic(client.ClientError) } fmt.Println(result) }